""" Generate Router - Whisk image generation """ from fastapi import APIRouter, HTTPException from models.requests import GenerateRequest from models.responses import GenerateResponse, GeneratedImage, ErrorResponse from services.whisk_client import WhiskClient import asyncio router = APIRouter(tags=["Generate"]) @router.post( "/generate", response_model=GenerateResponse, responses={ 400: {"model": ErrorResponse}, 401: {"model": ErrorResponse}, 500: {"model": ErrorResponse} } ) async def generate_images(request: GenerateRequest): """ Generate images using Whisk API. - **prompt**: Text description of the image to generate - **aspectRatio**: Output aspect ratio (1:1, 9:16, 16:9, etc.) - **refs**: Optional reference images {subject, scene, style} - **imageCount**: Number of parallel generation requests (1-4) - **cookies**: Whisk authentication cookies """ if not request.cookies: raise HTTPException(status_code=401, detail="Whisk cookies not found. Please configure settings.") try: # Normalize cookies if JSON format cookie_string = request.cookies.strip() print(f"[Generate] Raw cookie input length: {len(cookie_string)} chars") if cookie_string.startswith('[') or cookie_string.startswith('{'): import json try: cookie_array = json.loads(cookie_string) if isinstance(cookie_array, list): # Log all cookie names for debugging cookie_names = [c.get('name', '?') for c in cookie_array if isinstance(c, dict)] print(f"[Generate] Parsed {len(cookie_array)} cookies: {cookie_names}") # Check for required Google cookies required = ['SID', 'HSID', 'SSID', 'APISID', 'SAPISID'] missing = [r for r in required if r not in cookie_names] if missing: print(f"[Generate] WARNING: Missing required cookies: {missing}") cookie_string = "; ".join( f"{c['name']}={c['value']}" for c in cookie_array if isinstance(c, dict) and 'name' in c and 'value' in c ) except Exception as e: print(f"[Generate] Failed to parse cookie JSON: {e}") client = WhiskClient(cookie_string) # Generate images in parallel if imageCount > 1 parallel_count = min(max(1, request.imageCount), 4) print(f"Starting {parallel_count} parallel generation requests for prompt: \"{request.prompt[:20]}...\"") async def single_generate(): try: return await client.generate( request.prompt, request.aspectRatio, request.refs, request.preciseMode ) except Exception as e: print(f"Single generation request failed: {e}") return [] results = await asyncio.gather(*[single_generate() for _ in range(parallel_count)]) all_images = [img for result in results for img in result] if not all_images: raise HTTPException(status_code=500, detail="All generation requests failed. Check logs or try again.") return GenerateResponse( images=[ GeneratedImage( data=img.data, index=img.index, prompt=img.prompt, aspectRatio=img.aspect_ratio ) for img in all_images ] ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: msg = str(e) is_auth_error = any(x in msg.lower() for x in ["401", "403", "auth", "cookies", "expired"]) status_code = 401 if is_auth_error else 500 raise HTTPException(status_code=status_code, detail=msg)