This commit is contained in:
Timothy Jaeryang Baek
2026-03-07 20:12:35 -06:00
parent 989938856f
commit 144d8b1bb7

View File

@@ -890,6 +890,27 @@ def handle_responses_streaming_event(
return current_output, None
def get_source_context(sources: list, include_content: bool = True) -> str:
"""
Build <source> tag context string from citation sources.
"""
context_string = ""
citation_idx = {}
for source in sources:
for doc, meta in zip(source.get("document", []), source.get("metadata", [])):
src_id = meta.get("source") or source.get("source", {}).get("id") or "N/A"
if src_id not in citation_idx:
citation_idx[src_id] = len(citation_idx) + 1
src_name = source.get("source", {}).get("name")
body = doc if include_content else ""
context_string += (
f'<source id="{citation_idx[src_id]}"'
+ (f' name="{src_name}"' if src_name else "")
+ f">{body}</source>\n"
)
return context_string
def apply_source_context_to_messages(
request: Request,
messages: list,
@@ -908,21 +929,7 @@ def apply_source_context_to_messages(
if not sources or not user_message:
return messages
context_string = ""
citation_idx = {}
for source in sources:
for doc, meta in zip(source.get("document", []), source.get("metadata", [])):
src_id = meta.get("source") or source.get("source", {}).get("id") or "N/A"
if src_id not in citation_idx:
citation_idx[src_id] = len(citation_idx) + 1
src_name = source.get("source", {}).get("name")
body = doc if include_content else ""
context_string += (
f'<source id="{citation_idx[src_id]}"'
+ (f' name="{src_name}"' if src_name else "")
+ f">{body}</source>\n"
)
context_string = get_source_context(sources, include_content)
context_string = context_string.strip()
if not context_string:
@@ -2741,6 +2748,18 @@ async def process_chat_payload(request, form_data, user, metadata, model):
except Exception as e:
log.exception(e)
# Save the pre-RAG message state so the native tool call loop can
# restore to the true original (before file-source injection) rather
# than a snapshot that already has the RAG template baked in.
system_message = get_system_message(form_data["messages"])
metadata["system_prompt"] = (
get_content_from_message(system_message) if system_message else None
)
metadata["user_prompt"] = get_last_user_message(
form_data["messages"]
)
metadata["sources"] = sources[:] if sources else []
# If context is not empty, insert it into the messages
if sources and prompt:
form_data["messages"] = apply_source_context_to_messages(
@@ -4202,16 +4221,21 @@ async def streaming_chat_response_handler(response, ctx):
model.get("info", {}).get("meta", {}).get("capabilities") or {}
).get("citations", True)
# Save original system message so we can restore it before
# re-applying source context (prevents duplication when
# RAG_SYSTEM_CONTEXT is enabled and the template is appended
# to the system message on each iteration).
original_system_message = get_system_message(form_data["messages"])
original_system_content = (
get_content_from_message(original_system_message)
if original_system_message
else None
# Use the pre-RAG system content captured before the
# initial file-source injection in process_chat_payload.
# This ensures restore truly undoes the RAG template.
original_system_content = metadata.get(
"system_prompt"
)
if original_system_content is None:
original_system_message = get_system_message(
form_data["messages"]
)
original_system_content = (
get_content_from_message(original_system_message)
if original_system_message
else None
)
while (
len(tool_calls) > 0
@@ -4459,33 +4483,40 @@ async def streaming_chat_response_handler(response, ctx):
}
)
# Emit citation sources and apply source context to messages
# Emit citation sources to the frontend for display
if citations_enabled:
for source in tool_call_sources:
await event_emitter({"type": "source", "data": source})
# Apply source context to messages for model.
# Use include_content=False to avoid duplicating content
# that is already in the tool result message.
# Apply tool source context to messages for the model.
# Restoring to pre-RAG original prevents duplicating
# the RAG template across file and tool sources.
all_tool_call_sources.extend(tool_call_sources)
if all_tool_call_sources and user_message:
# Restore original messages before re-applying to
# avoid recursive nesting (user message) and
# duplication (system message with RAG_SYSTEM_CONTEXT).
# Restore pre-RAG message state before re-applying
# to prevent RAG template duplication.
original_user_message = metadata.get(
"user_prompt"
) or user_message
set_last_user_message_content(
user_message, form_data["messages"]
)
if original_system_content is not None:
replace_system_message_content(
original_system_content,
form_data["messages"],
)
form_data["messages"] = apply_source_context_to_messages(
request,
original_user_message,
form_data["messages"],
all_tool_call_sources,
user_message,
include_content=False,
)
replace_system_message_content(
original_system_content or "",
form_data["messages"],
)
# Combine file and tool sources into one RAG
# template application.
form_data["messages"] = (
apply_source_context_to_messages(
request,
form_data["messages"],
metadata.get("sources", [])
+ all_tool_call_sources,
user_message,
)
)
tool_call_sources.clear()