mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-05 10:28:06 -05:00
refac
This commit is contained in:
@@ -11,7 +11,7 @@ from typing import Optional
|
||||
|
||||
from urllib.parse import quote
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
@@ -52,32 +52,36 @@ IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def set_image_model(request: Request, model: str):
|
||||
async def set_image_model(request: Request, model: str):
|
||||
log.info(f'Setting image model to {model}')
|
||||
request.app.state.config.IMAGE_GENERATION_MODEL = model
|
||||
if request.app.state.config.IMAGE_GENERATION_ENGINE in ['', 'automatic1111']:
|
||||
api_auth = get_automatic1111_api_auth(request)
|
||||
|
||||
try:
|
||||
r = requests.get(
|
||||
session = await get_session()
|
||||
async with session.get(
|
||||
url=f'{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options',
|
||||
headers={'authorization': api_auth},
|
||||
)
|
||||
options = r.json()
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
options = await r.json()
|
||||
if model != options['sd_model_checkpoint']:
|
||||
options['sd_model_checkpoint'] = model
|
||||
r = requests.post(
|
||||
async with session.post(
|
||||
url=f'{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options',
|
||||
json=options,
|
||||
headers={'authorization': api_auth},
|
||||
)
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
except Exception as e:
|
||||
log.debug(f'{e}')
|
||||
|
||||
return request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
|
||||
|
||||
def get_image_model(request):
|
||||
async def get_image_model(request):
|
||||
if request.app.state.config.IMAGE_GENERATION_ENGINE == 'openai':
|
||||
return (
|
||||
request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
@@ -99,11 +103,13 @@ def get_image_model(request):
|
||||
or request.app.state.config.IMAGE_GENERATION_ENGINE == ''
|
||||
):
|
||||
try:
|
||||
r = requests.get(
|
||||
session = await get_session()
|
||||
async with session.get(
|
||||
url=f'{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options',
|
||||
headers={'authorization': get_automatic1111_api_auth(request)},
|
||||
)
|
||||
options = r.json()
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
options = await r.json()
|
||||
return options['sd_model_checkpoint']
|
||||
except Exception as e:
|
||||
request.app.state.config.ENABLE_IMAGE_GENERATION = False
|
||||
@@ -200,7 +206,7 @@ async def update_config(request: Request, form_data: ImagesConfig, user=Depends(
|
||||
request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = form_data.ENABLE_IMAGE_PROMPT_GENERATION
|
||||
|
||||
request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.IMAGE_GENERATION_ENGINE
|
||||
set_image_model(request, form_data.IMAGE_GENERATION_MODEL)
|
||||
await set_image_model(request, form_data.IMAGE_GENERATION_MODEL)
|
||||
if form_data.IMAGE_SIZE == 'auto' and not re.match(
|
||||
IMAGE_AUTO_SIZE_MODELS_REGEX_PATTERN, form_data.IMAGE_GENERATION_MODEL
|
||||
):
|
||||
@@ -437,21 +443,22 @@ class CreateImageForm(BaseModel):
|
||||
GenerateImageForm = CreateImageForm # Alias for backward compatibility
|
||||
|
||||
|
||||
def get_image_data(data: str, headers=None):
|
||||
async def get_image_data(data: str, headers=None):
|
||||
try:
|
||||
if data.startswith('http://') or data.startswith('https://'):
|
||||
if headers:
|
||||
r = requests.get(data, headers=headers)
|
||||
else:
|
||||
r = requests.get(data)
|
||||
|
||||
r.raise_for_status()
|
||||
if r.headers['content-type'].split('/')[0] == 'image':
|
||||
mime_type = r.headers['content-type']
|
||||
return r.content, mime_type
|
||||
else:
|
||||
log.error('Url does not point to an image.')
|
||||
return None
|
||||
session = await get_session()
|
||||
async with session.get(
|
||||
data,
|
||||
headers=headers,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
content_type = r.headers.get('content-type', '')
|
||||
if content_type.split('/')[0] == 'image':
|
||||
return await r.read(), content_type
|
||||
else:
|
||||
log.error('Url does not point to an image.')
|
||||
return None, None
|
||||
else:
|
||||
if ',' in data:
|
||||
header, encoded = data.split(',', 1)
|
||||
@@ -541,7 +548,7 @@ async def image_generations(
|
||||
|
||||
metadata = metadata or {}
|
||||
|
||||
model = get_image_model(request)
|
||||
model = await get_image_model(request)
|
||||
|
||||
try:
|
||||
if request.app.state.config.IMAGE_GENERATION_ENGINE == 'openai':
|
||||
@@ -595,12 +602,12 @@ async def image_generations(
|
||||
|
||||
for image in res['data']:
|
||||
if image_url := image.get('url', None):
|
||||
image_data, content_type = get_image_data(
|
||||
image_data, content_type = await get_image_data(
|
||||
image_url,
|
||||
{k: v for k, v in headers.items() if k != 'Content-Type'},
|
||||
)
|
||||
else:
|
||||
image_data, content_type = get_image_data(image['b64_json'])
|
||||
image_data, content_type = await get_image_data(image['b64_json'])
|
||||
|
||||
_, url = await upload_image(request, image_data, content_type, {**data, **metadata}, user)
|
||||
images.append({'url': url})
|
||||
@@ -645,14 +652,14 @@ async def image_generations(
|
||||
|
||||
if model.endswith(':predict'):
|
||||
for image in res['predictions']:
|
||||
image_data, content_type = get_image_data(image['bytesBase64Encoded'])
|
||||
image_data, content_type = await get_image_data(image['bytesBase64Encoded'])
|
||||
_, url = await upload_image(request, image_data, content_type, {**data, **metadata}, user)
|
||||
images.append({'url': url})
|
||||
elif model.endswith(':generateContent'):
|
||||
for image in res['candidates']:
|
||||
for part in image['content']['parts']:
|
||||
if part.get('inlineData', {}).get('data'):
|
||||
image_data, content_type = get_image_data(part['inlineData']['data'])
|
||||
image_data, content_type = await get_image_data(part['inlineData']['data'])
|
||||
_, url = await upload_image(
|
||||
request,
|
||||
image_data,
|
||||
@@ -705,7 +712,7 @@ async def image_generations(
|
||||
if request.app.state.config.COMFYUI_API_KEY:
|
||||
headers = {'Authorization': f'Bearer {request.app.state.config.COMFYUI_API_KEY}'}
|
||||
|
||||
image_data, content_type = get_image_data(image['url'], headers)
|
||||
image_data, content_type = await get_image_data(image['url'], headers)
|
||||
_, url = await upload_image(
|
||||
request,
|
||||
image_data,
|
||||
@@ -720,7 +727,7 @@ async def image_generations(
|
||||
or request.app.state.config.IMAGE_GENERATION_ENGINE == ''
|
||||
):
|
||||
if form_data.model:
|
||||
set_image_model(request, form_data.model)
|
||||
await set_image_model(request, form_data.model)
|
||||
|
||||
data = {
|
||||
'prompt': form_data.prompt,
|
||||
@@ -751,7 +758,7 @@ async def image_generations(
|
||||
images = []
|
||||
|
||||
for image in res['images']:
|
||||
image_data, content_type = get_image_data(image)
|
||||
image_data, content_type = await get_image_data(image)
|
||||
_, url = await upload_image(
|
||||
request,
|
||||
image_data,
|
||||
@@ -919,12 +926,12 @@ async def image_edits(
|
||||
images = []
|
||||
for image in res['data']:
|
||||
if image_url := image.get('url', None):
|
||||
image_data, content_type = get_image_data(
|
||||
image_data, content_type = await get_image_data(
|
||||
image_url,
|
||||
{k: v for k, v in headers.items() if k != 'Content-Type'},
|
||||
)
|
||||
else:
|
||||
image_data, content_type = get_image_data(image['b64_json'])
|
||||
image_data, content_type = await get_image_data(image['b64_json'])
|
||||
|
||||
_, url = await upload_image(request, image_data, content_type, {**data, **metadata}, user)
|
||||
images.append({'url': url})
|
||||
@@ -975,7 +982,7 @@ async def image_edits(
|
||||
for image in res['candidates']:
|
||||
for part in image['content']['parts']:
|
||||
if part.get('inlineData', {}).get('data'):
|
||||
image_data, content_type = get_image_data(part['inlineData']['data'])
|
||||
image_data, content_type = await get_image_data(part['inlineData']['data'])
|
||||
_, url = await upload_image(
|
||||
request,
|
||||
image_data,
|
||||
@@ -1055,7 +1062,7 @@ async def image_edits(
|
||||
if request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY:
|
||||
headers = {'Authorization': f'Bearer {request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY}'}
|
||||
|
||||
image_data, content_type = get_image_data(image_url, headers)
|
||||
image_data, content_type = await get_image_data(image_url, headers)
|
||||
_, url = await upload_image(
|
||||
request,
|
||||
image_data,
|
||||
|
||||
@@ -71,7 +71,7 @@ async def get_image_url_from_base64(request, base64_image_string, metadata, user
|
||||
if BASE64_IMAGE_URL_PREFIX.match(base64_image_string):
|
||||
image_url = ''
|
||||
# Extract base64 image data from the line
|
||||
image_data, content_type = get_image_data(base64_image_string)
|
||||
image_data, content_type = await get_image_data(base64_image_string)
|
||||
if image_data is not None:
|
||||
_, image_url = await upload_image(
|
||||
request,
|
||||
|
||||
Reference in New Issue
Block a user