enh: gemini flash image generation support

This commit is contained in:
Timothy Jaeryang Baek
2025-11-05 01:59:16 -05:00
parent 72900cd686
commit 8d34fcb586
4 changed files with 517 additions and 101 deletions

View File

@@ -97,33 +97,37 @@ def get_image_model(request):
class ImagesConfig(BaseModel):
ENABLE_IMAGE_GENERATION: bool
IMAGE_GENERATION_ENGINE: str
ENABLE_IMAGE_PROMPT_GENERATION: bool
IMAGE_GENERATION_ENGINE: str
IMAGE_GENERATION_MODEL: str
IMAGE_SIZE: str
IMAGE_STEPS: int
IMAGE_SIZE: Optional[str]
IMAGE_STEPS: Optional[int]
IMAGES_OPENAI_API_BASE_URL: str
IMAGES_OPENAI_API_KEY: str
IMAGES_OPENAI_API_VERSION: str
AUTOMATIC1111_BASE_URL: str
AUTOMATIC1111_API_AUTH: str
AUTOMATIC1111_PARAMS: Optional[dict | str]
COMFYUI_BASE_URL: str
COMFYUI_API_KEY: str
COMFYUI_WORKFLOW: str
COMFYUI_WORKFLOW_NODES: list[dict]
IMAGES_GEMINI_API_BASE_URL: str
IMAGES_GEMINI_API_KEY: str
IMAGES_GEMINI_ENDPOINT_METHOD: str
@router.get("/config", response_model=ImagesConfig)
async def get_config(request: Request, user=Depends(get_admin_user)):
return {
"ENABLE_IMAGE_GENERATION": request.app.state.config.ENABLE_IMAGE_GENERATION,
"IMAGE_GENERATION_ENGINE": request.app.state.config.IMAGE_GENERATION_ENGINE,
"ENABLE_IMAGE_PROMPT_GENERATION": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
"IMAGE_GENERATION_ENGINE": request.app.state.config.IMAGE_GENERATION_ENGINE,
"IMAGE_GENERATION_MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
"IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
"IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
@@ -139,6 +143,7 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
"IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
"IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
"IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
}
@@ -146,12 +151,12 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
async def update_config(
request: Request, form_data: ImagesConfig, user=Depends(get_admin_user)
):
request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.IMAGE_GENERATION_ENGINE
request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.ENABLE_IMAGE_GENERATION
request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = (
form_data.ENABLE_IMAGE_PROMPT_GENERATION
)
request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.IMAGE_GENERATION_ENGINE
set_image_model(request, form_data.IMAGE_GENERATION_MODEL)
if (
form_data.IMAGE_SIZE == "auto"
@@ -165,7 +170,11 @@ async def update_config(
)
pattern = r"^\d+x\d+$"
if form_data.IMAGE_SIZE == "auto" or re.match(pattern, form_data.IMAGE_SIZE):
if (
form_data.IMAGE_SIZE == "auto"
or form_data.IMAGE_SIZE == ""
or re.match(pattern, form_data.IMAGE_SIZE)
):
request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
else:
raise HTTPException(
@@ -202,11 +211,14 @@ async def update_config(
form_data.IMAGES_GEMINI_API_BASE_URL
)
request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.IMAGES_GEMINI_API_KEY
request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD = (
form_data.IMAGES_GEMINI_ENDPOINT_METHOD
)
return {
"ENABLE_IMAGE_GENERATION": request.app.state.config.ENABLE_IMAGE_GENERATION,
"IMAGE_GENERATION_ENGINE": request.app.state.config.IMAGE_GENERATION_ENGINE,
"ENABLE_IMAGE_PROMPT_GENERATION": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
"IMAGE_GENERATION_ENGINE": request.app.state.config.IMAGE_GENERATION_ENGINE,
"IMAGE_GENERATION_MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
"IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
"IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
@@ -222,6 +234,7 @@ async def update_config(
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
"IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
"IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
"IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
}
@@ -365,9 +378,7 @@ GenerateImageForm = CreateImageForm # Alias for backward compatibility
def get_image_data(data: str, headers=None):
try:
# if data url
if data.startswith("http"):
if data.startswith("http://") or data.startswith("https://"):
if headers:
r = requests.get(data, headers=headers)
else:
@@ -495,22 +506,37 @@ async def image_generations(
return images
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
headers = {}
headers["Content-Type"] = "application/json"
headers["x-goog-api-key"] = request.app.state.config.IMAGES_GEMINI_API_KEY
data = {
"instances": {"prompt": form_data.prompt},
"parameters": {
"sampleCount": form_data.n,
"outputOptions": {"mimeType": "image/png"},
},
headers = {
"Content-Type": "application/json",
"x-goog-api-key": request.app.state.config.IMAGES_GEMINI_API_KEY,
}
data = {}
if (
request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD == ""
or request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD == "predict"
):
model = f"{model}:predict"
data = {
"instances": {"prompt": form_data.prompt},
"parameters": {
"sampleCount": form_data.n,
"outputOptions": {"mimeType": "image/png"},
},
}
elif (
request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD
== "generateContent"
):
model = f"{model}:generateContent"
data = {"contents": [{"parts": [{"text": form_data.prompt}]}]}
# Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread(
requests.post,
url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}:predict",
url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}",
json=data,
headers=headers,
)
@@ -519,10 +545,25 @@ async def image_generations(
res = r.json()
images = []
for image in res["predictions"]:
image_data, content_type = get_image_data(image["bytesBase64Encoded"])
url = upload_image(request, image_data, content_type, data, user)
images.append({"url": url})
if model.endswith(":predict"):
for image in res["predictions"]:
image_data, content_type = get_image_data(
image["bytesBase64Encoded"]
)
url = upload_image(request, image_data, content_type, data, user)
images.append({"url": url})
elif model.endswith(":generateContent"):
for image in res["candidates"]:
for part in image["content"]["parts"]:
if part.get("inlineData", {}).get("data"):
image_data, content_type = get_image_data(
part["inlineData"]["data"]
)
url = upload_image(
request, image_data, content_type, data, user
)
images.append({"url": url})
return images