mirror of
https://github.com/ollama/ollama.git
synced 2025-12-05 18:46:22 -06:00
Adds logprobs support to Ollama's API including support for Ollama's OpenAI-compatible API. By specifying the new 'logprobs' boolean parameter in the API, Ollama will return the log probabilities for each token generated. 'top_logprobs', an integer value can also be specified up to the value 20. When specified, the API will also provide the number of most likely tokens to return at each token position Co-authored-by: Baptiste Jamin <baptiste@crisp.chat>
727 lines
20 KiB
Go
727 lines
20 KiB
Go
// openai package provides core transformation logic for partial compatibility with the OpenAI REST API
|
|
package openai
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/base64"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"slices"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
"github.com/ollama/ollama/types/model"
|
|
)
|
|
|
|
var finishReasonToolCalls = "tool_calls"
|
|
|
|
type Error struct {
|
|
Message string `json:"message"`
|
|
Type string `json:"type"`
|
|
Param any `json:"param"`
|
|
Code *string `json:"code"`
|
|
}
|
|
|
|
type ErrorResponse struct {
|
|
Error Error `json:"error"`
|
|
}
|
|
|
|
type Message struct {
|
|
Role string `json:"role"`
|
|
Content any `json:"content"`
|
|
Reasoning string `json:"reasoning,omitempty"`
|
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
|
Name string `json:"name,omitempty"`
|
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
|
}
|
|
|
|
type ChoiceLogprobs struct {
|
|
Content []api.Logprob `json:"content"`
|
|
}
|
|
|
|
type Choice struct {
|
|
Index int `json:"index"`
|
|
Message Message `json:"message"`
|
|
FinishReason *string `json:"finish_reason"`
|
|
Logprobs *ChoiceLogprobs `json:"logprobs,omitempty"`
|
|
}
|
|
|
|
type ChunkChoice struct {
|
|
Index int `json:"index"`
|
|
Delta Message `json:"delta"`
|
|
FinishReason *string `json:"finish_reason"`
|
|
Logprobs *ChoiceLogprobs `json:"logprobs,omitempty"`
|
|
}
|
|
|
|
type CompleteChunkChoice struct {
|
|
Text string `json:"text"`
|
|
Index int `json:"index"`
|
|
FinishReason *string `json:"finish_reason"`
|
|
Logprobs *ChoiceLogprobs `json:"logprobs,omitempty"`
|
|
}
|
|
|
|
type Usage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
}
|
|
|
|
type ResponseFormat struct {
|
|
Type string `json:"type"`
|
|
JsonSchema *JsonSchema `json:"json_schema,omitempty"`
|
|
}
|
|
|
|
type JsonSchema struct {
|
|
Schema json.RawMessage `json:"schema"`
|
|
}
|
|
|
|
type EmbedRequest struct {
|
|
Input any `json:"input"`
|
|
Model string `json:"model"`
|
|
Dimensions int `json:"dimensions,omitempty"`
|
|
EncodingFormat string `json:"encoding_format,omitempty"` // "float" or "base64"
|
|
}
|
|
|
|
type StreamOptions struct {
|
|
IncludeUsage bool `json:"include_usage"`
|
|
}
|
|
|
|
type Reasoning struct {
|
|
Effort string `json:"effort,omitempty"`
|
|
}
|
|
|
|
type ChatCompletionRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []Message `json:"messages"`
|
|
Stream bool `json:"stream"`
|
|
StreamOptions *StreamOptions `json:"stream_options"`
|
|
MaxTokens *int `json:"max_tokens"`
|
|
Seed *int `json:"seed"`
|
|
Stop any `json:"stop"`
|
|
Temperature *float64 `json:"temperature"`
|
|
FrequencyPenalty *float64 `json:"frequency_penalty"`
|
|
PresencePenalty *float64 `json:"presence_penalty"`
|
|
TopP *float64 `json:"top_p"`
|
|
ResponseFormat *ResponseFormat `json:"response_format"`
|
|
Tools []api.Tool `json:"tools"`
|
|
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
|
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
|
Logprobs *bool `json:"logprobs"`
|
|
TopLogprobs int `json:"top_logprobs"`
|
|
DebugRenderOnly bool `json:"_debug_render_only"`
|
|
}
|
|
|
|
type ChatCompletion struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
Choices []Choice `json:"choices"`
|
|
Usage Usage `json:"usage,omitempty"`
|
|
DebugInfo *api.DebugInfo `json:"_debug_info,omitempty"`
|
|
}
|
|
|
|
type ChatCompletionChunk struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
Choices []ChunkChoice `json:"choices"`
|
|
Usage *Usage `json:"usage,omitempty"`
|
|
}
|
|
|
|
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
|
|
type CompletionRequest struct {
|
|
Model string `json:"model"`
|
|
Prompt string `json:"prompt"`
|
|
FrequencyPenalty float32 `json:"frequency_penalty"`
|
|
MaxTokens *int `json:"max_tokens"`
|
|
PresencePenalty float32 `json:"presence_penalty"`
|
|
Seed *int `json:"seed"`
|
|
Stop any `json:"stop"`
|
|
Stream bool `json:"stream"`
|
|
StreamOptions *StreamOptions `json:"stream_options"`
|
|
Temperature *float32 `json:"temperature"`
|
|
TopP float32 `json:"top_p"`
|
|
Suffix string `json:"suffix"`
|
|
Logprobs *int `json:"logprobs"`
|
|
DebugRenderOnly bool `json:"_debug_render_only"`
|
|
}
|
|
|
|
type Completion struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
Choices []CompleteChunkChoice `json:"choices"`
|
|
Usage Usage `json:"usage,omitempty"`
|
|
}
|
|
|
|
type CompletionChunk struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Choices []CompleteChunkChoice `json:"choices"`
|
|
Model string `json:"model"`
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
Usage *Usage `json:"usage,omitempty"`
|
|
}
|
|
|
|
type ToolCall struct {
|
|
ID string `json:"id"`
|
|
Index int `json:"index"`
|
|
Type string `json:"type"`
|
|
Function struct {
|
|
Name string `json:"name"`
|
|
Arguments string `json:"arguments"`
|
|
} `json:"function"`
|
|
}
|
|
|
|
type Model struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
OwnedBy string `json:"owned_by"`
|
|
}
|
|
|
|
type Embedding struct {
|
|
Object string `json:"object"`
|
|
Embedding any `json:"embedding"` // Can be []float32 (float format) or string (base64 format)
|
|
Index int `json:"index"`
|
|
}
|
|
|
|
type ListCompletion struct {
|
|
Object string `json:"object"`
|
|
Data []Model `json:"data"`
|
|
}
|
|
|
|
type EmbeddingList struct {
|
|
Object string `json:"object"`
|
|
Data []Embedding `json:"data"`
|
|
Model string `json:"model"`
|
|
Usage EmbeddingUsage `json:"usage,omitempty"`
|
|
}
|
|
|
|
type EmbeddingUsage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
}
|
|
|
|
func NewError(code int, message string) ErrorResponse {
|
|
var etype string
|
|
switch code {
|
|
case http.StatusBadRequest:
|
|
etype = "invalid_request_error"
|
|
case http.StatusNotFound:
|
|
etype = "not_found_error"
|
|
default:
|
|
etype = "api_error"
|
|
}
|
|
|
|
return ErrorResponse{Error{Type: etype, Message: message}}
|
|
}
|
|
|
|
// ToUsage converts an api.ChatResponse to Usage
|
|
func ToUsage(r api.ChatResponse) Usage {
|
|
return Usage{
|
|
PromptTokens: r.Metrics.PromptEvalCount,
|
|
CompletionTokens: r.Metrics.EvalCount,
|
|
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
|
|
}
|
|
}
|
|
|
|
// ToToolCalls converts api.ToolCall to OpenAI ToolCall format
|
|
func ToToolCalls(tc []api.ToolCall) []ToolCall {
|
|
toolCalls := make([]ToolCall, len(tc))
|
|
for i, tc := range tc {
|
|
toolCalls[i].ID = tc.ID
|
|
toolCalls[i].Type = "function"
|
|
toolCalls[i].Function.Name = tc.Function.Name
|
|
toolCalls[i].Index = tc.Function.Index
|
|
|
|
args, err := json.Marshal(tc.Function.Arguments)
|
|
if err != nil {
|
|
slog.Error("could not marshall function arguments to json", "error", err)
|
|
continue
|
|
}
|
|
|
|
toolCalls[i].Function.Arguments = string(args)
|
|
}
|
|
return toolCalls
|
|
}
|
|
|
|
// ToChatCompletion converts an api.ChatResponse to ChatCompletion
|
|
func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
|
toolCalls := ToToolCalls(r.Message.ToolCalls)
|
|
|
|
var logprobs *ChoiceLogprobs
|
|
if len(r.Logprobs) > 0 {
|
|
logprobs = &ChoiceLogprobs{Content: r.Logprobs}
|
|
}
|
|
|
|
return ChatCompletion{
|
|
Id: id,
|
|
Object: "chat.completion",
|
|
Created: r.CreatedAt.Unix(),
|
|
Model: r.Model,
|
|
SystemFingerprint: "fp_ollama",
|
|
Choices: []Choice{{
|
|
Index: 0,
|
|
Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls, Reasoning: r.Message.Thinking},
|
|
FinishReason: func(reason string) *string {
|
|
if len(toolCalls) > 0 {
|
|
reason = "tool_calls"
|
|
}
|
|
if len(reason) > 0 {
|
|
return &reason
|
|
}
|
|
return nil
|
|
}(r.DoneReason),
|
|
Logprobs: logprobs,
|
|
}}, Usage: ToUsage(r),
|
|
DebugInfo: r.DebugInfo,
|
|
}
|
|
}
|
|
|
|
// ToChunk converts an api.ChatResponse to ChatCompletionChunk
|
|
func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
|
|
toolCalls := ToToolCalls(r.Message.ToolCalls)
|
|
|
|
var logprobs *ChoiceLogprobs
|
|
if len(r.Logprobs) > 0 {
|
|
logprobs = &ChoiceLogprobs{Content: r.Logprobs}
|
|
}
|
|
|
|
return ChatCompletionChunk{
|
|
Id: id,
|
|
Object: "chat.completion.chunk",
|
|
Created: time.Now().Unix(),
|
|
Model: r.Model,
|
|
SystemFingerprint: "fp_ollama",
|
|
Choices: []ChunkChoice{{
|
|
Index: 0,
|
|
Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls, Reasoning: r.Message.Thinking},
|
|
FinishReason: func(reason string) *string {
|
|
if len(reason) > 0 {
|
|
if toolCallSent || len(toolCalls) > 0 {
|
|
return &finishReasonToolCalls
|
|
}
|
|
return &reason
|
|
}
|
|
return nil
|
|
}(r.DoneReason),
|
|
Logprobs: logprobs,
|
|
}},
|
|
}
|
|
}
|
|
|
|
// ToUsageGenerate converts an api.GenerateResponse to Usage
|
|
func ToUsageGenerate(r api.GenerateResponse) Usage {
|
|
return Usage{
|
|
PromptTokens: r.Metrics.PromptEvalCount,
|
|
CompletionTokens: r.Metrics.EvalCount,
|
|
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
|
|
}
|
|
}
|
|
|
|
// ToCompletion converts an api.GenerateResponse to Completion
|
|
func ToCompletion(id string, r api.GenerateResponse) Completion {
|
|
return Completion{
|
|
Id: id,
|
|
Object: "text_completion",
|
|
Created: r.CreatedAt.Unix(),
|
|
Model: r.Model,
|
|
SystemFingerprint: "fp_ollama",
|
|
Choices: []CompleteChunkChoice{{
|
|
Text: r.Response,
|
|
Index: 0,
|
|
FinishReason: func(reason string) *string {
|
|
if len(reason) > 0 {
|
|
return &reason
|
|
}
|
|
return nil
|
|
}(r.DoneReason),
|
|
}},
|
|
Usage: ToUsageGenerate(r),
|
|
}
|
|
}
|
|
|
|
// ToCompleteChunk converts an api.GenerateResponse to CompletionChunk
|
|
func ToCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
|
|
return CompletionChunk{
|
|
Id: id,
|
|
Object: "text_completion",
|
|
Created: time.Now().Unix(),
|
|
Model: r.Model,
|
|
SystemFingerprint: "fp_ollama",
|
|
Choices: []CompleteChunkChoice{{
|
|
Text: r.Response,
|
|
Index: 0,
|
|
FinishReason: func(reason string) *string {
|
|
if len(reason) > 0 {
|
|
return &reason
|
|
}
|
|
return nil
|
|
}(r.DoneReason),
|
|
}},
|
|
}
|
|
}
|
|
|
|
// ToListCompletion converts an api.ListResponse to ListCompletion
|
|
func ToListCompletion(r api.ListResponse) ListCompletion {
|
|
var data []Model
|
|
for _, m := range r.Models {
|
|
data = append(data, Model{
|
|
Id: m.Name,
|
|
Object: "model",
|
|
Created: m.ModifiedAt.Unix(),
|
|
OwnedBy: model.ParseName(m.Name).Namespace,
|
|
})
|
|
}
|
|
|
|
return ListCompletion{
|
|
Object: "list",
|
|
Data: data,
|
|
}
|
|
}
|
|
|
|
// ToEmbeddingList converts an api.EmbedResponse to EmbeddingList
|
|
// encodingFormat can be "float", "base64", or empty (defaults to "float")
|
|
func ToEmbeddingList(model string, r api.EmbedResponse, encodingFormat string) EmbeddingList {
|
|
if r.Embeddings != nil {
|
|
var data []Embedding
|
|
for i, e := range r.Embeddings {
|
|
var embedding any
|
|
if strings.EqualFold(encodingFormat, "base64") {
|
|
embedding = floatsToBase64(e)
|
|
} else {
|
|
embedding = e
|
|
}
|
|
|
|
data = append(data, Embedding{
|
|
Object: "embedding",
|
|
Embedding: embedding,
|
|
Index: i,
|
|
})
|
|
}
|
|
|
|
return EmbeddingList{
|
|
Object: "list",
|
|
Data: data,
|
|
Model: model,
|
|
Usage: EmbeddingUsage{
|
|
PromptTokens: r.PromptEvalCount,
|
|
TotalTokens: r.PromptEvalCount,
|
|
},
|
|
}
|
|
}
|
|
|
|
return EmbeddingList{}
|
|
}
|
|
|
|
// floatsToBase64 encodes a []float32 to a base64 string
|
|
func floatsToBase64(floats []float32) string {
|
|
var buf bytes.Buffer
|
|
binary.Write(&buf, binary.LittleEndian, floats)
|
|
return base64.StdEncoding.EncodeToString(buf.Bytes())
|
|
}
|
|
|
|
// ToModel converts an api.ShowResponse to Model
|
|
func ToModel(r api.ShowResponse, m string) Model {
|
|
return Model{
|
|
Id: m,
|
|
Object: "model",
|
|
Created: r.ModifiedAt.Unix(),
|
|
OwnedBy: model.ParseName(m).Namespace,
|
|
}
|
|
}
|
|
|
|
// FromChatRequest converts a ChatCompletionRequest to api.ChatRequest
|
|
func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|
var messages []api.Message
|
|
for _, msg := range r.Messages {
|
|
toolName := ""
|
|
if strings.ToLower(msg.Role) == "tool" {
|
|
toolName = msg.Name
|
|
if toolName == "" && msg.ToolCallID != "" {
|
|
toolName = nameFromToolCallID(r.Messages, msg.ToolCallID)
|
|
}
|
|
}
|
|
switch content := msg.Content.(type) {
|
|
case string:
|
|
toolCalls, err := FromCompletionToolCall(msg.ToolCalls)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
messages = append(messages, api.Message{Role: msg.Role, Content: content, Thinking: msg.Reasoning, ToolCalls: toolCalls, ToolName: toolName, ToolCallID: msg.ToolCallID})
|
|
case []any:
|
|
for _, c := range content {
|
|
data, ok := c.(map[string]any)
|
|
if !ok {
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
switch data["type"] {
|
|
case "text":
|
|
text, ok := data["text"].(string)
|
|
if !ok {
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
messages = append(messages, api.Message{Role: msg.Role, Content: text})
|
|
case "image_url":
|
|
var url string
|
|
if urlMap, ok := data["image_url"].(map[string]any); ok {
|
|
if url, ok = urlMap["url"].(string); !ok {
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
} else {
|
|
if url, ok = data["image_url"].(string); !ok {
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
}
|
|
|
|
types := []string{"jpeg", "jpg", "png", "webp"}
|
|
valid := false
|
|
// support blank mime type to match api/chat taking just unadorned base64
|
|
if strings.HasPrefix(url, "data:;base64,") {
|
|
url = strings.TrimPrefix(url, "data:;base64,")
|
|
valid = true
|
|
}
|
|
for _, t := range types {
|
|
prefix := "data:image/" + t + ";base64,"
|
|
if strings.HasPrefix(url, prefix) {
|
|
url = strings.TrimPrefix(url, prefix)
|
|
valid = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !valid {
|
|
return nil, errors.New("invalid image input")
|
|
}
|
|
|
|
img, err := base64.StdEncoding.DecodeString(url)
|
|
if err != nil {
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
|
|
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
|
|
default:
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
}
|
|
// since we might have added multiple messages above, if we have tools
|
|
// calls we'll add them to the last message
|
|
if len(messages) > 0 && len(msg.ToolCalls) > 0 {
|
|
toolCalls, err := FromCompletionToolCall(msg.ToolCalls)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
messages[len(messages)-1].ToolCalls = toolCalls
|
|
messages[len(messages)-1].ToolName = toolName
|
|
messages[len(messages)-1].ToolCallID = msg.ToolCallID
|
|
messages[len(messages)-1].Thinking = msg.Reasoning
|
|
}
|
|
default:
|
|
// content is only optional if tool calls are present
|
|
if msg.ToolCalls == nil {
|
|
return nil, fmt.Errorf("invalid message content type: %T", content)
|
|
}
|
|
|
|
toolCalls, err := FromCompletionToolCall(msg.ToolCalls)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
messages = append(messages, api.Message{Role: msg.Role, Thinking: msg.Reasoning, ToolCalls: toolCalls, ToolCallID: msg.ToolCallID})
|
|
}
|
|
}
|
|
|
|
options := make(map[string]any)
|
|
|
|
switch stop := r.Stop.(type) {
|
|
case string:
|
|
options["stop"] = []string{stop}
|
|
case []any:
|
|
var stops []string
|
|
for _, s := range stop {
|
|
if str, ok := s.(string); ok {
|
|
stops = append(stops, str)
|
|
}
|
|
}
|
|
options["stop"] = stops
|
|
}
|
|
|
|
if r.MaxTokens != nil {
|
|
options["num_predict"] = *r.MaxTokens
|
|
}
|
|
|
|
if r.Temperature != nil {
|
|
options["temperature"] = *r.Temperature
|
|
} else {
|
|
options["temperature"] = 1.0
|
|
}
|
|
|
|
if r.Seed != nil {
|
|
options["seed"] = *r.Seed
|
|
}
|
|
|
|
if r.FrequencyPenalty != nil {
|
|
options["frequency_penalty"] = *r.FrequencyPenalty
|
|
}
|
|
|
|
if r.PresencePenalty != nil {
|
|
options["presence_penalty"] = *r.PresencePenalty
|
|
}
|
|
|
|
if r.TopP != nil {
|
|
options["top_p"] = *r.TopP
|
|
} else {
|
|
options["top_p"] = 1.0
|
|
}
|
|
|
|
var format json.RawMessage
|
|
if r.ResponseFormat != nil {
|
|
switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
|
|
// Support the old "json_object" type for OpenAI compatibility
|
|
case "json_object":
|
|
format = json.RawMessage(`"json"`)
|
|
case "json_schema":
|
|
if r.ResponseFormat.JsonSchema != nil {
|
|
format = r.ResponseFormat.JsonSchema.Schema
|
|
}
|
|
}
|
|
}
|
|
|
|
var think *api.ThinkValue
|
|
var effort string
|
|
|
|
if r.Reasoning != nil {
|
|
effort = r.Reasoning.Effort
|
|
} else if r.ReasoningEffort != nil {
|
|
effort = *r.ReasoningEffort
|
|
}
|
|
|
|
if effort != "" {
|
|
if !slices.Contains([]string{"high", "medium", "low", "none"}, effort) {
|
|
return nil, fmt.Errorf("invalid reasoning value: '%s' (must be \"high\", \"medium\", \"low\", or \"none\")", effort)
|
|
}
|
|
|
|
if effort == "none" {
|
|
think = &api.ThinkValue{Value: false}
|
|
} else {
|
|
think = &api.ThinkValue{Value: effort}
|
|
}
|
|
}
|
|
|
|
return &api.ChatRequest{
|
|
Model: r.Model,
|
|
Messages: messages,
|
|
Format: format,
|
|
Options: options,
|
|
Stream: &r.Stream,
|
|
Tools: r.Tools,
|
|
Think: think,
|
|
Logprobs: r.Logprobs != nil && *r.Logprobs,
|
|
TopLogprobs: r.TopLogprobs,
|
|
DebugRenderOnly: r.DebugRenderOnly,
|
|
}, nil
|
|
}
|
|
|
|
func nameFromToolCallID(messages []Message, toolCallID string) string {
|
|
// iterate backwards to be more resilient to duplicate tool call IDs (this
|
|
// follows "last one wins")
|
|
for i := len(messages) - 1; i >= 0; i-- {
|
|
msg := messages[i]
|
|
for _, tc := range msg.ToolCalls {
|
|
if tc.ID == toolCallID {
|
|
return tc.Function.Name
|
|
}
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// FromCompletionToolCall converts OpenAI ToolCall format to api.ToolCall
|
|
func FromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) {
|
|
apiToolCalls := make([]api.ToolCall, len(toolCalls))
|
|
for i, tc := range toolCalls {
|
|
apiToolCalls[i].ID = tc.ID
|
|
apiToolCalls[i].Function.Name = tc.Function.Name
|
|
err := json.Unmarshal([]byte(tc.Function.Arguments), &apiToolCalls[i].Function.Arguments)
|
|
if err != nil {
|
|
return nil, errors.New("invalid tool call arguments")
|
|
}
|
|
}
|
|
|
|
return apiToolCalls, nil
|
|
}
|
|
|
|
// FromCompleteRequest converts a CompletionRequest to api.GenerateRequest
|
|
func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
|
options := make(map[string]any)
|
|
|
|
switch stop := r.Stop.(type) {
|
|
case string:
|
|
options["stop"] = []string{stop}
|
|
case []any:
|
|
var stops []string
|
|
for _, s := range stop {
|
|
if str, ok := s.(string); ok {
|
|
stops = append(stops, str)
|
|
} else {
|
|
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
|
|
}
|
|
}
|
|
options["stop"] = stops
|
|
}
|
|
|
|
if r.MaxTokens != nil {
|
|
options["num_predict"] = *r.MaxTokens
|
|
}
|
|
|
|
if r.Temperature != nil {
|
|
options["temperature"] = *r.Temperature
|
|
} else {
|
|
options["temperature"] = 1.0
|
|
}
|
|
|
|
if r.Seed != nil {
|
|
options["seed"] = *r.Seed
|
|
}
|
|
|
|
options["frequency_penalty"] = r.FrequencyPenalty
|
|
|
|
options["presence_penalty"] = r.PresencePenalty
|
|
|
|
if r.TopP != 0.0 {
|
|
options["top_p"] = r.TopP
|
|
} else {
|
|
options["top_p"] = 1.0
|
|
}
|
|
|
|
var logprobs bool
|
|
var topLogprobs int
|
|
if r.Logprobs != nil && *r.Logprobs > 0 {
|
|
logprobs = true
|
|
topLogprobs = *r.Logprobs
|
|
}
|
|
|
|
return api.GenerateRequest{
|
|
Model: r.Model,
|
|
Prompt: r.Prompt,
|
|
Options: options,
|
|
Stream: &r.Stream,
|
|
Suffix: r.Suffix,
|
|
Logprobs: logprobs,
|
|
TopLogprobs: topLogprobs,
|
|
DebugRenderOnly: r.DebugRenderOnly,
|
|
}, nil
|
|
}
|