diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 389ae9a90b..256b21aa5e 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -890,6 +890,27 @@ def handle_responses_streaming_event( return current_output, None +def get_source_context(sources: list, include_content: bool = True) -> str: + """ + Build tag context string from citation sources. + """ + context_string = "" + citation_idx = {} + for source in sources: + for doc, meta in zip(source.get("document", []), source.get("metadata", [])): + src_id = meta.get("source") or source.get("source", {}).get("id") or "N/A" + if src_id not in citation_idx: + citation_idx[src_id] = len(citation_idx) + 1 + src_name = source.get("source", {}).get("name") + body = doc if include_content else "" + context_string += ( + f'{body}\n" + ) + return context_string + + def apply_source_context_to_messages( request: Request, messages: list, @@ -908,21 +929,7 @@ def apply_source_context_to_messages( if not sources or not user_message: return messages - context_string = "" - citation_idx = {} - - for source in sources: - for doc, meta in zip(source.get("document", []), source.get("metadata", [])): - src_id = meta.get("source") or source.get("source", {}).get("id") or "N/A" - if src_id not in citation_idx: - citation_idx[src_id] = len(citation_idx) + 1 - src_name = source.get("source", {}).get("name") - body = doc if include_content else "" - context_string += ( - f'{body}\n" - ) + context_string = get_source_context(sources, include_content) context_string = context_string.strip() if not context_string: @@ -2741,6 +2748,18 @@ async def process_chat_payload(request, form_data, user, metadata, model): except Exception as e: log.exception(e) + # Save the pre-RAG message state so the native tool call loop can + # restore to the true original (before file-source injection) rather + # than a snapshot that already has the RAG template baked in. + system_message = get_system_message(form_data["messages"]) + metadata["system_prompt"] = ( + get_content_from_message(system_message) if system_message else None + ) + metadata["user_prompt"] = get_last_user_message( + form_data["messages"] + ) + metadata["sources"] = sources[:] if sources else [] + # If context is not empty, insert it into the messages if sources and prompt: form_data["messages"] = apply_source_context_to_messages( @@ -4202,16 +4221,21 @@ async def streaming_chat_response_handler(response, ctx): model.get("info", {}).get("meta", {}).get("capabilities") or {} ).get("citations", True) - # Save original system message so we can restore it before - # re-applying source context (prevents duplication when - # RAG_SYSTEM_CONTEXT is enabled and the template is appended - # to the system message on each iteration). - original_system_message = get_system_message(form_data["messages"]) - original_system_content = ( - get_content_from_message(original_system_message) - if original_system_message - else None + # Use the pre-RAG system content captured before the + # initial file-source injection in process_chat_payload. + # This ensures restore truly undoes the RAG template. + original_system_content = metadata.get( + "system_prompt" ) + if original_system_content is None: + original_system_message = get_system_message( + form_data["messages"] + ) + original_system_content = ( + get_content_from_message(original_system_message) + if original_system_message + else None + ) while ( len(tool_calls) > 0 @@ -4459,33 +4483,40 @@ async def streaming_chat_response_handler(response, ctx): } ) - # Emit citation sources and apply source context to messages + # Emit citation sources to the frontend for display if citations_enabled: for source in tool_call_sources: await event_emitter({"type": "source", "data": source}) - # Apply source context to messages for model. - # Use include_content=False to avoid duplicating content - # that is already in the tool result message. + # Apply tool source context to messages for the model. + # Restoring to pre-RAG original prevents duplicating + # the RAG template across file and tool sources. all_tool_call_sources.extend(tool_call_sources) if all_tool_call_sources and user_message: - # Restore original messages before re-applying to - # avoid recursive nesting (user message) and - # duplication (system message with RAG_SYSTEM_CONTEXT). + # Restore pre-RAG message state before re-applying + # to prevent RAG template duplication. + original_user_message = metadata.get( + "user_prompt" + ) or user_message set_last_user_message_content( - user_message, form_data["messages"] - ) - if original_system_content is not None: - replace_system_message_content( - original_system_content, - form_data["messages"], - ) - form_data["messages"] = apply_source_context_to_messages( - request, + original_user_message, form_data["messages"], - all_tool_call_sources, - user_message, - include_content=False, + ) + replace_system_message_content( + original_system_content or "", + form_data["messages"], + ) + + # Combine file and tool sources into one RAG + # template application. + form_data["messages"] = ( + apply_source_context_to_messages( + request, + form_data["messages"], + metadata.get("sources", []) + + all_tool_call_sources, + user_message, + ) ) tool_call_sources.clear()