mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-05 10:28:06 -05:00
Merge remote-tracking branch 'upstream/dev' into playwright
# Conflicts: # backend/requirements.txt
This commit is contained in:
@@ -68,6 +68,10 @@ from open_webui.utils.misc import (
|
||||
)
|
||||
from open_webui.utils.tools import get_tools
|
||||
from open_webui.utils.plugin import load_function_module_by_id
|
||||
from open_webui.utils.filter import (
|
||||
get_sorted_filter_ids,
|
||||
process_filter_functions,
|
||||
)
|
||||
|
||||
|
||||
from open_webui.tasks import create_task
|
||||
@@ -91,99 +95,6 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
async def chat_completion_filter_functions_handler(request, body, model, extra_params):
|
||||
skip_files = None
|
||||
|
||||
def get_filter_function_ids(model):
|
||||
def get_priority(function_id):
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
if function is not None and hasattr(function, "valves"):
|
||||
# TODO: Fix FunctionModel
|
||||
return (function.valves if function.valves else {}).get("priority", 0)
|
||||
return 0
|
||||
|
||||
filter_ids = [
|
||||
function.id for function in Functions.get_global_filter_functions()
|
||||
]
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||
filter_ids = list(set(filter_ids))
|
||||
|
||||
enabled_filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||
]
|
||||
|
||||
filter_ids = [
|
||||
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
||||
]
|
||||
|
||||
filter_ids.sort(key=get_priority)
|
||||
return filter_ids
|
||||
|
||||
filter_ids = get_filter_function_ids(model)
|
||||
for filter_id in filter_ids:
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if not filter:
|
||||
continue
|
||||
|
||||
if filter_id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, _, _ = load_function_module_by_id(filter_id)
|
||||
request.app.state.FUNCTIONS[filter_id] = function_module
|
||||
|
||||
# Check if the function has a file_handler variable
|
||||
if hasattr(function_module, "file_handler"):
|
||||
skip_files = function_module.file_handler
|
||||
|
||||
# Apply valves to the function
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
valves = Functions.get_function_valves_by_id(filter_id)
|
||||
function_module.valves = function_module.Valves(
|
||||
**(valves if valves else {})
|
||||
)
|
||||
|
||||
if hasattr(function_module, "inlet"):
|
||||
try:
|
||||
inlet = function_module.inlet
|
||||
|
||||
# Create a dictionary of parameters to be passed to the function
|
||||
params = {"body": body} | {
|
||||
k: v
|
||||
for k, v in {
|
||||
**extra_params,
|
||||
"__model__": model,
|
||||
"__id__": filter_id,
|
||||
}.items()
|
||||
if k in inspect.signature(inlet).parameters
|
||||
}
|
||||
|
||||
if "__user__" in params and hasattr(function_module, "UserValves"):
|
||||
try:
|
||||
params["__user__"]["valves"] = function_module.UserValves(
|
||||
**Functions.get_user_valves_by_id_and_user_id(
|
||||
filter_id, params["__user__"]["id"]
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
if inspect.iscoroutinefunction(inlet):
|
||||
body = await inlet(**params)
|
||||
else:
|
||||
body = inlet(**params)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
raise e
|
||||
|
||||
if skip_files and "files" in body.get("metadata", {}):
|
||||
del body["metadata"]["files"]
|
||||
|
||||
return body, {}
|
||||
|
||||
|
||||
async def chat_completion_tools_handler(
|
||||
request: Request, body: dict, user: UserModel, models, tools
|
||||
) -> tuple[dict, dict]:
|
||||
@@ -566,13 +477,13 @@ async def chat_image_generation_handler(
|
||||
{
|
||||
"type": "status",
|
||||
"data": {
|
||||
"description": f"An error occured while generating an image",
|
||||
"description": f"An error occurred while generating an image",
|
||||
"done": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
system_message_content = "<context>Unable to generate an image, tell the user that an error occured</context>"
|
||||
system_message_content = "<context>Unable to generate an image, tell the user that an error occurred</context>"
|
||||
|
||||
if system_message_content:
|
||||
form_data["messages"] = add_or_update_system_message(
|
||||
@@ -700,6 +611,7 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
"__model__": model,
|
||||
}
|
||||
|
||||
# Initialize events to store additional event to be sent to the client
|
||||
@@ -776,8 +688,12 @@ async def process_chat_payload(request, form_data, metadata, user, model):
|
||||
)
|
||||
|
||||
try:
|
||||
form_data, flags = await chat_completion_filter_functions_handler(
|
||||
request, form_data, model, extra_params
|
||||
form_data, flags = await process_filter_functions(
|
||||
request=request,
|
||||
filter_ids=get_sorted_filter_ids(model),
|
||||
filter_type="inlet",
|
||||
form_data=form_data,
|
||||
extra_params=extra_params,
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error: {e}")
|
||||
@@ -1116,6 +1032,20 @@ async def process_chat_response(
|
||||
},
|
||||
)
|
||||
|
||||
def split_content_and_whitespace(content):
|
||||
content_stripped = content.rstrip()
|
||||
original_whitespace = (
|
||||
content[len(content_stripped) :]
|
||||
if len(content) > len(content_stripped)
|
||||
else ""
|
||||
)
|
||||
return content_stripped, original_whitespace
|
||||
|
||||
def is_opening_code_block(content):
|
||||
backtick_segments = content.split("```")
|
||||
# Even number of segments means the last backticks are opening a new block
|
||||
return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0
|
||||
|
||||
# Handle as a background task
|
||||
async def post_response_handler(response, events):
|
||||
def serialize_content_blocks(content_blocks, raw=False):
|
||||
@@ -1182,6 +1112,19 @@ async def process_chat_response(
|
||||
output = block.get("output", None)
|
||||
lang = attributes.get("lang", "")
|
||||
|
||||
content_stripped, original_whitespace = (
|
||||
split_content_and_whitespace(content)
|
||||
)
|
||||
if is_opening_code_block(content_stripped):
|
||||
# Remove trailing backticks that would open a new block
|
||||
content = (
|
||||
content_stripped.rstrip("`").rstrip()
|
||||
+ original_whitespace
|
||||
)
|
||||
else:
|
||||
# Keep content as is - either closing backticks or no backticks
|
||||
content = content_stripped + original_whitespace
|
||||
|
||||
if output:
|
||||
output = html.escape(json.dumps(output))
|
||||
|
||||
@@ -1236,10 +1179,10 @@ async def process_chat_response(
|
||||
match.end() :
|
||||
] # Content after opening tag
|
||||
|
||||
# Remove the start tag from the currently handling text block
|
||||
# Remove the start tag and after from the currently handling text block
|
||||
content_blocks[-1]["content"] = content_blocks[-1][
|
||||
"content"
|
||||
].replace(match.group(0), "")
|
||||
].replace(match.group(0) + after_tag, "")
|
||||
|
||||
if before_tag:
|
||||
content_blocks[-1]["content"] = before_tag
|
||||
|
||||
Reference in New Issue
Block a user