Files
ollama-ollama/openai/openai.go
Baptiste Jamin 59241c5bee server: add logprobs and top_logprobs support to Ollama's API (#12899)
Adds logprobs support to Ollama's API including support for Ollama's
OpenAI-compatible API. By specifying the new 'logprobs' boolean parameter
in the API, Ollama will return the log probabilities for each token generated.
'top_logprobs', an integer value can also be specified up to the value 20.
When specified, the API will also provide the number of most likely tokens to
return at each token position

Co-authored-by: Baptiste Jamin <baptiste@crisp.chat>
2025-11-11 08:49:50 -08:00

727 lines
20 KiB
Go

// openai package provides core transformation logic for partial compatibility with the OpenAI REST API
package openai
import (
"bytes"
"encoding/base64"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"slices"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
var finishReasonToolCalls = "tool_calls"
type Error struct {
Message string `json:"message"`
Type string `json:"type"`
Param any `json:"param"`
Code *string `json:"code"`
}
type ErrorResponse struct {
Error Error `json:"error"`
}
type Message struct {
Role string `json:"role"`
Content any `json:"content"`
Reasoning string `json:"reasoning,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Name string `json:"name,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
type ChoiceLogprobs struct {
Content []api.Logprob `json:"content"`
}
type Choice struct {
Index int `json:"index"`
Message Message `json:"message"`
FinishReason *string `json:"finish_reason"`
Logprobs *ChoiceLogprobs `json:"logprobs,omitempty"`
}
type ChunkChoice struct {
Index int `json:"index"`
Delta Message `json:"delta"`
FinishReason *string `json:"finish_reason"`
Logprobs *ChoiceLogprobs `json:"logprobs,omitempty"`
}
type CompleteChunkChoice struct {
Text string `json:"text"`
Index int `json:"index"`
FinishReason *string `json:"finish_reason"`
Logprobs *ChoiceLogprobs `json:"logprobs,omitempty"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type ResponseFormat struct {
Type string `json:"type"`
JsonSchema *JsonSchema `json:"json_schema,omitempty"`
}
type JsonSchema struct {
Schema json.RawMessage `json:"schema"`
}
type EmbedRequest struct {
Input any `json:"input"`
Model string `json:"model"`
Dimensions int `json:"dimensions,omitempty"`
EncodingFormat string `json:"encoding_format,omitempty"` // "float" or "base64"
}
type StreamOptions struct {
IncludeUsage bool `json:"include_usage"`
}
type Reasoning struct {
Effort string `json:"effort,omitempty"`
}
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream bool `json:"stream"`
StreamOptions *StreamOptions `json:"stream_options"`
MaxTokens *int `json:"max_tokens"`
Seed *int `json:"seed"`
Stop any `json:"stop"`
Temperature *float64 `json:"temperature"`
FrequencyPenalty *float64 `json:"frequency_penalty"`
PresencePenalty *float64 `json:"presence_penalty"`
TopP *float64 `json:"top_p"`
ResponseFormat *ResponseFormat `json:"response_format"`
Tools []api.Tool `json:"tools"`
Reasoning *Reasoning `json:"reasoning,omitempty"`
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
Logprobs *bool `json:"logprobs"`
TopLogprobs int `json:"top_logprobs"`
DebugRenderOnly bool `json:"_debug_render_only"`
}
type ChatCompletion struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage,omitempty"`
DebugInfo *api.DebugInfo `json:"_debug_info,omitempty"`
}
type ChatCompletionChunk struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []ChunkChoice `json:"choices"`
Usage *Usage `json:"usage,omitempty"`
}
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
type CompletionRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
FrequencyPenalty float32 `json:"frequency_penalty"`
MaxTokens *int `json:"max_tokens"`
PresencePenalty float32 `json:"presence_penalty"`
Seed *int `json:"seed"`
Stop any `json:"stop"`
Stream bool `json:"stream"`
StreamOptions *StreamOptions `json:"stream_options"`
Temperature *float32 `json:"temperature"`
TopP float32 `json:"top_p"`
Suffix string `json:"suffix"`
Logprobs *int `json:"logprobs"`
DebugRenderOnly bool `json:"_debug_render_only"`
}
type Completion struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []CompleteChunkChoice `json:"choices"`
Usage Usage `json:"usage,omitempty"`
}
type CompletionChunk struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []CompleteChunkChoice `json:"choices"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Usage *Usage `json:"usage,omitempty"`
}
type ToolCall struct {
ID string `json:"id"`
Index int `json:"index"`
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
}
type Model struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
type Embedding struct {
Object string `json:"object"`
Embedding any `json:"embedding"` // Can be []float32 (float format) or string (base64 format)
Index int `json:"index"`
}
type ListCompletion struct {
Object string `json:"object"`
Data []Model `json:"data"`
}
type EmbeddingList struct {
Object string `json:"object"`
Data []Embedding `json:"data"`
Model string `json:"model"`
Usage EmbeddingUsage `json:"usage,omitempty"`
}
type EmbeddingUsage struct {
PromptTokens int `json:"prompt_tokens"`
TotalTokens int `json:"total_tokens"`
}
func NewError(code int, message string) ErrorResponse {
var etype string
switch code {
case http.StatusBadRequest:
etype = "invalid_request_error"
case http.StatusNotFound:
etype = "not_found_error"
default:
etype = "api_error"
}
return ErrorResponse{Error{Type: etype, Message: message}}
}
// ToUsage converts an api.ChatResponse to Usage
func ToUsage(r api.ChatResponse) Usage {
return Usage{
PromptTokens: r.Metrics.PromptEvalCount,
CompletionTokens: r.Metrics.EvalCount,
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
}
}
// ToToolCalls converts api.ToolCall to OpenAI ToolCall format
func ToToolCalls(tc []api.ToolCall) []ToolCall {
toolCalls := make([]ToolCall, len(tc))
for i, tc := range tc {
toolCalls[i].ID = tc.ID
toolCalls[i].Type = "function"
toolCalls[i].Function.Name = tc.Function.Name
toolCalls[i].Index = tc.Function.Index
args, err := json.Marshal(tc.Function.Arguments)
if err != nil {
slog.Error("could not marshall function arguments to json", "error", err)
continue
}
toolCalls[i].Function.Arguments = string(args)
}
return toolCalls
}
// ToChatCompletion converts an api.ChatResponse to ChatCompletion
func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion {
toolCalls := ToToolCalls(r.Message.ToolCalls)
var logprobs *ChoiceLogprobs
if len(r.Logprobs) > 0 {
logprobs = &ChoiceLogprobs{Content: r.Logprobs}
}
return ChatCompletion{
Id: id,
Object: "chat.completion",
Created: r.CreatedAt.Unix(),
Model: r.Model,
SystemFingerprint: "fp_ollama",
Choices: []Choice{{
Index: 0,
Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls, Reasoning: r.Message.Thinking},
FinishReason: func(reason string) *string {
if len(toolCalls) > 0 {
reason = "tool_calls"
}
if len(reason) > 0 {
return &reason
}
return nil
}(r.DoneReason),
Logprobs: logprobs,
}}, Usage: ToUsage(r),
DebugInfo: r.DebugInfo,
}
}
// ToChunk converts an api.ChatResponse to ChatCompletionChunk
func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
toolCalls := ToToolCalls(r.Message.ToolCalls)
var logprobs *ChoiceLogprobs
if len(r.Logprobs) > 0 {
logprobs = &ChoiceLogprobs{Content: r.Logprobs}
}
return ChatCompletionChunk{
Id: id,
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
Model: r.Model,
SystemFingerprint: "fp_ollama",
Choices: []ChunkChoice{{
Index: 0,
Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls, Reasoning: r.Message.Thinking},
FinishReason: func(reason string) *string {
if len(reason) > 0 {
if toolCallSent || len(toolCalls) > 0 {
return &finishReasonToolCalls
}
return &reason
}
return nil
}(r.DoneReason),
Logprobs: logprobs,
}},
}
}
// ToUsageGenerate converts an api.GenerateResponse to Usage
func ToUsageGenerate(r api.GenerateResponse) Usage {
return Usage{
PromptTokens: r.Metrics.PromptEvalCount,
CompletionTokens: r.Metrics.EvalCount,
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
}
}
// ToCompletion converts an api.GenerateResponse to Completion
func ToCompletion(id string, r api.GenerateResponse) Completion {
return Completion{
Id: id,
Object: "text_completion",
Created: r.CreatedAt.Unix(),
Model: r.Model,
SystemFingerprint: "fp_ollama",
Choices: []CompleteChunkChoice{{
Text: r.Response,
Index: 0,
FinishReason: func(reason string) *string {
if len(reason) > 0 {
return &reason
}
return nil
}(r.DoneReason),
}},
Usage: ToUsageGenerate(r),
}
}
// ToCompleteChunk converts an api.GenerateResponse to CompletionChunk
func ToCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
return CompletionChunk{
Id: id,
Object: "text_completion",
Created: time.Now().Unix(),
Model: r.Model,
SystemFingerprint: "fp_ollama",
Choices: []CompleteChunkChoice{{
Text: r.Response,
Index: 0,
FinishReason: func(reason string) *string {
if len(reason) > 0 {
return &reason
}
return nil
}(r.DoneReason),
}},
}
}
// ToListCompletion converts an api.ListResponse to ListCompletion
func ToListCompletion(r api.ListResponse) ListCompletion {
var data []Model
for _, m := range r.Models {
data = append(data, Model{
Id: m.Name,
Object: "model",
Created: m.ModifiedAt.Unix(),
OwnedBy: model.ParseName(m.Name).Namespace,
})
}
return ListCompletion{
Object: "list",
Data: data,
}
}
// ToEmbeddingList converts an api.EmbedResponse to EmbeddingList
// encodingFormat can be "float", "base64", or empty (defaults to "float")
func ToEmbeddingList(model string, r api.EmbedResponse, encodingFormat string) EmbeddingList {
if r.Embeddings != nil {
var data []Embedding
for i, e := range r.Embeddings {
var embedding any
if strings.EqualFold(encodingFormat, "base64") {
embedding = floatsToBase64(e)
} else {
embedding = e
}
data = append(data, Embedding{
Object: "embedding",
Embedding: embedding,
Index: i,
})
}
return EmbeddingList{
Object: "list",
Data: data,
Model: model,
Usage: EmbeddingUsage{
PromptTokens: r.PromptEvalCount,
TotalTokens: r.PromptEvalCount,
},
}
}
return EmbeddingList{}
}
// floatsToBase64 encodes a []float32 to a base64 string
func floatsToBase64(floats []float32) string {
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, floats)
return base64.StdEncoding.EncodeToString(buf.Bytes())
}
// ToModel converts an api.ShowResponse to Model
func ToModel(r api.ShowResponse, m string) Model {
return Model{
Id: m,
Object: "model",
Created: r.ModifiedAt.Unix(),
OwnedBy: model.ParseName(m).Namespace,
}
}
// FromChatRequest converts a ChatCompletionRequest to api.ChatRequest
func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
var messages []api.Message
for _, msg := range r.Messages {
toolName := ""
if strings.ToLower(msg.Role) == "tool" {
toolName = msg.Name
if toolName == "" && msg.ToolCallID != "" {
toolName = nameFromToolCallID(r.Messages, msg.ToolCallID)
}
}
switch content := msg.Content.(type) {
case string:
toolCalls, err := FromCompletionToolCall(msg.ToolCalls)
if err != nil {
return nil, err
}
messages = append(messages, api.Message{Role: msg.Role, Content: content, Thinking: msg.Reasoning, ToolCalls: toolCalls, ToolName: toolName, ToolCallID: msg.ToolCallID})
case []any:
for _, c := range content {
data, ok := c.(map[string]any)
if !ok {
return nil, errors.New("invalid message format")
}
switch data["type"] {
case "text":
text, ok := data["text"].(string)
if !ok {
return nil, errors.New("invalid message format")
}
messages = append(messages, api.Message{Role: msg.Role, Content: text})
case "image_url":
var url string
if urlMap, ok := data["image_url"].(map[string]any); ok {
if url, ok = urlMap["url"].(string); !ok {
return nil, errors.New("invalid message format")
}
} else {
if url, ok = data["image_url"].(string); !ok {
return nil, errors.New("invalid message format")
}
}
types := []string{"jpeg", "jpg", "png", "webp"}
valid := false
// support blank mime type to match api/chat taking just unadorned base64
if strings.HasPrefix(url, "data:;base64,") {
url = strings.TrimPrefix(url, "data:;base64,")
valid = true
}
for _, t := range types {
prefix := "data:image/" + t + ";base64,"
if strings.HasPrefix(url, prefix) {
url = strings.TrimPrefix(url, prefix)
valid = true
break
}
}
if !valid {
return nil, errors.New("invalid image input")
}
img, err := base64.StdEncoding.DecodeString(url)
if err != nil {
return nil, errors.New("invalid message format")
}
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
default:
return nil, errors.New("invalid message format")
}
}
// since we might have added multiple messages above, if we have tools
// calls we'll add them to the last message
if len(messages) > 0 && len(msg.ToolCalls) > 0 {
toolCalls, err := FromCompletionToolCall(msg.ToolCalls)
if err != nil {
return nil, err
}
messages[len(messages)-1].ToolCalls = toolCalls
messages[len(messages)-1].ToolName = toolName
messages[len(messages)-1].ToolCallID = msg.ToolCallID
messages[len(messages)-1].Thinking = msg.Reasoning
}
default:
// content is only optional if tool calls are present
if msg.ToolCalls == nil {
return nil, fmt.Errorf("invalid message content type: %T", content)
}
toolCalls, err := FromCompletionToolCall(msg.ToolCalls)
if err != nil {
return nil, err
}
messages = append(messages, api.Message{Role: msg.Role, Thinking: msg.Reasoning, ToolCalls: toolCalls, ToolCallID: msg.ToolCallID})
}
}
options := make(map[string]any)
switch stop := r.Stop.(type) {
case string:
options["stop"] = []string{stop}
case []any:
var stops []string
for _, s := range stop {
if str, ok := s.(string); ok {
stops = append(stops, str)
}
}
options["stop"] = stops
}
if r.MaxTokens != nil {
options["num_predict"] = *r.MaxTokens
}
if r.Temperature != nil {
options["temperature"] = *r.Temperature
} else {
options["temperature"] = 1.0
}
if r.Seed != nil {
options["seed"] = *r.Seed
}
if r.FrequencyPenalty != nil {
options["frequency_penalty"] = *r.FrequencyPenalty
}
if r.PresencePenalty != nil {
options["presence_penalty"] = *r.PresencePenalty
}
if r.TopP != nil {
options["top_p"] = *r.TopP
} else {
options["top_p"] = 1.0
}
var format json.RawMessage
if r.ResponseFormat != nil {
switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
// Support the old "json_object" type for OpenAI compatibility
case "json_object":
format = json.RawMessage(`"json"`)
case "json_schema":
if r.ResponseFormat.JsonSchema != nil {
format = r.ResponseFormat.JsonSchema.Schema
}
}
}
var think *api.ThinkValue
var effort string
if r.Reasoning != nil {
effort = r.Reasoning.Effort
} else if r.ReasoningEffort != nil {
effort = *r.ReasoningEffort
}
if effort != "" {
if !slices.Contains([]string{"high", "medium", "low", "none"}, effort) {
return nil, fmt.Errorf("invalid reasoning value: '%s' (must be \"high\", \"medium\", \"low\", or \"none\")", effort)
}
if effort == "none" {
think = &api.ThinkValue{Value: false}
} else {
think = &api.ThinkValue{Value: effort}
}
}
return &api.ChatRequest{
Model: r.Model,
Messages: messages,
Format: format,
Options: options,
Stream: &r.Stream,
Tools: r.Tools,
Think: think,
Logprobs: r.Logprobs != nil && *r.Logprobs,
TopLogprobs: r.TopLogprobs,
DebugRenderOnly: r.DebugRenderOnly,
}, nil
}
func nameFromToolCallID(messages []Message, toolCallID string) string {
// iterate backwards to be more resilient to duplicate tool call IDs (this
// follows "last one wins")
for i := len(messages) - 1; i >= 0; i-- {
msg := messages[i]
for _, tc := range msg.ToolCalls {
if tc.ID == toolCallID {
return tc.Function.Name
}
}
}
return ""
}
// FromCompletionToolCall converts OpenAI ToolCall format to api.ToolCall
func FromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) {
apiToolCalls := make([]api.ToolCall, len(toolCalls))
for i, tc := range toolCalls {
apiToolCalls[i].ID = tc.ID
apiToolCalls[i].Function.Name = tc.Function.Name
err := json.Unmarshal([]byte(tc.Function.Arguments), &apiToolCalls[i].Function.Arguments)
if err != nil {
return nil, errors.New("invalid tool call arguments")
}
}
return apiToolCalls, nil
}
// FromCompleteRequest converts a CompletionRequest to api.GenerateRequest
func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
options := make(map[string]any)
switch stop := r.Stop.(type) {
case string:
options["stop"] = []string{stop}
case []any:
var stops []string
for _, s := range stop {
if str, ok := s.(string); ok {
stops = append(stops, str)
} else {
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
}
}
options["stop"] = stops
}
if r.MaxTokens != nil {
options["num_predict"] = *r.MaxTokens
}
if r.Temperature != nil {
options["temperature"] = *r.Temperature
} else {
options["temperature"] = 1.0
}
if r.Seed != nil {
options["seed"] = *r.Seed
}
options["frequency_penalty"] = r.FrequencyPenalty
options["presence_penalty"] = r.PresencePenalty
if r.TopP != 0.0 {
options["top_p"] = r.TopP
} else {
options["top_p"] = 1.0
}
var logprobs bool
var topLogprobs int
if r.Logprobs != nil && *r.Logprobs > 0 {
logprobs = true
topLogprobs = *r.Logprobs
}
return api.GenerateRequest{
Model: r.Model,
Prompt: r.Prompt,
Options: options,
Stream: &r.Stream,
Suffix: r.Suffix,
Logprobs: logprobs,
TopLogprobs: topLogprobs,
DebugRenderOnly: r.DebugRenderOnly,
}, nil
}