apix/backend/routers/generate.py
Khoa.vo 0ef7e5475c
Some checks are pending
CI / build (18.x) (push) Waiting to run
CI / build (20.x) (push) Waiting to run
v3.0.0: Add FastAPI backend
- Add Python FastAPI backend with Pydantic validation
- Port WhiskClient and MetaAIClient to Python
- Create API routers for all endpoints
- Add Swagger/ReDoc documentation at /docs
- Update Dockerfile for multi-service container
- Add lib/api.ts frontend client
- Update README for V3
2026-01-13 07:46:32 +07:00

92 lines
3.3 KiB
Python

"""
Generate Router - Whisk image generation
"""
from fastapi import APIRouter, HTTPException
from models.requests import GenerateRequest
from models.responses import GenerateResponse, GeneratedImage, ErrorResponse
from services.whisk_client import WhiskClient
import asyncio
router = APIRouter(tags=["Generate"])
@router.post(
"/generate",
response_model=GenerateResponse,
responses={
400: {"model": ErrorResponse},
401: {"model": ErrorResponse},
500: {"model": ErrorResponse}
}
)
async def generate_images(request: GenerateRequest):
"""
Generate images using Whisk API.
- **prompt**: Text description of the image to generate
- **aspectRatio**: Output aspect ratio (1:1, 9:16, 16:9, etc.)
- **refs**: Optional reference images {subject, scene, style}
- **imageCount**: Number of parallel generation requests (1-4)
- **cookies**: Whisk authentication cookies
"""
if not request.cookies:
raise HTTPException(status_code=401, detail="Whisk cookies not found. Please configure settings.")
try:
# Normalize cookies if JSON format
cookie_string = request.cookies.strip()
if cookie_string.startswith('[') or cookie_string.startswith('{'):
import json
try:
cookie_array = json.loads(cookie_string)
if isinstance(cookie_array, list):
cookie_string = "; ".join(
f"{c['name']}={c['value']}" for c in cookie_array
)
print(f"[Generate] Parsed {len(cookie_array)} cookies from JSON.")
except Exception as e:
print(f"[Generate] Failed to parse cookie JSON: {e}")
client = WhiskClient(cookie_string)
# Generate images in parallel if imageCount > 1
parallel_count = min(max(1, request.imageCount), 4)
print(f"Starting {parallel_count} parallel generation requests for prompt: \"{request.prompt[:20]}...\"")
async def single_generate():
try:
return await client.generate(
request.prompt,
request.aspectRatio,
request.refs,
request.preciseMode
)
except Exception as e:
print(f"Single generation request failed: {e}")
return []
results = await asyncio.gather(*[single_generate() for _ in range(parallel_count)])
all_images = [img for result in results for img in result]
if not all_images:
raise HTTPException(status_code=500, detail="All generation requests failed. Check logs or try again.")
return GenerateResponse(
images=[
GeneratedImage(
data=img.data,
index=img.index,
prompt=img.prompt,
aspectRatio=img.aspect_ratio
)
for img in all_images
]
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
msg = str(e)
is_auth_error = any(x in msg.lower() for x in ["401", "403", "auth", "cookies", "expired"])
status_code = 401 if is_auth_error else 500
raise HTTPException(status_code=status_code, detail=msg)