- Add Python FastAPI backend with Pydantic validation - Port WhiskClient and MetaAIClient to Python - Create API routers for all endpoints - Add Swagger/ReDoc documentation at /docs - Update Dockerfile for multi-service container - Add lib/api.ts frontend client - Update README for V3
410 lines
14 KiB
Python
410 lines
14 KiB
Python
"""
|
|
Whisk Client for Python/FastAPI
|
|
Port of lib/whisk-client.ts
|
|
|
|
Handles:
|
|
- Cookie parsing (JSON array or string format)
|
|
- Access token retrieval from Whisk API
|
|
- Image generation with aspect ratio support
|
|
- Reference image upload
|
|
- Video generation with polling
|
|
"""
|
|
import httpx
|
|
import json
|
|
import uuid
|
|
import base64
|
|
import asyncio
|
|
from typing import Optional, Dict, List, Any
|
|
|
|
# Whisk API endpoints
|
|
AUTH_URL = "https://aisandbox-pa.googleapis.com/v1:signInWithIdp"
|
|
GENERATE_URL = "https://aisandbox-pa.googleapis.com/v1:runImagine"
|
|
RECIPE_URL = "https://aisandbox-pa.googleapis.com/v1:runRecipe"
|
|
UPLOAD_URL = "https://aisandbox-pa.googleapis.com/v1:uploadMedia"
|
|
VIDEO_URL = "https://aisandbox-pa.googleapis.com/v1:runVideoFxSingleClips"
|
|
VIDEO_STATUS_URL = "https://aisandbox-pa.googleapis.com/v1:runVideoFxSingleClipsStatusCheck"
|
|
|
|
# Aspect ratio mapping
|
|
ASPECT_RATIOS = {
|
|
"1:1": "IMAGE_ASPECT_RATIO_SQUARE",
|
|
"9:16": "IMAGE_ASPECT_RATIO_PORTRAIT",
|
|
"16:9": "IMAGE_ASPECT_RATIO_LANDSCAPE",
|
|
"4:3": "IMAGE_ASPECT_RATIO_LANDSCAPE_FOUR_THREE",
|
|
"3:4": "IMAGE_ASPECT_RATIO_PORTRAIT",
|
|
"Auto": "IMAGE_ASPECT_RATIO_SQUARE"
|
|
}
|
|
|
|
MEDIA_CATEGORIES = {
|
|
"subject": "MEDIA_CATEGORY_SUBJECT",
|
|
"scene": "MEDIA_CATEGORY_SCENE",
|
|
"style": "MEDIA_CATEGORY_STYLE"
|
|
}
|
|
|
|
|
|
class GeneratedImage:
|
|
def __init__(self, data: str, index: int, prompt: str, aspect_ratio: str):
|
|
self.data = data
|
|
self.index = index
|
|
self.prompt = prompt
|
|
self.aspect_ratio = aspect_ratio
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"data": self.data,
|
|
"index": self.index,
|
|
"prompt": self.prompt,
|
|
"aspectRatio": self.aspect_ratio
|
|
}
|
|
|
|
|
|
class WhiskVideoResult:
|
|
def __init__(self, id: str, url: Optional[str], status: str):
|
|
self.id = id
|
|
self.url = url
|
|
self.status = status
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"id": self.id,
|
|
"url": self.url,
|
|
"status": self.status
|
|
}
|
|
|
|
|
|
class WhiskClient:
|
|
def __init__(self, cookie_input: str):
|
|
self.cookies = self._parse_cookies(cookie_input)
|
|
self.access_token: Optional[str] = None
|
|
self.token_expires: int = 0
|
|
self.cookie_string = ""
|
|
|
|
if not self.cookies:
|
|
raise ValueError("No valid cookies provided")
|
|
|
|
# Build cookie string for requests
|
|
self.cookie_string = "; ".join(
|
|
f"{name}={value}" for name, value in self.cookies.items()
|
|
)
|
|
|
|
def _parse_cookies(self, input_str: str) -> Dict[str, str]:
|
|
"""Parse cookies from string or JSON format"""
|
|
if not input_str or not input_str.strip():
|
|
return {}
|
|
|
|
trimmed = input_str.strip()
|
|
cookies: Dict[str, str] = {}
|
|
|
|
# Handle JSON array format (e.g., from Cookie-Editor)
|
|
if trimmed.startswith('[') or trimmed.startswith('{'):
|
|
try:
|
|
parsed = json.loads(trimmed)
|
|
if isinstance(parsed, list):
|
|
for c in parsed:
|
|
if isinstance(c, dict) and 'name' in c and 'value' in c:
|
|
cookies[c['name']] = c['value']
|
|
return cookies
|
|
elif isinstance(parsed, dict) and 'name' in parsed and 'value' in parsed:
|
|
return {parsed['name']: parsed['value']}
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# Handle string format (key=value; key2=value2)
|
|
for pair in trimmed.split(';'):
|
|
pair = pair.strip()
|
|
if '=' in pair:
|
|
key, _, value = pair.partition('=')
|
|
cookies[key.strip()] = value.strip()
|
|
|
|
return cookies
|
|
|
|
async def get_access_token(self) -> str:
|
|
"""Get or refresh access token from Whisk API"""
|
|
import time
|
|
|
|
# Return cached token if still valid
|
|
if self.access_token and self.token_expires > int(time.time() * 1000):
|
|
return self.access_token
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
AUTH_URL,
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
"Cookie": self.cookie_string
|
|
},
|
|
json={}
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
raise Exception(f"Auth failed: {response.status_code} - {response.text[:200]}")
|
|
|
|
data = response.json()
|
|
self.access_token = data.get("authToken")
|
|
expires_in = int(data.get("expiresIn", 3600))
|
|
self.token_expires = int(time.time() * 1000) + (expires_in * 1000) - 60000
|
|
|
|
if not self.access_token:
|
|
raise Exception("No auth token in response")
|
|
|
|
return self.access_token
|
|
|
|
async def upload_reference_image(
|
|
self,
|
|
file_base64: str,
|
|
mime_type: str,
|
|
category: str
|
|
) -> Optional[str]:
|
|
"""Upload a reference image and return media ID"""
|
|
token = await self.get_access_token()
|
|
|
|
data_uri = f"data:{mime_type};base64,{file_base64}"
|
|
media_category = MEDIA_CATEGORIES.get(category.lower(), MEDIA_CATEGORIES["subject"])
|
|
|
|
payload = {
|
|
"mediaData": data_uri,
|
|
"imageOptions": {
|
|
"imageCategory": media_category
|
|
}
|
|
}
|
|
|
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
response = await client.post(
|
|
UPLOAD_URL,
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {token}",
|
|
"Cookie": self.cookie_string
|
|
},
|
|
json=payload
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
print(f"[WhiskClient] Upload failed: {response.status_code}")
|
|
raise Exception(f"Upload failed: {response.text[:200]}")
|
|
|
|
data = response.json()
|
|
media_id = data.get("generationId") or data.get("imageMediaId")
|
|
|
|
if not media_id:
|
|
print(f"[WhiskClient] No media ID in response: {data}")
|
|
return None
|
|
|
|
print(f"[WhiskClient] Upload successful, mediaId: {media_id}")
|
|
return media_id
|
|
|
|
async def generate(
|
|
self,
|
|
prompt: str,
|
|
aspect_ratio: str = "1:1",
|
|
refs: Optional[Dict[str, Any]] = None,
|
|
precise_mode: bool = False
|
|
) -> List[GeneratedImage]:
|
|
"""Generate images using Whisk API"""
|
|
token = await self.get_access_token()
|
|
refs = refs or {}
|
|
|
|
# Build media inputs
|
|
media_inputs = []
|
|
|
|
def add_refs(category: str, ids):
|
|
"""Helper to add refs (handles both single string and array)"""
|
|
if not ids:
|
|
return
|
|
id_list = [ids] if isinstance(ids, str) else ids
|
|
cat_enum = MEDIA_CATEGORIES.get(category.lower())
|
|
for ref_id in id_list:
|
|
if ref_id:
|
|
media_inputs.append({
|
|
"mediaId": ref_id,
|
|
"mediaCategory": cat_enum
|
|
})
|
|
|
|
add_refs("subject", refs.get("subject"))
|
|
add_refs("scene", refs.get("scene"))
|
|
add_refs("style", refs.get("style"))
|
|
|
|
# Build payload
|
|
aspect_enum = ASPECT_RATIOS.get(aspect_ratio, ASPECT_RATIOS["1:1"])
|
|
|
|
# Determine endpoint based on refs
|
|
has_refs = len(media_inputs) > 0
|
|
endpoint = RECIPE_URL if has_refs else GENERATE_URL
|
|
|
|
if has_refs:
|
|
# Recipe format (with refs)
|
|
recipe_inputs = []
|
|
|
|
def add_recipe_refs(category: str, ids):
|
|
if not ids:
|
|
return
|
|
id_list = [ids] if isinstance(ids, str) else ids
|
|
cat_enum = MEDIA_CATEGORIES.get(category.lower())
|
|
for ref_id in id_list:
|
|
if ref_id:
|
|
recipe_inputs.append({
|
|
"inputType": cat_enum,
|
|
"mediaId": ref_id
|
|
})
|
|
|
|
add_recipe_refs("subject", refs.get("subject"))
|
|
add_recipe_refs("scene", refs.get("scene"))
|
|
add_recipe_refs("style", refs.get("style"))
|
|
|
|
payload = {
|
|
"recipeInputs": recipe_inputs,
|
|
"generationConfig": {
|
|
"aspectRatio": aspect_enum,
|
|
"numberOfImages": 4,
|
|
"personalizationConfig": {}
|
|
},
|
|
"textPromptInput": {
|
|
"text": prompt
|
|
}
|
|
}
|
|
else:
|
|
# Direct imagine format (no refs)
|
|
payload = {
|
|
"imagineConfig": {
|
|
"aspectRatio": aspect_enum,
|
|
"imaginePrompt": prompt,
|
|
"numberOfImages": 4,
|
|
"imageSafetyMode": "BLOCK_SOME"
|
|
}
|
|
}
|
|
|
|
print(f"[WhiskClient] Generating with prompt: \"{prompt[:50]}...\"")
|
|
|
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
response = await client.post(
|
|
endpoint,
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {token}",
|
|
"Cookie": self.cookie_string
|
|
},
|
|
json=payload
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
error_text = response.text[:500]
|
|
if "401" in error_text or "403" in error_text:
|
|
raise Exception("Whisk auth failed - cookies may be expired")
|
|
raise Exception(f"Generation failed: {response.status_code} - {error_text}")
|
|
|
|
data = response.json()
|
|
|
|
# Extract images
|
|
images: List[GeneratedImage] = []
|
|
image_list = data.get("generatedImages", [])
|
|
|
|
for i, img in enumerate(image_list):
|
|
image_data = img.get("encodedImage", "")
|
|
if image_data:
|
|
images.append(GeneratedImage(
|
|
data=image_data,
|
|
index=i,
|
|
prompt=prompt,
|
|
aspect_ratio=aspect_ratio
|
|
))
|
|
|
|
print(f"[WhiskClient] Generated {len(images)} images")
|
|
return images
|
|
|
|
async def generate_video(
|
|
self,
|
|
image_generation_id: str,
|
|
prompt: str,
|
|
image_base64: Optional[str] = None,
|
|
aspect_ratio: str = "16:9"
|
|
) -> WhiskVideoResult:
|
|
"""Generate a video from an image using Whisk Animate (Veo)"""
|
|
token = await self.get_access_token()
|
|
|
|
# If we have base64 but no generation ID, upload first
|
|
actual_gen_id = image_generation_id
|
|
if not actual_gen_id and image_base64:
|
|
actual_gen_id = await self.upload_reference_image(
|
|
image_base64, "image/png", "subject"
|
|
)
|
|
|
|
if not actual_gen_id:
|
|
raise Exception("No image generation ID available for video")
|
|
|
|
payload = {
|
|
"generationId": actual_gen_id,
|
|
"videoFxConfig": {
|
|
"aspectRatio": aspect_ratio.replace(":", "_"),
|
|
"prompt": prompt,
|
|
"duration": "5s"
|
|
}
|
|
}
|
|
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
response = await client.post(
|
|
VIDEO_URL,
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {token}",
|
|
"Cookie": self.cookie_string
|
|
},
|
|
json=payload
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
raise Exception(f"Video init failed: {response.text[:200]}")
|
|
|
|
data = response.json()
|
|
|
|
video_gen_id = data.get("videoGenId")
|
|
if not video_gen_id:
|
|
raise Exception("No video generation ID in response")
|
|
|
|
print(f"[WhiskClient] Video generation started: {video_gen_id}")
|
|
|
|
# Poll for completion
|
|
return await self.poll_video_status(video_gen_id, token)
|
|
|
|
async def poll_video_status(
|
|
self,
|
|
video_gen_id: str,
|
|
token: str
|
|
) -> WhiskVideoResult:
|
|
"""Poll for video generation status until complete or failed"""
|
|
max_attempts = 60
|
|
poll_interval = 3
|
|
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
for attempt in range(max_attempts):
|
|
print(f"[WhiskClient] Polling video status {attempt + 1}/{max_attempts}...")
|
|
|
|
response = await client.post(
|
|
VIDEO_STATUS_URL,
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {token}",
|
|
"Cookie": self.cookie_string
|
|
},
|
|
json={"videoGenId": video_gen_id}
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
await asyncio.sleep(poll_interval)
|
|
continue
|
|
|
|
data = response.json()
|
|
status = data.get("status", "")
|
|
video_url = data.get("videoUri")
|
|
|
|
if status == "COMPLETE" and video_url:
|
|
print(f"[WhiskClient] Video complete: {video_url[:50]}...")
|
|
return WhiskVideoResult(
|
|
id=video_gen_id,
|
|
url=video_url,
|
|
status="complete"
|
|
)
|
|
elif status in ["FAILED", "ERROR"]:
|
|
raise Exception(f"Video generation failed: {status}")
|
|
|
|
await asyncio.sleep(poll_interval)
|
|
|
|
raise Exception("Video generation timed out")
|