mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-01 01:39:05 -05:00
700 lines
26 KiB
Python
700 lines
26 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, Response, status, Request
|
|
from fastapi.responses import JSONResponse, RedirectResponse
|
|
|
|
from pydantic import BaseModel
|
|
from typing import Optional
|
|
import logging
|
|
import re
|
|
|
|
from open_webui.utils.chat import generate_chat_completion
|
|
from open_webui.utils.task import (
|
|
title_generation_template,
|
|
follow_up_generation_template,
|
|
query_generation_template,
|
|
image_prompt_generation_template,
|
|
autocomplete_generation_template,
|
|
tags_generation_template,
|
|
emoji_generation_template,
|
|
moa_response_generation_template,
|
|
)
|
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
|
from open_webui.constants import ERROR_MESSAGES, TASKS
|
|
|
|
from open_webui.routers.pipelines import process_pipeline_inlet_filter
|
|
|
|
from open_webui.utils.task import get_task_model_id
|
|
|
|
from open_webui.config import (
|
|
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
|
|
DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
|
|
DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
|
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
|
|
DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE,
|
|
DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE,
|
|
DEFAULT_VOICE_MODE_PROMPT_TEMPLATE,
|
|
)
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
##################################
|
|
#
|
|
# Task Endpoints
|
|
#
|
|
##################################
|
|
|
|
|
|
class ActiveChatsForm(BaseModel):
|
|
chat_ids: list[str]
|
|
|
|
|
|
@router.post('/active/chats')
|
|
async def check_active_chats(request: Request, form_data: ActiveChatsForm, user=Depends(get_verified_user)):
|
|
"""Check which chat IDs have active tasks."""
|
|
from open_webui.tasks import get_active_chat_ids
|
|
|
|
active = await get_active_chat_ids(request.app.state.redis, form_data.chat_ids)
|
|
return {'active_chat_ids': active}
|
|
|
|
|
|
@router.get('/config')
|
|
async def get_task_config(request: Request, user=Depends(get_verified_user)):
|
|
return {
|
|
'TASK_MODEL': request.app.state.config.TASK_MODEL,
|
|
'TASK_MODEL_EXTERNAL': request.app.state.config.TASK_MODEL_EXTERNAL,
|
|
'TITLE_GENERATION_PROMPT_TEMPLATE': request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
'IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE': request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
|
'ENABLE_AUTOCOMPLETE_GENERATION': request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
|
'AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH': request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
|
'TAGS_GENERATION_PROMPT_TEMPLATE': request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
|
'FOLLOW_UP_GENERATION_PROMPT_TEMPLATE': request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
|
|
'ENABLE_FOLLOW_UP_GENERATION': request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
|
|
'ENABLE_TAGS_GENERATION': request.app.state.config.ENABLE_TAGS_GENERATION,
|
|
'ENABLE_TITLE_GENERATION': request.app.state.config.ENABLE_TITLE_GENERATION,
|
|
'ENABLE_SEARCH_QUERY_GENERATION': request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
|
'ENABLE_RETRIEVAL_QUERY_GENERATION': request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
|
'QUERY_GENERATION_PROMPT_TEMPLATE': request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
'TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE': request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
'VOICE_MODE_PROMPT_TEMPLATE': request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE,
|
|
}
|
|
|
|
|
|
class TaskConfigForm(BaseModel):
|
|
TASK_MODEL: Optional[str]
|
|
TASK_MODEL_EXTERNAL: Optional[str]
|
|
ENABLE_TITLE_GENERATION: bool
|
|
TITLE_GENERATION_PROMPT_TEMPLATE: str
|
|
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str
|
|
ENABLE_AUTOCOMPLETE_GENERATION: bool
|
|
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
|
|
TAGS_GENERATION_PROMPT_TEMPLATE: str
|
|
FOLLOW_UP_GENERATION_PROMPT_TEMPLATE: str
|
|
ENABLE_FOLLOW_UP_GENERATION: bool
|
|
ENABLE_TAGS_GENERATION: bool
|
|
ENABLE_SEARCH_QUERY_GENERATION: bool
|
|
ENABLE_RETRIEVAL_QUERY_GENERATION: bool
|
|
QUERY_GENERATION_PROMPT_TEMPLATE: str
|
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
|
|
VOICE_MODE_PROMPT_TEMPLATE: Optional[str]
|
|
|
|
|
|
@router.post('/config/update')
|
|
async def update_task_config(request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user)):
|
|
request.app.state.config.TASK_MODEL = form_data.TASK_MODEL
|
|
request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
|
|
request.app.state.config.ENABLE_TITLE_GENERATION = form_data.ENABLE_TITLE_GENERATION
|
|
request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = form_data.TITLE_GENERATION_PROMPT_TEMPLATE
|
|
|
|
request.app.state.config.ENABLE_FOLLOW_UP_GENERATION = form_data.ENABLE_FOLLOW_UP_GENERATION
|
|
request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = form_data.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
|
|
|
|
request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
|
|
|
|
request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = form_data.ENABLE_AUTOCOMPLETE_GENERATION
|
|
request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
|
|
form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
|
|
)
|
|
|
|
request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = form_data.TAGS_GENERATION_PROMPT_TEMPLATE
|
|
request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
|
|
request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = form_data.ENABLE_SEARCH_QUERY_GENERATION
|
|
request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
|
|
|
|
request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = form_data.QUERY_GENERATION_PROMPT_TEMPLATE
|
|
request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
|
|
|
request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE = form_data.VOICE_MODE_PROMPT_TEMPLATE
|
|
|
|
return {
|
|
'TASK_MODEL': request.app.state.config.TASK_MODEL,
|
|
'TASK_MODEL_EXTERNAL': request.app.state.config.TASK_MODEL_EXTERNAL,
|
|
'ENABLE_TITLE_GENERATION': request.app.state.config.ENABLE_TITLE_GENERATION,
|
|
'TITLE_GENERATION_PROMPT_TEMPLATE': request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
'IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE': request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
|
'ENABLE_AUTOCOMPLETE_GENERATION': request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
|
'AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH': request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
|
'TAGS_GENERATION_PROMPT_TEMPLATE': request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
|
'ENABLE_TAGS_GENERATION': request.app.state.config.ENABLE_TAGS_GENERATION,
|
|
'ENABLE_FOLLOW_UP_GENERATION': request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
|
|
'FOLLOW_UP_GENERATION_PROMPT_TEMPLATE': request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
|
|
'ENABLE_SEARCH_QUERY_GENERATION': request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
|
'ENABLE_RETRIEVAL_QUERY_GENERATION': request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
|
'QUERY_GENERATION_PROMPT_TEMPLATE': request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
'TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE': request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
'VOICE_MODE_PROMPT_TEMPLATE': request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE,
|
|
}
|
|
|
|
|
|
@router.post('/title/completions')
|
|
async def generate_title(request: Request, form_data: dict, user=Depends(get_verified_user)):
|
|
if not request.app.state.config.ENABLE_TITLE_GENERATION:
|
|
return JSONResponse(
|
|
status_code=status.HTTP_200_OK,
|
|
content={'detail': 'Title generation is disabled'},
|
|
)
|
|
|
|
if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'):
|
|
models = {
|
|
request.state.model['id']: request.state.model,
|
|
}
|
|
else:
|
|
models = request.app.state.MODELS
|
|
|
|
model_id = form_data['model']
|
|
if model_id not in models:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(),
|
|
)
|
|
|
|
# Check if the user has a custom task model
|
|
# If the user has a custom task model, use that model
|
|
task_model_id = get_task_model_id(
|
|
model_id,
|
|
request.app.state.config.TASK_MODEL,
|
|
request.app.state.config.TASK_MODEL_EXTERNAL,
|
|
models,
|
|
)
|
|
|
|
log.debug(f'generating chat title using model {task_model_id} for user {user.email} ')
|
|
|
|
if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != '':
|
|
template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
|
|
else:
|
|
template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE
|
|
|
|
content = title_generation_template(template, form_data['messages'], user)
|
|
|
|
max_tokens = models[task_model_id].get('info', {}).get('params', {}).get('max_tokens', 1000)
|
|
|
|
payload = {
|
|
'model': task_model_id,
|
|
'messages': [{'role': 'user', 'content': content}],
|
|
'stream': False,
|
|
**(
|
|
{'max_tokens': max_tokens}
|
|
if models[task_model_id].get('owned_by') == 'ollama'
|
|
else {
|
|
'max_completion_tokens': max_tokens,
|
|
}
|
|
),
|
|
'metadata': {
|
|
**(request.state.metadata if hasattr(request.state, 'metadata') else {}),
|
|
'task': str(TASKS.TITLE_GENERATION),
|
|
'task_body': form_data,
|
|
'chat_id': form_data.get('chat_id', None),
|
|
},
|
|
}
|
|
|
|
# Process the payload through the pipeline
|
|
try:
|
|
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
|
except Exception as e:
|
|
raise e
|
|
|
|
try:
|
|
return await generate_chat_completion(request, form_data=payload, user=user)
|
|
except Exception as e:
|
|
log.error('Exception occurred', exc_info=True)
|
|
return JSONResponse(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
content={'detail': 'An internal error has occurred.'},
|
|
)
|
|
|
|
|
|
@router.post('/follow_up/completions')
|
|
async def generate_follow_ups(request: Request, form_data: dict, user=Depends(get_verified_user)):
|
|
if not request.app.state.config.ENABLE_FOLLOW_UP_GENERATION:
|
|
return JSONResponse(
|
|
status_code=status.HTTP_200_OK,
|
|
content={'detail': 'Follow-up generation is disabled'},
|
|
)
|
|
|
|
if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'):
|
|
models = {
|
|
request.state.model['id']: request.state.model,
|
|
}
|
|
else:
|
|
models = request.app.state.MODELS
|
|
|
|
model_id = form_data['model']
|
|
if model_id not in models:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(),
|
|
)
|
|
|
|
# Check if the user has a custom task model
|
|
# If the user has a custom task model, use that model
|
|
task_model_id = get_task_model_id(
|
|
model_id,
|
|
request.app.state.config.TASK_MODEL,
|
|
request.app.state.config.TASK_MODEL_EXTERNAL,
|
|
models,
|
|
)
|
|
|
|
log.debug(f'generating chat title using model {task_model_id} for user {user.email} ')
|
|
|
|
if request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE != '':
|
|
template = request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
|
|
else:
|
|
template = DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
|
|
|
|
content = follow_up_generation_template(template, form_data['messages'], user)
|
|
|
|
payload = {
|
|
'model': task_model_id,
|
|
'messages': [{'role': 'user', 'content': content}],
|
|
'stream': False,
|
|
'metadata': {
|
|
**(request.state.metadata if hasattr(request.state, 'metadata') else {}),
|
|
'task': str(TASKS.FOLLOW_UP_GENERATION),
|
|
'task_body': form_data,
|
|
'chat_id': form_data.get('chat_id', None),
|
|
},
|
|
}
|
|
|
|
# Process the payload through the pipeline
|
|
try:
|
|
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
|
except Exception as e:
|
|
raise e
|
|
|
|
try:
|
|
return await generate_chat_completion(request, form_data=payload, user=user)
|
|
except Exception as e:
|
|
log.error('Exception occurred', exc_info=True)
|
|
return JSONResponse(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
content={'detail': 'An internal error has occurred.'},
|
|
)
|
|
|
|
|
|
@router.post('/tags/completions')
|
|
async def generate_chat_tags(request: Request, form_data: dict, user=Depends(get_verified_user)):
|
|
if not request.app.state.config.ENABLE_TAGS_GENERATION:
|
|
return JSONResponse(
|
|
status_code=status.HTTP_200_OK,
|
|
content={'detail': 'Tags generation is disabled'},
|
|
)
|
|
|
|
if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'):
|
|
models = {
|
|
request.state.model['id']: request.state.model,
|
|
}
|
|
else:
|
|
models = request.app.state.MODELS
|
|
|
|
model_id = form_data['model']
|
|
if model_id not in models:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(),
|
|
)
|
|
|
|
# Check if the user has a custom task model
|
|
# If the user has a custom task model, use that model
|
|
task_model_id = get_task_model_id(
|
|
model_id,
|
|
request.app.state.config.TASK_MODEL,
|
|
request.app.state.config.TASK_MODEL_EXTERNAL,
|
|
models,
|
|
)
|
|
|
|
log.debug(f'generating chat tags using model {task_model_id} for user {user.email} ')
|
|
|
|
if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != '':
|
|
template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE
|
|
else:
|
|
template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE
|
|
|
|
content = tags_generation_template(template, form_data['messages'], user)
|
|
|
|
payload = {
|
|
'model': task_model_id,
|
|
'messages': [{'role': 'user', 'content': content}],
|
|
'stream': False,
|
|
'metadata': {
|
|
**(request.state.metadata if hasattr(request.state, 'metadata') else {}),
|
|
'task': str(TASKS.TAGS_GENERATION),
|
|
'task_body': form_data,
|
|
'chat_id': form_data.get('chat_id', None),
|
|
},
|
|
}
|
|
|
|
# Process the payload through the pipeline
|
|
try:
|
|
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
|
except Exception as e:
|
|
raise e
|
|
|
|
try:
|
|
return await generate_chat_completion(request, form_data=payload, user=user)
|
|
except Exception as e:
|
|
log.error(f'Error generating chat completion: {e}')
|
|
return JSONResponse(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
content={'detail': 'An internal error has occurred.'},
|
|
)
|
|
|
|
|
|
@router.post('/image_prompt/completions')
|
|
async def generate_image_prompt(request: Request, form_data: dict, user=Depends(get_verified_user)):
|
|
if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'):
|
|
models = {
|
|
request.state.model['id']: request.state.model,
|
|
}
|
|
else:
|
|
models = request.app.state.MODELS
|
|
|
|
model_id = form_data['model']
|
|
if model_id not in models:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(),
|
|
)
|
|
|
|
# Check if the user has a custom task model
|
|
# If the user has a custom task model, use that model
|
|
task_model_id = get_task_model_id(
|
|
model_id,
|
|
request.app.state.config.TASK_MODEL,
|
|
request.app.state.config.TASK_MODEL_EXTERNAL,
|
|
models,
|
|
)
|
|
|
|
log.debug(f'generating image prompt using model {task_model_id} for user {user.email} ')
|
|
|
|
if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != '':
|
|
template = request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
|
|
else:
|
|
template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
|
|
|
|
content = image_prompt_generation_template(template, form_data['messages'], user)
|
|
|
|
payload = {
|
|
'model': task_model_id,
|
|
'messages': [{'role': 'user', 'content': content}],
|
|
'stream': False,
|
|
'metadata': {
|
|
**(request.state.metadata if hasattr(request.state, 'metadata') else {}),
|
|
'task': str(TASKS.IMAGE_PROMPT_GENERATION),
|
|
'task_body': form_data,
|
|
'chat_id': form_data.get('chat_id', None),
|
|
},
|
|
}
|
|
|
|
# Process the payload through the pipeline
|
|
try:
|
|
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
|
except Exception as e:
|
|
raise e
|
|
|
|
try:
|
|
return await generate_chat_completion(request, form_data=payload, user=user)
|
|
except Exception as e:
|
|
log.error('Exception occurred', exc_info=True)
|
|
return JSONResponse(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
content={'detail': 'An internal error has occurred.'},
|
|
)
|
|
|
|
|
|
@router.post('/queries/completions')
|
|
async def generate_queries(request: Request, form_data: dict, user=Depends(get_verified_user)):
|
|
type = form_data.get('type')
|
|
if type == 'web_search':
|
|
if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=ERROR_MESSAGES.FEATURE_DISABLED('Search query generation'),
|
|
)
|
|
elif type == 'retrieval':
|
|
if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=ERROR_MESSAGES.FEATURE_DISABLED('Query generation'),
|
|
)
|
|
|
|
if getattr(request.state, 'cached_queries', None):
|
|
log.info(f'Reusing cached queries: {request.state.cached_queries}')
|
|
return request.state.cached_queries
|
|
|
|
if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'):
|
|
models = {
|
|
request.state.model['id']: request.state.model,
|
|
}
|
|
else:
|
|
models = request.app.state.MODELS
|
|
|
|
model_id = form_data['model']
|
|
if model_id not in models:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(),
|
|
)
|
|
|
|
# Check if the user has a custom task model
|
|
# If the user has a custom task model, use that model
|
|
task_model_id = get_task_model_id(
|
|
model_id,
|
|
request.app.state.config.TASK_MODEL,
|
|
request.app.state.config.TASK_MODEL_EXTERNAL,
|
|
models,
|
|
)
|
|
|
|
log.debug(f'generating {type} queries using model {task_model_id} for user {user.email}')
|
|
|
|
if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != '':
|
|
template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE
|
|
else:
|
|
template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
|
|
|
|
content = query_generation_template(template, form_data['messages'], user)
|
|
|
|
payload = {
|
|
'model': task_model_id,
|
|
'messages': [{'role': 'user', 'content': content}],
|
|
'stream': False,
|
|
'metadata': {
|
|
**(request.state.metadata if hasattr(request.state, 'metadata') else {}),
|
|
'task': str(TASKS.QUERY_GENERATION),
|
|
'task_body': form_data,
|
|
'chat_id': form_data.get('chat_id', None),
|
|
},
|
|
}
|
|
|
|
# Process the payload through the pipeline
|
|
try:
|
|
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
|
except Exception as e:
|
|
raise e
|
|
|
|
try:
|
|
return await generate_chat_completion(request, form_data=payload, user=user)
|
|
except Exception as e:
|
|
return JSONResponse(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
content={'detail': str(e)},
|
|
)
|
|
|
|
|
|
@router.post('/auto/completions')
|
|
async def generate_autocompletion(request: Request, form_data: dict, user=Depends(get_verified_user)):
|
|
if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=ERROR_MESSAGES.FEATURE_DISABLED('Autocompletion generation'),
|
|
)
|
|
|
|
type = form_data.get('type')
|
|
prompt = form_data.get('prompt')
|
|
messages = form_data.get('messages')
|
|
|
|
if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0:
|
|
if len(prompt) > request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=ERROR_MESSAGES.INPUT_TOO_LONG(request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH),
|
|
)
|
|
|
|
if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'):
|
|
models = {
|
|
request.state.model['id']: request.state.model,
|
|
}
|
|
else:
|
|
models = request.app.state.MODELS
|
|
|
|
model_id = form_data['model']
|
|
if model_id not in models:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(),
|
|
)
|
|
|
|
# Check if the user has a custom task model
|
|
# If the user has a custom task model, use that model
|
|
task_model_id = get_task_model_id(
|
|
model_id,
|
|
request.app.state.config.TASK_MODEL,
|
|
request.app.state.config.TASK_MODEL_EXTERNAL,
|
|
models,
|
|
)
|
|
|
|
log.debug(f'generating autocompletion using model {task_model_id} for user {user.email}')
|
|
|
|
if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != '':
|
|
template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
|
|
else:
|
|
template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
|
|
|
|
content = autocomplete_generation_template(template, prompt, messages, type, user)
|
|
|
|
payload = {
|
|
'model': task_model_id,
|
|
'messages': [{'role': 'user', 'content': content}],
|
|
'stream': False,
|
|
'metadata': {
|
|
**(request.state.metadata if hasattr(request.state, 'metadata') else {}),
|
|
'task': str(TASKS.AUTOCOMPLETE_GENERATION),
|
|
'task_body': form_data,
|
|
'chat_id': form_data.get('chat_id', None),
|
|
},
|
|
}
|
|
|
|
# Process the payload through the pipeline
|
|
try:
|
|
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
|
except Exception as e:
|
|
raise e
|
|
|
|
try:
|
|
return await generate_chat_completion(request, form_data=payload, user=user)
|
|
except Exception as e:
|
|
log.error(f'Error generating chat completion: {e}')
|
|
return JSONResponse(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
content={'detail': 'An internal error has occurred.'},
|
|
)
|
|
|
|
|
|
@router.post('/emoji/completions')
|
|
async def generate_emoji(request: Request, form_data: dict, user=Depends(get_verified_user)):
|
|
if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'):
|
|
models = {
|
|
request.state.model['id']: request.state.model,
|
|
}
|
|
else:
|
|
models = request.app.state.MODELS
|
|
|
|
model_id = form_data['model']
|
|
if model_id not in models:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(),
|
|
)
|
|
|
|
# Check if the user has a custom task model
|
|
# If the user has a custom task model, use that model
|
|
task_model_id = get_task_model_id(
|
|
model_id,
|
|
request.app.state.config.TASK_MODEL,
|
|
request.app.state.config.TASK_MODEL_EXTERNAL,
|
|
models,
|
|
)
|
|
|
|
log.debug(f'generating emoji using model {task_model_id} for user {user.email} ')
|
|
|
|
template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE
|
|
|
|
content = emoji_generation_template(template, form_data['prompt'], user)
|
|
|
|
payload = {
|
|
'model': task_model_id,
|
|
'messages': [{'role': 'user', 'content': content}],
|
|
'stream': False,
|
|
**(
|
|
{'max_tokens': 4}
|
|
if models[task_model_id].get('owned_by') == 'ollama'
|
|
else {
|
|
'max_completion_tokens': 4,
|
|
}
|
|
),
|
|
'metadata': {
|
|
**(request.state.metadata if hasattr(request.state, 'metadata') else {}),
|
|
'task': str(TASKS.EMOJI_GENERATION),
|
|
'task_body': form_data,
|
|
'chat_id': form_data.get('chat_id', None),
|
|
},
|
|
}
|
|
|
|
# Process the payload through the pipeline
|
|
try:
|
|
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
|
except Exception as e:
|
|
raise e
|
|
|
|
try:
|
|
return await generate_chat_completion(request, form_data=payload, user=user)
|
|
except Exception as e:
|
|
return JSONResponse(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
content={'detail': str(e)},
|
|
)
|
|
|
|
|
|
@router.post('/moa/completions')
|
|
async def generate_moa_response(request: Request, form_data: dict, user=Depends(get_verified_user)):
|
|
if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'):
|
|
models = {
|
|
request.state.model['id']: request.state.model,
|
|
}
|
|
else:
|
|
models = request.app.state.MODELS
|
|
|
|
model_id = form_data['model']
|
|
|
|
if model_id not in models:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(),
|
|
)
|
|
|
|
template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE
|
|
|
|
content = moa_response_generation_template(
|
|
template,
|
|
form_data['prompt'],
|
|
form_data['responses'],
|
|
)
|
|
|
|
payload = {
|
|
'model': model_id,
|
|
'messages': [{'role': 'user', 'content': content}],
|
|
'stream': form_data.get('stream', False),
|
|
'metadata': {
|
|
**(request.state.metadata if hasattr(request.state, 'metadata') else {}),
|
|
'chat_id': form_data.get('chat_id', None),
|
|
'task': str(TASKS.MOA_RESPONSE_GENERATION),
|
|
'task_body': form_data,
|
|
},
|
|
}
|
|
|
|
# Process the payload through the pipeline
|
|
try:
|
|
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
|
except Exception as e:
|
|
raise e
|
|
|
|
try:
|
|
return await generate_chat_completion(request, form_data=payload, user=user)
|
|
except Exception as e:
|
|
return JSONResponse(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
content={'detail': str(e)},
|
|
)
|