mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-04 19:29:27 -05:00
refac
This commit is contained in:
@@ -828,6 +828,43 @@ async def get_terminal_cwd(
|
||||
return None
|
||||
|
||||
|
||||
async def get_terminal_system_prompt(
|
||||
base_url: str,
|
||||
headers: dict,
|
||||
cookies: Optional[dict] = None,
|
||||
) -> Optional[str]:
|
||||
"""Fetch the system prompt from a terminal server.
|
||||
|
||||
Checks ``/api/config`` for the ``system`` feature flag first;
|
||||
only fetches ``/system`` if the flag is present. Returns *None*
|
||||
silently when the server doesn't support the endpoint.
|
||||
"""
|
||||
base = base_url.rstrip('/')
|
||||
try:
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=3),
|
||||
trust_env=True,
|
||||
) as session:
|
||||
# 1. Check feature flag
|
||||
async with session.get(f'{base}/api/config') as resp:
|
||||
if resp.status != 200:
|
||||
return None
|
||||
config = await resp.json()
|
||||
if not config.get('features', {}).get('system'):
|
||||
return None
|
||||
|
||||
# 2. Fetch system prompt
|
||||
async with session.get(
|
||||
f'{base}/system', headers=headers, cookies=cookies or {}
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return data.get('prompt')
|
||||
except Exception as e:
|
||||
log.debug(f'Failed to fetch terminal system prompt: {e}')
|
||||
return None
|
||||
|
||||
|
||||
async def set_terminal_servers(request: Request):
|
||||
"""Load and cache OpenAPI specs from all TERMINAL_SERVER_CONNECTIONS."""
|
||||
connections = request.app.state.config.TERMINAL_SERVER_CONNECTIONS or []
|
||||
@@ -867,6 +904,25 @@ async def set_terminal_servers(request: Request):
|
||||
|
||||
request.app.state.TERMINAL_SERVERS = await get_tool_servers_data(server_configs)
|
||||
|
||||
# Fetch system prompts concurrently (runs at cache time, not per-request)
|
||||
connections_by_id = {c.get('id'): c for c in connections if c.get('id')}
|
||||
|
||||
async def _fetch_system_prompt(server):
|
||||
connection = connections_by_id.get(server.get('id'))
|
||||
if not connection:
|
||||
return
|
||||
headers = {}
|
||||
if connection.get('auth_type', 'bearer') == 'bearer':
|
||||
headers['Authorization'] = f'Bearer {connection.get("key", "")}'
|
||||
prompt = await get_terminal_system_prompt(server['url'], headers)
|
||||
if prompt:
|
||||
server['system_prompt'] = prompt
|
||||
|
||||
await asyncio.gather(
|
||||
*[_fetch_system_prompt(s) for s in request.app.state.TERMINAL_SERVERS],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
if request.app.state.redis is not None:
|
||||
await request.app.state.redis.set('terminal_servers', json.dumps(request.app.state.TERMINAL_SERVERS))
|
||||
|
||||
@@ -894,7 +950,7 @@ async def get_terminal_tools(
|
||||
terminal_id: str,
|
||||
user: UserModel,
|
||||
extra_params: dict,
|
||||
) -> dict[str, dict]:
|
||||
) -> tuple[dict[str, dict], Optional[str]]:
|
||||
"""Resolve tools for a terminal server identified by terminal_id.
|
||||
|
||||
- Finds the connection in TERMINAL_SERVER_CONNECTIONS
|
||||
@@ -941,14 +997,14 @@ async def get_terminal_tools(
|
||||
headers['Authorization'] = f'Bearer {oauth_token.get("access_token", "")}'
|
||||
# auth_type == "none": no Authorization header
|
||||
|
||||
system_prompt = server_data.get('system_prompt')
|
||||
terminal_cwd = await get_terminal_cwd(connection.get('url', ''), headers, cookies)
|
||||
|
||||
tools_dict = {}
|
||||
for spec in specs:
|
||||
function_name = spec['name']
|
||||
|
||||
# Inject CWD into run_command description
|
||||
tool_spec = clean_openai_tool_schema(spec)
|
||||
|
||||
if function_name == 'run_command' and terminal_cwd:
|
||||
tool_spec['description'] = (
|
||||
tool_spec.get('description', '') + f'\n\nThe current working directory is: {terminal_cwd}'
|
||||
@@ -977,7 +1033,7 @@ async def get_terminal_tools(
|
||||
'type': 'terminal',
|
||||
}
|
||||
|
||||
return tools_dict
|
||||
return tools_dict, system_prompt
|
||||
|
||||
|
||||
async def get_tool_server_data(url: str, headers: Optional[dict]) -> Dict[str, Any]:
|
||||
|
||||
Reference in New Issue
Block a user