diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 162857ef43..3675a9846f 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -142,6 +142,104 @@ DEFAULT_SOLUTION_TAGS = [("<|begin_of_solution|>", "<|end_of_solution|>")] DEFAULT_CODE_INTERPRETER_TAGS = [("", "")] +def get_citation_source_from_tool_result( + tool_name: str, + tool_params: dict, + tool_result: str, + tool_id: str = "" +) -> dict: + """ + Parse a tool's result and convert it to a source dict for citation display. + + For web_search: extracts title, link, snippet from each search result. + For other tools: wraps the raw result as a generic source. + """ + try: + if tool_name == "web_search": + # Parse JSON array: [{"title": "...", "link": "...", "snippet": "..."}] + results = json.loads(tool_result) + documents = [] + metadata = [] + + for result in results: + title = result.get("title", "") + link = result.get("link", "") + snippet = result.get("snippet", "") + + documents.append(f"{title}\n{snippet}") + metadata.append({ + "source": link, + "name": title, + "url": link, + }) + + return { + "source": {"name": "web_search", "id": "web_search"}, + "document": documents, + "metadata": metadata, + } + else: + # Fallback for other tools + return { + "source": {"name": tool_name, "id": tool_id or tool_name}, + "document": [str(tool_result)], + "metadata": [{"source": tool_id or tool_name, "parameters": tool_params}], + } + except Exception as e: + log.exception(f"Error parsing tool result for {tool_name}: {e}") + return { + "source": {"name": tool_name, "id": tool_id or tool_name}, + "document": [str(tool_result)], + "metadata": [{"source": tool_id or tool_name}], + } + + +def apply_source_context_to_messages( + request: Request, + messages: list, + sources: list, + user_message: str, +) -> list: + """ + Build source context from citation sources and apply to messages. + Uses RAG template to format context for model consumption. + """ + if not sources or not user_message: + return messages + + context_string = "" + citation_idx = {} + + for source in sources: + for doc, meta in zip(source.get("document", []), source.get("metadata", [])): + src_id = meta.get("source") or source.get("source", {}).get("id") or "N/A" + if src_id not in citation_idx: + citation_idx[src_id] = len(citation_idx) + 1 + src_name = source.get("source", {}).get("name") + context_string += ( + f'{doc}\n" + ) + + context_string = context_string.strip() + if not context_string: + return messages + + if RAG_SYSTEM_CONTEXT: + return add_or_update_system_message( + rag_template(request.app.state.config.RAG_TEMPLATE, context_string, user_message), + messages, + append=True, + ) + else: + return add_or_update_user_message( + rag_template(request.app.state.config.RAG_TEMPLATE, context_string, user_message), + messages, + append=False, + ) + + def process_tool_result( request, tool_function_name, @@ -1567,6 +1665,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): "__event_emitter__": event_emitter, }, features, + model, ) for name, tool_dict in builtin_tools.items(): if name not in tools_dict: @@ -1599,58 +1698,10 @@ async def process_chat_payload(request, form_data, user, metadata, model): log.exception(e) # If context is not empty, insert it into the messages - if len(sources) > 0: - context_string = "" - citation_idx_map = {} - - for source in sources: - if "document" in source: - for document_text, document_metadata in zip( - source["document"], source["metadata"] - ): - source_name = source.get("source", {}).get("name", None) - source_id = ( - document_metadata.get("source", None) - or source.get("source", {}).get("id", None) - or "N/A" - ) - - if source_id not in citation_idx_map: - citation_idx_map[source_id] = len(citation_idx_map) + 1 - - context_string += ( - f'{document_text}\n" - ) - - context_string = context_string.strip() - if prompt is None: - raise Exception("No user message found") - - if context_string != "": - if RAG_SYSTEM_CONTEXT: - # Inject into system message for KV prefix caching - form_data["messages"] = add_or_update_system_message( - rag_template( - request.app.state.config.RAG_TEMPLATE, - context_string, - prompt, - ), - form_data["messages"], - append=True, - ) - else: - # Inject into user message - form_data["messages"] = add_or_update_user_message( - rag_template( - request.app.state.config.RAG_TEMPLATE, - context_string, - prompt, - ), - form_data["messages"], - append=False, - ) + if sources and prompt: + form_data["messages"] = apply_source_context_to_messages( + request, form_data["messages"], sources, prompt + ) # If there are citations, add them to the data_items sources = [ @@ -2977,6 +3028,7 @@ async def process_chat_response( await stream_body_handler(response, form_data) tool_call_retries = 0 + tool_call_sources = [] # Track citation sources from tool results while ( len(tool_calls) > 0 @@ -3111,6 +3163,19 @@ async def process_chat_response( ) ) + # Extract citation sources from web_search results + if tool_function_name == "web_search" and tool_result: + try: + citation_source = get_citation_source_from_tool_result( + tool_name=tool_function_name, + tool_params=tool_function_params, + tool_result=tool_result, + tool_id=tool.get("tool_id", "") if tool else "" + ) + tool_call_sources.append(citation_source) + except Exception as e: + log.exception(f"Error extracting citation source: {e}") + results.append( { "tool_call_id": tool_call_id, @@ -3136,6 +3201,19 @@ async def process_chat_response( } ) + # Emit citation sources for UI display + for source in tool_call_sources: + await event_emitter({"type": "source", "data": source}) + + # Apply source context to messages for model + if tool_call_sources: + user_msg = get_last_user_message(form_data["messages"]) + if user_msg: + form_data["messages"] = apply_source_context_to_messages( + request, form_data["messages"], tool_call_sources, user_msg + ) + tool_call_sources.clear() + await event_emitter( { "type": "chat:completion", diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index 4027a0a831..8a970e1c56 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -342,15 +342,20 @@ async def get_tools( def get_builtin_tools( - request: Request, extra_params: dict, features: dict = None + request: Request, extra_params: dict, features: dict = None, model: dict = None ) -> dict[str, dict]: """ Get built-in tools for native function calling. - Only returns tools when BOTH the global config is enabled AND the feature is enabled for this chat. + Only returns tools when BOTH the global config is enabled AND the model capability allows it. """ tools_dict = {} builtin_functions = [] features = features or {} + model = model or {} + + # Helper to get model capabilities (defaults to True if not specified) + def get_model_capability(name: str, default: bool = True) -> bool: + return model.get("info", {}).get("meta", {}).get("capabilities", {}).get(name, default) # Time utilities - always available for date calculations builtin_functions.extend([get_current_timestamp, calculate_timestamp]) @@ -362,18 +367,18 @@ def get_builtin_tools( if features.get("memory"): builtin_functions.extend([search_memories, add_memory, replace_memory_content]) - # Add web search tools if enabled globally AND for this chat - if getattr(request.app.state.config, "ENABLE_WEB_SEARCH", False) and features.get( + # Add web search tools if enabled globally AND model has web_search capability + if getattr(request.app.state.config, "ENABLE_WEB_SEARCH", False) and get_model_capability( "web_search" ): builtin_functions.extend([web_search, fetch_url]) - # Add image generation/edit tools if enabled globally AND for this chat + # Add image generation/edit tools if enabled globally AND model has image_generation capability if getattr( request.app.state.config, "ENABLE_IMAGE_GENERATION", False - ) and features.get("image_generation"): + ) and get_model_capability("image_generation"): builtin_functions.append(generate_image) - if getattr(request.app.state.config, "ENABLE_IMAGE_EDIT", False) and features.get( + if getattr(request.app.state.config, "ENABLE_IMAGE_EDIT", False) and get_model_capability( "image_generation" ): builtin_functions.append(edit_image)