diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index 37dda2eb1b..930de8bc66 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -11,7 +11,7 @@ from typing import Optional from urllib.parse import quote import aiohttp -import requests + from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile from fastapi.responses import FileResponse @@ -52,32 +52,36 @@ IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True) router = APIRouter() -def set_image_model(request: Request, model: str): +async def set_image_model(request: Request, model: str): log.info(f'Setting image model to {model}') request.app.state.config.IMAGE_GENERATION_MODEL = model if request.app.state.config.IMAGE_GENERATION_ENGINE in ['', 'automatic1111']: api_auth = get_automatic1111_api_auth(request) try: - r = requests.get( + session = await get_session() + async with session.get( url=f'{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options', headers={'authorization': api_auth}, - ) - options = r.json() + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) as r: + options = await r.json() if model != options['sd_model_checkpoint']: options['sd_model_checkpoint'] = model - r = requests.post( + async with session.post( url=f'{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options', json=options, headers={'authorization': api_auth}, - ) + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) as r: + r.raise_for_status() except Exception as e: log.debug(f'{e}') return request.app.state.config.IMAGE_GENERATION_MODEL -def get_image_model(request): +async def get_image_model(request): if request.app.state.config.IMAGE_GENERATION_ENGINE == 'openai': return ( request.app.state.config.IMAGE_GENERATION_MODEL @@ -99,11 +103,13 @@ def get_image_model(request): or request.app.state.config.IMAGE_GENERATION_ENGINE == '' ): try: - r = requests.get( + session = await get_session() + async with session.get( url=f'{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options', headers={'authorization': get_automatic1111_api_auth(request)}, - ) - options = r.json() + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) as r: + options = await r.json() return options['sd_model_checkpoint'] except Exception as e: request.app.state.config.ENABLE_IMAGE_GENERATION = False @@ -200,7 +206,7 @@ async def update_config(request: Request, form_data: ImagesConfig, user=Depends( request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = form_data.ENABLE_IMAGE_PROMPT_GENERATION request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.IMAGE_GENERATION_ENGINE - set_image_model(request, form_data.IMAGE_GENERATION_MODEL) + await set_image_model(request, form_data.IMAGE_GENERATION_MODEL) if form_data.IMAGE_SIZE == 'auto' and not re.match( IMAGE_AUTO_SIZE_MODELS_REGEX_PATTERN, form_data.IMAGE_GENERATION_MODEL ): @@ -437,21 +443,22 @@ class CreateImageForm(BaseModel): GenerateImageForm = CreateImageForm # Alias for backward compatibility -def get_image_data(data: str, headers=None): +async def get_image_data(data: str, headers=None): try: if data.startswith('http://') or data.startswith('https://'): - if headers: - r = requests.get(data, headers=headers) - else: - r = requests.get(data) - - r.raise_for_status() - if r.headers['content-type'].split('/')[0] == 'image': - mime_type = r.headers['content-type'] - return r.content, mime_type - else: - log.error('Url does not point to an image.') - return None + session = await get_session() + async with session.get( + data, + headers=headers, + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) as r: + r.raise_for_status() + content_type = r.headers.get('content-type', '') + if content_type.split('/')[0] == 'image': + return await r.read(), content_type + else: + log.error('Url does not point to an image.') + return None, None else: if ',' in data: header, encoded = data.split(',', 1) @@ -541,7 +548,7 @@ async def image_generations( metadata = metadata or {} - model = get_image_model(request) + model = await get_image_model(request) try: if request.app.state.config.IMAGE_GENERATION_ENGINE == 'openai': @@ -595,12 +602,12 @@ async def image_generations( for image in res['data']: if image_url := image.get('url', None): - image_data, content_type = get_image_data( + image_data, content_type = await get_image_data( image_url, {k: v for k, v in headers.items() if k != 'Content-Type'}, ) else: - image_data, content_type = get_image_data(image['b64_json']) + image_data, content_type = await get_image_data(image['b64_json']) _, url = await upload_image(request, image_data, content_type, {**data, **metadata}, user) images.append({'url': url}) @@ -645,14 +652,14 @@ async def image_generations( if model.endswith(':predict'): for image in res['predictions']: - image_data, content_type = get_image_data(image['bytesBase64Encoded']) + image_data, content_type = await get_image_data(image['bytesBase64Encoded']) _, url = await upload_image(request, image_data, content_type, {**data, **metadata}, user) images.append({'url': url}) elif model.endswith(':generateContent'): for image in res['candidates']: for part in image['content']['parts']: if part.get('inlineData', {}).get('data'): - image_data, content_type = get_image_data(part['inlineData']['data']) + image_data, content_type = await get_image_data(part['inlineData']['data']) _, url = await upload_image( request, image_data, @@ -705,7 +712,7 @@ async def image_generations( if request.app.state.config.COMFYUI_API_KEY: headers = {'Authorization': f'Bearer {request.app.state.config.COMFYUI_API_KEY}'} - image_data, content_type = get_image_data(image['url'], headers) + image_data, content_type = await get_image_data(image['url'], headers) _, url = await upload_image( request, image_data, @@ -720,7 +727,7 @@ async def image_generations( or request.app.state.config.IMAGE_GENERATION_ENGINE == '' ): if form_data.model: - set_image_model(request, form_data.model) + await set_image_model(request, form_data.model) data = { 'prompt': form_data.prompt, @@ -751,7 +758,7 @@ async def image_generations( images = [] for image in res['images']: - image_data, content_type = get_image_data(image) + image_data, content_type = await get_image_data(image) _, url = await upload_image( request, image_data, @@ -919,12 +926,12 @@ async def image_edits( images = [] for image in res['data']: if image_url := image.get('url', None): - image_data, content_type = get_image_data( + image_data, content_type = await get_image_data( image_url, {k: v for k, v in headers.items() if k != 'Content-Type'}, ) else: - image_data, content_type = get_image_data(image['b64_json']) + image_data, content_type = await get_image_data(image['b64_json']) _, url = await upload_image(request, image_data, content_type, {**data, **metadata}, user) images.append({'url': url}) @@ -975,7 +982,7 @@ async def image_edits( for image in res['candidates']: for part in image['content']['parts']: if part.get('inlineData', {}).get('data'): - image_data, content_type = get_image_data(part['inlineData']['data']) + image_data, content_type = await get_image_data(part['inlineData']['data']) _, url = await upload_image( request, image_data, @@ -1055,7 +1062,7 @@ async def image_edits( if request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY: headers = {'Authorization': f'Bearer {request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY}'} - image_data, content_type = get_image_data(image_url, headers) + image_data, content_type = await get_image_data(image_url, headers) _, url = await upload_image( request, image_data, diff --git a/backend/open_webui/utils/files.py b/backend/open_webui/utils/files.py index eea3a8b486..78392f2d41 100644 --- a/backend/open_webui/utils/files.py +++ b/backend/open_webui/utils/files.py @@ -71,7 +71,7 @@ async def get_image_url_from_base64(request, base64_image_string, metadata, user if BASE64_IMAGE_URL_PREFIX.match(base64_image_string): image_url = '' # Extract base64 image data from the line - image_data, content_type = get_image_data(base64_image_string) + image_data, content_type = await get_image_data(base64_image_string) if image_data is not None: _, image_url = await upload_image( request,