fix: payload tools handling

This commit is contained in:
Timothy Jaeryang Baek
2026-02-22 17:58:59 -06:00
parent 4b3543d3c0
commit 8f0658e64f

View File

@@ -2220,6 +2220,10 @@ async def process_chat_payload(request, form_data, user, metadata, model):
tool_ids = form_data.pop("tool_ids", None)
files = form_data.pop("files", None)
# Caller-provided OpenAI-style tools take precedence over server-side
# tool resolution (tool_ids, MCP servers, builtin tools).
payload_tools = form_data.get("tools", None)
# Skills
user_skill_ids = set(form_data.pop("skill_ids", None) or [])
model_skill_ids = set(model.get("info", {}).get("meta", {}).get("skillIds", []))
@@ -2291,228 +2295,236 @@ async def process_chat_payload(request, form_data, user, metadata, model):
}
form_data["metadata"] = metadata
# Server side tools
tool_ids = metadata.get("tool_ids", None)
# Client side tools
direct_tool_servers = metadata.get("tool_servers", None)
# When the caller provides an explicit OpenAI-style `tools` array in the
# request body, skip all server-side tool resolution and pass the caller's
# tools through to the model unchanged.
if payload_tools:
log.debug(
"Caller provided explicit tools — skipping server-side tool resolution"
)
else:
# Server side tools
tool_ids = metadata.get("tool_ids", None)
# Client side tools
direct_tool_servers = metadata.get("tool_servers", None)
log.debug(f"{tool_ids=}")
log.debug(f"{direct_tool_servers=}")
log.debug(f"{tool_ids=}")
log.debug(f"{direct_tool_servers=}")
tools_dict = {}
tools_dict = {}
mcp_clients = {}
mcp_tools_dict = {}
mcp_clients = {}
mcp_tools_dict = {}
if tool_ids:
for tool_id in tool_ids:
if tool_id.startswith("server:mcp:"):
try:
server_id = tool_id[len("server:mcp:") :]
if tool_ids:
for tool_id in tool_ids:
if tool_id.startswith("server:mcp:"):
try:
server_id = tool_id[len("server:mcp:") :]
mcp_server_connection = None
for (
server_connection
) in request.app.state.config.TOOL_SERVER_CONNECTIONS:
if (
server_connection.get("type", "") == "mcp"
and server_connection.get("info", {}).get("id") == server_id
):
mcp_server_connection = server_connection
break
mcp_server_connection = None
for (
server_connection
) in request.app.state.config.TOOL_SERVER_CONNECTIONS:
if (
server_connection.get("type", "") == "mcp"
and server_connection.get("info", {}).get("id") == server_id
):
mcp_server_connection = server_connection
break
if not mcp_server_connection:
log.error(f"MCP server with id {server_id} not found")
continue
if not mcp_server_connection:
log.error(f"MCP server with id {server_id} not found")
continue
# Check access control for MCP server
if not has_tool_server_access(user, mcp_server_connection):
log.warning(
f"Access denied to MCP server {server_id} for user {user.id}"
)
continue
# Check access control for MCP server
if not has_tool_server_access(user, mcp_server_connection):
log.warning(
f"Access denied to MCP server {server_id} for user {user.id}"
)
continue
auth_type = mcp_server_connection.get("auth_type", "")
headers = {}
if auth_type == "bearer":
headers["Authorization"] = (
f"Bearer {mcp_server_connection.get('key', '')}"
)
elif auth_type == "none":
# No authentication
pass
elif auth_type == "session":
headers["Authorization"] = (
f"Bearer {request.state.token.credentials}"
)
elif auth_type == "system_oauth":
oauth_token = extra_params.get("__oauth_token__", None)
if oauth_token:
auth_type = mcp_server_connection.get("auth_type", "")
headers = {}
if auth_type == "bearer":
headers["Authorization"] = (
f"Bearer {oauth_token.get('access_token', '')}"
f"Bearer {mcp_server_connection.get('key', '')}"
)
elif auth_type == "oauth_2.1":
try:
splits = server_id.split(":")
server_id = splits[-1] if len(splits) > 1 else server_id
oauth_token = await request.app.state.oauth_client_manager.get_oauth_token(
user.id, f"mcp:{server_id}"
elif auth_type == "none":
# No authentication
pass
elif auth_type == "session":
headers["Authorization"] = (
f"Bearer {request.state.token.credentials}"
)
elif auth_type == "system_oauth":
oauth_token = extra_params.get("__oauth_token__", None)
if oauth_token:
headers["Authorization"] = (
f"Bearer {oauth_token.get('access_token', '')}"
)
except Exception as e:
log.error(f"Error getting OAuth token: {e}")
oauth_token = None
elif auth_type == "oauth_2.1":
try:
splits = server_id.split(":")
server_id = splits[-1] if len(splits) > 1 else server_id
connection_headers = mcp_server_connection.get("headers", None)
if connection_headers and isinstance(connection_headers, dict):
for key, value in connection_headers.items():
headers[key] = value
# Add user info headers if enabled
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
if metadata and metadata.get("chat_id"):
headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = metadata.get(
"chat_id"
)
if metadata and metadata.get("message_id"):
headers[FORWARD_SESSION_INFO_HEADER_MESSAGE_ID] = (
metadata.get("message_id")
)
mcp_clients[server_id] = MCPClient()
await mcp_clients[server_id].connect(
url=mcp_server_connection.get("url", ""),
headers=headers if headers else None,
)
function_name_filter_list = mcp_server_connection.get(
"config", {}
).get("function_name_filter_list", "")
if isinstance(function_name_filter_list, str):
function_name_filter_list = function_name_filter_list.split(",")
tool_specs = await mcp_clients[server_id].list_tool_specs()
for tool_spec in tool_specs:
def make_tool_function(client, function_name):
async def tool_function(**kwargs):
return await client.call_tool(
function_name,
function_args=kwargs,
oauth_token = await request.app.state.oauth_client_manager.get_oauth_token(
user.id, f"mcp:{server_id}"
)
return tool_function
if oauth_token:
headers["Authorization"] = (
f"Bearer {oauth_token.get('access_token', '')}"
)
except Exception as e:
log.error(f"Error getting OAuth token: {e}")
oauth_token = None
if function_name_filter_list:
if not is_string_allowed(
tool_spec["name"], function_name_filter_list
):
# Skip this function
continue
connection_headers = mcp_server_connection.get("headers", None)
if connection_headers and isinstance(connection_headers, dict):
for key, value in connection_headers.items():
headers[key] = value
tool_function = make_tool_function(
mcp_clients[server_id], tool_spec["name"]
# Add user info headers if enabled
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
if metadata and metadata.get("chat_id"):
headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = metadata.get(
"chat_id"
)
if metadata and metadata.get("message_id"):
headers[FORWARD_SESSION_INFO_HEADER_MESSAGE_ID] = (
metadata.get("message_id")
)
mcp_clients[server_id] = MCPClient()
await mcp_clients[server_id].connect(
url=mcp_server_connection.get("url", ""),
headers=headers if headers else None,
)
mcp_tools_dict[f"{server_id}_{tool_spec['name']}"] = {
"spec": {
**tool_spec,
"name": f"{server_id}_{tool_spec['name']}",
},
"callable": tool_function,
"type": "mcp",
"client": mcp_clients[server_id],
"direct": False,
}
except Exception as e:
log.debug(e)
if event_emitter:
await event_emitter(
{
"type": "chat:message:error",
"data": {
"error": {
"content": f"Failed to connect to MCP server '{server_id}'"
}
function_name_filter_list = mcp_server_connection.get(
"config", {}
).get("function_name_filter_list", "")
if isinstance(function_name_filter_list, str):
function_name_filter_list = function_name_filter_list.split(",")
tool_specs = await mcp_clients[server_id].list_tool_specs()
for tool_spec in tool_specs:
def make_tool_function(client, function_name):
async def tool_function(**kwargs):
return await client.call_tool(
function_name,
function_args=kwargs,
)
return tool_function
if function_name_filter_list:
if not is_string_allowed(
tool_spec["name"], function_name_filter_list
):
# Skip this function
continue
tool_function = make_tool_function(
mcp_clients[server_id], tool_spec["name"]
)
mcp_tools_dict[f"{server_id}_{tool_spec['name']}"] = {
"spec": {
**tool_spec,
"name": f"{server_id}_{tool_spec['name']}",
},
"callable": tool_function,
"type": "mcp",
"client": mcp_clients[server_id],
"direct": False,
}
)
continue
except Exception as e:
log.debug(e)
if event_emitter:
await event_emitter(
{
"type": "chat:message:error",
"data": {
"error": {
"content": f"Failed to connect to MCP server '{server_id}'"
}
},
}
)
continue
tools_dict = await get_tools(
request,
tool_ids,
user,
{
**extra_params,
"__model__": models[task_model_id],
"__messages__": form_data["messages"],
"__files__": metadata.get("files", []),
},
)
tools_dict = await get_tools(
request,
tool_ids,
user,
{
**extra_params,
"__model__": models[task_model_id],
"__messages__": form_data["messages"],
"__files__": metadata.get("files", []),
},
)
if mcp_tools_dict:
tools_dict = {**tools_dict, **mcp_tools_dict}
if mcp_tools_dict:
tools_dict = {**tools_dict, **mcp_tools_dict}
if direct_tool_servers:
for tool_server in direct_tool_servers:
tool_specs = tool_server.pop("specs", [])
if direct_tool_servers:
for tool_server in direct_tool_servers:
tool_specs = tool_server.pop("specs", [])
for tool in tool_specs:
tools_dict[tool["name"]] = {
"spec": tool,
"direct": True,
"server": tool_server,
}
for tool in tool_specs:
tools_dict[tool["name"]] = {
"spec": tool,
"direct": True,
"server": tool_server,
}
if mcp_clients:
metadata["mcp_clients"] = mcp_clients
if mcp_clients:
metadata["mcp_clients"] = mcp_clients
# Inject builtin tools for native function calling based on enabled features and model capability
# Check if builtin_tools capability is enabled for this model (defaults to True if not specified)
builtin_tools_enabled = (
model.get("info", {}).get("meta", {}).get("capabilities") or {}
).get("builtin_tools", True)
if (
metadata.get("params", {}).get("function_calling") == "native"
and builtin_tools_enabled
):
# Add file context to user messages
chat_id = metadata.get("chat_id")
form_data["messages"] = add_file_context(
form_data.get("messages", []), chat_id, user
)
builtin_tools = get_builtin_tools(
request,
{
**extra_params,
"__event_emitter__": event_emitter,
"__skill_ids__": [
s.id for s in available_skills if s.id not in user_skill_ids
],
},
features,
model,
)
for name, tool_dict in builtin_tools.items():
if name not in tools_dict:
tools_dict[name] = tool_dict
# Inject builtin tools for native function calling based on enabled features and model capability
# Check if builtin_tools capability is enabled for this model (defaults to True if not specified)
builtin_tools_enabled = (
model.get("info", {}).get("meta", {}).get("capabilities") or {}
).get("builtin_tools", True)
if (
metadata.get("params", {}).get("function_calling") == "native"
and builtin_tools_enabled
):
# Add file context to user messages
chat_id = metadata.get("chat_id")
form_data["messages"] = add_file_context(
form_data.get("messages", []), chat_id, user
)
builtin_tools = get_builtin_tools(
request,
{
**extra_params,
"__event_emitter__": event_emitter,
"__skill_ids__": [
s.id for s in available_skills if s.id not in user_skill_ids
],
},
features,
model,
)
for name, tool_dict in builtin_tools.items():
if name not in tools_dict:
tools_dict[name] = tool_dict
if tools_dict:
if metadata.get("params", {}).get("function_calling") == "native":
# If the function calling is native, then call the tools function calling handler
metadata["tools"] = tools_dict
form_data["tools"] = [
{"type": "function", "function": tool.get("spec", {})}
for tool in tools_dict.values()
]
if tools_dict:
if metadata.get("params", {}).get("function_calling") == "native":
# If the function calling is native, then call the tools function calling handler
metadata["tools"] = tools_dict
form_data["tools"] = [
{"type": "function", "function": tool.get("spec", {})}
for tool in tools_dict.values()
]
else:
# If the function calling is not native, then call the tools function calling handler