Files
ollama-ollama/openai/openai_test.go
Baptiste Jamin 59241c5bee server: add logprobs and top_logprobs support to Ollama's API (#12899)
Adds logprobs support to Ollama's API including support for Ollama's
OpenAI-compatible API. By specifying the new 'logprobs' boolean parameter
in the API, Ollama will return the log probabilities for each token generated.
'top_logprobs', an integer value can also be specified up to the value 20.
When specified, the API will also provide the number of most likely tokens to
return at each token position

Co-authored-by: Baptiste Jamin <baptiste@crisp.chat>
2025-11-11 08:49:50 -08:00

437 lines
9.8 KiB
Go

package openai
import (
"encoding/base64"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
const (
prefix = `data:image/jpeg;base64,`
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
)
func TestFromChatRequest_Basic(t *testing.T) {
req := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{
{Role: "user", Content: "Hello"},
},
}
result, err := FromChatRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", result.Model)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
t.Errorf("unexpected message: %+v", result.Messages[0])
}
}
func TestFromChatRequest_WithImage(t *testing.T) {
imgData, _ := base64.StdEncoding.DecodeString(image)
req := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{
{
Role: "user",
Content: []any{
map[string]any{"type": "text", "text": "Hello"},
map[string]any{
"type": "image_url",
"image_url": map[string]any{"url": prefix + image},
},
},
},
},
}
result, err := FromChatRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Content != "Hello" {
t.Errorf("expected first message content 'Hello', got %q", result.Messages[0].Content)
}
if len(result.Messages[1].Images) != 1 {
t.Fatalf("expected 1 image, got %d", len(result.Messages[1].Images))
}
if string(result.Messages[1].Images[0]) != string(imgData) {
t.Error("image data mismatch")
}
}
func TestFromCompleteRequest_Basic(t *testing.T) {
temp := float32(0.8)
req := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: &temp,
}
result, err := FromCompleteRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", result.Model)
}
if result.Prompt != "Hello" {
t.Errorf("expected prompt 'Hello', got %q", result.Prompt)
}
if tempVal, ok := result.Options["temperature"].(float32); !ok || tempVal != 0.8 {
t.Errorf("expected temperature 0.8, got %v", result.Options["temperature"])
}
}
func TestToUsage(t *testing.T) {
resp := api.ChatResponse{
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 20,
},
}
usage := ToUsage(resp)
if usage.PromptTokens != 10 {
t.Errorf("expected PromptTokens 10, got %d", usage.PromptTokens)
}
if usage.CompletionTokens != 20 {
t.Errorf("expected CompletionTokens 20, got %d", usage.CompletionTokens)
}
if usage.TotalTokens != 30 {
t.Errorf("expected TotalTokens 30, got %d", usage.TotalTokens)
}
}
func TestNewError(t *testing.T) {
tests := []struct {
code int
want string
}{
{400, "invalid_request_error"},
{404, "not_found_error"},
{500, "api_error"},
}
for _, tt := range tests {
result := NewError(tt.code, "test message")
if result.Error.Type != tt.want {
t.Errorf("NewError(%d) type = %q, want %q", tt.code, result.Error.Type, tt.want)
}
if result.Error.Message != "test message" {
t.Errorf("NewError(%d) message = %q, want %q", tt.code, result.Error.Message, "test message")
}
}
}
func TestToToolCallsPreservesIDs(t *testing.T) {
original := []api.ToolCall{
{
ID: "call_abc123",
Function: api.ToolCallFunction{
Index: 2,
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Seattle",
},
},
},
{
ID: "call_def456",
Function: api.ToolCallFunction{
Index: 7,
Name: "get_time",
Arguments: api.ToolCallFunctionArguments{
"timezone": "UTC",
},
},
},
}
toolCalls := make([]api.ToolCall, len(original))
copy(toolCalls, original)
got := ToToolCalls(toolCalls)
if len(got) != len(original) {
t.Fatalf("expected %d tool calls, got %d", len(original), len(got))
}
expected := []ToolCall{
{
ID: "call_abc123",
Type: "function",
Index: 2,
Function: struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}{
Name: "get_weather",
Arguments: `{"location":"Seattle"}`,
},
},
{
ID: "call_def456",
Type: "function",
Index: 7,
Function: struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}{
Name: "get_time",
Arguments: `{"timezone":"UTC"}`,
},
},
}
if diff := cmp.Diff(expected, got); diff != "" {
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(original, toolCalls); diff != "" {
t.Errorf("input tool calls mutated (-want +got):\n%s", diff)
}
}
func TestFromChatRequest_WithLogprobs(t *testing.T) {
trueVal := true
req := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{
{Role: "user", Content: "Hello"},
},
Logprobs: &trueVal,
TopLogprobs: 5,
}
result, err := FromChatRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !result.Logprobs {
t.Error("expected Logprobs to be true")
}
if result.TopLogprobs != 5 {
t.Errorf("expected TopLogprobs to be 5, got %d", result.TopLogprobs)
}
}
func TestFromChatRequest_LogprobsDefault(t *testing.T) {
req := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{
{Role: "user", Content: "Hello"},
},
}
result, err := FromChatRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Logprobs {
t.Error("expected Logprobs to be false by default")
}
if result.TopLogprobs != 0 {
t.Errorf("expected TopLogprobs to be 0 by default, got %d", result.TopLogprobs)
}
}
func TestFromCompleteRequest_WithLogprobs(t *testing.T) {
logprobsVal := 5
req := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Logprobs: &logprobsVal,
}
result, err := FromCompleteRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !result.Logprobs {
t.Error("expected Logprobs to be true")
}
if result.TopLogprobs != 5 {
t.Errorf("expected TopLogprobs to be 5, got %d", result.TopLogprobs)
}
}
func TestToChatCompletion_WithLogprobs(t *testing.T) {
createdAt := time.Unix(1234567890, 0)
resp := api.ChatResponse{
Model: "test-model",
CreatedAt: createdAt,
Message: api.Message{Role: "assistant", Content: "Hello there"},
Logprobs: []api.Logprob{
{
TokenLogprob: api.TokenLogprob{
Token: "Hello",
Logprob: -0.5,
},
TopLogprobs: []api.TokenLogprob{
{Token: "Hello", Logprob: -0.5},
{Token: "Hi", Logprob: -1.2},
},
},
{
TokenLogprob: api.TokenLogprob{
Token: " there",
Logprob: -0.3,
},
TopLogprobs: []api.TokenLogprob{
{Token: " there", Logprob: -0.3},
{Token: " world", Logprob: -1.5},
},
},
},
Done: true,
Metrics: api.Metrics{
PromptEvalCount: 5,
EvalCount: 2,
},
}
id := "test-id"
result := ToChatCompletion(id, resp)
if result.Id != id {
t.Errorf("expected Id %q, got %q", id, result.Id)
}
if result.Created != 1234567890 {
t.Errorf("expected Created %d, got %d", int64(1234567890), result.Created)
}
if len(result.Choices) != 1 {
t.Fatalf("expected 1 choice, got %d", len(result.Choices))
}
choice := result.Choices[0]
if choice.Message.Content != "Hello there" {
t.Errorf("expected content %q, got %q", "Hello there", choice.Message.Content)
}
if choice.Logprobs == nil {
t.Fatal("expected Logprobs to be present")
}
if len(choice.Logprobs.Content) != 2 {
t.Fatalf("expected 2 logprobs, got %d", len(choice.Logprobs.Content))
}
// Verify first logprob
if choice.Logprobs.Content[0].Token != "Hello" {
t.Errorf("expected first token %q, got %q", "Hello", choice.Logprobs.Content[0].Token)
}
if choice.Logprobs.Content[0].Logprob != -0.5 {
t.Errorf("expected first logprob -0.5, got %f", choice.Logprobs.Content[0].Logprob)
}
if len(choice.Logprobs.Content[0].TopLogprobs) != 2 {
t.Errorf("expected 2 top_logprobs, got %d", len(choice.Logprobs.Content[0].TopLogprobs))
}
// Verify second logprob
if choice.Logprobs.Content[1].Token != " there" {
t.Errorf("expected second token %q, got %q", " there", choice.Logprobs.Content[1].Token)
}
}
func TestToChatCompletion_WithoutLogprobs(t *testing.T) {
createdAt := time.Unix(1234567890, 0)
resp := api.ChatResponse{
Model: "test-model",
CreatedAt: createdAt,
Message: api.Message{Role: "assistant", Content: "Hello"},
Done: true,
Metrics: api.Metrics{
PromptEvalCount: 5,
EvalCount: 1,
},
}
id := "test-id"
result := ToChatCompletion(id, resp)
if len(result.Choices) != 1 {
t.Fatalf("expected 1 choice, got %d", len(result.Choices))
}
// When no logprobs, Logprobs should be nil
if result.Choices[0].Logprobs != nil {
t.Error("expected Logprobs to be nil when not requested")
}
}
func TestFromChatRequest_TopLogprobsRange(t *testing.T) {
tests := []struct {
name string
topLogprobs int
expectValid bool
}{
{name: "valid: 0", topLogprobs: 0, expectValid: true},
{name: "valid: 1", topLogprobs: 1, expectValid: true},
{name: "valid: 10", topLogprobs: 10, expectValid: true},
{name: "valid: 20", topLogprobs: 20, expectValid: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
trueVal := true
req := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{
{Role: "user", Content: "Hello"},
},
Logprobs: &trueVal,
TopLogprobs: tt.topLogprobs,
}
result, err := FromChatRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.TopLogprobs != tt.topLogprobs {
t.Errorf("expected TopLogprobs %d, got %d", tt.topLogprobs, result.TopLogprobs)
}
})
}
}