mirror of
https://github.com/open-webui/open-webui.git
synced 2026-03-11 17:47:44 -05:00
fix: payload tools handling
This commit is contained in:
@@ -2220,6 +2220,10 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
tool_ids = form_data.pop("tool_ids", None)
|
||||
files = form_data.pop("files", None)
|
||||
|
||||
# Caller-provided OpenAI-style tools take precedence over server-side
|
||||
# tool resolution (tool_ids, MCP servers, builtin tools).
|
||||
payload_tools = form_data.get("tools", None)
|
||||
|
||||
# Skills
|
||||
user_skill_ids = set(form_data.pop("skill_ids", None) or [])
|
||||
model_skill_ids = set(model.get("info", {}).get("meta", {}).get("skillIds", []))
|
||||
@@ -2291,228 +2295,236 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
}
|
||||
form_data["metadata"] = metadata
|
||||
|
||||
# Server side tools
|
||||
tool_ids = metadata.get("tool_ids", None)
|
||||
# Client side tools
|
||||
direct_tool_servers = metadata.get("tool_servers", None)
|
||||
# When the caller provides an explicit OpenAI-style `tools` array in the
|
||||
# request body, skip all server-side tool resolution and pass the caller's
|
||||
# tools through to the model unchanged.
|
||||
if payload_tools:
|
||||
log.debug(
|
||||
"Caller provided explicit tools — skipping server-side tool resolution"
|
||||
)
|
||||
else:
|
||||
# Server side tools
|
||||
tool_ids = metadata.get("tool_ids", None)
|
||||
# Client side tools
|
||||
direct_tool_servers = metadata.get("tool_servers", None)
|
||||
|
||||
log.debug(f"{tool_ids=}")
|
||||
log.debug(f"{direct_tool_servers=}")
|
||||
log.debug(f"{tool_ids=}")
|
||||
log.debug(f"{direct_tool_servers=}")
|
||||
|
||||
tools_dict = {}
|
||||
tools_dict = {}
|
||||
|
||||
mcp_clients = {}
|
||||
mcp_tools_dict = {}
|
||||
mcp_clients = {}
|
||||
mcp_tools_dict = {}
|
||||
|
||||
if tool_ids:
|
||||
for tool_id in tool_ids:
|
||||
if tool_id.startswith("server:mcp:"):
|
||||
try:
|
||||
server_id = tool_id[len("server:mcp:") :]
|
||||
if tool_ids:
|
||||
for tool_id in tool_ids:
|
||||
if tool_id.startswith("server:mcp:"):
|
||||
try:
|
||||
server_id = tool_id[len("server:mcp:") :]
|
||||
|
||||
mcp_server_connection = None
|
||||
for (
|
||||
server_connection
|
||||
) in request.app.state.config.TOOL_SERVER_CONNECTIONS:
|
||||
if (
|
||||
server_connection.get("type", "") == "mcp"
|
||||
and server_connection.get("info", {}).get("id") == server_id
|
||||
):
|
||||
mcp_server_connection = server_connection
|
||||
break
|
||||
mcp_server_connection = None
|
||||
for (
|
||||
server_connection
|
||||
) in request.app.state.config.TOOL_SERVER_CONNECTIONS:
|
||||
if (
|
||||
server_connection.get("type", "") == "mcp"
|
||||
and server_connection.get("info", {}).get("id") == server_id
|
||||
):
|
||||
mcp_server_connection = server_connection
|
||||
break
|
||||
|
||||
if not mcp_server_connection:
|
||||
log.error(f"MCP server with id {server_id} not found")
|
||||
continue
|
||||
if not mcp_server_connection:
|
||||
log.error(f"MCP server with id {server_id} not found")
|
||||
continue
|
||||
|
||||
# Check access control for MCP server
|
||||
if not has_tool_server_access(user, mcp_server_connection):
|
||||
log.warning(
|
||||
f"Access denied to MCP server {server_id} for user {user.id}"
|
||||
)
|
||||
continue
|
||||
# Check access control for MCP server
|
||||
if not has_tool_server_access(user, mcp_server_connection):
|
||||
log.warning(
|
||||
f"Access denied to MCP server {server_id} for user {user.id}"
|
||||
)
|
||||
continue
|
||||
|
||||
auth_type = mcp_server_connection.get("auth_type", "")
|
||||
headers = {}
|
||||
if auth_type == "bearer":
|
||||
headers["Authorization"] = (
|
||||
f"Bearer {mcp_server_connection.get('key', '')}"
|
||||
)
|
||||
elif auth_type == "none":
|
||||
# No authentication
|
||||
pass
|
||||
elif auth_type == "session":
|
||||
headers["Authorization"] = (
|
||||
f"Bearer {request.state.token.credentials}"
|
||||
)
|
||||
elif auth_type == "system_oauth":
|
||||
oauth_token = extra_params.get("__oauth_token__", None)
|
||||
if oauth_token:
|
||||
auth_type = mcp_server_connection.get("auth_type", "")
|
||||
headers = {}
|
||||
if auth_type == "bearer":
|
||||
headers["Authorization"] = (
|
||||
f"Bearer {oauth_token.get('access_token', '')}"
|
||||
f"Bearer {mcp_server_connection.get('key', '')}"
|
||||
)
|
||||
elif auth_type == "oauth_2.1":
|
||||
try:
|
||||
splits = server_id.split(":")
|
||||
server_id = splits[-1] if len(splits) > 1 else server_id
|
||||
|
||||
oauth_token = await request.app.state.oauth_client_manager.get_oauth_token(
|
||||
user.id, f"mcp:{server_id}"
|
||||
elif auth_type == "none":
|
||||
# No authentication
|
||||
pass
|
||||
elif auth_type == "session":
|
||||
headers["Authorization"] = (
|
||||
f"Bearer {request.state.token.credentials}"
|
||||
)
|
||||
|
||||
elif auth_type == "system_oauth":
|
||||
oauth_token = extra_params.get("__oauth_token__", None)
|
||||
if oauth_token:
|
||||
headers["Authorization"] = (
|
||||
f"Bearer {oauth_token.get('access_token', '')}"
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth token: {e}")
|
||||
oauth_token = None
|
||||
elif auth_type == "oauth_2.1":
|
||||
try:
|
||||
splits = server_id.split(":")
|
||||
server_id = splits[-1] if len(splits) > 1 else server_id
|
||||
|
||||
connection_headers = mcp_server_connection.get("headers", None)
|
||||
if connection_headers and isinstance(connection_headers, dict):
|
||||
for key, value in connection_headers.items():
|
||||
headers[key] = value
|
||||
|
||||
# Add user info headers if enabled
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||
headers = include_user_info_headers(headers, user)
|
||||
if metadata and metadata.get("chat_id"):
|
||||
headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = metadata.get(
|
||||
"chat_id"
|
||||
)
|
||||
if metadata and metadata.get("message_id"):
|
||||
headers[FORWARD_SESSION_INFO_HEADER_MESSAGE_ID] = (
|
||||
metadata.get("message_id")
|
||||
)
|
||||
|
||||
mcp_clients[server_id] = MCPClient()
|
||||
await mcp_clients[server_id].connect(
|
||||
url=mcp_server_connection.get("url", ""),
|
||||
headers=headers if headers else None,
|
||||
)
|
||||
|
||||
function_name_filter_list = mcp_server_connection.get(
|
||||
"config", {}
|
||||
).get("function_name_filter_list", "")
|
||||
|
||||
if isinstance(function_name_filter_list, str):
|
||||
function_name_filter_list = function_name_filter_list.split(",")
|
||||
|
||||
tool_specs = await mcp_clients[server_id].list_tool_specs()
|
||||
for tool_spec in tool_specs:
|
||||
|
||||
def make_tool_function(client, function_name):
|
||||
async def tool_function(**kwargs):
|
||||
return await client.call_tool(
|
||||
function_name,
|
||||
function_args=kwargs,
|
||||
oauth_token = await request.app.state.oauth_client_manager.get_oauth_token(
|
||||
user.id, f"mcp:{server_id}"
|
||||
)
|
||||
|
||||
return tool_function
|
||||
if oauth_token:
|
||||
headers["Authorization"] = (
|
||||
f"Bearer {oauth_token.get('access_token', '')}"
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth token: {e}")
|
||||
oauth_token = None
|
||||
|
||||
if function_name_filter_list:
|
||||
if not is_string_allowed(
|
||||
tool_spec["name"], function_name_filter_list
|
||||
):
|
||||
# Skip this function
|
||||
continue
|
||||
connection_headers = mcp_server_connection.get("headers", None)
|
||||
if connection_headers and isinstance(connection_headers, dict):
|
||||
for key, value in connection_headers.items():
|
||||
headers[key] = value
|
||||
|
||||
tool_function = make_tool_function(
|
||||
mcp_clients[server_id], tool_spec["name"]
|
||||
# Add user info headers if enabled
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||
headers = include_user_info_headers(headers, user)
|
||||
if metadata and metadata.get("chat_id"):
|
||||
headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = metadata.get(
|
||||
"chat_id"
|
||||
)
|
||||
if metadata and metadata.get("message_id"):
|
||||
headers[FORWARD_SESSION_INFO_HEADER_MESSAGE_ID] = (
|
||||
metadata.get("message_id")
|
||||
)
|
||||
|
||||
mcp_clients[server_id] = MCPClient()
|
||||
await mcp_clients[server_id].connect(
|
||||
url=mcp_server_connection.get("url", ""),
|
||||
headers=headers if headers else None,
|
||||
)
|
||||
|
||||
mcp_tools_dict[f"{server_id}_{tool_spec['name']}"] = {
|
||||
"spec": {
|
||||
**tool_spec,
|
||||
"name": f"{server_id}_{tool_spec['name']}",
|
||||
},
|
||||
"callable": tool_function,
|
||||
"type": "mcp",
|
||||
"client": mcp_clients[server_id],
|
||||
"direct": False,
|
||||
}
|
||||
except Exception as e:
|
||||
log.debug(e)
|
||||
if event_emitter:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:message:error",
|
||||
"data": {
|
||||
"error": {
|
||||
"content": f"Failed to connect to MCP server '{server_id}'"
|
||||
}
|
||||
function_name_filter_list = mcp_server_connection.get(
|
||||
"config", {}
|
||||
).get("function_name_filter_list", "")
|
||||
|
||||
if isinstance(function_name_filter_list, str):
|
||||
function_name_filter_list = function_name_filter_list.split(",")
|
||||
|
||||
tool_specs = await mcp_clients[server_id].list_tool_specs()
|
||||
for tool_spec in tool_specs:
|
||||
|
||||
def make_tool_function(client, function_name):
|
||||
async def tool_function(**kwargs):
|
||||
return await client.call_tool(
|
||||
function_name,
|
||||
function_args=kwargs,
|
||||
)
|
||||
|
||||
return tool_function
|
||||
|
||||
if function_name_filter_list:
|
||||
if not is_string_allowed(
|
||||
tool_spec["name"], function_name_filter_list
|
||||
):
|
||||
# Skip this function
|
||||
continue
|
||||
|
||||
tool_function = make_tool_function(
|
||||
mcp_clients[server_id], tool_spec["name"]
|
||||
)
|
||||
|
||||
mcp_tools_dict[f"{server_id}_{tool_spec['name']}"] = {
|
||||
"spec": {
|
||||
**tool_spec,
|
||||
"name": f"{server_id}_{tool_spec['name']}",
|
||||
},
|
||||
"callable": tool_function,
|
||||
"type": "mcp",
|
||||
"client": mcp_clients[server_id],
|
||||
"direct": False,
|
||||
}
|
||||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
log.debug(e)
|
||||
if event_emitter:
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:message:error",
|
||||
"data": {
|
||||
"error": {
|
||||
"content": f"Failed to connect to MCP server '{server_id}'"
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
tools_dict = await get_tools(
|
||||
request,
|
||||
tool_ids,
|
||||
user,
|
||||
{
|
||||
**extra_params,
|
||||
"__model__": models[task_model_id],
|
||||
"__messages__": form_data["messages"],
|
||||
"__files__": metadata.get("files", []),
|
||||
},
|
||||
)
|
||||
tools_dict = await get_tools(
|
||||
request,
|
||||
tool_ids,
|
||||
user,
|
||||
{
|
||||
**extra_params,
|
||||
"__model__": models[task_model_id],
|
||||
"__messages__": form_data["messages"],
|
||||
"__files__": metadata.get("files", []),
|
||||
},
|
||||
)
|
||||
|
||||
if mcp_tools_dict:
|
||||
tools_dict = {**tools_dict, **mcp_tools_dict}
|
||||
if mcp_tools_dict:
|
||||
tools_dict = {**tools_dict, **mcp_tools_dict}
|
||||
|
||||
if direct_tool_servers:
|
||||
for tool_server in direct_tool_servers:
|
||||
tool_specs = tool_server.pop("specs", [])
|
||||
if direct_tool_servers:
|
||||
for tool_server in direct_tool_servers:
|
||||
tool_specs = tool_server.pop("specs", [])
|
||||
|
||||
for tool in tool_specs:
|
||||
tools_dict[tool["name"]] = {
|
||||
"spec": tool,
|
||||
"direct": True,
|
||||
"server": tool_server,
|
||||
}
|
||||
for tool in tool_specs:
|
||||
tools_dict[tool["name"]] = {
|
||||
"spec": tool,
|
||||
"direct": True,
|
||||
"server": tool_server,
|
||||
}
|
||||
|
||||
if mcp_clients:
|
||||
metadata["mcp_clients"] = mcp_clients
|
||||
if mcp_clients:
|
||||
metadata["mcp_clients"] = mcp_clients
|
||||
|
||||
# Inject builtin tools for native function calling based on enabled features and model capability
|
||||
# Check if builtin_tools capability is enabled for this model (defaults to True if not specified)
|
||||
builtin_tools_enabled = (
|
||||
model.get("info", {}).get("meta", {}).get("capabilities") or {}
|
||||
).get("builtin_tools", True)
|
||||
if (
|
||||
metadata.get("params", {}).get("function_calling") == "native"
|
||||
and builtin_tools_enabled
|
||||
):
|
||||
# Add file context to user messages
|
||||
chat_id = metadata.get("chat_id")
|
||||
form_data["messages"] = add_file_context(
|
||||
form_data.get("messages", []), chat_id, user
|
||||
)
|
||||
builtin_tools = get_builtin_tools(
|
||||
request,
|
||||
{
|
||||
**extra_params,
|
||||
"__event_emitter__": event_emitter,
|
||||
"__skill_ids__": [
|
||||
s.id for s in available_skills if s.id not in user_skill_ids
|
||||
],
|
||||
},
|
||||
features,
|
||||
model,
|
||||
)
|
||||
for name, tool_dict in builtin_tools.items():
|
||||
if name not in tools_dict:
|
||||
tools_dict[name] = tool_dict
|
||||
# Inject builtin tools for native function calling based on enabled features and model capability
|
||||
# Check if builtin_tools capability is enabled for this model (defaults to True if not specified)
|
||||
builtin_tools_enabled = (
|
||||
model.get("info", {}).get("meta", {}).get("capabilities") or {}
|
||||
).get("builtin_tools", True)
|
||||
if (
|
||||
metadata.get("params", {}).get("function_calling") == "native"
|
||||
and builtin_tools_enabled
|
||||
):
|
||||
# Add file context to user messages
|
||||
chat_id = metadata.get("chat_id")
|
||||
form_data["messages"] = add_file_context(
|
||||
form_data.get("messages", []), chat_id, user
|
||||
)
|
||||
builtin_tools = get_builtin_tools(
|
||||
request,
|
||||
{
|
||||
**extra_params,
|
||||
"__event_emitter__": event_emitter,
|
||||
"__skill_ids__": [
|
||||
s.id for s in available_skills if s.id not in user_skill_ids
|
||||
],
|
||||
},
|
||||
features,
|
||||
model,
|
||||
)
|
||||
for name, tool_dict in builtin_tools.items():
|
||||
if name not in tools_dict:
|
||||
tools_dict[name] = tool_dict
|
||||
|
||||
if tools_dict:
|
||||
if metadata.get("params", {}).get("function_calling") == "native":
|
||||
# If the function calling is native, then call the tools function calling handler
|
||||
metadata["tools"] = tools_dict
|
||||
form_data["tools"] = [
|
||||
{"type": "function", "function": tool.get("spec", {})}
|
||||
for tool in tools_dict.values()
|
||||
]
|
||||
if tools_dict:
|
||||
if metadata.get("params", {}).get("function_calling") == "native":
|
||||
# If the function calling is native, then call the tools function calling handler
|
||||
metadata["tools"] = tools_dict
|
||||
form_data["tools"] = [
|
||||
{"type": "function", "function": tool.get("spec", {})}
|
||||
for tool in tools_dict.values()
|
||||
]
|
||||
|
||||
else:
|
||||
# If the function calling is not native, then call the tools function calling handler
|
||||
|
||||
Reference in New Issue
Block a user