Files
ollama/middleware/anthropic_test.go
2026-02-13 19:20:46 -08:00

2980 lines
89 KiB
Go

package middleware
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/ollama/ollama/anthropic"
"github.com/ollama/ollama/api"
)
func captureAnthropicRequest(capturedRequest any) gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
_ = json.Unmarshal(bodyBytes, capturedRequest)
c.Next()
}
}
// testProps creates ToolPropertiesMap from a map (convenience function for tests)
func testProps(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
func TestAnthropicMessagesMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.ChatRequest
err anthropic.ErrorResponse
}
var capturedRequest *api.ChatRequest
stream := true
testCases := []testCase{
{
name: "basic message",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with system prompt",
body: `{
"model": "test-model",
"max_tokens": 1024,
"system": "You are helpful.",
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with options",
body: `{
"model": "test-model",
"max_tokens": 2048,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"stop_sequences": ["\n", "END"],
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{
"num_predict": 2048,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"stop": []string{"\n", "END"},
},
Stream: &False,
},
},
{
name: "streaming",
body: `{
"model": "test-model",
"max_tokens": 1024,
"stream": true,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &stream,
},
},
{
name: "with tools",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "What's the weather?"}
],
"tools": [{
"name": "get_weather",
"description": "Get current weather",
"input_schema": {
"type": "object",
"properties": {
"location": {"type": "string"}
},
"required": ["location"]
}
}]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
},
Tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: testProps(map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
}),
},
},
},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with tool result",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "What's the weather?"},
{"role": "assistant", "content": [
{"type": "tool_use", "id": "call_123", "name": "get_weather", "input": {"location": "Paris"}}
]},
{"role": "user", "content": [
{"type": "tool_result", "tool_use_id": "call_123", "content": "Sunny, 22°C"}
]}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_123",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "Paris"}),
},
},
},
},
{Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call_123"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
},
},
{
name: "with thinking enabled",
body: `{
"model": "test-model",
"max_tokens": 1024,
"thinking": {"type": "enabled", "budget_tokens": 1000},
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
Options: map[string]any{"num_predict": 1024},
Stream: &False,
Think: &api.ThinkValue{Value: true},
},
},
{
name: "missing model error",
body: `{
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "model is required",
},
},
},
{
name: "missing max_tokens error",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "max_tokens is required and must be positive",
},
},
},
{
name: "missing messages error",
body: `{
"model": "test-model",
"max_tokens": 1024
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "messages is required",
},
},
},
{
name: "tool_use missing id error",
body: `{
"model": "test-model",
"max_tokens": 1024,
"messages": [
{"role": "assistant", "content": [
{"type": "tool_use", "name": "test"}
]}
]
}`,
err: anthropic.ErrorResponse{
Type: "error",
Error: anthropic.Error{
Type: "invalid_request_error",
Message: "tool_use block missing required 'id' field",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware(), captureAnthropicRequest(&capturedRequest))
router.Handle(http.MethodPost, "/v1/messages", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
defer func() { capturedRequest = nil }()
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if tc.err.Type != "" {
// Expect error
if resp.Code == http.StatusOK {
t.Fatalf("expected error response, got 200 OK")
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error: %v", err)
}
if errResp.Type != tc.err.Type {
t.Errorf("expected error type %q, got %q", tc.err.Type, errResp.Type)
}
if errResp.Error.Type != tc.err.Error.Type {
t.Errorf("expected error.type %q, got %q", tc.err.Error.Type, errResp.Error.Type)
}
if errResp.Error.Message != tc.err.Error.Message {
t.Errorf("expected error.message %q, got %q", tc.err.Error.Message, errResp.Error.Message)
}
return
}
if resp.Code != http.StatusOK {
t.Fatalf("unexpected status code: %d, body: %s", resp.Code, resp.Body.String())
}
if capturedRequest == nil {
t.Fatal("request was not captured")
}
// Compare relevant fields
if capturedRequest.Model != tc.req.Model {
t.Errorf("model mismatch: got %q, want %q", capturedRequest.Model, tc.req.Model)
}
if diff := cmp.Diff(tc.req.Messages, capturedRequest.Messages,
cmpopts.IgnoreUnexported(api.ToolCallFunctionArguments{}, api.ToolPropertiesMap{})); diff != "" {
t.Errorf("messages mismatch (-want +got):\n%s", diff)
}
if tc.req.Stream != nil && capturedRequest.Stream != nil {
if *tc.req.Stream != *capturedRequest.Stream {
t.Errorf("stream mismatch: got %v, want %v", *capturedRequest.Stream, *tc.req.Stream)
}
}
if tc.req.Think != nil {
if capturedRequest.Think == nil {
t.Error("expected Think to be set")
} else if capturedRequest.Think.Value != tc.req.Think.Value {
t.Errorf("Think mismatch: got %v, want %v", capturedRequest.Think.Value, tc.req.Think.Value)
}
}
})
}
}
func TestAnthropicMessagesMiddleware_Headers(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Run("streaming sets correct headers", func(t *testing.T) {
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Check headers were set
if c.Writer.Header().Get("Content-Type") != "text/event-stream" {
t.Errorf("expected Content-Type text/event-stream, got %q", c.Writer.Header().Get("Content-Type"))
}
if c.Writer.Header().Get("Cache-Control") != "no-cache" {
t.Errorf("expected Cache-Control no-cache, got %q", c.Writer.Header().Get("Cache-Control"))
}
c.Status(http.StatusOK)
})
body := `{"model": "test", "max_tokens": 100, "stream": true, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
})
}
func TestAnthropicMessagesMiddleware_InvalidJSON(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{invalid json`))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", resp.Code)
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error: %v", err)
}
if errResp.Type != "error" {
t.Errorf("expected type 'error', got %q", errResp.Type)
}
if errResp.Error.Type != "invalid_request_error" {
t.Errorf("expected error type 'invalid_request_error', got %q", errResp.Error.Type)
}
}
func TestAnthropicWriter_NonStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Simulate Ollama response
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Hello there!",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 5,
},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", resp.Code)
}
var result anthropic.MessagesResponse
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if result.Role != "assistant" {
t.Errorf("expected role 'assistant', got %q", result.Role)
}
if len(result.Content) != 1 {
t.Fatalf("expected 1 content block, got %d", len(result.Content))
}
if result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
t.Errorf("expected text 'Hello there!', got %v", result.Content[0].Text)
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
if result.Usage.InputTokens != 10 {
t.Errorf("expected input_tokens 10, got %d", result.Usage.InputTokens)
}
if result.Usage.OutputTokens != 5 {
t.Errorf("expected output_tokens 5, got %d", result.Usage.OutputTokens)
}
}
// TestAnthropicWriter_ErrorFromRoutes tests error handling when routes.go sends
// gin.H{"error": "message"} without a StatusCode field (which is the common case)
func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
statusCode int
errorPayload any
wantErrorType string
wantMessage string
}{
// routes.go sends errors without StatusCode in JSON, so we must use HTTP status
{
name: "404 with gin.H error (model not found)",
statusCode: http.StatusNotFound,
errorPayload: gin.H{"error": "model 'nonexistent' not found"},
wantErrorType: "not_found_error",
wantMessage: "model 'nonexistent' not found",
},
{
name: "400 with gin.H error (bad request)",
statusCode: http.StatusBadRequest,
errorPayload: gin.H{"error": "model is required"},
wantErrorType: "invalid_request_error",
wantMessage: "model is required",
},
{
name: "500 with gin.H error (internal error)",
statusCode: http.StatusInternalServerError,
errorPayload: gin.H{"error": "something went wrong"},
wantErrorType: "api_error",
wantMessage: "something went wrong",
},
{
name: "404 with api.StatusError",
statusCode: http.StatusNotFound,
errorPayload: api.StatusError{
StatusCode: http.StatusNotFound,
ErrorMessage: "model not found via StatusError",
},
wantErrorType: "not_found_error",
wantMessage: "model not found via StatusError",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Simulate what routes.go does - set status and write error JSON
data, _ := json.Marshal(tt.errorPayload)
c.Writer.WriteHeader(tt.statusCode)
_, _ = c.Writer.Write(data)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != tt.statusCode {
t.Errorf("expected status %d, got %d", tt.statusCode, resp.Code)
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to unmarshal error response: %v\nbody: %s", err, resp.Body.String())
}
if errResp.Type != "error" {
t.Errorf("expected type 'error', got %q", errResp.Type)
}
if errResp.Error.Type != tt.wantErrorType {
t.Errorf("expected error type %q, got %q", tt.wantErrorType, errResp.Error.Type)
}
if errResp.Error.Message != tt.wantMessage {
t.Errorf("expected message %q, got %q", tt.wantMessage, errResp.Error.Message)
}
})
}
}
func TestAnthropicMessagesMiddleware_SetsRelaxThinkingFlag(t *testing.T) {
gin.SetMode(gin.TestMode)
var flagSet bool
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
_, flagSet = c.Get("relax_thinking")
c.Status(http.StatusOK)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if !flagSet {
t.Error("expected relax_thinking flag to be set in context")
}
}
// Web Search Tests
func TestHasWebSearchTool(t *testing.T) {
tests := []struct {
name string
tools []anthropic.Tool
expected bool
}{
{
name: "no tools",
tools: nil,
expected: false,
},
{
name: "regular tool only",
tools: []anthropic.Tool{
{Type: "custom", Name: "get_weather"},
},
expected: false,
},
{
name: "web search tool",
tools: []anthropic.Tool{
{Type: "web_search_20250305", Name: "web_search"},
},
expected: true,
},
{
name: "mixed tools",
tools: []anthropic.Tool{
{Type: "custom", Name: "get_weather"},
{Type: "web_search_20250305", Name: "web_search"},
},
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := hasWebSearchTool(tt.tools)
if result != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, result)
}
})
}
}
func TestExtractQueryFromToolCall(t *testing.T) {
tests := []struct {
name string
tc *api.ToolCall
expected string
}{
{
name: "valid query",
tc: &api.ToolCall{
Function: api.ToolCallFunction{
Name: "web_search",
Arguments: makeArgs("query", "test search"),
},
},
expected: "test search",
},
{
name: "empty arguments",
tc: &api.ToolCall{
Function: api.ToolCallFunction{
Name: "web_search",
},
},
expected: "",
},
{
name: "no query key",
tc: &api.ToolCall{
Function: api.ToolCallFunction{
Name: "web_search",
Arguments: makeArgs("other", "value"),
},
},
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractQueryFromToolCall(tt.tc)
if result != tt.expected {
t.Errorf("expected %q, got %q", tt.expected, result)
}
})
}
}
// makeArgs is a test helper that creates ToolCallFunctionArguments
func makeArgs(key string, value any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
args.Set(key, value)
return args
}
// --- Web Search Integration Tests ---
// TestWebSearchServerToolUseID tests the ID derivation logic.
func TestWebSearchServerToolUseID(t *testing.T) {
tests := []struct {
msgID string
expected string
}{
{"msg_abc123", "srvtoolu_abc123"},
{"msg_", "srvtoolu_"},
{"nomsgprefix", "srvtoolu_nomsgprefix"},
}
for _, tt := range tests {
got := serverToolUseID(tt.msgID)
if got != tt.expected {
t.Errorf("serverToolUseID(%q) = %q, want %q", tt.msgID, got, tt.expected)
}
}
}
// TestWebSearchNoWebSearchTool verifies that when there is no web_search tool,
// requests pass through to the normal AnthropicWriter without interception.
func TestWebSearchNoWebSearchTool(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Normal response",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{"model":"test-model","max_tokens":100,"messages":[{"role":"user","content":"Hello"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
}
var result anthropic.MessagesResponse
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if len(result.Content) != 1 || result.Content[0].Type != "text" {
t.Fatalf("expected single text block, got %d blocks", len(result.Content))
}
if *result.Content[0].Text != "Normal response" {
t.Errorf("expected text 'Normal response', got %q", *result.Content[0].Text)
}
}
// TestWebSearchToolPresent_ModelDoesNotCallIt_NonStreaming verifies that when
// the web_search tool is present but the model does not call it, the response
// passes through normally (non-streaming case).
func TestWebSearchToolPresent_ModelDoesNotCallIt_NonStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
enableCloudForTest(t)
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "I can answer that without searching.",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 12, EvalCount: 8},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{
"model":"test-model:cloud",
"max_tokens":100,
"messages":[{"role":"user","content":"What is 2+2?"}],
"tools":[{"type":"web_search_20250305","name":"web_search"}]
}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
}
var result anthropic.MessagesResponse
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if len(result.Content) != 1 || result.Content[0].Type != "text" {
t.Fatalf("expected single text block, got %+v", result.Content)
}
if *result.Content[0].Text != "I can answer that without searching." {
t.Errorf("unexpected text: %q", *result.Content[0].Text)
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
}
// TestWebSearchToolPresent_ModelDoesNotCallIt_Streaming verifies the streaming
// pass-through case when the model does not invoke web_search.
func TestWebSearchToolPresent_ModelDoesNotCallIt_Streaming(t *testing.T) {
gin.SetMode(gin.TestMode)
enableCloudForTest(t)
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Simulate streaming: two partial chunks then a final chunk
chunks := []api.ChatResponse{
{
Model: "test-model",
Message: api.Message{Role: "assistant", Content: "Hello "},
Done: false,
},
{
Model: "test-model",
Message: api.Message{Role: "assistant", Content: "world"},
Done: false,
},
{
Model: "test-model",
Message: api.Message{Role: "assistant", Content: ""},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
},
}
c.Writer.WriteHeader(http.StatusOK)
for _, chunk := range chunks {
data, _ := json.Marshal(chunk)
_, _ = c.Writer.Write(data)
}
})
body := `{
"model":"test-model:cloud",
"max_tokens":100,
"stream":true,
"messages":[{"role":"user","content":"Hi"}],
"tools":[{"type":"web_search_20250305","name":"web_search"}]
}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
}
// Parse SSE events
events := parseSSEEvents(t, resp.Body.String())
// Should have standard streaming event flow
if len(events) == 0 {
t.Fatal("expected SSE events, got none")
}
// First event should be message_start
if events[0].event != "message_start" {
t.Errorf("first event should be message_start, got %q", events[0].event)
}
// Should have content_block_start for text
hasTextStart := false
hasTextDelta := false
hasMessageStop := false
for _, e := range events {
if e.event == "content_block_start" {
var cbs anthropic.ContentBlockStartEvent
if err := json.Unmarshal([]byte(e.data), &cbs); err == nil {
if cbs.ContentBlock.Type == "text" {
hasTextStart = true
}
}
}
if e.event == "content_block_delta" {
var cbd anthropic.ContentBlockDeltaEvent
if err := json.Unmarshal([]byte(e.data), &cbd); err == nil {
if cbd.Delta.Type == "text_delta" {
hasTextDelta = true
}
}
}
if e.event == "message_stop" {
hasMessageStop = true
}
}
if !hasTextStart {
t.Error("expected content_block_start with text type")
}
if !hasTextDelta {
t.Error("expected content_block_delta with text_delta")
}
if !hasMessageStop {
t.Error("expected message_stop event")
}
}
// TestWebSearchToolPresent_ModelCallsIt_NonStreaming tests the full web search flow
// in non-streaming mode. It mocks the followup /api/chat call using a local HTTP server.
func TestWebSearchToolPresent_ModelCallsIt_NonStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
enableCloudForTest(t)
// Create a mock Ollama server that responds to the followup /api/chat call
followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "Based on my search, the answer is 42.",
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 50, EvalCount: 20},
}
_ = json.NewEncoder(w).Encode(resp)
}))
defer followupServer.Close()
// Set OLLAMA_HOST to our mock server so the followup call goes there
t.Setenv("OLLAMA_HOST", followupServer.URL)
// Also mock the web search API
searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := anthropic.OllamaWebSearchResponse{
Results: []anthropic.OllamaWebSearchResult{
{Title: "Test Result", URL: "https://example.com/result", Content: "Some content"},
},
}
_ = json.NewEncoder(w).Encode(resp)
}))
defer searchServer.Close()
// Point DoWebSearch at our mock search server
originalEndpoint := anthropic.WebSearchEndpoint
anthropic.WebSearchEndpoint = searchServer.URL
defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_ws_001",
Function: api.ToolCallFunction{
Name: "web_search",
Arguments: makeArgs("query", "meaning of life"),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 15, EvalCount: 3},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{
"model":"test-model:cloud",
"max_tokens":100,
"messages":[{"role":"user","content":"What is the meaning of life?"}],
"tools":[{"type":"web_search_20250305","name":"web_search"}]
}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
}
var result anthropic.MessagesResponse
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
t.Fatalf("unmarshal error: %v\nbody: %s", err, resp.Body.String())
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if result.Role != "assistant" {
t.Errorf("expected role 'assistant', got %q", result.Role)
}
// Should have 3 blocks: server_tool_use + web_search_tool_result + text
if len(result.Content) != 3 {
t.Fatalf("expected 3 content blocks, got %d: %+v", len(result.Content), result.Content)
}
if result.Content[0].Type != "server_tool_use" {
t.Errorf("expected first block type 'server_tool_use', got %q", result.Content[0].Type)
}
if result.Content[0].Name != "web_search" {
t.Errorf("expected name 'web_search', got %q", result.Content[0].Name)
}
if result.Content[1].Type != "web_search_tool_result" {
t.Errorf("expected second block type 'web_search_tool_result', got %q", result.Content[1].Type)
}
if result.Content[1].ToolUseID != result.Content[0].ID {
t.Errorf("tool_use_id mismatch: %q != %q", result.Content[1].ToolUseID, result.Content[0].ID)
}
if result.Content[2].Type != "text" {
t.Errorf("expected third block type 'text', got %q", result.Content[2].Type)
}
if result.Content[2].Text == nil || *result.Content[2].Text == "" {
t.Error("expected non-empty text in third block")
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
}
// TestWebSearchToolPresent_ModelCallsIt_Streaming tests the streaming SSE output
// when the model calls web_search with mocked search and followup endpoints.
func TestWebSearchToolPresent_ModelCallsIt_Streaming(t *testing.T) {
gin.SetMode(gin.TestMode)
enableCloudForTest(t)
// Mock followup /api/chat server
followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Content: "Here are the latest news."},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 40, EvalCount: 15},
}
_ = json.NewEncoder(w).Encode(resp)
}))
defer followupServer.Close()
t.Setenv("OLLAMA_HOST", followupServer.URL)
// Mock web search API
searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := anthropic.OllamaWebSearchResponse{
Results: []anthropic.OllamaWebSearchResult{
{Title: "News Result", URL: "https://example.com/news", Content: "Breaking news"},
},
}
_ = json.NewEncoder(w).Encode(resp)
}))
defer searchServer.Close()
originalEndpoint := anthropic.WebSearchEndpoint
anthropic.WebSearchEndpoint = searchServer.URL
defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
// Simulate buffered streaming: non-final chunk then final with tool call
chunks := []api.ChatResponse{
{
Model: "test-model",
Message: api.Message{Role: "assistant"},
Done: false,
},
{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_ws_002",
Function: api.ToolCallFunction{
Name: "web_search",
Arguments: makeArgs("query", "latest news"),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 2},
},
}
c.Writer.WriteHeader(http.StatusOK)
for _, chunk := range chunks {
data, _ := json.Marshal(chunk)
_, _ = c.Writer.Write(data)
}
})
body := `{
"model":"test-model:cloud",
"max_tokens":100,
"stream":true,
"messages":[{"role":"user","content":"What is the latest news?"}],
"tools":[{"type":"web_search_20250305","name":"web_search"}]
}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
}
events := parseSSEEvents(t, resp.Body.String())
// Success path: 10 events (3 blocks: server_tool_use, web_search_tool_result, text with delta)
expectedEventTypes := []string{
"message_start",
"content_block_start", // server_tool_use
"content_block_stop",
"content_block_start", // web_search_tool_result
"content_block_stop",
"content_block_start", // text (empty)
"content_block_delta", // text_delta with actual content
"content_block_stop",
"message_delta",
"message_stop",
}
if len(events) != len(expectedEventTypes) {
t.Fatalf("expected %d events, got %d.\nEvents: %v", len(expectedEventTypes), len(events), eventNames(events))
}
for i, expected := range expectedEventTypes {
if events[i].event != expected {
t.Errorf("event[%d]: expected %q, got %q", i, expected, events[i].event)
}
}
// Verify text delta has the followup model's content
var textDelta anthropic.ContentBlockDeltaEvent
if err := json.Unmarshal([]byte(events[6].data), &textDelta); err != nil {
t.Fatalf("failed to parse text delta: %v", err)
}
if textDelta.Delta.Type != "text_delta" {
t.Errorf("expected delta type 'text_delta', got %q", textDelta.Delta.Type)
}
if textDelta.Delta.Text != "Here are the latest news." {
t.Errorf("expected followup text, got %q", textDelta.Delta.Text)
}
}
// TestWebSearchStreamResponse tests the streamResponse method directly by constructing
// a WebSearchAnthropicWriter and calling streamResponse with a known response.
func TestWebSearchStreamResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
text := "Here is the answer."
response := anthropic.MessagesResponse{
ID: "msg_test123",
Type: "message",
Role: "assistant",
Model: "test-model",
Content: []anthropic.ContentBlock{
{
Type: "server_tool_use",
ID: "srvtoolu_test123",
Name: "web_search",
Input: map[string]any{"query": "test query"},
},
{
Type: "web_search_tool_result",
ToolUseID: "srvtoolu_test123",
Content: []anthropic.WebSearchResult{
{Type: "web_search_result", URL: "https://example.com", Title: "Example"},
},
},
{
Type: "text",
Text: &text,
},
},
StopReason: "end_turn",
Usage: anthropic.Usage{InputTokens: 20, OutputTokens: 10},
}
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
innerWriter := &AnthropicWriter{
BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer},
stream: true,
id: "msg_test123",
}
wsWriter := &WebSearchAnthropicWriter{
BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer},
inner: innerWriter,
stream: true,
req: anthropic.MessagesRequest{Model: "test-model"},
}
if err := wsWriter.streamResponse(response); err != nil {
t.Fatalf("streamResponse error: %v", err)
}
events := parseSSEEvents(t, rec.Body.String())
// Verify full event sequence
expectedEventTypes := []string{
"message_start",
"content_block_start", // server_tool_use (index 0)
"content_block_stop", // index 0
"content_block_start", // web_search_tool_result (index 1)
"content_block_stop", // index 1
"content_block_start", // text (index 2)
"content_block_delta", // text_delta
"content_block_stop", // index 2
"message_delta",
"message_stop",
}
if len(events) != len(expectedEventTypes) {
t.Fatalf("expected %d events, got %d.\nEvents: %v", len(expectedEventTypes), len(events), eventNames(events))
}
for i, expected := range expectedEventTypes {
if events[i].event != expected {
t.Errorf("event[%d]: expected %q, got %q", i, expected, events[i].event)
}
}
// Verify message_start content
var msgStart anthropic.MessageStartEvent
if err := json.Unmarshal([]byte(events[0].data), &msgStart); err != nil {
t.Fatalf("failed to parse message_start: %v", err)
}
if msgStart.Message.ID != "msg_test123" {
t.Errorf("expected message ID 'msg_test123', got %q", msgStart.Message.ID)
}
if msgStart.Message.Role != "assistant" {
t.Errorf("expected role 'assistant', got %q", msgStart.Message.Role)
}
if len(msgStart.Message.Content) != 0 {
t.Errorf("expected empty content in message_start, got %d blocks", len(msgStart.Message.Content))
}
// Verify content_block_start for server_tool_use (event index 1)
var toolStart anthropic.ContentBlockStartEvent
if err := json.Unmarshal([]byte(events[1].data), &toolStart); err != nil {
t.Fatalf("failed to parse server_tool_use start: %v", err)
}
if toolStart.Index != 0 {
t.Errorf("expected index 0, got %d", toolStart.Index)
}
if toolStart.ContentBlock.Type != "server_tool_use" {
t.Errorf("expected type 'server_tool_use', got %q", toolStart.ContentBlock.Type)
}
if toolStart.ContentBlock.ID != "srvtoolu_test123" {
t.Errorf("expected ID 'srvtoolu_test123', got %q", toolStart.ContentBlock.ID)
}
// Verify content_block_start for web_search_tool_result (event index 3)
var searchStart anthropic.ContentBlockStartEvent
if err := json.Unmarshal([]byte(events[3].data), &searchStart); err != nil {
t.Fatalf("failed to parse web_search_tool_result start: %v", err)
}
if searchStart.Index != 1 {
t.Errorf("expected index 1, got %d", searchStart.Index)
}
if searchStart.ContentBlock.Type != "web_search_tool_result" {
t.Errorf("expected type 'web_search_tool_result', got %q", searchStart.ContentBlock.Type)
}
// Verify text block: content_block_start (event index 5)
var textStart anthropic.ContentBlockStartEvent
if err := json.Unmarshal([]byte(events[5].data), &textStart); err != nil {
t.Fatalf("failed to parse text start: %v", err)
}
if textStart.Index != 2 {
t.Errorf("expected index 2, got %d", textStart.Index)
}
if textStart.ContentBlock.Type != "text" {
t.Errorf("expected type 'text', got %q", textStart.ContentBlock.Type)
}
// Text in start should be empty
if textStart.ContentBlock.Text == nil || *textStart.ContentBlock.Text != "" {
t.Errorf("expected empty text in content_block_start, got %v", textStart.ContentBlock.Text)
}
// Verify text delta (event index 6)
var textDelta anthropic.ContentBlockDeltaEvent
if err := json.Unmarshal([]byte(events[6].data), &textDelta); err != nil {
t.Fatalf("failed to parse text delta: %v", err)
}
if textDelta.Index != 2 {
t.Errorf("expected index 2, got %d", textDelta.Index)
}
if textDelta.Delta.Type != "text_delta" {
t.Errorf("expected delta type 'text_delta', got %q", textDelta.Delta.Type)
}
if textDelta.Delta.Text != "Here is the answer." {
t.Errorf("expected delta text 'Here is the answer.', got %q", textDelta.Delta.Text)
}
// Verify message_delta (event index 8)
var msgDelta anthropic.MessageDeltaEvent
if err := json.Unmarshal([]byte(events[8].data), &msgDelta); err != nil {
t.Fatalf("failed to parse message_delta: %v", err)
}
if msgDelta.Delta.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", msgDelta.Delta.StopReason)
}
if msgDelta.Usage.InputTokens != 20 {
t.Errorf("expected input_tokens 20, got %d", msgDelta.Usage.InputTokens)
}
if msgDelta.Usage.OutputTokens != 10 {
t.Errorf("expected output_tokens 10, got %d", msgDelta.Usage.OutputTokens)
}
}
// TestWebSearchSendError_NonStreaming tests sendError produces correct response shape.
func TestWebSearchSendError_NonStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
innerWriter := &AnthropicWriter{
BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer},
stream: false,
id: "msg_err001",
}
wsWriter := &WebSearchAnthropicWriter{
BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer},
inner: innerWriter,
stream: false,
req: anthropic.MessagesRequest{Model: "test-model"},
}
errorUsage := anthropic.Usage{InputTokens: 7, OutputTokens: 2}
if err := wsWriter.sendError("unavailable", "test query", errorUsage); err != nil {
t.Fatalf("sendError error: %v", err)
}
var result anthropic.MessagesResponse
if err := json.Unmarshal(rec.Body.Bytes(), &result); err != nil {
t.Fatalf("unmarshal error: %v\nbody: %s", err, rec.Body.String())
}
if result.Type != "message" {
t.Errorf("expected type 'message', got %q", result.Type)
}
if result.ID != "msg_err001" {
t.Errorf("expected ID 'msg_err001', got %q", result.ID)
}
// Should have exactly 2 blocks: server_tool_use + web_search_tool_result
if len(result.Content) != 2 {
t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
}
// Block 0: server_tool_use
if result.Content[0].Type != "server_tool_use" {
t.Errorf("expected 'server_tool_use', got %q", result.Content[0].Type)
}
expectedToolID := "srvtoolu_err001"
if result.Content[0].ID != expectedToolID {
t.Errorf("expected ID %q, got %q", expectedToolID, result.Content[0].ID)
}
if result.Content[0].Name != "web_search" {
t.Errorf("expected name 'web_search', got %q", result.Content[0].Name)
}
// Verify input contains the query
inputMap, ok := result.Content[0].Input.(map[string]any)
if !ok {
t.Fatalf("expected Input to be map, got %T", result.Content[0].Input)
}
if inputMap["query"] != "test query" {
t.Errorf("expected query 'test query', got %v", inputMap["query"])
}
// Block 1: web_search_tool_result with error
if result.Content[1].Type != "web_search_tool_result" {
t.Errorf("expected 'web_search_tool_result', got %q", result.Content[1].Type)
}
if result.Content[1].ToolUseID != expectedToolID {
t.Errorf("expected tool_use_id %q, got %q", expectedToolID, result.Content[1].ToolUseID)
}
// The Content field should be a WebSearchToolResultError
contentJSON, _ := json.Marshal(result.Content[1].Content)
var errContent anthropic.WebSearchToolResultError
if err := json.Unmarshal(contentJSON, &errContent); err != nil {
t.Fatalf("failed to parse error content: %v\nraw: %s", err, string(contentJSON))
}
if errContent.Type != "web_search_tool_result_error" {
t.Errorf("expected error type 'web_search_tool_result_error', got %q", errContent.Type)
}
if errContent.ErrorCode != "unavailable" {
t.Errorf("expected error_code 'unavailable', got %q", errContent.ErrorCode)
}
if result.StopReason != "end_turn" {
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
}
if result.Usage != errorUsage {
t.Errorf("expected usage %+v, got %+v", errorUsage, result.Usage)
}
}
// TestWebSearchSendError_Streaming tests sendError in streaming mode produces proper SSE.
func TestWebSearchSendError_Streaming(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
innerWriter := &AnthropicWriter{
BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer},
stream: true,
id: "msg_err002",
}
wsWriter := &WebSearchAnthropicWriter{
BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer},
inner: innerWriter,
stream: true,
req: anthropic.MessagesRequest{Model: "test-model"},
}
errorUsage := anthropic.Usage{InputTokens: 9, OutputTokens: 4}
if err := wsWriter.sendError("invalid_request", "bad query", errorUsage); err != nil {
t.Fatalf("sendError error: %v", err)
}
events := parseSSEEvents(t, rec.Body.String())
// Error response has 2 blocks: server_tool_use + web_search_tool_result
// Expected events: message_start,
// content_block_start(server_tool_use), content_block_stop,
// content_block_start(web_search_tool_result), content_block_stop,
// message_delta, message_stop
expectedEventTypes := []string{
"message_start",
"content_block_start",
"content_block_stop",
"content_block_start",
"content_block_stop",
"message_delta",
"message_stop",
}
if len(events) != len(expectedEventTypes) {
t.Fatalf("expected %d events, got %d.\nEvents: %v", len(expectedEventTypes), len(events), eventNames(events))
}
for i, expected := range expectedEventTypes {
if events[i].event != expected {
t.Errorf("event[%d]: expected %q, got %q", i, expected, events[i].event)
}
}
// Verify the server_tool_use block
var toolStart anthropic.ContentBlockStartEvent
if err := json.Unmarshal([]byte(events[1].data), &toolStart); err != nil {
t.Fatalf("failed to parse server_tool_use start: %v", err)
}
if toolStart.ContentBlock.Type != "server_tool_use" {
t.Errorf("expected 'server_tool_use', got %q", toolStart.ContentBlock.Type)
}
// Verify the web_search_tool_result block
var resultStart anthropic.ContentBlockStartEvent
if err := json.Unmarshal([]byte(events[3].data), &resultStart); err != nil {
t.Fatalf("failed to parse web_search_tool_result start: %v", err)
}
if resultStart.ContentBlock.Type != "web_search_tool_result" {
t.Errorf("expected 'web_search_tool_result', got %q", resultStart.ContentBlock.Type)
}
var msgDelta anthropic.MessageDeltaEvent
if err := json.Unmarshal([]byte(events[5].data), &msgDelta); err != nil {
t.Fatalf("failed to parse message_delta: %v", err)
}
if msgDelta.Usage.InputTokens != errorUsage.InputTokens || msgDelta.Usage.OutputTokens != errorUsage.OutputTokens {
t.Fatalf("expected usage %+v in message_delta, got %+v", errorUsage, msgDelta.Usage)
}
}
// TestWebSearchSendError_EmptyQuery tests sendError with an empty query.
func TestWebSearchSendError_EmptyQuery(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
innerWriter := &AnthropicWriter{
BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer},
stream: false,
id: "msg_empty001",
}
wsWriter := &WebSearchAnthropicWriter{
BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer},
inner: innerWriter,
stream: false,
req: anthropic.MessagesRequest{Model: "test-model"},
}
if err := wsWriter.sendError("invalid_request", "", anthropic.Usage{}); err != nil {
t.Fatalf("sendError error: %v", err)
}
var result anthropic.MessagesResponse
if err := json.Unmarshal(rec.Body.Bytes(), &result); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if len(result.Content) != 2 {
t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
}
// Verify the input has empty query
inputMap, ok := result.Content[0].Input.(map[string]any)
if !ok {
t.Fatalf("expected Input to be map, got %T", result.Content[0].Input)
}
if inputMap["query"] != "" {
t.Errorf("expected empty query, got %v", inputMap["query"])
}
}
// --- SSE parsing helpers ---
type sseEvent struct {
event string
data string
}
// parseSSEEvents parses Server-Sent Events from a string.
func parseSSEEvents(t *testing.T, body string) []sseEvent {
t.Helper()
var events []sseEvent
var currentEvent string
var currentData strings.Builder
for _, line := range strings.Split(body, "\n") {
if strings.HasPrefix(line, "event: ") {
currentEvent = strings.TrimPrefix(line, "event: ")
} else if strings.HasPrefix(line, "data: ") {
currentData.WriteString(strings.TrimPrefix(line, "data: "))
} else if line == "" && currentEvent != "" {
events = append(events, sseEvent{event: currentEvent, data: currentData.String()})
currentEvent = ""
currentData.Reset()
}
}
return events
}
// eventNames returns a list of event type names for debugging.
func eventNames(events []sseEvent) []string {
names := make([]string, len(events))
for i, e := range events {
names[i] = e.event
}
return names
}
// TestWebSearchCloudModelGating tests web_search behavior across model types.
func TestWebSearchCloudModelGating(t *testing.T) {
gin.SetMode(gin.TestMode)
enableCloudForTest(t)
t.Run("local model allowed when web_search is not called", func(t *testing.T) {
handlerCalled := false
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
handlerCalled = true
resp := api.ChatResponse{
Model: "llama3.2",
Message: api.Message{Role: "assistant", Content: "hello"},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{"model":"llama3.2","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", resp.Code, resp.Body.String())
}
if !handlerCalled {
t.Error("handler should be called for local model when web_search is not called")
}
})
t.Run("local model emits web_search and gets structured error", func(t *testing.T) {
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
resp := api.ChatResponse{
Model: "llama3.2",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_local_ws",
Function: api.ToolCallFunction{
Name: "web_search",
Arguments: makeArgs("query", "hello"),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 8, EvalCount: 2},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{"model":"llama3.2","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
}
var result anthropic.MessagesResponse
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if len(result.Content) != 2 {
t.Fatalf("expected 2 content blocks for local model web_search error, got %d", len(result.Content))
}
contentJSON, _ := json.Marshal(result.Content[1].Content)
var errContent anthropic.WebSearchToolResultError
if err := json.Unmarshal(contentJSON, &errContent); err != nil {
t.Fatalf("failed to parse web_search error content: %v", err)
}
if errContent.ErrorCode != "web_search_not_supported_for_local_models" {
t.Fatalf("expected web_search_not_supported_for_local_models, got %q", errContent.ErrorCode)
}
})
t.Run("model ending in cloud without cloud suffix treated as local", func(t *testing.T) {
handlerCalled := false
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
handlerCalled = true
resp := api.ChatResponse{
Model: "notreallycloud",
Message: api.Message{Role: "assistant", Content: "hello"},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{"model":"notreallycloud","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if !handlerCalled {
t.Error("handler should be called for non-cloud model when web_search is not called")
}
if resp.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", resp.Code, resp.Body.String())
}
})
t.Run("cloud model with size tag allowed", func(t *testing.T) {
handlerCalled := false
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
handlerCalled = true
resp := api.ChatResponse{
Model: "gpt-oss:120b",
Message: api.Message{Role: "assistant", Content: "hello"},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{"model":"gpt-oss:120b-cloud","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if !handlerCalled {
t.Error("handler should be called for cloud model")
}
if resp.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", resp.Code, resp.Body.String())
}
})
t.Run("cloud model allowed", func(t *testing.T) {
handlerCalled := false
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
handlerCalled = true
resp := api.ChatResponse{
Model: "kimi-k2.5",
Message: api.Message{Role: "assistant", Content: "hello"},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{"model":"kimi-k2.5:cloud","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if !handlerCalled {
t.Error("handler should be called for cloud model")
}
if resp.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", resp.Code, resp.Body.String())
}
})
t.Run("cloud disabled blocks web search for cloud model", func(t *testing.T) {
t.Setenv("OLLAMA_NO_CLOUD", "1")
handlerCalled := false
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
handlerCalled = true
})
body := `{"model":"kimi-k2.5:cloud","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusForbidden {
t.Fatalf("expected 403, got %d: %s", resp.Code, resp.Body.String())
}
if handlerCalled {
t.Fatal("handler should not be called when cloud is disabled")
}
var errResp anthropic.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatalf("failed to parse error response: %v", err)
}
if !strings.Contains(errResp.Error.Message, "ollama cloud is disabled") {
t.Fatalf("expected cloud disabled error, got: %q", errResp.Error.Message)
}
})
t.Run("cloud disabled does not block local model if web_search is not called", func(t *testing.T) {
t.Setenv("OLLAMA_NO_CLOUD", "1")
handlerCalled := false
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
handlerCalled = true
resp := api.ChatResponse{
Model: "llama3.2",
Message: api.Message{Role: "assistant", Content: "hello"},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{"model":"llama3.2","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
}
if !handlerCalled {
t.Fatal("handler should be called for local model when web_search is not called")
}
})
}
func TestWebSearchDoesNotRequireAuthorizationHeaderForMockEndpoint(t *testing.T) {
gin.SetMode(gin.TestMode)
enableCloudForTest(t)
var authHeader string
searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader = r.Header.Get("Authorization")
resp := anthropic.OllamaWebSearchResponse{
Results: []anthropic.OllamaWebSearchResult{
{Title: "Result", URL: "https://example.com", Content: "content"},
},
}
_ = json.NewEncoder(w).Encode(resp)
}))
defer searchServer.Close()
originalEndpoint := anthropic.WebSearchEndpoint
anthropic.WebSearchEndpoint = searchServer.URL
defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Content: "done"},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 5, EvalCount: 2},
}
_ = json.NewEncoder(w).Encode(resp)
}))
defer followupServer.Close()
t.Setenv("OLLAMA_HOST", followupServer.URL)
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_auth",
Function: api.ToolCallFunction{
Name: "web_search",
Arguments: makeArgs("query", "auth test"),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 4, EvalCount: 1},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{
"model":"test-model:cloud",
"max_tokens":100,
"messages":[{"role":"user","content":"test auth"}],
"tools":[{"type":"web_search_20250305","name":"web_search"}]
}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
}
if authHeader != "" {
t.Fatalf("expected no Authorization header for mock web search endpoint, got %q", authHeader)
}
}
// TestWebSearchSearchAPIError tests that a failing search API returns a proper error response.
func TestWebSearchSearchAPIError(t *testing.T) {
gin.SetMode(gin.TestMode)
enableCloudForTest(t)
// Mock search server that returns 500
searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "internal error", http.StatusInternalServerError)
}))
defer searchServer.Close()
originalEndpoint := anthropic.WebSearchEndpoint
anthropic.WebSearchEndpoint = searchServer.URL
defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_err",
Function: api.ToolCallFunction{
Name: "web_search",
Arguments: makeArgs("query", "test"),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 2},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{
"model":"test-model:cloud",
"max_tokens":100,
"messages":[{"role":"user","content":"test"}],
"tools":[{"type":"web_search_20250305","name":"web_search"}]
}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
}
var result anthropic.MessagesResponse
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
// Error response: server_tool_use + web_search_tool_result with error
if len(result.Content) != 2 {
t.Fatalf("expected 2 content blocks for error, got %d", len(result.Content))
}
if result.Content[0].Type != "server_tool_use" {
t.Errorf("expected 'server_tool_use', got %q", result.Content[0].Type)
}
if result.Content[1].Type != "web_search_tool_result" {
t.Errorf("expected 'web_search_tool_result', got %q", result.Content[1].Type)
}
if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 2 {
t.Fatalf("expected usage input=10 output=2, got %+v", result.Usage)
}
}
func TestWebSearchStreamingImmediateTakeover(t *testing.T) {
gin.SetMode(gin.TestMode)
enableCloudForTest(t)
followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Content: "After search."},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 20, EvalCount: 10},
}
_ = json.NewEncoder(w).Encode(resp)
}))
defer followupServer.Close()
t.Setenv("OLLAMA_HOST", followupServer.URL)
searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := anthropic.OllamaWebSearchResponse{
Results: []anthropic.OllamaWebSearchResult{
{Title: "Result", URL: "https://example.com", Content: "content"},
},
}
_ = json.NewEncoder(w).Encode(resp)
}))
defer searchServer.Close()
originalEndpoint := anthropic.WebSearchEndpoint
anthropic.WebSearchEndpoint = searchServer.URL
defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
chunks := []api.ChatResponse{
{
Model: "test-model",
Message: api.Message{Role: "assistant", Content: "Preface "},
Done: false,
},
{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_ws_stream_1",
Function: api.ToolCallFunction{
Name: "web_search",
Arguments: makeArgs("query", "latest updates"),
},
},
},
},
Done: false,
},
{
Model: "test-model",
Message: api.Message{Role: "assistant", Content: "ignored chunk"},
Done: false,
},
{
Model: "test-model",
Message: api.Message{Role: "assistant"},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 9, EvalCount: 4},
},
}
c.Writer.WriteHeader(http.StatusOK)
for _, chunk := range chunks {
data, _ := json.Marshal(chunk)
_, _ = c.Writer.Write(data)
}
})
body := `{
"model":"test-model:cloud",
"max_tokens":100,
"stream":true,
"messages":[{"role":"user","content":"Find updates"}],
"tools":[{"type":"web_search_20250305","name":"web_search"}]
}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
}
events := parseSSEEvents(t, resp.Body.String())
if countEventsByName(events, "message_start") != 1 {
t.Fatalf("expected exactly one message_start, got %d", countEventsByName(events, "message_start"))
}
if countEventsByName(events, "message_stop") != 1 {
t.Fatalf("expected exactly one message_stop, got %d", countEventsByName(events, "message_stop"))
}
textDeltas := collectTextDeltas(t, events)
if !containsString(textDeltas, "Preface ") {
t.Fatalf("expected passthrough text delta, got %v", textDeltas)
}
if !containsString(textDeltas, "After search.") {
t.Fatalf("expected post-search text delta, got %v", textDeltas)
}
if containsString(textDeltas, "ignored chunk") {
t.Fatalf("unexpected text from chunks after takeover: %v", textDeltas)
}
}
func TestWebSearchStreamingUsageUsesObservedChunkMetrics(t *testing.T) {
gin.SetMode(gin.TestMode)
enableCloudForTest(t)
followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Content: "After search."},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 20, EvalCount: 7},
}
_ = json.NewEncoder(w).Encode(resp)
}))
defer followupServer.Close()
t.Setenv("OLLAMA_HOST", followupServer.URL)
searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := anthropic.OllamaWebSearchResponse{
Results: []anthropic.OllamaWebSearchResult{
{Title: "Result", URL: "https://example.com", Content: "content"},
},
}
_ = json.NewEncoder(w).Encode(resp)
}))
defer searchServer.Close()
originalEndpoint := anthropic.WebSearchEndpoint
anthropic.WebSearchEndpoint = searchServer.URL
defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
chunks := []api.ChatResponse{
{
Model: "test-model",
Message: api.Message{Role: "assistant", Content: "Preface "},
Done: false,
Metrics: api.Metrics{PromptEvalCount: 12, EvalCount: 4},
},
{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_ws_stream_usage",
Function: api.ToolCallFunction{
Name: "web_search",
Arguments: makeArgs("query", "latest updates"),
},
},
},
},
Done: false,
Metrics: api.Metrics{PromptEvalCount: 0, EvalCount: 0},
},
{
Model: "test-model",
Message: api.Message{Role: "assistant"},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 12, EvalCount: 4},
},
}
c.Writer.WriteHeader(http.StatusOK)
for _, chunk := range chunks {
data, _ := json.Marshal(chunk)
_, _ = c.Writer.Write(data)
}
})
body := `{
"model":"test-model:cloud",
"max_tokens":100,
"stream":true,
"messages":[{"role":"user","content":"Find updates"}],
"tools":[{"type":"web_search_20250305","name":"web_search"}]
}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
}
events := parseSSEEvents(t, resp.Body.String())
var messageDelta anthropic.MessageDeltaEvent
found := false
for _, event := range events {
if event.event != "message_delta" {
continue
}
if err := json.Unmarshal([]byte(event.data), &messageDelta); err != nil {
t.Fatalf("failed to unmarshal message_delta: %v", err)
}
found = true
break
}
if !found {
t.Fatal("expected message_delta event")
}
if messageDelta.Usage.InputTokens != 32 {
t.Fatalf("expected aggregated input tokens 32 (12 passthrough + 20 followup), got %d", messageDelta.Usage.InputTokens)
}
if messageDelta.Usage.OutputTokens != 11 {
t.Fatalf("expected aggregated output tokens 11 (4 passthrough + 7 followup), got %d", messageDelta.Usage.OutputTokens)
}
}
func TestWebSearchMixedToolCallsPreferWebSearch(t *testing.T) {
gin.SetMode(gin.TestMode)
enableCloudForTest(t)
followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Content: "Search answer."},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 11, EvalCount: 6},
}
_ = json.NewEncoder(w).Encode(resp)
}))
defer followupServer.Close()
t.Setenv("OLLAMA_HOST", followupServer.URL)
searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := anthropic.OllamaWebSearchResponse{
Results: []anthropic.OllamaWebSearchResult{
{Title: "Result", URL: "https://example.com", Content: "content"},
},
}
_ = json.NewEncoder(w).Encode(resp)
}))
defer searchServer.Close()
originalEndpoint := anthropic.WebSearchEndpoint
anthropic.WebSearchEndpoint = searchServer.URL
defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_other",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: makeArgs("location", "SF"),
},
},
{
ID: "call_ws_mixed",
Function: api.ToolCallFunction{
Name: "web_search",
Arguments: makeArgs("query", "latest weather"),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 2},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{
"model":"test-model:cloud",
"max_tokens":100,
"messages":[{"role":"user","content":"Weather?"}],
"tools":[
{"type":"web_search_20250305","name":"web_search"},
{"type":"custom","name":"get_weather","input_schema":{"type":"object"}}
]
}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
}
var result anthropic.MessagesResponse
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if len(result.Content) < 3 {
t.Fatalf("expected at least 3 blocks, got %d", len(result.Content))
}
if result.Content[0].Type != "server_tool_use" {
t.Fatalf("expected server_tool_use first, got %q", result.Content[0].Type)
}
if result.Content[1].Type != "web_search_tool_result" {
t.Fatalf("expected web_search_tool_result second, got %q", result.Content[1].Type)
}
for _, block := range result.Content {
if block.Type == "tool_use" && block.Name == "get_weather" {
t.Fatalf("did not expect get_weather tool_use in mixed web_search-preferred path: %+v", result.Content)
}
}
}
func TestWebSearchFollowupClientToolStopReasonToolUse(t *testing.T) {
gin.SetMode(gin.TestMode)
enableCloudForTest(t)
followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_weather_final",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: makeArgs("location", "New York"),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 25, EvalCount: 7},
}
_ = json.NewEncoder(w).Encode(resp)
}))
defer followupServer.Close()
t.Setenv("OLLAMA_HOST", followupServer.URL)
searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := anthropic.OllamaWebSearchResponse{
Results: []anthropic.OllamaWebSearchResult{
{Title: "Result", URL: "https://example.com", Content: "content"},
},
}
_ = json.NewEncoder(w).Encode(resp)
}))
defer searchServer.Close()
originalEndpoint := anthropic.WebSearchEndpoint
anthropic.WebSearchEndpoint = searchServer.URL
defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_ws_tool_use",
Function: api.ToolCallFunction{
Name: "web_search",
Arguments: makeArgs("query", "forecast"),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 15, EvalCount: 3},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{
"model":"test-model:cloud",
"max_tokens":100,
"messages":[{"role":"user","content":"Do I need an umbrella?"}],
"tools":[
{"type":"web_search_20250305","name":"web_search"},
{"type":"custom","name":"get_weather","input_schema":{"type":"object"}}
]
}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
}
var result anthropic.MessagesResponse
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if result.StopReason != "tool_use" {
t.Fatalf("expected stop_reason tool_use, got %q", result.StopReason)
}
if len(result.Content) < 3 {
t.Fatalf("expected server blocks + tool_use, got %d blocks", len(result.Content))
}
last := result.Content[len(result.Content)-1]
if last.Type != "tool_use" {
t.Fatalf("expected final block tool_use, got %q", last.Type)
}
if last.Name != "get_weather" {
t.Fatalf("expected final tool name get_weather, got %q", last.Name)
}
if result.Usage.InputTokens != 40 || result.Usage.OutputTokens != 10 {
t.Fatalf("unexpected aggregated usage: %+v", result.Usage)
}
}
func TestWebSearchMultiIterationLoop(t *testing.T) {
gin.SetMode(gin.TestMode)
enableCloudForTest(t)
followupCall := 0
followupDecodeErr := false
missingWebSearchTool := false
followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var followupReq api.ChatRequest
if err := json.NewDecoder(r.Body).Decode(&followupReq); err != nil {
followupDecodeErr = true
http.Error(w, "bad request", http.StatusBadRequest)
return
}
hasWebSearchTool := false
for _, tool := range followupReq.Tools {
if tool.Function.Name == "web_search" {
hasWebSearchTool = true
break
}
}
if !hasWebSearchTool {
missingWebSearchTool = true
}
followupCall++
switch followupCall {
case 1:
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_ws_2",
Function: api.ToolCallFunction{
Name: "web_search",
Arguments: makeArgs("query", "loop query 2"),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 20, EvalCount: 2},
}
_ = json.NewEncoder(w).Encode(resp)
case 2:
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{Role: "assistant", Content: "Final answer after 2 searches."},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 30, EvalCount: 3},
}
_ = json.NewEncoder(w).Encode(resp)
default:
t.Fatalf("unexpected extra followup call: %d", followupCall)
}
}))
defer followupServer.Close()
t.Setenv("OLLAMA_HOST", followupServer.URL)
searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := anthropic.OllamaWebSearchResponse{
Results: []anthropic.OllamaWebSearchResult{
{Title: "Result", URL: "https://example.com", Content: "content"},
},
}
_ = json.NewEncoder(w).Encode(resp)
}))
defer searchServer.Close()
originalEndpoint := anthropic.WebSearchEndpoint
anthropic.WebSearchEndpoint = searchServer.URL
defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_ws_1",
Function: api.ToolCallFunction{
Name: "web_search",
Arguments: makeArgs("query", "loop query 1"),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 1},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{
"model":"test-model:cloud",
"max_tokens":100,
"messages":[{"role":"user","content":"do multiple searches"}],
"tools":[{"type":"web_search_20250305","name":"web_search"}]
}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
}
if followupCall != 2 {
t.Fatalf("expected 2 followup calls, got %d", followupCall)
}
if followupDecodeErr {
t.Fatal("failed to decode followup request body")
}
if missingWebSearchTool {
t.Fatal("expected followup requests to retain web_search tool definition")
}
var result anthropic.MessagesResponse
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
serverToolUses := 0
webResults := 0
for _, block := range result.Content {
if block.Type == "server_tool_use" {
serverToolUses++
}
if block.Type == "web_search_tool_result" {
webResults++
}
}
if serverToolUses != 2 || webResults != 2 {
t.Fatalf("expected two search iterations, got server_tool_use=%d web_search_tool_result=%d", serverToolUses, webResults)
}
if result.Usage.InputTokens != 60 || result.Usage.OutputTokens != 6 {
t.Fatalf("unexpected aggregated usage: %+v", result.Usage)
}
}
func TestWebSearchLoopMaxLimit(t *testing.T) {
gin.SetMode(gin.TestMode)
enableCloudForTest(t)
followupCall := 0
followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
followupCall++
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_ws_loop_limit",
Function: api.ToolCallFunction{
Name: "web_search",
Arguments: makeArgs("query", "loop query next"),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 7, EvalCount: 2},
}
_ = json.NewEncoder(w).Encode(resp)
}))
defer followupServer.Close()
t.Setenv("OLLAMA_HOST", followupServer.URL)
searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := anthropic.OllamaWebSearchResponse{
Results: []anthropic.OllamaWebSearchResult{
{Title: "Result", URL: "https://example.com", Content: "content"},
},
}
_ = json.NewEncoder(w).Encode(resp)
}))
defer searchServer.Close()
originalEndpoint := anthropic.WebSearchEndpoint
anthropic.WebSearchEndpoint = searchServer.URL
defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_ws_initial",
Function: api.ToolCallFunction{
Name: "web_search",
Arguments: makeArgs("query", "loop query 1"),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 5, EvalCount: 1},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{
"model":"test-model:cloud",
"max_tokens":100,
"messages":[{"role":"user","content":"keep searching"}],
"tools":[{"type":"web_search_20250305","name":"web_search"}]
}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
}
if followupCall != 3 {
t.Fatalf("expected 3 followup calls before max loop error, got %d", followupCall)
}
var result anthropic.MessagesResponse
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
last := result.Content[len(result.Content)-1]
if last.Type != "web_search_tool_result" {
t.Fatalf("expected last block web_search_tool_result, got %q", last.Type)
}
contentJSON, _ := json.Marshal(last.Content)
var errContent anthropic.WebSearchToolResultError
if err := json.Unmarshal(contentJSON, &errContent); err != nil {
t.Fatalf("failed to parse web search error content: %v", err)
}
if errContent.ErrorCode != "max_uses_exceeded" {
t.Fatalf("expected max_uses_exceeded error, got %q", errContent.ErrorCode)
}
if result.StopReason != "end_turn" {
t.Fatalf("expected end_turn, got %q", result.StopReason)
}
}
func TestWebSearchStreamingFinalStopReasonToolUse(t *testing.T) {
gin.SetMode(gin.TestMode)
enableCloudForTest(t)
followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_weather_stream",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: makeArgs("location", "Seattle"),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 14, EvalCount: 5},
}
_ = json.NewEncoder(w).Encode(resp)
}))
defer followupServer.Close()
t.Setenv("OLLAMA_HOST", followupServer.URL)
searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := anthropic.OllamaWebSearchResponse{
Results: []anthropic.OllamaWebSearchResult{
{Title: "Result", URL: "https://example.com", Content: "content"},
},
}
_ = json.NewEncoder(w).Encode(resp)
}))
defer searchServer.Close()
originalEndpoint := anthropic.WebSearchEndpoint
anthropic.WebSearchEndpoint = searchServer.URL
defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
chunks := []api.ChatResponse{
{
Model: "test-model",
Message: api.Message{Role: "assistant", Content: "Let me check. "},
Done: false,
},
{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_ws_stream_tool_use",
Function: api.ToolCallFunction{
Name: "web_search",
Arguments: makeArgs("query", "weather seattle"),
},
},
},
},
Done: false,
},
{
Model: "test-model",
Message: api.Message{Role: "assistant"},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 3},
},
}
c.Writer.WriteHeader(http.StatusOK)
for _, chunk := range chunks {
data, _ := json.Marshal(chunk)
_, _ = c.Writer.Write(data)
}
})
body := `{
"model":"test-model:cloud",
"max_tokens":100,
"stream":true,
"messages":[{"role":"user","content":"Should I take a jacket?"}],
"tools":[
{"type":"web_search_20250305","name":"web_search"},
{"type":"custom","name":"get_weather","input_schema":{"type":"object"}}
]
}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
}
events := parseSSEEvents(t, resp.Body.String())
if countEventsByName(events, "message_start") != 1 {
t.Fatalf("expected exactly one message_start, got %d", countEventsByName(events, "message_start"))
}
var messageDelta anthropic.MessageDeltaEvent
foundMessageDelta := false
foundToolUse := false
for _, event := range events {
if event.event == "message_delta" {
foundMessageDelta = true
if err := json.Unmarshal([]byte(event.data), &messageDelta); err != nil {
t.Fatalf("failed to unmarshal message_delta: %v", err)
}
}
if event.event == "content_block_start" {
var start anthropic.ContentBlockStartEvent
if err := json.Unmarshal([]byte(event.data), &start); err != nil {
t.Fatalf("failed to unmarshal content_block_start: %v", err)
}
if start.ContentBlock.Type == "tool_use" && start.ContentBlock.Name == "get_weather" {
foundToolUse = true
}
}
}
if !foundMessageDelta {
t.Fatal("expected message_delta event")
}
if messageDelta.Delta.StopReason != "tool_use" {
t.Fatalf("expected stop_reason tool_use, got %q", messageDelta.Delta.StopReason)
}
if !foundToolUse {
t.Fatal("expected tool_use content block for get_weather")
}
}
func TestWebSearchFollowupNon200ReturnsApiError(t *testing.T) {
gin.SetMode(gin.TestMode)
enableCloudForTest(t)
followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "boom", http.StatusInternalServerError)
}))
defer followupServer.Close()
t.Setenv("OLLAMA_HOST", followupServer.URL)
searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := anthropic.OllamaWebSearchResponse{
Results: []anthropic.OllamaWebSearchResult{
{Title: "Result", URL: "https://example.com", Content: "content"},
},
}
_ = json.NewEncoder(w).Encode(resp)
}))
defer searchServer.Close()
originalEndpoint := anthropic.WebSearchEndpoint
anthropic.WebSearchEndpoint = searchServer.URL
defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
resp := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_ws_non200",
Function: api.ToolCallFunction{
Name: "web_search",
Arguments: makeArgs("query", "test"),
},
},
},
},
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{PromptEvalCount: 9, EvalCount: 1},
}
data, _ := json.Marshal(resp)
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(data)
})
body := `{
"model":"test-model:cloud",
"max_tokens":100,
"messages":[{"role":"user","content":"test"}],
"tools":[{"type":"web_search_20250305","name":"web_search"}]
}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
}
var result anthropic.MessagesResponse
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if len(result.Content) != 2 {
t.Fatalf("expected 2 blocks in error response, got %d", len(result.Content))
}
contentJSON, _ := json.Marshal(result.Content[1].Content)
var errContent anthropic.WebSearchToolResultError
if err := json.Unmarshal(contentJSON, &errContent); err != nil {
t.Fatalf("failed to parse error content: %v", err)
}
if errContent.ErrorCode != "api_error" {
t.Fatalf("expected api_error, got %q", errContent.ErrorCode)
}
if result.Usage.InputTokens != 9 || result.Usage.OutputTokens != 1 {
t.Fatalf("expected usage input=9 output=1, got %+v", result.Usage)
}
}
func countEventsByName(events []sseEvent, eventName string) int {
count := 0
for _, event := range events {
if event.event == eventName {
count++
}
}
return count
}
func collectTextDeltas(t *testing.T, events []sseEvent) []string {
t.Helper()
var deltas []string
for _, event := range events {
if event.event != "content_block_delta" {
continue
}
var delta anthropic.ContentBlockDeltaEvent
if err := json.Unmarshal([]byte(event.data), &delta); err != nil {
t.Fatalf("failed to unmarshal content_block_delta: %v", err)
}
if delta.Delta.Type == "text_delta" {
deltas = append(deltas, delta.Delta.Text)
}
}
return deltas
}
func containsString(values []string, target string) bool {
for _, value := range values {
if value == target {
return true
}
}
return false
}