mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-01 09:49:03 -05:00
679 lines
21 KiB
Python
679 lines
21 KiB
Python
from typing import Optional
|
|
import io
|
|
import base64
|
|
import json
|
|
import asyncio
|
|
import logging
|
|
|
|
from open_webui.models.groups import Groups
|
|
from open_webui.models.models import (
|
|
ModelForm,
|
|
ModelMeta,
|
|
ModelModel,
|
|
ModelParams,
|
|
ModelResponse,
|
|
ModelListResponse,
|
|
ModelAccessListResponse,
|
|
ModelAccessResponse,
|
|
Models,
|
|
)
|
|
from open_webui.models.access_grants import AccessGrants
|
|
|
|
from pydantic import BaseModel
|
|
from open_webui.constants import ERROR_MESSAGES
|
|
from fastapi import (
|
|
APIRouter,
|
|
Depends,
|
|
HTTPException,
|
|
Request,
|
|
status,
|
|
Response,
|
|
)
|
|
from fastapi.responses import FileResponse, StreamingResponse
|
|
|
|
|
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
|
from open_webui.utils.access_control import has_permission, filter_allowed_access_grants
|
|
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR
|
|
from open_webui.internal.db import get_async_session
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
def is_valid_model_id(model_id: str) -> bool:
|
|
return model_id and len(model_id) <= 256
|
|
|
|
|
|
###########################
|
|
# GetModels
|
|
# Let each model here be judged by what it does and not
|
|
# by what it claims. The house deserves honest servants.
|
|
###########################
|
|
|
|
|
|
PAGE_ITEM_COUNT = 30
|
|
|
|
|
|
@router.get('/list', response_model=ModelAccessListResponse) # do NOT use "/" as path, conflicts with main.py
|
|
async def get_models(
|
|
query: Optional[str] = None,
|
|
view_option: Optional[str] = None,
|
|
tag: Optional[str] = None,
|
|
order_by: Optional[str] = None,
|
|
direction: Optional[str] = None,
|
|
page: Optional[int] = 1,
|
|
user=Depends(get_verified_user),
|
|
db: AsyncSession = Depends(get_async_session),
|
|
):
|
|
limit = PAGE_ITEM_COUNT
|
|
|
|
page = max(1, page)
|
|
skip = (page - 1) * limit
|
|
|
|
filter = {}
|
|
if query:
|
|
filter['query'] = query
|
|
if view_option:
|
|
filter['view_option'] = view_option
|
|
if tag:
|
|
filter['tag'] = tag
|
|
if order_by:
|
|
filter['order_by'] = order_by
|
|
if direction:
|
|
filter['direction'] = direction
|
|
|
|
# Pre-fetch user group IDs once - used for both filter and write_access check
|
|
groups = await Groups.get_groups_by_member_id(user.id, db=db)
|
|
user_group_ids = {group.id for group in groups}
|
|
|
|
if not user.role == 'admin' or not BYPASS_ADMIN_ACCESS_CONTROL:
|
|
if groups:
|
|
filter['group_ids'] = [group.id for group in groups]
|
|
|
|
filter['user_id'] = user.id
|
|
|
|
result = await Models.search_models(user.id, filter=filter, skip=skip, limit=limit, db=db)
|
|
|
|
# Batch-fetch writable model IDs in a single query instead of N has_access calls
|
|
model_ids = [model.id for model in result.items]
|
|
writable_model_ids = await AccessGrants.get_accessible_resource_ids(
|
|
user_id=user.id,
|
|
resource_type='model',
|
|
resource_ids=model_ids,
|
|
permission='write',
|
|
user_group_ids=user_group_ids,
|
|
db=db,
|
|
)
|
|
|
|
return ModelAccessListResponse(
|
|
items=[
|
|
ModelAccessResponse(
|
|
**model.model_dump(),
|
|
write_access=(
|
|
(user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL)
|
|
or user.id == model.user_id
|
|
or model.id in writable_model_ids
|
|
),
|
|
)
|
|
for model in result.items
|
|
],
|
|
total=result.total,
|
|
)
|
|
|
|
|
|
###########################
|
|
# GetBaseModels
|
|
###########################
|
|
|
|
|
|
@router.get('/base', response_model=list[ModelResponse])
|
|
async def get_base_models(user=Depends(get_admin_user), db: AsyncSession = Depends(get_async_session)):
|
|
return await Models.get_base_models(db=db)
|
|
|
|
|
|
###########################
|
|
# GetModelTags
|
|
###########################
|
|
|
|
|
|
@router.get('/tags', response_model=list[str])
|
|
async def get_model_tags(user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
|
|
if user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL:
|
|
models = await Models.get_models(db=db)
|
|
else:
|
|
models = await Models.get_models_by_user_id(user.id, db=db)
|
|
|
|
tags_set = set()
|
|
for model in models:
|
|
if model.meta:
|
|
meta = model.meta.model_dump()
|
|
for tag in meta.get('tags', []):
|
|
try:
|
|
name = tag.get('name') if isinstance(tag, dict) else str(tag)
|
|
if name:
|
|
tags_set.add(name)
|
|
except Exception:
|
|
continue
|
|
|
|
tags = sorted(tags_set)
|
|
return tags
|
|
|
|
|
|
############################
|
|
# CreateNewModel
|
|
############################
|
|
|
|
|
|
@router.post('/create', response_model=Optional[ModelModel])
|
|
async def create_new_model(
|
|
request: Request,
|
|
form_data: ModelForm,
|
|
user=Depends(get_verified_user),
|
|
db: AsyncSession = Depends(get_async_session),
|
|
):
|
|
if user.role != 'admin' and not await has_permission(
|
|
user.id, 'workspace.models', request.app.state.config.USER_PERMISSIONS, db=db
|
|
):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
|
)
|
|
|
|
model = await Models.get_model_by_id(form_data.id, db=db)
|
|
if model:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
|
|
)
|
|
|
|
if not is_valid_model_id(form_data.id):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=ERROR_MESSAGES.MODEL_ID_TOO_LONG,
|
|
)
|
|
|
|
else:
|
|
form_data.access_grants = await filter_allowed_access_grants(
|
|
request.app.state.config.USER_PERMISSIONS,
|
|
user.id,
|
|
user.role,
|
|
form_data.access_grants,
|
|
'sharing.public_models',
|
|
)
|
|
|
|
model = await Models.insert_new_model(form_data, user.id, db=db)
|
|
if model:
|
|
return model
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.DEFAULT(),
|
|
)
|
|
|
|
|
|
############################
|
|
# ExportModels
|
|
############################
|
|
|
|
|
|
@router.get('/export', response_model=list[ModelModel])
|
|
async def export_models(
|
|
request: Request,
|
|
user=Depends(get_verified_user),
|
|
db: AsyncSession = Depends(get_async_session),
|
|
):
|
|
if user.role != 'admin' and not await has_permission(
|
|
user.id,
|
|
'workspace.models_export',
|
|
request.app.state.config.USER_PERMISSIONS,
|
|
db=db,
|
|
):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
|
)
|
|
|
|
if user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL:
|
|
return await Models.get_models(db=db)
|
|
else:
|
|
return await Models.get_models_by_user_id(user.id, db=db)
|
|
|
|
|
|
############################
|
|
# ImportModels
|
|
############################
|
|
|
|
|
|
class ModelsImportForm(BaseModel):
|
|
models: list[dict]
|
|
|
|
|
|
@router.post('/import', response_model=bool)
|
|
async def import_models(
|
|
request: Request,
|
|
user=Depends(get_verified_user),
|
|
form_data: ModelsImportForm = (...),
|
|
db: AsyncSession = Depends(get_async_session),
|
|
):
|
|
if user.role != 'admin' and not await has_permission(
|
|
user.id,
|
|
'workspace.models_import',
|
|
request.app.state.config.USER_PERMISSIONS,
|
|
db=db,
|
|
):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
|
)
|
|
try:
|
|
data = form_data.models
|
|
if isinstance(data, list):
|
|
# Batch-fetch all existing models in one query to avoid N+1
|
|
model_ids = [
|
|
model_data.get('id')
|
|
for model_data in data
|
|
if model_data.get('id') and is_valid_model_id(model_data.get('id'))
|
|
]
|
|
existing_models = {
|
|
model.id: model for model in (await Models.get_models_by_ids(model_ids, db=db) if model_ids else [])
|
|
}
|
|
|
|
# Batch-resolve write permissions in one query instead of
|
|
# per-model has_access calls (N+1 avoidance).
|
|
existing_model_ids = list(existing_models.keys())
|
|
if user.role != 'admin' and existing_model_ids:
|
|
groups = await Groups.get_groups_by_member_id(user.id, db=db)
|
|
user_group_ids = {group.id for group in groups}
|
|
writable_model_ids = await AccessGrants.get_accessible_resource_ids(
|
|
user_id=user.id,
|
|
resource_type='model',
|
|
resource_ids=existing_model_ids,
|
|
permission='write',
|
|
user_group_ids=user_group_ids,
|
|
db=db,
|
|
)
|
|
else:
|
|
writable_model_ids = set(existing_model_ids)
|
|
|
|
for model_data in data:
|
|
model_id = model_data.get('id')
|
|
|
|
if model_id and is_valid_model_id(model_id):
|
|
existing_model = existing_models.get(model_id)
|
|
if existing_model:
|
|
# Enforce ownership/write-access before allowing overwrite
|
|
if (
|
|
user.role != 'admin'
|
|
and existing_model.user_id != user.id
|
|
and model_id not in writable_model_ids
|
|
):
|
|
log.warning(
|
|
'import_models: user %s skipped model %s (no write access)',
|
|
user.id,
|
|
model_id,
|
|
)
|
|
continue
|
|
|
|
# Update existing model
|
|
model_data['meta'] = model_data.get('meta', {})
|
|
model_data['params'] = model_data.get('params', {})
|
|
|
|
updated_model = ModelForm(**{**existing_model.model_dump(), **model_data})
|
|
# Only filter access_grants when explicitly provided
|
|
# in the payload to avoid altering existing ACLs on
|
|
# metadata-only imports.
|
|
if 'access_grants' in model_data:
|
|
updated_model.access_grants = await filter_allowed_access_grants(
|
|
request.app.state.config.USER_PERMISSIONS,
|
|
user.id,
|
|
user.role,
|
|
updated_model.access_grants,
|
|
'sharing.public_models',
|
|
)
|
|
await Models.update_model_by_id(model_id, updated_model, db=db)
|
|
else:
|
|
# Insert new model
|
|
model_data['meta'] = model_data.get('meta', {})
|
|
model_data['params'] = model_data.get('params', {})
|
|
new_model = ModelForm(**model_data)
|
|
new_model.access_grants = await filter_allowed_access_grants(
|
|
request.app.state.config.USER_PERMISSIONS,
|
|
user.id,
|
|
user.role,
|
|
new_model.access_grants,
|
|
'sharing.public_models',
|
|
)
|
|
await Models.insert_new_model(user_id=user.id, form_data=new_model, db=db)
|
|
return True
|
|
else:
|
|
raise HTTPException(status_code=400, detail='Invalid JSON format')
|
|
except Exception as e:
|
|
log.exception(e)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
############################
|
|
# SyncModels
|
|
############################
|
|
|
|
|
|
class SyncModelsForm(BaseModel):
|
|
models: list[ModelModel] = []
|
|
|
|
|
|
@router.post('/sync', response_model=list[ModelModel])
|
|
async def sync_models(
|
|
request: Request,
|
|
form_data: SyncModelsForm,
|
|
user=Depends(get_admin_user),
|
|
db: AsyncSession = Depends(get_async_session),
|
|
):
|
|
return await Models.sync_models(user.id, form_data.models, db=db)
|
|
|
|
|
|
###########################
|
|
# GetModelById
|
|
###########################
|
|
|
|
|
|
class ModelIdForm(BaseModel):
|
|
id: str
|
|
|
|
|
|
# Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id
|
|
@router.get('/model', response_model=Optional[ModelAccessResponse])
|
|
async def get_model_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
|
|
model = await Models.get_model_by_id(id, db=db)
|
|
if model:
|
|
if (
|
|
(user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL)
|
|
or model.user_id == user.id
|
|
or await AccessGrants.has_access(
|
|
user_id=user.id,
|
|
resource_type='model',
|
|
resource_id=model.id,
|
|
permission='read',
|
|
db=db,
|
|
)
|
|
):
|
|
return ModelAccessResponse(
|
|
**model.model_dump(),
|
|
write_access=(
|
|
(user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL)
|
|
or user.id == model.user_id
|
|
or await AccessGrants.has_access(
|
|
user_id=user.id,
|
|
resource_type='model',
|
|
resource_id=model.id,
|
|
permission='write',
|
|
db=db,
|
|
)
|
|
),
|
|
)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
|
)
|
|
|
|
|
|
###########################
|
|
# GetModelById
|
|
###########################
|
|
|
|
|
|
@router.get('/model/profile/image')
|
|
async def get_model_profile_image(
|
|
id: str,
|
|
user=Depends(get_verified_user),
|
|
db: AsyncSession = Depends(get_async_session),
|
|
):
|
|
model = await Models.get_model_by_id(id, db=db)
|
|
|
|
if model:
|
|
etag = f'"{model.updated_at}"' if model.updated_at else None
|
|
|
|
if model.meta.profile_image_url:
|
|
if model.meta.profile_image_url.startswith('http'):
|
|
return Response(
|
|
status_code=status.HTTP_302_FOUND,
|
|
headers={'Location': model.meta.profile_image_url},
|
|
)
|
|
elif model.meta.profile_image_url.startswith('data:image'):
|
|
try:
|
|
header, base64_data = model.meta.profile_image_url.split(',', 1)
|
|
image_data = base64.b64decode(base64_data)
|
|
image_buffer = io.BytesIO(image_data)
|
|
media_type = header.split(';')[0].lstrip('data:')
|
|
|
|
headers = {'Content-Disposition': 'inline'}
|
|
if etag:
|
|
headers['ETag'] = etag
|
|
|
|
return StreamingResponse(
|
|
image_buffer,
|
|
media_type=media_type,
|
|
headers=headers,
|
|
)
|
|
except Exception as e:
|
|
pass
|
|
|
|
return FileResponse(f'{STATIC_DIR}/favicon.png')
|
|
else:
|
|
return FileResponse(f'{STATIC_DIR}/favicon.png')
|
|
|
|
|
|
############################
|
|
# ToggleModelById
|
|
############################
|
|
|
|
|
|
@router.post('/model/toggle', response_model=Optional[ModelResponse])
|
|
async def toggle_model_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
|
|
model = await Models.get_model_by_id(id, db=db)
|
|
if model:
|
|
if (
|
|
user.role == 'admin'
|
|
or model.user_id == user.id
|
|
or await AccessGrants.has_access(
|
|
user_id=user.id,
|
|
resource_type='model',
|
|
resource_id=model.id,
|
|
permission='write',
|
|
db=db,
|
|
)
|
|
):
|
|
model = await Models.toggle_model_by_id(id, db=db)
|
|
|
|
if model:
|
|
return model
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=ERROR_MESSAGES.DEFAULT('Error updating function'),
|
|
)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
|
)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
|
)
|
|
|
|
|
|
############################
|
|
# UpdateModelById
|
|
############################
|
|
|
|
|
|
@router.post('/model/update', response_model=Optional[ModelModel])
|
|
async def update_model_by_id(
|
|
request: Request,
|
|
form_data: ModelForm,
|
|
user=Depends(get_verified_user),
|
|
db: AsyncSession = Depends(get_async_session),
|
|
):
|
|
model = await Models.get_model_by_id(form_data.id, db=db)
|
|
if not model:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
|
)
|
|
|
|
if (
|
|
model.user_id != user.id
|
|
and not await AccessGrants.has_access(
|
|
user_id=user.id,
|
|
resource_type='model',
|
|
resource_id=model.id,
|
|
permission='write',
|
|
db=db,
|
|
)
|
|
and user.role != 'admin'
|
|
):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
)
|
|
|
|
form_data.access_grants = await filter_allowed_access_grants(
|
|
request.app.state.config.USER_PERMISSIONS,
|
|
user.id,
|
|
user.role,
|
|
form_data.access_grants,
|
|
'sharing.public_models',
|
|
)
|
|
|
|
model = await Models.update_model_by_id(form_data.id, ModelForm(**form_data.model_dump()), db=db)
|
|
return model
|
|
|
|
|
|
############################
|
|
# UpdateModelAccessById
|
|
############################
|
|
|
|
|
|
class ModelAccessGrantsForm(BaseModel):
|
|
id: str
|
|
name: Optional[str] = None
|
|
access_grants: list[dict]
|
|
|
|
|
|
@router.post('/model/access/update', response_model=Optional[ModelModel])
|
|
async def update_model_access_by_id(
|
|
request: Request,
|
|
form_data: ModelAccessGrantsForm,
|
|
user=Depends(get_verified_user),
|
|
db: AsyncSession = Depends(get_async_session),
|
|
):
|
|
model = await Models.get_model_by_id(form_data.id, db=db)
|
|
|
|
# Non-preset models (e.g. direct Ollama/OpenAI models) may not have a DB
|
|
# entry yet. Create a minimal one so access grants can be stored.
|
|
if not model:
|
|
if user.role != 'admin':
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
)
|
|
model = await Models.insert_new_model(
|
|
ModelForm(
|
|
id=form_data.id,
|
|
name=form_data.name or form_data.id,
|
|
meta=ModelMeta(),
|
|
params=ModelParams(),
|
|
),
|
|
user.id,
|
|
db=db,
|
|
)
|
|
if not model:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=ERROR_MESSAGES.DEFAULT('Error creating model entry'),
|
|
)
|
|
|
|
if (
|
|
model.user_id != user.id
|
|
and not await AccessGrants.has_access(
|
|
user_id=user.id,
|
|
resource_type='model',
|
|
resource_id=model.id,
|
|
permission='write',
|
|
db=db,
|
|
)
|
|
and user.role != 'admin'
|
|
):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
)
|
|
|
|
form_data.access_grants = await filter_allowed_access_grants(
|
|
request.app.state.config.USER_PERMISSIONS,
|
|
user.id,
|
|
user.role,
|
|
form_data.access_grants,
|
|
'sharing.public_models',
|
|
)
|
|
|
|
await AccessGrants.set_access_grants('model', form_data.id, form_data.access_grants, db=db)
|
|
|
|
await Models.update_model_updated_at_by_id(form_data.id, db=db)
|
|
|
|
return await Models.get_model_by_id(form_data.id, db=db)
|
|
|
|
|
|
############################
|
|
# DeleteModelById
|
|
############################
|
|
|
|
|
|
@router.post('/model/delete', response_model=bool)
|
|
async def delete_model_by_id(
|
|
form_data: ModelIdForm,
|
|
user=Depends(get_verified_user),
|
|
db: AsyncSession = Depends(get_async_session),
|
|
):
|
|
model = await Models.get_model_by_id(form_data.id, db=db)
|
|
if not model:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
|
)
|
|
|
|
if (
|
|
user.role != 'admin'
|
|
and model.user_id != user.id
|
|
and not await AccessGrants.has_access(
|
|
user_id=user.id,
|
|
resource_type='model',
|
|
resource_id=model.id,
|
|
permission='write',
|
|
db=db,
|
|
)
|
|
):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
|
)
|
|
|
|
result = await Models.delete_model_by_id(form_data.id, db=db)
|
|
return result
|
|
|
|
|
|
@router.delete('/delete/all', response_model=bool)
|
|
async def delete_all_models(user=Depends(get_admin_user), db: AsyncSession = Depends(get_async_session)):
|
|
result = await Models.delete_all_models(db=db)
|
|
return result
|