mirror of
https://github.com/open-webui/open-webui.git
synced 2026-04-30 01:10:17 -05:00
enh: gemini flash image generation support
This commit is contained in:
@@ -97,33 +97,37 @@ def get_image_model(request):
|
||||
|
||||
class ImagesConfig(BaseModel):
|
||||
ENABLE_IMAGE_GENERATION: bool
|
||||
IMAGE_GENERATION_ENGINE: str
|
||||
ENABLE_IMAGE_PROMPT_GENERATION: bool
|
||||
|
||||
IMAGE_GENERATION_ENGINE: str
|
||||
IMAGE_GENERATION_MODEL: str
|
||||
IMAGE_SIZE: str
|
||||
IMAGE_STEPS: int
|
||||
IMAGE_SIZE: Optional[str]
|
||||
IMAGE_STEPS: Optional[int]
|
||||
|
||||
IMAGES_OPENAI_API_BASE_URL: str
|
||||
IMAGES_OPENAI_API_KEY: str
|
||||
IMAGES_OPENAI_API_VERSION: str
|
||||
|
||||
AUTOMATIC1111_BASE_URL: str
|
||||
AUTOMATIC1111_API_AUTH: str
|
||||
AUTOMATIC1111_PARAMS: Optional[dict | str]
|
||||
|
||||
COMFYUI_BASE_URL: str
|
||||
COMFYUI_API_KEY: str
|
||||
COMFYUI_WORKFLOW: str
|
||||
COMFYUI_WORKFLOW_NODES: list[dict]
|
||||
|
||||
IMAGES_GEMINI_API_BASE_URL: str
|
||||
IMAGES_GEMINI_API_KEY: str
|
||||
IMAGES_GEMINI_ENDPOINT_METHOD: str
|
||||
|
||||
|
||||
@router.get("/config", response_model=ImagesConfig)
|
||||
async def get_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"ENABLE_IMAGE_GENERATION": request.app.state.config.ENABLE_IMAGE_GENERATION,
|
||||
"IMAGE_GENERATION_ENGINE": request.app.state.config.IMAGE_GENERATION_ENGINE,
|
||||
"ENABLE_IMAGE_PROMPT_GENERATION": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
|
||||
"IMAGE_GENERATION_ENGINE": request.app.state.config.IMAGE_GENERATION_ENGINE,
|
||||
"IMAGE_GENERATION_MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
|
||||
"IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
|
||||
"IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
|
||||
@@ -139,6 +143,7 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
|
||||
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
||||
"IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
|
||||
"IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
||||
"IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
|
||||
}
|
||||
|
||||
|
||||
@@ -146,12 +151,12 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
|
||||
async def update_config(
|
||||
request: Request, form_data: ImagesConfig, user=Depends(get_admin_user)
|
||||
):
|
||||
request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.IMAGE_GENERATION_ENGINE
|
||||
request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.ENABLE_IMAGE_GENERATION
|
||||
request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = (
|
||||
form_data.ENABLE_IMAGE_PROMPT_GENERATION
|
||||
)
|
||||
|
||||
request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.IMAGE_GENERATION_ENGINE
|
||||
set_image_model(request, form_data.IMAGE_GENERATION_MODEL)
|
||||
if (
|
||||
form_data.IMAGE_SIZE == "auto"
|
||||
@@ -165,7 +170,11 @@ async def update_config(
|
||||
)
|
||||
|
||||
pattern = r"^\d+x\d+$"
|
||||
if form_data.IMAGE_SIZE == "auto" or re.match(pattern, form_data.IMAGE_SIZE):
|
||||
if (
|
||||
form_data.IMAGE_SIZE == "auto"
|
||||
or form_data.IMAGE_SIZE == ""
|
||||
or re.match(pattern, form_data.IMAGE_SIZE)
|
||||
):
|
||||
request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
|
||||
else:
|
||||
raise HTTPException(
|
||||
@@ -202,11 +211,14 @@ async def update_config(
|
||||
form_data.IMAGES_GEMINI_API_BASE_URL
|
||||
)
|
||||
request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.IMAGES_GEMINI_API_KEY
|
||||
request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD = (
|
||||
form_data.IMAGES_GEMINI_ENDPOINT_METHOD
|
||||
)
|
||||
|
||||
return {
|
||||
"ENABLE_IMAGE_GENERATION": request.app.state.config.ENABLE_IMAGE_GENERATION,
|
||||
"IMAGE_GENERATION_ENGINE": request.app.state.config.IMAGE_GENERATION_ENGINE,
|
||||
"ENABLE_IMAGE_PROMPT_GENERATION": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
|
||||
"IMAGE_GENERATION_ENGINE": request.app.state.config.IMAGE_GENERATION_ENGINE,
|
||||
"IMAGE_GENERATION_MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
|
||||
"IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
|
||||
"IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
|
||||
@@ -222,6 +234,7 @@ async def update_config(
|
||||
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
||||
"IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
|
||||
"IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
||||
"IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
|
||||
}
|
||||
|
||||
|
||||
@@ -365,9 +378,7 @@ GenerateImageForm = CreateImageForm # Alias for backward compatibility
|
||||
|
||||
def get_image_data(data: str, headers=None):
|
||||
try:
|
||||
# if data url
|
||||
|
||||
if data.startswith("http"):
|
||||
if data.startswith("http://") or data.startswith("https://"):
|
||||
if headers:
|
||||
r = requests.get(data, headers=headers)
|
||||
else:
|
||||
@@ -495,22 +506,37 @@ async def image_generations(
|
||||
return images
|
||||
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
||||
headers = {}
|
||||
headers["Content-Type"] = "application/json"
|
||||
headers["x-goog-api-key"] = request.app.state.config.IMAGES_GEMINI_API_KEY
|
||||
|
||||
data = {
|
||||
"instances": {"prompt": form_data.prompt},
|
||||
"parameters": {
|
||||
"sampleCount": form_data.n,
|
||||
"outputOptions": {"mimeType": "image/png"},
|
||||
},
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-goog-api-key": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
||||
}
|
||||
|
||||
data = {}
|
||||
|
||||
if (
|
||||
request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD == ""
|
||||
or request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD == "predict"
|
||||
):
|
||||
model = f"{model}:predict"
|
||||
data = {
|
||||
"instances": {"prompt": form_data.prompt},
|
||||
"parameters": {
|
||||
"sampleCount": form_data.n,
|
||||
"outputOptions": {"mimeType": "image/png"},
|
||||
},
|
||||
}
|
||||
|
||||
elif (
|
||||
request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD
|
||||
== "generateContent"
|
||||
):
|
||||
model = f"{model}:generateContent"
|
||||
data = {"contents": [{"parts": [{"text": form_data.prompt}]}]}
|
||||
|
||||
# Use asyncio.to_thread for the requests.post call
|
||||
r = await asyncio.to_thread(
|
||||
requests.post,
|
||||
url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}:predict",
|
||||
url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}",
|
||||
json=data,
|
||||
headers=headers,
|
||||
)
|
||||
@@ -519,10 +545,25 @@ async def image_generations(
|
||||
res = r.json()
|
||||
|
||||
images = []
|
||||
for image in res["predictions"]:
|
||||
image_data, content_type = get_image_data(image["bytesBase64Encoded"])
|
||||
url = upload_image(request, image_data, content_type, data, user)
|
||||
images.append({"url": url})
|
||||
|
||||
if model.endswith(":predict"):
|
||||
for image in res["predictions"]:
|
||||
image_data, content_type = get_image_data(
|
||||
image["bytesBase64Encoded"]
|
||||
)
|
||||
url = upload_image(request, image_data, content_type, data, user)
|
||||
images.append({"url": url})
|
||||
elif model.endswith(":generateContent"):
|
||||
for image in res["candidates"]:
|
||||
for part in image["content"]["parts"]:
|
||||
if part.get("inlineData", {}).get("data"):
|
||||
image_data, content_type = get_image_data(
|
||||
part["inlineData"]["data"]
|
||||
)
|
||||
url = upload_image(
|
||||
request, image_data, content_type, data, user
|
||||
)
|
||||
images.append({"url": url})
|
||||
|
||||
return images
|
||||
|
||||
|
||||
Reference in New Issue
Block a user