From df70249520fda991a83f607d485fbf4e64cfe1fd Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Wed, 4 Feb 2026 01:21:31 -0800 Subject: [PATCH] server: optimize chatPrompt to reduce tokenization calls (#14040) Change the truncation algorithm to start with all messages and remove from the front until it fits, rather than adding messages one at a time from the back. This reduces tokenization calls from O(n) to O(1) in the common case where all messages fit in context. --- server/prompt.go | 32 ++++++++++++--------- server/prompt_test.go | 66 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 14 deletions(-) diff --git a/server/prompt.go b/server/prompt.go index 217591982..bc12f4d5d 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -27,14 +27,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. // Clip images are represented as 768 tokens, each an embedding imageNumTokens := 768 - n := len(msgs) - 1 - // in reverse, find all messages that fit into context window - for i := n; i >= 0; i-- { - // always include the last message - if i == n { - continue - } + lastMsgIdx := len(msgs) - 1 + currMsgIdx := 0 + // Start with all messages and remove from the front until it fits in context + for i := 0; i <= lastMsgIdx; i++ { + // Collect system messages from the portion we're about to skip system = make([]api.Message, 0) for j := range i { if msgs[j].Role == "system" { @@ -54,20 +52,26 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. ctxLen := len(s) if m.ProjectorPaths != nil { - for _, m := range msgs[i:] { - ctxLen += imageNumTokens * len(m.Images) + for _, msg := range msgs[i:] { + ctxLen += imageNumTokens * len(msg.Images) } } - if truncate && ctxLen > opts.NumCtx { - slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:])) + if !truncate || ctxLen <= opts.NumCtx { + currMsgIdx = i + break + } + + // Must always include at least the last message + if i == lastMsgIdx { + currMsgIdx = lastMsgIdx break - } else { - n = i } } - currMsgIdx := n + if currMsgIdx > 0 { + slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[currMsgIdx:])) + } for cnt, msg := range msgs[currMsgIdx:] { if slices.Contains(m.Config.ModelFamilies, "mllama") && len(msg.Images) > 1 { diff --git a/server/prompt_test.go b/server/prompt_test.go index 3bd621152..082667b83 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -2,6 +2,7 @@ package server import ( "bytes" + "context" "testing" "github.com/google/go-cmp/cmp" @@ -264,3 +265,68 @@ func TestChatPrompt(t *testing.T) { }) } } + +func TestChatPromptTokenizeCalls(t *testing.T) { + tmpl, err := template.Parse(` +{{- if .System }}{{ .System }} {{ end }} +{{- if .Prompt }}{{ .Prompt }} {{ end }} +{{- if .Response }}{{ .Response }} {{ end }}`) + if err != nil { + t.Fatal(err) + } + model := Model{Template: tmpl} + + cases := []struct { + name string + limit int + msgs []api.Message + maxTokenizes int + }{ + { + name: "all messages fit", + limit: 2048, + msgs: []api.Message{ + {Role: "user", Content: "message 1"}, + {Role: "assistant", Content: "response 1"}, + {Role: "user", Content: "message 2"}, + {Role: "assistant", Content: "response 2"}, + {Role: "user", Content: "message 3"}, + }, + maxTokenizes: 1, + }, + { + name: "truncate to last message", + limit: 5, + msgs: []api.Message{ + {Role: "user", Content: "message 1"}, + {Role: "assistant", Content: "response 1"}, + {Role: "user", Content: "message 2"}, + {Role: "assistant", Content: "response 2"}, + {Role: "user", Content: "message 3"}, + }, + maxTokenizes: 5, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + tokenizeCount := 0 + countingTokenize := func(ctx context.Context, s string) ([]int, error) { + tokenizeCount++ + tokens, err := mockRunner{}.Tokenize(ctx, s) + return tokens, err + } + + opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}} + think := false + _, _, err := chatPrompt(t.Context(), &model, countingTokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, true) + if err != nil { + t.Fatal(err) + } + + if tokenizeCount > tt.maxTokenizes { + t.Errorf("tokenize called %d times, expected at most %d", tokenizeCount, tt.maxTokenizes) + } + }) + } +}