mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-01 09:49:03 -05:00
feat: experimental mcp support
This commit is contained in:
@@ -87,6 +87,7 @@ from open_webui.utils.filter import (
|
||||
)
|
||||
from open_webui.utils.code_interpreter import execute_code_jupyter
|
||||
from open_webui.utils.payload import apply_system_prompt_to_body
|
||||
from open_webui.utils.mcp.client import MCPClient
|
||||
|
||||
|
||||
from open_webui.config import (
|
||||
@@ -988,14 +989,94 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
# Server side tools
|
||||
tool_ids = metadata.get("tool_ids", None)
|
||||
# Client side tools
|
||||
tool_servers = metadata.get("tool_servers", None)
|
||||
direct_tool_servers = metadata.get("tool_servers", None)
|
||||
|
||||
log.debug(f"{tool_ids=}")
|
||||
log.debug(f"{tool_servers=}")
|
||||
log.debug(f"{direct_tool_servers=}")
|
||||
|
||||
tools_dict = {}
|
||||
|
||||
mcp_clients = []
|
||||
mcp_tools_dict = {}
|
||||
|
||||
if tool_ids:
|
||||
for tool_id in tool_ids:
|
||||
if tool_id.startswith("server:mcp:"):
|
||||
try:
|
||||
server_id = tool_id[len("server:mcp:") :]
|
||||
|
||||
mcp_server_connection = None
|
||||
for (
|
||||
server_connection
|
||||
) in request.app.state.config.TOOL_SERVER_CONNECTIONS:
|
||||
if (
|
||||
server_connection.get("type", "") == "mcp"
|
||||
and server_connection.get("info", {}).get("id") == server_id
|
||||
):
|
||||
mcp_server_connection = server_connection
|
||||
break
|
||||
|
||||
if not mcp_server_connection:
|
||||
log.error(f"MCP server with id {server_id} not found")
|
||||
continue
|
||||
|
||||
auth_type = mcp_server_connection.get("auth_type", "")
|
||||
|
||||
headers = {}
|
||||
if auth_type == "bearer":
|
||||
headers["Authorization"] = (
|
||||
f"Bearer {mcp_server_connection.get('key', '')}"
|
||||
)
|
||||
elif auth_type == "none":
|
||||
# No authentication
|
||||
pass
|
||||
elif auth_type == "session":
|
||||
headers["Authorization"] = (
|
||||
f"Bearer {request.state.token.credentials}"
|
||||
)
|
||||
elif auth_type == "system_oauth":
|
||||
oauth_token = extra_params.get("__oauth_token__", None)
|
||||
if oauth_token:
|
||||
headers["Authorization"] = (
|
||||
f"Bearer {oauth_token.get('access_token', '')}"
|
||||
)
|
||||
|
||||
mcp_client = MCPClient()
|
||||
await mcp_client.connect(
|
||||
url=mcp_server_connection.get("url", ""),
|
||||
headers=headers if headers else None,
|
||||
)
|
||||
|
||||
tool_specs = await mcp_client.list_tool_specs()
|
||||
for tool_spec in tool_specs:
|
||||
|
||||
def make_tool_function(function_name):
|
||||
async def tool_function(**kwargs):
|
||||
print(
|
||||
f"Calling MCP tool {function_name} with args {kwargs}"
|
||||
)
|
||||
return await mcp_client.call_tool(
|
||||
function_name,
|
||||
function_args=kwargs,
|
||||
)
|
||||
|
||||
return tool_function
|
||||
|
||||
tool_function = make_tool_function(tool_spec["name"])
|
||||
|
||||
mcp_tools_dict[tool_spec["name"]] = {
|
||||
"spec": tool_spec,
|
||||
"callable": tool_function,
|
||||
"type": "mcp",
|
||||
"client": mcp_client,
|
||||
"direct": False,
|
||||
}
|
||||
|
||||
mcp_clients.append(mcp_client)
|
||||
except Exception as e:
|
||||
log.debug(e)
|
||||
continue
|
||||
|
||||
tools_dict = await get_tools(
|
||||
request,
|
||||
tool_ids,
|
||||
@@ -1007,9 +1088,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
"__files__": metadata.get("files", []),
|
||||
},
|
||||
)
|
||||
if mcp_tools_dict:
|
||||
tools_dict = {**tools_dict, **mcp_tools_dict}
|
||||
|
||||
if tool_servers:
|
||||
for tool_server in tool_servers:
|
||||
if direct_tool_servers:
|
||||
for tool_server in direct_tool_servers:
|
||||
tool_specs = tool_server.pop("specs", [])
|
||||
|
||||
for tool in tool_specs:
|
||||
@@ -1019,7 +1102,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
"server": tool_server,
|
||||
}
|
||||
|
||||
if mcp_clients:
|
||||
metadata["mcp_clients"] = mcp_clients
|
||||
|
||||
if tools_dict:
|
||||
log.info(f"tools_dict: {tools_dict}")
|
||||
if metadata.get("params", {}).get("function_calling") == "native":
|
||||
# If the function calling is native, then call the tools function calling handler
|
||||
metadata["tools"] = tools_dict
|
||||
@@ -1027,6 +1114,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
{"type": "function", "function": tool.get("spec", {})}
|
||||
for tool in tools_dict.values()
|
||||
]
|
||||
|
||||
else:
|
||||
# If the function calling is not native, then call the tools function calling handler
|
||||
try:
|
||||
@@ -2330,6 +2418,8 @@ async def process_chat_response(
|
||||
results = []
|
||||
|
||||
for tool_call in response_tool_calls:
|
||||
|
||||
print("tool_call", tool_call)
|
||||
tool_call_id = tool_call.get("id", "")
|
||||
tool_name = tool_call.get("function", {}).get("name", "")
|
||||
tool_args = tool_call.get("function", {}).get("arguments", "{}")
|
||||
@@ -2397,9 +2487,14 @@ async def process_chat_response(
|
||||
|
||||
else:
|
||||
tool_function = tool["callable"]
|
||||
|
||||
print("tool_name", tool_name)
|
||||
print("tool_function", tool_function)
|
||||
print("tool_function_params", tool_function_params)
|
||||
tool_result = await tool_function(
|
||||
**tool_function_params
|
||||
)
|
||||
print("tool_result", tool_result)
|
||||
|
||||
except Exception as e:
|
||||
tool_result = str(e)
|
||||
|
||||
Reference in New Issue
Block a user