mirror of
https://github.com/ollama/ollama.git
synced 2026-03-11 17:34:04 -05:00
server: optimize chatPrompt to reduce tokenization calls (#14040)
Change the truncation algorithm to start with all messages and remove from the front until it fits, rather than adding messages one at a time from the back. This reduces tokenization calls from O(n) to O(1) in the common case where all messages fit in context.
This commit is contained in:
@@ -27,14 +27,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
// Clip images are represented as 768 tokens, each an embedding
|
||||
imageNumTokens := 768
|
||||
|
||||
n := len(msgs) - 1
|
||||
// in reverse, find all messages that fit into context window
|
||||
for i := n; i >= 0; i-- {
|
||||
// always include the last message
|
||||
if i == n {
|
||||
continue
|
||||
}
|
||||
lastMsgIdx := len(msgs) - 1
|
||||
currMsgIdx := 0
|
||||
|
||||
// Start with all messages and remove from the front until it fits in context
|
||||
for i := 0; i <= lastMsgIdx; i++ {
|
||||
// Collect system messages from the portion we're about to skip
|
||||
system = make([]api.Message, 0)
|
||||
for j := range i {
|
||||
if msgs[j].Role == "system" {
|
||||
@@ -54,20 +52,26 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
|
||||
ctxLen := len(s)
|
||||
if m.ProjectorPaths != nil {
|
||||
for _, m := range msgs[i:] {
|
||||
ctxLen += imageNumTokens * len(m.Images)
|
||||
for _, msg := range msgs[i:] {
|
||||
ctxLen += imageNumTokens * len(msg.Images)
|
||||
}
|
||||
}
|
||||
|
||||
if truncate && ctxLen > opts.NumCtx {
|
||||
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
|
||||
if !truncate || ctxLen <= opts.NumCtx {
|
||||
currMsgIdx = i
|
||||
break
|
||||
}
|
||||
|
||||
// Must always include at least the last message
|
||||
if i == lastMsgIdx {
|
||||
currMsgIdx = lastMsgIdx
|
||||
break
|
||||
} else {
|
||||
n = i
|
||||
}
|
||||
}
|
||||
|
||||
currMsgIdx := n
|
||||
if currMsgIdx > 0 {
|
||||
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[currMsgIdx:]))
|
||||
}
|
||||
|
||||
for cnt, msg := range msgs[currMsgIdx:] {
|
||||
if slices.Contains(m.Config.ModelFamilies, "mllama") && len(msg.Images) > 1 {
|
||||
|
||||
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -264,3 +265,68 @@ func TestChatPrompt(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatPromptTokenizeCalls(t *testing.T) {
|
||||
tmpl, err := template.Parse(`
|
||||
{{- if .System }}{{ .System }} {{ end }}
|
||||
{{- if .Prompt }}{{ .Prompt }} {{ end }}
|
||||
{{- if .Response }}{{ .Response }} {{ end }}`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
model := Model{Template: tmpl}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
limit int
|
||||
msgs []api.Message
|
||||
maxTokenizes int
|
||||
}{
|
||||
{
|
||||
name: "all messages fit",
|
||||
limit: 2048,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "message 1"},
|
||||
{Role: "assistant", Content: "response 1"},
|
||||
{Role: "user", Content: "message 2"},
|
||||
{Role: "assistant", Content: "response 2"},
|
||||
{Role: "user", Content: "message 3"},
|
||||
},
|
||||
maxTokenizes: 1,
|
||||
},
|
||||
{
|
||||
name: "truncate to last message",
|
||||
limit: 5,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "message 1"},
|
||||
{Role: "assistant", Content: "response 1"},
|
||||
{Role: "user", Content: "message 2"},
|
||||
{Role: "assistant", Content: "response 2"},
|
||||
{Role: "user", Content: "message 3"},
|
||||
},
|
||||
maxTokenizes: 5,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokenizeCount := 0
|
||||
countingTokenize := func(ctx context.Context, s string) ([]int, error) {
|
||||
tokenizeCount++
|
||||
tokens, err := mockRunner{}.Tokenize(ctx, s)
|
||||
return tokens, err
|
||||
}
|
||||
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||
think := false
|
||||
_, _, err := chatPrompt(t.Context(), &model, countingTokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if tokenizeCount > tt.maxTokenizes {
|
||||
t.Errorf("tokenize called %d times, expected at most %d", tokenizeCount, tt.maxTokenizes)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user