mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-05 18:38:17 -05:00
refac
This commit is contained in:
@@ -52,36 +52,34 @@ class ActiveChatsForm(BaseModel):
|
||||
chat_ids: list[str]
|
||||
|
||||
|
||||
@router.post("/active/chats")
|
||||
async def check_active_chats(
|
||||
request: Request, form_data: ActiveChatsForm, user=Depends(get_verified_user)
|
||||
):
|
||||
@router.post('/active/chats')
|
||||
async def check_active_chats(request: Request, form_data: ActiveChatsForm, user=Depends(get_verified_user)):
|
||||
"""Check which chat IDs have active tasks."""
|
||||
from open_webui.tasks import get_active_chat_ids
|
||||
|
||||
active = await get_active_chat_ids(request.app.state.redis, form_data.chat_ids)
|
||||
return {"active_chat_ids": active}
|
||||
return {'active_chat_ids': active}
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
@router.get('/config')
|
||||
async def get_task_config(request: Request, user=Depends(get_verified_user)):
|
||||
return {
|
||||
"TASK_MODEL": request.app.state.config.TASK_MODEL,
|
||||
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
"IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
||||
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
||||
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
"FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
|
||||
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
|
||||
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
|
||||
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
||||
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
"VOICE_MODE_PROMPT_TEMPLATE": request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE,
|
||||
'TASK_MODEL': request.app.state.config.TASK_MODEL,
|
||||
'TASK_MODEL_EXTERNAL': request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
'TITLE_GENERATION_PROMPT_TEMPLATE': request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
'IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE': request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
||||
'ENABLE_AUTOCOMPLETE_GENERATION': request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
||||
'AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH': request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
||||
'TAGS_GENERATION_PROMPT_TEMPLATE': request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
'FOLLOW_UP_GENERATION_PROMPT_TEMPLATE': request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
|
||||
'ENABLE_FOLLOW_UP_GENERATION': request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
|
||||
'ENABLE_TAGS_GENERATION': request.app.state.config.ENABLE_TAGS_GENERATION,
|
||||
'ENABLE_TITLE_GENERATION': request.app.state.config.ENABLE_TITLE_GENERATION,
|
||||
'ENABLE_SEARCH_QUERY_GENERATION': request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
||||
'ENABLE_RETRIEVAL_QUERY_GENERATION': request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||
'QUERY_GENERATION_PROMPT_TEMPLATE': request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
'TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE': request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
'VOICE_MODE_PROMPT_TEMPLATE': request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE,
|
||||
}
|
||||
|
||||
|
||||
@@ -104,100 +102,73 @@ class TaskConfigForm(BaseModel):
|
||||
VOICE_MODE_PROMPT_TEMPLATE: Optional[str]
|
||||
|
||||
|
||||
@router.post("/config/update")
|
||||
async def update_task_config(
|
||||
request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user)
|
||||
):
|
||||
@router.post('/config/update')
|
||||
async def update_task_config(request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user)):
|
||||
request.app.state.config.TASK_MODEL = form_data.TASK_MODEL
|
||||
request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
|
||||
request.app.state.config.ENABLE_TITLE_GENERATION = form_data.ENABLE_TITLE_GENERATION
|
||||
request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = form_data.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
request.app.state.config.ENABLE_FOLLOW_UP_GENERATION = (
|
||||
form_data.ENABLE_FOLLOW_UP_GENERATION
|
||||
)
|
||||
request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
request.app.state.config.ENABLE_FOLLOW_UP_GENERATION = form_data.ENABLE_FOLLOW_UP_GENERATION
|
||||
request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = form_data.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = (
|
||||
form_data.ENABLE_AUTOCOMPLETE_GENERATION
|
||||
)
|
||||
request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = form_data.ENABLE_AUTOCOMPLETE_GENERATION
|
||||
request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
|
||||
form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
|
||||
)
|
||||
|
||||
request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.TAGS_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = form_data.TAGS_GENERATION_PROMPT_TEMPLATE
|
||||
request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
|
||||
request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = (
|
||||
form_data.ENABLE_SEARCH_QUERY_GENERATION
|
||||
)
|
||||
request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = (
|
||||
form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
|
||||
)
|
||||
request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = form_data.ENABLE_SEARCH_QUERY_GENERATION
|
||||
request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
|
||||
|
||||
request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
||||
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||
)
|
||||
request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = form_data.QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||
|
||||
request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE = (
|
||||
form_data.VOICE_MODE_PROMPT_TEMPLATE
|
||||
)
|
||||
request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE = form_data.VOICE_MODE_PROMPT_TEMPLATE
|
||||
|
||||
return {
|
||||
"TASK_MODEL": request.app.state.config.TASK_MODEL,
|
||||
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
|
||||
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
"IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
||||
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
||||
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
|
||||
"ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
|
||||
"FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
||||
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
"VOICE_MODE_PROMPT_TEMPLATE": request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE,
|
||||
'TASK_MODEL': request.app.state.config.TASK_MODEL,
|
||||
'TASK_MODEL_EXTERNAL': request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
'ENABLE_TITLE_GENERATION': request.app.state.config.ENABLE_TITLE_GENERATION,
|
||||
'TITLE_GENERATION_PROMPT_TEMPLATE': request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
'IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE': request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
||||
'ENABLE_AUTOCOMPLETE_GENERATION': request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
||||
'AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH': request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
||||
'TAGS_GENERATION_PROMPT_TEMPLATE': request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
'ENABLE_TAGS_GENERATION': request.app.state.config.ENABLE_TAGS_GENERATION,
|
||||
'ENABLE_FOLLOW_UP_GENERATION': request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
|
||||
'FOLLOW_UP_GENERATION_PROMPT_TEMPLATE': request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
|
||||
'ENABLE_SEARCH_QUERY_GENERATION': request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
||||
'ENABLE_RETRIEVAL_QUERY_GENERATION': request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||
'QUERY_GENERATION_PROMPT_TEMPLATE': request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
'TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE': request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
'VOICE_MODE_PROMPT_TEMPLATE': request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/title/completions")
|
||||
async def generate_title(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
|
||||
@router.post('/title/completions')
|
||||
async def generate_title(request: Request, form_data: dict, user=Depends(get_verified_user)):
|
||||
if not request.app.state.config.ENABLE_TITLE_GENERATION:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={"detail": "Title generation is disabled"},
|
||||
content={'detail': 'Title generation is disabled'},
|
||||
)
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
request.state.model['id']: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
model_id = form_data['model']
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
detail='Model not found',
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
@@ -209,37 +180,33 @@ async def generate_title(
|
||||
models,
|
||||
)
|
||||
|
||||
log.debug(
|
||||
f"generating chat title using model {task_model_id} for user {user.email} "
|
||||
)
|
||||
log.debug(f'generating chat title using model {task_model_id} for user {user.email} ')
|
||||
|
||||
if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
|
||||
if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != '':
|
||||
template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
else:
|
||||
template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = title_generation_template(template, form_data["messages"], user)
|
||||
content = title_generation_template(template, form_data['messages'], user)
|
||||
|
||||
max_tokens = (
|
||||
models[task_model_id].get("info", {}).get("params", {}).get("max_tokens", 1000)
|
||||
)
|
||||
max_tokens = models[task_model_id].get('info', {}).get('params', {}).get('max_tokens', 1000)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
'model': task_model_id,
|
||||
'messages': [{'role': 'user', 'content': content}],
|
||||
'stream': False,
|
||||
**(
|
||||
{"max_tokens": max_tokens}
|
||||
if models[task_model_id].get("owned_by") == "ollama"
|
||||
{'max_tokens': max_tokens}
|
||||
if models[task_model_id].get('owned_by') == 'ollama'
|
||||
else {
|
||||
"max_completion_tokens": max_tokens,
|
||||
'max_completion_tokens': max_tokens,
|
||||
}
|
||||
),
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.TITLE_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
'metadata': {
|
||||
**(request.state.metadata if hasattr(request.state, 'metadata') else {}),
|
||||
'task': str(TASKS.TITLE_GENERATION),
|
||||
'task_body': form_data,
|
||||
'chat_id': form_data.get('chat_id', None),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -252,36 +219,33 @@ async def generate_title(
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
log.error("Exception occurred", exc_info=True)
|
||||
log.error('Exception occurred', exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": "An internal error has occurred."},
|
||||
content={'detail': 'An internal error has occurred.'},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/follow_up/completions")
|
||||
async def generate_follow_ups(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
|
||||
@router.post('/follow_up/completions')
|
||||
async def generate_follow_ups(request: Request, form_data: dict, user=Depends(get_verified_user)):
|
||||
if not request.app.state.config.ENABLE_FOLLOW_UP_GENERATION:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={"detail": "Follow-up generation is disabled"},
|
||||
content={'detail': 'Follow-up generation is disabled'},
|
||||
)
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
request.state.model['id']: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
model_id = form_data['model']
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
detail='Model not found',
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
@@ -293,26 +257,24 @@ async def generate_follow_ups(
|
||||
models,
|
||||
)
|
||||
|
||||
log.debug(
|
||||
f"generating chat title using model {task_model_id} for user {user.email} "
|
||||
)
|
||||
log.debug(f'generating chat title using model {task_model_id} for user {user.email} ')
|
||||
|
||||
if request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE != "":
|
||||
if request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE != '':
|
||||
template = request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
|
||||
else:
|
||||
template = DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = follow_up_generation_template(template, form_data["messages"], user)
|
||||
content = follow_up_generation_template(template, form_data['messages'], user)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.FOLLOW_UP_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
'model': task_model_id,
|
||||
'messages': [{'role': 'user', 'content': content}],
|
||||
'stream': False,
|
||||
'metadata': {
|
||||
**(request.state.metadata if hasattr(request.state, 'metadata') else {}),
|
||||
'task': str(TASKS.FOLLOW_UP_GENERATION),
|
||||
'task_body': form_data,
|
||||
'chat_id': form_data.get('chat_id', None),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -325,36 +287,33 @@ async def generate_follow_ups(
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
log.error("Exception occurred", exc_info=True)
|
||||
log.error('Exception occurred', exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": "An internal error has occurred."},
|
||||
content={'detail': 'An internal error has occurred.'},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/tags/completions")
|
||||
async def generate_chat_tags(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
|
||||
@router.post('/tags/completions')
|
||||
async def generate_chat_tags(request: Request, form_data: dict, user=Depends(get_verified_user)):
|
||||
if not request.app.state.config.ENABLE_TAGS_GENERATION:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={"detail": "Tags generation is disabled"},
|
||||
content={'detail': 'Tags generation is disabled'},
|
||||
)
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
request.state.model['id']: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
model_id = form_data['model']
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
detail='Model not found',
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
@@ -366,26 +325,24 @@ async def generate_chat_tags(
|
||||
models,
|
||||
)
|
||||
|
||||
log.debug(
|
||||
f"generating chat tags using model {task_model_id} for user {user.email} "
|
||||
)
|
||||
log.debug(f'generating chat tags using model {task_model_id} for user {user.email} ')
|
||||
|
||||
if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
|
||||
if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != '':
|
||||
template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE
|
||||
else:
|
||||
template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = tags_generation_template(template, form_data["messages"], user)
|
||||
content = tags_generation_template(template, form_data['messages'], user)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.TAGS_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
'model': task_model_id,
|
||||
'messages': [{'role': 'user', 'content': content}],
|
||||
'stream': False,
|
||||
'metadata': {
|
||||
**(request.state.metadata if hasattr(request.state, 'metadata') else {}),
|
||||
'task': str(TASKS.TAGS_GENERATION),
|
||||
'task_body': form_data,
|
||||
'chat_id': form_data.get('chat_id', None),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -398,29 +355,27 @@ async def generate_chat_tags(
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
log.error(f"Error generating chat completion: {e}")
|
||||
log.error(f'Error generating chat completion: {e}')
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={"detail": "An internal error has occurred."},
|
||||
content={'detail': 'An internal error has occurred.'},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/image_prompt/completions")
|
||||
async def generate_image_prompt(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
@router.post('/image_prompt/completions')
|
||||
async def generate_image_prompt(request: Request, form_data: dict, user=Depends(get_verified_user)):
|
||||
if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
request.state.model['id']: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
model_id = form_data['model']
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
detail='Model not found',
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
@@ -432,26 +387,24 @@ async def generate_image_prompt(
|
||||
models,
|
||||
)
|
||||
|
||||
log.debug(
|
||||
f"generating image prompt using model {task_model_id} for user {user.email} "
|
||||
)
|
||||
log.debug(f'generating image prompt using model {task_model_id} for user {user.email} ')
|
||||
|
||||
if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != "":
|
||||
if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != '':
|
||||
template = request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
|
||||
else:
|
||||
template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = image_prompt_generation_template(template, form_data["messages"], user)
|
||||
content = image_prompt_generation_template(template, form_data['messages'], user)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.IMAGE_PROMPT_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
'model': task_model_id,
|
||||
'messages': [{'role': 'user', 'content': content}],
|
||||
'stream': False,
|
||||
'metadata': {
|
||||
**(request.state.metadata if hasattr(request.state, 'metadata') else {}),
|
||||
'task': str(TASKS.IMAGE_PROMPT_GENERATION),
|
||||
'task_body': form_data,
|
||||
'chat_id': form_data.get('chat_id', None),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -464,48 +417,45 @@ async def generate_image_prompt(
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
log.error("Exception occurred", exc_info=True)
|
||||
log.error('Exception occurred', exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": "An internal error has occurred."},
|
||||
content={'detail': 'An internal error has occurred.'},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/queries/completions")
|
||||
async def generate_queries(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
|
||||
type = form_data.get("type")
|
||||
if type == "web_search":
|
||||
@router.post('/queries/completions')
|
||||
async def generate_queries(request: Request, form_data: dict, user=Depends(get_verified_user)):
|
||||
type = form_data.get('type')
|
||||
if type == 'web_search':
|
||||
if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Search query generation is disabled",
|
||||
detail=f'Search query generation is disabled',
|
||||
)
|
||||
elif type == "retrieval":
|
||||
elif type == 'retrieval':
|
||||
if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Query generation is disabled",
|
||||
detail=f'Query generation is disabled',
|
||||
)
|
||||
|
||||
if getattr(request.state, "cached_queries", None):
|
||||
log.info(f"Reusing cached queries: {request.state.cached_queries}")
|
||||
if getattr(request.state, 'cached_queries', None):
|
||||
log.info(f'Reusing cached queries: {request.state.cached_queries}')
|
||||
return request.state.cached_queries
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
request.state.model['id']: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
model_id = form_data['model']
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
detail='Model not found',
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
@@ -517,26 +467,24 @@ async def generate_queries(
|
||||
models,
|
||||
)
|
||||
|
||||
log.debug(
|
||||
f"generating {type} queries using model {task_model_id} for user {user.email}"
|
||||
)
|
||||
log.debug(f'generating {type} queries using model {task_model_id} for user {user.email}')
|
||||
|
||||
if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "":
|
||||
if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != '':
|
||||
template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
else:
|
||||
template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = query_generation_template(template, form_data["messages"], user)
|
||||
content = query_generation_template(template, form_data['messages'], user)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.QUERY_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
'model': task_model_id,
|
||||
'messages': [{'role': 'user', 'content': content}],
|
||||
'stream': False,
|
||||
'metadata': {
|
||||
**(request.state.metadata if hasattr(request.state, 'metadata') else {}),
|
||||
'task': str(TASKS.QUERY_GENERATION),
|
||||
'task_body': form_data,
|
||||
'chat_id': form_data.get('chat_id', None),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -551,46 +499,41 @@ async def generate_queries(
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
content={'detail': str(e)},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/auto/completions")
|
||||
async def generate_autocompletion(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
@router.post('/auto/completions')
|
||||
async def generate_autocompletion(request: Request, form_data: dict, user=Depends(get_verified_user)):
|
||||
if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Autocompletion generation is disabled",
|
||||
detail=f'Autocompletion generation is disabled',
|
||||
)
|
||||
|
||||
type = form_data.get("type")
|
||||
prompt = form_data.get("prompt")
|
||||
messages = form_data.get("messages")
|
||||
type = form_data.get('type')
|
||||
prompt = form_data.get('prompt')
|
||||
messages = form_data.get('messages')
|
||||
|
||||
if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0:
|
||||
if (
|
||||
len(prompt)
|
||||
> request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
|
||||
):
|
||||
if len(prompt) > request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
|
||||
detail=f'Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}',
|
||||
)
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
request.state.model['id']: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
model_id = form_data['model']
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
detail='Model not found',
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
@@ -602,11 +545,9 @@ async def generate_autocompletion(
|
||||
models,
|
||||
)
|
||||
|
||||
log.debug(
|
||||
f"generating autocompletion using model {task_model_id} for user {user.email}"
|
||||
)
|
||||
log.debug(f'generating autocompletion using model {task_model_id} for user {user.email}')
|
||||
|
||||
if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "":
|
||||
if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != '':
|
||||
template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
|
||||
else:
|
||||
template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
|
||||
@@ -614,14 +555,14 @@ async def generate_autocompletion(
|
||||
content = autocomplete_generation_template(template, prompt, messages, type, user)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.AUTOCOMPLETE_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
'model': task_model_id,
|
||||
'messages': [{'role': 'user', 'content': content}],
|
||||
'stream': False,
|
||||
'metadata': {
|
||||
**(request.state.metadata if hasattr(request.state, 'metadata') else {}),
|
||||
'task': str(TASKS.AUTOCOMPLETE_GENERATION),
|
||||
'task_body': form_data,
|
||||
'chat_id': form_data.get('chat_id', None),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -634,30 +575,27 @@ async def generate_autocompletion(
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
log.error(f"Error generating chat completion: {e}")
|
||||
log.error(f'Error generating chat completion: {e}')
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={"detail": "An internal error has occurred."},
|
||||
content={'detail': 'An internal error has occurred.'},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/emoji/completions")
|
||||
async def generate_emoji(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
@router.post('/emoji/completions')
|
||||
async def generate_emoji(request: Request, form_data: dict, user=Depends(get_verified_user)):
|
||||
if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
request.state.model['id']: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
model_id = form_data['model']
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
detail='Model not found',
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
@@ -669,28 +607,28 @@ async def generate_emoji(
|
||||
models,
|
||||
)
|
||||
|
||||
log.debug(f"generating emoji using model {task_model_id} for user {user.email} ")
|
||||
log.debug(f'generating emoji using model {task_model_id} for user {user.email} ')
|
||||
|
||||
template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = emoji_generation_template(template, form_data["prompt"], user)
|
||||
content = emoji_generation_template(template, form_data['prompt'], user)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
'model': task_model_id,
|
||||
'messages': [{'role': 'user', 'content': content}],
|
||||
'stream': False,
|
||||
**(
|
||||
{"max_tokens": 4}
|
||||
if models[task_model_id].get("owned_by") == "ollama"
|
||||
{'max_tokens': 4}
|
||||
if models[task_model_id].get('owned_by') == 'ollama'
|
||||
else {
|
||||
"max_completion_tokens": 4,
|
||||
'max_completion_tokens': 4,
|
||||
}
|
||||
),
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.EMOJI_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
'metadata': {
|
||||
**(request.state.metadata if hasattr(request.state, 'metadata') else {}),
|
||||
'task': str(TASKS.EMOJI_GENERATION),
|
||||
'task_body': form_data,
|
||||
'chat_id': form_data.get('chat_id', None),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -705,47 +643,44 @@ async def generate_emoji(
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
content={'detail': str(e)},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/moa/completions")
|
||||
async def generate_moa_response(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
@router.post('/moa/completions')
|
||||
async def generate_moa_response(request: Request, form_data: dict, user=Depends(get_verified_user)):
|
||||
if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
request.state.model['id']: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
model_id = form_data['model']
|
||||
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
detail='Model not found',
|
||||
)
|
||||
|
||||
template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = moa_response_generation_template(
|
||||
template,
|
||||
form_data["prompt"],
|
||||
form_data["responses"],
|
||||
form_data['prompt'],
|
||||
form_data['responses'],
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": form_data.get("stream", False),
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
"task": str(TASKS.MOA_RESPONSE_GENERATION),
|
||||
"task_body": form_data,
|
||||
'model': model_id,
|
||||
'messages': [{'role': 'user', 'content': content}],
|
||||
'stream': form_data.get('stream', False),
|
||||
'metadata': {
|
||||
**(request.state.metadata if hasattr(request.state, 'metadata') else {}),
|
||||
'chat_id': form_data.get('chat_id', None),
|
||||
'task': str(TASKS.MOA_RESPONSE_GENERATION),
|
||||
'task_body': form_data,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -760,5 +695,5 @@ async def generate_moa_response(
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
content={'detail': str(e)},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user