mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-02 02:09:17 -05:00
After migration to async db operations, the throttle decorator also needs to support async. Since the decorator is only used for async funcs now, we can just switch it to async instead of supporting sync and async at the same time. Signed-off-by: Adam Tao <tcx4c70@gmail.com>
1003 lines
33 KiB
Python
1003 lines
33 KiB
Python
import hashlib
|
|
import re
|
|
import threading
|
|
import time
|
|
import uuid
|
|
import logging
|
|
from datetime import timedelta
|
|
from pathlib import Path
|
|
from typing import Callable, Optional, Sequence, Union
|
|
import json
|
|
import aiohttp
|
|
import mimeparse
|
|
|
|
|
|
import collections.abc
|
|
from open_webui.env import CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def deep_update(d, u):
|
|
for k, v in u.items():
|
|
if isinstance(v, collections.abc.Mapping):
|
|
d[k] = deep_update(d.get(k, {}), v)
|
|
else:
|
|
d[k] = v
|
|
return d
|
|
|
|
|
|
def get_allow_block_lists(filter_list):
|
|
allow_list = []
|
|
block_list = []
|
|
|
|
if filter_list:
|
|
for d in filter_list:
|
|
if d.startswith('!'):
|
|
# Domains starting with "!" → blocked
|
|
block_list.append(d[1:].strip())
|
|
else:
|
|
# Domains starting without "!" → allowed
|
|
allow_list.append(d.strip())
|
|
|
|
return allow_list, block_list
|
|
|
|
|
|
def is_string_allowed(string: Union[str, Sequence[str]], filter_list: Optional[list[str]] = None) -> bool:
|
|
"""
|
|
Checks if a string is allowed based on the provided filter list.
|
|
:param string: The string or sequence of strings to check (e.g., domain or hostname).
|
|
:param filter_list: List of allowed/blocked strings. Strings starting with "!" are blocked.
|
|
:return: True if the string or sequence of strings is allowed, False otherwise.
|
|
"""
|
|
if not filter_list:
|
|
return True
|
|
|
|
allow_list, block_list = get_allow_block_lists(filter_list)
|
|
strings = [string] if isinstance(string, str) else list(string)
|
|
|
|
# If allow list is non-empty, require domain to match one of them
|
|
if allow_list:
|
|
if not any(s.endswith(allowed) for s in strings for allowed in allow_list):
|
|
return False
|
|
|
|
# Block list always removes matches
|
|
if any(s.endswith(blocked) for s in strings for blocked in block_list):
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def get_message_list(messages_map, message_id):
|
|
"""
|
|
Reconstructs a list of messages in order up to the specified message_id.
|
|
|
|
:param message_id: ID of the message to reconstruct the chain
|
|
:param messages: Message history dict containing all messages
|
|
:return: List of ordered messages starting from the root to the given message
|
|
"""
|
|
|
|
# Handle case where messages is None
|
|
if not messages_map:
|
|
return [] # Return empty list instead of None to prevent iteration errors
|
|
|
|
# Find the message by its id
|
|
current_message = messages_map.get(message_id)
|
|
|
|
if not current_message:
|
|
return [] # Return empty list instead of None to prevent iteration errors
|
|
|
|
# Reconstruct the chain by following the parentId links
|
|
message_list = []
|
|
visited_message_ids = set()
|
|
|
|
while current_message:
|
|
message_id = current_message.get('id')
|
|
if message_id in visited_message_ids:
|
|
# Cycle detected, break to prevent infinite loop
|
|
break
|
|
|
|
if message_id is not None:
|
|
visited_message_ids.add(message_id)
|
|
|
|
message_list.append(current_message)
|
|
parent_id = current_message.get('parentId') # Use .get() for safety
|
|
current_message = messages_map.get(parent_id) if parent_id else None
|
|
|
|
message_list.reverse()
|
|
return message_list
|
|
|
|
|
|
def get_messages_content(messages: list[dict]) -> str:
|
|
return '\n'.join([f'{message["role"].upper()}: {get_content_from_message(message)}' for message in messages])
|
|
|
|
|
|
def get_last_user_message_item(messages: list[dict]) -> Optional[dict]:
|
|
for message in reversed(messages):
|
|
if message['role'] == 'user':
|
|
return message
|
|
return None
|
|
|
|
|
|
def get_content_from_message(message: dict) -> Optional[str]:
|
|
if isinstance(message.get('content'), list):
|
|
for item in message['content']:
|
|
if item['type'] == 'text':
|
|
return item['text']
|
|
else:
|
|
return message.get('content')
|
|
return None
|
|
|
|
|
|
def convert_output_to_messages(output: list, raw: bool = False) -> list[dict]:
|
|
"""
|
|
Convert OR-aligned output items to OpenAI Chat Completion-format messages.
|
|
|
|
This reconstructs the full conversation from the stored Responses API-native
|
|
output items, including assistant messages with tool_calls arrays and tool
|
|
role messages.
|
|
|
|
Args:
|
|
output: List of OR-aligned output items (Responses API format).
|
|
raw: If True, include reasoning blocks (with original tags) and code
|
|
interpreter blocks for LLM re-processing follow-ups.
|
|
"""
|
|
if not output or not isinstance(output, list):
|
|
return []
|
|
|
|
messages = []
|
|
pending_tool_calls = []
|
|
pending_content = []
|
|
|
|
def flush_pending():
|
|
nonlocal pending_content, pending_tool_calls
|
|
if pending_content or pending_tool_calls:
|
|
messages.append(
|
|
{
|
|
'role': 'assistant',
|
|
'content': '\n'.join(pending_content) if pending_content else '',
|
|
**({'tool_calls': pending_tool_calls} if pending_tool_calls else {}),
|
|
}
|
|
)
|
|
pending_content = []
|
|
pending_tool_calls = []
|
|
|
|
for item in output:
|
|
item_type = item.get('type', '')
|
|
|
|
if item_type == 'message':
|
|
# Extract text from output_text content parts
|
|
content_parts = item.get('content', [])
|
|
text = ''
|
|
for part in content_parts:
|
|
if part.get('type') == 'output_text':
|
|
text += part.get('text', '')
|
|
if text:
|
|
pending_content.append(text)
|
|
|
|
elif item_type == 'function_call':
|
|
# Collect tool calls to batch into assistant message
|
|
arguments = item.get('arguments', '{}')
|
|
# Ensure arguments is always a JSON string
|
|
if not isinstance(arguments, str):
|
|
arguments = json.dumps(arguments)
|
|
pending_tool_calls.append(
|
|
{
|
|
'id': item.get('call_id', ''),
|
|
'type': 'function',
|
|
'function': {
|
|
'name': item.get('name', ''),
|
|
'arguments': arguments,
|
|
},
|
|
}
|
|
)
|
|
|
|
elif item_type == 'function_call_output':
|
|
# Flush any pending content/tool_calls before adding tool result
|
|
flush_pending()
|
|
|
|
# Extract text and images from output content parts
|
|
output_parts = item.get('output', [])
|
|
content = ''
|
|
image_urls = []
|
|
for part in output_parts:
|
|
if part.get('type') == 'input_text':
|
|
output_text = part.get('text', '')
|
|
content += str(output_text) if not isinstance(output_text, str) else output_text
|
|
elif part.get('type') == 'input_image':
|
|
url = part.get('image_url', '')
|
|
if url:
|
|
image_urls.append(url)
|
|
|
|
if image_urls:
|
|
# Multimodal tool content with image(s)
|
|
messages.append(
|
|
{
|
|
'role': 'tool',
|
|
'tool_call_id': item.get('call_id', ''),
|
|
'content': [
|
|
{'type': 'input_text', 'text': content},
|
|
*[{'type': 'input_image', 'image_url': url} for url in image_urls],
|
|
],
|
|
}
|
|
)
|
|
else:
|
|
messages.append(
|
|
{
|
|
'role': 'tool',
|
|
'tool_call_id': item.get('call_id', ''),
|
|
'content': content,
|
|
}
|
|
)
|
|
|
|
elif item_type == 'reasoning':
|
|
if raw:
|
|
# Include reasoning with original tags for LLM re-processing
|
|
reasoning_text = ''
|
|
source_list = item.get('summary', []) or item.get('content', [])
|
|
for part in source_list:
|
|
if part.get('type') == 'output_text':
|
|
reasoning_text += part.get('text', '')
|
|
elif 'text' in part:
|
|
reasoning_text += part.get('text', '')
|
|
|
|
if reasoning_text:
|
|
start_tag = item.get('start_tag', '<think>')
|
|
end_tag = item.get('end_tag', '</think>')
|
|
pending_content.append(f'{start_tag}{reasoning_text}{end_tag}')
|
|
# NOTE: Some providers (e.g. Moonshot/Kimi K2.5) require
|
|
# reasoning_content as a top-level field on assistant
|
|
# messages. This should be handled externally via a
|
|
# pipeline filter or connection-level middleware, not
|
|
# here — adding it universally breaks strict providers
|
|
# (OpenAI, Vertex AI, Azure) that reject unknown fields.
|
|
# else: skip reasoning blocks for normal LLM messages
|
|
|
|
elif item_type == 'open_webui:code_interpreter':
|
|
# Always include code interpreter content so the LLM knows
|
|
# the code was already executed and doesn't retry.
|
|
code = item.get('code', '')
|
|
code_output = item.get('output', '')
|
|
|
|
if code:
|
|
pending_content.append(f'<code_interpreter>\n{code}\n</code_interpreter>')
|
|
|
|
if code_output:
|
|
if isinstance(code_output, dict):
|
|
stdout = code_output.get('stdout', '')
|
|
result = code_output.get('result', '')
|
|
output_text = stdout or result
|
|
else:
|
|
output_text = str(code_output)
|
|
if output_text:
|
|
pending_content.append(f'<code_interpreter_output>\n{output_text}\n</code_interpreter_output>')
|
|
|
|
elif item_type.startswith('open_webui:'):
|
|
# Skip other extension types
|
|
pass
|
|
|
|
# Flush remaining content/tool_calls
|
|
flush_pending()
|
|
|
|
return messages
|
|
|
|
|
|
def get_last_user_message(messages: list[dict]) -> Optional[str]:
|
|
message = get_last_user_message_item(messages)
|
|
if message is None:
|
|
return None
|
|
return get_content_from_message(message)
|
|
|
|
|
|
def set_last_user_message_content(content: str, messages: list[dict]) -> list[dict]:
|
|
"""
|
|
Replace the text content of the last user message in-place.
|
|
Handles both plain-string and list-of-parts content formats.
|
|
"""
|
|
for message in reversed(messages):
|
|
if message.get('role') == 'user':
|
|
if isinstance(message.get('content'), list):
|
|
for item in message['content']:
|
|
if item.get('type') == 'text':
|
|
item['text'] = content
|
|
break
|
|
else:
|
|
message['content'] = content
|
|
break
|
|
return messages
|
|
|
|
|
|
def get_last_assistant_message_item(messages: list[dict]) -> Optional[dict]:
|
|
for message in reversed(messages):
|
|
if message['role'] == 'assistant':
|
|
return message
|
|
return None
|
|
|
|
|
|
def get_last_assistant_message(messages: list[dict]) -> Optional[str]:
|
|
for message in reversed(messages):
|
|
if message['role'] == 'assistant':
|
|
return get_content_from_message(message)
|
|
return None
|
|
|
|
|
|
def get_system_message(messages: list[dict]) -> Optional[dict]:
|
|
for message in messages:
|
|
if message['role'] == 'system':
|
|
return message
|
|
return None
|
|
|
|
|
|
def remove_system_message(messages: list[dict]) -> list[dict]:
|
|
return [message for message in messages if message['role'] != 'system']
|
|
|
|
|
|
def pop_system_message(messages: list[dict]) -> tuple[Optional[dict], list[dict]]:
|
|
return get_system_message(messages), remove_system_message(messages)
|
|
|
|
|
|
def merge_system_messages(messages: list[dict]) -> list[dict]:
|
|
"""
|
|
Merge all system messages into one at position 0.
|
|
|
|
Some chat templates (e.g. Qwen) require exactly one system
|
|
message at the start. Multiple pipeline stages may each
|
|
insert their own system message; this function consolidates
|
|
them.
|
|
"""
|
|
system_contents: list[str] = []
|
|
other_messages: list[dict] = []
|
|
|
|
for message in messages:
|
|
if message.get('role') == 'system':
|
|
content = get_content_from_message(message)
|
|
if content:
|
|
system_contents.append(content)
|
|
else:
|
|
other_messages.append(message)
|
|
|
|
if not system_contents:
|
|
return other_messages
|
|
|
|
merged = {'role': 'system', 'content': '\n'.join(system_contents)}
|
|
return [merged, *other_messages]
|
|
|
|
|
|
def update_message_content(message: dict, content: str, append: bool = True) -> dict:
|
|
if isinstance(message['content'], list):
|
|
for item in message['content']:
|
|
if item['type'] == 'text':
|
|
if append:
|
|
item['text'] = f'{item["text"]}\n{content}'
|
|
else:
|
|
item['text'] = f'{content}\n{item["text"]}'
|
|
else:
|
|
if append:
|
|
message['content'] = f'{message["content"]}\n{content}'
|
|
else:
|
|
message['content'] = f'{content}\n{message["content"]}'
|
|
return message
|
|
|
|
|
|
def replace_system_message_content(content: str, messages: list[dict]) -> dict:
|
|
for message in messages:
|
|
if message['role'] == 'system':
|
|
message['content'] = content
|
|
break
|
|
return messages
|
|
|
|
|
|
def add_or_update_system_message(content: str, messages: list[dict], append: bool = False):
|
|
"""
|
|
Adds a new system message at the beginning of the messages list
|
|
or updates the existing system message at the beginning.
|
|
|
|
:param msg: The message to be added or appended.
|
|
:param messages: The list of message dictionaries.
|
|
:return: The updated list of message dictionaries.
|
|
"""
|
|
|
|
if messages and messages[0].get('role') == 'system':
|
|
messages[0] = update_message_content(messages[0], content, append)
|
|
else:
|
|
# Insert at the beginning
|
|
messages.insert(0, {'role': 'system', 'content': content})
|
|
|
|
return messages
|
|
|
|
|
|
def add_or_update_user_message(content: str, messages: list[dict], append: bool = True):
|
|
"""
|
|
Adds a new user message at the end of the messages list
|
|
or updates the existing user message at the end.
|
|
|
|
:param msg: The message to be added or appended.
|
|
:param messages: The list of message dictionaries.
|
|
:return: The updated list of message dictionaries.
|
|
"""
|
|
|
|
if messages and messages[-1].get('role') == 'user':
|
|
messages[-1] = update_message_content(messages[-1], content, append)
|
|
else:
|
|
# Insert at the end
|
|
messages.append({'role': 'user', 'content': content})
|
|
|
|
return messages
|
|
|
|
|
|
def prepend_to_first_user_message_content(content: str, messages: list[dict]) -> list[dict]:
|
|
for message in messages:
|
|
if message['role'] == 'user':
|
|
message = update_message_content(message, content, append=False)
|
|
break
|
|
return messages
|
|
|
|
|
|
def append_or_update_assistant_message(content: str, messages: list[dict]):
|
|
"""
|
|
Adds a new assistant message at the end of the messages list
|
|
or updates the existing assistant message at the end.
|
|
|
|
:param msg: The message to be added or appended.
|
|
:param messages: The list of message dictionaries.
|
|
:return: The updated list of message dictionaries.
|
|
"""
|
|
|
|
if messages and messages[-1].get('role') == 'assistant':
|
|
messages[-1]['content'] = f'{messages[-1]["content"]}\n{content}'
|
|
else:
|
|
# Insert at the end
|
|
messages.append({'role': 'assistant', 'content': content})
|
|
|
|
return messages
|
|
|
|
|
|
def strip_empty_content_blocks(messages: list[dict]) -> list[dict]:
|
|
"""
|
|
Remove empty text content blocks from multimodal message content arrays.
|
|
|
|
Providers like Gemini and Claude reject messages where a text block has
|
|
an empty string. This can happen when a user sends only file/image
|
|
attachments without typing any text.
|
|
"""
|
|
for message in messages:
|
|
content = message.get('content')
|
|
if isinstance(content, list):
|
|
cleaned = [
|
|
block
|
|
for block in content
|
|
if not (isinstance(block, dict) and block.get('type') == 'text' and not block.get('text', '').strip())
|
|
]
|
|
if cleaned:
|
|
message['content'] = cleaned
|
|
return messages
|
|
|
|
|
|
def openai_chat_message_template(model: str):
|
|
return {
|
|
'id': f'{model}-{str(uuid.uuid4())}',
|
|
'created': int(time.time()),
|
|
'model': model,
|
|
'choices': [{'index': 0, 'logprobs': None, 'finish_reason': None}],
|
|
}
|
|
|
|
|
|
def openai_chat_chunk_message_template(
|
|
model: str,
|
|
content: Optional[str] = None,
|
|
reasoning_content: Optional[str] = None,
|
|
tool_calls: Optional[list[dict]] = None,
|
|
usage: Optional[dict] = None,
|
|
) -> dict:
|
|
template = openai_chat_message_template(model)
|
|
template['object'] = 'chat.completion.chunk'
|
|
|
|
template['choices'][0]['index'] = 0
|
|
template['choices'][0]['delta'] = {}
|
|
|
|
if content:
|
|
template['choices'][0]['delta']['content'] = content
|
|
|
|
if reasoning_content:
|
|
template['choices'][0]['delta']['reasoning_content'] = reasoning_content
|
|
|
|
if tool_calls:
|
|
template['choices'][0]['delta']['tool_calls'] = tool_calls
|
|
|
|
if not content and not reasoning_content and not tool_calls:
|
|
template['choices'][0]['finish_reason'] = 'stop'
|
|
|
|
if usage:
|
|
template['usage'] = usage
|
|
return template
|
|
|
|
|
|
def openai_chat_completion_message_template(
|
|
model: str,
|
|
message: Optional[str] = None,
|
|
reasoning_content: Optional[str] = None,
|
|
tool_calls: Optional[list[dict]] = None,
|
|
usage: Optional[dict] = None,
|
|
) -> dict:
|
|
template = openai_chat_message_template(model)
|
|
template['object'] = 'chat.completion'
|
|
if message is not None:
|
|
template['choices'][0]['message'] = {
|
|
'role': 'assistant',
|
|
'content': message,
|
|
**({'reasoning_content': reasoning_content} if reasoning_content else {}),
|
|
**({'tool_calls': tool_calls} if tool_calls else {}),
|
|
}
|
|
|
|
template['choices'][0]['finish_reason'] = 'tool_calls' if tool_calls else 'stop'
|
|
|
|
if usage:
|
|
template['usage'] = usage
|
|
return template
|
|
|
|
|
|
def get_gravatar_url(email):
|
|
# Trim leading and trailing whitespace from
|
|
# an email address and force all characters
|
|
# to lower case
|
|
address = str(email).strip().lower()
|
|
|
|
# Create a SHA256 hash of the final string
|
|
hash_object = hashlib.sha256(address.encode())
|
|
hash_hex = hash_object.hexdigest()
|
|
|
|
# Grab the actual image URL
|
|
return f'https://www.gravatar.com/avatar/{hash_hex}?d=mp'
|
|
|
|
|
|
# Give us each day the data we require, and forgive us our
|
|
# technical debts as we forgive those who commit upstream.
|
|
# Lead the bits not into corruption but deliver them from
|
|
# entropy, for the checksum and the glory are forever.
|
|
def calculate_sha256(file_path, chunk_size):
|
|
# Compute SHA-256 hash of a file efficiently in chunks
|
|
sha256 = hashlib.sha256()
|
|
with open(file_path, 'rb') as f:
|
|
while chunk := f.read(chunk_size):
|
|
sha256.update(chunk)
|
|
return sha256.hexdigest()
|
|
|
|
|
|
def calculate_sha256_string(string):
|
|
# Create a new SHA-256 hash object
|
|
sha256_hash = hashlib.sha256()
|
|
# Update the hash object with the bytes of the input string
|
|
sha256_hash.update(string.encode('utf-8'))
|
|
# Get the hexadecimal representation of the hash
|
|
hashed_string = sha256_hash.hexdigest()
|
|
return hashed_string
|
|
|
|
|
|
def validate_email_format(email: str) -> bool:
|
|
if email.endswith('@localhost'):
|
|
return True
|
|
|
|
return bool(re.match(r'[^@]+@[^@]+\.[^@]+', email))
|
|
|
|
|
|
def sanitize_filename(file_name):
|
|
# Convert to lowercase
|
|
lower_case_file_name = file_name.lower()
|
|
|
|
# Remove special characters using regular expression
|
|
sanitized_file_name = re.sub(r'[^\w\s]', '', lower_case_file_name)
|
|
|
|
# Replace spaces with dashes
|
|
final_file_name = re.sub(r'\s+', '-', sanitized_file_name)
|
|
|
|
return final_file_name
|
|
|
|
|
|
def sanitize_text_for_db(text: str) -> str:
|
|
"""Remove null bytes and invalid UTF-8 surrogates from text for PostgreSQL storage."""
|
|
if not isinstance(text, str):
|
|
return text
|
|
# Remove null bytes
|
|
text = text.replace('\x00', '').replace('\u0000', '')
|
|
# Remove invalid UTF-8 surrogate characters that can cause encoding errors
|
|
# This handles cases where binary data or encoding issues introduced surrogates
|
|
try:
|
|
text = text.encode('utf-8', errors='surrogatepass').decode('utf-8', errors='ignore')
|
|
except (UnicodeEncodeError, UnicodeDecodeError):
|
|
pass
|
|
return text
|
|
|
|
|
|
def sanitize_data_for_db(obj):
|
|
"""Recursively sanitize all strings in a data structure for database storage."""
|
|
if isinstance(obj, str):
|
|
return sanitize_text_for_db(obj)
|
|
elif isinstance(obj, dict):
|
|
return {k: sanitize_data_for_db(v) for k, v in obj.items()}
|
|
elif isinstance(obj, list):
|
|
return [sanitize_data_for_db(v) for v in obj]
|
|
return obj
|
|
|
|
|
|
def sanitize_metadata(metadata: dict) -> dict:
|
|
"""
|
|
Return a JSON-safe copy of a metadata dict for database storage.
|
|
|
|
The middleware metadata accumulates non-serializable Python objects
|
|
(e.g. callable tool functions, MCP client instances) that cause
|
|
PostgreSQL JSON inserts to fail. This helper strips those out while
|
|
preserving the primitive data needed for file-to-chat linking.
|
|
"""
|
|
if not isinstance(metadata, dict):
|
|
return metadata
|
|
|
|
def _sanitize(obj):
|
|
if isinstance(obj, (str, int, float, bool, type(None))):
|
|
return obj
|
|
if isinstance(obj, dict):
|
|
return {k: _sanitize(v) for k, v in obj.items() if not callable(v) and _is_serializable(v)}
|
|
if isinstance(obj, list):
|
|
return [_sanitize(v) for v in obj if not callable(v) and _is_serializable(v)]
|
|
if callable(obj):
|
|
return None
|
|
# Last resort: try to see if it's serializable
|
|
try:
|
|
json.dumps(obj)
|
|
return obj
|
|
except (TypeError, ValueError):
|
|
return None
|
|
|
|
def _is_serializable(obj):
|
|
"""Quick check whether a value can survive JSON serialization."""
|
|
if isinstance(obj, (str, int, float, bool, type(None), dict, list)):
|
|
return True
|
|
try:
|
|
json.dumps(obj)
|
|
return True
|
|
except (TypeError, ValueError):
|
|
return False
|
|
|
|
return _sanitize(metadata)
|
|
|
|
|
|
def extract_folders_after_data_docs(path):
|
|
# Convert the path to a Path object if it's not already
|
|
path = Path(path)
|
|
|
|
# Extract parts of the path
|
|
parts = path.parts
|
|
|
|
# Find the index of '/data/docs' in the path
|
|
try:
|
|
index_data_docs = parts.index('data') + 1
|
|
index_docs = parts.index('docs', index_data_docs) + 1
|
|
except ValueError:
|
|
return []
|
|
|
|
# Exclude the filename and accumulate folder names
|
|
tags = []
|
|
|
|
folders = parts[index_docs:-1]
|
|
for idx, _ in enumerate(folders):
|
|
tags.append('/'.join(folders[: idx + 1]))
|
|
|
|
return tags
|
|
|
|
|
|
def parse_duration(duration: str) -> Optional[timedelta]:
|
|
if duration == '-1' or duration == '0':
|
|
return None
|
|
|
|
# Regular expression to find number and unit pairs
|
|
pattern = r'(-?\d+(\.\d+)?)(ms|s|m|h|d|w)'
|
|
matches = re.findall(pattern, duration)
|
|
|
|
if not matches:
|
|
raise ValueError('Invalid duration string')
|
|
|
|
total_duration = timedelta()
|
|
|
|
for number, _, unit in matches:
|
|
number = float(number)
|
|
if unit == 'ms':
|
|
total_duration += timedelta(milliseconds=number)
|
|
elif unit == 's':
|
|
total_duration += timedelta(seconds=number)
|
|
elif unit == 'm':
|
|
total_duration += timedelta(minutes=number)
|
|
elif unit == 'h':
|
|
total_duration += timedelta(hours=number)
|
|
elif unit == 'd':
|
|
total_duration += timedelta(days=number)
|
|
elif unit == 'w':
|
|
total_duration += timedelta(weeks=number)
|
|
|
|
return total_duration
|
|
|
|
|
|
def parse_ollama_modelfile(model_text):
|
|
parameters_meta = {
|
|
'mirostat': int,
|
|
'mirostat_eta': float,
|
|
'mirostat_tau': float,
|
|
'num_ctx': int,
|
|
'repeat_last_n': int,
|
|
'repeat_penalty': float,
|
|
'temperature': float,
|
|
'seed': int,
|
|
'tfs_z': float,
|
|
'num_predict': int,
|
|
'top_k': int,
|
|
'top_p': float,
|
|
'num_keep': int,
|
|
'presence_penalty': float,
|
|
'frequency_penalty': float,
|
|
'num_batch': int,
|
|
'num_gpu': int,
|
|
'use_mmap': bool,
|
|
'use_mlock': bool,
|
|
'num_thread': int,
|
|
}
|
|
|
|
data = {'base_model_id': None, 'params': {}}
|
|
|
|
# Parse base model
|
|
base_model_match = re.search(r'^FROM\s+(\w+)', model_text, re.MULTILINE | re.IGNORECASE)
|
|
if base_model_match:
|
|
data['base_model_id'] = base_model_match.group(1)
|
|
|
|
# Parse template
|
|
template_match = re.search(r'TEMPLATE\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE)
|
|
if template_match:
|
|
data['params'] = {'template': template_match.group(1).strip()}
|
|
|
|
# Parse stops
|
|
stops = re.findall(r'PARAMETER stop "(.*?)"', model_text, re.IGNORECASE)
|
|
if stops:
|
|
data['params']['stop'] = stops
|
|
|
|
# Parse other parameters from the provided list
|
|
for param, param_type in parameters_meta.items():
|
|
param_match = re.search(rf'PARAMETER {param} (.+)', model_text, re.IGNORECASE)
|
|
if param_match:
|
|
value = param_match.group(1)
|
|
|
|
try:
|
|
if param_type is int:
|
|
value = int(value)
|
|
elif param_type is float:
|
|
value = float(value)
|
|
elif param_type is bool:
|
|
value = value.lower() == 'true'
|
|
except Exception as e:
|
|
log.exception(f'Failed to parse parameter {param}: {e}')
|
|
continue
|
|
|
|
data['params'][param] = value
|
|
|
|
# Parse adapter
|
|
adapter_match = re.search(r'ADAPTER (.+)', model_text, re.IGNORECASE)
|
|
if adapter_match:
|
|
data['params']['adapter'] = adapter_match.group(1)
|
|
|
|
# Parse system description
|
|
system_desc_match = re.search(r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE)
|
|
system_desc_match_single = re.search(r'SYSTEM\s+([^\n]+)', model_text, re.IGNORECASE)
|
|
|
|
if system_desc_match:
|
|
data['params']['system'] = system_desc_match.group(1).strip()
|
|
elif system_desc_match_single:
|
|
data['params']['system'] = system_desc_match_single.group(1).strip()
|
|
|
|
# Parse messages
|
|
messages = []
|
|
message_matches = re.findall(r'MESSAGE (\w+) (.+)', model_text, re.IGNORECASE)
|
|
for role, content in message_matches:
|
|
messages.append({'role': role, 'content': content})
|
|
|
|
if messages:
|
|
data['params']['messages'] = messages
|
|
|
|
return data
|
|
|
|
|
|
def convert_logit_bias_input_to_json(logit_bias_input) -> Optional[str]:
|
|
if not logit_bias_input:
|
|
return None
|
|
|
|
if isinstance(logit_bias_input, dict):
|
|
return json.dumps(logit_bias_input)
|
|
|
|
logit_bias_pairs = logit_bias_input.split(',')
|
|
logit_bias_json = {}
|
|
for pair in logit_bias_pairs:
|
|
token, bias = pair.split(':')
|
|
token = str(token.strip())
|
|
bias = int(bias.strip())
|
|
bias = 100 if bias > 100 else -100 if bias < -100 else bias
|
|
logit_bias_json[token] = bias
|
|
return json.dumps(logit_bias_json)
|
|
|
|
|
|
def freeze(value):
|
|
"""
|
|
Freeze a value to make it hashable.
|
|
"""
|
|
if isinstance(value, dict):
|
|
return frozenset((k, freeze(v)) for k, v in value.items())
|
|
elif isinstance(value, list):
|
|
return tuple(freeze(v) for v in value)
|
|
return value
|
|
|
|
|
|
def throttle(interval: float = 10.0):
|
|
"""
|
|
Decorator to prevent a function from being called more than once within a specified duration.
|
|
If the function is called again within the duration, it returns None. To avoid returning
|
|
different types, the return type of the function should be Optional[T].
|
|
|
|
:param interval: Duration in seconds to wait before allowing the function to be called again.
|
|
"""
|
|
|
|
def decorator(func):
|
|
last_calls = {}
|
|
lock = threading.Lock()
|
|
|
|
async def wrapper(*args, **kwargs):
|
|
if interval is None:
|
|
return await func(*args, **kwargs)
|
|
|
|
key = (args, freeze(kwargs))
|
|
now = time.time()
|
|
if now - last_calls.get(key, 0) < interval:
|
|
return None
|
|
with lock:
|
|
if now - last_calls.get(key, 0) < interval:
|
|
return None
|
|
last_calls[key] = now
|
|
return await func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
def strict_match_mime_type(supported: list[str] | str, header: str) -> Optional[str]:
|
|
"""
|
|
Strictly match the mime type with the supported mime types.
|
|
|
|
:param supported: The supported mime types.
|
|
:param header: The header to match.
|
|
:return: The matched mime type or None if no match is found.
|
|
"""
|
|
|
|
try:
|
|
if isinstance(supported, str):
|
|
supported = supported.split(',')
|
|
|
|
supported = [s for s in supported if s.strip() and '/' in s]
|
|
|
|
if len(supported) == 0:
|
|
# Default to common types if none are specified
|
|
supported = ['audio/*', 'video/webm']
|
|
|
|
match = mimeparse.best_match(supported, header)
|
|
if not match:
|
|
return None
|
|
|
|
_, _, match_params = mimeparse.parse_mime_type(match)
|
|
_, _, header_params = mimeparse.parse_mime_type(header)
|
|
for k, v in match_params.items():
|
|
if header_params.get(k) != v:
|
|
return None
|
|
|
|
return match
|
|
except Exception as e:
|
|
log.exception(f'Failed to match mime type {header}: {e}')
|
|
return None
|
|
|
|
|
|
def extract_urls(text: str) -> list[str]:
|
|
# Regex pattern to match URLs
|
|
url_pattern = re.compile(r'(https?://[^\s]+)', re.IGNORECASE) # Matches http and https URLs
|
|
return url_pattern.findall(text)
|
|
|
|
|
|
# We believe in one architect of all that is seen and served.
|
|
# Should this stream falter, it shall be raised again on the
|
|
# third retry. We look for the uptime of the world to come.
|
|
async def cleanup_response(
|
|
response: Optional[aiohttp.ClientResponse],
|
|
session: Optional[aiohttp.ClientSession],
|
|
):
|
|
if response:
|
|
if not response.closed:
|
|
# aiohttp 3.9+ made ClientResponse.close() synchronous (returns None).
|
|
# Older versions returned a coroutine. Handle both gracefully.
|
|
result = response.close()
|
|
if result is not None:
|
|
await result
|
|
if session:
|
|
if not session.closed:
|
|
result = session.close()
|
|
if result is not None:
|
|
await result
|
|
|
|
|
|
async def stream_wrapper(response, session, content_handler=None):
|
|
"""
|
|
Wrap a stream to ensure cleanup happens even if streaming is interrupted.
|
|
This is more reliable than BackgroundTask which may not run if client disconnects.
|
|
"""
|
|
try:
|
|
stream = content_handler(response.content) if content_handler else response.content
|
|
async for chunk in stream:
|
|
yield chunk
|
|
finally:
|
|
await cleanup_response(response, session)
|
|
|
|
|
|
def stream_chunks_handler(stream: aiohttp.StreamReader):
|
|
"""
|
|
Handle stream response chunks, supporting large data chunks that exceed the original 16kb limit.
|
|
When a single line exceeds max_buffer_size, returns an empty JSON string {} and skips subsequent data
|
|
until encountering normally sized data.
|
|
|
|
:param stream: The stream reader to handle.
|
|
:return: An async generator that yields the stream data.
|
|
"""
|
|
|
|
max_buffer_size = CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE
|
|
if max_buffer_size is None or max_buffer_size <= 0:
|
|
return stream
|
|
|
|
async def yield_safe_stream_chunks():
|
|
buffer = b''
|
|
skip_mode = False
|
|
|
|
async for data, _ in stream.iter_chunks():
|
|
if not data:
|
|
continue
|
|
|
|
# In skip_mode, if buffer already exceeds the limit, clear it (it's part of an oversized line)
|
|
if skip_mode and len(buffer) > max_buffer_size:
|
|
buffer = b''
|
|
|
|
lines = (buffer + data).split(b'\n')
|
|
|
|
# Process complete lines (except the last possibly incomplete fragment)
|
|
for i in range(len(lines) - 1):
|
|
line = lines[i]
|
|
|
|
if skip_mode:
|
|
# Skip mode: check if current line is small enough to exit skip mode
|
|
if len(line) <= max_buffer_size:
|
|
skip_mode = False
|
|
yield line
|
|
else:
|
|
yield b'data: {}\n'
|
|
else:
|
|
# Normal mode: check if line exceeds limit
|
|
if len(line) > max_buffer_size:
|
|
skip_mode = True
|
|
yield b'data: {}\n'
|
|
log.info(f'Skip mode triggered, line size: {len(line)}')
|
|
else:
|
|
yield line + b'\n'
|
|
|
|
# Save the last incomplete fragment
|
|
buffer = lines[-1]
|
|
|
|
# Check if buffer exceeds limit
|
|
if not skip_mode and len(buffer) > max_buffer_size:
|
|
skip_mode = True
|
|
log.info(f'Skip mode triggered, buffer size: {len(buffer)}')
|
|
# Clear oversized buffer to prevent unlimited growth
|
|
buffer = b''
|
|
|
|
# Process remaining buffer data
|
|
if buffer and not skip_mode:
|
|
yield buffer + b'\n'
|
|
|
|
return yield_safe_stream_chunks()
|