mirror of
https://github.com/ollama/ollama.git
synced 2026-03-11 17:34:04 -05:00
anthropic: enable websearch (#14246)
This commit is contained in:
@@ -1,17 +1,25 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
// Error types matching Anthropic API
|
||||
@@ -82,22 +90,25 @@ type MessageParam struct {
|
||||
// Text and Thinking use pointers so they serialize as the field being present (even if empty)
|
||||
// only when set, which is required for SDK streaming accumulation.
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"` // text, image, tool_use, tool_result, thinking
|
||||
Type string `json:"type"` // text, image, tool_use, tool_result, thinking, server_tool_use, web_search_tool_result
|
||||
|
||||
// For text blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
||||
Text *string `json:"text,omitempty"`
|
||||
|
||||
// For text blocks with citations
|
||||
Citations []Citation `json:"citations,omitempty"`
|
||||
|
||||
// For image blocks
|
||||
Source *ImageSource `json:"source,omitempty"`
|
||||
|
||||
// For tool_use blocks
|
||||
// For tool_use and server_tool_use blocks
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
|
||||
// For tool_result blocks
|
||||
// For tool_result and web_search_tool_result blocks
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content any `json:"content,omitempty"` // string or []ContentBlock
|
||||
Content any `json:"content,omitempty"` // string, []ContentBlock, []WebSearchResult, or WebSearchToolResultError
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
|
||||
// For thinking blocks - pointer so field only appears when set (SDK requires it for accumulation)
|
||||
@@ -105,6 +116,30 @@ type ContentBlock struct {
|
||||
Signature string `json:"signature,omitempty"`
|
||||
}
|
||||
|
||||
// Citation represents a citation in a text block
|
||||
type Citation struct {
|
||||
Type string `json:"type"` // "web_search_result_location"
|
||||
URL string `json:"url"`
|
||||
Title string `json:"title"`
|
||||
EncryptedIndex string `json:"encrypted_index,omitempty"`
|
||||
CitedText string `json:"cited_text,omitempty"`
|
||||
}
|
||||
|
||||
// WebSearchResult represents a single web search result
|
||||
type WebSearchResult struct {
|
||||
Type string `json:"type"` // "web_search_result"
|
||||
URL string `json:"url"`
|
||||
Title string `json:"title"`
|
||||
EncryptedContent string `json:"encrypted_content,omitempty"`
|
||||
PageAge string `json:"page_age,omitempty"`
|
||||
}
|
||||
|
||||
// WebSearchToolResultError represents an error from web search
|
||||
type WebSearchToolResultError struct {
|
||||
Type string `json:"type"` // "web_search_tool_result_error"
|
||||
ErrorCode string `json:"error_code"`
|
||||
}
|
||||
|
||||
// ImageSource represents the source of an image
|
||||
type ImageSource struct {
|
||||
Type string `json:"type"` // "base64" or "url"
|
||||
@@ -115,10 +150,13 @@ type ImageSource struct {
|
||||
|
||||
// Tool represents a tool definition
|
||||
type Tool struct {
|
||||
Type string `json:"type,omitempty"` // "custom" for user-defined tools
|
||||
Type string `json:"type,omitempty"` // "custom" for user-defined tools, or "web_search_20250305" for web search
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema,omitempty"`
|
||||
|
||||
// Web search specific fields
|
||||
MaxUses int `json:"max_uses,omitempty"`
|
||||
}
|
||||
|
||||
// ToolChoice controls how the model uses tools
|
||||
@@ -233,6 +271,8 @@ type StreamErrorEvent struct {
|
||||
|
||||
// FromMessagesRequest converts an Anthropic MessagesRequest to an Ollama api.ChatRequest
|
||||
func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||
logutil.Trace("anthropic: converting request", "req", TraceMessagesRequest(r))
|
||||
|
||||
var messages []api.Message
|
||||
|
||||
if r.System != nil {
|
||||
@@ -259,9 +299,10 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||
}
|
||||
}
|
||||
|
||||
for _, msg := range r.Messages {
|
||||
for i, msg := range r.Messages {
|
||||
converted, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
logutil.Trace("anthropic: message conversion failed", "index", i, "role", msg.Role, "err", err)
|
||||
return nil, err
|
||||
}
|
||||
messages = append(messages, converted...)
|
||||
@@ -288,8 +329,24 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||
}
|
||||
|
||||
var tools api.Tools
|
||||
hasBuiltinWebSearch := false
|
||||
for _, t := range r.Tools {
|
||||
tool, err := convertTool(t)
|
||||
if strings.HasPrefix(t.Type, "web_search") {
|
||||
hasBuiltinWebSearch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
for _, t := range r.Tools {
|
||||
// Anthropic built-in web_search maps to Ollama function name "web_search".
|
||||
// If a user-defined tool also uses that name in the same request, drop the
|
||||
// user-defined one to avoid ambiguous tool-call routing.
|
||||
if hasBuiltinWebSearch && !strings.HasPrefix(t.Type, "web_search") && t.Name == "web_search" {
|
||||
logutil.Trace("anthropic: dropping colliding custom web_search tool", "tool", TraceTool(t))
|
||||
continue
|
||||
}
|
||||
|
||||
tool, _, err := convertTool(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -302,15 +359,17 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||
}
|
||||
|
||||
stream := r.Stream
|
||||
|
||||
return &api.ChatRequest{
|
||||
convertedRequest := &api.ChatRequest{
|
||||
Model: r.Model,
|
||||
Messages: messages,
|
||||
Options: options,
|
||||
Stream: &stream,
|
||||
Tools: tools,
|
||||
Think: think,
|
||||
}, nil
|
||||
}
|
||||
logutil.Trace("anthropic: converted request", "req", TraceChatRequest(convertedRequest))
|
||||
|
||||
return convertedRequest, nil
|
||||
}
|
||||
|
||||
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
|
||||
@@ -328,10 +387,19 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
var toolCalls []api.ToolCall
|
||||
var thinking string
|
||||
var toolResults []api.Message
|
||||
textBlocks := 0
|
||||
imageBlocks := 0
|
||||
toolUseBlocks := 0
|
||||
toolResultBlocks := 0
|
||||
serverToolUseBlocks := 0
|
||||
webSearchToolResultBlocks := 0
|
||||
thinkingBlocks := 0
|
||||
unknownBlocks := 0
|
||||
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: invalid content block format", "role", role)
|
||||
return nil, errors.New("invalid content block format")
|
||||
}
|
||||
|
||||
@@ -339,13 +407,16 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
|
||||
switch blockType {
|
||||
case "text":
|
||||
textBlocks++
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
textContent.WriteString(text)
|
||||
}
|
||||
|
||||
case "image":
|
||||
imageBlocks++
|
||||
source, ok := blockMap["source"].(map[string]any)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: invalid image source", "role", role)
|
||||
return nil, errors.New("invalid image source")
|
||||
}
|
||||
|
||||
@@ -354,21 +425,26 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
data, _ := source["data"].(string)
|
||||
decoded, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err)
|
||||
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||
}
|
||||
images = append(images, decoded)
|
||||
} else {
|
||||
logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", sourceType)
|
||||
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
|
||||
}
|
||||
// URL images would need to be fetched - skip for now
|
||||
|
||||
case "tool_use":
|
||||
toolUseBlocks++
|
||||
id, ok := blockMap["id"].(string)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: tool_use block missing id", "role", role)
|
||||
return nil, errors.New("tool_use block missing required 'id' field")
|
||||
}
|
||||
name, ok := blockMap["name"].(string)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: tool_use block missing name", "role", role)
|
||||
return nil, errors.New("tool_use block missing required 'name' field")
|
||||
}
|
||||
tc := api.ToolCall{
|
||||
@@ -383,6 +459,7 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
toolCalls = append(toolCalls, tc)
|
||||
|
||||
case "tool_result":
|
||||
toolResultBlocks++
|
||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||
var resultContent string
|
||||
|
||||
@@ -408,9 +485,36 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
})
|
||||
|
||||
case "thinking":
|
||||
thinkingBlocks++
|
||||
if t, ok := blockMap["thinking"].(string); ok {
|
||||
thinking = t
|
||||
}
|
||||
|
||||
case "server_tool_use":
|
||||
serverToolUseBlocks++
|
||||
id, _ := blockMap["id"].(string)
|
||||
name, _ := blockMap["name"].(string)
|
||||
tc := api.ToolCall{
|
||||
ID: id,
|
||||
Function: api.ToolCallFunction{
|
||||
Name: name,
|
||||
},
|
||||
}
|
||||
if input, ok := blockMap["input"].(map[string]any); ok {
|
||||
tc.Function.Arguments = mapToArgs(input)
|
||||
}
|
||||
toolCalls = append(toolCalls, tc)
|
||||
|
||||
case "web_search_tool_result":
|
||||
webSearchToolResultBlocks++
|
||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: formatWebSearchToolResultContent(blockMap["content"]),
|
||||
ToolCallID: toolUseID,
|
||||
})
|
||||
default:
|
||||
unknownBlocks++
|
||||
}
|
||||
}
|
||||
|
||||
@@ -427,6 +531,19 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
|
||||
// Add tool results as separate messages
|
||||
messages = append(messages, toolResults...)
|
||||
logutil.Trace("anthropic: converted block message",
|
||||
"role", role,
|
||||
"blocks", len(content),
|
||||
"text", textBlocks,
|
||||
"image", imageBlocks,
|
||||
"tool_use", toolUseBlocks,
|
||||
"tool_result", toolResultBlocks,
|
||||
"server_tool_use", serverToolUseBlocks,
|
||||
"web_search_result", webSearchToolResultBlocks,
|
||||
"thinking", thinkingBlocks,
|
||||
"unknown", unknownBlocks,
|
||||
"messages", TraceAPIMessages(messages),
|
||||
)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid message content type: %T", content)
|
||||
@@ -435,12 +552,94 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// convertTool converts an Anthropic Tool to an Ollama api.Tool
|
||||
func convertTool(t Tool) (api.Tool, error) {
|
||||
func formatWebSearchToolResultContent(content any) string {
|
||||
switch c := content.(type) {
|
||||
case string:
|
||||
return c
|
||||
case []WebSearchResult:
|
||||
var resultContent strings.Builder
|
||||
for _, item := range c {
|
||||
if item.Type != "web_search_result" {
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(&resultContent, "- %s: %s\n", item.Title, item.URL)
|
||||
}
|
||||
return resultContent.String()
|
||||
case []any:
|
||||
var resultContent strings.Builder
|
||||
for _, item := range c {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch itemMap["type"] {
|
||||
case "web_search_result":
|
||||
title, _ := itemMap["title"].(string)
|
||||
url, _ := itemMap["url"].(string)
|
||||
fmt.Fprintf(&resultContent, "- %s: %s\n", title, url)
|
||||
case "web_search_tool_result_error":
|
||||
errorCode, _ := itemMap["error_code"].(string)
|
||||
if errorCode == "" {
|
||||
return "web_search_tool_result_error"
|
||||
}
|
||||
return "web_search_tool_result_error: " + errorCode
|
||||
}
|
||||
}
|
||||
return resultContent.String()
|
||||
case map[string]any:
|
||||
if c["type"] == "web_search_tool_result_error" {
|
||||
errorCode, _ := c["error_code"].(string)
|
||||
if errorCode == "" {
|
||||
return "web_search_tool_result_error"
|
||||
}
|
||||
return "web_search_tool_result_error: " + errorCode
|
||||
}
|
||||
data, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(data)
|
||||
case WebSearchToolResultError:
|
||||
if c.ErrorCode == "" {
|
||||
return "web_search_tool_result_error"
|
||||
}
|
||||
return "web_search_tool_result_error: " + c.ErrorCode
|
||||
default:
|
||||
data, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
}
|
||||
|
||||
// convertTool converts an Anthropic Tool to an Ollama api.Tool, returning true if it's a server tool
|
||||
func convertTool(t Tool) (api.Tool, bool, error) {
|
||||
if strings.HasPrefix(t.Type, "web_search") {
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("query", api.ToolProperty{
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The search query to look up on the web",
|
||||
})
|
||||
return api.Tool{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "web_search",
|
||||
Description: "Search the web for current information. Use this to find up-to-date information about any topic.",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"query"},
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
}, true, nil
|
||||
}
|
||||
|
||||
var params api.ToolFunctionParameters
|
||||
if len(t.InputSchema) > 0 {
|
||||
if err := json.Unmarshal(t.InputSchema, ¶ms); err != nil {
|
||||
return api.Tool{}, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
|
||||
logutil.Trace("anthropic: invalid tool schema", "tool", t.Name, "err", err)
|
||||
return api.Tool{}, false, fmt.Errorf("invalid input_schema for tool %q: %w", t.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -451,7 +650,7 @@ func convertTool(t Tool) (api.Tool, error) {
|
||||
Description: t.Description,
|
||||
Parameters: params,
|
||||
},
|
||||
}, nil
|
||||
}, false, nil
|
||||
}
|
||||
|
||||
// ToMessagesResponse converts an Ollama api.ChatResponse to an Anthropic MessagesResponse
|
||||
@@ -899,3 +1098,113 @@ func countContentBlock(block any) int {
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
// OllamaWebSearchRequest represents a request to the Ollama web search API
|
||||
type OllamaWebSearchRequest struct {
|
||||
Query string `json:"query"`
|
||||
MaxResults int `json:"max_results,omitempty"`
|
||||
}
|
||||
|
||||
// OllamaWebSearchResult represents a single search result from Ollama API
|
||||
type OllamaWebSearchResult struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// OllamaWebSearchResponse represents the response from the Ollama web search API
|
||||
type OllamaWebSearchResponse struct {
|
||||
Results []OllamaWebSearchResult `json:"results"`
|
||||
}
|
||||
|
||||
var WebSearchEndpoint = "https://ollama.com/api/web_search"
|
||||
|
||||
func WebSearch(ctx context.Context, query string, maxResults int) (*OllamaWebSearchResponse, error) {
|
||||
if internalcloud.Disabled() {
|
||||
logutil.TraceContext(ctx, "anthropic: web search blocked", "reason", "cloud_disabled")
|
||||
return nil, errors.New(internalcloud.DisabledError("web search is unavailable"))
|
||||
}
|
||||
|
||||
if maxResults <= 0 {
|
||||
maxResults = 5
|
||||
}
|
||||
if maxResults > 10 {
|
||||
maxResults = 10
|
||||
}
|
||||
|
||||
reqBody := OllamaWebSearchRequest{
|
||||
Query: query,
|
||||
MaxResults: maxResults,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal web search request: %w", err)
|
||||
}
|
||||
|
||||
searchURL, err := url.Parse(WebSearchEndpoint)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse web search URL: %w", err)
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic: web search request",
|
||||
"query", TraceTruncateString(query),
|
||||
"max_results", maxResults,
|
||||
"url", searchURL.String(),
|
||||
)
|
||||
|
||||
q := searchURL.Query()
|
||||
q.Set("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
||||
searchURL.RawQuery = q.Encode()
|
||||
|
||||
signature := ""
|
||||
if strings.EqualFold(searchURL.Hostname(), "ollama.com") {
|
||||
challenge := fmt.Sprintf("%s,%s", http.MethodPost, searchURL.RequestURI())
|
||||
signature, err = auth.Sign(ctx, []byte(challenge))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to sign web search request: %w", err)
|
||||
}
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic: web search auth", "signed", signature != "")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", searchURL.String(), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create web search request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if signature != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signature))
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("web search request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
logutil.TraceContext(ctx, "anthropic: web search response", "status", resp.StatusCode)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("web search returned status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var searchResp OllamaWebSearchResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&searchResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode web search response: %w", err)
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic: web search results", "count", len(searchResp.Results))
|
||||
|
||||
return &searchResp, nil
|
||||
}
|
||||
|
||||
func ConvertOllamaToAnthropicResults(ollamaResults *OllamaWebSearchResponse) []WebSearchResult {
|
||||
var results []WebSearchResult
|
||||
for _, r := range ollamaResults.Results {
|
||||
results = append(results, WebSearchResult{
|
||||
Type: "web_search_result",
|
||||
URL: r.URL,
|
||||
Title: r.Title,
|
||||
})
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package anthropic
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -300,6 +301,78 @@ func TestFromMessagesRequest_WithTools(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_DropsCustomWebSearchWhenBuiltinPresent(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Type: "web_search_20250305",
|
||||
Name: "web_search",
|
||||
},
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "web_search",
|
||||
Description: "User-defined web search that should be dropped",
|
||||
InputSchema: json.RawMessage(`{"type":"invalid"}`),
|
||||
},
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Tools) != 2 {
|
||||
t.Fatalf("expected 2 tools after dropping custom web_search, got %d", len(result.Tools))
|
||||
}
|
||||
if result.Tools[0].Function.Name != "web_search" {
|
||||
t.Fatalf("expected first tool to be built-in web_search, got %q", result.Tools[0].Function.Name)
|
||||
}
|
||||
if result.Tools[1].Function.Name != "get_weather" {
|
||||
t.Fatalf("expected second tool to be get_weather, got %q", result.Tools[1].Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_KeepsCustomWebSearchWhenBuiltinAbsent(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{{Role: "user", Content: "Hello"}},
|
||||
Tools: []Tool{
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "web_search",
|
||||
Description: "User-defined web search",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Tools) != 1 {
|
||||
t.Fatalf("expected 1 custom tool, got %d", len(result.Tools))
|
||||
}
|
||||
if result.Tools[0].Function.Name != "web_search" {
|
||||
t.Fatalf("expected custom tool name web_search, got %q", result.Tools[0].Function.Name)
|
||||
}
|
||||
if result.Tools[0].Function.Description != "User-defined web search" {
|
||||
t.Fatalf("expected custom description preserved, got %q", result.Tools[0].Function.Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithThinking(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
@@ -1063,3 +1136,320 @@ func TestEstimateTokens_EmptyContent(t *testing.T) {
|
||||
t.Errorf("expected 0 tokens for empty content, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
// Web Search Tests
|
||||
|
||||
func TestConvertTool_WebSearch(t *testing.T) {
|
||||
tool := Tool{
|
||||
Type: "web_search_20250305",
|
||||
Name: "web_search",
|
||||
MaxUses: 5,
|
||||
}
|
||||
|
||||
result, isServerTool, err := convertTool(tool)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !isServerTool {
|
||||
t.Error("expected isServerTool to be true for web_search tool")
|
||||
}
|
||||
|
||||
if result.Type != "function" {
|
||||
t.Errorf("expected type 'function', got %q", result.Type)
|
||||
}
|
||||
|
||||
if result.Function.Name != "web_search" {
|
||||
t.Errorf("expected name 'web_search', got %q", result.Function.Name)
|
||||
}
|
||||
|
||||
if result.Function.Description == "" {
|
||||
t.Error("expected non-empty description for web_search tool")
|
||||
}
|
||||
|
||||
// Check that query parameter is defined
|
||||
if result.Function.Parameters.Properties == nil {
|
||||
t.Fatal("expected properties to be defined")
|
||||
}
|
||||
|
||||
queryProp, ok := result.Function.Parameters.Properties.Get("query")
|
||||
if !ok {
|
||||
t.Error("expected 'query' property to be defined")
|
||||
}
|
||||
|
||||
if len(queryProp.Type) == 0 || queryProp.Type[0] != "string" {
|
||||
t.Errorf("expected query type to be 'string', got %v", queryProp.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertTool_RegularTool(t *testing.T) {
|
||||
tool := Tool{
|
||||
Type: "custom",
|
||||
Name: "get_weather",
|
||||
Description: "Get the weather",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}}}`),
|
||||
}
|
||||
|
||||
result, isServerTool, err := convertTool(tool)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if isServerTool {
|
||||
t.Error("expected isServerTool to be false for regular tool")
|
||||
}
|
||||
|
||||
if result.Function.Name != "get_weather" {
|
||||
t.Errorf("expected name 'get_weather', got %q", result.Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertMessage_ServerToolUse(t *testing.T) {
|
||||
msg := MessageParam{
|
||||
Role: "assistant",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "server_tool_use",
|
||||
"id": "srvtoolu_123",
|
||||
"name": "web_search",
|
||||
"input": map[string]any{"query": "test query"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
messages, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(messages))
|
||||
}
|
||||
|
||||
if len(messages[0].ToolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(messages[0].ToolCalls))
|
||||
}
|
||||
|
||||
tc := messages[0].ToolCalls[0]
|
||||
if tc.ID != "srvtoolu_123" {
|
||||
t.Errorf("expected tool call ID 'srvtoolu_123', got %q", tc.ID)
|
||||
}
|
||||
|
||||
if tc.Function.Name != "web_search" {
|
||||
t.Errorf("expected tool name 'web_search', got %q", tc.Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertMessage_WebSearchToolResult(t *testing.T) {
|
||||
msg := MessageParam{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": "srvtoolu_123",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "web_search_result",
|
||||
"title": "Test Result",
|
||||
"url": "https://example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
messages, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Should have a tool result message
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(messages))
|
||||
}
|
||||
|
||||
if messages[0].Role != "tool" {
|
||||
t.Errorf("expected role 'tool', got %q", messages[0].Role)
|
||||
}
|
||||
|
||||
if messages[0].ToolCallID != "srvtoolu_123" {
|
||||
t.Errorf("expected tool_call_id 'srvtoolu_123', got %q", messages[0].ToolCallID)
|
||||
}
|
||||
|
||||
if messages[0].Content == "" {
|
||||
t.Error("expected non-empty content from web search results")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertMessage_WebSearchToolResultEmptyStillCreatesToolMessage(t *testing.T) {
|
||||
msg := MessageParam{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": "srvtoolu_empty",
|
||||
"content": []any{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
messages, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(messages))
|
||||
}
|
||||
if messages[0].Role != "tool" {
|
||||
t.Fatalf("expected role tool, got %q", messages[0].Role)
|
||||
}
|
||||
if messages[0].ToolCallID != "srvtoolu_empty" {
|
||||
t.Fatalf("expected tool_call_id srvtoolu_empty, got %q", messages[0].ToolCallID)
|
||||
}
|
||||
if messages[0].Content != "" {
|
||||
t.Fatalf("expected empty content for empty web search results, got %q", messages[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertMessage_WebSearchToolResultErrorStillCreatesToolMessage(t *testing.T) {
|
||||
msg := MessageParam{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": "srvtoolu_error",
|
||||
"content": map[string]any{
|
||||
"type": "web_search_tool_result_error",
|
||||
"error_code": "max_uses_exceeded",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
messages, err := convertMessage(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(messages))
|
||||
}
|
||||
if messages[0].Role != "tool" {
|
||||
t.Fatalf("expected role tool, got %q", messages[0].Role)
|
||||
}
|
||||
if messages[0].ToolCallID != "srvtoolu_error" {
|
||||
t.Fatalf("expected tool_call_id srvtoolu_error, got %q", messages[0].ToolCallID)
|
||||
}
|
||||
if !strings.Contains(messages[0].Content, "max_uses_exceeded") {
|
||||
t.Fatalf("expected error code in converted tool content, got %q", messages[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOllamaToAnthropicResults(t *testing.T) {
|
||||
ollamaResp := &OllamaWebSearchResponse{
|
||||
Results: []OllamaWebSearchResult{
|
||||
{
|
||||
Title: "Test Title",
|
||||
URL: "https://example.com",
|
||||
Content: "Test content",
|
||||
},
|
||||
{
|
||||
Title: "Another Result",
|
||||
URL: "https://example.org",
|
||||
Content: "More content",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
results := ConvertOllamaToAnthropicResults(ollamaResp)
|
||||
|
||||
if len(results) != 2 {
|
||||
t.Fatalf("expected 2 results, got %d", len(results))
|
||||
}
|
||||
|
||||
if results[0].Type != "web_search_result" {
|
||||
t.Errorf("expected type 'web_search_result', got %q", results[0].Type)
|
||||
}
|
||||
|
||||
if results[0].Title != "Test Title" {
|
||||
t.Errorf("expected title 'Test Title', got %q", results[0].Title)
|
||||
}
|
||||
|
||||
if results[0].URL != "https://example.com" {
|
||||
t.Errorf("expected URL 'https://example.com', got %q", results[0].URL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSearchTypes(t *testing.T) {
|
||||
// Test that WebSearchResult serializes correctly
|
||||
result := WebSearchResult{
|
||||
Type: "web_search_result",
|
||||
URL: "https://example.com",
|
||||
Title: "Test",
|
||||
EncryptedContent: "abc123",
|
||||
PageAge: "2025-01-01",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal WebSearchResult: %v", err)
|
||||
}
|
||||
|
||||
var unmarshaled WebSearchResult
|
||||
if err := json.Unmarshal(data, &unmarshaled); err != nil {
|
||||
t.Fatalf("failed to unmarshal WebSearchResult: %v", err)
|
||||
}
|
||||
|
||||
if unmarshaled.Type != result.Type {
|
||||
t.Errorf("type mismatch: expected %q, got %q", result.Type, unmarshaled.Type)
|
||||
}
|
||||
|
||||
// Test WebSearchToolResultError
|
||||
errResult := WebSearchToolResultError{
|
||||
Type: "web_search_tool_result_error",
|
||||
ErrorCode: "max_uses_exceeded",
|
||||
}
|
||||
|
||||
data, err = json.Marshal(errResult)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal WebSearchToolResultError: %v", err)
|
||||
}
|
||||
|
||||
var unmarshaledErr WebSearchToolResultError
|
||||
if err := json.Unmarshal(data, &unmarshaledErr); err != nil {
|
||||
t.Fatalf("failed to unmarshal WebSearchToolResultError: %v", err)
|
||||
}
|
||||
|
||||
if unmarshaledErr.ErrorCode != "max_uses_exceeded" {
|
||||
t.Errorf("error_code mismatch: expected 'max_uses_exceeded', got %q", unmarshaledErr.ErrorCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCitation(t *testing.T) {
|
||||
citation := Citation{
|
||||
Type: "web_search_result_location",
|
||||
URL: "https://example.com",
|
||||
Title: "Example",
|
||||
EncryptedIndex: "enc123",
|
||||
CitedText: "Some cited text...",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(citation)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal Citation: %v", err)
|
||||
}
|
||||
|
||||
var unmarshaled Citation
|
||||
if err := json.Unmarshal(data, &unmarshaled); err != nil {
|
||||
t.Fatalf("failed to unmarshal Citation: %v", err)
|
||||
}
|
||||
|
||||
if unmarshaled.Type != "web_search_result_location" {
|
||||
t.Errorf("type mismatch: expected 'web_search_result_location', got %q", unmarshaled.Type)
|
||||
}
|
||||
|
||||
if unmarshaled.CitedText != "Some cited text..." {
|
||||
t.Errorf("cited_text mismatch: expected 'Some cited text...', got %q", unmarshaled.CitedText)
|
||||
}
|
||||
}
|
||||
|
||||
352
anthropic/trace.go
Normal file
352
anthropic/trace.go
Normal file
@@ -0,0 +1,352 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Trace truncation limits.
|
||||
const (
|
||||
TraceMaxStringRunes = 240
|
||||
TraceMaxSliceItems = 8
|
||||
TraceMaxMapEntries = 16
|
||||
TraceMaxDepth = 4
|
||||
)
|
||||
|
||||
// TraceTruncateString shortens s to TraceMaxStringRunes, appending a count of
|
||||
// omitted characters when truncated.
|
||||
func TraceTruncateString(s string) string {
|
||||
if len(s) == 0 {
|
||||
return s
|
||||
}
|
||||
runes := []rune(s)
|
||||
if len(runes) <= TraceMaxStringRunes {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("%s...(+%d chars)", string(runes[:TraceMaxStringRunes]), len(runes)-TraceMaxStringRunes)
|
||||
}
|
||||
|
||||
// TraceJSON round-trips v through JSON and returns a compacted representation.
|
||||
func TraceJSON(v any) any {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return map[string]any{"marshal_error": err.Error(), "type": fmt.Sprintf("%T", v)}
|
||||
}
|
||||
var out any
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
return TraceTruncateString(string(data))
|
||||
}
|
||||
return TraceCompactValue(out, 0)
|
||||
}
|
||||
|
||||
// TraceCompactValue recursively truncates strings, slices, and maps for trace
|
||||
// output. depth tracks recursion to enforce TraceMaxDepth.
|
||||
func TraceCompactValue(v any, depth int) any {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
if depth >= TraceMaxDepth {
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
return TraceTruncateString(t)
|
||||
case []any:
|
||||
return fmt.Sprintf("<array len=%d>", len(t))
|
||||
case map[string]any:
|
||||
return fmt.Sprintf("<object keys=%d>", len(t))
|
||||
default:
|
||||
return fmt.Sprintf("<%T>", v)
|
||||
}
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
return TraceTruncateString(t)
|
||||
case []any:
|
||||
limit := min(len(t), TraceMaxSliceItems)
|
||||
out := make([]any, 0, limit+1)
|
||||
for i := range limit {
|
||||
out = append(out, TraceCompactValue(t[i], depth+1))
|
||||
}
|
||||
if len(t) > limit {
|
||||
out = append(out, fmt.Sprintf("... +%d more items", len(t)-limit))
|
||||
}
|
||||
return out
|
||||
case map[string]any:
|
||||
keys := make([]string, 0, len(t))
|
||||
for k := range t {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
limit := min(len(keys), TraceMaxMapEntries)
|
||||
out := make(map[string]any, limit+1)
|
||||
for i := range limit {
|
||||
out[keys[i]] = TraceCompactValue(t[keys[i]], depth+1)
|
||||
}
|
||||
if len(keys) > limit {
|
||||
out["__truncated_keys"] = len(keys) - limit
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return t
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Anthropic request/response tracing
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TraceMessagesRequest returns a compact trace representation of a MessagesRequest.
|
||||
func TraceMessagesRequest(r MessagesRequest) map[string]any {
|
||||
return map[string]any{
|
||||
"model": r.Model,
|
||||
"max_tokens": r.MaxTokens,
|
||||
"messages": traceMessageParams(r.Messages),
|
||||
"system": traceAnthropicContent(r.System),
|
||||
"stream": r.Stream,
|
||||
"tools": traceTools(r.Tools),
|
||||
"tool_choice": TraceJSON(r.ToolChoice),
|
||||
"thinking": TraceJSON(r.Thinking),
|
||||
"stop_sequences": r.StopSequences,
|
||||
"temperature": ptrVal(r.Temperature),
|
||||
"top_p": ptrVal(r.TopP),
|
||||
"top_k": ptrVal(r.TopK),
|
||||
}
|
||||
}
|
||||
|
||||
// TraceMessagesResponse returns a compact trace representation of a MessagesResponse.
|
||||
func TraceMessagesResponse(r MessagesResponse) map[string]any {
|
||||
return map[string]any{
|
||||
"id": r.ID,
|
||||
"model": r.Model,
|
||||
"content": TraceJSON(r.Content),
|
||||
"stop_reason": r.StopReason,
|
||||
"usage": r.Usage,
|
||||
}
|
||||
}
|
||||
|
||||
func traceMessageParams(msgs []MessageParam) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(msgs))
|
||||
for _, m := range msgs {
|
||||
out = append(out, map[string]any{
|
||||
"role": m.Role,
|
||||
"content": traceAnthropicContent(m.Content),
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func traceAnthropicContent(content any) any {
|
||||
switch c := content.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
case string:
|
||||
return TraceTruncateString(c)
|
||||
case []any:
|
||||
blocks := make([]any, 0, len(c))
|
||||
for _, block := range c {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
blocks = append(blocks, TraceCompactValue(block, 0))
|
||||
continue
|
||||
}
|
||||
blocks = append(blocks, traceAnthropicBlock(blockMap))
|
||||
}
|
||||
return blocks
|
||||
default:
|
||||
return TraceJSON(c)
|
||||
}
|
||||
}
|
||||
|
||||
func traceAnthropicBlock(block map[string]any) map[string]any {
|
||||
blockType, _ := block["type"].(string)
|
||||
out := map[string]any{"type": blockType}
|
||||
switch blockType {
|
||||
case "text":
|
||||
if text, ok := block["text"].(string); ok {
|
||||
out["text"] = TraceTruncateString(text)
|
||||
} else {
|
||||
out["text"] = TraceCompactValue(block["text"], 0)
|
||||
}
|
||||
case "thinking":
|
||||
if thinking, ok := block["thinking"].(string); ok {
|
||||
out["thinking"] = TraceTruncateString(thinking)
|
||||
} else {
|
||||
out["thinking"] = TraceCompactValue(block["thinking"], 0)
|
||||
}
|
||||
case "tool_use", "server_tool_use":
|
||||
out["id"] = block["id"]
|
||||
out["name"] = block["name"]
|
||||
out["input"] = TraceCompactValue(block["input"], 0)
|
||||
case "tool_result", "web_search_tool_result":
|
||||
out["tool_use_id"] = block["tool_use_id"]
|
||||
out["content"] = TraceCompactValue(block["content"], 0)
|
||||
case "image":
|
||||
if source, ok := block["source"].(map[string]any); ok {
|
||||
out["source"] = map[string]any{
|
||||
"type": source["type"],
|
||||
"media_type": source["media_type"],
|
||||
"url": source["url"],
|
||||
"data_len": len(fmt.Sprint(source["data"])),
|
||||
}
|
||||
}
|
||||
default:
|
||||
out["block"] = TraceCompactValue(block, 0)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func traceTools(tools []Tool) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
out = append(out, TraceTool(t))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// TraceTool returns a compact trace representation of an Anthropic Tool.
|
||||
func TraceTool(t Tool) map[string]any {
|
||||
return map[string]any{
|
||||
"type": t.Type,
|
||||
"name": t.Name,
|
||||
"description": TraceTruncateString(t.Description),
|
||||
"input_schema": TraceJSON(t.InputSchema),
|
||||
"max_uses": t.MaxUses,
|
||||
}
|
||||
}
|
||||
|
||||
// ContentBlockTypes returns the type strings from content (when it's []any blocks).
|
||||
func ContentBlockTypes(content any) []string {
|
||||
blocks, ok := content.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
types := make([]string, 0, len(blocks))
|
||||
for _, block := range blocks {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
types = append(types, fmt.Sprintf("%T", block))
|
||||
continue
|
||||
}
|
||||
t, _ := blockMap["type"].(string)
|
||||
types = append(types, t)
|
||||
}
|
||||
return types
|
||||
}
|
||||
|
||||
func ptrVal[T any](v *T) any {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
return *v
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Ollama api.* tracing (shared between anthropic and middleware packages)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TraceChatRequest returns a compact trace representation of an Ollama ChatRequest.
|
||||
func TraceChatRequest(req *api.ChatRequest) map[string]any {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
stream := false
|
||||
if req.Stream != nil {
|
||||
stream = *req.Stream
|
||||
}
|
||||
return map[string]any{
|
||||
"model": req.Model,
|
||||
"messages": TraceAPIMessages(req.Messages),
|
||||
"tools": TraceAPITools(req.Tools),
|
||||
"stream": stream,
|
||||
"options": req.Options,
|
||||
"think": TraceJSON(req.Think),
|
||||
}
|
||||
}
|
||||
|
||||
// TraceChatResponse returns a compact trace representation of an Ollama ChatResponse.
|
||||
func TraceChatResponse(resp api.ChatResponse) map[string]any {
|
||||
return map[string]any{
|
||||
"model": resp.Model,
|
||||
"done": resp.Done,
|
||||
"done_reason": resp.DoneReason,
|
||||
"message": TraceAPIMessage(resp.Message),
|
||||
"metrics": TraceJSON(resp.Metrics),
|
||||
}
|
||||
}
|
||||
|
||||
// TraceAPIMessages returns compact trace representations for a slice of api.Message.
|
||||
func TraceAPIMessages(msgs []api.Message) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(msgs))
|
||||
for _, m := range msgs {
|
||||
out = append(out, TraceAPIMessage(m))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// TraceAPIMessage returns a compact trace representation of a single api.Message.
|
||||
func TraceAPIMessage(m api.Message) map[string]any {
|
||||
return map[string]any{
|
||||
"role": m.Role,
|
||||
"content": TraceTruncateString(m.Content),
|
||||
"thinking": TraceTruncateString(m.Thinking),
|
||||
"images": traceImageSizes(m.Images),
|
||||
"tool_calls": traceToolCalls(m.ToolCalls),
|
||||
"tool_name": m.ToolName,
|
||||
"tool_call_id": m.ToolCallID,
|
||||
}
|
||||
}
|
||||
|
||||
func traceImageSizes(images []api.ImageData) []int {
|
||||
if len(images) == 0 {
|
||||
return nil
|
||||
}
|
||||
sizes := make([]int, 0, len(images))
|
||||
for _, img := range images {
|
||||
sizes = append(sizes, len(img))
|
||||
}
|
||||
return sizes
|
||||
}
|
||||
|
||||
// TraceAPITools returns compact trace representations for a slice of api.Tool.
|
||||
func TraceAPITools(tools api.Tools) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
out = append(out, TraceAPITool(t))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// TraceAPITool returns a compact trace representation of a single api.Tool.
|
||||
func TraceAPITool(t api.Tool) map[string]any {
|
||||
return map[string]any{
|
||||
"type": t.Type,
|
||||
"name": t.Function.Name,
|
||||
"description": TraceTruncateString(t.Function.Description),
|
||||
"parameters": TraceJSON(t.Function.Parameters),
|
||||
}
|
||||
}
|
||||
|
||||
// TraceToolCall returns a compact trace representation of an api.ToolCall.
|
||||
func TraceToolCall(tc api.ToolCall) map[string]any {
|
||||
return map[string]any{
|
||||
"id": tc.ID,
|
||||
"name": tc.Function.Name,
|
||||
"args": TraceJSON(tc.Function.Arguments),
|
||||
}
|
||||
}
|
||||
|
||||
func traceToolCalls(tcs []api.ToolCall) []map[string]any {
|
||||
if len(tcs) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]map[string]any, 0, len(tcs))
|
||||
for _, tc := range tcs {
|
||||
out = append(out, TraceToolCall(tc))
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -2,15 +2,22 @@ package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/anthropic"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
// AnthropicWriter wraps the response writer to transform Ollama responses to Anthropic format
|
||||
@@ -18,7 +25,6 @@ type AnthropicWriter struct {
|
||||
BaseWriter
|
||||
stream bool
|
||||
id string
|
||||
model string
|
||||
converter *anthropic.StreamConverter
|
||||
}
|
||||
|
||||
@@ -31,7 +37,7 @@ func (w *AnthropicWriter) writeError(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.ResponseWriter.Status(), errData.Error))
|
||||
err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.Status(), errData.Error))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -40,18 +46,7 @@ func (w *AnthropicWriter) writeError(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) writeEvent(eventType string, data any) error {
|
||||
d, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return nil
|
||||
return writeSSE(w.ResponseWriter, eventType, data)
|
||||
}
|
||||
|
||||
func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
|
||||
@@ -65,6 +60,7 @@ func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
events := w.converter.Process(chatResponse)
|
||||
logutil.Trace("anthropic middleware: stream chunk", "resp", anthropic.TraceChatResponse(chatResponse), "events", len(events))
|
||||
for _, event := range events {
|
||||
if err := w.writeEvent(event.Event, event.Data); err != nil {
|
||||
return 0, err
|
||||
@@ -75,6 +71,7 @@ func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
response := anthropic.ToMessagesResponse(w.id, chatResponse)
|
||||
logutil.Trace("anthropic middleware: converted response", "resp", anthropic.TraceMessagesResponse(response))
|
||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
||||
}
|
||||
|
||||
@@ -87,9 +84,743 @@ func (w *AnthropicWriter) Write(data []byte) (int, error) {
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
// WebSearchAnthropicWriter intercepts responses containing web_search tool calls,
|
||||
// executes the search, re-invokes the model with results, and assembles the
|
||||
// Anthropic-format response (server_tool_use + web_search_tool_result + text).
|
||||
type WebSearchAnthropicWriter struct {
|
||||
BaseWriter
|
||||
newLoopContext func() (context.Context, context.CancelFunc)
|
||||
inner *AnthropicWriter
|
||||
req anthropic.MessagesRequest // original Anthropic request
|
||||
chatReq *api.ChatRequest // converted Ollama request (for followup calls)
|
||||
stream bool
|
||||
|
||||
estimatedInputTokens int
|
||||
|
||||
terminalSent bool
|
||||
|
||||
observedPromptEvalCount int
|
||||
observedEvalCount int
|
||||
|
||||
loopInFlight bool
|
||||
loopBaseInputTok int
|
||||
loopBaseOutputTok int
|
||||
loopResultCh chan webSearchLoopResult
|
||||
|
||||
streamMessageStarted bool
|
||||
streamHasOpenBlock bool
|
||||
streamOpenBlockIndex int
|
||||
streamNextIndex int
|
||||
}
|
||||
|
||||
const maxWebSearchLoops = 3
|
||||
|
||||
type webSearchLoopResult struct {
|
||||
response anthropic.MessagesResponse
|
||||
loopErr *webSearchLoopError
|
||||
}
|
||||
|
||||
type webSearchLoopError struct {
|
||||
code string
|
||||
query string
|
||||
usage anthropic.Usage
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *webSearchLoopError) Error() string {
|
||||
if e.err == nil {
|
||||
return e.code
|
||||
}
|
||||
return fmt.Sprintf("%s: %v", e.code, e.err)
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) Write(data []byte) (int, error) {
|
||||
if w.terminalSent {
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
code := w.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.inner.writeError(data)
|
||||
}
|
||||
|
||||
var chatResponse api.ChatResponse
|
||||
if err := json.Unmarshal(data, &chatResponse); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
w.recordObservedUsage(chatResponse.Metrics)
|
||||
|
||||
if w.stream && w.loopInFlight {
|
||||
if !chatResponse.Done {
|
||||
return len(data), nil
|
||||
}
|
||||
if err := w.writeLoopResult(); err != nil {
|
||||
return len(data), err
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
webSearchCall, hasWebSearch, hasOtherTools := findWebSearchToolCall(chatResponse.Message.ToolCalls)
|
||||
logutil.Trace("anthropic middleware: upstream chunk",
|
||||
"resp", anthropic.TraceChatResponse(chatResponse),
|
||||
"web_search", hasWebSearch,
|
||||
"other_tools", hasOtherTools,
|
||||
)
|
||||
if hasWebSearch && hasOtherTools {
|
||||
// Prefer web_search if both server and client tools are present in one chunk.
|
||||
slog.Debug("preferring web_search tool call over client tool calls in mixed tool response")
|
||||
}
|
||||
|
||||
if !hasWebSearch {
|
||||
if w.stream {
|
||||
if err := w.writePassthroughStreamChunk(chatResponse); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
return w.inner.writeResponse(data)
|
||||
}
|
||||
|
||||
if w.stream {
|
||||
// Let the original generation continue to completion while web search runs in parallel.
|
||||
logutil.Trace("anthropic middleware: starting async web_search loop",
|
||||
"tool_call", anthropic.TraceToolCall(webSearchCall),
|
||||
"resp", anthropic.TraceChatResponse(chatResponse),
|
||||
)
|
||||
w.startLoopWorker(chatResponse, webSearchCall)
|
||||
if chatResponse.Done {
|
||||
if err := w.writeLoopResult(); err != nil {
|
||||
return len(data), err
|
||||
}
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
loopCtx, cancel := w.startLoopContext()
|
||||
defer cancel()
|
||||
|
||||
initialUsage := anthropic.Usage{
|
||||
InputTokens: max(w.observedPromptEvalCount, chatResponse.Metrics.PromptEvalCount),
|
||||
OutputTokens: max(w.observedEvalCount, chatResponse.Metrics.EvalCount),
|
||||
}
|
||||
logutil.Trace("anthropic middleware: starting sync web_search loop",
|
||||
"tool_call", anthropic.TraceToolCall(webSearchCall),
|
||||
"resp", anthropic.TraceChatResponse(chatResponse),
|
||||
"usage", initialUsage,
|
||||
)
|
||||
response, loopErr := w.runWebSearchLoop(loopCtx, chatResponse, webSearchCall, initialUsage)
|
||||
if loopErr != nil {
|
||||
return len(data), w.sendError(loopErr.code, loopErr.query, loopErr.usage)
|
||||
}
|
||||
|
||||
if err := w.writeTerminalResponse(response); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) runWebSearchLoop(ctx context.Context, initialResponse api.ChatResponse, initialToolCall api.ToolCall, initialUsage anthropic.Usage) (anthropic.MessagesResponse, *webSearchLoopError) {
|
||||
followUpMessages := make([]api.Message, 0, len(w.chatReq.Messages)+maxWebSearchLoops*2)
|
||||
followUpMessages = append(followUpMessages, w.chatReq.Messages...)
|
||||
|
||||
followUpTools := append(api.Tools(nil), w.chatReq.Tools...)
|
||||
usage := initialUsage
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search loop init",
|
||||
"model", w.req.Model,
|
||||
"tool_call", anthropic.TraceToolCall(initialToolCall),
|
||||
"messages", len(followUpMessages),
|
||||
"tools", len(followUpTools),
|
||||
"max_loops", maxWebSearchLoops,
|
||||
)
|
||||
|
||||
currentResponse := initialResponse
|
||||
currentToolCall := initialToolCall
|
||||
|
||||
var serverContent []anthropic.ContentBlock
|
||||
|
||||
if !isCloudModelName(w.req.Model) {
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search execution blocked", "reason", "non_cloud_model")
|
||||
return anthropic.MessagesResponse{}, &webSearchLoopError{
|
||||
code: "web_search_not_supported_for_local_models",
|
||||
query: extractQueryFromToolCall(&initialToolCall),
|
||||
usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
for loop := 1; loop <= maxWebSearchLoops; loop++ {
|
||||
query := extractQueryFromToolCall(¤tToolCall)
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search loop iteration",
|
||||
"loop", loop,
|
||||
"query", anthropic.TraceTruncateString(query),
|
||||
"messages", len(followUpMessages),
|
||||
)
|
||||
if query == "" {
|
||||
return anthropic.MessagesResponse{}, &webSearchLoopError{
|
||||
code: "invalid_request",
|
||||
query: "",
|
||||
usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
const defaultMaxResults = 5
|
||||
searchResp, err := anthropic.WebSearch(ctx, query, defaultMaxResults)
|
||||
if err != nil {
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search request failed",
|
||||
"loop", loop,
|
||||
"query", query,
|
||||
"error", err,
|
||||
)
|
||||
return anthropic.MessagesResponse{}, &webSearchLoopError{
|
||||
code: "unavailable",
|
||||
query: query,
|
||||
usage: usage,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search results",
|
||||
"loop", loop,
|
||||
"results", len(searchResp.Results),
|
||||
)
|
||||
|
||||
toolUseID := loopServerToolUseID(w.inner.id, loop)
|
||||
searchResults := anthropic.ConvertOllamaToAnthropicResults(searchResp)
|
||||
serverContent = append(serverContent,
|
||||
anthropic.ContentBlock{
|
||||
Type: "server_tool_use",
|
||||
ID: toolUseID,
|
||||
Name: "web_search",
|
||||
Input: map[string]any{"query": query},
|
||||
},
|
||||
anthropic.ContentBlock{
|
||||
Type: "web_search_tool_result",
|
||||
ToolUseID: toolUseID,
|
||||
Content: searchResults,
|
||||
},
|
||||
)
|
||||
|
||||
assistantMsg := buildWebSearchAssistantMessage(currentResponse, currentToolCall)
|
||||
toolResultMsg := api.Message{
|
||||
Role: "tool",
|
||||
Content: formatWebSearchResultsForToolMessage(searchResp.Results),
|
||||
ToolCallID: currentToolCall.ID,
|
||||
}
|
||||
followUpMessages = append(followUpMessages, assistantMsg, toolResultMsg)
|
||||
|
||||
followUpResponse, err := w.callFollowUpChat(ctx, followUpMessages, followUpTools)
|
||||
if err != nil {
|
||||
logutil.TraceContext(ctx, "anthropic middleware: followup /api/chat failed",
|
||||
"loop", loop,
|
||||
"query", query,
|
||||
"error", err,
|
||||
)
|
||||
return anthropic.MessagesResponse{}, &webSearchLoopError{
|
||||
code: "api_error",
|
||||
query: query,
|
||||
usage: usage,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic middleware: followup response",
|
||||
"loop", loop,
|
||||
"resp", anthropic.TraceChatResponse(followUpResponse),
|
||||
)
|
||||
|
||||
usage.InputTokens += followUpResponse.Metrics.PromptEvalCount
|
||||
usage.OutputTokens += followUpResponse.Metrics.EvalCount
|
||||
|
||||
nextToolCall, hasWebSearch, hasOtherTools := findWebSearchToolCall(followUpResponse.Message.ToolCalls)
|
||||
if hasWebSearch && hasOtherTools {
|
||||
// Prefer web_search if both server and client tools are present in one chunk.
|
||||
slog.Debug("preferring web_search tool call over client tool calls in mixed followup response")
|
||||
}
|
||||
|
||||
if !hasWebSearch {
|
||||
finalResponse := w.combineServerAndFinalContent(serverContent, followUpResponse, usage)
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search loop complete",
|
||||
"loop", loop,
|
||||
"resp", anthropic.TraceMessagesResponse(finalResponse),
|
||||
)
|
||||
return finalResponse, nil
|
||||
}
|
||||
|
||||
currentResponse = followUpResponse
|
||||
currentToolCall = nextToolCall
|
||||
}
|
||||
|
||||
maxLoopQuery := extractQueryFromToolCall(¤tToolCall)
|
||||
maxLoopToolUseID := loopServerToolUseID(w.inner.id, maxWebSearchLoops+1)
|
||||
serverContent = append(serverContent,
|
||||
anthropic.ContentBlock{
|
||||
Type: "server_tool_use",
|
||||
ID: maxLoopToolUseID,
|
||||
Name: "web_search",
|
||||
Input: map[string]any{"query": maxLoopQuery},
|
||||
},
|
||||
anthropic.ContentBlock{
|
||||
Type: "web_search_tool_result",
|
||||
ToolUseID: maxLoopToolUseID,
|
||||
Content: anthropic.WebSearchToolResultError{
|
||||
Type: "web_search_tool_result_error",
|
||||
ErrorCode: "max_uses_exceeded",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
maxResponse := anthropic.MessagesResponse{
|
||||
ID: w.inner.id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: w.req.Model,
|
||||
Content: serverContent,
|
||||
StopReason: "end_turn",
|
||||
Usage: usage,
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic middleware: web_search loop max reached",
|
||||
"resp", anthropic.TraceMessagesResponse(maxResponse),
|
||||
)
|
||||
return maxResponse, nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) startLoopWorker(initialResponse api.ChatResponse, initialToolCall api.ToolCall) {
|
||||
if w.loopInFlight {
|
||||
return
|
||||
}
|
||||
|
||||
initialUsage := anthropic.Usage{
|
||||
InputTokens: max(w.observedPromptEvalCount, initialResponse.Metrics.PromptEvalCount),
|
||||
OutputTokens: max(w.observedEvalCount, initialResponse.Metrics.EvalCount),
|
||||
}
|
||||
w.loopBaseInputTok = initialUsage.InputTokens
|
||||
w.loopBaseOutputTok = initialUsage.OutputTokens
|
||||
w.loopResultCh = make(chan webSearchLoopResult, 1)
|
||||
w.loopInFlight = true
|
||||
logutil.Trace("anthropic middleware: loop worker started",
|
||||
"usage", initialUsage,
|
||||
"tool_call", anthropic.TraceToolCall(initialToolCall),
|
||||
)
|
||||
|
||||
go func() {
|
||||
ctx, cancel := w.startLoopContext()
|
||||
defer cancel()
|
||||
|
||||
response, loopErr := w.runWebSearchLoop(ctx, initialResponse, initialToolCall, initialUsage)
|
||||
w.loopResultCh <- webSearchLoopResult{
|
||||
response: response,
|
||||
loopErr: loopErr,
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) writeLoopResult() error {
|
||||
if w.loopResultCh == nil {
|
||||
return w.sendError("api_error", "", w.currentObservedUsage())
|
||||
}
|
||||
|
||||
result := <-w.loopResultCh
|
||||
w.loopResultCh = nil
|
||||
w.loopInFlight = false
|
||||
if result.loopErr != nil {
|
||||
logutil.Trace("anthropic middleware: loop worker returned error",
|
||||
"code", result.loopErr.code,
|
||||
"query", result.loopErr.query,
|
||||
"usage", result.loopErr.usage,
|
||||
"error", result.loopErr.err,
|
||||
)
|
||||
usage := result.loopErr.usage
|
||||
w.applyObservedUsageDeltaToUsage(&usage)
|
||||
return w.sendError(result.loopErr.code, result.loopErr.query, usage)
|
||||
}
|
||||
logutil.Trace("anthropic middleware: loop worker done", "resp", anthropic.TraceMessagesResponse(result.response))
|
||||
|
||||
w.applyObservedUsageDelta(&result.response)
|
||||
return w.writeTerminalResponse(result.response)
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) applyObservedUsageDelta(response *anthropic.MessagesResponse) {
|
||||
w.applyObservedUsageDeltaToUsage(&response.Usage)
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) recordObservedUsage(metrics api.Metrics) {
|
||||
if metrics.PromptEvalCount > w.observedPromptEvalCount {
|
||||
w.observedPromptEvalCount = metrics.PromptEvalCount
|
||||
}
|
||||
if metrics.EvalCount > w.observedEvalCount {
|
||||
w.observedEvalCount = metrics.EvalCount
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) applyObservedUsageDeltaToUsage(usage *anthropic.Usage) {
|
||||
if deltaIn := w.observedPromptEvalCount - w.loopBaseInputTok; deltaIn > 0 {
|
||||
usage.InputTokens += deltaIn
|
||||
}
|
||||
if deltaOut := w.observedEvalCount - w.loopBaseOutputTok; deltaOut > 0 {
|
||||
usage.OutputTokens += deltaOut
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) currentObservedUsage() anthropic.Usage {
|
||||
return anthropic.Usage{
|
||||
InputTokens: w.observedPromptEvalCount,
|
||||
OutputTokens: w.observedEvalCount,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) startLoopContext() (context.Context, context.CancelFunc) {
|
||||
if w.newLoopContext != nil {
|
||||
return w.newLoopContext()
|
||||
}
|
||||
return context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) combineServerAndFinalContent(serverContent []anthropic.ContentBlock, finalResponse api.ChatResponse, usage anthropic.Usage) anthropic.MessagesResponse {
|
||||
converted := anthropic.ToMessagesResponse(w.inner.id, finalResponse)
|
||||
|
||||
content := make([]anthropic.ContentBlock, 0, len(serverContent)+len(converted.Content))
|
||||
content = append(content, serverContent...)
|
||||
content = append(content, converted.Content...)
|
||||
|
||||
return anthropic.MessagesResponse{
|
||||
ID: w.inner.id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: w.req.Model,
|
||||
Content: content,
|
||||
StopReason: converted.StopReason,
|
||||
StopSequence: converted.StopSequence,
|
||||
Usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
func buildWebSearchAssistantMessage(response api.ChatResponse, webSearchCall api.ToolCall) api.Message {
|
||||
assistantMsg := api.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{webSearchCall},
|
||||
}
|
||||
if response.Message.Content != "" {
|
||||
assistantMsg.Content = response.Message.Content
|
||||
}
|
||||
if response.Message.Thinking != "" {
|
||||
assistantMsg.Thinking = response.Message.Thinking
|
||||
}
|
||||
return assistantMsg
|
||||
}
|
||||
|
||||
func formatWebSearchResultsForToolMessage(results []anthropic.OllamaWebSearchResult) string {
|
||||
var resultText strings.Builder
|
||||
for _, r := range results {
|
||||
fmt.Fprintf(&resultText, "Title: %s\nURL: %s\n", r.Title, r.URL)
|
||||
if r.Content != "" {
|
||||
fmt.Fprintf(&resultText, "Content: %s\n", r.Content)
|
||||
}
|
||||
resultText.WriteString("\n")
|
||||
}
|
||||
return resultText.String()
|
||||
}
|
||||
|
||||
func findWebSearchToolCall(toolCalls []api.ToolCall) (api.ToolCall, bool, bool) {
|
||||
var webSearchCall api.ToolCall
|
||||
hasWebSearch := false
|
||||
hasOtherTools := false
|
||||
|
||||
for _, toolCall := range toolCalls {
|
||||
if toolCall.Function.Name == "web_search" {
|
||||
if !hasWebSearch {
|
||||
webSearchCall = toolCall
|
||||
hasWebSearch = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
hasOtherTools = true
|
||||
}
|
||||
|
||||
return webSearchCall, hasWebSearch, hasOtherTools
|
||||
}
|
||||
|
||||
func loopServerToolUseID(messageID string, loop int) string {
|
||||
base := serverToolUseID(messageID)
|
||||
if loop <= 1 {
|
||||
return base
|
||||
}
|
||||
return fmt.Sprintf("%s_%d", base, loop)
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) callFollowUpChat(ctx context.Context, messages []api.Message, tools api.Tools) (api.ChatResponse, error) {
|
||||
streaming := false
|
||||
followUp := api.ChatRequest{
|
||||
Model: w.chatReq.Model,
|
||||
Messages: messages,
|
||||
Stream: &streaming,
|
||||
Tools: tools,
|
||||
Options: w.chatReq.Options,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(followUp)
|
||||
if err != nil {
|
||||
return api.ChatResponse{}, err
|
||||
}
|
||||
|
||||
chatURL := envconfig.Host().String() + "/api/chat"
|
||||
logutil.TraceContext(ctx, "anthropic middleware: followup request",
|
||||
"url", chatURL,
|
||||
"req", anthropic.TraceChatRequest(&followUp),
|
||||
)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", chatURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return api.ChatResponse{}, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return api.ChatResponse{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
logutil.TraceContext(ctx, "anthropic middleware: followup non-200 response",
|
||||
"status", resp.StatusCode,
|
||||
"response", strings.TrimSpace(string(respBody)),
|
||||
)
|
||||
return api.ChatResponse{}, fmt.Errorf("followup /api/chat returned status %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody)))
|
||||
}
|
||||
|
||||
var chatResp api.ChatResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
||||
return api.ChatResponse{}, err
|
||||
}
|
||||
logutil.TraceContext(ctx, "anthropic middleware: followup decoded", "resp", anthropic.TraceChatResponse(chatResp))
|
||||
|
||||
return chatResp, nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) writePassthroughStreamChunk(chatResponse api.ChatResponse) error {
|
||||
events := w.inner.converter.Process(chatResponse)
|
||||
for _, event := range events {
|
||||
switch e := event.Data.(type) {
|
||||
case anthropic.MessageStartEvent:
|
||||
w.streamMessageStarted = true
|
||||
case anthropic.ContentBlockStartEvent:
|
||||
w.streamHasOpenBlock = true
|
||||
w.streamOpenBlockIndex = e.Index
|
||||
if e.Index+1 > w.streamNextIndex {
|
||||
w.streamNextIndex = e.Index + 1
|
||||
}
|
||||
case anthropic.ContentBlockStopEvent:
|
||||
if w.streamHasOpenBlock && w.streamOpenBlockIndex == e.Index {
|
||||
w.streamHasOpenBlock = false
|
||||
}
|
||||
if e.Index+1 > w.streamNextIndex {
|
||||
w.streamNextIndex = e.Index + 1
|
||||
}
|
||||
case anthropic.MessageStopEvent:
|
||||
w.terminalSent = true
|
||||
}
|
||||
|
||||
if err := writeSSE(w.ResponseWriter, event.Event, event.Data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) ensureStreamMessageStart(usage anthropic.Usage) error {
|
||||
if w.streamMessageStarted {
|
||||
return nil
|
||||
}
|
||||
|
||||
inputTokens := usage.InputTokens
|
||||
if inputTokens == 0 {
|
||||
inputTokens = w.estimatedInputTokens
|
||||
}
|
||||
|
||||
if err := writeSSE(w.ResponseWriter, "message_start", anthropic.MessageStartEvent{
|
||||
Type: "message_start",
|
||||
Message: anthropic.MessagesResponse{
|
||||
ID: w.inner.id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: w.req.Model,
|
||||
Content: []anthropic.ContentBlock{},
|
||||
Usage: anthropic.Usage{
|
||||
InputTokens: inputTokens,
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w.streamMessageStarted = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) closeOpenStreamBlock() error {
|
||||
if !w.streamHasOpenBlock {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := writeSSE(w.ResponseWriter, "content_block_stop", anthropic.ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: w.streamOpenBlockIndex,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if w.streamOpenBlockIndex+1 > w.streamNextIndex {
|
||||
w.streamNextIndex = w.streamOpenBlockIndex + 1
|
||||
}
|
||||
w.streamHasOpenBlock = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) writeStreamContentBlocks(content []anthropic.ContentBlock) error {
|
||||
for _, block := range content {
|
||||
index := w.streamNextIndex
|
||||
if block.Type == "text" {
|
||||
emptyText := ""
|
||||
if err := writeSSE(w.ResponseWriter, "content_block_start", anthropic.ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: index,
|
||||
ContentBlock: anthropic.ContentBlock{
|
||||
Type: "text",
|
||||
Text: &emptyText,
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
text := ""
|
||||
if block.Text != nil {
|
||||
text = *block.Text
|
||||
}
|
||||
if err := writeSSE(w.ResponseWriter, "content_block_delta", anthropic.ContentBlockDeltaEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: index,
|
||||
Delta: anthropic.Delta{
|
||||
Type: "text_delta",
|
||||
Text: text,
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := writeSSE(w.ResponseWriter, "content_block_start", anthropic.ContentBlockStartEvent{
|
||||
Type: "content_block_start",
|
||||
Index: index,
|
||||
ContentBlock: block,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := writeSSE(w.ResponseWriter, "content_block_stop", anthropic.ContentBlockStopEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: index,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w.streamNextIndex++
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) writeTerminalResponse(response anthropic.MessagesResponse) error {
|
||||
if w.terminalSent {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !w.stream {
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w.ResponseWriter).Encode(response); err != nil {
|
||||
return err
|
||||
}
|
||||
w.terminalSent = true
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := w.ensureStreamMessageStart(response.Usage); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := w.closeOpenStreamBlock(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := w.writeStreamContentBlocks(response.Content); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writeSSE(w.ResponseWriter, "message_delta", anthropic.MessageDeltaEvent{
|
||||
Type: "message_delta",
|
||||
Delta: anthropic.MessageDelta{
|
||||
StopReason: response.StopReason,
|
||||
},
|
||||
Usage: anthropic.DeltaUsage{
|
||||
InputTokens: response.Usage.InputTokens,
|
||||
OutputTokens: response.Usage.OutputTokens,
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writeSSE(w.ResponseWriter, "message_stop", anthropic.MessageStopEvent{
|
||||
Type: "message_stop",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w.terminalSent = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// streamResponse emits a complete MessagesResponse as SSE events.
|
||||
func (w *WebSearchAnthropicWriter) streamResponse(response anthropic.MessagesResponse) error {
|
||||
return w.writeTerminalResponse(response)
|
||||
}
|
||||
|
||||
func (w *WebSearchAnthropicWriter) webSearchErrorResponse(errorCode, query string, usage anthropic.Usage) anthropic.MessagesResponse {
|
||||
toolUseID := serverToolUseID(w.inner.id)
|
||||
|
||||
return anthropic.MessagesResponse{
|
||||
ID: w.inner.id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: w.req.Model,
|
||||
Content: []anthropic.ContentBlock{
|
||||
{
|
||||
Type: "server_tool_use",
|
||||
ID: toolUseID,
|
||||
Name: "web_search",
|
||||
Input: map[string]any{"query": query},
|
||||
},
|
||||
{
|
||||
Type: "web_search_tool_result",
|
||||
ToolUseID: toolUseID,
|
||||
Content: anthropic.WebSearchToolResultError{
|
||||
Type: "web_search_tool_result_error",
|
||||
ErrorCode: errorCode,
|
||||
},
|
||||
},
|
||||
},
|
||||
StopReason: "end_turn",
|
||||
Usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
// sendError sends a web search error response.
|
||||
func (w *WebSearchAnthropicWriter) sendError(errorCode, query string, usage anthropic.Usage) error {
|
||||
response := w.webSearchErrorResponse(errorCode, query, usage)
|
||||
logutil.Trace("anthropic middleware: web_search error", "code", errorCode, "query", query, "usage", usage)
|
||||
return w.writeTerminalResponse(response)
|
||||
}
|
||||
|
||||
// AnthropicMessagesMiddleware handles Anthropic Messages API requests
|
||||
func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
requestCtx := c.Request.Context()
|
||||
|
||||
var req anthropic.MessagesRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
@@ -134,11 +865,10 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
// Estimate input tokens for streaming (actual count not available until generation completes)
|
||||
estimatedTokens := anthropic.EstimateInputTokens(req)
|
||||
|
||||
w := &AnthropicWriter{
|
||||
innerWriter := &AnthropicWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: messageID,
|
||||
model: req.Model,
|
||||
converter: anthropic.NewStreamConverter(messageID, req.Model, estimatedTokens),
|
||||
}
|
||||
|
||||
@@ -148,8 +878,78 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
if hasWebSearchTool(req.Tools) {
|
||||
// Guard against runtime cloud-disable policy (OLLAMA_NO_CLOUD/server.json)
|
||||
// for cloud models. Local models may still receive web_search tool definitions;
|
||||
// execution is validated when the model actually emits a web_search tool call.
|
||||
if isCloudModelName(req.Model) {
|
||||
if disabled, _ := internalcloud.Status(); disabled {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, anthropic.NewError(http.StatusForbidden, internalcloud.DisabledError("web search is unavailable")))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.Writer = &WebSearchAnthropicWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
newLoopContext: func() (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(requestCtx, 5*time.Minute)
|
||||
},
|
||||
inner: innerWriter,
|
||||
req: req,
|
||||
chatReq: chatReq,
|
||||
stream: req.Stream,
|
||||
estimatedInputTokens: estimatedTokens,
|
||||
}
|
||||
} else {
|
||||
c.Writer = innerWriter
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// hasWebSearchTool checks if the request tools include a web_search tool
|
||||
func hasWebSearchTool(tools []anthropic.Tool) bool {
|
||||
for _, tool := range tools {
|
||||
if strings.HasPrefix(tool.Type, "web_search") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isCloudModelName(name string) bool {
|
||||
return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
|
||||
}
|
||||
|
||||
// extractQueryFromToolCall extracts the search query from a web_search tool call
|
||||
func extractQueryFromToolCall(tc *api.ToolCall) string {
|
||||
q, ok := tc.Function.Arguments.Get("query")
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
if s, ok := q.(string); ok {
|
||||
return s
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// writeSSE writes a Server-Sent Event
|
||||
func writeSSE(w http.ResponseWriter, eventType string, data any) error {
|
||||
d, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, d); err != nil {
|
||||
return err
|
||||
}
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// serverToolUseID derives a server tool use ID from a message ID
|
||||
func serverToolUseID(messageID string) string {
|
||||
return "srvtoolu_" + strings.TrimPrefix(messageID, "msg_")
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
22
middleware/test_home_test.go
Normal file
22
middleware/test_home_test.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
func setTestHome(t *testing.T, home string) {
|
||||
t.Helper()
|
||||
t.Setenv("HOME", home)
|
||||
t.Setenv("USERPROFILE", home)
|
||||
envconfig.ReloadServerConfig()
|
||||
}
|
||||
|
||||
// enableCloudForTest sets HOME to a clean temp dir and clears OLLAMA_NO_CLOUD
|
||||
// so that cloud features are enabled for the duration of the test.
|
||||
func enableCloudForTest(t *testing.T) {
|
||||
t.Helper()
|
||||
t.Setenv("OLLAMA_NO_CLOUD", "")
|
||||
setTestHome(t, t.TempDir())
|
||||
}
|
||||
Reference in New Issue
Block a user