fix: payload tools handling

This commit is contained in:
Timothy Jaeryang Baek
2026-02-22 17:58:59 -06:00
parent 4b3543d3c0
commit 8f0658e64f

View File

@@ -2220,6 +2220,10 @@ async def process_chat_payload(request, form_data, user, metadata, model):
tool_ids = form_data.pop("tool_ids", None) tool_ids = form_data.pop("tool_ids", None)
files = form_data.pop("files", None) files = form_data.pop("files", None)
# Caller-provided OpenAI-style tools take precedence over server-side
# tool resolution (tool_ids, MCP servers, builtin tools).
payload_tools = form_data.get("tools", None)
# Skills # Skills
user_skill_ids = set(form_data.pop("skill_ids", None) or []) user_skill_ids = set(form_data.pop("skill_ids", None) or [])
model_skill_ids = set(model.get("info", {}).get("meta", {}).get("skillIds", [])) model_skill_ids = set(model.get("info", {}).get("meta", {}).get("skillIds", []))
@@ -2291,228 +2295,236 @@ async def process_chat_payload(request, form_data, user, metadata, model):
} }
form_data["metadata"] = metadata form_data["metadata"] = metadata
# Server side tools # When the caller provides an explicit OpenAI-style `tools` array in the
tool_ids = metadata.get("tool_ids", None) # request body, skip all server-side tool resolution and pass the caller's
# Client side tools # tools through to the model unchanged.
direct_tool_servers = metadata.get("tool_servers", None) if payload_tools:
log.debug(
"Caller provided explicit tools — skipping server-side tool resolution"
)
else:
# Server side tools
tool_ids = metadata.get("tool_ids", None)
# Client side tools
direct_tool_servers = metadata.get("tool_servers", None)
log.debug(f"{tool_ids=}") log.debug(f"{tool_ids=}")
log.debug(f"{direct_tool_servers=}") log.debug(f"{direct_tool_servers=}")
tools_dict = {} tools_dict = {}
mcp_clients = {} mcp_clients = {}
mcp_tools_dict = {} mcp_tools_dict = {}
if tool_ids: if tool_ids:
for tool_id in tool_ids: for tool_id in tool_ids:
if tool_id.startswith("server:mcp:"): if tool_id.startswith("server:mcp:"):
try: try:
server_id = tool_id[len("server:mcp:") :] server_id = tool_id[len("server:mcp:") :]
mcp_server_connection = None mcp_server_connection = None
for ( for (
server_connection server_connection
) in request.app.state.config.TOOL_SERVER_CONNECTIONS: ) in request.app.state.config.TOOL_SERVER_CONNECTIONS:
if ( if (
server_connection.get("type", "") == "mcp" server_connection.get("type", "") == "mcp"
and server_connection.get("info", {}).get("id") == server_id and server_connection.get("info", {}).get("id") == server_id
): ):
mcp_server_connection = server_connection mcp_server_connection = server_connection
break break
if not mcp_server_connection: if not mcp_server_connection:
log.error(f"MCP server with id {server_id} not found") log.error(f"MCP server with id {server_id} not found")
continue continue
# Check access control for MCP server # Check access control for MCP server
if not has_tool_server_access(user, mcp_server_connection): if not has_tool_server_access(user, mcp_server_connection):
log.warning( log.warning(
f"Access denied to MCP server {server_id} for user {user.id}" f"Access denied to MCP server {server_id} for user {user.id}"
) )
continue continue
auth_type = mcp_server_connection.get("auth_type", "") auth_type = mcp_server_connection.get("auth_type", "")
headers = {} headers = {}
if auth_type == "bearer": if auth_type == "bearer":
headers["Authorization"] = (
f"Bearer {mcp_server_connection.get('key', '')}"
)
elif auth_type == "none":
# No authentication
pass
elif auth_type == "session":
headers["Authorization"] = (
f"Bearer {request.state.token.credentials}"
)
elif auth_type == "system_oauth":
oauth_token = extra_params.get("__oauth_token__", None)
if oauth_token:
headers["Authorization"] = ( headers["Authorization"] = (
f"Bearer {oauth_token.get('access_token', '')}" f"Bearer {mcp_server_connection.get('key', '')}"
) )
elif auth_type == "oauth_2.1": elif auth_type == "none":
try: # No authentication
splits = server_id.split(":") pass
server_id = splits[-1] if len(splits) > 1 else server_id elif auth_type == "session":
headers["Authorization"] = (
oauth_token = await request.app.state.oauth_client_manager.get_oauth_token( f"Bearer {request.state.token.credentials}"
user.id, f"mcp:{server_id}"
) )
elif auth_type == "system_oauth":
oauth_token = extra_params.get("__oauth_token__", None)
if oauth_token: if oauth_token:
headers["Authorization"] = ( headers["Authorization"] = (
f"Bearer {oauth_token.get('access_token', '')}" f"Bearer {oauth_token.get('access_token', '')}"
) )
except Exception as e: elif auth_type == "oauth_2.1":
log.error(f"Error getting OAuth token: {e}") try:
oauth_token = None splits = server_id.split(":")
server_id = splits[-1] if len(splits) > 1 else server_id
connection_headers = mcp_server_connection.get("headers", None) oauth_token = await request.app.state.oauth_client_manager.get_oauth_token(
if connection_headers and isinstance(connection_headers, dict): user.id, f"mcp:{server_id}"
for key, value in connection_headers.items():
headers[key] = value
# Add user info headers if enabled
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
if metadata and metadata.get("chat_id"):
headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = metadata.get(
"chat_id"
)
if metadata and metadata.get("message_id"):
headers[FORWARD_SESSION_INFO_HEADER_MESSAGE_ID] = (
metadata.get("message_id")
)
mcp_clients[server_id] = MCPClient()
await mcp_clients[server_id].connect(
url=mcp_server_connection.get("url", ""),
headers=headers if headers else None,
)
function_name_filter_list = mcp_server_connection.get(
"config", {}
).get("function_name_filter_list", "")
if isinstance(function_name_filter_list, str):
function_name_filter_list = function_name_filter_list.split(",")
tool_specs = await mcp_clients[server_id].list_tool_specs()
for tool_spec in tool_specs:
def make_tool_function(client, function_name):
async def tool_function(**kwargs):
return await client.call_tool(
function_name,
function_args=kwargs,
) )
return tool_function if oauth_token:
headers["Authorization"] = (
f"Bearer {oauth_token.get('access_token', '')}"
)
except Exception as e:
log.error(f"Error getting OAuth token: {e}")
oauth_token = None
if function_name_filter_list: connection_headers = mcp_server_connection.get("headers", None)
if not is_string_allowed( if connection_headers and isinstance(connection_headers, dict):
tool_spec["name"], function_name_filter_list for key, value in connection_headers.items():
): headers[key] = value
# Skip this function
continue
tool_function = make_tool_function( # Add user info headers if enabled
mcp_clients[server_id], tool_spec["name"] if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
if metadata and metadata.get("chat_id"):
headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = metadata.get(
"chat_id"
)
if metadata and metadata.get("message_id"):
headers[FORWARD_SESSION_INFO_HEADER_MESSAGE_ID] = (
metadata.get("message_id")
)
mcp_clients[server_id] = MCPClient()
await mcp_clients[server_id].connect(
url=mcp_server_connection.get("url", ""),
headers=headers if headers else None,
) )
mcp_tools_dict[f"{server_id}_{tool_spec['name']}"] = { function_name_filter_list = mcp_server_connection.get(
"spec": { "config", {}
**tool_spec, ).get("function_name_filter_list", "")
"name": f"{server_id}_{tool_spec['name']}",
}, if isinstance(function_name_filter_list, str):
"callable": tool_function, function_name_filter_list = function_name_filter_list.split(",")
"type": "mcp",
"client": mcp_clients[server_id], tool_specs = await mcp_clients[server_id].list_tool_specs()
"direct": False, for tool_spec in tool_specs:
}
except Exception as e: def make_tool_function(client, function_name):
log.debug(e) async def tool_function(**kwargs):
if event_emitter: return await client.call_tool(
await event_emitter( function_name,
{ function_args=kwargs,
"type": "chat:message:error", )
"data": {
"error": { return tool_function
"content": f"Failed to connect to MCP server '{server_id}'"
} if function_name_filter_list:
if not is_string_allowed(
tool_spec["name"], function_name_filter_list
):
# Skip this function
continue
tool_function = make_tool_function(
mcp_clients[server_id], tool_spec["name"]
)
mcp_tools_dict[f"{server_id}_{tool_spec['name']}"] = {
"spec": {
**tool_spec,
"name": f"{server_id}_{tool_spec['name']}",
}, },
"callable": tool_function,
"type": "mcp",
"client": mcp_clients[server_id],
"direct": False,
} }
) except Exception as e:
continue log.debug(e)
if event_emitter:
await event_emitter(
{
"type": "chat:message:error",
"data": {
"error": {
"content": f"Failed to connect to MCP server '{server_id}'"
}
},
}
)
continue
tools_dict = await get_tools( tools_dict = await get_tools(
request, request,
tool_ids, tool_ids,
user, user,
{ {
**extra_params, **extra_params,
"__model__": models[task_model_id], "__model__": models[task_model_id],
"__messages__": form_data["messages"], "__messages__": form_data["messages"],
"__files__": metadata.get("files", []), "__files__": metadata.get("files", []),
}, },
) )
if mcp_tools_dict: if mcp_tools_dict:
tools_dict = {**tools_dict, **mcp_tools_dict} tools_dict = {**tools_dict, **mcp_tools_dict}
if direct_tool_servers: if direct_tool_servers:
for tool_server in direct_tool_servers: for tool_server in direct_tool_servers:
tool_specs = tool_server.pop("specs", []) tool_specs = tool_server.pop("specs", [])
for tool in tool_specs: for tool in tool_specs:
tools_dict[tool["name"]] = { tools_dict[tool["name"]] = {
"spec": tool, "spec": tool,
"direct": True, "direct": True,
"server": tool_server, "server": tool_server,
} }
if mcp_clients: if mcp_clients:
metadata["mcp_clients"] = mcp_clients metadata["mcp_clients"] = mcp_clients
# Inject builtin tools for native function calling based on enabled features and model capability # Inject builtin tools for native function calling based on enabled features and model capability
# Check if builtin_tools capability is enabled for this model (defaults to True if not specified) # Check if builtin_tools capability is enabled for this model (defaults to True if not specified)
builtin_tools_enabled = ( builtin_tools_enabled = (
model.get("info", {}).get("meta", {}).get("capabilities") or {} model.get("info", {}).get("meta", {}).get("capabilities") or {}
).get("builtin_tools", True) ).get("builtin_tools", True)
if ( if (
metadata.get("params", {}).get("function_calling") == "native" metadata.get("params", {}).get("function_calling") == "native"
and builtin_tools_enabled and builtin_tools_enabled
): ):
# Add file context to user messages # Add file context to user messages
chat_id = metadata.get("chat_id") chat_id = metadata.get("chat_id")
form_data["messages"] = add_file_context( form_data["messages"] = add_file_context(
form_data.get("messages", []), chat_id, user form_data.get("messages", []), chat_id, user
) )
builtin_tools = get_builtin_tools( builtin_tools = get_builtin_tools(
request, request,
{ {
**extra_params, **extra_params,
"__event_emitter__": event_emitter, "__event_emitter__": event_emitter,
"__skill_ids__": [ "__skill_ids__": [
s.id for s in available_skills if s.id not in user_skill_ids s.id for s in available_skills if s.id not in user_skill_ids
], ],
}, },
features, features,
model, model,
) )
for name, tool_dict in builtin_tools.items(): for name, tool_dict in builtin_tools.items():
if name not in tools_dict: if name not in tools_dict:
tools_dict[name] = tool_dict tools_dict[name] = tool_dict
if tools_dict: if tools_dict:
if metadata.get("params", {}).get("function_calling") == "native": if metadata.get("params", {}).get("function_calling") == "native":
# If the function calling is native, then call the tools function calling handler # If the function calling is native, then call the tools function calling handler
metadata["tools"] = tools_dict metadata["tools"] = tools_dict
form_data["tools"] = [ form_data["tools"] = [
{"type": "function", "function": tool.get("spec", {})} {"type": "function", "function": tool.get("spec", {})}
for tool in tools_dict.values() for tool in tools_dict.values()
] ]
else: else:
# If the function calling is not native, then call the tools function calling handler # If the function calling is not native, then call the tools function calling handler