This commit is contained in:
Timothy Jaeryang Baek
2026-03-17 17:58:01 -05:00
parent fcf7208352
commit de3317e26b
220 changed files with 17200 additions and 22836 deletions

View File

@@ -33,7 +33,7 @@ def get_allow_block_lists(filter_list):
if filter_list:
for d in filter_list:
if d.startswith("!"):
if d.startswith('!'):
# Domains starting with "!" → blocked
block_list.append(d[1:].strip())
else:
@@ -43,9 +43,7 @@ def get_allow_block_lists(filter_list):
return allow_list, block_list
def is_string_allowed(
string: Union[str, Sequence[str]], filter_list: Optional[list[str]] = None
) -> bool:
def is_string_allowed(string: Union[str, Sequence[str]], filter_list: Optional[list[str]] = None) -> bool:
"""
Checks if a string is allowed based on the provided filter list.
:param string: The string or sequence of strings to check (e.g., domain or hostname).
@@ -94,7 +92,7 @@ def get_message_list(messages_map, message_id):
visited_message_ids = set()
while current_message:
message_id = current_message.get("id")
message_id = current_message.get('id')
if message_id in visited_message_ids:
# Cycle detected, break to prevent infinite loop
break
@@ -103,7 +101,7 @@ def get_message_list(messages_map, message_id):
visited_message_ids.add(message_id)
message_list.append(current_message)
parent_id = current_message.get("parentId") # Use .get() for safety
parent_id = current_message.get('parentId') # Use .get() for safety
current_message = messages_map.get(parent_id) if parent_id else None
message_list.reverse()
@@ -111,28 +109,23 @@ def get_message_list(messages_map, message_id):
def get_messages_content(messages: list[dict]) -> str:
return "\n".join(
[
f"{message['role'].upper()}: {get_content_from_message(message)}"
for message in messages
]
)
return '\n'.join([f'{message["role"].upper()}: {get_content_from_message(message)}' for message in messages])
def get_last_user_message_item(messages: list[dict]) -> Optional[dict]:
for message in reversed(messages):
if message["role"] == "user":
if message['role'] == 'user':
return message
return None
def get_content_from_message(message: dict) -> Optional[str]:
if isinstance(message.get("content"), list):
for item in message["content"]:
if item["type"] == "text":
return item["text"]
if isinstance(message.get('content'), list):
for item in message['content']:
if item['type'] == 'text':
return item['text']
else:
return message.get("content")
return message.get('content')
return None
@@ -161,111 +154,101 @@ def convert_output_to_messages(output: list, raw: bool = False) -> list[dict]:
if pending_content or pending_tool_calls:
messages.append(
{
"role": "assistant",
"content": "\n".join(pending_content) if pending_content else "",
**(
{"tool_calls": pending_tool_calls} if pending_tool_calls else {}
),
'role': 'assistant',
'content': '\n'.join(pending_content) if pending_content else '',
**({'tool_calls': pending_tool_calls} if pending_tool_calls else {}),
}
)
pending_content = []
pending_tool_calls = []
for item in output:
item_type = item.get("type", "")
item_type = item.get('type', '')
if item_type == "message":
if item_type == 'message':
# Extract text from output_text content parts
content_parts = item.get("content", [])
text = ""
content_parts = item.get('content', [])
text = ''
for part in content_parts:
if part.get("type") == "output_text":
text += part.get("text", "")
if part.get('type') == 'output_text':
text += part.get('text', '')
if text:
pending_content.append(text)
elif item_type == "function_call":
elif item_type == 'function_call':
# Collect tool calls to batch into assistant message
arguments = item.get("arguments", "{}")
arguments = item.get('arguments', '{}')
# Ensure arguments is always a JSON string
if not isinstance(arguments, str):
arguments = json.dumps(arguments)
pending_tool_calls.append(
{
"id": item.get("call_id", ""),
"type": "function",
"function": {
"name": item.get("name", ""),
"arguments": arguments,
'id': item.get('call_id', ''),
'type': 'function',
'function': {
'name': item.get('name', ''),
'arguments': arguments,
},
}
)
elif item_type == "function_call_output":
elif item_type == 'function_call_output':
# Flush any pending content/tool_calls before adding tool result
flush_pending()
# Extract text from output content parts
output_parts = item.get("output", [])
content = ""
output_parts = item.get('output', [])
content = ''
for part in output_parts:
if part.get("type") == "input_text":
output_text = part.get("text", "")
content += (
str(output_text)
if not isinstance(output_text, str)
else output_text
)
if part.get('type') == 'input_text':
output_text = part.get('text', '')
content += str(output_text) if not isinstance(output_text, str) else output_text
messages.append(
{
"role": "tool",
"tool_call_id": item.get("call_id", ""),
"content": content,
'role': 'tool',
'tool_call_id': item.get('call_id', ''),
'content': content,
}
)
elif item_type == "reasoning":
elif item_type == 'reasoning':
if raw:
# Include reasoning with original tags for LLM re-processing
reasoning_text = ""
source_list = item.get("summary", []) or item.get("content", [])
reasoning_text = ''
source_list = item.get('summary', []) or item.get('content', [])
for part in source_list:
if part.get("type") == "output_text":
reasoning_text += part.get("text", "")
elif "text" in part:
reasoning_text += part.get("text", "")
if part.get('type') == 'output_text':
reasoning_text += part.get('text', '')
elif 'text' in part:
reasoning_text += part.get('text', '')
if reasoning_text:
start_tag = item.get("start_tag", "<think>")
end_tag = item.get("end_tag", "</think>")
pending_content.append(f"{start_tag}{reasoning_text}{end_tag}")
start_tag = item.get('start_tag', '<think>')
end_tag = item.get('end_tag', '</think>')
pending_content.append(f'{start_tag}{reasoning_text}{end_tag}')
# else: skip reasoning blocks for normal LLM messages
elif item_type == "open_webui:code_interpreter":
elif item_type == 'open_webui:code_interpreter':
# Always include code interpreter content so the LLM knows
# the code was already executed and doesn't retry.
code = item.get("code", "")
code_output = item.get("output", "")
code = item.get('code', '')
code_output = item.get('output', '')
if code:
pending_content.append(
f"<code_interpreter>\n{code}\n</code_interpreter>"
)
pending_content.append(f'<code_interpreter>\n{code}\n</code_interpreter>')
if code_output:
if isinstance(code_output, dict):
stdout = code_output.get("stdout", "")
result = code_output.get("result", "")
stdout = code_output.get('stdout', '')
result = code_output.get('result', '')
output_text = stdout or result
else:
output_text = str(code_output)
if output_text:
pending_content.append(
f"<code_interpreter_output>\n{output_text}\n</code_interpreter_output>"
)
pending_content.append(f'<code_interpreter_output>\n{output_text}\n</code_interpreter_output>')
elif item_type.startswith("open_webui:"):
elif item_type.startswith('open_webui:'):
# Skip other extension types
pass
@@ -288,41 +271,41 @@ def set_last_user_message_content(content: str, messages: list[dict]) -> list[di
Handles both plain-string and list-of-parts content formats.
"""
for message in reversed(messages):
if message.get("role") == "user":
if isinstance(message.get("content"), list):
for item in message["content"]:
if item.get("type") == "text":
item["text"] = content
if message.get('role') == 'user':
if isinstance(message.get('content'), list):
for item in message['content']:
if item.get('type') == 'text':
item['text'] = content
break
else:
message["content"] = content
message['content'] = content
break
return messages
def get_last_assistant_message_item(messages: list[dict]) -> Optional[dict]:
for message in reversed(messages):
if message["role"] == "assistant":
if message['role'] == 'assistant':
return message
return None
def get_last_assistant_message(messages: list[dict]) -> Optional[str]:
for message in reversed(messages):
if message["role"] == "assistant":
if message['role'] == 'assistant':
return get_content_from_message(message)
return None
def get_system_message(messages: list[dict]) -> Optional[dict]:
for message in messages:
if message["role"] == "system":
if message['role'] == 'system':
return message
return None
def remove_system_message(messages: list[dict]) -> list[dict]:
return [message for message in messages if message["role"] != "system"]
return [message for message in messages if message['role'] != 'system']
def pop_system_message(messages: list[dict]) -> tuple[Optional[dict], list[dict]]:
@@ -330,32 +313,30 @@ def pop_system_message(messages: list[dict]) -> tuple[Optional[dict], list[dict]
def update_message_content(message: dict, content: str, append: bool = True) -> dict:
if isinstance(message["content"], list):
for item in message["content"]:
if item["type"] == "text":
if isinstance(message['content'], list):
for item in message['content']:
if item['type'] == 'text':
if append:
item["text"] = f"{item['text']}\n{content}"
item['text'] = f'{item["text"]}\n{content}'
else:
item["text"] = f"{content}\n{item['text']}"
item['text'] = f'{content}\n{item["text"]}'
else:
if append:
message["content"] = f"{message['content']}\n{content}"
message['content'] = f'{message["content"]}\n{content}'
else:
message["content"] = f"{content}\n{message['content']}"
message['content'] = f'{content}\n{message["content"]}'
return message
def replace_system_message_content(content: str, messages: list[dict]) -> dict:
for message in messages:
if message["role"] == "system":
message["content"] = content
if message['role'] == 'system':
message['content'] = content
break
return messages
def add_or_update_system_message(
content: str, messages: list[dict], append: bool = False
):
def add_or_update_system_message(content: str, messages: list[dict], append: bool = False):
"""
Adds a new system message at the beginning of the messages list
or updates the existing system message at the beginning.
@@ -365,11 +346,11 @@ def add_or_update_system_message(
:return: The updated list of message dictionaries.
"""
if messages and messages[0].get("role") == "system":
if messages and messages[0].get('role') == 'system':
messages[0] = update_message_content(messages[0], content, append)
else:
# Insert at the beginning
messages.insert(0, {"role": "system", "content": content})
messages.insert(0, {'role': 'system', 'content': content})
return messages
@@ -384,20 +365,18 @@ def add_or_update_user_message(content: str, messages: list[dict], append: bool
:return: The updated list of message dictionaries.
"""
if messages and messages[-1].get("role") == "user":
if messages and messages[-1].get('role') == 'user':
messages[-1] = update_message_content(messages[-1], content, append)
else:
# Insert at the end
messages.append({"role": "user", "content": content})
messages.append({'role': 'user', 'content': content})
return messages
def prepend_to_first_user_message_content(
content: str, messages: list[dict]
) -> list[dict]:
def prepend_to_first_user_message_content(content: str, messages: list[dict]) -> list[dict]:
for message in messages:
if message["role"] == "user":
if message['role'] == 'user':
message = update_message_content(message, content, append=False)
break
return messages
@@ -413,21 +392,21 @@ def append_or_update_assistant_message(content: str, messages: list[dict]):
:return: The updated list of message dictionaries.
"""
if messages and messages[-1].get("role") == "assistant":
messages[-1]["content"] = f"{messages[-1]['content']}\n{content}"
if messages and messages[-1].get('role') == 'assistant':
messages[-1]['content'] = f'{messages[-1]["content"]}\n{content}'
else:
# Insert at the end
messages.append({"role": "assistant", "content": content})
messages.append({'role': 'assistant', 'content': content})
return messages
def openai_chat_message_template(model: str):
return {
"id": f"{model}-{str(uuid.uuid4())}",
"created": int(time.time()),
"model": model,
"choices": [{"index": 0, "logprobs": None, "finish_reason": None}],
'id': f'{model}-{str(uuid.uuid4())}',
'created': int(time.time()),
'model': model,
'choices': [{'index': 0, 'logprobs': None, 'finish_reason': None}],
}
@@ -439,25 +418,25 @@ def openai_chat_chunk_message_template(
usage: Optional[dict] = None,
) -> dict:
template = openai_chat_message_template(model)
template["object"] = "chat.completion.chunk"
template['object'] = 'chat.completion.chunk'
template["choices"][0]["index"] = 0
template["choices"][0]["delta"] = {}
template['choices'][0]['index'] = 0
template['choices'][0]['delta'] = {}
if content:
template["choices"][0]["delta"]["content"] = content
template['choices'][0]['delta']['content'] = content
if reasoning_content:
template["choices"][0]["delta"]["reasoning_content"] = reasoning_content
template['choices'][0]['delta']['reasoning_content'] = reasoning_content
if tool_calls:
template["choices"][0]["delta"]["tool_calls"] = tool_calls
template['choices'][0]['delta']['tool_calls'] = tool_calls
if not content and not reasoning_content and not tool_calls:
template["choices"][0]["finish_reason"] = "stop"
template['choices'][0]['finish_reason'] = 'stop'
if usage:
template["usage"] = usage
template['usage'] = usage
return template
@@ -469,19 +448,19 @@ def openai_chat_completion_message_template(
usage: Optional[dict] = None,
) -> dict:
template = openai_chat_message_template(model)
template["object"] = "chat.completion"
template['object'] = 'chat.completion'
if message is not None:
template["choices"][0]["message"] = {
"role": "assistant",
"content": message,
**({"reasoning_content": reasoning_content} if reasoning_content else {}),
**({"tool_calls": tool_calls} if tool_calls else {}),
template['choices'][0]['message'] = {
'role': 'assistant',
'content': message,
**({'reasoning_content': reasoning_content} if reasoning_content else {}),
**({'tool_calls': tool_calls} if tool_calls else {}),
}
template["choices"][0]["finish_reason"] = "tool_calls" if tool_calls else "stop"
template['choices'][0]['finish_reason'] = 'tool_calls' if tool_calls else 'stop'
if usage:
template["usage"] = usage
template['usage'] = usage
return template
@@ -496,13 +475,13 @@ def get_gravatar_url(email):
hash_hex = hash_object.hexdigest()
# Grab the actual image URL
return f"https://www.gravatar.com/avatar/{hash_hex}?d=mp"
return f'https://www.gravatar.com/avatar/{hash_hex}?d=mp'
def calculate_sha256(file_path, chunk_size):
# Compute SHA-256 hash of a file efficiently in chunks
sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
with open(file_path, 'rb') as f:
while chunk := f.read(chunk_size):
sha256.update(chunk)
return sha256.hexdigest()
@@ -512,17 +491,17 @@ def calculate_sha256_string(string):
# Create a new SHA-256 hash object
sha256_hash = hashlib.sha256()
# Update the hash object with the bytes of the input string
sha256_hash.update(string.encode("utf-8"))
sha256_hash.update(string.encode('utf-8'))
# Get the hexadecimal representation of the hash
hashed_string = sha256_hash.hexdigest()
return hashed_string
def validate_email_format(email: str) -> bool:
if email.endswith("@localhost"):
if email.endswith('@localhost'):
return True
return bool(re.match(r"[^@]+@[^@]+\.[^@]+", email))
return bool(re.match(r'[^@]+@[^@]+\.[^@]+', email))
def sanitize_filename(file_name):
@@ -530,10 +509,10 @@ def sanitize_filename(file_name):
lower_case_file_name = file_name.lower()
# Remove special characters using regular expression
sanitized_file_name = re.sub(r"[^\w\s]", "", lower_case_file_name)
sanitized_file_name = re.sub(r'[^\w\s]', '', lower_case_file_name)
# Replace spaces with dashes
final_file_name = re.sub(r"\s+", "-", sanitized_file_name)
final_file_name = re.sub(r'\s+', '-', sanitized_file_name)
return final_file_name
@@ -543,13 +522,11 @@ def sanitize_text_for_db(text: str) -> str:
if not isinstance(text, str):
return text
# Remove null bytes
text = text.replace("\x00", "").replace("\u0000", "")
text = text.replace('\x00', '').replace('\u0000', '')
# Remove invalid UTF-8 surrogate characters that can cause encoding errors
# This handles cases where binary data or encoding issues introduced surrogates
try:
text = text.encode("utf-8", errors="surrogatepass").decode(
"utf-8", errors="ignore"
)
text = text.encode('utf-8', errors='surrogatepass').decode('utf-8', errors='ignore')
except (UnicodeEncodeError, UnicodeDecodeError):
pass
return text
@@ -582,15 +559,9 @@ def sanitize_metadata(metadata: dict) -> dict:
if isinstance(obj, (str, int, float, bool, type(None))):
return obj
if isinstance(obj, dict):
return {
k: _sanitize(v)
for k, v in obj.items()
if not callable(v) and _is_serializable(v)
}
return {k: _sanitize(v) for k, v in obj.items() if not callable(v) and _is_serializable(v)}
if isinstance(obj, list):
return [
_sanitize(v) for v in obj if not callable(v) and _is_serializable(v)
]
return [_sanitize(v) for v in obj if not callable(v) and _is_serializable(v)]
if callable(obj):
return None
# Last resort: try to see if it's serializable
@@ -622,8 +593,8 @@ def extract_folders_after_data_docs(path):
# Find the index of '/data/docs' in the path
try:
index_data_docs = parts.index("data") + 1
index_docs = parts.index("docs", index_data_docs) + 1
index_data_docs = parts.index('data') + 1
index_docs = parts.index('docs', index_data_docs) + 1
except ValueError:
return []
@@ -632,37 +603,37 @@ def extract_folders_after_data_docs(path):
folders = parts[index_docs:-1]
for idx, _ in enumerate(folders):
tags.append("/".join(folders[: idx + 1]))
tags.append('/'.join(folders[: idx + 1]))
return tags
def parse_duration(duration: str) -> Optional[timedelta]:
if duration == "-1" or duration == "0":
if duration == '-1' or duration == '0':
return None
# Regular expression to find number and unit pairs
pattern = r"(-?\d+(\.\d+)?)(ms|s|m|h|d|w)"
pattern = r'(-?\d+(\.\d+)?)(ms|s|m|h|d|w)'
matches = re.findall(pattern, duration)
if not matches:
raise ValueError("Invalid duration string")
raise ValueError('Invalid duration string')
total_duration = timedelta()
for number, _, unit in matches:
number = float(number)
if unit == "ms":
if unit == 'ms':
total_duration += timedelta(milliseconds=number)
elif unit == "s":
elif unit == 's':
total_duration += timedelta(seconds=number)
elif unit == "m":
elif unit == 'm':
total_duration += timedelta(minutes=number)
elif unit == "h":
elif unit == 'h':
total_duration += timedelta(hours=number)
elif unit == "d":
elif unit == 'd':
total_duration += timedelta(days=number)
elif unit == "w":
elif unit == 'w':
total_duration += timedelta(weeks=number)
return total_duration
@@ -670,52 +641,48 @@ def parse_duration(duration: str) -> Optional[timedelta]:
def parse_ollama_modelfile(model_text):
parameters_meta = {
"mirostat": int,
"mirostat_eta": float,
"mirostat_tau": float,
"num_ctx": int,
"repeat_last_n": int,
"repeat_penalty": float,
"temperature": float,
"seed": int,
"tfs_z": float,
"num_predict": int,
"top_k": int,
"top_p": float,
"num_keep": int,
"presence_penalty": float,
"frequency_penalty": float,
"num_batch": int,
"num_gpu": int,
"use_mmap": bool,
"use_mlock": bool,
"num_thread": int,
'mirostat': int,
'mirostat_eta': float,
'mirostat_tau': float,
'num_ctx': int,
'repeat_last_n': int,
'repeat_penalty': float,
'temperature': float,
'seed': int,
'tfs_z': float,
'num_predict': int,
'top_k': int,
'top_p': float,
'num_keep': int,
'presence_penalty': float,
'frequency_penalty': float,
'num_batch': int,
'num_gpu': int,
'use_mmap': bool,
'use_mlock': bool,
'num_thread': int,
}
data = {"base_model_id": None, "params": {}}
data = {'base_model_id': None, 'params': {}}
# Parse base model
base_model_match = re.search(
r"^FROM\s+(\w+)", model_text, re.MULTILINE | re.IGNORECASE
)
base_model_match = re.search(r'^FROM\s+(\w+)', model_text, re.MULTILINE | re.IGNORECASE)
if base_model_match:
data["base_model_id"] = base_model_match.group(1)
data['base_model_id'] = base_model_match.group(1)
# Parse template
template_match = re.search(
r'TEMPLATE\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
)
template_match = re.search(r'TEMPLATE\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE)
if template_match:
data["params"] = {"template": template_match.group(1).strip()}
data['params'] = {'template': template_match.group(1).strip()}
# Parse stops
stops = re.findall(r'PARAMETER stop "(.*?)"', model_text, re.IGNORECASE)
if stops:
data["params"]["stop"] = stops
data['params']['stop'] = stops
# Parse other parameters from the provided list
for param, param_type in parameters_meta.items():
param_match = re.search(rf"PARAMETER {param} (.+)", model_text, re.IGNORECASE)
param_match = re.search(rf'PARAMETER {param} (.+)', model_text, re.IGNORECASE)
if param_match:
value = param_match.group(1)
@@ -725,39 +692,35 @@ def parse_ollama_modelfile(model_text):
elif param_type is float:
value = float(value)
elif param_type is bool:
value = value.lower() == "true"
value = value.lower() == 'true'
except Exception as e:
log.exception(f"Failed to parse parameter {param}: {e}")
log.exception(f'Failed to parse parameter {param}: {e}')
continue
data["params"][param] = value
data['params'][param] = value
# Parse adapter
adapter_match = re.search(r"ADAPTER (.+)", model_text, re.IGNORECASE)
adapter_match = re.search(r'ADAPTER (.+)', model_text, re.IGNORECASE)
if adapter_match:
data["params"]["adapter"] = adapter_match.group(1)
data['params']['adapter'] = adapter_match.group(1)
# Parse system description
system_desc_match = re.search(
r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
)
system_desc_match_single = re.search(
r"SYSTEM\s+([^\n]+)", model_text, re.IGNORECASE
)
system_desc_match = re.search(r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE)
system_desc_match_single = re.search(r'SYSTEM\s+([^\n]+)', model_text, re.IGNORECASE)
if system_desc_match:
data["params"]["system"] = system_desc_match.group(1).strip()
data['params']['system'] = system_desc_match.group(1).strip()
elif system_desc_match_single:
data["params"]["system"] = system_desc_match_single.group(1).strip()
data['params']['system'] = system_desc_match_single.group(1).strip()
# Parse messages
messages = []
message_matches = re.findall(r"MESSAGE (\w+) (.+)", model_text, re.IGNORECASE)
message_matches = re.findall(r'MESSAGE (\w+) (.+)', model_text, re.IGNORECASE)
for role, content in message_matches:
messages.append({"role": role, "content": content})
messages.append({'role': role, 'content': content})
if messages:
data["params"]["messages"] = messages
data['params']['messages'] = messages
return data
@@ -769,10 +732,10 @@ def convert_logit_bias_input_to_json(logit_bias_input) -> Optional[str]:
if isinstance(logit_bias_input, dict):
return json.dumps(logit_bias_input)
logit_bias_pairs = logit_bias_input.split(",")
logit_bias_pairs = logit_bias_input.split(',')
logit_bias_json = {}
for pair in logit_bias_pairs:
token, bias = pair.split(":")
token, bias = pair.split(':')
token = str(token.strip())
bias = int(bias.strip())
bias = 100 if bias > 100 else -100 if bias < -100 else bias
@@ -834,13 +797,13 @@ def strict_match_mime_type(supported: list[str] | str, header: str) -> Optional[
try:
if isinstance(supported, str):
supported = supported.split(",")
supported = supported.split(',')
supported = [s for s in supported if s.strip() and "/" in s]
supported = [s for s in supported if s.strip() and '/' in s]
if len(supported) == 0:
# Default to common types if none are specified
supported = ["audio/*", "video/webm"]
supported = ['audio/*', 'video/webm']
match = mimeparse.best_match(supported, header)
if not match:
@@ -854,15 +817,13 @@ def strict_match_mime_type(supported: list[str] | str, header: str) -> Optional[
return match
except Exception as e:
log.exception(f"Failed to match mime type {header}: {e}")
log.exception(f'Failed to match mime type {header}: {e}')
return None
def extract_urls(text: str) -> list[str]:
# Regex pattern to match URLs
url_pattern = re.compile(
r"(https?://[^\s]+)", re.IGNORECASE
) # Matches http and https URLs
url_pattern = re.compile(r'(https?://[^\s]+)', re.IGNORECASE) # Matches http and https URLs
return url_pattern.findall(text)
@@ -882,9 +843,7 @@ async def stream_wrapper(response, session, content_handler=None):
This is more reliable than BackgroundTask which may not run if client disconnects.
"""
try:
stream = (
content_handler(response.content) if content_handler else response.content
)
stream = content_handler(response.content) if content_handler else response.content
async for chunk in stream:
yield chunk
finally:
@@ -906,7 +865,7 @@ def stream_chunks_handler(stream: aiohttp.StreamReader):
return stream
async def yield_safe_stream_chunks():
buffer = b""
buffer = b''
skip_mode = False
async for data, _ in stream.iter_chunks():
@@ -915,9 +874,9 @@ def stream_chunks_handler(stream: aiohttp.StreamReader):
# In skip_mode, if buffer already exceeds the limit, clear it (it's part of an oversized line)
if skip_mode and len(buffer) > max_buffer_size:
buffer = b""
buffer = b''
lines = (buffer + data).split(b"\n")
lines = (buffer + data).split(b'\n')
# Process complete lines (except the last possibly incomplete fragment)
for i in range(len(lines) - 1):
@@ -929,18 +888,18 @@ def stream_chunks_handler(stream: aiohttp.StreamReader):
skip_mode = False
yield line
else:
yield b"data: {}"
yield b"\n"
yield b'data: {}'
yield b'\n'
else:
# Normal mode: check if line exceeds limit
if len(line) > max_buffer_size:
skip_mode = True
yield b"data: {}"
yield b"\n"
log.info(f"Skip mode triggered, line size: {len(line)}")
yield b'data: {}'
yield b'\n'
log.info(f'Skip mode triggered, line size: {len(line)}')
else:
yield line
yield b"\n"
yield b'\n'
# Save the last incomplete fragment
buffer = lines[-1]
@@ -948,13 +907,13 @@ def stream_chunks_handler(stream: aiohttp.StreamReader):
# Check if buffer exceeds limit
if not skip_mode and len(buffer) > max_buffer_size:
skip_mode = True
log.info(f"Skip mode triggered, buffer size: {len(buffer)}")
log.info(f'Skip mode triggered, buffer size: {len(buffer)}')
# Clear oversized buffer to prevent unlimited growth
buffer = b""
buffer = b''
# Process remaining buffer data
if buffer and not skip_mode:
yield buffer
yield b"\n"
yield b'\n'
return yield_safe_stream_chunks()