953 lines
36 KiB
Python
953 lines
36 KiB
Python
import os
|
|
import base64
|
|
import uuid
|
|
import glob
|
|
import json
|
|
import shutil
|
|
from datetime import datetime
|
|
from io import BytesIO
|
|
from send2trash import send2trash
|
|
from flask import Flask, render_template, request, jsonify, url_for
|
|
from google import genai
|
|
from google.genai import types
|
|
from PIL import Image, PngImagePlugin
|
|
import threading, time, subprocess, re
|
|
import whisk_client
|
|
|
|
|
|
import logging
|
|
|
|
app = Flask(__name__)
|
|
log = logging.getLogger('werkzeug')
|
|
log.setLevel(logging.WARNING)
|
|
app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 0
|
|
|
|
# Configuration Directory Setup
|
|
CONFIG_DIR = os.environ.get('CONFIG_DIR', os.path.dirname(__file__))
|
|
if not os.path.exists(CONFIG_DIR):
|
|
os.makedirs(CONFIG_DIR, exist_ok=True)
|
|
|
|
def get_config_path(filename):
|
|
return os.path.join(CONFIG_DIR, filename)
|
|
|
|
def initialize_config_files():
|
|
"""Copy default config files to CONFIG_DIR if they don't exist."""
|
|
defaults = ['prompts.json', 'user_prompts.json', 'gallery_favorites.json']
|
|
source_dir = os.path.dirname(__file__)
|
|
|
|
for filename in defaults:
|
|
dest_path = get_config_path(filename)
|
|
if not os.path.exists(dest_path):
|
|
source_path = os.path.join(source_dir, filename)
|
|
if os.path.exists(source_path):
|
|
print(f"Initializing {filename} in {CONFIG_DIR}...", flush=True)
|
|
try:
|
|
import shutil
|
|
shutil.copy2(source_path, dest_path)
|
|
except Exception as e:
|
|
print(f"Error initializing {filename}: {e}", flush=True)
|
|
|
|
# Run initialization on startup
|
|
initialize_config_files()
|
|
|
|
PREVIEW_MAX_DIMENSION = 1024
|
|
PREVIEW_JPEG_QUALITY = 85
|
|
|
|
try:
|
|
RESAMPLE_FILTER = Image.Resampling.LANCZOS
|
|
except AttributeError:
|
|
if hasattr(Image, 'LANCZOS'):
|
|
RESAMPLE_FILTER = Image.LANCZOS
|
|
else:
|
|
RESAMPLE_FILTER = Image.BICUBIC
|
|
|
|
FORMAT_BY_EXTENSION = {
|
|
'.jpg': 'JPEG',
|
|
'.jpeg': 'JPEG',
|
|
'.png': 'PNG',
|
|
'.webp': 'WEBP',
|
|
}
|
|
|
|
|
|
def _normalize_extension(ext):
|
|
if not ext:
|
|
return '.png'
|
|
ext = ext.lower()
|
|
if not ext.startswith('.'):
|
|
ext = f'.{ext}'
|
|
return ext
|
|
|
|
|
|
def _format_for_extension(ext):
|
|
return FORMAT_BY_EXTENSION.get(ext, 'PNG')
|
|
|
|
|
|
def save_compressed_preview(image, filepath, extension):
|
|
extension = _normalize_extension(extension)
|
|
image_copy = image.copy()
|
|
image_copy.thumbnail((PREVIEW_MAX_DIMENSION, PREVIEW_MAX_DIMENSION), RESAMPLE_FILTER)
|
|
image_format = _format_for_extension(extension)
|
|
save_kwargs = {}
|
|
|
|
if image_format == 'JPEG':
|
|
if image_copy.mode not in ('RGB', 'RGBA'):
|
|
image_copy = image_copy.convert('RGB')
|
|
save_kwargs.update(quality=PREVIEW_JPEG_QUALITY, optimize=True, progressive=True)
|
|
elif image_format == 'WEBP':
|
|
save_kwargs.update(quality=PREVIEW_JPEG_QUALITY, method=6)
|
|
elif image_format == 'PNG':
|
|
save_kwargs.update(optimize=True)
|
|
|
|
image_copy.save(filepath, format=image_format, **save_kwargs)
|
|
|
|
|
|
def save_preview_image(preview_dir, extension='.png', source_bytes=None, source_path=None):
|
|
extension = _normalize_extension(extension)
|
|
filename = f"template_{uuid.uuid4()}{extension}"
|
|
filepath = os.path.join(preview_dir, filename)
|
|
|
|
try:
|
|
image = None
|
|
if source_bytes is not None:
|
|
image = Image.open(BytesIO(source_bytes))
|
|
elif source_path is not None:
|
|
image = Image.open(source_path)
|
|
|
|
if image is not None:
|
|
save_compressed_preview(image, filepath, extension)
|
|
return filename
|
|
elif source_bytes is not None:
|
|
with open(filepath, 'wb') as f:
|
|
f.write(source_bytes)
|
|
return filename
|
|
elif source_path is not None:
|
|
shutil.copy2(source_path, filepath)
|
|
return filename
|
|
except Exception as exc:
|
|
print(f"Error saving preview image '{filename}': {exc}")
|
|
try:
|
|
if source_bytes is not None:
|
|
with open(filepath, 'wb') as f:
|
|
f.write(source_bytes)
|
|
return filename
|
|
if source_path is not None:
|
|
shutil.copy2(source_path, filepath)
|
|
return filename
|
|
except Exception as fallback_exc:
|
|
print(f"Fallback saving preview image failed: {fallback_exc}")
|
|
return None
|
|
|
|
return None
|
|
|
|
FAVORITES_FILE = get_config_path('template_favorites.json')
|
|
|
|
def load_template_favorites():
|
|
if os.path.exists(FAVORITES_FILE):
|
|
try:
|
|
with open(FAVORITES_FILE, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
if isinstance(data, list):
|
|
return [item for item in data if isinstance(item, str)]
|
|
except json.JSONDecodeError:
|
|
pass
|
|
return []
|
|
|
|
def save_template_favorites(favorites):
|
|
try:
|
|
with open(FAVORITES_FILE, 'w', encoding='utf-8') as f:
|
|
json.dump(favorites, f, indent=4, ensure_ascii=False)
|
|
except Exception as e:
|
|
print(f"Failed to persist template favorites: {e}")
|
|
|
|
GALLERY_FAVORITES_FILE = get_config_path('gallery_favorites.json')
|
|
|
|
def load_gallery_favorites():
|
|
if os.path.exists(GALLERY_FAVORITES_FILE):
|
|
try:
|
|
with open(GALLERY_FAVORITES_FILE, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
if isinstance(data, list):
|
|
return [item for item in data if isinstance(item, str)]
|
|
except json.JSONDecodeError:
|
|
pass
|
|
return []
|
|
|
|
def save_gallery_favorites(favorites):
|
|
try:
|
|
with open(GALLERY_FAVORITES_FILE, 'w', encoding='utf-8') as f:
|
|
json.dump(favorites, f, indent=4, ensure_ascii=False)
|
|
except Exception as e:
|
|
print(f"Failed to persist gallery favorites: {e}")
|
|
|
|
def parse_tags_field(value):
|
|
tags = []
|
|
if isinstance(value, list):
|
|
tags = value
|
|
elif isinstance(value, str):
|
|
try:
|
|
parsed = json.loads(value)
|
|
if isinstance(parsed, list):
|
|
tags = parsed
|
|
else:
|
|
tags = [parsed]
|
|
except json.JSONDecodeError:
|
|
tags = [value]
|
|
else:
|
|
return []
|
|
|
|
result = []
|
|
for tag in tags:
|
|
if isinstance(tag, dict):
|
|
fallback = tag.get('vi') or tag.get('en')
|
|
if fallback:
|
|
normalized = fallback.strip()
|
|
else:
|
|
continue
|
|
elif isinstance(tag, str):
|
|
normalized = tag.strip()
|
|
else:
|
|
continue
|
|
|
|
if normalized:
|
|
result.append(normalized)
|
|
if len(result) >= 12:
|
|
break
|
|
|
|
return result
|
|
|
|
# Ensure generated directory exists inside Flask static folder
|
|
GENERATED_DIR = os.path.join(app.static_folder, 'generated')
|
|
os.makedirs(GENERATED_DIR, exist_ok=True)
|
|
|
|
# Ensure uploads directory exists
|
|
UPLOADS_DIR = os.path.join(app.static_folder, 'uploads')
|
|
os.makedirs(UPLOADS_DIR, exist_ok=True)
|
|
ALLOWED_GALLERY_EXTS = ('.png', '.jpg', '.jpeg', '.webp')
|
|
|
|
|
|
|
|
|
|
def normalize_gallery_path(path):
|
|
"""Return a clean path relative to /static without traversal."""
|
|
if not path:
|
|
return ''
|
|
cleaned = path.replace('\\', '/')
|
|
cleaned = cleaned.split('?', 1)[0]
|
|
if cleaned.startswith('/'):
|
|
cleaned = cleaned[1:]
|
|
if cleaned.startswith('static/'):
|
|
cleaned = cleaned[len('static/'):]
|
|
normalized = os.path.normpath(cleaned)
|
|
if normalized.startswith('..'):
|
|
return ''
|
|
return normalized
|
|
|
|
|
|
def resolve_gallery_target(source, filename=None, relative_path=None):
|
|
"""Resolve the gallery source (generated/uploads) and absolute filepath."""
|
|
cleaned_path = normalize_gallery_path(relative_path)
|
|
candidate_name = cleaned_path or (filename or '')
|
|
if not candidate_name:
|
|
return None, None, None
|
|
|
|
normalized_name = os.path.basename(candidate_name)
|
|
|
|
inferred_source = (source or '').lower()
|
|
if cleaned_path:
|
|
first_segment = cleaned_path.split('/')[0]
|
|
if first_segment in ('generated', 'uploads'):
|
|
inferred_source = first_segment
|
|
|
|
if inferred_source not in ('generated', 'uploads'):
|
|
inferred_source = 'generated'
|
|
|
|
base_dir = UPLOADS_DIR if inferred_source == 'uploads' else GENERATED_DIR
|
|
filepath = os.path.join(base_dir, normalized_name)
|
|
storage_key = f"{inferred_source}/{normalized_name}"
|
|
return inferred_source, filepath, storage_key
|
|
|
|
def process_prompt_with_placeholders(prompt, note):
|
|
"""
|
|
Process prompt with {text} or [text] placeholders.
|
|
|
|
Logic:
|
|
1. If prompt has placeholders:
|
|
- If note is empty:
|
|
- If placeholder contains pipes (e.g. {cat|dog} or [cat|dog]), generate multiple prompts
|
|
- If no pipes, keep placeholder as is
|
|
- If note has content:
|
|
- If note has pipes (|), split note and replace placeholders for each segment (queue)
|
|
- If note has newlines, split note and replace placeholders sequentially
|
|
- If single note, replace all placeholders with note content
|
|
2. If no placeholders:
|
|
- Standard behavior: "{prompt}. {note}"
|
|
|
|
Returns:
|
|
list: List of processed prompts
|
|
"""
|
|
import re
|
|
|
|
# Regex to find placeholders: {text} or [text]
|
|
# Matches {content} or [content]
|
|
placeholder_pattern = r'\{([^{}]+)\}|\[([^\[\]]+)\]'
|
|
placeholders = re.findall(placeholder_pattern, prompt)
|
|
|
|
# Flatten the list of tuples from findall and filter empty strings
|
|
# re.findall with groups returns list of tuples like [('content', ''), ('', 'content')]
|
|
placeholders = [p[0] or p[1] for p in placeholders if p[0] or p[1]]
|
|
|
|
if not placeholders:
|
|
# Standard behavior
|
|
return [f"{prompt}. {note}" if note else prompt]
|
|
|
|
# If note is empty, check for default values in placeholders
|
|
if not note:
|
|
# Check if any placeholder has pipe-separated values
|
|
# We only handle the FIRST placeholder with pipes for combinatorial generation to keep it simple
|
|
# or we could generate for all, but let's stick to the requirement: "creates multiple commands"
|
|
|
|
# Find the first placeholder that has options
|
|
target_placeholder = None
|
|
options = []
|
|
|
|
for p in placeholders:
|
|
if '|' in p:
|
|
target_placeholder = p
|
|
options = p.split('|')
|
|
break
|
|
|
|
if target_placeholder:
|
|
# Generate a prompt for each option
|
|
generated_prompts = []
|
|
for option in options:
|
|
# Replace the target placeholder with the option
|
|
# We need to handle both {placeholder} and [placeholder]
|
|
# Construct regex that matches either {target} or [target]
|
|
escaped_target = re.escape(target_placeholder)
|
|
pattern = f'(\\{{{escaped_target}\\}}|\\[{escaped_target}\\])'
|
|
|
|
# Replace only the first occurrence or all?
|
|
# Usually all occurrences of the same placeholder string
|
|
new_prompt = re.sub(pattern, option.strip(), prompt)
|
|
generated_prompts.append(new_prompt)
|
|
return generated_prompts
|
|
|
|
# No pipes in placeholders, return prompt as is (placeholders remain)
|
|
return [prompt]
|
|
|
|
# Note has content
|
|
if '|' in note:
|
|
# Split note by pipe and generate a prompt for each segment
|
|
note_segments = [s.strip() for s in note.split('|') if s.strip()]
|
|
generated_prompts = []
|
|
|
|
for segment in note_segments:
|
|
current_prompt = prompt
|
|
# Replace all placeholders with this segment
|
|
# We need to replace all found placeholders
|
|
for p in placeholders:
|
|
escaped_p = re.escape(p)
|
|
pattern = f'(\\{{{escaped_p}\\}}|\\[{escaped_p}\\])'
|
|
current_prompt = re.sub(pattern, segment, current_prompt)
|
|
generated_prompts.append(current_prompt)
|
|
|
|
return generated_prompts
|
|
|
|
elif '\n' in note:
|
|
# Split note by newline and replace placeholders sequentially
|
|
note_lines = [l.strip() for l in note.split('\n') if l.strip()]
|
|
current_prompt = prompt
|
|
|
|
for i, p in enumerate(placeholders):
|
|
replacement = ""
|
|
if i < len(note_lines):
|
|
replacement = note_lines[i]
|
|
else:
|
|
# If fewer lines than placeholders, use default (content inside braces)
|
|
# If default has pipes, take the first one
|
|
if '|' in p:
|
|
replacement = p.split('|')[0]
|
|
else:
|
|
# Keep the placeholder text but remove braces?
|
|
# Or keep the original placeholder?
|
|
# Requirement says: "remaining placeholders use their default text"
|
|
replacement = p
|
|
|
|
escaped_p = re.escape(p)
|
|
pattern = f'(\\{{{escaped_p}\\}}|\\[{escaped_p}\\])'
|
|
# Replace only the first occurrence of this specific placeholder to allow sequential mapping
|
|
# But if multiple placeholders have SAME text, this might be ambiguous.
|
|
# Assuming placeholders are unique or processed left-to-right.
|
|
# re.sub replaces all by default, count=1 replaces first
|
|
current_prompt = re.sub(pattern, replacement, current_prompt, count=1)
|
|
|
|
return [current_prompt]
|
|
|
|
else:
|
|
# Single note content, replace all placeholders
|
|
current_prompt = prompt
|
|
for p in placeholders:
|
|
escaped_p = re.escape(p)
|
|
pattern = f'(\\{{{escaped_p}\\}}|\\[{escaped_p}\\])'
|
|
current_prompt = re.sub(pattern, note, current_prompt)
|
|
return [current_prompt]
|
|
|
|
@app.route('/')
|
|
def index():
|
|
return render_template('index.html')
|
|
|
|
@app.route('/generate', methods=['POST'])
|
|
def generate_image():
|
|
multipart = request.content_type and 'multipart/form-data' in request.content_type
|
|
|
|
if multipart:
|
|
form = request.form
|
|
prompt = form.get('prompt')
|
|
note = form.get('note', '')
|
|
aspect_ratio = form.get('aspect_ratio')
|
|
resolution = form.get('resolution', '2K')
|
|
model = form.get('model', 'gemini-3-pro-image-preview')
|
|
api_key = form.get('api_key') or os.environ.get('GOOGLE_API_KEY')
|
|
reference_files = request.files.getlist('reference_images')
|
|
reference_paths_json = form.get('reference_image_paths')
|
|
else:
|
|
data = request.get_json() or {}
|
|
prompt = data.get('prompt')
|
|
note = data.get('note', '')
|
|
aspect_ratio = data.get('aspect_ratio')
|
|
resolution = data.get('resolution', '2K')
|
|
model = data.get('model', 'gemini-3-pro-image-preview')
|
|
api_key = data.get('api_key') or os.environ.get('GOOGLE_API_KEY')
|
|
reference_files = []
|
|
reference_paths_json = data.get('reference_image_paths')
|
|
|
|
if not prompt:
|
|
return jsonify({'error': 'Prompt is required'}), 400
|
|
|
|
# Determine if this is a Whisk request
|
|
is_whisk = 'whisk' in model.lower() or 'imagefx' in model.lower()
|
|
|
|
if not is_whisk and not api_key:
|
|
return jsonify({'error': 'API Key is required for Gemini models.'}), 401
|
|
|
|
try:
|
|
print("Đang gửi lệnh...", flush=True)
|
|
# client initialization moved to Gemini block
|
|
|
|
image_config_args = {}
|
|
|
|
# Only add resolution if NOT using flash model
|
|
if model != 'gemini-2.5-flash-image':
|
|
image_config_args["image_size"] = resolution
|
|
|
|
if aspect_ratio and aspect_ratio != 'Auto':
|
|
image_config_args["aspect_ratio"] = aspect_ratio
|
|
|
|
# Process reference paths and files
|
|
final_reference_paths = []
|
|
|
|
# Process prompt with placeholders - returns list of prompts
|
|
processed_prompts = process_prompt_with_placeholders(prompt, note)
|
|
|
|
# If multiple prompts (queue scenario), return them to frontend for queue processing
|
|
if len(processed_prompts) > 1:
|
|
return jsonify({
|
|
'queue': True,
|
|
'prompts': processed_prompts,
|
|
'metadata': {
|
|
'original_prompt': prompt,
|
|
'original_note': note,
|
|
'aspect_ratio': aspect_ratio or 'Auto',
|
|
'resolution': resolution,
|
|
'model': model
|
|
}
|
|
})
|
|
|
|
# Single prompt - continue with normal generation
|
|
api_prompt = processed_prompts[0]
|
|
contents = [api_prompt]
|
|
|
|
# Parse reference paths from frontend
|
|
frontend_paths = []
|
|
if reference_paths_json:
|
|
try:
|
|
frontend_paths = json.loads(reference_paths_json)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# If no paths provided but we have files (legacy or simple upload), treat all as new uploads
|
|
# But we need to handle the mix.
|
|
# Strategy: Iterate frontend_paths. If it looks like a path/URL, keep it.
|
|
# If it doesn't (or is null), consume from reference_files.
|
|
|
|
file_index = 0
|
|
|
|
# If frontend_paths is empty but we have files, just use the files
|
|
if not frontend_paths and reference_files:
|
|
for _ in reference_files:
|
|
frontend_paths.append(None) # Placeholder for each file
|
|
|
|
for path in frontend_paths:
|
|
if path and (path.startswith('/') or path.startswith('http')):
|
|
# Existing path/URL
|
|
final_reference_paths.append(path)
|
|
# We also need to add the image content to the prompt
|
|
# We need to fetch it or read it if it's local (server-side local)
|
|
# If it's a URL we generated, it's in static/generated or static/uploads
|
|
# path might be "http://localhost:8888/static/generated/..." or "/static/generated/..."
|
|
|
|
# Extract relative path to open file
|
|
# Assuming path contains '/static/'
|
|
try:
|
|
if '/static/' in path:
|
|
rel_path = path.split('/static/')[1]
|
|
abs_path = os.path.join(app.static_folder, rel_path)
|
|
if os.path.exists(abs_path):
|
|
img = Image.open(abs_path)
|
|
contents.append(img)
|
|
else:
|
|
print(f"Warning: Reference file not found at {abs_path}")
|
|
else:
|
|
print(f"Warning: Could not resolve local path for {path}")
|
|
except Exception as e:
|
|
print(f"Error loading reference from path {path}: {e}")
|
|
|
|
elif file_index < len(reference_files):
|
|
# New upload
|
|
file = reference_files[file_index]
|
|
file_index += 1
|
|
|
|
try:
|
|
# Save to uploads
|
|
ext = os.path.splitext(file.filename)[1]
|
|
if not ext:
|
|
ext = '.png'
|
|
filename = f"{uuid.uuid4()}{ext}"
|
|
filepath = os.path.join(UPLOADS_DIR, filename)
|
|
|
|
# We need to read the file for Gemini AND save it
|
|
# file.stream is a stream.
|
|
file.stream.seek(0)
|
|
file_bytes = file.read()
|
|
|
|
with open(filepath, 'wb') as f:
|
|
f.write(file_bytes)
|
|
|
|
# Add to contents
|
|
image = Image.open(BytesIO(file_bytes))
|
|
contents.append(image)
|
|
|
|
# Add to final paths
|
|
# URL for the uploaded file
|
|
rel_path = os.path.join('uploads', filename)
|
|
file_url = url_for('static', filename=rel_path)
|
|
final_reference_paths.append(file_url)
|
|
|
|
except Exception as e:
|
|
print(f"Error processing uploaded file: {e}")
|
|
continue
|
|
|
|
model_name = model
|
|
|
|
# ==================================================================================
|
|
# WHISK (IMAGEFX) HANDLING
|
|
# ==================================================================================
|
|
if is_whisk:
|
|
print(f"Detected Whisk/ImageFX model request: {model_name}", flush=True)
|
|
|
|
# Extract cookies from request headers or form data
|
|
# Priority: Form Data 'cookies' > Request Header 'x-whisk-cookies' > Environment Variable
|
|
cookie_str = request.form.get('cookies') or request.headers.get('x-whisk-cookies') or os.environ.get('WHISK_COOKIES')
|
|
|
|
if not cookie_str:
|
|
return jsonify({'error': 'Whisk cookies are required. Please provide them in the "cookies" form field or configuration.'}), 400
|
|
|
|
print("Sending request to Whisk...", flush=True)
|
|
try:
|
|
# Check for reference images
|
|
reference_image_path = None
|
|
|
|
# final_reference_paths (populated above) contains URLs/paths to reference images.
|
|
# Can be new uploads or history items.
|
|
if final_reference_paths:
|
|
# Use the first one
|
|
ref_url = final_reference_paths[0]
|
|
|
|
# Convert URL/Path to absolute local path
|
|
# ref_url might be "http://.../static/..." or "/static/..."
|
|
if '/static/' in ref_url:
|
|
rel_path = ref_url.split('/static/')[1]
|
|
possible_path = os.path.join(app.static_folder, rel_path)
|
|
if os.path.exists(possible_path):
|
|
reference_image_path = possible_path
|
|
print(f"Whisk: Using reference image at {reference_image_path}", flush=True)
|
|
elif os.path.exists(ref_url):
|
|
# It's already a path?
|
|
reference_image_path = ref_url
|
|
|
|
# Call the client
|
|
image_count = int(data.get('image_count', 4)) if not multipart else int(form.get('image_count', 4))
|
|
|
|
try:
|
|
whisk_result = whisk_client.generate_image_whisk(
|
|
prompt=api_prompt,
|
|
cookie_str=cookie_str,
|
|
image_count=image_count,
|
|
aspect_ratio=aspect_ratio,
|
|
resolution=resolution,
|
|
reference_image_path=reference_image_path
|
|
)
|
|
except Exception as e:
|
|
# Re-raise to be caught by the outer block
|
|
raise e
|
|
|
|
# Process result - whisk_client returns List[bytes] or bytes (in case of fallback/legacy)
|
|
image_bytes_list = []
|
|
if isinstance(whisk_result, list):
|
|
image_bytes_list = whisk_result
|
|
elif isinstance(whisk_result, bytes):
|
|
image_bytes_list = [whisk_result]
|
|
elif isinstance(whisk_result, dict):
|
|
# Fallback if I ever change the client to return dict
|
|
if 'image_data' in whisk_result:
|
|
image_bytes_list = [whisk_result['image_data']]
|
|
elif 'image_url' in whisk_result:
|
|
import requests
|
|
img_resp = requests.get(whisk_result['image_url'])
|
|
image_bytes_list = [img_resp.content]
|
|
|
|
if not image_bytes_list:
|
|
raise ValueError("No image data returned from Whisk.")
|
|
|
|
# Process all images
|
|
saved_urls = []
|
|
saved_b64s = []
|
|
|
|
date_str = datetime.now().strftime("%Y%m%d")
|
|
search_pattern = os.path.join(GENERATED_DIR, f"whisk_{date_str}_*.png")
|
|
existing_files = glob.glob(search_pattern)
|
|
max_id = 0
|
|
for f in existing_files:
|
|
try:
|
|
basename = os.path.basename(f)
|
|
name_without_ext = os.path.splitext(basename)[0]
|
|
parts = name_without_ext.split('_')
|
|
# Check for batch_ID part
|
|
if len(parts) >= 3:
|
|
id_part = parts[2]
|
|
id_num = int(id_part)
|
|
if id_num > max_id:
|
|
max_id = id_num
|
|
elif len(parts) == 2:
|
|
pass
|
|
except (ValueError, IndexError):
|
|
continue
|
|
|
|
next_batch_id = max_id + 1
|
|
|
|
for idx, img_bytes in enumerate(image_bytes_list):
|
|
image = Image.open(BytesIO(img_bytes))
|
|
png_info = PngImagePlugin.PngInfo()
|
|
|
|
filename = f"whisk_{date_str}_{next_batch_id}_{idx}.png"
|
|
filepath = os.path.join(GENERATED_DIR, filename)
|
|
rel_path = os.path.join('generated', filename)
|
|
image_url = url_for('static', filename=rel_path)
|
|
|
|
metadata = {
|
|
'prompt': prompt,
|
|
'note': note,
|
|
'processed_prompt': api_prompt,
|
|
'aspect_ratio': aspect_ratio or 'Auto',
|
|
'resolution': resolution,
|
|
'reference_images': final_reference_paths,
|
|
'model': 'whisk',
|
|
'batch_id': next_batch_id,
|
|
'batch_index': idx
|
|
}
|
|
png_info.add_text('sdvn_meta', json.dumps(metadata))
|
|
|
|
buffer = BytesIO()
|
|
image.save(buffer, format='PNG', pnginfo=png_info)
|
|
final_bytes = buffer.getvalue()
|
|
|
|
with open(filepath, 'wb') as f:
|
|
f.write(final_bytes)
|
|
|
|
b64_str = base64.b64encode(final_bytes).decode('utf-8')
|
|
saved_urls.append(image_url)
|
|
saved_b64s.append(b64_str)
|
|
|
|
return jsonify({
|
|
'image': saved_urls[0], # Legacy support
|
|
'images': saved_urls, # New support
|
|
'image_data': saved_b64s[0], # Legacy
|
|
'image_datas': saved_b64s, # New
|
|
'metadata': metadata,
|
|
})
|
|
|
|
except Exception as e:
|
|
print(f"Whisk error: {e}")
|
|
return jsonify({'error': f"Whisk Generation Error: {str(e)}"}), 500
|
|
|
|
# ==================================================================================
|
|
# STANDARD GEMINI HANDLING
|
|
# ==================================================================================
|
|
|
|
# Initialize Client here, since API Key is required
|
|
client = genai.Client(api_key=api_key)
|
|
|
|
print(f"Đang tạo với model {model_name}...", flush=True)
|
|
response = client.models.generate_content(
|
|
model=model_name,
|
|
contents=contents,
|
|
config=types.GenerateContentConfig(
|
|
response_modalities=['IMAGE'],
|
|
image_config=types.ImageConfig(**image_config_args),
|
|
)
|
|
)
|
|
print("Hoàn tất!", flush=True)
|
|
|
|
for part in response.parts:
|
|
if part.inline_data:
|
|
image_bytes = part.inline_data.data
|
|
|
|
image = Image.open(BytesIO(image_bytes))
|
|
png_info = PngImagePlugin.PngInfo()
|
|
|
|
date_str = datetime.now().strftime("%Y%m%d")
|
|
|
|
# Find existing files to determine next ID
|
|
search_pattern = os.path.join(GENERATED_DIR, f"{model_name}_{date_str}_*.png")
|
|
existing_files = glob.glob(search_pattern)
|
|
max_id = 0
|
|
for f in existing_files:
|
|
try:
|
|
basename = os.path.basename(f)
|
|
name_without_ext = os.path.splitext(basename)[0]
|
|
id_part = name_without_ext.split('_')[-1]
|
|
id_num = int(id_part)
|
|
if id_num > max_id:
|
|
max_id = id_num
|
|
except ValueError:
|
|
continue
|
|
|
|
next_id = max_id + 1
|
|
filename = f"{model_name}_{date_str}_{next_id}.png"
|
|
filepath = os.path.join(GENERATED_DIR, filename)
|
|
rel_path = os.path.join('generated', filename)
|
|
image_url = url_for('static', filename=rel_path)
|
|
|
|
metadata = {
|
|
# Keep the exact user input before placeholder expansion
|
|
'prompt': prompt,
|
|
'note': note,
|
|
# Also store the expanded prompt for reference
|
|
'processed_prompt': api_prompt,
|
|
'aspect_ratio': aspect_ratio or 'Auto',
|
|
'resolution': resolution,
|
|
'reference_images': final_reference_paths,
|
|
}
|
|
|
|
png_info.add_text('sdvn_meta', json.dumps(metadata))
|
|
|
|
buffer = BytesIO()
|
|
image.save(buffer, format='PNG', pnginfo=png_info)
|
|
final_bytes = buffer.getvalue()
|
|
|
|
# Save image to file
|
|
with open(filepath, 'wb') as f:
|
|
f.write(final_bytes)
|
|
|
|
image_data = base64.b64encode(final_bytes).decode('utf-8')
|
|
return jsonify({
|
|
'image': image_url,
|
|
'image_data': image_data,
|
|
'metadata': metadata,
|
|
})
|
|
|
|
return jsonify({'error': 'No image generated'}), 500
|
|
|
|
except Exception as e:
|
|
return jsonify({'error': str(e)}), 500
|
|
|
|
@app.route('/delete_image', methods=['POST'])
|
|
def delete_image():
|
|
data = request.get_json() or {}
|
|
filename = data.get('filename')
|
|
source = data.get('source')
|
|
rel_path = data.get('path') or data.get('relative_path')
|
|
|
|
resolved_source, filepath, storage_key = resolve_gallery_target(source, filename, rel_path)
|
|
if not filepath:
|
|
return jsonify({'error': 'Filename is required'}), 400
|
|
|
|
if os.path.exists(filepath):
|
|
try:
|
|
send2trash(filepath)
|
|
return jsonify({'success': True, 'source': resolved_source})
|
|
except Exception as e:
|
|
return jsonify({'error': str(e)}), 500
|
|
else:
|
|
return jsonify({'error': 'File not found'}), 404
|
|
|
|
@app.route('/gallery')
|
|
def get_gallery():
|
|
# List all images in the chosen source directory, sorted by modification time (newest first)
|
|
source_param = (request.args.get('source') or 'generated').lower()
|
|
base_dir = UPLOADS_DIR if source_param == 'uploads' else GENERATED_DIR
|
|
resolved_source = 'uploads' if base_dir == UPLOADS_DIR else 'generated'
|
|
|
|
files = [
|
|
f for f in glob.glob(os.path.join(base_dir, '*'))
|
|
if os.path.splitext(f)[1].lower() in ALLOWED_GALLERY_EXTS
|
|
]
|
|
files.sort(key=os.path.getmtime, reverse=True)
|
|
|
|
image_urls = [url_for('static', filename=f'{resolved_source}/{os.path.basename(f)}') for f in files]
|
|
response = jsonify({'images': image_urls, 'source': resolved_source})
|
|
response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
|
|
return response
|
|
|
|
@app.route('/prompts')
|
|
def get_prompts():
|
|
category = request.args.get('category')
|
|
|
|
try:
|
|
all_prompts = []
|
|
|
|
# Read prompts.json file
|
|
prompts_path = get_config_path('prompts.json')
|
|
if os.path.exists(prompts_path):
|
|
with open(prompts_path, 'r', encoding='utf-8') as f:
|
|
core_data = json.load(f)
|
|
if isinstance(core_data, list):
|
|
all_prompts.extend(core_data)
|
|
|
|
# Read user_prompts.json file
|
|
user_prompts_path = get_config_path('user_prompts.json')
|
|
if os.path.exists(user_prompts_path):
|
|
with open(user_prompts_path, 'r', encoding='utf-8') as f:
|
|
user_data = json.load(f)
|
|
if isinstance(user_data, list):
|
|
all_prompts.extend(user_data)
|
|
|
|
# Filter by category if provided
|
|
if category:
|
|
filtered_prompts = [p for p in all_prompts if p.get('category') == category]
|
|
return jsonify(filtered_prompts)
|
|
|
|
return jsonify(all_prompts)
|
|
except Exception as e:
|
|
print(f"Error reading prompts: {e}")
|
|
return jsonify([])
|
|
|
|
@app.route('/save_prompt', methods=['POST'])
|
|
def save_prompt():
|
|
data = request.get_json()
|
|
new_prompt = {
|
|
'act': data.get('act'),
|
|
'prompt': data.get('prompt'),
|
|
'category': 'User Saved',
|
|
'desc': data.get('desc', '')
|
|
}
|
|
|
|
user_prompts_path = get_config_path('user_prompts.json')
|
|
try:
|
|
existing_prompts = []
|
|
if os.path.exists(user_prompts_path):
|
|
with open(user_prompts_path, 'r', encoding='utf-8') as f:
|
|
existing_prompts = json.load(f)
|
|
|
|
existing_prompts.append(new_prompt)
|
|
|
|
with open(user_prompts_path, 'w', encoding='utf-8') as f:
|
|
json.dump(existing_prompts, f, ensure_ascii=False, indent=4)
|
|
|
|
return jsonify({'success': True})
|
|
except Exception as e:
|
|
return jsonify({'error': str(e)}), 500
|
|
|
|
@app.route('/save_template_favorite', methods=['POST'])
|
|
def save_template_fav():
|
|
data = request.get_json()
|
|
template_name = data.get('template')
|
|
if not template_name:
|
|
return jsonify({'error': 'Template name required'}), 400
|
|
|
|
favorites = load_template_favorites()
|
|
if template_name not in favorites:
|
|
favorites.insert(0, template_name)
|
|
save_template_favorites(favorites)
|
|
|
|
return jsonify({'success': True, 'favorites': favorites})
|
|
|
|
@app.route('/remove_template_favorite', methods=['POST'])
|
|
def remove_template_fav():
|
|
data = request.get_json()
|
|
template_name = data.get('template')
|
|
if not template_name:
|
|
return jsonify({'error': 'Template name required'}), 400
|
|
|
|
favorites = load_template_favorites()
|
|
if template_name in favorites:
|
|
favorites.remove(template_name)
|
|
save_template_favorites(favorites)
|
|
|
|
return jsonify({'success': True, 'favorites': favorites})
|
|
|
|
@app.route('/get_template_favorites')
|
|
def get_template_favs():
|
|
return jsonify(load_template_favorites())
|
|
|
|
@app.route('/save_gallery_favorite', methods=['POST'])
|
|
def save_gallery_fav():
|
|
data = request.get_json()
|
|
image_url = data.get('url')
|
|
if not image_url:
|
|
return jsonify({'error': 'URL required'}), 400
|
|
|
|
favorites = load_gallery_favorites()
|
|
if image_url not in favorites:
|
|
favorites.insert(0, image_url)
|
|
save_gallery_favorites(favorites)
|
|
|
|
return jsonify({'success': True, 'favorites': favorites})
|
|
|
|
@app.route('/remove_gallery_favorite', methods=['POST'])
|
|
def remove_gallery_fav():
|
|
data = request.get_json()
|
|
image_url = data.get('url')
|
|
if not image_url:
|
|
return jsonify({'error': 'URL required'}), 400
|
|
|
|
favorites = load_gallery_favorites()
|
|
if image_url in favorites:
|
|
favorites.remove(image_url)
|
|
save_gallery_favorites(favorites)
|
|
|
|
return jsonify({'success': True, 'favorites': favorites})
|
|
|
|
@app.route('/get_gallery_favorites')
|
|
def get_gallery_favs():
|
|
return jsonify(load_gallery_favorites())
|
|
|
|
def open_browser(url):
|
|
time.sleep(1.5)
|
|
print(f"Opening browser at {url}")
|
|
try:
|
|
subprocess.run(['open', url])
|
|
except:
|
|
pass
|
|
|
|
if __name__ == '__main__':
|
|
port_sever = 8888
|
|
# browser_thread = threading.Thread(target=open_browser, args=(f"http://127.0.0.1:{port_sever}",))
|
|
# browser_thread.start()
|
|
|
|
print("----------------------------------------------------------------")
|
|
print(" aPix v2.1 - STARTED")
|
|
print("----------------------------------------------------------------")
|
|
|
|
# Listen on all interfaces
|
|
app.run(host='0.0.0.0', port=port_sever, debug=True)
|