From 5c73c4e2eedeeb847743f43b4d2e85ff4b02be4b Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Thu, 19 Feb 2026 16:50:18 -0800 Subject: [PATCH] mlxrunner: Simplify KV cache to single-entry prefix matching The KV cache previously used a tree structure which could store multiple divergent sequences, which is good for cache reuse. However, this is typically used in conjunction with paged attention so each node in the tree can store just a chunk of the KV cache and they can be stitched together later. We don't currently do this, so the cache was storing copies of the full cache for each past sequence. This redundancy plus the lack of resource limits, caused significant memory use as a conversation grew. Instead, this changes to store a single entry for the cache, which can be prefix matched. Although it is less ideal for multiple users, it largely matches Ollama's current behavior. It can be improved as additional pieces are fleshed out. --- x/mlxrunner/cache.go | 121 ++++++++++++++----------------------- x/mlxrunner/cache/cache.go | 6 ++ x/mlxrunner/pipeline.go | 3 + x/mlxrunner/runner.go | 8 +-- x/mlxrunner/server.go | 3 +- 5 files changed, 60 insertions(+), 81 deletions(-) diff --git a/x/mlxrunner/cache.go b/x/mlxrunner/cache.go index 49ddd04b6..750d556b4 100644 --- a/x/mlxrunner/cache.go +++ b/x/mlxrunner/cache.go @@ -3,94 +3,65 @@ package mlxrunner import ( + "fmt" "log/slog" + "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/x/mlxrunner/cache" + "github.com/ollama/ollama/x/mlxrunner/mlx" ) +// CacheEntry stores a single sequence type CacheEntry struct { - Caches []cache.Cache - Count int - Entries map[int32]*CacheEntry + Tokens []int32 + Caches []cache.Cache } -func (s Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) { - current := &CacheEntry{Entries: s.CacheEntries} - index, cacheIndex := 0, -1 - for _, token := range tokens { - if _, ok := current.Entries[token]; !ok { - break - } - - current = current.Entries[token] - if len(current.Caches) > 0 { - cacheIndex = index - } - - index += 1 +// FindNearestCache finds the longest common prefix between tokens and the cached sequence +func (r *Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) { + if r.cache == nil { + slog.Info("Cache miss", "left", len(tokens)) + return nil, tokens } - if cacheIndex == len(tokens)-1 { - slog.Info("Cache hit", "type", "exact", "total", len(tokens), "cached", len(tokens), "left", len(tokens)) - return current.Caches, []int32{} - } else if cacheIndex > 1 { - slog.Info("Cache hit", "type", "partial", "total", len(tokens), "cached", cacheIndex+1, "left", len(tokens[cacheIndex+1:])) - return current.Caches, tokens[cacheIndex+1:] - } else if index > 0 && cacheIndex < 0 { - type stackItem struct { - entry *CacheEntry - tokens []int32 - } - - var best, item stackItem - stack := []stackItem{{entry: current, tokens: []int32{}}} - for len(stack) > 0 { - item, stack = stack[len(stack)-1], stack[:len(stack)-1] - if len(item.entry.Caches) > 0 { - if len(best.tokens) == 0 || len(item.tokens) < len(best.tokens) { - best = item - } - } else { - for token, entry := range item.entry.Entries { - stack = append(stack, stackItem{ - entry: entry, - tokens: append(item.tokens, token), - }) - } - } - } - - prefix := min(len(tokens)-1, index) - caches := make([]cache.Cache, len(best.entry.Caches)) - trim := len(best.tokens)+1 - for i := range caches { - caches[i] = best.entry.Caches[i].Clone() - caches[i].Trim(trim) - } - - slog.Info("Cache hit", "type", "prefix", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]), "trimmed", trim) - return caches, tokens[prefix:] + // Find longest common prefix + prefix := 0 + for prefix < len(tokens) && prefix < len(r.cache.Tokens) && tokens[prefix] == r.cache.Tokens[prefix] { + prefix++ } - slog.Info("Cache miss", "left", len(tokens)) - return nil, tokens + switch { + case prefix == 0: + for _, c := range r.cache.Caches { + c.Free() + } + r.cache = nil + slog.Info("Cache miss", "left", len(tokens)) + return nil, tokens + case prefix < len(r.cache.Tokens): + trim := len(r.cache.Tokens) - prefix + for _, c := range r.cache.Caches { + c.Trim(trim) + } + r.cache.Tokens = r.cache.Tokens[:prefix] + } + + slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:])) + return r.cache.Caches, tokens[prefix:] } -func (s *Runner) InsertCache(tokens []int32, caches []cache.Cache) { - current := &CacheEntry{Entries: s.CacheEntries} - for _, token := range tokens { - if _, ok := current.Entries[token]; !ok { - current.Entries[token] = &CacheEntry{ - Entries: make(map[int32]*CacheEntry), - } - } - - current = current.Entries[token] - } - - if len(current.Caches) > 0 { - current.Count += 1 - } else { - current.Caches = caches +func (r *Runner) InsertCache(tokens []int32, caches []cache.Cache) { + r.cache = &CacheEntry{ + Tokens: tokens, + Caches: caches, } } + +func (c *CacheEntry) LogCache() { + var totalBytes int + for _, kv := range c.Caches { + k, v := kv.State() + totalBytes += k.NumBytes() + v.NumBytes() + } + logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.Caches[0].Offset(), mlx.PrettyBytes(totalBytes))) +} diff --git a/x/mlxrunner/cache/cache.go b/x/mlxrunner/cache/cache.go index 3196b9e2a..274bdffe1 100644 --- a/x/mlxrunner/cache/cache.go +++ b/x/mlxrunner/cache/cache.go @@ -13,6 +13,7 @@ type Cache interface { State() (keys, values *mlx.Array) Trim(int) int Clone() Cache + Free() Offset() int Len() int } @@ -84,6 +85,11 @@ func (c *KVCache) Clone() Cache { return clone } +func (c *KVCache) Free() { + mlx.Unpin(c.keys, c.values) + c.keys, c.values = nil, nil +} + func (c *KVCache) Offset() int { return c.offset } func (c *KVCache) Len() int { return c.offset } diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index 618d7ec9e..e16a6c9a6 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -125,6 +125,9 @@ func (r *Runner) TextGenerationPipeline(request Request) error { if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) { mlx.LogArrays() + if r.cache != nil { + r.cache.LogCache() + } } return nil diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go index 0b24fdb3d..effaf0847 100644 --- a/x/mlxrunner/runner.go +++ b/x/mlxrunner/runner.go @@ -58,10 +58,10 @@ type Response struct { } type Runner struct { - Model base.Model - Tokenizer *tokenizer.Tokenizer - Requests chan Request - CacheEntries map[int32]*CacheEntry + Model base.Model + Tokenizer *tokenizer.Tokenizer + Requests chan Request + cache *CacheEntry } func (r *Runner) Load(modelName string) error { diff --git a/x/mlxrunner/server.go b/x/mlxrunner/server.go index ef1e0dd1c..09b71f3c8 100644 --- a/x/mlxrunner/server.go +++ b/x/mlxrunner/server.go @@ -40,8 +40,7 @@ func Execute(args []string) error { flagSet.Parse(args) runner := Runner{ - Requests: make(chan Request), - CacheEntries: make(map[int32]*CacheEntry), + Requests: make(chan Request), } if err := runner.Load(modelName); err != nil {