This commit is contained in:
Timothy Jaeryang Baek
2026-03-22 22:10:04 -05:00
parent ebb7ce2092
commit 6a9d67b5bb
3 changed files with 102 additions and 6 deletions

View File

@@ -828,6 +828,43 @@ async def get_terminal_cwd(
return None
async def get_terminal_system_prompt(
base_url: str,
headers: dict,
cookies: Optional[dict] = None,
) -> Optional[str]:
"""Fetch the system prompt from a terminal server.
Checks ``/api/config`` for the ``system`` feature flag first;
only fetches ``/system`` if the flag is present. Returns *None*
silently when the server doesn't support the endpoint.
"""
base = base_url.rstrip('/')
try:
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=3),
trust_env=True,
) as session:
# 1. Check feature flag
async with session.get(f'{base}/api/config') as resp:
if resp.status != 200:
return None
config = await resp.json()
if not config.get('features', {}).get('system'):
return None
# 2. Fetch system prompt
async with session.get(
f'{base}/system', headers=headers, cookies=cookies or {}
) as resp:
if resp.status == 200:
data = await resp.json()
return data.get('prompt')
except Exception as e:
log.debug(f'Failed to fetch terminal system prompt: {e}')
return None
async def set_terminal_servers(request: Request):
"""Load and cache OpenAPI specs from all TERMINAL_SERVER_CONNECTIONS."""
connections = request.app.state.config.TERMINAL_SERVER_CONNECTIONS or []
@@ -867,6 +904,25 @@ async def set_terminal_servers(request: Request):
request.app.state.TERMINAL_SERVERS = await get_tool_servers_data(server_configs)
# Fetch system prompts concurrently (runs at cache time, not per-request)
connections_by_id = {c.get('id'): c for c in connections if c.get('id')}
async def _fetch_system_prompt(server):
connection = connections_by_id.get(server.get('id'))
if not connection:
return
headers = {}
if connection.get('auth_type', 'bearer') == 'bearer':
headers['Authorization'] = f'Bearer {connection.get("key", "")}'
prompt = await get_terminal_system_prompt(server['url'], headers)
if prompt:
server['system_prompt'] = prompt
await asyncio.gather(
*[_fetch_system_prompt(s) for s in request.app.state.TERMINAL_SERVERS],
return_exceptions=True,
)
if request.app.state.redis is not None:
await request.app.state.redis.set('terminal_servers', json.dumps(request.app.state.TERMINAL_SERVERS))
@@ -894,7 +950,7 @@ async def get_terminal_tools(
terminal_id: str,
user: UserModel,
extra_params: dict,
) -> dict[str, dict]:
) -> tuple[dict[str, dict], Optional[str]]:
"""Resolve tools for a terminal server identified by terminal_id.
- Finds the connection in TERMINAL_SERVER_CONNECTIONS
@@ -941,14 +997,14 @@ async def get_terminal_tools(
headers['Authorization'] = f'Bearer {oauth_token.get("access_token", "")}'
# auth_type == "none": no Authorization header
system_prompt = server_data.get('system_prompt')
terminal_cwd = await get_terminal_cwd(connection.get('url', ''), headers, cookies)
tools_dict = {}
for spec in specs:
function_name = spec['name']
# Inject CWD into run_command description
tool_spec = clean_openai_tool_schema(spec)
if function_name == 'run_command' and terminal_cwd:
tool_spec['description'] = (
tool_spec.get('description', '') + f'\n\nThe current working directory is: {terminal_cwd}'
@@ -977,7 +1033,7 @@ async def get_terminal_tools(
'type': 'terminal',
}
return tools_dict
return tools_dict, system_prompt
async def get_tool_server_data(url: str, headers: Optional[dict]) -> Dict[str, Any]: