mirror of
https://github.com/open-webui/open-webui.git
synced 2026-04-30 17:28:51 -05:00
fix: payload tools handling
This commit is contained in:
@@ -2220,6 +2220,10 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||||||
tool_ids = form_data.pop("tool_ids", None)
|
tool_ids = form_data.pop("tool_ids", None)
|
||||||
files = form_data.pop("files", None)
|
files = form_data.pop("files", None)
|
||||||
|
|
||||||
|
# Caller-provided OpenAI-style tools take precedence over server-side
|
||||||
|
# tool resolution (tool_ids, MCP servers, builtin tools).
|
||||||
|
payload_tools = form_data.get("tools", None)
|
||||||
|
|
||||||
# Skills
|
# Skills
|
||||||
user_skill_ids = set(form_data.pop("skill_ids", None) or [])
|
user_skill_ids = set(form_data.pop("skill_ids", None) or [])
|
||||||
model_skill_ids = set(model.get("info", {}).get("meta", {}).get("skillIds", []))
|
model_skill_ids = set(model.get("info", {}).get("meta", {}).get("skillIds", []))
|
||||||
@@ -2291,228 +2295,236 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||||||
}
|
}
|
||||||
form_data["metadata"] = metadata
|
form_data["metadata"] = metadata
|
||||||
|
|
||||||
# Server side tools
|
# When the caller provides an explicit OpenAI-style `tools` array in the
|
||||||
tool_ids = metadata.get("tool_ids", None)
|
# request body, skip all server-side tool resolution and pass the caller's
|
||||||
# Client side tools
|
# tools through to the model unchanged.
|
||||||
direct_tool_servers = metadata.get("tool_servers", None)
|
if payload_tools:
|
||||||
|
log.debug(
|
||||||
|
"Caller provided explicit tools — skipping server-side tool resolution"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Server side tools
|
||||||
|
tool_ids = metadata.get("tool_ids", None)
|
||||||
|
# Client side tools
|
||||||
|
direct_tool_servers = metadata.get("tool_servers", None)
|
||||||
|
|
||||||
log.debug(f"{tool_ids=}")
|
log.debug(f"{tool_ids=}")
|
||||||
log.debug(f"{direct_tool_servers=}")
|
log.debug(f"{direct_tool_servers=}")
|
||||||
|
|
||||||
tools_dict = {}
|
tools_dict = {}
|
||||||
|
|
||||||
mcp_clients = {}
|
mcp_clients = {}
|
||||||
mcp_tools_dict = {}
|
mcp_tools_dict = {}
|
||||||
|
|
||||||
if tool_ids:
|
if tool_ids:
|
||||||
for tool_id in tool_ids:
|
for tool_id in tool_ids:
|
||||||
if tool_id.startswith("server:mcp:"):
|
if tool_id.startswith("server:mcp:"):
|
||||||
try:
|
try:
|
||||||
server_id = tool_id[len("server:mcp:") :]
|
server_id = tool_id[len("server:mcp:") :]
|
||||||
|
|
||||||
mcp_server_connection = None
|
mcp_server_connection = None
|
||||||
for (
|
for (
|
||||||
server_connection
|
server_connection
|
||||||
) in request.app.state.config.TOOL_SERVER_CONNECTIONS:
|
) in request.app.state.config.TOOL_SERVER_CONNECTIONS:
|
||||||
if (
|
if (
|
||||||
server_connection.get("type", "") == "mcp"
|
server_connection.get("type", "") == "mcp"
|
||||||
and server_connection.get("info", {}).get("id") == server_id
|
and server_connection.get("info", {}).get("id") == server_id
|
||||||
):
|
):
|
||||||
mcp_server_connection = server_connection
|
mcp_server_connection = server_connection
|
||||||
break
|
break
|
||||||
|
|
||||||
if not mcp_server_connection:
|
if not mcp_server_connection:
|
||||||
log.error(f"MCP server with id {server_id} not found")
|
log.error(f"MCP server with id {server_id} not found")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check access control for MCP server
|
# Check access control for MCP server
|
||||||
if not has_tool_server_access(user, mcp_server_connection):
|
if not has_tool_server_access(user, mcp_server_connection):
|
||||||
log.warning(
|
log.warning(
|
||||||
f"Access denied to MCP server {server_id} for user {user.id}"
|
f"Access denied to MCP server {server_id} for user {user.id}"
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
auth_type = mcp_server_connection.get("auth_type", "")
|
auth_type = mcp_server_connection.get("auth_type", "")
|
||||||
headers = {}
|
headers = {}
|
||||||
if auth_type == "bearer":
|
if auth_type == "bearer":
|
||||||
headers["Authorization"] = (
|
|
||||||
f"Bearer {mcp_server_connection.get('key', '')}"
|
|
||||||
)
|
|
||||||
elif auth_type == "none":
|
|
||||||
# No authentication
|
|
||||||
pass
|
|
||||||
elif auth_type == "session":
|
|
||||||
headers["Authorization"] = (
|
|
||||||
f"Bearer {request.state.token.credentials}"
|
|
||||||
)
|
|
||||||
elif auth_type == "system_oauth":
|
|
||||||
oauth_token = extra_params.get("__oauth_token__", None)
|
|
||||||
if oauth_token:
|
|
||||||
headers["Authorization"] = (
|
headers["Authorization"] = (
|
||||||
f"Bearer {oauth_token.get('access_token', '')}"
|
f"Bearer {mcp_server_connection.get('key', '')}"
|
||||||
)
|
)
|
||||||
elif auth_type == "oauth_2.1":
|
elif auth_type == "none":
|
||||||
try:
|
# No authentication
|
||||||
splits = server_id.split(":")
|
pass
|
||||||
server_id = splits[-1] if len(splits) > 1 else server_id
|
elif auth_type == "session":
|
||||||
|
headers["Authorization"] = (
|
||||||
oauth_token = await request.app.state.oauth_client_manager.get_oauth_token(
|
f"Bearer {request.state.token.credentials}"
|
||||||
user.id, f"mcp:{server_id}"
|
|
||||||
)
|
)
|
||||||
|
elif auth_type == "system_oauth":
|
||||||
|
oauth_token = extra_params.get("__oauth_token__", None)
|
||||||
if oauth_token:
|
if oauth_token:
|
||||||
headers["Authorization"] = (
|
headers["Authorization"] = (
|
||||||
f"Bearer {oauth_token.get('access_token', '')}"
|
f"Bearer {oauth_token.get('access_token', '')}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
elif auth_type == "oauth_2.1":
|
||||||
log.error(f"Error getting OAuth token: {e}")
|
try:
|
||||||
oauth_token = None
|
splits = server_id.split(":")
|
||||||
|
server_id = splits[-1] if len(splits) > 1 else server_id
|
||||||
|
|
||||||
connection_headers = mcp_server_connection.get("headers", None)
|
oauth_token = await request.app.state.oauth_client_manager.get_oauth_token(
|
||||||
if connection_headers and isinstance(connection_headers, dict):
|
user.id, f"mcp:{server_id}"
|
||||||
for key, value in connection_headers.items():
|
|
||||||
headers[key] = value
|
|
||||||
|
|
||||||
# Add user info headers if enabled
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
|
||||||
headers = include_user_info_headers(headers, user)
|
|
||||||
if metadata and metadata.get("chat_id"):
|
|
||||||
headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = metadata.get(
|
|
||||||
"chat_id"
|
|
||||||
)
|
|
||||||
if metadata and metadata.get("message_id"):
|
|
||||||
headers[FORWARD_SESSION_INFO_HEADER_MESSAGE_ID] = (
|
|
||||||
metadata.get("message_id")
|
|
||||||
)
|
|
||||||
|
|
||||||
mcp_clients[server_id] = MCPClient()
|
|
||||||
await mcp_clients[server_id].connect(
|
|
||||||
url=mcp_server_connection.get("url", ""),
|
|
||||||
headers=headers if headers else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
function_name_filter_list = mcp_server_connection.get(
|
|
||||||
"config", {}
|
|
||||||
).get("function_name_filter_list", "")
|
|
||||||
|
|
||||||
if isinstance(function_name_filter_list, str):
|
|
||||||
function_name_filter_list = function_name_filter_list.split(",")
|
|
||||||
|
|
||||||
tool_specs = await mcp_clients[server_id].list_tool_specs()
|
|
||||||
for tool_spec in tool_specs:
|
|
||||||
|
|
||||||
def make_tool_function(client, function_name):
|
|
||||||
async def tool_function(**kwargs):
|
|
||||||
return await client.call_tool(
|
|
||||||
function_name,
|
|
||||||
function_args=kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return tool_function
|
if oauth_token:
|
||||||
|
headers["Authorization"] = (
|
||||||
|
f"Bearer {oauth_token.get('access_token', '')}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error getting OAuth token: {e}")
|
||||||
|
oauth_token = None
|
||||||
|
|
||||||
if function_name_filter_list:
|
connection_headers = mcp_server_connection.get("headers", None)
|
||||||
if not is_string_allowed(
|
if connection_headers and isinstance(connection_headers, dict):
|
||||||
tool_spec["name"], function_name_filter_list
|
for key, value in connection_headers.items():
|
||||||
):
|
headers[key] = value
|
||||||
# Skip this function
|
|
||||||
continue
|
|
||||||
|
|
||||||
tool_function = make_tool_function(
|
# Add user info headers if enabled
|
||||||
mcp_clients[server_id], tool_spec["name"]
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||||
|
headers = include_user_info_headers(headers, user)
|
||||||
|
if metadata and metadata.get("chat_id"):
|
||||||
|
headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = metadata.get(
|
||||||
|
"chat_id"
|
||||||
|
)
|
||||||
|
if metadata and metadata.get("message_id"):
|
||||||
|
headers[FORWARD_SESSION_INFO_HEADER_MESSAGE_ID] = (
|
||||||
|
metadata.get("message_id")
|
||||||
|
)
|
||||||
|
|
||||||
|
mcp_clients[server_id] = MCPClient()
|
||||||
|
await mcp_clients[server_id].connect(
|
||||||
|
url=mcp_server_connection.get("url", ""),
|
||||||
|
headers=headers if headers else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
mcp_tools_dict[f"{server_id}_{tool_spec['name']}"] = {
|
function_name_filter_list = mcp_server_connection.get(
|
||||||
"spec": {
|
"config", {}
|
||||||
**tool_spec,
|
).get("function_name_filter_list", "")
|
||||||
"name": f"{server_id}_{tool_spec['name']}",
|
|
||||||
},
|
if isinstance(function_name_filter_list, str):
|
||||||
"callable": tool_function,
|
function_name_filter_list = function_name_filter_list.split(",")
|
||||||
"type": "mcp",
|
|
||||||
"client": mcp_clients[server_id],
|
tool_specs = await mcp_clients[server_id].list_tool_specs()
|
||||||
"direct": False,
|
for tool_spec in tool_specs:
|
||||||
}
|
|
||||||
except Exception as e:
|
def make_tool_function(client, function_name):
|
||||||
log.debug(e)
|
async def tool_function(**kwargs):
|
||||||
if event_emitter:
|
return await client.call_tool(
|
||||||
await event_emitter(
|
function_name,
|
||||||
{
|
function_args=kwargs,
|
||||||
"type": "chat:message:error",
|
)
|
||||||
"data": {
|
|
||||||
"error": {
|
return tool_function
|
||||||
"content": f"Failed to connect to MCP server '{server_id}'"
|
|
||||||
}
|
if function_name_filter_list:
|
||||||
|
if not is_string_allowed(
|
||||||
|
tool_spec["name"], function_name_filter_list
|
||||||
|
):
|
||||||
|
# Skip this function
|
||||||
|
continue
|
||||||
|
|
||||||
|
tool_function = make_tool_function(
|
||||||
|
mcp_clients[server_id], tool_spec["name"]
|
||||||
|
)
|
||||||
|
|
||||||
|
mcp_tools_dict[f"{server_id}_{tool_spec['name']}"] = {
|
||||||
|
"spec": {
|
||||||
|
**tool_spec,
|
||||||
|
"name": f"{server_id}_{tool_spec['name']}",
|
||||||
},
|
},
|
||||||
|
"callable": tool_function,
|
||||||
|
"type": "mcp",
|
||||||
|
"client": mcp_clients[server_id],
|
||||||
|
"direct": False,
|
||||||
}
|
}
|
||||||
)
|
except Exception as e:
|
||||||
continue
|
log.debug(e)
|
||||||
|
if event_emitter:
|
||||||
|
await event_emitter(
|
||||||
|
{
|
||||||
|
"type": "chat:message:error",
|
||||||
|
"data": {
|
||||||
|
"error": {
|
||||||
|
"content": f"Failed to connect to MCP server '{server_id}'"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
tools_dict = await get_tools(
|
tools_dict = await get_tools(
|
||||||
request,
|
request,
|
||||||
tool_ids,
|
tool_ids,
|
||||||
user,
|
user,
|
||||||
{
|
{
|
||||||
**extra_params,
|
**extra_params,
|
||||||
"__model__": models[task_model_id],
|
"__model__": models[task_model_id],
|
||||||
"__messages__": form_data["messages"],
|
"__messages__": form_data["messages"],
|
||||||
"__files__": metadata.get("files", []),
|
"__files__": metadata.get("files", []),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if mcp_tools_dict:
|
if mcp_tools_dict:
|
||||||
tools_dict = {**tools_dict, **mcp_tools_dict}
|
tools_dict = {**tools_dict, **mcp_tools_dict}
|
||||||
|
|
||||||
if direct_tool_servers:
|
if direct_tool_servers:
|
||||||
for tool_server in direct_tool_servers:
|
for tool_server in direct_tool_servers:
|
||||||
tool_specs = tool_server.pop("specs", [])
|
tool_specs = tool_server.pop("specs", [])
|
||||||
|
|
||||||
for tool in tool_specs:
|
for tool in tool_specs:
|
||||||
tools_dict[tool["name"]] = {
|
tools_dict[tool["name"]] = {
|
||||||
"spec": tool,
|
"spec": tool,
|
||||||
"direct": True,
|
"direct": True,
|
||||||
"server": tool_server,
|
"server": tool_server,
|
||||||
}
|
}
|
||||||
|
|
||||||
if mcp_clients:
|
if mcp_clients:
|
||||||
metadata["mcp_clients"] = mcp_clients
|
metadata["mcp_clients"] = mcp_clients
|
||||||
|
|
||||||
# Inject builtin tools for native function calling based on enabled features and model capability
|
# Inject builtin tools for native function calling based on enabled features and model capability
|
||||||
# Check if builtin_tools capability is enabled for this model (defaults to True if not specified)
|
# Check if builtin_tools capability is enabled for this model (defaults to True if not specified)
|
||||||
builtin_tools_enabled = (
|
builtin_tools_enabled = (
|
||||||
model.get("info", {}).get("meta", {}).get("capabilities") or {}
|
model.get("info", {}).get("meta", {}).get("capabilities") or {}
|
||||||
).get("builtin_tools", True)
|
).get("builtin_tools", True)
|
||||||
if (
|
if (
|
||||||
metadata.get("params", {}).get("function_calling") == "native"
|
metadata.get("params", {}).get("function_calling") == "native"
|
||||||
and builtin_tools_enabled
|
and builtin_tools_enabled
|
||||||
):
|
):
|
||||||
# Add file context to user messages
|
# Add file context to user messages
|
||||||
chat_id = metadata.get("chat_id")
|
chat_id = metadata.get("chat_id")
|
||||||
form_data["messages"] = add_file_context(
|
form_data["messages"] = add_file_context(
|
||||||
form_data.get("messages", []), chat_id, user
|
form_data.get("messages", []), chat_id, user
|
||||||
)
|
)
|
||||||
builtin_tools = get_builtin_tools(
|
builtin_tools = get_builtin_tools(
|
||||||
request,
|
request,
|
||||||
{
|
{
|
||||||
**extra_params,
|
**extra_params,
|
||||||
"__event_emitter__": event_emitter,
|
"__event_emitter__": event_emitter,
|
||||||
"__skill_ids__": [
|
"__skill_ids__": [
|
||||||
s.id for s in available_skills if s.id not in user_skill_ids
|
s.id for s in available_skills if s.id not in user_skill_ids
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
features,
|
features,
|
||||||
model,
|
model,
|
||||||
)
|
)
|
||||||
for name, tool_dict in builtin_tools.items():
|
for name, tool_dict in builtin_tools.items():
|
||||||
if name not in tools_dict:
|
if name not in tools_dict:
|
||||||
tools_dict[name] = tool_dict
|
tools_dict[name] = tool_dict
|
||||||
|
|
||||||
if tools_dict:
|
if tools_dict:
|
||||||
if metadata.get("params", {}).get("function_calling") == "native":
|
if metadata.get("params", {}).get("function_calling") == "native":
|
||||||
# If the function calling is native, then call the tools function calling handler
|
# If the function calling is native, then call the tools function calling handler
|
||||||
metadata["tools"] = tools_dict
|
metadata["tools"] = tools_dict
|
||||||
form_data["tools"] = [
|
form_data["tools"] = [
|
||||||
{"type": "function", "function": tool.get("spec", {})}
|
{"type": "function", "function": tool.get("spec", {})}
|
||||||
for tool in tools_dict.values()
|
for tool in tools_dict.values()
|
||||||
]
|
]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# If the function calling is not native, then call the tools function calling handler
|
# If the function calling is not native, then call the tools function calling handler
|
||||||
|
|||||||
Reference in New Issue
Block a user