apix/backend/services/whisk_client.py
Khoa.vo 0ef7e5475c
Some checks are pending
CI / build (18.x) (push) Waiting to run
CI / build (20.x) (push) Waiting to run
v3.0.0: Add FastAPI backend
- Add Python FastAPI backend with Pydantic validation
- Port WhiskClient and MetaAIClient to Python
- Create API routers for all endpoints
- Add Swagger/ReDoc documentation at /docs
- Update Dockerfile for multi-service container
- Add lib/api.ts frontend client
- Update README for V3
2026-01-13 07:46:32 +07:00

410 lines
14 KiB
Python

"""
Whisk Client for Python/FastAPI
Port of lib/whisk-client.ts
Handles:
- Cookie parsing (JSON array or string format)
- Access token retrieval from Whisk API
- Image generation with aspect ratio support
- Reference image upload
- Video generation with polling
"""
import httpx
import json
import uuid
import base64
import asyncio
from typing import Optional, Dict, List, Any
# Whisk API endpoints
AUTH_URL = "https://aisandbox-pa.googleapis.com/v1:signInWithIdp"
GENERATE_URL = "https://aisandbox-pa.googleapis.com/v1:runImagine"
RECIPE_URL = "https://aisandbox-pa.googleapis.com/v1:runRecipe"
UPLOAD_URL = "https://aisandbox-pa.googleapis.com/v1:uploadMedia"
VIDEO_URL = "https://aisandbox-pa.googleapis.com/v1:runVideoFxSingleClips"
VIDEO_STATUS_URL = "https://aisandbox-pa.googleapis.com/v1:runVideoFxSingleClipsStatusCheck"
# Aspect ratio mapping
ASPECT_RATIOS = {
"1:1": "IMAGE_ASPECT_RATIO_SQUARE",
"9:16": "IMAGE_ASPECT_RATIO_PORTRAIT",
"16:9": "IMAGE_ASPECT_RATIO_LANDSCAPE",
"4:3": "IMAGE_ASPECT_RATIO_LANDSCAPE_FOUR_THREE",
"3:4": "IMAGE_ASPECT_RATIO_PORTRAIT",
"Auto": "IMAGE_ASPECT_RATIO_SQUARE"
}
MEDIA_CATEGORIES = {
"subject": "MEDIA_CATEGORY_SUBJECT",
"scene": "MEDIA_CATEGORY_SCENE",
"style": "MEDIA_CATEGORY_STYLE"
}
class GeneratedImage:
def __init__(self, data: str, index: int, prompt: str, aspect_ratio: str):
self.data = data
self.index = index
self.prompt = prompt
self.aspect_ratio = aspect_ratio
def to_dict(self) -> Dict[str, Any]:
return {
"data": self.data,
"index": self.index,
"prompt": self.prompt,
"aspectRatio": self.aspect_ratio
}
class WhiskVideoResult:
def __init__(self, id: str, url: Optional[str], status: str):
self.id = id
self.url = url
self.status = status
def to_dict(self) -> Dict[str, Any]:
return {
"id": self.id,
"url": self.url,
"status": self.status
}
class WhiskClient:
def __init__(self, cookie_input: str):
self.cookies = self._parse_cookies(cookie_input)
self.access_token: Optional[str] = None
self.token_expires: int = 0
self.cookie_string = ""
if not self.cookies:
raise ValueError("No valid cookies provided")
# Build cookie string for requests
self.cookie_string = "; ".join(
f"{name}={value}" for name, value in self.cookies.items()
)
def _parse_cookies(self, input_str: str) -> Dict[str, str]:
"""Parse cookies from string or JSON format"""
if not input_str or not input_str.strip():
return {}
trimmed = input_str.strip()
cookies: Dict[str, str] = {}
# Handle JSON array format (e.g., from Cookie-Editor)
if trimmed.startswith('[') or trimmed.startswith('{'):
try:
parsed = json.loads(trimmed)
if isinstance(parsed, list):
for c in parsed:
if isinstance(c, dict) and 'name' in c and 'value' in c:
cookies[c['name']] = c['value']
return cookies
elif isinstance(parsed, dict) and 'name' in parsed and 'value' in parsed:
return {parsed['name']: parsed['value']}
except json.JSONDecodeError:
pass
# Handle string format (key=value; key2=value2)
for pair in trimmed.split(';'):
pair = pair.strip()
if '=' in pair:
key, _, value = pair.partition('=')
cookies[key.strip()] = value.strip()
return cookies
async def get_access_token(self) -> str:
"""Get or refresh access token from Whisk API"""
import time
# Return cached token if still valid
if self.access_token and self.token_expires > int(time.time() * 1000):
return self.access_token
async with httpx.AsyncClient() as client:
response = await client.post(
AUTH_URL,
headers={
"Content-Type": "application/json",
"Cookie": self.cookie_string
},
json={}
)
if response.status_code != 200:
raise Exception(f"Auth failed: {response.status_code} - {response.text[:200]}")
data = response.json()
self.access_token = data.get("authToken")
expires_in = int(data.get("expiresIn", 3600))
self.token_expires = int(time.time() * 1000) + (expires_in * 1000) - 60000
if not self.access_token:
raise Exception("No auth token in response")
return self.access_token
async def upload_reference_image(
self,
file_base64: str,
mime_type: str,
category: str
) -> Optional[str]:
"""Upload a reference image and return media ID"""
token = await self.get_access_token()
data_uri = f"data:{mime_type};base64,{file_base64}"
media_category = MEDIA_CATEGORIES.get(category.lower(), MEDIA_CATEGORIES["subject"])
payload = {
"mediaData": data_uri,
"imageOptions": {
"imageCategory": media_category
}
}
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
UPLOAD_URL,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {token}",
"Cookie": self.cookie_string
},
json=payload
)
if response.status_code != 200:
print(f"[WhiskClient] Upload failed: {response.status_code}")
raise Exception(f"Upload failed: {response.text[:200]}")
data = response.json()
media_id = data.get("generationId") or data.get("imageMediaId")
if not media_id:
print(f"[WhiskClient] No media ID in response: {data}")
return None
print(f"[WhiskClient] Upload successful, mediaId: {media_id}")
return media_id
async def generate(
self,
prompt: str,
aspect_ratio: str = "1:1",
refs: Optional[Dict[str, Any]] = None,
precise_mode: bool = False
) -> List[GeneratedImage]:
"""Generate images using Whisk API"""
token = await self.get_access_token()
refs = refs or {}
# Build media inputs
media_inputs = []
def add_refs(category: str, ids):
"""Helper to add refs (handles both single string and array)"""
if not ids:
return
id_list = [ids] if isinstance(ids, str) else ids
cat_enum = MEDIA_CATEGORIES.get(category.lower())
for ref_id in id_list:
if ref_id:
media_inputs.append({
"mediaId": ref_id,
"mediaCategory": cat_enum
})
add_refs("subject", refs.get("subject"))
add_refs("scene", refs.get("scene"))
add_refs("style", refs.get("style"))
# Build payload
aspect_enum = ASPECT_RATIOS.get(aspect_ratio, ASPECT_RATIOS["1:1"])
# Determine endpoint based on refs
has_refs = len(media_inputs) > 0
endpoint = RECIPE_URL if has_refs else GENERATE_URL
if has_refs:
# Recipe format (with refs)
recipe_inputs = []
def add_recipe_refs(category: str, ids):
if not ids:
return
id_list = [ids] if isinstance(ids, str) else ids
cat_enum = MEDIA_CATEGORIES.get(category.lower())
for ref_id in id_list:
if ref_id:
recipe_inputs.append({
"inputType": cat_enum,
"mediaId": ref_id
})
add_recipe_refs("subject", refs.get("subject"))
add_recipe_refs("scene", refs.get("scene"))
add_recipe_refs("style", refs.get("style"))
payload = {
"recipeInputs": recipe_inputs,
"generationConfig": {
"aspectRatio": aspect_enum,
"numberOfImages": 4,
"personalizationConfig": {}
},
"textPromptInput": {
"text": prompt
}
}
else:
# Direct imagine format (no refs)
payload = {
"imagineConfig": {
"aspectRatio": aspect_enum,
"imaginePrompt": prompt,
"numberOfImages": 4,
"imageSafetyMode": "BLOCK_SOME"
}
}
print(f"[WhiskClient] Generating with prompt: \"{prompt[:50]}...\"")
async with httpx.AsyncClient(timeout=120.0) as client:
response = await client.post(
endpoint,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {token}",
"Cookie": self.cookie_string
},
json=payload
)
if response.status_code != 200:
error_text = response.text[:500]
if "401" in error_text or "403" in error_text:
raise Exception("Whisk auth failed - cookies may be expired")
raise Exception(f"Generation failed: {response.status_code} - {error_text}")
data = response.json()
# Extract images
images: List[GeneratedImage] = []
image_list = data.get("generatedImages", [])
for i, img in enumerate(image_list):
image_data = img.get("encodedImage", "")
if image_data:
images.append(GeneratedImage(
data=image_data,
index=i,
prompt=prompt,
aspect_ratio=aspect_ratio
))
print(f"[WhiskClient] Generated {len(images)} images")
return images
async def generate_video(
self,
image_generation_id: str,
prompt: str,
image_base64: Optional[str] = None,
aspect_ratio: str = "16:9"
) -> WhiskVideoResult:
"""Generate a video from an image using Whisk Animate (Veo)"""
token = await self.get_access_token()
# If we have base64 but no generation ID, upload first
actual_gen_id = image_generation_id
if not actual_gen_id and image_base64:
actual_gen_id = await self.upload_reference_image(
image_base64, "image/png", "subject"
)
if not actual_gen_id:
raise Exception("No image generation ID available for video")
payload = {
"generationId": actual_gen_id,
"videoFxConfig": {
"aspectRatio": aspect_ratio.replace(":", "_"),
"prompt": prompt,
"duration": "5s"
}
}
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
VIDEO_URL,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {token}",
"Cookie": self.cookie_string
},
json=payload
)
if response.status_code != 200:
raise Exception(f"Video init failed: {response.text[:200]}")
data = response.json()
video_gen_id = data.get("videoGenId")
if not video_gen_id:
raise Exception("No video generation ID in response")
print(f"[WhiskClient] Video generation started: {video_gen_id}")
# Poll for completion
return await self.poll_video_status(video_gen_id, token)
async def poll_video_status(
self,
video_gen_id: str,
token: str
) -> WhiskVideoResult:
"""Poll for video generation status until complete or failed"""
max_attempts = 60
poll_interval = 3
async with httpx.AsyncClient(timeout=30.0) as client:
for attempt in range(max_attempts):
print(f"[WhiskClient] Polling video status {attempt + 1}/{max_attempts}...")
response = await client.post(
VIDEO_STATUS_URL,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {token}",
"Cookie": self.cookie_string
},
json={"videoGenId": video_gen_id}
)
if response.status_code != 200:
await asyncio.sleep(poll_interval)
continue
data = response.json()
status = data.get("status", "")
video_url = data.get("videoUri")
if status == "COMPLETE" and video_url:
print(f"[WhiskClient] Video complete: {video_url[:50]}...")
return WhiskVideoResult(
id=video_gen_id,
url=video_url,
status="complete"
)
elif status in ["FAILED", "ERROR"]:
raise Exception(f"Video generation failed: {status}")
await asyncio.sleep(poll_interval)
raise Exception("Video generation timed out")