Merge branch 'dev' into k_reranker

This commit is contained in:
Timothy Jaeryang Baek
2025-03-26 20:50:31 -07:00
committed by GitHub
147 changed files with 6065 additions and 1350 deletions

View File

@@ -100,7 +100,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
async def chat_completion_tools_handler(
request: Request, body: dict, user: UserModel, models, tools
request: Request, body: dict, extra_params: dict, user: UserModel, models, tools
) -> tuple[dict, dict]:
async def get_content_from_response(response) -> Optional[str]:
content = None
@@ -135,6 +135,9 @@ async def chat_completion_tools_handler(
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
}
event_caller = extra_params["__event_call__"]
metadata = extra_params["__metadata__"]
task_model_id = get_task_model_id(
body["model"],
request.app.state.config.TASK_MODEL,
@@ -189,19 +192,33 @@ async def chat_completion_tools_handler(
tool_function_params = tool_call.get("parameters", {})
try:
required_params = (
tools[tool_function_name]
.get("spec", {})
.get("parameters", {})
.get("required", [])
tool = tools[tool_function_name]
spec = tool.get("spec", {})
allowed_params = (
spec.get("parameters", {}).get("properties", {}).keys()
)
tool_function = tools[tool_function_name]["callable"]
tool_function = tool["callable"]
tool_function_params = {
k: v
for k, v in tool_function_params.items()
if k in required_params
if k in allowed_params
}
tool_output = await tool_function(**tool_function_params)
if tool.get("direct", False):
tool_output = await tool_function(**tool_function_params)
else:
tool_output = await event_caller(
{
"type": "execute:tool",
"data": {
"id": str(uuid4()),
"tool": tool,
"params": tool_function_params,
"session_id": metadata.get("session_id", None),
},
}
)
except Exception as e:
tool_output = str(e)
@@ -767,12 +784,18 @@ async def process_chat_payload(request, form_data, user, metadata, model):
}
form_data["metadata"] = metadata
# Server side tools
tool_ids = metadata.get("tool_ids", None)
# Client side tools
tool_specs = form_data.get("tool_specs", None)
log.debug(f"{tool_ids=}")
log.debug(f"{tool_specs=}")
tools_dict = {}
if tool_ids:
# If tool_ids field is present, then get the tools
tools = get_tools(
tools_dict = get_tools(
request,
tool_ids,
user,
@@ -783,20 +806,30 @@ async def process_chat_payload(request, form_data, user, metadata, model):
"__files__": metadata.get("files", []),
},
)
log.info(f"{tools=}")
log.info(f"{tools_dict=}")
if tool_specs:
for tool in tool_specs:
callable = tool.pop("callable", None)
tools_dict[tool["name"]] = {
"direct": True,
"callable": callable,
"spec": tool,
}
if tools_dict:
if metadata.get("function_calling") == "native":
# If the function calling is native, then call the tools function calling handler
metadata["tools"] = tools
metadata["tools"] = tools_dict
form_data["tools"] = [
{"type": "function", "function": tool.get("spec", {})}
for tool in tools.values()
for tool in tools_dict.values()
]
else:
# If the function calling is not native, then call the tools function calling handler
try:
form_data, flags = await chat_completion_tools_handler(
request, form_data, user, models, tools
request, form_data, extra_params, user, models, tools_dict
)
sources.extend(flags.get("sources", []))
@@ -815,7 +848,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
for source_idx, source in enumerate(sources):
if "document" in source:
for doc_idx, doc_context in enumerate(source["document"]):
context_string += f"<source><source_id>{source_idx}</source_id><source_context>{doc_context}</source_context></source>\n"
context_string += f"<source><source_id>{source_idx + 1}</source_id><source_context>{doc_context}</source_context></source>\n"
context_string = context_string.strip()
prompt = get_last_user_message(form_data["messages"])
@@ -1082,8 +1115,6 @@ async def process_chat_response(
for filter_id in get_sorted_filter_ids(model)
]
print(f"{filter_functions=}")
# Streaming response
if event_emitter and event_caller:
task_id = str(uuid4()) # Create a unique task ID.
@@ -1563,7 +1594,9 @@ async def process_chat_response(
value = delta.get("content")
reasoning_content = delta.get("reasoning_content")
reasoning_content = delta.get(
"reasoning_content"
) or delta.get("reasoning")
if reasoning_content:
if (
not content_blocks
@@ -1766,18 +1799,36 @@ async def process_chat_response(
spec = tool.get("spec", {})
try:
required_params = spec.get("parameters", {}).get(
"required", []
allowed_params = (
spec.get("parameters", {})
.get("properties", {})
.keys()
)
tool_function = tool["callable"]
tool_function_params = {
k: v
for k, v in tool_function_params.items()
if k in required_params
if k in allowed_params
}
tool_result = await tool_function(
**tool_function_params
)
if tool.get("direct", False):
tool_result = await tool_function(
**tool_function_params
)
else:
tool_result = await event_caller(
{
"type": "execute:tool",
"data": {
"id": str(uuid4()),
"tool": tool,
"params": tool_function_params,
"session_id": metadata.get(
"session_id", None
),
},
}
)
except Exception as e:
tool_result = str(e)