enh: builtin tools

This commit is contained in:
Tim Baek
2026-01-07 07:00:32 -05:00
parent 60e916d6c0
commit 2789f6a24d
2 changed files with 142 additions and 59 deletions

View File

@@ -142,6 +142,104 @@ DEFAULT_SOLUTION_TAGS = [("<|begin_of_solution|>", "<|end_of_solution|>")]
DEFAULT_CODE_INTERPRETER_TAGS = [("<code_interpreter>", "</code_interpreter>")]
def get_citation_source_from_tool_result(
tool_name: str,
tool_params: dict,
tool_result: str,
tool_id: str = ""
) -> dict:
"""
Parse a tool's result and convert it to a source dict for citation display.
For web_search: extracts title, link, snippet from each search result.
For other tools: wraps the raw result as a generic source.
"""
try:
if tool_name == "web_search":
# Parse JSON array: [{"title": "...", "link": "...", "snippet": "..."}]
results = json.loads(tool_result)
documents = []
metadata = []
for result in results:
title = result.get("title", "")
link = result.get("link", "")
snippet = result.get("snippet", "")
documents.append(f"{title}\n{snippet}")
metadata.append({
"source": link,
"name": title,
"url": link,
})
return {
"source": {"name": "web_search", "id": "web_search"},
"document": documents,
"metadata": metadata,
}
else:
# Fallback for other tools
return {
"source": {"name": tool_name, "id": tool_id or tool_name},
"document": [str(tool_result)],
"metadata": [{"source": tool_id or tool_name, "parameters": tool_params}],
}
except Exception as e:
log.exception(f"Error parsing tool result for {tool_name}: {e}")
return {
"source": {"name": tool_name, "id": tool_id or tool_name},
"document": [str(tool_result)],
"metadata": [{"source": tool_id or tool_name}],
}
def apply_source_context_to_messages(
request: Request,
messages: list,
sources: list,
user_message: str,
) -> list:
"""
Build source context from citation sources and apply to messages.
Uses RAG template to format context for model consumption.
"""
if not sources or not user_message:
return messages
context_string = ""
citation_idx = {}
for source in sources:
for doc, meta in zip(source.get("document", []), source.get("metadata", [])):
src_id = meta.get("source") or source.get("source", {}).get("id") or "N/A"
if src_id not in citation_idx:
citation_idx[src_id] = len(citation_idx) + 1
src_name = source.get("source", {}).get("name")
context_string += (
f'<source id="{citation_idx[src_id]}"'
+ (f' name="{src_name}"' if src_name else "")
+ f">{doc}</source>\n"
)
context_string = context_string.strip()
if not context_string:
return messages
if RAG_SYSTEM_CONTEXT:
return add_or_update_system_message(
rag_template(request.app.state.config.RAG_TEMPLATE, context_string, user_message),
messages,
append=True,
)
else:
return add_or_update_user_message(
rag_template(request.app.state.config.RAG_TEMPLATE, context_string, user_message),
messages,
append=False,
)
def process_tool_result(
request,
tool_function_name,
@@ -1567,6 +1665,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
"__event_emitter__": event_emitter,
},
features,
model,
)
for name, tool_dict in builtin_tools.items():
if name not in tools_dict:
@@ -1599,58 +1698,10 @@ async def process_chat_payload(request, form_data, user, metadata, model):
log.exception(e)
# If context is not empty, insert it into the messages
if len(sources) > 0:
context_string = ""
citation_idx_map = {}
for source in sources:
if "document" in source:
for document_text, document_metadata in zip(
source["document"], source["metadata"]
):
source_name = source.get("source", {}).get("name", None)
source_id = (
document_metadata.get("source", None)
or source.get("source", {}).get("id", None)
or "N/A"
)
if source_id not in citation_idx_map:
citation_idx_map[source_id] = len(citation_idx_map) + 1
context_string += (
f'<source id="{citation_idx_map[source_id]}"'
+ (f' name="{source_name}"' if source_name else "")
+ f">{document_text}</source>\n"
)
context_string = context_string.strip()
if prompt is None:
raise Exception("No user message found")
if context_string != "":
if RAG_SYSTEM_CONTEXT:
# Inject into system message for KV prefix caching
form_data["messages"] = add_or_update_system_message(
rag_template(
request.app.state.config.RAG_TEMPLATE,
context_string,
prompt,
),
form_data["messages"],
append=True,
)
else:
# Inject into user message
form_data["messages"] = add_or_update_user_message(
rag_template(
request.app.state.config.RAG_TEMPLATE,
context_string,
prompt,
),
form_data["messages"],
append=False,
)
if sources and prompt:
form_data["messages"] = apply_source_context_to_messages(
request, form_data["messages"], sources, prompt
)
# If there are citations, add them to the data_items
sources = [
@@ -2977,6 +3028,7 @@ async def process_chat_response(
await stream_body_handler(response, form_data)
tool_call_retries = 0
tool_call_sources = [] # Track citation sources from tool results
while (
len(tool_calls) > 0
@@ -3111,6 +3163,19 @@ async def process_chat_response(
)
)
# Extract citation sources from web_search results
if tool_function_name == "web_search" and tool_result:
try:
citation_source = get_citation_source_from_tool_result(
tool_name=tool_function_name,
tool_params=tool_function_params,
tool_result=tool_result,
tool_id=tool.get("tool_id", "") if tool else ""
)
tool_call_sources.append(citation_source)
except Exception as e:
log.exception(f"Error extracting citation source: {e}")
results.append(
{
"tool_call_id": tool_call_id,
@@ -3136,6 +3201,19 @@ async def process_chat_response(
}
)
# Emit citation sources for UI display
for source in tool_call_sources:
await event_emitter({"type": "source", "data": source})
# Apply source context to messages for model
if tool_call_sources:
user_msg = get_last_user_message(form_data["messages"])
if user_msg:
form_data["messages"] = apply_source_context_to_messages(
request, form_data["messages"], tool_call_sources, user_msg
)
tool_call_sources.clear()
await event_emitter(
{
"type": "chat:completion",

View File

@@ -342,15 +342,20 @@ async def get_tools(
def get_builtin_tools(
request: Request, extra_params: dict, features: dict = None
request: Request, extra_params: dict, features: dict = None, model: dict = None
) -> dict[str, dict]:
"""
Get built-in tools for native function calling.
Only returns tools when BOTH the global config is enabled AND the feature is enabled for this chat.
Only returns tools when BOTH the global config is enabled AND the model capability allows it.
"""
tools_dict = {}
builtin_functions = []
features = features or {}
model = model or {}
# Helper to get model capabilities (defaults to True if not specified)
def get_model_capability(name: str, default: bool = True) -> bool:
return model.get("info", {}).get("meta", {}).get("capabilities", {}).get(name, default)
# Time utilities - always available for date calculations
builtin_functions.extend([get_current_timestamp, calculate_timestamp])
@@ -362,18 +367,18 @@ def get_builtin_tools(
if features.get("memory"):
builtin_functions.extend([search_memories, add_memory, replace_memory_content])
# Add web search tools if enabled globally AND for this chat
if getattr(request.app.state.config, "ENABLE_WEB_SEARCH", False) and features.get(
# Add web search tools if enabled globally AND model has web_search capability
if getattr(request.app.state.config, "ENABLE_WEB_SEARCH", False) and get_model_capability(
"web_search"
):
builtin_functions.extend([web_search, fetch_url])
# Add image generation/edit tools if enabled globally AND for this chat
# Add image generation/edit tools if enabled globally AND model has image_generation capability
if getattr(
request.app.state.config, "ENABLE_IMAGE_GENERATION", False
) and features.get("image_generation"):
) and get_model_capability("image_generation"):
builtin_functions.append(generate_image)
if getattr(request.app.state.config, "ENABLE_IMAGE_EDIT", False) and features.get(
if getattr(request.app.state.config, "ENABLE_IMAGE_EDIT", False) and get_model_capability(
"image_generation"
):
builtin_functions.append(edit_image)