mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-01 01:39:05 -05:00
1685 lines
56 KiB
Python
1685 lines
56 KiB
Python
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
|
|
# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
|
|
# least connections, or least response time for better resource utilization and performance optimization.
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import re
|
|
import time
|
|
from datetime import datetime
|
|
|
|
from typing import Optional, Union
|
|
from urllib.parse import urlparse
|
|
import aiohttp
|
|
from aiocache import cached
|
|
|
|
|
|
from open_webui.utils.headers import include_user_info_headers
|
|
from open_webui.models.chats import Chats
|
|
from open_webui.models.users import UserModel
|
|
|
|
from open_webui.env import (
|
|
ENABLE_FORWARD_USER_INFO_HEADERS,
|
|
FORWARD_SESSION_INFO_HEADER_CHAT_ID,
|
|
)
|
|
|
|
from fastapi import (
|
|
Depends,
|
|
FastAPI,
|
|
File,
|
|
HTTPException,
|
|
Request,
|
|
UploadFile,
|
|
APIRouter,
|
|
)
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel, ConfigDict, validator
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from open_webui.internal.db import get_async_session
|
|
|
|
|
|
from open_webui.models.models import Models
|
|
from open_webui.models.access_grants import AccessGrants
|
|
from open_webui.models.groups import Groups
|
|
from open_webui.utils.access_control import check_model_access
|
|
from open_webui.utils.misc import (
|
|
calculate_sha256,
|
|
)
|
|
from open_webui.utils.session_pool import (
|
|
cleanup_response,
|
|
get_session,
|
|
stream_wrapper,
|
|
)
|
|
from open_webui.utils.payload import (
|
|
apply_model_params_to_body_ollama,
|
|
apply_model_params_to_body_openai,
|
|
apply_system_prompt_to_body,
|
|
)
|
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
|
from open_webui.config import (
|
|
UPLOAD_DIR,
|
|
)
|
|
from open_webui.env import (
|
|
ENV,
|
|
MODELS_CACHE_TTL,
|
|
AIOHTTP_CLIENT_SESSION_SSL,
|
|
AIOHTTP_CLIENT_TIMEOUT,
|
|
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
|
|
BYPASS_MODEL_ACCESS_CONTROL,
|
|
)
|
|
from open_webui.constants import ERROR_MESSAGES
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
##########################################
|
|
#
|
|
# Utility functions
|
|
# Let what runs locally be trusted, and let no weight
|
|
# be loaded without serving the one who waits for the answer.
|
|
#
|
|
##########################################
|
|
|
|
|
|
async def send_get_request(url, key=None, user: UserModel = None):
|
|
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
|
try:
|
|
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
|
headers = {
|
|
'Content-Type': 'application/json',
|
|
**({'Authorization': f'Bearer {key}'} if key else {}),
|
|
}
|
|
|
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
|
headers = include_user_info_headers(headers, user)
|
|
|
|
async with session.get(
|
|
url,
|
|
headers=headers,
|
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
|
) as response:
|
|
return await response.json()
|
|
except Exception as e:
|
|
# Handle connection error here
|
|
log.error(f'Connection error: {e}')
|
|
return None
|
|
|
|
|
|
async def send_request(
|
|
url: str,
|
|
method: str = 'POST',
|
|
*,
|
|
payload: Optional[Union[str, bytes]] = None,
|
|
key: Optional[str] = None,
|
|
user: UserModel = None,
|
|
stream: bool = False,
|
|
content_type: Optional[str] = None,
|
|
metadata: Optional[dict] = None,
|
|
):
|
|
r = None
|
|
streaming = False
|
|
try:
|
|
session = await get_session()
|
|
|
|
headers = {
|
|
'Content-Type': 'application/json',
|
|
**({'Authorization': f'Bearer {key}'} if key else {}),
|
|
}
|
|
|
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
|
headers = include_user_info_headers(headers, user)
|
|
if metadata and metadata.get('chat_id'):
|
|
headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = metadata.get('chat_id')
|
|
|
|
r = await session.request(
|
|
method,
|
|
url,
|
|
data=payload,
|
|
headers=headers,
|
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
|
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT),
|
|
)
|
|
|
|
if not r.ok:
|
|
try:
|
|
res = await r.json()
|
|
if 'error' in res:
|
|
raise HTTPException(status_code=r.status, detail=res['error'])
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
log.error(f'Failed to parse error response: {e}')
|
|
raise HTTPException(
|
|
status_code=r.status,
|
|
detail=ERROR_MESSAGES.SERVER_CONNECTION_ERROR,
|
|
)
|
|
|
|
r.raise_for_status()
|
|
|
|
if stream:
|
|
response_headers = dict(r.headers)
|
|
if content_type:
|
|
response_headers['Content-Type'] = content_type
|
|
|
|
streaming = True
|
|
return StreamingResponse(
|
|
stream_wrapper(r),
|
|
status_code=r.status,
|
|
headers=response_headers,
|
|
)
|
|
else:
|
|
try:
|
|
return await r.json()
|
|
except Exception:
|
|
return None
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=r.status if r else 500,
|
|
detail=f'Ollama: {e}' if str(e) else ERROR_MESSAGES.SERVER_CONNECTION_ERROR,
|
|
)
|
|
finally:
|
|
if not streaming:
|
|
await cleanup_response(r)
|
|
|
|
|
|
def get_api_key(idx, url, configs):
|
|
parsed_url = urlparse(url)
|
|
base_url = f'{parsed_url.scheme}://{parsed_url.netloc}'
|
|
return configs.get(str(idx), configs.get(base_url, {})).get('key', None) # Legacy support
|
|
|
|
|
|
##########################################
|
|
#
|
|
# API routes
|
|
#
|
|
##########################################
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.head('/')
|
|
@router.get('/')
|
|
async def get_status():
|
|
return {'status': True}
|
|
|
|
|
|
class ConnectionVerificationForm(BaseModel):
|
|
url: str
|
|
key: Optional[str] = None
|
|
|
|
|
|
@router.post('/verify')
|
|
async def verify_connection(form_data: ConnectionVerificationForm, user=Depends(get_admin_user)):
|
|
url = form_data.url
|
|
key = form_data.key
|
|
|
|
async with aiohttp.ClientSession(
|
|
trust_env=True,
|
|
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
|
|
) as session:
|
|
try:
|
|
headers = {
|
|
**({'Authorization': f'Bearer {key}'} if key else {}),
|
|
}
|
|
|
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
|
headers = include_user_info_headers(headers, user)
|
|
|
|
async with session.get(
|
|
f'{url}/api/version',
|
|
headers=headers,
|
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
|
) as r:
|
|
if r.status != 200:
|
|
detail = f'HTTP Error: {r.status}'
|
|
res = await r.json()
|
|
|
|
if 'error' in res:
|
|
detail = f'External Error: {res["error"]}'
|
|
raise Exception(detail)
|
|
|
|
data = await r.json()
|
|
return data
|
|
except aiohttp.ClientError as e:
|
|
log.exception(f'Client error: {str(e)}')
|
|
raise HTTPException(status_code=500, detail=ERROR_MESSAGES.SERVER_CONNECTION_ERROR)
|
|
except Exception as e:
|
|
log.exception(f'Unexpected error: {e}')
|
|
error_detail = f'Unexpected error: {str(e)}'
|
|
raise HTTPException(status_code=500, detail=error_detail)
|
|
|
|
|
|
@router.get('/config')
|
|
async def get_config(request: Request, user=Depends(get_admin_user)):
|
|
return {
|
|
'ENABLE_OLLAMA_API': request.app.state.config.ENABLE_OLLAMA_API,
|
|
'OLLAMA_BASE_URLS': request.app.state.config.OLLAMA_BASE_URLS,
|
|
'OLLAMA_API_CONFIGS': request.app.state.config.OLLAMA_API_CONFIGS,
|
|
}
|
|
|
|
|
|
class OllamaConfigForm(BaseModel):
|
|
ENABLE_OLLAMA_API: Optional[bool] = None
|
|
OLLAMA_BASE_URLS: list[str]
|
|
OLLAMA_API_CONFIGS: dict
|
|
|
|
|
|
@router.post('/config/update')
|
|
async def update_config(request: Request, form_data: OllamaConfigForm, user=Depends(get_admin_user)):
|
|
request.app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
|
|
|
|
request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
|
|
request.app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
|
|
|
|
# Remove the API configs that are not in the API URLS
|
|
keys = list(map(str, range(len(request.app.state.config.OLLAMA_BASE_URLS))))
|
|
request.app.state.config.OLLAMA_API_CONFIGS = {
|
|
key: value for key, value in request.app.state.config.OLLAMA_API_CONFIGS.items() if key in keys
|
|
}
|
|
|
|
return {
|
|
'ENABLE_OLLAMA_API': request.app.state.config.ENABLE_OLLAMA_API,
|
|
'OLLAMA_BASE_URLS': request.app.state.config.OLLAMA_BASE_URLS,
|
|
'OLLAMA_API_CONFIGS': request.app.state.config.OLLAMA_API_CONFIGS,
|
|
}
|
|
|
|
|
|
def merge_ollama_models_lists(model_lists):
|
|
merged_models = {}
|
|
|
|
for idx, model_list in enumerate(model_lists):
|
|
if model_list is not None:
|
|
for model in model_list:
|
|
id = model.get('model')
|
|
if id is not None:
|
|
if id not in merged_models:
|
|
model['urls'] = [idx]
|
|
merged_models[id] = model
|
|
else:
|
|
merged_models[id]['urls'].append(idx)
|
|
|
|
return list(merged_models.values())
|
|
|
|
|
|
@cached(
|
|
ttl=MODELS_CACHE_TTL,
|
|
key=lambda _, user: f'ollama_all_models_{user.id}' if user else 'ollama_all_models',
|
|
)
|
|
async def get_all_models(request: Request, user: UserModel = None):
|
|
log.info('get_all_models()')
|
|
if request.app.state.config.ENABLE_OLLAMA_API:
|
|
request_tasks = []
|
|
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
|
|
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
|
|
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
|
|
):
|
|
request_tasks.append(send_get_request(f'{url}/api/tags', user=user))
|
|
else:
|
|
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
|
str(idx),
|
|
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
|
|
)
|
|
|
|
enable = api_config.get('enable', True)
|
|
key = api_config.get('key', None)
|
|
|
|
if enable:
|
|
request_tasks.append(send_get_request(f'{url}/api/tags', key, user=user))
|
|
else:
|
|
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
|
|
|
|
responses = await asyncio.gather(*request_tasks)
|
|
|
|
for idx, response in enumerate(responses):
|
|
if response:
|
|
url = request.app.state.config.OLLAMA_BASE_URLS[idx]
|
|
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
|
str(idx),
|
|
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
|
|
)
|
|
|
|
connection_type = api_config.get('connection_type', 'local')
|
|
|
|
prefix_id = api_config.get('prefix_id', None)
|
|
tags = api_config.get('tags', [])
|
|
model_ids = api_config.get('model_ids', [])
|
|
|
|
if len(model_ids) != 0 and 'models' in response:
|
|
response['models'] = list(
|
|
filter(
|
|
lambda model: model['model'] in model_ids,
|
|
response['models'],
|
|
)
|
|
)
|
|
|
|
for model in response.get('models', []):
|
|
if prefix_id:
|
|
model['model'] = f'{prefix_id}.{model["model"]}'
|
|
|
|
if tags:
|
|
model['tags'] = tags
|
|
|
|
if connection_type:
|
|
model['connection_type'] = connection_type
|
|
|
|
models = {
|
|
'models': merge_ollama_models_lists(
|
|
map(
|
|
lambda response: response.get('models', []) if response else None,
|
|
responses,
|
|
)
|
|
)
|
|
}
|
|
|
|
try:
|
|
loaded_models = await get_ollama_loaded_models(request, user=user)
|
|
expires_map = {m['model']: m['expires_at'] for m in loaded_models['models'] if 'expires_at' in m}
|
|
|
|
for m in models['models']:
|
|
if m['model'] in expires_map:
|
|
# Parse ISO8601 datetime with offset, get unix timestamp as int
|
|
dt = datetime.fromisoformat(expires_map[m['model']])
|
|
m['expires_at'] = int(dt.timestamp())
|
|
except Exception as e:
|
|
log.debug(f'Failed to get loaded models: {e}')
|
|
|
|
else:
|
|
models = {'models': []}
|
|
|
|
request.app.state.OLLAMA_MODELS = {model['model']: model for model in models['models']}
|
|
return models
|
|
|
|
|
|
async def get_filtered_models(models, user, db=None):
|
|
# Filter models based on user access control
|
|
model_ids = [model['model'] for model in models.get('models', [])]
|
|
model_infos = {model_info.id: model_info for model_info in await Models.get_models_by_ids(model_ids, db=db)}
|
|
user_group_ids = {group.id for group in await Groups.get_groups_by_member_id(user.id, db=db)}
|
|
|
|
# Batch-fetch accessible resource IDs in a single query instead of N has_access calls
|
|
accessible_model_ids = await AccessGrants.get_accessible_resource_ids(
|
|
user_id=user.id,
|
|
resource_type='model',
|
|
resource_ids=list(model_infos.keys()),
|
|
permission='read',
|
|
user_group_ids=user_group_ids,
|
|
db=db,
|
|
)
|
|
|
|
filtered_models = []
|
|
for model in models.get('models', []):
|
|
model_info = model_infos.get(model['model'])
|
|
if model_info:
|
|
if user.id == model_info.user_id or model_info.id in accessible_model_ids:
|
|
filtered_models.append(model)
|
|
return filtered_models
|
|
|
|
|
|
@router.get('/api/tags')
|
|
@router.get('/api/tags/{url_idx}')
|
|
async def get_ollama_tags(request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)):
|
|
if not request.app.state.config.ENABLE_OLLAMA_API:
|
|
raise HTTPException(status_code=503, detail=ERROR_MESSAGES.OLLAMA_API_DISABLED)
|
|
|
|
models = []
|
|
|
|
if url_idx is None:
|
|
models = await get_all_models(request, user=user)
|
|
else:
|
|
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
|
models = await send_request(f'{url}/api/tags', 'GET', key=key, user=user)
|
|
|
|
if user.role == 'user' and not BYPASS_MODEL_ACCESS_CONTROL:
|
|
models['models'] = await get_filtered_models(models, user)
|
|
|
|
return models
|
|
|
|
|
|
@router.get('/api/ps')
|
|
async def get_ollama_loaded_models(request: Request, user=Depends(get_admin_user)):
|
|
"""
|
|
List models that are currently loaded into Ollama memory, and which node they are loaded on.
|
|
"""
|
|
if request.app.state.config.ENABLE_OLLAMA_API:
|
|
request_tasks = []
|
|
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
|
|
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
|
|
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
|
|
):
|
|
request_tasks.append(send_get_request(f'{url}/api/ps', user=user))
|
|
else:
|
|
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
|
str(idx),
|
|
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
|
|
)
|
|
|
|
enable = api_config.get('enable', True)
|
|
key = api_config.get('key', None)
|
|
|
|
if enable:
|
|
request_tasks.append(send_get_request(f'{url}/api/ps', key, user=user))
|
|
else:
|
|
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
|
|
|
|
responses = await asyncio.gather(*request_tasks)
|
|
|
|
for idx, response in enumerate(responses):
|
|
if response:
|
|
url = request.app.state.config.OLLAMA_BASE_URLS[idx]
|
|
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
|
str(idx),
|
|
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
|
|
)
|
|
|
|
prefix_id = api_config.get('prefix_id', None)
|
|
|
|
for model in response.get('models', []):
|
|
if prefix_id:
|
|
model['model'] = f'{prefix_id}.{model["model"]}'
|
|
|
|
models = {
|
|
'models': merge_ollama_models_lists(
|
|
map(
|
|
lambda response: response.get('models', []) if response else None,
|
|
responses,
|
|
)
|
|
)
|
|
}
|
|
else:
|
|
models = {'models': []}
|
|
|
|
return models
|
|
|
|
|
|
@router.get('/api/version')
|
|
@router.get('/api/version/{url_idx}')
|
|
async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
|
|
if request.app.state.config.ENABLE_OLLAMA_API:
|
|
if url_idx is None:
|
|
# returns lowest version
|
|
request_tasks = []
|
|
|
|
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
|
|
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
|
str(idx),
|
|
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
|
|
)
|
|
|
|
enable = api_config.get('enable', True)
|
|
key = api_config.get('key', None)
|
|
|
|
if enable:
|
|
request_tasks.append(
|
|
send_get_request(
|
|
f'{url}/api/version',
|
|
key,
|
|
)
|
|
)
|
|
|
|
responses = await asyncio.gather(*request_tasks)
|
|
responses = list(filter(lambda x: x is not None, responses))
|
|
|
|
if len(responses) > 0:
|
|
lowest_version = min(
|
|
responses,
|
|
key=lambda x: tuple(map(int, re.sub(r'^v|-.*', '', x['version']).split('.'))),
|
|
)
|
|
|
|
return {'version': lowest_version['version']}
|
|
else:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
|
|
)
|
|
else:
|
|
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
return await send_request(f'{url}/api/version', 'GET')
|
|
else:
|
|
return {'version': False}
|
|
|
|
|
|
class ModelNameForm(BaseModel):
|
|
model: Optional[str] = None
|
|
model_config = ConfigDict(
|
|
extra='allow',
|
|
)
|
|
|
|
|
|
@router.post('/api/unload')
|
|
async def unload_model(
|
|
request: Request,
|
|
form_data: ModelNameForm,
|
|
user=Depends(get_admin_user),
|
|
):
|
|
form_data = form_data.model_dump(exclude_none=True)
|
|
model_name = form_data.get('model', form_data.get('name'))
|
|
|
|
if not model_name:
|
|
raise HTTPException(status_code=400, detail='Missing name of the model to unload.')
|
|
|
|
# Refresh/load models if needed, get mapping from name to URLs
|
|
await get_all_models(request, user=user)
|
|
models = request.app.state.OLLAMA_MODELS
|
|
|
|
if model_name not in models:
|
|
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model_name))
|
|
url_indices = models[model_name]['urls']
|
|
|
|
# Send unload to ALL url_indices
|
|
results = []
|
|
errors = []
|
|
for idx in url_indices:
|
|
url = request.app.state.config.OLLAMA_BASE_URLS[idx]
|
|
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
|
str(idx), request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
|
)
|
|
key = get_api_key(idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
|
|
|
prefix_id = api_config.get('prefix_id', None)
|
|
if prefix_id and model_name.startswith(f'{prefix_id}.'):
|
|
model_name = model_name[len(f'{prefix_id}.') :]
|
|
|
|
payload = {'model': model_name, 'keep_alive': 0, 'prompt': ''}
|
|
|
|
try:
|
|
res = await send_request(
|
|
f'{url}/api/generate',
|
|
payload=json.dumps(payload),
|
|
key=key,
|
|
user=user,
|
|
)
|
|
results.append({'url_idx': idx, 'success': True, 'response': res})
|
|
except Exception as e:
|
|
log.exception(f'Failed to unload model on node {idx}: {e}')
|
|
errors.append({'url_idx': idx, 'success': False, 'error': str(e)})
|
|
|
|
if len(errors) > 0:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f'Failed to unload model on {len(errors)} nodes: {errors}',
|
|
)
|
|
|
|
return {'status': True}
|
|
|
|
|
|
@router.post('/api/pull')
|
|
@router.post('/api/pull/{url_idx}')
|
|
async def pull_model(
|
|
request: Request,
|
|
form_data: ModelNameForm,
|
|
url_idx: int = 0,
|
|
user=Depends(get_admin_user),
|
|
):
|
|
if not request.app.state.config.ENABLE_OLLAMA_API:
|
|
raise HTTPException(status_code=503, detail=ERROR_MESSAGES.OLLAMA_API_DISABLED)
|
|
|
|
form_data = form_data.model_dump(exclude_none=True)
|
|
form_data['model'] = form_data.get('model', form_data.get('name'))
|
|
|
|
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
log.info(f'url: {url}')
|
|
|
|
# Admin should be able to pull models from any source
|
|
payload = {**form_data, 'insecure': True}
|
|
|
|
return await send_request(
|
|
f'{url}/api/pull',
|
|
payload=json.dumps(payload),
|
|
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
|
user=user,
|
|
stream=True,
|
|
)
|
|
|
|
|
|
class PushModelForm(BaseModel):
|
|
model: str
|
|
insecure: Optional[bool] = None
|
|
stream: Optional[bool] = None
|
|
|
|
|
|
@router.delete('/api/push')
|
|
@router.delete('/api/push/{url_idx}')
|
|
async def push_model(
|
|
request: Request,
|
|
form_data: PushModelForm,
|
|
url_idx: Optional[int] = None,
|
|
user=Depends(get_admin_user),
|
|
):
|
|
if not request.app.state.config.ENABLE_OLLAMA_API:
|
|
raise HTTPException(status_code=503, detail=ERROR_MESSAGES.OLLAMA_API_DISABLED)
|
|
|
|
if url_idx is None:
|
|
await get_all_models(request, user=user)
|
|
models = request.app.state.OLLAMA_MODELS
|
|
|
|
if form_data.model in models:
|
|
url_idx = models[form_data.model]['urls'][0]
|
|
else:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
|
)
|
|
|
|
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
log.debug(f'url: {url}')
|
|
|
|
return await send_request(
|
|
f'{url}/api/push',
|
|
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
|
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
|
user=user,
|
|
stream=True,
|
|
)
|
|
|
|
|
|
class CreateModelForm(BaseModel):
|
|
model: Optional[str] = None
|
|
stream: Optional[bool] = None
|
|
path: Optional[str] = None
|
|
|
|
model_config = ConfigDict(extra='allow')
|
|
|
|
|
|
@router.post('/api/create')
|
|
@router.post('/api/create/{url_idx}')
|
|
async def create_model(
|
|
request: Request,
|
|
form_data: CreateModelForm,
|
|
url_idx: int = 0,
|
|
user=Depends(get_admin_user),
|
|
):
|
|
if not request.app.state.config.ENABLE_OLLAMA_API:
|
|
raise HTTPException(status_code=503, detail=ERROR_MESSAGES.OLLAMA_API_DISABLED)
|
|
|
|
log.debug(f'form_data: {form_data}')
|
|
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
|
|
return await send_request(
|
|
f'{url}/api/create',
|
|
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
|
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
|
user=user,
|
|
stream=True,
|
|
)
|
|
|
|
|
|
class CopyModelForm(BaseModel):
|
|
source: str
|
|
destination: str
|
|
|
|
|
|
@router.post('/api/copy')
|
|
@router.post('/api/copy/{url_idx}')
|
|
async def copy_model(
|
|
request: Request,
|
|
form_data: CopyModelForm,
|
|
url_idx: Optional[int] = None,
|
|
user=Depends(get_admin_user),
|
|
):
|
|
if not request.app.state.config.ENABLE_OLLAMA_API:
|
|
raise HTTPException(status_code=503, detail=ERROR_MESSAGES.OLLAMA_API_DISABLED)
|
|
|
|
if url_idx is None:
|
|
await get_all_models(request, user=user)
|
|
models = request.app.state.OLLAMA_MODELS
|
|
|
|
if form_data.source in models:
|
|
url_idx = models[form_data.source]['urls'][0]
|
|
else:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
|
|
)
|
|
|
|
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
|
|
|
await send_request(
|
|
f'{url}/api/copy',
|
|
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
|
key=key,
|
|
user=user,
|
|
)
|
|
return True
|
|
|
|
|
|
@router.delete('/api/delete')
|
|
@router.delete('/api/delete/{url_idx}')
|
|
async def delete_model(
|
|
request: Request,
|
|
form_data: ModelNameForm,
|
|
url_idx: Optional[int] = None,
|
|
user=Depends(get_admin_user),
|
|
):
|
|
if not request.app.state.config.ENABLE_OLLAMA_API:
|
|
raise HTTPException(status_code=503, detail=ERROR_MESSAGES.OLLAMA_API_DISABLED)
|
|
|
|
form_data = form_data.model_dump(exclude_none=True)
|
|
form_data['model'] = form_data.get('model', form_data.get('name'))
|
|
|
|
model = form_data.get('model')
|
|
|
|
if url_idx is None:
|
|
await get_all_models(request, user=user)
|
|
models = request.app.state.OLLAMA_MODELS
|
|
|
|
if model in models:
|
|
url_idx = models[model]['urls'][0]
|
|
else:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
|
|
)
|
|
|
|
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
|
|
|
await send_request(
|
|
f'{url}/api/delete',
|
|
'DELETE',
|
|
payload=json.dumps(form_data),
|
|
key=key,
|
|
user=user,
|
|
)
|
|
return True
|
|
|
|
|
|
@router.post('/api/show')
|
|
async def show_model_info(request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)):
|
|
if not request.app.state.config.ENABLE_OLLAMA_API:
|
|
raise HTTPException(status_code=503, detail=ERROR_MESSAGES.OLLAMA_API_DISABLED)
|
|
|
|
form_data = form_data.model_dump(exclude_none=True)
|
|
form_data['model'] = form_data.get('model', form_data.get('name'))
|
|
|
|
model = form_data.get('model')
|
|
|
|
# Enforce per-model access control
|
|
await check_model_access(user, await Models.get_model_by_id(model), BYPASS_MODEL_ACCESS_CONTROL)
|
|
|
|
await get_all_models(request, user=user)
|
|
models = request.app.state.OLLAMA_MODELS
|
|
|
|
if model not in models:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
|
|
)
|
|
|
|
url_idx = random.choice(models[model]['urls'])
|
|
|
|
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
|
|
|
return await send_request(
|
|
f'{url}/api/show',
|
|
payload=json.dumps(form_data),
|
|
key=key,
|
|
user=user,
|
|
)
|
|
|
|
|
|
class GenerateEmbedForm(BaseModel):
|
|
model: str
|
|
input: list[str] | str
|
|
truncate: Optional[bool] = None
|
|
options: Optional[dict] = None
|
|
keep_alive: Optional[Union[int, str]] = None
|
|
|
|
model_config = ConfigDict(
|
|
extra='allow',
|
|
)
|
|
|
|
|
|
@router.post('/api/embed')
|
|
@router.post('/api/embed/{url_idx}')
|
|
async def embed(
|
|
request: Request,
|
|
form_data: GenerateEmbedForm,
|
|
url_idx: Optional[int] = None,
|
|
user=Depends(get_verified_user),
|
|
):
|
|
if not request.app.state.config.ENABLE_OLLAMA_API:
|
|
raise HTTPException(status_code=503, detail=ERROR_MESSAGES.OLLAMA_API_DISABLED)
|
|
|
|
log.info(f'generate_ollama_batch_embeddings {form_data}')
|
|
|
|
# Enforce per-model access control
|
|
await check_model_access(user, await Models.get_model_by_id(form_data.model), BYPASS_MODEL_ACCESS_CONTROL)
|
|
|
|
if url_idx is None:
|
|
model = form_data.model
|
|
|
|
# Check if model is already in app state cache to avoid expensive get_all_models() call
|
|
models = request.app.state.OLLAMA_MODELS
|
|
if not models or model not in models:
|
|
await get_all_models(request, user=user)
|
|
models = request.app.state.OLLAMA_MODELS
|
|
|
|
if model in models:
|
|
url_idx = random.choice(models[model]['urls'])
|
|
else:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
|
)
|
|
|
|
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
|
str(url_idx),
|
|
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
|
|
)
|
|
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
|
|
|
prefix_id = api_config.get('prefix_id', None)
|
|
if prefix_id:
|
|
form_data.model = form_data.model.replace(f'{prefix_id}.', '')
|
|
|
|
return await send_request(
|
|
f'{url}/api/embed',
|
|
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
|
key=key,
|
|
user=user,
|
|
)
|
|
|
|
|
|
class GenerateEmbeddingsForm(BaseModel):
|
|
model: str
|
|
prompt: str
|
|
options: Optional[dict] = None
|
|
keep_alive: Optional[Union[int, str]] = None
|
|
|
|
|
|
@router.post('/api/embeddings')
|
|
@router.post('/api/embeddings/{url_idx}')
|
|
async def embeddings(
|
|
request: Request,
|
|
form_data: GenerateEmbeddingsForm,
|
|
url_idx: Optional[int] = None,
|
|
user=Depends(get_verified_user),
|
|
):
|
|
if not request.app.state.config.ENABLE_OLLAMA_API:
|
|
raise HTTPException(status_code=503, detail=ERROR_MESSAGES.OLLAMA_API_DISABLED)
|
|
|
|
log.info(f'generate_ollama_embeddings {form_data}')
|
|
|
|
# Enforce per-model access control
|
|
await check_model_access(user, await Models.get_model_by_id(form_data.model), BYPASS_MODEL_ACCESS_CONTROL)
|
|
|
|
if url_idx is None:
|
|
model = form_data.model
|
|
|
|
# Check if model is already in app state cache to avoid expensive get_all_models() call
|
|
models = request.app.state.OLLAMA_MODELS
|
|
if not models or model not in models:
|
|
await get_all_models(request, user=user)
|
|
models = request.app.state.OLLAMA_MODELS
|
|
|
|
if model in models:
|
|
url_idx = random.choice(models[model]['urls'])
|
|
else:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
|
)
|
|
|
|
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
|
str(url_idx),
|
|
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
|
|
)
|
|
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
|
|
|
prefix_id = api_config.get('prefix_id', None)
|
|
if prefix_id:
|
|
form_data.model = form_data.model.replace(f'{prefix_id}.', '')
|
|
|
|
return await send_request(
|
|
f'{url}/api/embeddings',
|
|
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
|
key=key,
|
|
user=user,
|
|
)
|
|
|
|
|
|
class GenerateCompletionForm(BaseModel):
|
|
model: str
|
|
prompt: Optional[str] = None
|
|
suffix: Optional[str] = None
|
|
images: Optional[list[str]] = None
|
|
format: Optional[Union[dict, str]] = None
|
|
options: Optional[dict] = None
|
|
system: Optional[str] = None
|
|
template: Optional[str] = None
|
|
context: Optional[list[int]] = None
|
|
stream: Optional[bool] = True
|
|
raw: Optional[bool] = None
|
|
keep_alive: Optional[Union[int, str]] = None
|
|
|
|
|
|
@router.post('/api/generate')
|
|
@router.post('/api/generate/{url_idx}')
|
|
async def generate_completion(
|
|
request: Request,
|
|
form_data: GenerateCompletionForm,
|
|
url_idx: Optional[int] = None,
|
|
user=Depends(get_verified_user),
|
|
):
|
|
if not request.app.state.config.ENABLE_OLLAMA_API:
|
|
raise HTTPException(status_code=503, detail=ERROR_MESSAGES.OLLAMA_API_DISABLED)
|
|
|
|
# Enforce per-model access control
|
|
await check_model_access(user, await Models.get_model_by_id(form_data.model), BYPASS_MODEL_ACCESS_CONTROL)
|
|
|
|
if url_idx is None:
|
|
await get_all_models(request, user=user)
|
|
models = request.app.state.OLLAMA_MODELS
|
|
|
|
model = form_data.model
|
|
|
|
if model in models:
|
|
url_idx = random.choice(models[model]['urls'])
|
|
else:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
|
)
|
|
|
|
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
|
str(url_idx),
|
|
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
|
|
)
|
|
|
|
prefix_id = api_config.get('prefix_id', None)
|
|
if prefix_id:
|
|
form_data.model = form_data.model.replace(f'{prefix_id}.', '')
|
|
|
|
return await send_request(
|
|
f'{url}/api/generate',
|
|
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
|
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
|
user=user,
|
|
stream=True,
|
|
)
|
|
|
|
|
|
class ChatMessage(BaseModel):
|
|
role: str
|
|
content: Optional[str] = None
|
|
tool_calls: Optional[list[dict]] = None
|
|
images: Optional[list[str]] = None
|
|
|
|
model_config = ConfigDict(extra='allow')
|
|
|
|
@validator('content', pre=True)
|
|
@classmethod
|
|
def check_at_least_one_field(cls, field_value, values, **kwargs):
|
|
# Raise an error if both 'content' and 'tool_calls' are None
|
|
if field_value is None and ('tool_calls' not in values or values['tool_calls'] is None):
|
|
raise ValueError("At least one of 'content' or 'tool_calls' must be provided")
|
|
|
|
return field_value
|
|
|
|
|
|
class GenerateChatCompletionForm(BaseModel):
|
|
model: str
|
|
messages: list[ChatMessage]
|
|
format: Optional[Union[dict, str]] = None
|
|
options: Optional[dict] = None
|
|
template: Optional[str] = None
|
|
stream: Optional[bool] = True
|
|
keep_alive: Optional[Union[int, str]] = None
|
|
tools: Optional[list[dict]] = None
|
|
model_config = ConfigDict(
|
|
extra='allow',
|
|
)
|
|
|
|
|
|
async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None):
|
|
if url_idx is None:
|
|
models = request.app.state.OLLAMA_MODELS
|
|
if model not in models:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
|
|
)
|
|
url_idx = random.choice(models[model].get('urls', []))
|
|
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
return url, url_idx
|
|
|
|
|
|
@router.post('/api/chat')
|
|
@router.post('/api/chat/{url_idx}')
|
|
async def generate_chat_completion(
|
|
request: Request,
|
|
form_data: dict,
|
|
url_idx: Optional[int] = None,
|
|
user=Depends(get_verified_user),
|
|
bypass_system_prompt: bool = False,
|
|
):
|
|
if not request.app.state.config.ENABLE_OLLAMA_API:
|
|
raise HTTPException(status_code=503, detail=ERROR_MESSAGES.OLLAMA_API_DISABLED)
|
|
|
|
# NOTE: We intentionally do NOT use Depends(get_async_session) here.
|
|
# Database operations (get_model_by_id, AccessGrants.has_access) manage their own short-lived sessions.
|
|
# This prevents holding a connection during the entire LLM call (30-60+ seconds),
|
|
# which would exhaust the connection pool under concurrent load.
|
|
|
|
# bypass_filter is read from request.state to prevent external clients from
|
|
# setting it via query parameter (CVE fix). Only internal server-side callers
|
|
# (e.g. utils/chat.py) should set request.state.bypass_filter = True.
|
|
bypass_filter = getattr(request.state, 'bypass_filter', False)
|
|
if BYPASS_MODEL_ACCESS_CONTROL:
|
|
bypass_filter = True
|
|
|
|
metadata = form_data.pop('metadata', None)
|
|
try:
|
|
form_data = GenerateChatCompletionForm(**form_data)
|
|
except Exception as e:
|
|
log.exception(e)
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=str(e),
|
|
)
|
|
|
|
if isinstance(form_data, BaseModel):
|
|
payload = {**form_data.model_dump(exclude_none=True)}
|
|
|
|
if 'metadata' in payload:
|
|
del payload['metadata']
|
|
|
|
model_id = payload['model']
|
|
model_info = await Models.get_model_by_id(model_id)
|
|
|
|
if model_info:
|
|
if model_info.base_model_id:
|
|
base_model_id = (
|
|
request.base_model_id if hasattr(request, 'base_model_id') else model_info.base_model_id
|
|
) # Use request's base_model_id if available
|
|
payload['model'] = base_model_id
|
|
|
|
params = model_info.params.model_dump()
|
|
|
|
if params:
|
|
system = params.pop('system', None)
|
|
|
|
payload = apply_model_params_to_body_ollama(params, payload)
|
|
if not bypass_system_prompt:
|
|
payload = apply_system_prompt_to_body(system, payload, metadata, user)
|
|
|
|
await check_model_access(user, model_info, bypass_filter)
|
|
else:
|
|
await check_model_access(user, None, bypass_filter)
|
|
|
|
url, url_idx = await get_ollama_url(request, payload['model'], url_idx)
|
|
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
|
str(url_idx),
|
|
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
|
|
)
|
|
|
|
prefix_id = api_config.get('prefix_id', None)
|
|
if prefix_id:
|
|
payload['model'] = payload['model'].replace(f'{prefix_id}.', '')
|
|
|
|
return await send_request(
|
|
f'{url}/api/chat',
|
|
payload=json.dumps(payload),
|
|
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
|
user=user,
|
|
stream=form_data.stream,
|
|
content_type='application/x-ndjson',
|
|
metadata=metadata,
|
|
)
|
|
|
|
|
|
# TODO: we should update this part once Ollama supports other types
|
|
class OpenAIChatMessageContent(BaseModel):
|
|
type: str
|
|
model_config = ConfigDict(extra='allow')
|
|
|
|
|
|
class OpenAIChatMessage(BaseModel):
|
|
role: str
|
|
content: Union[Optional[str], list[OpenAIChatMessageContent]]
|
|
|
|
model_config = ConfigDict(extra='allow')
|
|
|
|
|
|
class OpenAIChatCompletionForm(BaseModel):
|
|
model: str
|
|
messages: list[OpenAIChatMessage]
|
|
|
|
model_config = ConfigDict(extra='allow')
|
|
|
|
|
|
class OpenAICompletionForm(BaseModel):
|
|
model: str
|
|
prompt: str
|
|
|
|
model_config = ConfigDict(extra='allow')
|
|
|
|
|
|
@router.post('/v1/completions')
|
|
@router.post('/v1/completions/{url_idx}')
|
|
async def generate_openai_completion(
|
|
request: Request,
|
|
form_data: dict,
|
|
url_idx: Optional[int] = None,
|
|
user=Depends(get_verified_user),
|
|
):
|
|
# NOTE: We intentionally do NOT use Depends(get_async_session) here.
|
|
# Database operations (get_model_by_id, AccessGrants.has_access) manage their own short-lived sessions.
|
|
# This prevents holding a connection during the entire LLM call (30-60+ seconds),
|
|
# which would exhaust the connection pool under concurrent load.
|
|
metadata = form_data.pop('metadata', None)
|
|
|
|
try:
|
|
form_data = OpenAICompletionForm(**form_data)
|
|
except Exception as e:
|
|
log.exception(e)
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=str(e),
|
|
)
|
|
|
|
payload = {**form_data.model_dump(exclude_none=True, exclude=['metadata'])}
|
|
if 'metadata' in payload:
|
|
del payload['metadata']
|
|
|
|
model_id = form_data.model
|
|
model_info = await Models.get_model_by_id(model_id)
|
|
if model_info:
|
|
if model_info.base_model_id:
|
|
payload['model'] = model_info.base_model_id
|
|
params = model_info.params.model_dump()
|
|
|
|
if params:
|
|
payload = apply_model_params_to_body_openai(params, payload)
|
|
|
|
await check_model_access(user, model_info)
|
|
else:
|
|
await check_model_access(user, None)
|
|
|
|
url, url_idx = await get_ollama_url(request, payload['model'], url_idx)
|
|
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
|
str(url_idx),
|
|
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
|
|
)
|
|
|
|
prefix_id = api_config.get('prefix_id', None)
|
|
|
|
if prefix_id:
|
|
payload['model'] = payload['model'].replace(f'{prefix_id}.', '')
|
|
|
|
return await send_request(
|
|
f'{url}/v1/completions',
|
|
payload=json.dumps(payload),
|
|
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
|
user=user,
|
|
stream=payload.get('stream', False),
|
|
metadata=metadata,
|
|
)
|
|
|
|
|
|
@router.post('/v1/chat/completions')
|
|
@router.post('/v1/chat/completions/{url_idx}')
|
|
async def generate_openai_chat_completion(
|
|
request: Request,
|
|
form_data: dict,
|
|
url_idx: Optional[int] = None,
|
|
user=Depends(get_verified_user),
|
|
):
|
|
# NOTE: We intentionally do NOT use Depends(get_async_session) here.
|
|
# Database operations (get_model_by_id, AccessGrants.has_access) manage their own short-lived sessions.
|
|
# This prevents holding a connection during the entire LLM call (30-60+ seconds),
|
|
# which would exhaust the connection pool under concurrent load.
|
|
metadata = form_data.pop('metadata', None)
|
|
|
|
try:
|
|
completion_form = OpenAIChatCompletionForm(**form_data)
|
|
except Exception as e:
|
|
log.exception(e)
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=str(e),
|
|
)
|
|
|
|
payload = {**completion_form.model_dump(exclude_none=True, exclude=['metadata'])}
|
|
if 'metadata' in payload:
|
|
del payload['metadata']
|
|
|
|
model_id = completion_form.model
|
|
model_info = await Models.get_model_by_id(model_id)
|
|
if model_info:
|
|
if model_info.base_model_id:
|
|
payload['model'] = model_info.base_model_id
|
|
|
|
params = model_info.params.model_dump()
|
|
|
|
if params:
|
|
system = params.pop('system', None)
|
|
|
|
payload = apply_model_params_to_body_openai(params, payload)
|
|
payload = apply_system_prompt_to_body(system, payload, metadata, user)
|
|
|
|
await check_model_access(user, model_info)
|
|
else:
|
|
await check_model_access(user, None)
|
|
|
|
url, url_idx = await get_ollama_url(request, payload['model'], url_idx)
|
|
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
|
str(url_idx),
|
|
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
|
|
)
|
|
|
|
prefix_id = api_config.get('prefix_id', None)
|
|
if prefix_id:
|
|
payload['model'] = payload['model'].replace(f'{prefix_id}.', '')
|
|
|
|
return await send_request(
|
|
f'{url}/v1/chat/completions',
|
|
payload=json.dumps(payload),
|
|
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
|
user=user,
|
|
stream=payload.get('stream', False),
|
|
metadata=metadata,
|
|
)
|
|
|
|
|
|
@router.post('/v1/messages')
|
|
@router.post('/v1/messages/{url_idx}')
|
|
async def generate_anthropic_messages(
|
|
request: Request,
|
|
form_data: dict,
|
|
url_idx: Optional[int] = None,
|
|
user=Depends(get_verified_user),
|
|
):
|
|
"""
|
|
Proxy for Ollama's Anthropic-compatible /v1/messages endpoint.
|
|
|
|
Forwards the request as-is to the Ollama backend, applying the same
|
|
model resolution, access control, and prefix_id handling used by
|
|
the OpenAI-compatible /v1/chat/completions proxy.
|
|
|
|
See https://docs.ollama.com/api/anthropic-compatibility
|
|
"""
|
|
if not request.app.state.config.ENABLE_OLLAMA_API:
|
|
raise HTTPException(status_code=503, detail=ERROR_MESSAGES.OLLAMA_API_DISABLED)
|
|
|
|
payload = {**form_data}
|
|
model_id = payload.get('model', '')
|
|
|
|
model_info = await Models.get_model_by_id(model_id)
|
|
if model_info:
|
|
if model_info.base_model_id:
|
|
payload['model'] = model_info.base_model_id
|
|
|
|
await check_model_access(user, model_info)
|
|
else:
|
|
await check_model_access(user, None)
|
|
|
|
url, url_idx = await get_ollama_url(request, payload['model'], url_idx)
|
|
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
|
str(url_idx),
|
|
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
|
|
)
|
|
|
|
prefix_id = api_config.get('prefix_id', None)
|
|
if prefix_id:
|
|
payload['model'] = payload['model'].replace(f'{prefix_id}.', '')
|
|
|
|
return await send_request(
|
|
f'{url}/v1/messages',
|
|
payload=json.dumps(payload),
|
|
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
|
user=user,
|
|
stream=payload.get('stream', False),
|
|
content_type='text/event-stream' if payload.get('stream', False) else None,
|
|
)
|
|
|
|
|
|
class ResponsesForm(BaseModel):
|
|
model: str
|
|
|
|
model_config = ConfigDict(extra='allow')
|
|
|
|
|
|
@router.post('/v1/responses')
|
|
@router.post('/v1/responses/{url_idx}')
|
|
async def generate_responses(
|
|
request: Request,
|
|
form_data: ResponsesForm,
|
|
url_idx: Optional[int] = None,
|
|
user=Depends(get_verified_user),
|
|
):
|
|
"""
|
|
Proxy for Ollama's OpenAI-compatible /v1/responses endpoint.
|
|
|
|
Forwards the request as-is to the Ollama backend, applying the same
|
|
model resolution, access control, and prefix_id handling used by
|
|
the OpenAI-compatible /v1/chat/completions proxy.
|
|
|
|
See https://ollama.com/blog/responses-api
|
|
"""
|
|
if not request.app.state.config.ENABLE_OLLAMA_API:
|
|
raise HTTPException(status_code=503, detail=ERROR_MESSAGES.OLLAMA_API_DISABLED)
|
|
|
|
payload = form_data.model_dump()
|
|
model_id = form_data.model
|
|
|
|
model_info = await Models.get_model_by_id(model_id)
|
|
if model_info:
|
|
if model_info.base_model_id:
|
|
payload['model'] = model_info.base_model_id
|
|
|
|
await check_model_access(user, model_info)
|
|
else:
|
|
await check_model_access(user, None)
|
|
|
|
url, url_idx = await get_ollama_url(request, payload['model'], url_idx)
|
|
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
|
str(url_idx),
|
|
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
|
|
)
|
|
|
|
prefix_id = api_config.get('prefix_id', None)
|
|
if prefix_id:
|
|
payload['model'] = payload['model'].replace(f'{prefix_id}.', '')
|
|
|
|
return await send_request(
|
|
f'{url}/v1/responses',
|
|
payload=json.dumps(payload),
|
|
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
|
|
user=user,
|
|
stream=payload.get('stream', False),
|
|
content_type='text/event-stream' if payload.get('stream', False) else None,
|
|
)
|
|
|
|
|
|
@router.get('/v1/models')
|
|
@router.get('/v1/models/{url_idx}')
|
|
async def get_openai_models(
|
|
request: Request,
|
|
url_idx: Optional[int] = None,
|
|
user=Depends(get_verified_user),
|
|
db: AsyncSession = Depends(get_async_session),
|
|
):
|
|
models = []
|
|
if url_idx is None:
|
|
model_list = await get_all_models(request, user=user)
|
|
models = [
|
|
{
|
|
'id': model['model'],
|
|
'object': 'model',
|
|
'created': int(time.time()),
|
|
'owned_by': 'openai',
|
|
}
|
|
for model in model_list['models']
|
|
]
|
|
|
|
else:
|
|
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
model_list = await send_request(f'{url}/api/tags', 'GET')
|
|
|
|
models = [
|
|
{
|
|
'id': model['model'],
|
|
'object': 'model',
|
|
'created': int(time.time()),
|
|
'owned_by': 'openai',
|
|
}
|
|
for model in model_list.get('models', [])
|
|
]
|
|
|
|
if user.role == 'user' and not BYPASS_MODEL_ACCESS_CONTROL:
|
|
# Filter models based on user access control
|
|
model_ids = [model['id'] for model in models]
|
|
model_infos = {model_info.id: model_info for model_info in await Models.get_models_by_ids(model_ids, db=db)}
|
|
user_group_ids = {group.id for group in await Groups.get_groups_by_member_id(user.id, db=db)}
|
|
|
|
# Batch-fetch accessible resource IDs in a single query instead of N has_access calls
|
|
accessible_model_ids = await AccessGrants.get_accessible_resource_ids(
|
|
user_id=user.id,
|
|
resource_type='model',
|
|
resource_ids=list(model_infos.keys()),
|
|
permission='read',
|
|
user_group_ids=user_group_ids,
|
|
db=db,
|
|
)
|
|
|
|
filtered_models = []
|
|
for model in models:
|
|
model_info = model_infos.get(model['id'])
|
|
if model_info:
|
|
if user.id == model_info.user_id or model_info.id in accessible_model_ids:
|
|
filtered_models.append(model)
|
|
models = filtered_models
|
|
|
|
return {
|
|
'data': models,
|
|
'object': 'list',
|
|
}
|
|
|
|
|
|
class UrlForm(BaseModel):
|
|
url: str
|
|
|
|
|
|
class UploadBlobForm(BaseModel):
|
|
filename: str
|
|
|
|
|
|
def parse_huggingface_url(hf_url):
|
|
try:
|
|
# Parse the URL
|
|
parsed_url = urlparse(hf_url)
|
|
|
|
# Get the path and split it into components
|
|
path_components = parsed_url.path.split('/')
|
|
|
|
# Extract the desired output
|
|
model_file = path_components[-1]
|
|
|
|
return model_file
|
|
except ValueError:
|
|
return None
|
|
|
|
|
|
async def download_file_stream(ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024):
|
|
done = False
|
|
|
|
if os.path.exists(file_path):
|
|
current_size = os.path.getsize(file_path)
|
|
else:
|
|
current_size = 0
|
|
|
|
headers = {'Range': f'bytes={current_size}-'} if current_size > 0 else {}
|
|
|
|
timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
|
|
|
|
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
|
async with session.get(file_url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_SSL) as response:
|
|
total_size = int(response.headers.get('content-length', 0)) + current_size
|
|
|
|
with open(file_path, 'ab+') as file:
|
|
async for data in response.content.iter_chunked(chunk_size):
|
|
current_size += len(data)
|
|
file.write(data)
|
|
|
|
done = current_size == total_size
|
|
progress = round((current_size / total_size) * 100, 2)
|
|
|
|
yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
|
|
|
|
if done:
|
|
file.close()
|
|
hashed = calculate_sha256(file_path, chunk_size)
|
|
|
|
with open(file_path, 'rb') as f:
|
|
blob_data = f.read()
|
|
|
|
url = f'{ollama_url}/api/blobs/sha256:{hashed}'
|
|
blob_timeout = aiohttp.ClientTimeout(total=30)
|
|
async with aiohttp.ClientSession(timeout=blob_timeout, trust_env=True) as blob_session:
|
|
async with blob_session.post(
|
|
url, data=blob_data, ssl=AIOHTTP_CLIENT_SESSION_SSL
|
|
) as blob_response:
|
|
if blob_response.ok:
|
|
res = {
|
|
'done': done,
|
|
'blob': f'sha256:{hashed}',
|
|
'name': file_name,
|
|
}
|
|
os.remove(file_path)
|
|
|
|
yield f'data: {json.dumps(res)}\n\n'
|
|
else:
|
|
raise RuntimeError('Ollama: Could not create blob, Please try again.')
|
|
|
|
|
|
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
|
|
@router.post('/models/download')
|
|
@router.post('/models/download/{url_idx}')
|
|
async def download_model(
|
|
request: Request,
|
|
form_data: UrlForm,
|
|
url_idx: Optional[int] = None,
|
|
user=Depends(get_admin_user),
|
|
):
|
|
allowed_hosts = ['https://huggingface.co/', 'https://github.com/']
|
|
|
|
if not any(form_data.url.startswith(host) for host in allowed_hosts):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail='Invalid file_url. Only URLs from allowed hosts are permitted.',
|
|
)
|
|
|
|
if url_idx is None:
|
|
url_idx = 0
|
|
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
|
|
file_name = parse_huggingface_url(form_data.url)
|
|
|
|
if file_name:
|
|
file_path = os.path.join(UPLOAD_DIR, file_name)
|
|
|
|
return StreamingResponse(
|
|
download_file_stream(url, form_data.url, file_path, file_name),
|
|
)
|
|
else:
|
|
return None
|
|
|
|
|
|
# TODO: Progress bar does not reflect size & duration of upload.
|
|
@router.post('/models/upload')
|
|
@router.post('/models/upload/{url_idx}')
|
|
async def upload_model(
|
|
request: Request,
|
|
file: UploadFile = File(...),
|
|
url_idx: Optional[int] = None,
|
|
user=Depends(get_admin_user),
|
|
):
|
|
if url_idx is None:
|
|
url_idx = 0
|
|
ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
|
|
filename = os.path.basename(file.filename)
|
|
file_path = os.path.join(UPLOAD_DIR, filename)
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
|
|
|
# --- P1: save file locally ---
|
|
chunk_size = 1024 * 1024 * 2 # 2 MB chunks
|
|
with open(file_path, 'wb') as out_f:
|
|
while True:
|
|
chunk = file.file.read(chunk_size)
|
|
# log.info(f"Chunk: {str(chunk)}") # DEBUG
|
|
if not chunk:
|
|
break
|
|
out_f.write(chunk)
|
|
|
|
async def file_process_stream():
|
|
nonlocal ollama_url
|
|
total_size = os.path.getsize(file_path)
|
|
log.info(f'Total Model Size: {str(total_size)}') # DEBUG
|
|
|
|
# --- P2: SSE progress + calculate sha256 hash ---
|
|
file_hash = calculate_sha256(file_path, chunk_size)
|
|
log.info(f'Model Hash: {str(file_hash)}') # DEBUG
|
|
try:
|
|
with open(file_path, 'rb') as f:
|
|
bytes_read = 0
|
|
while chunk := f.read(chunk_size):
|
|
bytes_read += len(chunk)
|
|
progress = round(bytes_read / total_size * 100, 2)
|
|
data_msg = {
|
|
'progress': progress,
|
|
'total': total_size,
|
|
'completed': bytes_read,
|
|
}
|
|
yield f'data: {json.dumps(data_msg)}\n\n'
|
|
|
|
# --- P3: Upload to ollama /api/blobs ---
|
|
with open(file_path, 'rb') as f:
|
|
blob_data = f.read()
|
|
|
|
url = f'{ollama_url}/api/blobs/sha256:{file_hash}'
|
|
upload_timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
|
async with aiohttp.ClientSession(timeout=upload_timeout, trust_env=True) as upload_session:
|
|
async with upload_session.post(url, data=blob_data, ssl=AIOHTTP_CLIENT_SESSION_SSL) as response:
|
|
if not response.ok:
|
|
raise Exception('Ollama: Could not create blob, Please try again.')
|
|
|
|
log.info(f'Uploaded to /api/blobs') # DEBUG
|
|
# Remove local file
|
|
os.remove(file_path)
|
|
|
|
# Create model in ollama
|
|
model_name, ext = os.path.splitext(filename)
|
|
log.info(f'Created Model: {model_name}') # DEBUG
|
|
|
|
create_payload = {
|
|
'model': model_name,
|
|
# Reference the file by its original name => the uploaded blob's digest
|
|
'files': {filename: f'sha256:{file_hash}'},
|
|
}
|
|
log.info(f'Model Payload: {create_payload}') # DEBUG
|
|
|
|
# Call ollama /api/create
|
|
# https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model
|
|
async with aiohttp.ClientSession(timeout=upload_timeout, trust_env=True) as create_session:
|
|
async with create_session.post(
|
|
f'{ollama_url}/api/create',
|
|
headers={'Content-Type': 'application/json'},
|
|
data=json.dumps(create_payload),
|
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
|
) as create_resp:
|
|
if create_resp.ok:
|
|
log.info(f'API SUCCESS!') # DEBUG
|
|
done_msg = {
|
|
'done': True,
|
|
'blob': f'sha256:{file_hash}',
|
|
'name': filename,
|
|
'model_created': model_name,
|
|
}
|
|
yield f'data: {json.dumps(done_msg)}\n\n'
|
|
else:
|
|
resp_text = await create_resp.text()
|
|
raise Exception(f'Failed to create model in Ollama. {resp_text}')
|
|
|
|
except Exception as e:
|
|
res = {'error': str(e)}
|
|
yield f'data: {json.dumps(res)}\n\n'
|
|
|
|
return StreamingResponse(file_process_stream(), media_type='text/event-stream')
|