This commit is contained in:
Timothy Jaeryang Baek
2026-04-17 11:12:42 +09:00
parent 398718d505
commit 128cf41fce
2 changed files with 45 additions and 38 deletions

View File

@@ -11,7 +11,7 @@ from typing import Optional
from urllib.parse import quote
import aiohttp
import requests
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
from fastapi.responses import FileResponse
@@ -52,32 +52,36 @@ IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
router = APIRouter()
def set_image_model(request: Request, model: str):
async def set_image_model(request: Request, model: str):
log.info(f'Setting image model to {model}')
request.app.state.config.IMAGE_GENERATION_MODEL = model
if request.app.state.config.IMAGE_GENERATION_ENGINE in ['', 'automatic1111']:
api_auth = get_automatic1111_api_auth(request)
try:
r = requests.get(
session = await get_session()
async with session.get(
url=f'{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options',
headers={'authorization': api_auth},
)
options = r.json()
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
options = await r.json()
if model != options['sd_model_checkpoint']:
options['sd_model_checkpoint'] = model
r = requests.post(
async with session.post(
url=f'{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options',
json=options,
headers={'authorization': api_auth},
)
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
r.raise_for_status()
except Exception as e:
log.debug(f'{e}')
return request.app.state.config.IMAGE_GENERATION_MODEL
def get_image_model(request):
async def get_image_model(request):
if request.app.state.config.IMAGE_GENERATION_ENGINE == 'openai':
return (
request.app.state.config.IMAGE_GENERATION_MODEL
@@ -99,11 +103,13 @@ def get_image_model(request):
or request.app.state.config.IMAGE_GENERATION_ENGINE == ''
):
try:
r = requests.get(
session = await get_session()
async with session.get(
url=f'{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options',
headers={'authorization': get_automatic1111_api_auth(request)},
)
options = r.json()
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
options = await r.json()
return options['sd_model_checkpoint']
except Exception as e:
request.app.state.config.ENABLE_IMAGE_GENERATION = False
@@ -200,7 +206,7 @@ async def update_config(request: Request, form_data: ImagesConfig, user=Depends(
request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = form_data.ENABLE_IMAGE_PROMPT_GENERATION
request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.IMAGE_GENERATION_ENGINE
set_image_model(request, form_data.IMAGE_GENERATION_MODEL)
await set_image_model(request, form_data.IMAGE_GENERATION_MODEL)
if form_data.IMAGE_SIZE == 'auto' and not re.match(
IMAGE_AUTO_SIZE_MODELS_REGEX_PATTERN, form_data.IMAGE_GENERATION_MODEL
):
@@ -437,21 +443,22 @@ class CreateImageForm(BaseModel):
GenerateImageForm = CreateImageForm # Alias for backward compatibility
def get_image_data(data: str, headers=None):
async def get_image_data(data: str, headers=None):
try:
if data.startswith('http://') or data.startswith('https://'):
if headers:
r = requests.get(data, headers=headers)
else:
r = requests.get(data)
r.raise_for_status()
if r.headers['content-type'].split('/')[0] == 'image':
mime_type = r.headers['content-type']
return r.content, mime_type
else:
log.error('Url does not point to an image.')
return None
session = await get_session()
async with session.get(
data,
headers=headers,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
r.raise_for_status()
content_type = r.headers.get('content-type', '')
if content_type.split('/')[0] == 'image':
return await r.read(), content_type
else:
log.error('Url does not point to an image.')
return None, None
else:
if ',' in data:
header, encoded = data.split(',', 1)
@@ -541,7 +548,7 @@ async def image_generations(
metadata = metadata or {}
model = get_image_model(request)
model = await get_image_model(request)
try:
if request.app.state.config.IMAGE_GENERATION_ENGINE == 'openai':
@@ -595,12 +602,12 @@ async def image_generations(
for image in res['data']:
if image_url := image.get('url', None):
image_data, content_type = get_image_data(
image_data, content_type = await get_image_data(
image_url,
{k: v for k, v in headers.items() if k != 'Content-Type'},
)
else:
image_data, content_type = get_image_data(image['b64_json'])
image_data, content_type = await get_image_data(image['b64_json'])
_, url = await upload_image(request, image_data, content_type, {**data, **metadata}, user)
images.append({'url': url})
@@ -645,14 +652,14 @@ async def image_generations(
if model.endswith(':predict'):
for image in res['predictions']:
image_data, content_type = get_image_data(image['bytesBase64Encoded'])
image_data, content_type = await get_image_data(image['bytesBase64Encoded'])
_, url = await upload_image(request, image_data, content_type, {**data, **metadata}, user)
images.append({'url': url})
elif model.endswith(':generateContent'):
for image in res['candidates']:
for part in image['content']['parts']:
if part.get('inlineData', {}).get('data'):
image_data, content_type = get_image_data(part['inlineData']['data'])
image_data, content_type = await get_image_data(part['inlineData']['data'])
_, url = await upload_image(
request,
image_data,
@@ -705,7 +712,7 @@ async def image_generations(
if request.app.state.config.COMFYUI_API_KEY:
headers = {'Authorization': f'Bearer {request.app.state.config.COMFYUI_API_KEY}'}
image_data, content_type = get_image_data(image['url'], headers)
image_data, content_type = await get_image_data(image['url'], headers)
_, url = await upload_image(
request,
image_data,
@@ -720,7 +727,7 @@ async def image_generations(
or request.app.state.config.IMAGE_GENERATION_ENGINE == ''
):
if form_data.model:
set_image_model(request, form_data.model)
await set_image_model(request, form_data.model)
data = {
'prompt': form_data.prompt,
@@ -751,7 +758,7 @@ async def image_generations(
images = []
for image in res['images']:
image_data, content_type = get_image_data(image)
image_data, content_type = await get_image_data(image)
_, url = await upload_image(
request,
image_data,
@@ -919,12 +926,12 @@ async def image_edits(
images = []
for image in res['data']:
if image_url := image.get('url', None):
image_data, content_type = get_image_data(
image_data, content_type = await get_image_data(
image_url,
{k: v for k, v in headers.items() if k != 'Content-Type'},
)
else:
image_data, content_type = get_image_data(image['b64_json'])
image_data, content_type = await get_image_data(image['b64_json'])
_, url = await upload_image(request, image_data, content_type, {**data, **metadata}, user)
images.append({'url': url})
@@ -975,7 +982,7 @@ async def image_edits(
for image in res['candidates']:
for part in image['content']['parts']:
if part.get('inlineData', {}).get('data'):
image_data, content_type = get_image_data(part['inlineData']['data'])
image_data, content_type = await get_image_data(part['inlineData']['data'])
_, url = await upload_image(
request,
image_data,
@@ -1055,7 +1062,7 @@ async def image_edits(
if request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY:
headers = {'Authorization': f'Bearer {request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY}'}
image_data, content_type = get_image_data(image_url, headers)
image_data, content_type = await get_image_data(image_url, headers)
_, url = await upload_image(
request,
image_data,

View File

@@ -71,7 +71,7 @@ async def get_image_url_from_base64(request, base64_image_string, metadata, user
if BASE64_IMAGE_URL_PREFIX.match(base64_image_string):
image_url = ''
# Extract base64 image data from the line
image_data, content_type = get_image_data(base64_image_string)
image_data, content_type = await get_image_data(base64_image_string)
if image_data is not None:
_, image_url = await upload_image(
request,