""" Generate Router - Whisk image generation """ from fastapi import APIRouter, HTTPException from models.requests import GenerateRequest from models.responses import GenerateResponse, GeneratedImage, ErrorResponse from services.whisk_client import WhiskClient import asyncio router = APIRouter(tags=["Generate"]) @router.post( "/generate", response_model=GenerateResponse, responses={ 400: {"model": ErrorResponse}, 401: {"model": ErrorResponse}, 500: {"model": ErrorResponse} } ) async def generate_images(request: GenerateRequest): """ Generate images using Whisk API. - **prompt**: Text description of the image to generate - **aspectRatio**: Output aspect ratio (1:1, 9:16, 16:9, etc.) - **refs**: Optional reference images {subject, scene, style} - **imageCount**: Number of parallel generation requests (1-4) - **cookies**: Whisk authentication cookies """ if not request.cookies: raise HTTPException(status_code=401, detail="Whisk cookies not found. Please configure settings.") try: # Normalize cookies if JSON format cookie_string = request.cookies.strip() if cookie_string.startswith('[') or cookie_string.startswith('{'): import json try: cookie_array = json.loads(cookie_string) if isinstance(cookie_array, list): cookie_string = "; ".join( f"{c['name']}={c['value']}" for c in cookie_array ) print(f"[Generate] Parsed {len(cookie_array)} cookies from JSON.") except Exception as e: print(f"[Generate] Failed to parse cookie JSON: {e}") client = WhiskClient(cookie_string) # Generate images in parallel if imageCount > 1 parallel_count = min(max(1, request.imageCount), 4) print(f"Starting {parallel_count} parallel generation requests for prompt: \"{request.prompt[:20]}...\"") async def single_generate(): try: return await client.generate( request.prompt, request.aspectRatio, request.refs, request.preciseMode ) except Exception as e: print(f"Single generation request failed: {e}") return [] results = await asyncio.gather(*[single_generate() for _ in range(parallel_count)]) all_images = [img for result in results for img in result] if not all_images: raise HTTPException(status_code=500, detail="All generation requests failed. Check logs or try again.") return GenerateResponse( images=[ GeneratedImage( data=img.data, index=img.index, prompt=img.prompt, aspectRatio=img.aspect_ratio ) for img in all_images ] ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: msg = str(e) is_auth_error = any(x in msg.lower() for x in ["401", "403", "auth", "cookies", "expired"]) status_code = 401 if is_auth_error else 500 raise HTTPException(status_code=status_code, detail=msg)