mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-07 03:18:23 -05:00
refac
This commit is contained in:
@@ -2,6 +2,7 @@ import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -739,8 +740,16 @@ def get_azure_allowed_params(api_version: str) -> set[str]:
|
||||
return allowed_params
|
||||
|
||||
|
||||
def is_openai_reasoning_model(model: str) -> bool:
|
||||
return model.lower().startswith(('o1', 'o3', 'o4', 'gpt-5'))
|
||||
def is_openai_new_model(model: str) -> bool:
|
||||
model_lower = model.lower()
|
||||
# o-series models (o1, o3, o4, o5, ...)
|
||||
if re.match(r'^o\d+', model_lower):
|
||||
return True
|
||||
# gpt-N where N >= 5 (gpt-5, gpt-5.2, gpt-6, ...)
|
||||
m = re.match(r'^gpt-(\d+)', model_lower)
|
||||
if m and int(m.group(1)) >= 5:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def convert_to_azure_payload(url, payload: dict, api_version: str):
|
||||
@@ -750,7 +759,7 @@ def convert_to_azure_payload(url, payload: dict, api_version: str):
|
||||
allowed_params = get_azure_allowed_params(api_version)
|
||||
|
||||
# Special handling for o-series models
|
||||
if is_openai_reasoning_model(model):
|
||||
if is_openai_new_model(model):
|
||||
# Convert max_tokens to max_completion_tokens for o-series models
|
||||
if 'max_tokens' in payload:
|
||||
payload['max_completion_tokens'] = payload['max_tokens']
|
||||
@@ -1040,7 +1049,7 @@ async def generate_chat_completion(
|
||||
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
||||
|
||||
# Check if model is a reasoning model that needs special handling
|
||||
if is_openai_reasoning_model(payload['model']):
|
||||
if is_openai_new_model(payload['model']):
|
||||
payload = openai_reasoning_model_handler(payload)
|
||||
elif 'api.openai.com' not in url:
|
||||
# Remove "max_completion_tokens" from the payload for backward compatibility
|
||||
|
||||
Reference in New Issue
Block a user