mirror of
https://github.com/ollama/ollama.git
synced 2026-03-08 23:04:13 -05:00
2980 lines
89 KiB
Go
2980 lines
89 KiB
Go
package middleware
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/go-cmp/cmp"
|
|
"github.com/google/go-cmp/cmp/cmpopts"
|
|
|
|
"github.com/ollama/ollama/anthropic"
|
|
"github.com/ollama/ollama/api"
|
|
)
|
|
|
|
func captureAnthropicRequest(capturedRequest any) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
_ = json.Unmarshal(bodyBytes, capturedRequest)
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// testProps creates ToolPropertiesMap from a map (convenience function for tests)
|
|
func testProps(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
|
props := api.NewToolPropertiesMap()
|
|
for k, v := range m {
|
|
props.Set(k, v)
|
|
}
|
|
return props
|
|
}
|
|
|
|
func TestAnthropicMessagesMiddleware(t *testing.T) {
|
|
type testCase struct {
|
|
name string
|
|
body string
|
|
req api.ChatRequest
|
|
err anthropic.ErrorResponse
|
|
}
|
|
|
|
var capturedRequest *api.ChatRequest
|
|
stream := true
|
|
|
|
testCases := []testCase{
|
|
{
|
|
name: "basic message",
|
|
body: `{
|
|
"model": "test-model",
|
|
"max_tokens": 1024,
|
|
"messages": [
|
|
{"role": "user", "content": "Hello"}
|
|
]
|
|
}`,
|
|
req: api.ChatRequest{
|
|
Model: "test-model",
|
|
Messages: []api.Message{
|
|
{Role: "user", Content: "Hello"},
|
|
},
|
|
Options: map[string]any{"num_predict": 1024},
|
|
Stream: &False,
|
|
},
|
|
},
|
|
{
|
|
name: "with system prompt",
|
|
body: `{
|
|
"model": "test-model",
|
|
"max_tokens": 1024,
|
|
"system": "You are helpful.",
|
|
"messages": [
|
|
{"role": "user", "content": "Hello"}
|
|
]
|
|
}`,
|
|
req: api.ChatRequest{
|
|
Model: "test-model",
|
|
Messages: []api.Message{
|
|
{Role: "system", Content: "You are helpful."},
|
|
{Role: "user", Content: "Hello"},
|
|
},
|
|
Options: map[string]any{"num_predict": 1024},
|
|
Stream: &False,
|
|
},
|
|
},
|
|
{
|
|
name: "with options",
|
|
body: `{
|
|
"model": "test-model",
|
|
"max_tokens": 2048,
|
|
"temperature": 0.7,
|
|
"top_p": 0.9,
|
|
"top_k": 40,
|
|
"stop_sequences": ["\n", "END"],
|
|
"messages": [
|
|
{"role": "user", "content": "Hello"}
|
|
]
|
|
}`,
|
|
req: api.ChatRequest{
|
|
Model: "test-model",
|
|
Messages: []api.Message{
|
|
{Role: "user", Content: "Hello"},
|
|
},
|
|
Options: map[string]any{
|
|
"num_predict": 2048,
|
|
"temperature": 0.7,
|
|
"top_p": 0.9,
|
|
"top_k": 40,
|
|
"stop": []string{"\n", "END"},
|
|
},
|
|
Stream: &False,
|
|
},
|
|
},
|
|
{
|
|
name: "streaming",
|
|
body: `{
|
|
"model": "test-model",
|
|
"max_tokens": 1024,
|
|
"stream": true,
|
|
"messages": [
|
|
{"role": "user", "content": "Hello"}
|
|
]
|
|
}`,
|
|
req: api.ChatRequest{
|
|
Model: "test-model",
|
|
Messages: []api.Message{
|
|
{Role: "user", Content: "Hello"},
|
|
},
|
|
Options: map[string]any{"num_predict": 1024},
|
|
Stream: &stream,
|
|
},
|
|
},
|
|
{
|
|
name: "with tools",
|
|
body: `{
|
|
"model": "test-model",
|
|
"max_tokens": 1024,
|
|
"messages": [
|
|
{"role": "user", "content": "What's the weather?"}
|
|
],
|
|
"tools": [{
|
|
"name": "get_weather",
|
|
"description": "Get current weather",
|
|
"input_schema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"location": {"type": "string"}
|
|
},
|
|
"required": ["location"]
|
|
}
|
|
}]
|
|
}`,
|
|
req: api.ChatRequest{
|
|
Model: "test-model",
|
|
Messages: []api.Message{
|
|
{Role: "user", Content: "What's the weather?"},
|
|
},
|
|
Tools: []api.Tool{
|
|
{
|
|
Type: "function",
|
|
Function: api.ToolFunction{
|
|
Name: "get_weather",
|
|
Description: "Get current weather",
|
|
Parameters: api.ToolFunctionParameters{
|
|
Type: "object",
|
|
Required: []string{"location"},
|
|
Properties: testProps(map[string]api.ToolProperty{
|
|
"location": {Type: api.PropertyType{"string"}},
|
|
}),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Options: map[string]any{"num_predict": 1024},
|
|
Stream: &False,
|
|
},
|
|
},
|
|
{
|
|
name: "with tool result",
|
|
body: `{
|
|
"model": "test-model",
|
|
"max_tokens": 1024,
|
|
"messages": [
|
|
{"role": "user", "content": "What's the weather?"},
|
|
{"role": "assistant", "content": [
|
|
{"type": "tool_use", "id": "call_123", "name": "get_weather", "input": {"location": "Paris"}}
|
|
]},
|
|
{"role": "user", "content": [
|
|
{"type": "tool_result", "tool_use_id": "call_123", "content": "Sunny, 22°C"}
|
|
]}
|
|
]
|
|
}`,
|
|
req: api.ChatRequest{
|
|
Model: "test-model",
|
|
Messages: []api.Message{
|
|
{Role: "user", Content: "What's the weather?"},
|
|
{
|
|
Role: "assistant",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "call_123",
|
|
Function: api.ToolCallFunction{
|
|
Name: "get_weather",
|
|
Arguments: testArgs(map[string]any{"location": "Paris"}),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call_123"},
|
|
},
|
|
Options: map[string]any{"num_predict": 1024},
|
|
Stream: &False,
|
|
},
|
|
},
|
|
{
|
|
name: "with thinking enabled",
|
|
body: `{
|
|
"model": "test-model",
|
|
"max_tokens": 1024,
|
|
"thinking": {"type": "enabled", "budget_tokens": 1000},
|
|
"messages": [
|
|
{"role": "user", "content": "Hello"}
|
|
]
|
|
}`,
|
|
req: api.ChatRequest{
|
|
Model: "test-model",
|
|
Messages: []api.Message{
|
|
{Role: "user", Content: "Hello"},
|
|
},
|
|
Options: map[string]any{"num_predict": 1024},
|
|
Stream: &False,
|
|
Think: &api.ThinkValue{Value: true},
|
|
},
|
|
},
|
|
{
|
|
name: "missing model error",
|
|
body: `{
|
|
"max_tokens": 1024,
|
|
"messages": [
|
|
{"role": "user", "content": "Hello"}
|
|
]
|
|
}`,
|
|
err: anthropic.ErrorResponse{
|
|
Type: "error",
|
|
Error: anthropic.Error{
|
|
Type: "invalid_request_error",
|
|
Message: "model is required",
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "missing max_tokens error",
|
|
body: `{
|
|
"model": "test-model",
|
|
"messages": [
|
|
{"role": "user", "content": "Hello"}
|
|
]
|
|
}`,
|
|
err: anthropic.ErrorResponse{
|
|
Type: "error",
|
|
Error: anthropic.Error{
|
|
Type: "invalid_request_error",
|
|
Message: "max_tokens is required and must be positive",
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "missing messages error",
|
|
body: `{
|
|
"model": "test-model",
|
|
"max_tokens": 1024
|
|
}`,
|
|
err: anthropic.ErrorResponse{
|
|
Type: "error",
|
|
Error: anthropic.Error{
|
|
Type: "invalid_request_error",
|
|
Message: "messages is required",
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "tool_use missing id error",
|
|
body: `{
|
|
"model": "test-model",
|
|
"max_tokens": 1024,
|
|
"messages": [
|
|
{"role": "assistant", "content": [
|
|
{"type": "tool_use", "name": "test"}
|
|
]}
|
|
]
|
|
}`,
|
|
err: anthropic.ErrorResponse{
|
|
Type: "error",
|
|
Error: anthropic.Error{
|
|
Type: "invalid_request_error",
|
|
Message: "tool_use block missing required 'id' field",
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
endpoint := func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
}
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware(), captureAnthropicRequest(&capturedRequest))
|
|
router.Handle(http.MethodPost, "/v1/messages", endpoint)
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(tc.body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
defer func() { capturedRequest = nil }()
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if tc.err.Type != "" {
|
|
// Expect error
|
|
if resp.Code == http.StatusOK {
|
|
t.Fatalf("expected error response, got 200 OK")
|
|
}
|
|
var errResp anthropic.ErrorResponse
|
|
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
|
t.Fatalf("failed to unmarshal error: %v", err)
|
|
}
|
|
if errResp.Type != tc.err.Type {
|
|
t.Errorf("expected error type %q, got %q", tc.err.Type, errResp.Type)
|
|
}
|
|
if errResp.Error.Type != tc.err.Error.Type {
|
|
t.Errorf("expected error.type %q, got %q", tc.err.Error.Type, errResp.Error.Type)
|
|
}
|
|
if errResp.Error.Message != tc.err.Error.Message {
|
|
t.Errorf("expected error.message %q, got %q", tc.err.Error.Message, errResp.Error.Message)
|
|
}
|
|
return
|
|
}
|
|
|
|
if resp.Code != http.StatusOK {
|
|
t.Fatalf("unexpected status code: %d, body: %s", resp.Code, resp.Body.String())
|
|
}
|
|
|
|
if capturedRequest == nil {
|
|
t.Fatal("request was not captured")
|
|
}
|
|
|
|
// Compare relevant fields
|
|
if capturedRequest.Model != tc.req.Model {
|
|
t.Errorf("model mismatch: got %q, want %q", capturedRequest.Model, tc.req.Model)
|
|
}
|
|
|
|
if diff := cmp.Diff(tc.req.Messages, capturedRequest.Messages,
|
|
cmpopts.IgnoreUnexported(api.ToolCallFunctionArguments{}, api.ToolPropertiesMap{})); diff != "" {
|
|
t.Errorf("messages mismatch (-want +got):\n%s", diff)
|
|
}
|
|
|
|
if tc.req.Stream != nil && capturedRequest.Stream != nil {
|
|
if *tc.req.Stream != *capturedRequest.Stream {
|
|
t.Errorf("stream mismatch: got %v, want %v", *capturedRequest.Stream, *tc.req.Stream)
|
|
}
|
|
}
|
|
|
|
if tc.req.Think != nil {
|
|
if capturedRequest.Think == nil {
|
|
t.Error("expected Think to be set")
|
|
} else if capturedRequest.Think.Value != tc.req.Think.Value {
|
|
t.Errorf("Think mismatch: got %v, want %v", capturedRequest.Think.Value, tc.req.Think.Value)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAnthropicMessagesMiddleware_Headers(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
t.Run("streaming sets correct headers", func(t *testing.T) {
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
// Check headers were set
|
|
if c.Writer.Header().Get("Content-Type") != "text/event-stream" {
|
|
t.Errorf("expected Content-Type text/event-stream, got %q", c.Writer.Header().Get("Content-Type"))
|
|
}
|
|
if c.Writer.Header().Get("Cache-Control") != "no-cache" {
|
|
t.Errorf("expected Cache-Control no-cache, got %q", c.Writer.Header().Get("Cache-Control"))
|
|
}
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
body := `{"model": "test", "max_tokens": 100, "stream": true, "messages": [{"role": "user", "content": "Hi"}]}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
})
|
|
}
|
|
|
|
func TestAnthropicMessagesMiddleware_InvalidJSON(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{invalid json`))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != http.StatusBadRequest {
|
|
t.Errorf("expected status 400, got %d", resp.Code)
|
|
}
|
|
|
|
var errResp anthropic.ErrorResponse
|
|
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
|
t.Fatalf("failed to unmarshal error: %v", err)
|
|
}
|
|
|
|
if errResp.Type != "error" {
|
|
t.Errorf("expected type 'error', got %q", errResp.Type)
|
|
}
|
|
if errResp.Error.Type != "invalid_request_error" {
|
|
t.Errorf("expected error type 'invalid_request_error', got %q", errResp.Error.Type)
|
|
}
|
|
}
|
|
|
|
func TestAnthropicWriter_NonStreaming(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
// Simulate Ollama response
|
|
resp := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{
|
|
Role: "assistant",
|
|
Content: "Hello there!",
|
|
},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{
|
|
PromptEvalCount: 10,
|
|
EvalCount: 5,
|
|
},
|
|
}
|
|
data, _ := json.Marshal(resp)
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
_, _ = c.Writer.Write(data)
|
|
})
|
|
|
|
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d", resp.Code)
|
|
}
|
|
|
|
var result anthropic.MessagesResponse
|
|
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
|
|
t.Fatalf("failed to unmarshal response: %v", err)
|
|
}
|
|
|
|
if result.Type != "message" {
|
|
t.Errorf("expected type 'message', got %q", result.Type)
|
|
}
|
|
if result.Role != "assistant" {
|
|
t.Errorf("expected role 'assistant', got %q", result.Role)
|
|
}
|
|
if len(result.Content) != 1 {
|
|
t.Fatalf("expected 1 content block, got %d", len(result.Content))
|
|
}
|
|
if result.Content[0].Text == nil || *result.Content[0].Text != "Hello there!" {
|
|
t.Errorf("expected text 'Hello there!', got %v", result.Content[0].Text)
|
|
}
|
|
if result.StopReason != "end_turn" {
|
|
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
|
|
}
|
|
if result.Usage.InputTokens != 10 {
|
|
t.Errorf("expected input_tokens 10, got %d", result.Usage.InputTokens)
|
|
}
|
|
if result.Usage.OutputTokens != 5 {
|
|
t.Errorf("expected output_tokens 5, got %d", result.Usage.OutputTokens)
|
|
}
|
|
}
|
|
|
|
// TestAnthropicWriter_ErrorFromRoutes tests error handling when routes.go sends
|
|
// gin.H{"error": "message"} without a StatusCode field (which is the common case)
|
|
func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
tests := []struct {
|
|
name string
|
|
statusCode int
|
|
errorPayload any
|
|
wantErrorType string
|
|
wantMessage string
|
|
}{
|
|
// routes.go sends errors without StatusCode in JSON, so we must use HTTP status
|
|
{
|
|
name: "404 with gin.H error (model not found)",
|
|
statusCode: http.StatusNotFound,
|
|
errorPayload: gin.H{"error": "model 'nonexistent' not found"},
|
|
wantErrorType: "not_found_error",
|
|
wantMessage: "model 'nonexistent' not found",
|
|
},
|
|
{
|
|
name: "400 with gin.H error (bad request)",
|
|
statusCode: http.StatusBadRequest,
|
|
errorPayload: gin.H{"error": "model is required"},
|
|
wantErrorType: "invalid_request_error",
|
|
wantMessage: "model is required",
|
|
},
|
|
{
|
|
name: "500 with gin.H error (internal error)",
|
|
statusCode: http.StatusInternalServerError,
|
|
errorPayload: gin.H{"error": "something went wrong"},
|
|
wantErrorType: "api_error",
|
|
wantMessage: "something went wrong",
|
|
},
|
|
{
|
|
name: "404 with api.StatusError",
|
|
statusCode: http.StatusNotFound,
|
|
errorPayload: api.StatusError{
|
|
StatusCode: http.StatusNotFound,
|
|
ErrorMessage: "model not found via StatusError",
|
|
},
|
|
wantErrorType: "not_found_error",
|
|
wantMessage: "model not found via StatusError",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
// Simulate what routes.go does - set status and write error JSON
|
|
data, _ := json.Marshal(tt.errorPayload)
|
|
c.Writer.WriteHeader(tt.statusCode)
|
|
_, _ = c.Writer.Write(data)
|
|
})
|
|
|
|
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != tt.statusCode {
|
|
t.Errorf("expected status %d, got %d", tt.statusCode, resp.Code)
|
|
}
|
|
|
|
var errResp anthropic.ErrorResponse
|
|
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
|
t.Fatalf("failed to unmarshal error response: %v\nbody: %s", err, resp.Body.String())
|
|
}
|
|
|
|
if errResp.Type != "error" {
|
|
t.Errorf("expected type 'error', got %q", errResp.Type)
|
|
}
|
|
if errResp.Error.Type != tt.wantErrorType {
|
|
t.Errorf("expected error type %q, got %q", tt.wantErrorType, errResp.Error.Type)
|
|
}
|
|
if errResp.Error.Message != tt.wantMessage {
|
|
t.Errorf("expected message %q, got %q", tt.wantMessage, errResp.Error.Message)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAnthropicMessagesMiddleware_SetsRelaxThinkingFlag(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
var flagSet bool
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
_, flagSet = c.Get("relax_thinking")
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if !flagSet {
|
|
t.Error("expected relax_thinking flag to be set in context")
|
|
}
|
|
}
|
|
|
|
// Web Search Tests
|
|
|
|
func TestHasWebSearchTool(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
tools []anthropic.Tool
|
|
expected bool
|
|
}{
|
|
{
|
|
name: "no tools",
|
|
tools: nil,
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "regular tool only",
|
|
tools: []anthropic.Tool{
|
|
{Type: "custom", Name: "get_weather"},
|
|
},
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "web search tool",
|
|
tools: []anthropic.Tool{
|
|
{Type: "web_search_20250305", Name: "web_search"},
|
|
},
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "mixed tools",
|
|
tools: []anthropic.Tool{
|
|
{Type: "custom", Name: "get_weather"},
|
|
{Type: "web_search_20250305", Name: "web_search"},
|
|
},
|
|
expected: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := hasWebSearchTool(tt.tools)
|
|
if result != tt.expected {
|
|
t.Errorf("expected %v, got %v", tt.expected, result)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestExtractQueryFromToolCall(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
tc *api.ToolCall
|
|
expected string
|
|
}{
|
|
{
|
|
name: "valid query",
|
|
tc: &api.ToolCall{
|
|
Function: api.ToolCallFunction{
|
|
Name: "web_search",
|
|
Arguments: makeArgs("query", "test search"),
|
|
},
|
|
},
|
|
expected: "test search",
|
|
},
|
|
{
|
|
name: "empty arguments",
|
|
tc: &api.ToolCall{
|
|
Function: api.ToolCallFunction{
|
|
Name: "web_search",
|
|
},
|
|
},
|
|
expected: "",
|
|
},
|
|
{
|
|
name: "no query key",
|
|
tc: &api.ToolCall{
|
|
Function: api.ToolCallFunction{
|
|
Name: "web_search",
|
|
Arguments: makeArgs("other", "value"),
|
|
},
|
|
},
|
|
expected: "",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := extractQueryFromToolCall(tt.tc)
|
|
if result != tt.expected {
|
|
t.Errorf("expected %q, got %q", tt.expected, result)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// makeArgs is a test helper that creates ToolCallFunctionArguments
|
|
func makeArgs(key string, value any) api.ToolCallFunctionArguments {
|
|
args := api.NewToolCallFunctionArguments()
|
|
args.Set(key, value)
|
|
return args
|
|
}
|
|
|
|
// --- Web Search Integration Tests ---
|
|
|
|
// TestWebSearchServerToolUseID tests the ID derivation logic.
|
|
func TestWebSearchServerToolUseID(t *testing.T) {
|
|
tests := []struct {
|
|
msgID string
|
|
expected string
|
|
}{
|
|
{"msg_abc123", "srvtoolu_abc123"},
|
|
{"msg_", "srvtoolu_"},
|
|
{"nomsgprefix", "srvtoolu_nomsgprefix"},
|
|
}
|
|
for _, tt := range tests {
|
|
got := serverToolUseID(tt.msgID)
|
|
if got != tt.expected {
|
|
t.Errorf("serverToolUseID(%q) = %q, want %q", tt.msgID, got, tt.expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestWebSearchNoWebSearchTool verifies that when there is no web_search tool,
|
|
// requests pass through to the normal AnthropicWriter without interception.
|
|
func TestWebSearchNoWebSearchTool(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
resp := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{
|
|
Role: "assistant",
|
|
Content: "Normal response",
|
|
},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
|
}
|
|
data, _ := json.Marshal(resp)
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
_, _ = c.Writer.Write(data)
|
|
})
|
|
|
|
body := `{"model":"test-model","max_tokens":100,"messages":[{"role":"user","content":"Hello"}]}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
|
|
}
|
|
|
|
var result anthropic.MessagesResponse
|
|
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
|
|
t.Fatalf("unmarshal error: %v", err)
|
|
}
|
|
|
|
if result.Type != "message" {
|
|
t.Errorf("expected type 'message', got %q", result.Type)
|
|
}
|
|
if len(result.Content) != 1 || result.Content[0].Type != "text" {
|
|
t.Fatalf("expected single text block, got %d blocks", len(result.Content))
|
|
}
|
|
if *result.Content[0].Text != "Normal response" {
|
|
t.Errorf("expected text 'Normal response', got %q", *result.Content[0].Text)
|
|
}
|
|
}
|
|
|
|
// TestWebSearchToolPresent_ModelDoesNotCallIt_NonStreaming verifies that when
|
|
// the web_search tool is present but the model does not call it, the response
|
|
// passes through normally (non-streaming case).
|
|
func TestWebSearchToolPresent_ModelDoesNotCallIt_NonStreaming(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
enableCloudForTest(t)
|
|
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
resp := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{
|
|
Role: "assistant",
|
|
Content: "I can answer that without searching.",
|
|
},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 12, EvalCount: 8},
|
|
}
|
|
data, _ := json.Marshal(resp)
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
_, _ = c.Writer.Write(data)
|
|
})
|
|
|
|
body := `{
|
|
"model":"test-model:cloud",
|
|
"max_tokens":100,
|
|
"messages":[{"role":"user","content":"What is 2+2?"}],
|
|
"tools":[{"type":"web_search_20250305","name":"web_search"}]
|
|
}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
|
|
}
|
|
|
|
var result anthropic.MessagesResponse
|
|
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
|
|
t.Fatalf("unmarshal error: %v", err)
|
|
}
|
|
|
|
if result.Type != "message" {
|
|
t.Errorf("expected type 'message', got %q", result.Type)
|
|
}
|
|
if len(result.Content) != 1 || result.Content[0].Type != "text" {
|
|
t.Fatalf("expected single text block, got %+v", result.Content)
|
|
}
|
|
if *result.Content[0].Text != "I can answer that without searching." {
|
|
t.Errorf("unexpected text: %q", *result.Content[0].Text)
|
|
}
|
|
if result.StopReason != "end_turn" {
|
|
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
|
|
}
|
|
}
|
|
|
|
// TestWebSearchToolPresent_ModelDoesNotCallIt_Streaming verifies the streaming
|
|
// pass-through case when the model does not invoke web_search.
|
|
func TestWebSearchToolPresent_ModelDoesNotCallIt_Streaming(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
enableCloudForTest(t)
|
|
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
// Simulate streaming: two partial chunks then a final chunk
|
|
chunks := []api.ChatResponse{
|
|
{
|
|
Model: "test-model",
|
|
Message: api.Message{Role: "assistant", Content: "Hello "},
|
|
Done: false,
|
|
},
|
|
{
|
|
Model: "test-model",
|
|
Message: api.Message{Role: "assistant", Content: "world"},
|
|
Done: false,
|
|
},
|
|
{
|
|
Model: "test-model",
|
|
Message: api.Message{Role: "assistant", Content: ""},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
|
},
|
|
}
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
for _, chunk := range chunks {
|
|
data, _ := json.Marshal(chunk)
|
|
_, _ = c.Writer.Write(data)
|
|
}
|
|
})
|
|
|
|
body := `{
|
|
"model":"test-model:cloud",
|
|
"max_tokens":100,
|
|
"stream":true,
|
|
"messages":[{"role":"user","content":"Hi"}],
|
|
"tools":[{"type":"web_search_20250305","name":"web_search"}]
|
|
}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
|
|
}
|
|
|
|
// Parse SSE events
|
|
events := parseSSEEvents(t, resp.Body.String())
|
|
|
|
// Should have standard streaming event flow
|
|
if len(events) == 0 {
|
|
t.Fatal("expected SSE events, got none")
|
|
}
|
|
|
|
// First event should be message_start
|
|
if events[0].event != "message_start" {
|
|
t.Errorf("first event should be message_start, got %q", events[0].event)
|
|
}
|
|
|
|
// Should have content_block_start for text
|
|
hasTextStart := false
|
|
hasTextDelta := false
|
|
hasMessageStop := false
|
|
for _, e := range events {
|
|
if e.event == "content_block_start" {
|
|
var cbs anthropic.ContentBlockStartEvent
|
|
if err := json.Unmarshal([]byte(e.data), &cbs); err == nil {
|
|
if cbs.ContentBlock.Type == "text" {
|
|
hasTextStart = true
|
|
}
|
|
}
|
|
}
|
|
if e.event == "content_block_delta" {
|
|
var cbd anthropic.ContentBlockDeltaEvent
|
|
if err := json.Unmarshal([]byte(e.data), &cbd); err == nil {
|
|
if cbd.Delta.Type == "text_delta" {
|
|
hasTextDelta = true
|
|
}
|
|
}
|
|
}
|
|
if e.event == "message_stop" {
|
|
hasMessageStop = true
|
|
}
|
|
}
|
|
if !hasTextStart {
|
|
t.Error("expected content_block_start with text type")
|
|
}
|
|
if !hasTextDelta {
|
|
t.Error("expected content_block_delta with text_delta")
|
|
}
|
|
if !hasMessageStop {
|
|
t.Error("expected message_stop event")
|
|
}
|
|
}
|
|
|
|
// TestWebSearchToolPresent_ModelCallsIt_NonStreaming tests the full web search flow
|
|
// in non-streaming mode. It mocks the followup /api/chat call using a local HTTP server.
|
|
func TestWebSearchToolPresent_ModelCallsIt_NonStreaming(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
enableCloudForTest(t)
|
|
|
|
// Create a mock Ollama server that responds to the followup /api/chat call
|
|
followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
resp := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{
|
|
Role: "assistant",
|
|
Content: "Based on my search, the answer is 42.",
|
|
},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 50, EvalCount: 20},
|
|
}
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer followupServer.Close()
|
|
|
|
// Set OLLAMA_HOST to our mock server so the followup call goes there
|
|
t.Setenv("OLLAMA_HOST", followupServer.URL)
|
|
|
|
// Also mock the web search API
|
|
searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
resp := anthropic.OllamaWebSearchResponse{
|
|
Results: []anthropic.OllamaWebSearchResult{
|
|
{Title: "Test Result", URL: "https://example.com/result", Content: "Some content"},
|
|
},
|
|
}
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer searchServer.Close()
|
|
|
|
// Point DoWebSearch at our mock search server
|
|
originalEndpoint := anthropic.WebSearchEndpoint
|
|
anthropic.WebSearchEndpoint = searchServer.URL
|
|
defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
|
|
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
resp := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{
|
|
Role: "assistant",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "call_ws_001",
|
|
Function: api.ToolCallFunction{
|
|
Name: "web_search",
|
|
Arguments: makeArgs("query", "meaning of life"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 15, EvalCount: 3},
|
|
}
|
|
data, _ := json.Marshal(resp)
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
_, _ = c.Writer.Write(data)
|
|
})
|
|
|
|
body := `{
|
|
"model":"test-model:cloud",
|
|
"max_tokens":100,
|
|
"messages":[{"role":"user","content":"What is the meaning of life?"}],
|
|
"tools":[{"type":"web_search_20250305","name":"web_search"}]
|
|
}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
|
|
}
|
|
|
|
var result anthropic.MessagesResponse
|
|
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
|
|
t.Fatalf("unmarshal error: %v\nbody: %s", err, resp.Body.String())
|
|
}
|
|
|
|
if result.Type != "message" {
|
|
t.Errorf("expected type 'message', got %q", result.Type)
|
|
}
|
|
if result.Role != "assistant" {
|
|
t.Errorf("expected role 'assistant', got %q", result.Role)
|
|
}
|
|
|
|
// Should have 3 blocks: server_tool_use + web_search_tool_result + text
|
|
if len(result.Content) != 3 {
|
|
t.Fatalf("expected 3 content blocks, got %d: %+v", len(result.Content), result.Content)
|
|
}
|
|
|
|
if result.Content[0].Type != "server_tool_use" {
|
|
t.Errorf("expected first block type 'server_tool_use', got %q", result.Content[0].Type)
|
|
}
|
|
if result.Content[0].Name != "web_search" {
|
|
t.Errorf("expected name 'web_search', got %q", result.Content[0].Name)
|
|
}
|
|
|
|
if result.Content[1].Type != "web_search_tool_result" {
|
|
t.Errorf("expected second block type 'web_search_tool_result', got %q", result.Content[1].Type)
|
|
}
|
|
if result.Content[1].ToolUseID != result.Content[0].ID {
|
|
t.Errorf("tool_use_id mismatch: %q != %q", result.Content[1].ToolUseID, result.Content[0].ID)
|
|
}
|
|
|
|
if result.Content[2].Type != "text" {
|
|
t.Errorf("expected third block type 'text', got %q", result.Content[2].Type)
|
|
}
|
|
if result.Content[2].Text == nil || *result.Content[2].Text == "" {
|
|
t.Error("expected non-empty text in third block")
|
|
}
|
|
|
|
if result.StopReason != "end_turn" {
|
|
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
|
|
}
|
|
}
|
|
|
|
// TestWebSearchToolPresent_ModelCallsIt_Streaming tests the streaming SSE output
|
|
// when the model calls web_search with mocked search and followup endpoints.
|
|
func TestWebSearchToolPresent_ModelCallsIt_Streaming(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
enableCloudForTest(t)
|
|
|
|
// Mock followup /api/chat server
|
|
followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
resp := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{Role: "assistant", Content: "Here are the latest news."},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 40, EvalCount: 15},
|
|
}
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer followupServer.Close()
|
|
t.Setenv("OLLAMA_HOST", followupServer.URL)
|
|
|
|
// Mock web search API
|
|
searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
resp := anthropic.OllamaWebSearchResponse{
|
|
Results: []anthropic.OllamaWebSearchResult{
|
|
{Title: "News Result", URL: "https://example.com/news", Content: "Breaking news"},
|
|
},
|
|
}
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer searchServer.Close()
|
|
originalEndpoint := anthropic.WebSearchEndpoint
|
|
anthropic.WebSearchEndpoint = searchServer.URL
|
|
defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
|
|
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
// Simulate buffered streaming: non-final chunk then final with tool call
|
|
chunks := []api.ChatResponse{
|
|
{
|
|
Model: "test-model",
|
|
Message: api.Message{Role: "assistant"},
|
|
Done: false,
|
|
},
|
|
{
|
|
Model: "test-model",
|
|
Message: api.Message{
|
|
Role: "assistant",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "call_ws_002",
|
|
Function: api.ToolCallFunction{
|
|
Name: "web_search",
|
|
Arguments: makeArgs("query", "latest news"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 2},
|
|
},
|
|
}
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
for _, chunk := range chunks {
|
|
data, _ := json.Marshal(chunk)
|
|
_, _ = c.Writer.Write(data)
|
|
}
|
|
})
|
|
|
|
body := `{
|
|
"model":"test-model:cloud",
|
|
"max_tokens":100,
|
|
"stream":true,
|
|
"messages":[{"role":"user","content":"What is the latest news?"}],
|
|
"tools":[{"type":"web_search_20250305","name":"web_search"}]
|
|
}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
|
|
}
|
|
|
|
events := parseSSEEvents(t, resp.Body.String())
|
|
|
|
// Success path: 10 events (3 blocks: server_tool_use, web_search_tool_result, text with delta)
|
|
expectedEventTypes := []string{
|
|
"message_start",
|
|
"content_block_start", // server_tool_use
|
|
"content_block_stop",
|
|
"content_block_start", // web_search_tool_result
|
|
"content_block_stop",
|
|
"content_block_start", // text (empty)
|
|
"content_block_delta", // text_delta with actual content
|
|
"content_block_stop",
|
|
"message_delta",
|
|
"message_stop",
|
|
}
|
|
|
|
if len(events) != len(expectedEventTypes) {
|
|
t.Fatalf("expected %d events, got %d.\nEvents: %v", len(expectedEventTypes), len(events), eventNames(events))
|
|
}
|
|
|
|
for i, expected := range expectedEventTypes {
|
|
if events[i].event != expected {
|
|
t.Errorf("event[%d]: expected %q, got %q", i, expected, events[i].event)
|
|
}
|
|
}
|
|
|
|
// Verify text delta has the followup model's content
|
|
var textDelta anthropic.ContentBlockDeltaEvent
|
|
if err := json.Unmarshal([]byte(events[6].data), &textDelta); err != nil {
|
|
t.Fatalf("failed to parse text delta: %v", err)
|
|
}
|
|
if textDelta.Delta.Type != "text_delta" {
|
|
t.Errorf("expected delta type 'text_delta', got %q", textDelta.Delta.Type)
|
|
}
|
|
if textDelta.Delta.Text != "Here are the latest news." {
|
|
t.Errorf("expected followup text, got %q", textDelta.Delta.Text)
|
|
}
|
|
}
|
|
|
|
// TestWebSearchStreamResponse tests the streamResponse method directly by constructing
|
|
// a WebSearchAnthropicWriter and calling streamResponse with a known response.
|
|
func TestWebSearchStreamResponse(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
text := "Here is the answer."
|
|
|
|
response := anthropic.MessagesResponse{
|
|
ID: "msg_test123",
|
|
Type: "message",
|
|
Role: "assistant",
|
|
Model: "test-model",
|
|
Content: []anthropic.ContentBlock{
|
|
{
|
|
Type: "server_tool_use",
|
|
ID: "srvtoolu_test123",
|
|
Name: "web_search",
|
|
Input: map[string]any{"query": "test query"},
|
|
},
|
|
{
|
|
Type: "web_search_tool_result",
|
|
ToolUseID: "srvtoolu_test123",
|
|
Content: []anthropic.WebSearchResult{
|
|
{Type: "web_search_result", URL: "https://example.com", Title: "Example"},
|
|
},
|
|
},
|
|
{
|
|
Type: "text",
|
|
Text: &text,
|
|
},
|
|
},
|
|
StopReason: "end_turn",
|
|
Usage: anthropic.Usage{InputTokens: 20, OutputTokens: 10},
|
|
}
|
|
|
|
rec := httptest.NewRecorder()
|
|
ginCtx, _ := gin.CreateTestContext(rec)
|
|
|
|
innerWriter := &AnthropicWriter{
|
|
BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer},
|
|
stream: true,
|
|
id: "msg_test123",
|
|
}
|
|
wsWriter := &WebSearchAnthropicWriter{
|
|
BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer},
|
|
inner: innerWriter,
|
|
stream: true,
|
|
req: anthropic.MessagesRequest{Model: "test-model"},
|
|
}
|
|
|
|
if err := wsWriter.streamResponse(response); err != nil {
|
|
t.Fatalf("streamResponse error: %v", err)
|
|
}
|
|
|
|
events := parseSSEEvents(t, rec.Body.String())
|
|
|
|
// Verify full event sequence
|
|
expectedEventTypes := []string{
|
|
"message_start",
|
|
"content_block_start", // server_tool_use (index 0)
|
|
"content_block_stop", // index 0
|
|
"content_block_start", // web_search_tool_result (index 1)
|
|
"content_block_stop", // index 1
|
|
"content_block_start", // text (index 2)
|
|
"content_block_delta", // text_delta
|
|
"content_block_stop", // index 2
|
|
"message_delta",
|
|
"message_stop",
|
|
}
|
|
|
|
if len(events) != len(expectedEventTypes) {
|
|
t.Fatalf("expected %d events, got %d.\nEvents: %v", len(expectedEventTypes), len(events), eventNames(events))
|
|
}
|
|
|
|
for i, expected := range expectedEventTypes {
|
|
if events[i].event != expected {
|
|
t.Errorf("event[%d]: expected %q, got %q", i, expected, events[i].event)
|
|
}
|
|
}
|
|
|
|
// Verify message_start content
|
|
var msgStart anthropic.MessageStartEvent
|
|
if err := json.Unmarshal([]byte(events[0].data), &msgStart); err != nil {
|
|
t.Fatalf("failed to parse message_start: %v", err)
|
|
}
|
|
if msgStart.Message.ID != "msg_test123" {
|
|
t.Errorf("expected message ID 'msg_test123', got %q", msgStart.Message.ID)
|
|
}
|
|
if msgStart.Message.Role != "assistant" {
|
|
t.Errorf("expected role 'assistant', got %q", msgStart.Message.Role)
|
|
}
|
|
if len(msgStart.Message.Content) != 0 {
|
|
t.Errorf("expected empty content in message_start, got %d blocks", len(msgStart.Message.Content))
|
|
}
|
|
|
|
// Verify content_block_start for server_tool_use (event index 1)
|
|
var toolStart anthropic.ContentBlockStartEvent
|
|
if err := json.Unmarshal([]byte(events[1].data), &toolStart); err != nil {
|
|
t.Fatalf("failed to parse server_tool_use start: %v", err)
|
|
}
|
|
if toolStart.Index != 0 {
|
|
t.Errorf("expected index 0, got %d", toolStart.Index)
|
|
}
|
|
if toolStart.ContentBlock.Type != "server_tool_use" {
|
|
t.Errorf("expected type 'server_tool_use', got %q", toolStart.ContentBlock.Type)
|
|
}
|
|
if toolStart.ContentBlock.ID != "srvtoolu_test123" {
|
|
t.Errorf("expected ID 'srvtoolu_test123', got %q", toolStart.ContentBlock.ID)
|
|
}
|
|
|
|
// Verify content_block_start for web_search_tool_result (event index 3)
|
|
var searchStart anthropic.ContentBlockStartEvent
|
|
if err := json.Unmarshal([]byte(events[3].data), &searchStart); err != nil {
|
|
t.Fatalf("failed to parse web_search_tool_result start: %v", err)
|
|
}
|
|
if searchStart.Index != 1 {
|
|
t.Errorf("expected index 1, got %d", searchStart.Index)
|
|
}
|
|
if searchStart.ContentBlock.Type != "web_search_tool_result" {
|
|
t.Errorf("expected type 'web_search_tool_result', got %q", searchStart.ContentBlock.Type)
|
|
}
|
|
|
|
// Verify text block: content_block_start (event index 5)
|
|
var textStart anthropic.ContentBlockStartEvent
|
|
if err := json.Unmarshal([]byte(events[5].data), &textStart); err != nil {
|
|
t.Fatalf("failed to parse text start: %v", err)
|
|
}
|
|
if textStart.Index != 2 {
|
|
t.Errorf("expected index 2, got %d", textStart.Index)
|
|
}
|
|
if textStart.ContentBlock.Type != "text" {
|
|
t.Errorf("expected type 'text', got %q", textStart.ContentBlock.Type)
|
|
}
|
|
// Text in start should be empty
|
|
if textStart.ContentBlock.Text == nil || *textStart.ContentBlock.Text != "" {
|
|
t.Errorf("expected empty text in content_block_start, got %v", textStart.ContentBlock.Text)
|
|
}
|
|
|
|
// Verify text delta (event index 6)
|
|
var textDelta anthropic.ContentBlockDeltaEvent
|
|
if err := json.Unmarshal([]byte(events[6].data), &textDelta); err != nil {
|
|
t.Fatalf("failed to parse text delta: %v", err)
|
|
}
|
|
if textDelta.Index != 2 {
|
|
t.Errorf("expected index 2, got %d", textDelta.Index)
|
|
}
|
|
if textDelta.Delta.Type != "text_delta" {
|
|
t.Errorf("expected delta type 'text_delta', got %q", textDelta.Delta.Type)
|
|
}
|
|
if textDelta.Delta.Text != "Here is the answer." {
|
|
t.Errorf("expected delta text 'Here is the answer.', got %q", textDelta.Delta.Text)
|
|
}
|
|
|
|
// Verify message_delta (event index 8)
|
|
var msgDelta anthropic.MessageDeltaEvent
|
|
if err := json.Unmarshal([]byte(events[8].data), &msgDelta); err != nil {
|
|
t.Fatalf("failed to parse message_delta: %v", err)
|
|
}
|
|
if msgDelta.Delta.StopReason != "end_turn" {
|
|
t.Errorf("expected stop_reason 'end_turn', got %q", msgDelta.Delta.StopReason)
|
|
}
|
|
if msgDelta.Usage.InputTokens != 20 {
|
|
t.Errorf("expected input_tokens 20, got %d", msgDelta.Usage.InputTokens)
|
|
}
|
|
if msgDelta.Usage.OutputTokens != 10 {
|
|
t.Errorf("expected output_tokens 10, got %d", msgDelta.Usage.OutputTokens)
|
|
}
|
|
}
|
|
|
|
// TestWebSearchSendError_NonStreaming tests sendError produces correct response shape.
|
|
func TestWebSearchSendError_NonStreaming(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
rec := httptest.NewRecorder()
|
|
ginCtx, _ := gin.CreateTestContext(rec)
|
|
|
|
innerWriter := &AnthropicWriter{
|
|
BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer},
|
|
stream: false,
|
|
id: "msg_err001",
|
|
}
|
|
wsWriter := &WebSearchAnthropicWriter{
|
|
BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer},
|
|
inner: innerWriter,
|
|
stream: false,
|
|
req: anthropic.MessagesRequest{Model: "test-model"},
|
|
}
|
|
|
|
errorUsage := anthropic.Usage{InputTokens: 7, OutputTokens: 2}
|
|
if err := wsWriter.sendError("unavailable", "test query", errorUsage); err != nil {
|
|
t.Fatalf("sendError error: %v", err)
|
|
}
|
|
|
|
var result anthropic.MessagesResponse
|
|
if err := json.Unmarshal(rec.Body.Bytes(), &result); err != nil {
|
|
t.Fatalf("unmarshal error: %v\nbody: %s", err, rec.Body.String())
|
|
}
|
|
|
|
if result.Type != "message" {
|
|
t.Errorf("expected type 'message', got %q", result.Type)
|
|
}
|
|
if result.ID != "msg_err001" {
|
|
t.Errorf("expected ID 'msg_err001', got %q", result.ID)
|
|
}
|
|
|
|
// Should have exactly 2 blocks: server_tool_use + web_search_tool_result
|
|
if len(result.Content) != 2 {
|
|
t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
|
|
}
|
|
|
|
// Block 0: server_tool_use
|
|
if result.Content[0].Type != "server_tool_use" {
|
|
t.Errorf("expected 'server_tool_use', got %q", result.Content[0].Type)
|
|
}
|
|
expectedToolID := "srvtoolu_err001"
|
|
if result.Content[0].ID != expectedToolID {
|
|
t.Errorf("expected ID %q, got %q", expectedToolID, result.Content[0].ID)
|
|
}
|
|
if result.Content[0].Name != "web_search" {
|
|
t.Errorf("expected name 'web_search', got %q", result.Content[0].Name)
|
|
}
|
|
// Verify input contains the query
|
|
inputMap, ok := result.Content[0].Input.(map[string]any)
|
|
if !ok {
|
|
t.Fatalf("expected Input to be map, got %T", result.Content[0].Input)
|
|
}
|
|
if inputMap["query"] != "test query" {
|
|
t.Errorf("expected query 'test query', got %v", inputMap["query"])
|
|
}
|
|
|
|
// Block 1: web_search_tool_result with error
|
|
if result.Content[1].Type != "web_search_tool_result" {
|
|
t.Errorf("expected 'web_search_tool_result', got %q", result.Content[1].Type)
|
|
}
|
|
if result.Content[1].ToolUseID != expectedToolID {
|
|
t.Errorf("expected tool_use_id %q, got %q", expectedToolID, result.Content[1].ToolUseID)
|
|
}
|
|
|
|
// The Content field should be a WebSearchToolResultError
|
|
contentJSON, _ := json.Marshal(result.Content[1].Content)
|
|
var errContent anthropic.WebSearchToolResultError
|
|
if err := json.Unmarshal(contentJSON, &errContent); err != nil {
|
|
t.Fatalf("failed to parse error content: %v\nraw: %s", err, string(contentJSON))
|
|
}
|
|
if errContent.Type != "web_search_tool_result_error" {
|
|
t.Errorf("expected error type 'web_search_tool_result_error', got %q", errContent.Type)
|
|
}
|
|
if errContent.ErrorCode != "unavailable" {
|
|
t.Errorf("expected error_code 'unavailable', got %q", errContent.ErrorCode)
|
|
}
|
|
|
|
if result.StopReason != "end_turn" {
|
|
t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
|
|
}
|
|
if result.Usage != errorUsage {
|
|
t.Errorf("expected usage %+v, got %+v", errorUsage, result.Usage)
|
|
}
|
|
}
|
|
|
|
// TestWebSearchSendError_Streaming tests sendError in streaming mode produces proper SSE.
|
|
func TestWebSearchSendError_Streaming(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
rec := httptest.NewRecorder()
|
|
ginCtx, _ := gin.CreateTestContext(rec)
|
|
|
|
innerWriter := &AnthropicWriter{
|
|
BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer},
|
|
stream: true,
|
|
id: "msg_err002",
|
|
}
|
|
wsWriter := &WebSearchAnthropicWriter{
|
|
BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer},
|
|
inner: innerWriter,
|
|
stream: true,
|
|
req: anthropic.MessagesRequest{Model: "test-model"},
|
|
}
|
|
|
|
errorUsage := anthropic.Usage{InputTokens: 9, OutputTokens: 4}
|
|
if err := wsWriter.sendError("invalid_request", "bad query", errorUsage); err != nil {
|
|
t.Fatalf("sendError error: %v", err)
|
|
}
|
|
|
|
events := parseSSEEvents(t, rec.Body.String())
|
|
|
|
// Error response has 2 blocks: server_tool_use + web_search_tool_result
|
|
// Expected events: message_start,
|
|
// content_block_start(server_tool_use), content_block_stop,
|
|
// content_block_start(web_search_tool_result), content_block_stop,
|
|
// message_delta, message_stop
|
|
expectedEventTypes := []string{
|
|
"message_start",
|
|
"content_block_start",
|
|
"content_block_stop",
|
|
"content_block_start",
|
|
"content_block_stop",
|
|
"message_delta",
|
|
"message_stop",
|
|
}
|
|
|
|
if len(events) != len(expectedEventTypes) {
|
|
t.Fatalf("expected %d events, got %d.\nEvents: %v", len(expectedEventTypes), len(events), eventNames(events))
|
|
}
|
|
|
|
for i, expected := range expectedEventTypes {
|
|
if events[i].event != expected {
|
|
t.Errorf("event[%d]: expected %q, got %q", i, expected, events[i].event)
|
|
}
|
|
}
|
|
|
|
// Verify the server_tool_use block
|
|
var toolStart anthropic.ContentBlockStartEvent
|
|
if err := json.Unmarshal([]byte(events[1].data), &toolStart); err != nil {
|
|
t.Fatalf("failed to parse server_tool_use start: %v", err)
|
|
}
|
|
if toolStart.ContentBlock.Type != "server_tool_use" {
|
|
t.Errorf("expected 'server_tool_use', got %q", toolStart.ContentBlock.Type)
|
|
}
|
|
|
|
// Verify the web_search_tool_result block
|
|
var resultStart anthropic.ContentBlockStartEvent
|
|
if err := json.Unmarshal([]byte(events[3].data), &resultStart); err != nil {
|
|
t.Fatalf("failed to parse web_search_tool_result start: %v", err)
|
|
}
|
|
if resultStart.ContentBlock.Type != "web_search_tool_result" {
|
|
t.Errorf("expected 'web_search_tool_result', got %q", resultStart.ContentBlock.Type)
|
|
}
|
|
|
|
var msgDelta anthropic.MessageDeltaEvent
|
|
if err := json.Unmarshal([]byte(events[5].data), &msgDelta); err != nil {
|
|
t.Fatalf("failed to parse message_delta: %v", err)
|
|
}
|
|
if msgDelta.Usage.InputTokens != errorUsage.InputTokens || msgDelta.Usage.OutputTokens != errorUsage.OutputTokens {
|
|
t.Fatalf("expected usage %+v in message_delta, got %+v", errorUsage, msgDelta.Usage)
|
|
}
|
|
}
|
|
|
|
// TestWebSearchSendError_EmptyQuery tests sendError with an empty query.
|
|
func TestWebSearchSendError_EmptyQuery(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
rec := httptest.NewRecorder()
|
|
ginCtx, _ := gin.CreateTestContext(rec)
|
|
|
|
innerWriter := &AnthropicWriter{
|
|
BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer},
|
|
stream: false,
|
|
id: "msg_empty001",
|
|
}
|
|
wsWriter := &WebSearchAnthropicWriter{
|
|
BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer},
|
|
inner: innerWriter,
|
|
stream: false,
|
|
req: anthropic.MessagesRequest{Model: "test-model"},
|
|
}
|
|
|
|
if err := wsWriter.sendError("invalid_request", "", anthropic.Usage{}); err != nil {
|
|
t.Fatalf("sendError error: %v", err)
|
|
}
|
|
|
|
var result anthropic.MessagesResponse
|
|
if err := json.Unmarshal(rec.Body.Bytes(), &result); err != nil {
|
|
t.Fatalf("unmarshal error: %v", err)
|
|
}
|
|
|
|
if len(result.Content) != 2 {
|
|
t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
|
|
}
|
|
|
|
// Verify the input has empty query
|
|
inputMap, ok := result.Content[0].Input.(map[string]any)
|
|
if !ok {
|
|
t.Fatalf("expected Input to be map, got %T", result.Content[0].Input)
|
|
}
|
|
if inputMap["query"] != "" {
|
|
t.Errorf("expected empty query, got %v", inputMap["query"])
|
|
}
|
|
}
|
|
|
|
// --- SSE parsing helpers ---
|
|
|
|
type sseEvent struct {
|
|
event string
|
|
data string
|
|
}
|
|
|
|
// parseSSEEvents parses Server-Sent Events from a string.
|
|
func parseSSEEvents(t *testing.T, body string) []sseEvent {
|
|
t.Helper()
|
|
var events []sseEvent
|
|
var currentEvent string
|
|
var currentData strings.Builder
|
|
|
|
for _, line := range strings.Split(body, "\n") {
|
|
if strings.HasPrefix(line, "event: ") {
|
|
currentEvent = strings.TrimPrefix(line, "event: ")
|
|
} else if strings.HasPrefix(line, "data: ") {
|
|
currentData.WriteString(strings.TrimPrefix(line, "data: "))
|
|
} else if line == "" && currentEvent != "" {
|
|
events = append(events, sseEvent{event: currentEvent, data: currentData.String()})
|
|
currentEvent = ""
|
|
currentData.Reset()
|
|
}
|
|
}
|
|
return events
|
|
}
|
|
|
|
// eventNames returns a list of event type names for debugging.
|
|
func eventNames(events []sseEvent) []string {
|
|
names := make([]string, len(events))
|
|
for i, e := range events {
|
|
names[i] = e.event
|
|
}
|
|
return names
|
|
}
|
|
|
|
// TestWebSearchCloudModelGating tests web_search behavior across model types.
|
|
func TestWebSearchCloudModelGating(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
enableCloudForTest(t)
|
|
|
|
t.Run("local model allowed when web_search is not called", func(t *testing.T) {
|
|
handlerCalled := false
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
handlerCalled = true
|
|
resp := api.ChatResponse{
|
|
Model: "llama3.2",
|
|
Message: api.Message{Role: "assistant", Content: "hello"},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
|
}
|
|
data, _ := json.Marshal(resp)
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
_, _ = c.Writer.Write(data)
|
|
})
|
|
|
|
body := `{"model":"llama3.2","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != http.StatusOK {
|
|
t.Errorf("expected 200, got %d: %s", resp.Code, resp.Body.String())
|
|
}
|
|
if !handlerCalled {
|
|
t.Error("handler should be called for local model when web_search is not called")
|
|
}
|
|
})
|
|
|
|
t.Run("local model emits web_search and gets structured error", func(t *testing.T) {
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
resp := api.ChatResponse{
|
|
Model: "llama3.2",
|
|
Message: api.Message{
|
|
Role: "assistant",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "call_local_ws",
|
|
Function: api.ToolCallFunction{
|
|
Name: "web_search",
|
|
Arguments: makeArgs("query", "hello"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 8, EvalCount: 2},
|
|
}
|
|
data, _ := json.Marshal(resp)
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
_, _ = c.Writer.Write(data)
|
|
})
|
|
|
|
body := `{"model":"llama3.2","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
|
|
}
|
|
|
|
var result anthropic.MessagesResponse
|
|
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
|
|
t.Fatalf("unmarshal error: %v", err)
|
|
}
|
|
if len(result.Content) != 2 {
|
|
t.Fatalf("expected 2 content blocks for local model web_search error, got %d", len(result.Content))
|
|
}
|
|
contentJSON, _ := json.Marshal(result.Content[1].Content)
|
|
var errContent anthropic.WebSearchToolResultError
|
|
if err := json.Unmarshal(contentJSON, &errContent); err != nil {
|
|
t.Fatalf("failed to parse web_search error content: %v", err)
|
|
}
|
|
if errContent.ErrorCode != "web_search_not_supported_for_local_models" {
|
|
t.Fatalf("expected web_search_not_supported_for_local_models, got %q", errContent.ErrorCode)
|
|
}
|
|
})
|
|
|
|
t.Run("model ending in cloud without cloud suffix treated as local", func(t *testing.T) {
|
|
handlerCalled := false
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
handlerCalled = true
|
|
resp := api.ChatResponse{
|
|
Model: "notreallycloud",
|
|
Message: api.Message{Role: "assistant", Content: "hello"},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
|
}
|
|
data, _ := json.Marshal(resp)
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
_, _ = c.Writer.Write(data)
|
|
})
|
|
|
|
body := `{"model":"notreallycloud","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if !handlerCalled {
|
|
t.Error("handler should be called for non-cloud model when web_search is not called")
|
|
}
|
|
if resp.Code != http.StatusOK {
|
|
t.Errorf("expected 200, got %d: %s", resp.Code, resp.Body.String())
|
|
}
|
|
})
|
|
|
|
t.Run("cloud model with size tag allowed", func(t *testing.T) {
|
|
handlerCalled := false
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
handlerCalled = true
|
|
resp := api.ChatResponse{
|
|
Model: "gpt-oss:120b",
|
|
Message: api.Message{Role: "assistant", Content: "hello"},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
|
}
|
|
data, _ := json.Marshal(resp)
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
_, _ = c.Writer.Write(data)
|
|
})
|
|
|
|
body := `{"model":"gpt-oss:120b-cloud","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if !handlerCalled {
|
|
t.Error("handler should be called for cloud model")
|
|
}
|
|
if resp.Code != http.StatusOK {
|
|
t.Errorf("expected 200, got %d: %s", resp.Code, resp.Body.String())
|
|
}
|
|
})
|
|
|
|
t.Run("cloud model allowed", func(t *testing.T) {
|
|
handlerCalled := false
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
handlerCalled = true
|
|
resp := api.ChatResponse{
|
|
Model: "kimi-k2.5",
|
|
Message: api.Message{Role: "assistant", Content: "hello"},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
|
}
|
|
data, _ := json.Marshal(resp)
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
_, _ = c.Writer.Write(data)
|
|
})
|
|
|
|
body := `{"model":"kimi-k2.5:cloud","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if !handlerCalled {
|
|
t.Error("handler should be called for cloud model")
|
|
}
|
|
if resp.Code != http.StatusOK {
|
|
t.Errorf("expected 200, got %d: %s", resp.Code, resp.Body.String())
|
|
}
|
|
})
|
|
|
|
t.Run("cloud disabled blocks web search for cloud model", func(t *testing.T) {
|
|
t.Setenv("OLLAMA_NO_CLOUD", "1")
|
|
|
|
handlerCalled := false
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
handlerCalled = true
|
|
})
|
|
|
|
body := `{"model":"kimi-k2.5:cloud","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != http.StatusForbidden {
|
|
t.Fatalf("expected 403, got %d: %s", resp.Code, resp.Body.String())
|
|
}
|
|
if handlerCalled {
|
|
t.Fatal("handler should not be called when cloud is disabled")
|
|
}
|
|
|
|
var errResp anthropic.ErrorResponse
|
|
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
|
t.Fatalf("failed to parse error response: %v", err)
|
|
}
|
|
if !strings.Contains(errResp.Error.Message, "ollama cloud is disabled") {
|
|
t.Fatalf("expected cloud disabled error, got: %q", errResp.Error.Message)
|
|
}
|
|
})
|
|
|
|
t.Run("cloud disabled does not block local model if web_search is not called", func(t *testing.T) {
|
|
t.Setenv("OLLAMA_NO_CLOUD", "1")
|
|
|
|
handlerCalled := false
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
handlerCalled = true
|
|
resp := api.ChatResponse{
|
|
Model: "llama3.2",
|
|
Message: api.Message{Role: "assistant", Content: "hello"},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
|
}
|
|
data, _ := json.Marshal(resp)
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
_, _ = c.Writer.Write(data)
|
|
})
|
|
|
|
body := `{"model":"llama3.2","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
|
|
}
|
|
if !handlerCalled {
|
|
t.Fatal("handler should be called for local model when web_search is not called")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestWebSearchDoesNotRequireAuthorizationHeaderForMockEndpoint(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
enableCloudForTest(t)
|
|
|
|
var authHeader string
|
|
searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
authHeader = r.Header.Get("Authorization")
|
|
resp := anthropic.OllamaWebSearchResponse{
|
|
Results: []anthropic.OllamaWebSearchResult{
|
|
{Title: "Result", URL: "https://example.com", Content: "content"},
|
|
},
|
|
}
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer searchServer.Close()
|
|
originalEndpoint := anthropic.WebSearchEndpoint
|
|
anthropic.WebSearchEndpoint = searchServer.URL
|
|
defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
|
|
|
|
followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
resp := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{Role: "assistant", Content: "done"},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 5, EvalCount: 2},
|
|
}
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer followupServer.Close()
|
|
t.Setenv("OLLAMA_HOST", followupServer.URL)
|
|
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
resp := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{
|
|
Role: "assistant",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "call_auth",
|
|
Function: api.ToolCallFunction{
|
|
Name: "web_search",
|
|
Arguments: makeArgs("query", "auth test"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 4, EvalCount: 1},
|
|
}
|
|
data, _ := json.Marshal(resp)
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
_, _ = c.Writer.Write(data)
|
|
})
|
|
|
|
body := `{
|
|
"model":"test-model:cloud",
|
|
"max_tokens":100,
|
|
"messages":[{"role":"user","content":"test auth"}],
|
|
"tools":[{"type":"web_search_20250305","name":"web_search"}]
|
|
}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
|
|
}
|
|
if authHeader != "" {
|
|
t.Fatalf("expected no Authorization header for mock web search endpoint, got %q", authHeader)
|
|
}
|
|
}
|
|
|
|
// TestWebSearchSearchAPIError tests that a failing search API returns a proper error response.
|
|
func TestWebSearchSearchAPIError(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
enableCloudForTest(t)
|
|
|
|
// Mock search server that returns 500
|
|
searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
|
}))
|
|
defer searchServer.Close()
|
|
originalEndpoint := anthropic.WebSearchEndpoint
|
|
anthropic.WebSearchEndpoint = searchServer.URL
|
|
defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
|
|
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
resp := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{
|
|
Role: "assistant",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "call_err",
|
|
Function: api.ToolCallFunction{
|
|
Name: "web_search",
|
|
Arguments: makeArgs("query", "test"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 2},
|
|
}
|
|
data, _ := json.Marshal(resp)
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
_, _ = c.Writer.Write(data)
|
|
})
|
|
|
|
body := `{
|
|
"model":"test-model:cloud",
|
|
"max_tokens":100,
|
|
"messages":[{"role":"user","content":"test"}],
|
|
"tools":[{"type":"web_search_20250305","name":"web_search"}]
|
|
}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
|
|
}
|
|
|
|
var result anthropic.MessagesResponse
|
|
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
|
|
t.Fatalf("unmarshal error: %v", err)
|
|
}
|
|
|
|
// Error response: server_tool_use + web_search_tool_result with error
|
|
if len(result.Content) != 2 {
|
|
t.Fatalf("expected 2 content blocks for error, got %d", len(result.Content))
|
|
}
|
|
if result.Content[0].Type != "server_tool_use" {
|
|
t.Errorf("expected 'server_tool_use', got %q", result.Content[0].Type)
|
|
}
|
|
if result.Content[1].Type != "web_search_tool_result" {
|
|
t.Errorf("expected 'web_search_tool_result', got %q", result.Content[1].Type)
|
|
}
|
|
if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 2 {
|
|
t.Fatalf("expected usage input=10 output=2, got %+v", result.Usage)
|
|
}
|
|
}
|
|
|
|
func TestWebSearchStreamingImmediateTakeover(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
enableCloudForTest(t)
|
|
|
|
followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
resp := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{Role: "assistant", Content: "After search."},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 20, EvalCount: 10},
|
|
}
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer followupServer.Close()
|
|
t.Setenv("OLLAMA_HOST", followupServer.URL)
|
|
|
|
searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
resp := anthropic.OllamaWebSearchResponse{
|
|
Results: []anthropic.OllamaWebSearchResult{
|
|
{Title: "Result", URL: "https://example.com", Content: "content"},
|
|
},
|
|
}
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer searchServer.Close()
|
|
originalEndpoint := anthropic.WebSearchEndpoint
|
|
anthropic.WebSearchEndpoint = searchServer.URL
|
|
defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
|
|
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
chunks := []api.ChatResponse{
|
|
{
|
|
Model: "test-model",
|
|
Message: api.Message{Role: "assistant", Content: "Preface "},
|
|
Done: false,
|
|
},
|
|
{
|
|
Model: "test-model",
|
|
Message: api.Message{
|
|
Role: "assistant",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "call_ws_stream_1",
|
|
Function: api.ToolCallFunction{
|
|
Name: "web_search",
|
|
Arguments: makeArgs("query", "latest updates"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Done: false,
|
|
},
|
|
{
|
|
Model: "test-model",
|
|
Message: api.Message{Role: "assistant", Content: "ignored chunk"},
|
|
Done: false,
|
|
},
|
|
{
|
|
Model: "test-model",
|
|
Message: api.Message{Role: "assistant"},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 9, EvalCount: 4},
|
|
},
|
|
}
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
for _, chunk := range chunks {
|
|
data, _ := json.Marshal(chunk)
|
|
_, _ = c.Writer.Write(data)
|
|
}
|
|
})
|
|
|
|
body := `{
|
|
"model":"test-model:cloud",
|
|
"max_tokens":100,
|
|
"stream":true,
|
|
"messages":[{"role":"user","content":"Find updates"}],
|
|
"tools":[{"type":"web_search_20250305","name":"web_search"}]
|
|
}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
|
|
}
|
|
|
|
events := parseSSEEvents(t, resp.Body.String())
|
|
if countEventsByName(events, "message_start") != 1 {
|
|
t.Fatalf("expected exactly one message_start, got %d", countEventsByName(events, "message_start"))
|
|
}
|
|
if countEventsByName(events, "message_stop") != 1 {
|
|
t.Fatalf("expected exactly one message_stop, got %d", countEventsByName(events, "message_stop"))
|
|
}
|
|
|
|
textDeltas := collectTextDeltas(t, events)
|
|
if !containsString(textDeltas, "Preface ") {
|
|
t.Fatalf("expected passthrough text delta, got %v", textDeltas)
|
|
}
|
|
if !containsString(textDeltas, "After search.") {
|
|
t.Fatalf("expected post-search text delta, got %v", textDeltas)
|
|
}
|
|
if containsString(textDeltas, "ignored chunk") {
|
|
t.Fatalf("unexpected text from chunks after takeover: %v", textDeltas)
|
|
}
|
|
}
|
|
|
|
func TestWebSearchStreamingUsageUsesObservedChunkMetrics(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
enableCloudForTest(t)
|
|
|
|
followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
resp := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{Role: "assistant", Content: "After search."},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 20, EvalCount: 7},
|
|
}
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer followupServer.Close()
|
|
t.Setenv("OLLAMA_HOST", followupServer.URL)
|
|
|
|
searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
resp := anthropic.OllamaWebSearchResponse{
|
|
Results: []anthropic.OllamaWebSearchResult{
|
|
{Title: "Result", URL: "https://example.com", Content: "content"},
|
|
},
|
|
}
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer searchServer.Close()
|
|
originalEndpoint := anthropic.WebSearchEndpoint
|
|
anthropic.WebSearchEndpoint = searchServer.URL
|
|
defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
|
|
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
chunks := []api.ChatResponse{
|
|
{
|
|
Model: "test-model",
|
|
Message: api.Message{Role: "assistant", Content: "Preface "},
|
|
Done: false,
|
|
Metrics: api.Metrics{PromptEvalCount: 12, EvalCount: 4},
|
|
},
|
|
{
|
|
Model: "test-model",
|
|
Message: api.Message{
|
|
Role: "assistant",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "call_ws_stream_usage",
|
|
Function: api.ToolCallFunction{
|
|
Name: "web_search",
|
|
Arguments: makeArgs("query", "latest updates"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Done: false,
|
|
Metrics: api.Metrics{PromptEvalCount: 0, EvalCount: 0},
|
|
},
|
|
{
|
|
Model: "test-model",
|
|
Message: api.Message{Role: "assistant"},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 12, EvalCount: 4},
|
|
},
|
|
}
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
for _, chunk := range chunks {
|
|
data, _ := json.Marshal(chunk)
|
|
_, _ = c.Writer.Write(data)
|
|
}
|
|
})
|
|
|
|
body := `{
|
|
"model":"test-model:cloud",
|
|
"max_tokens":100,
|
|
"stream":true,
|
|
"messages":[{"role":"user","content":"Find updates"}],
|
|
"tools":[{"type":"web_search_20250305","name":"web_search"}]
|
|
}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
|
|
}
|
|
|
|
events := parseSSEEvents(t, resp.Body.String())
|
|
var messageDelta anthropic.MessageDeltaEvent
|
|
found := false
|
|
for _, event := range events {
|
|
if event.event != "message_delta" {
|
|
continue
|
|
}
|
|
if err := json.Unmarshal([]byte(event.data), &messageDelta); err != nil {
|
|
t.Fatalf("failed to unmarshal message_delta: %v", err)
|
|
}
|
|
found = true
|
|
break
|
|
}
|
|
if !found {
|
|
t.Fatal("expected message_delta event")
|
|
}
|
|
if messageDelta.Usage.InputTokens != 32 {
|
|
t.Fatalf("expected aggregated input tokens 32 (12 passthrough + 20 followup), got %d", messageDelta.Usage.InputTokens)
|
|
}
|
|
if messageDelta.Usage.OutputTokens != 11 {
|
|
t.Fatalf("expected aggregated output tokens 11 (4 passthrough + 7 followup), got %d", messageDelta.Usage.OutputTokens)
|
|
}
|
|
}
|
|
|
|
func TestWebSearchMixedToolCallsPreferWebSearch(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
enableCloudForTest(t)
|
|
|
|
followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
resp := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{Role: "assistant", Content: "Search answer."},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 11, EvalCount: 6},
|
|
}
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer followupServer.Close()
|
|
t.Setenv("OLLAMA_HOST", followupServer.URL)
|
|
|
|
searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
resp := anthropic.OllamaWebSearchResponse{
|
|
Results: []anthropic.OllamaWebSearchResult{
|
|
{Title: "Result", URL: "https://example.com", Content: "content"},
|
|
},
|
|
}
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer searchServer.Close()
|
|
originalEndpoint := anthropic.WebSearchEndpoint
|
|
anthropic.WebSearchEndpoint = searchServer.URL
|
|
defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
|
|
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
resp := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{
|
|
Role: "assistant",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "call_other",
|
|
Function: api.ToolCallFunction{
|
|
Name: "get_weather",
|
|
Arguments: makeArgs("location", "SF"),
|
|
},
|
|
},
|
|
{
|
|
ID: "call_ws_mixed",
|
|
Function: api.ToolCallFunction{
|
|
Name: "web_search",
|
|
Arguments: makeArgs("query", "latest weather"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 2},
|
|
}
|
|
data, _ := json.Marshal(resp)
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
_, _ = c.Writer.Write(data)
|
|
})
|
|
|
|
body := `{
|
|
"model":"test-model:cloud",
|
|
"max_tokens":100,
|
|
"messages":[{"role":"user","content":"Weather?"}],
|
|
"tools":[
|
|
{"type":"web_search_20250305","name":"web_search"},
|
|
{"type":"custom","name":"get_weather","input_schema":{"type":"object"}}
|
|
]
|
|
}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
|
|
}
|
|
|
|
var result anthropic.MessagesResponse
|
|
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
|
|
t.Fatalf("unmarshal error: %v", err)
|
|
}
|
|
|
|
if len(result.Content) < 3 {
|
|
t.Fatalf("expected at least 3 blocks, got %d", len(result.Content))
|
|
}
|
|
if result.Content[0].Type != "server_tool_use" {
|
|
t.Fatalf("expected server_tool_use first, got %q", result.Content[0].Type)
|
|
}
|
|
if result.Content[1].Type != "web_search_tool_result" {
|
|
t.Fatalf("expected web_search_tool_result second, got %q", result.Content[1].Type)
|
|
}
|
|
|
|
for _, block := range result.Content {
|
|
if block.Type == "tool_use" && block.Name == "get_weather" {
|
|
t.Fatalf("did not expect get_weather tool_use in mixed web_search-preferred path: %+v", result.Content)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestWebSearchFollowupClientToolStopReasonToolUse(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
enableCloudForTest(t)
|
|
|
|
followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
resp := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{
|
|
Role: "assistant",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "call_weather_final",
|
|
Function: api.ToolCallFunction{
|
|
Name: "get_weather",
|
|
Arguments: makeArgs("location", "New York"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 25, EvalCount: 7},
|
|
}
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer followupServer.Close()
|
|
t.Setenv("OLLAMA_HOST", followupServer.URL)
|
|
|
|
searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
resp := anthropic.OllamaWebSearchResponse{
|
|
Results: []anthropic.OllamaWebSearchResult{
|
|
{Title: "Result", URL: "https://example.com", Content: "content"},
|
|
},
|
|
}
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer searchServer.Close()
|
|
originalEndpoint := anthropic.WebSearchEndpoint
|
|
anthropic.WebSearchEndpoint = searchServer.URL
|
|
defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
|
|
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
resp := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{
|
|
Role: "assistant",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "call_ws_tool_use",
|
|
Function: api.ToolCallFunction{
|
|
Name: "web_search",
|
|
Arguments: makeArgs("query", "forecast"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 15, EvalCount: 3},
|
|
}
|
|
data, _ := json.Marshal(resp)
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
_, _ = c.Writer.Write(data)
|
|
})
|
|
|
|
body := `{
|
|
"model":"test-model:cloud",
|
|
"max_tokens":100,
|
|
"messages":[{"role":"user","content":"Do I need an umbrella?"}],
|
|
"tools":[
|
|
{"type":"web_search_20250305","name":"web_search"},
|
|
{"type":"custom","name":"get_weather","input_schema":{"type":"object"}}
|
|
]
|
|
}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
|
|
}
|
|
|
|
var result anthropic.MessagesResponse
|
|
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
|
|
t.Fatalf("unmarshal error: %v", err)
|
|
}
|
|
|
|
if result.StopReason != "tool_use" {
|
|
t.Fatalf("expected stop_reason tool_use, got %q", result.StopReason)
|
|
}
|
|
if len(result.Content) < 3 {
|
|
t.Fatalf("expected server blocks + tool_use, got %d blocks", len(result.Content))
|
|
}
|
|
last := result.Content[len(result.Content)-1]
|
|
if last.Type != "tool_use" {
|
|
t.Fatalf("expected final block tool_use, got %q", last.Type)
|
|
}
|
|
if last.Name != "get_weather" {
|
|
t.Fatalf("expected final tool name get_weather, got %q", last.Name)
|
|
}
|
|
if result.Usage.InputTokens != 40 || result.Usage.OutputTokens != 10 {
|
|
t.Fatalf("unexpected aggregated usage: %+v", result.Usage)
|
|
}
|
|
}
|
|
|
|
func TestWebSearchMultiIterationLoop(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
enableCloudForTest(t)
|
|
|
|
followupCall := 0
|
|
followupDecodeErr := false
|
|
missingWebSearchTool := false
|
|
followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var followupReq api.ChatRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&followupReq); err != nil {
|
|
followupDecodeErr = true
|
|
http.Error(w, "bad request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
hasWebSearchTool := false
|
|
for _, tool := range followupReq.Tools {
|
|
if tool.Function.Name == "web_search" {
|
|
hasWebSearchTool = true
|
|
break
|
|
}
|
|
}
|
|
if !hasWebSearchTool {
|
|
missingWebSearchTool = true
|
|
}
|
|
|
|
followupCall++
|
|
switch followupCall {
|
|
case 1:
|
|
resp := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{
|
|
Role: "assistant",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "call_ws_2",
|
|
Function: api.ToolCallFunction{
|
|
Name: "web_search",
|
|
Arguments: makeArgs("query", "loop query 2"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 20, EvalCount: 2},
|
|
}
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
case 2:
|
|
resp := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{Role: "assistant", Content: "Final answer after 2 searches."},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 30, EvalCount: 3},
|
|
}
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
default:
|
|
t.Fatalf("unexpected extra followup call: %d", followupCall)
|
|
}
|
|
}))
|
|
defer followupServer.Close()
|
|
t.Setenv("OLLAMA_HOST", followupServer.URL)
|
|
|
|
searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
resp := anthropic.OllamaWebSearchResponse{
|
|
Results: []anthropic.OllamaWebSearchResult{
|
|
{Title: "Result", URL: "https://example.com", Content: "content"},
|
|
},
|
|
}
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer searchServer.Close()
|
|
originalEndpoint := anthropic.WebSearchEndpoint
|
|
anthropic.WebSearchEndpoint = searchServer.URL
|
|
defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
|
|
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
resp := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{
|
|
Role: "assistant",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "call_ws_1",
|
|
Function: api.ToolCallFunction{
|
|
Name: "web_search",
|
|
Arguments: makeArgs("query", "loop query 1"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 1},
|
|
}
|
|
data, _ := json.Marshal(resp)
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
_, _ = c.Writer.Write(data)
|
|
})
|
|
|
|
body := `{
|
|
"model":"test-model:cloud",
|
|
"max_tokens":100,
|
|
"messages":[{"role":"user","content":"do multiple searches"}],
|
|
"tools":[{"type":"web_search_20250305","name":"web_search"}]
|
|
}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
|
|
}
|
|
if followupCall != 2 {
|
|
t.Fatalf("expected 2 followup calls, got %d", followupCall)
|
|
}
|
|
if followupDecodeErr {
|
|
t.Fatal("failed to decode followup request body")
|
|
}
|
|
if missingWebSearchTool {
|
|
t.Fatal("expected followup requests to retain web_search tool definition")
|
|
}
|
|
|
|
var result anthropic.MessagesResponse
|
|
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
|
|
t.Fatalf("unmarshal error: %v", err)
|
|
}
|
|
|
|
serverToolUses := 0
|
|
webResults := 0
|
|
for _, block := range result.Content {
|
|
if block.Type == "server_tool_use" {
|
|
serverToolUses++
|
|
}
|
|
if block.Type == "web_search_tool_result" {
|
|
webResults++
|
|
}
|
|
}
|
|
if serverToolUses != 2 || webResults != 2 {
|
|
t.Fatalf("expected two search iterations, got server_tool_use=%d web_search_tool_result=%d", serverToolUses, webResults)
|
|
}
|
|
|
|
if result.Usage.InputTokens != 60 || result.Usage.OutputTokens != 6 {
|
|
t.Fatalf("unexpected aggregated usage: %+v", result.Usage)
|
|
}
|
|
}
|
|
|
|
func TestWebSearchLoopMaxLimit(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
enableCloudForTest(t)
|
|
|
|
followupCall := 0
|
|
followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
followupCall++
|
|
resp := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{
|
|
Role: "assistant",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "call_ws_loop_limit",
|
|
Function: api.ToolCallFunction{
|
|
Name: "web_search",
|
|
Arguments: makeArgs("query", "loop query next"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 7, EvalCount: 2},
|
|
}
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer followupServer.Close()
|
|
t.Setenv("OLLAMA_HOST", followupServer.URL)
|
|
|
|
searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
resp := anthropic.OllamaWebSearchResponse{
|
|
Results: []anthropic.OllamaWebSearchResult{
|
|
{Title: "Result", URL: "https://example.com", Content: "content"},
|
|
},
|
|
}
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer searchServer.Close()
|
|
originalEndpoint := anthropic.WebSearchEndpoint
|
|
anthropic.WebSearchEndpoint = searchServer.URL
|
|
defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
|
|
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
resp := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{
|
|
Role: "assistant",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "call_ws_initial",
|
|
Function: api.ToolCallFunction{
|
|
Name: "web_search",
|
|
Arguments: makeArgs("query", "loop query 1"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 5, EvalCount: 1},
|
|
}
|
|
data, _ := json.Marshal(resp)
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
_, _ = c.Writer.Write(data)
|
|
})
|
|
|
|
body := `{
|
|
"model":"test-model:cloud",
|
|
"max_tokens":100,
|
|
"messages":[{"role":"user","content":"keep searching"}],
|
|
"tools":[{"type":"web_search_20250305","name":"web_search"}]
|
|
}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
|
|
}
|
|
if followupCall != 3 {
|
|
t.Fatalf("expected 3 followup calls before max loop error, got %d", followupCall)
|
|
}
|
|
|
|
var result anthropic.MessagesResponse
|
|
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
|
|
t.Fatalf("unmarshal error: %v", err)
|
|
}
|
|
|
|
last := result.Content[len(result.Content)-1]
|
|
if last.Type != "web_search_tool_result" {
|
|
t.Fatalf("expected last block web_search_tool_result, got %q", last.Type)
|
|
}
|
|
contentJSON, _ := json.Marshal(last.Content)
|
|
var errContent anthropic.WebSearchToolResultError
|
|
if err := json.Unmarshal(contentJSON, &errContent); err != nil {
|
|
t.Fatalf("failed to parse web search error content: %v", err)
|
|
}
|
|
if errContent.ErrorCode != "max_uses_exceeded" {
|
|
t.Fatalf("expected max_uses_exceeded error, got %q", errContent.ErrorCode)
|
|
}
|
|
if result.StopReason != "end_turn" {
|
|
t.Fatalf("expected end_turn, got %q", result.StopReason)
|
|
}
|
|
}
|
|
|
|
func TestWebSearchStreamingFinalStopReasonToolUse(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
enableCloudForTest(t)
|
|
|
|
followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
resp := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{
|
|
Role: "assistant",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "call_weather_stream",
|
|
Function: api.ToolCallFunction{
|
|
Name: "get_weather",
|
|
Arguments: makeArgs("location", "Seattle"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 14, EvalCount: 5},
|
|
}
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer followupServer.Close()
|
|
t.Setenv("OLLAMA_HOST", followupServer.URL)
|
|
|
|
searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
resp := anthropic.OllamaWebSearchResponse{
|
|
Results: []anthropic.OllamaWebSearchResult{
|
|
{Title: "Result", URL: "https://example.com", Content: "content"},
|
|
},
|
|
}
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer searchServer.Close()
|
|
originalEndpoint := anthropic.WebSearchEndpoint
|
|
anthropic.WebSearchEndpoint = searchServer.URL
|
|
defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
|
|
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
chunks := []api.ChatResponse{
|
|
{
|
|
Model: "test-model",
|
|
Message: api.Message{Role: "assistant", Content: "Let me check. "},
|
|
Done: false,
|
|
},
|
|
{
|
|
Model: "test-model",
|
|
Message: api.Message{
|
|
Role: "assistant",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "call_ws_stream_tool_use",
|
|
Function: api.ToolCallFunction{
|
|
Name: "web_search",
|
|
Arguments: makeArgs("query", "weather seattle"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Done: false,
|
|
},
|
|
{
|
|
Model: "test-model",
|
|
Message: api.Message{Role: "assistant"},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 3},
|
|
},
|
|
}
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
for _, chunk := range chunks {
|
|
data, _ := json.Marshal(chunk)
|
|
_, _ = c.Writer.Write(data)
|
|
}
|
|
})
|
|
|
|
body := `{
|
|
"model":"test-model:cloud",
|
|
"max_tokens":100,
|
|
"stream":true,
|
|
"messages":[{"role":"user","content":"Should I take a jacket?"}],
|
|
"tools":[
|
|
{"type":"web_search_20250305","name":"web_search"},
|
|
{"type":"custom","name":"get_weather","input_schema":{"type":"object"}}
|
|
]
|
|
}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
|
|
}
|
|
|
|
events := parseSSEEvents(t, resp.Body.String())
|
|
if countEventsByName(events, "message_start") != 1 {
|
|
t.Fatalf("expected exactly one message_start, got %d", countEventsByName(events, "message_start"))
|
|
}
|
|
|
|
var messageDelta anthropic.MessageDeltaEvent
|
|
foundMessageDelta := false
|
|
foundToolUse := false
|
|
for _, event := range events {
|
|
if event.event == "message_delta" {
|
|
foundMessageDelta = true
|
|
if err := json.Unmarshal([]byte(event.data), &messageDelta); err != nil {
|
|
t.Fatalf("failed to unmarshal message_delta: %v", err)
|
|
}
|
|
}
|
|
if event.event == "content_block_start" {
|
|
var start anthropic.ContentBlockStartEvent
|
|
if err := json.Unmarshal([]byte(event.data), &start); err != nil {
|
|
t.Fatalf("failed to unmarshal content_block_start: %v", err)
|
|
}
|
|
if start.ContentBlock.Type == "tool_use" && start.ContentBlock.Name == "get_weather" {
|
|
foundToolUse = true
|
|
}
|
|
}
|
|
}
|
|
|
|
if !foundMessageDelta {
|
|
t.Fatal("expected message_delta event")
|
|
}
|
|
if messageDelta.Delta.StopReason != "tool_use" {
|
|
t.Fatalf("expected stop_reason tool_use, got %q", messageDelta.Delta.StopReason)
|
|
}
|
|
if !foundToolUse {
|
|
t.Fatal("expected tool_use content block for get_weather")
|
|
}
|
|
}
|
|
|
|
func TestWebSearchFollowupNon200ReturnsApiError(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
enableCloudForTest(t)
|
|
|
|
followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
http.Error(w, "boom", http.StatusInternalServerError)
|
|
}))
|
|
defer followupServer.Close()
|
|
t.Setenv("OLLAMA_HOST", followupServer.URL)
|
|
|
|
searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
resp := anthropic.OllamaWebSearchResponse{
|
|
Results: []anthropic.OllamaWebSearchResult{
|
|
{Title: "Result", URL: "https://example.com", Content: "content"},
|
|
},
|
|
}
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer searchServer.Close()
|
|
originalEndpoint := anthropic.WebSearchEndpoint
|
|
anthropic.WebSearchEndpoint = searchServer.URL
|
|
defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
|
|
|
|
router := gin.New()
|
|
router.Use(AnthropicMessagesMiddleware())
|
|
router.POST("/v1/messages", func(c *gin.Context) {
|
|
resp := api.ChatResponse{
|
|
Model: "test-model",
|
|
Message: api.Message{
|
|
Role: "assistant",
|
|
ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "call_ws_non200",
|
|
Function: api.ToolCallFunction{
|
|
Name: "web_search",
|
|
Arguments: makeArgs("query", "test"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Done: true,
|
|
DoneReason: "stop",
|
|
Metrics: api.Metrics{PromptEvalCount: 9, EvalCount: 1},
|
|
}
|
|
data, _ := json.Marshal(resp)
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
_, _ = c.Writer.Write(data)
|
|
})
|
|
|
|
body := `{
|
|
"model":"test-model:cloud",
|
|
"max_tokens":100,
|
|
"messages":[{"role":"user","content":"test"}],
|
|
"tools":[{"type":"web_search_20250305","name":"web_search"}]
|
|
}`
|
|
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
|
|
}
|
|
|
|
var result anthropic.MessagesResponse
|
|
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
|
|
t.Fatalf("unmarshal error: %v", err)
|
|
}
|
|
if len(result.Content) != 2 {
|
|
t.Fatalf("expected 2 blocks in error response, got %d", len(result.Content))
|
|
}
|
|
|
|
contentJSON, _ := json.Marshal(result.Content[1].Content)
|
|
var errContent anthropic.WebSearchToolResultError
|
|
if err := json.Unmarshal(contentJSON, &errContent); err != nil {
|
|
t.Fatalf("failed to parse error content: %v", err)
|
|
}
|
|
if errContent.ErrorCode != "api_error" {
|
|
t.Fatalf("expected api_error, got %q", errContent.ErrorCode)
|
|
}
|
|
if result.Usage.InputTokens != 9 || result.Usage.OutputTokens != 1 {
|
|
t.Fatalf("expected usage input=9 output=1, got %+v", result.Usage)
|
|
}
|
|
}
|
|
|
|
func countEventsByName(events []sseEvent, eventName string) int {
|
|
count := 0
|
|
for _, event := range events {
|
|
if event.event == eventName {
|
|
count++
|
|
}
|
|
}
|
|
return count
|
|
}
|
|
|
|
func collectTextDeltas(t *testing.T, events []sseEvent) []string {
|
|
t.Helper()
|
|
|
|
var deltas []string
|
|
for _, event := range events {
|
|
if event.event != "content_block_delta" {
|
|
continue
|
|
}
|
|
|
|
var delta anthropic.ContentBlockDeltaEvent
|
|
if err := json.Unmarshal([]byte(event.data), &delta); err != nil {
|
|
t.Fatalf("failed to unmarshal content_block_delta: %v", err)
|
|
}
|
|
if delta.Delta.Type == "text_delta" {
|
|
deltas = append(deltas, delta.Delta.Text)
|
|
}
|
|
}
|
|
|
|
return deltas
|
|
}
|
|
|
|
func containsString(values []string, target string) bool {
|
|
for _, value := range values {
|
|
if value == target {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|