feat: experimental mcp support

This commit is contained in:
Timothy Jaeryang Baek
2025-09-23 02:03:26 -04:00
parent aeb5288a3c
commit 777e81f7a8
10 changed files with 417 additions and 105 deletions

View File

@@ -87,6 +87,7 @@ from open_webui.utils.filter import (
)
from open_webui.utils.code_interpreter import execute_code_jupyter
from open_webui.utils.payload import apply_system_prompt_to_body
from open_webui.utils.mcp.client import MCPClient
from open_webui.config import (
@@ -988,14 +989,94 @@ async def process_chat_payload(request, form_data, user, metadata, model):
# Server side tools
tool_ids = metadata.get("tool_ids", None)
# Client side tools
tool_servers = metadata.get("tool_servers", None)
direct_tool_servers = metadata.get("tool_servers", None)
log.debug(f"{tool_ids=}")
log.debug(f"{tool_servers=}")
log.debug(f"{direct_tool_servers=}")
tools_dict = {}
mcp_clients = []
mcp_tools_dict = {}
if tool_ids:
for tool_id in tool_ids:
if tool_id.startswith("server:mcp:"):
try:
server_id = tool_id[len("server:mcp:") :]
mcp_server_connection = None
for (
server_connection
) in request.app.state.config.TOOL_SERVER_CONNECTIONS:
if (
server_connection.get("type", "") == "mcp"
and server_connection.get("info", {}).get("id") == server_id
):
mcp_server_connection = server_connection
break
if not mcp_server_connection:
log.error(f"MCP server with id {server_id} not found")
continue
auth_type = mcp_server_connection.get("auth_type", "")
headers = {}
if auth_type == "bearer":
headers["Authorization"] = (
f"Bearer {mcp_server_connection.get('key', '')}"
)
elif auth_type == "none":
# No authentication
pass
elif auth_type == "session":
headers["Authorization"] = (
f"Bearer {request.state.token.credentials}"
)
elif auth_type == "system_oauth":
oauth_token = extra_params.get("__oauth_token__", None)
if oauth_token:
headers["Authorization"] = (
f"Bearer {oauth_token.get('access_token', '')}"
)
mcp_client = MCPClient()
await mcp_client.connect(
url=mcp_server_connection.get("url", ""),
headers=headers if headers else None,
)
tool_specs = await mcp_client.list_tool_specs()
for tool_spec in tool_specs:
def make_tool_function(function_name):
async def tool_function(**kwargs):
print(
f"Calling MCP tool {function_name} with args {kwargs}"
)
return await mcp_client.call_tool(
function_name,
function_args=kwargs,
)
return tool_function
tool_function = make_tool_function(tool_spec["name"])
mcp_tools_dict[tool_spec["name"]] = {
"spec": tool_spec,
"callable": tool_function,
"type": "mcp",
"client": mcp_client,
"direct": False,
}
mcp_clients.append(mcp_client)
except Exception as e:
log.debug(e)
continue
tools_dict = await get_tools(
request,
tool_ids,
@@ -1007,9 +1088,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
"__files__": metadata.get("files", []),
},
)
if mcp_tools_dict:
tools_dict = {**tools_dict, **mcp_tools_dict}
if tool_servers:
for tool_server in tool_servers:
if direct_tool_servers:
for tool_server in direct_tool_servers:
tool_specs = tool_server.pop("specs", [])
for tool in tool_specs:
@@ -1019,7 +1102,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
"server": tool_server,
}
if mcp_clients:
metadata["mcp_clients"] = mcp_clients
if tools_dict:
log.info(f"tools_dict: {tools_dict}")
if metadata.get("params", {}).get("function_calling") == "native":
# If the function calling is native, then call the tools function calling handler
metadata["tools"] = tools_dict
@@ -1027,6 +1114,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
{"type": "function", "function": tool.get("spec", {})}
for tool in tools_dict.values()
]
else:
# If the function calling is not native, then call the tools function calling handler
try:
@@ -2330,6 +2418,8 @@ async def process_chat_response(
results = []
for tool_call in response_tool_calls:
print("tool_call", tool_call)
tool_call_id = tool_call.get("id", "")
tool_name = tool_call.get("function", {}).get("name", "")
tool_args = tool_call.get("function", {}).get("arguments", "{}")
@@ -2397,9 +2487,14 @@ async def process_chat_response(
else:
tool_function = tool["callable"]
print("tool_name", tool_name)
print("tool_function", tool_function)
print("tool_function_params", tool_function_params)
tool_result = await tool_function(
**tool_function_params
)
print("tool_result", tool_result)
except Exception as e:
tool_result = str(e)