104 lines
4 KiB
Python
104 lines
4 KiB
Python
"""
|
|
Generate Router - Whisk image generation
|
|
"""
|
|
from fastapi import APIRouter, HTTPException
|
|
from models.requests import GenerateRequest
|
|
from models.responses import GenerateResponse, GeneratedImage, ErrorResponse
|
|
from services.whisk_client import WhiskClient
|
|
import asyncio
|
|
|
|
router = APIRouter(tags=["Generate"])
|
|
|
|
|
|
@router.post(
|
|
"/generate",
|
|
response_model=GenerateResponse,
|
|
responses={
|
|
400: {"model": ErrorResponse},
|
|
401: {"model": ErrorResponse},
|
|
500: {"model": ErrorResponse}
|
|
}
|
|
)
|
|
async def generate_images(request: GenerateRequest):
|
|
"""
|
|
Generate images using Whisk API.
|
|
|
|
- **prompt**: Text description of the image to generate
|
|
- **aspectRatio**: Output aspect ratio (1:1, 9:16, 16:9, etc.)
|
|
- **refs**: Optional reference images {subject, scene, style}
|
|
- **imageCount**: Number of parallel generation requests (1-4)
|
|
- **cookies**: Whisk authentication cookies
|
|
"""
|
|
if not request.cookies:
|
|
raise HTTPException(status_code=401, detail="Whisk cookies not found. Please configure settings.")
|
|
|
|
try:
|
|
# Normalize cookies if JSON format
|
|
cookie_string = request.cookies.strip()
|
|
print(f"[Generate] Raw cookie input length: {len(cookie_string)} chars")
|
|
|
|
if cookie_string.startswith('[') or cookie_string.startswith('{'):
|
|
import json
|
|
try:
|
|
cookie_array = json.loads(cookie_string)
|
|
if isinstance(cookie_array, list):
|
|
# Log all cookie names for debugging
|
|
cookie_names = [c.get('name', '?') for c in cookie_array if isinstance(c, dict)]
|
|
print(f"[Generate] Parsed {len(cookie_array)} cookies: {cookie_names}")
|
|
|
|
# Check for required Google cookies
|
|
required = ['SID', 'HSID', 'SSID', 'APISID', 'SAPISID']
|
|
missing = [r for r in required if r not in cookie_names]
|
|
if missing:
|
|
print(f"[Generate] WARNING: Missing required cookies: {missing}")
|
|
|
|
cookie_string = "; ".join(
|
|
f"{c['name']}={c['value']}" for c in cookie_array
|
|
if isinstance(c, dict) and 'name' in c and 'value' in c
|
|
)
|
|
except Exception as e:
|
|
print(f"[Generate] Failed to parse cookie JSON: {e}")
|
|
|
|
client = WhiskClient(cookie_string)
|
|
|
|
# Generate images in parallel if imageCount > 1
|
|
parallel_count = min(max(1, request.imageCount), 4)
|
|
print(f"Starting {parallel_count} parallel generation requests for prompt: \"{request.prompt[:20]}...\"")
|
|
|
|
async def single_generate():
|
|
try:
|
|
return await client.generate(
|
|
request.prompt,
|
|
request.aspectRatio,
|
|
request.refs,
|
|
request.preciseMode
|
|
)
|
|
except Exception as e:
|
|
print(f"Single generation request failed: {e}")
|
|
return []
|
|
|
|
results = await asyncio.gather(*[single_generate() for _ in range(parallel_count)])
|
|
all_images = [img for result in results for img in result]
|
|
|
|
if not all_images:
|
|
raise HTTPException(status_code=500, detail="All generation requests failed. Check logs or try again.")
|
|
|
|
return GenerateResponse(
|
|
images=[
|
|
GeneratedImage(
|
|
data=img.data,
|
|
index=img.index,
|
|
prompt=img.prompt,
|
|
aspectRatio=img.aspect_ratio
|
|
)
|
|
for img in all_images
|
|
]
|
|
)
|
|
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
except Exception as e:
|
|
msg = str(e)
|
|
is_auth_error = any(x in msg.lower() for x in ["401", "403", "auth", "cookies", "expired"])
|
|
status_code = 401 if is_auth_error else 500
|
|
raise HTTPException(status_code=status_code, detail=msg)
|