mlxrunner: Simplify KV cache to single-entry prefix matching

The KV cache previously used a tree structure which could
store multiple divergent sequences, which is good for cache
reuse. However, this is typically used in conjunction with
paged attention so each node in the tree can store just a
chunk of the KV cache and they can be stitched together later.
We don't currently do this, so the cache was storing copies of
the full cache for each past sequence.

This redundancy plus the lack of resource limits, caused significant
memory use as a conversation grew. Instead, this changes to store
a single entry for the cache, which can be prefix matched. Although
it is less ideal for multiple users, it largely matches Ollama's
current behavior. It can be improved as additional pieces are fleshed
out.
This commit is contained in:
Jesse Gross
2026-02-19 16:50:18 -08:00
parent 5daf59cc66
commit 5c73c4e2ee
5 changed files with 60 additions and 81 deletions

View File

@@ -3,94 +3,65 @@
package mlxrunner
import (
"fmt"
"log/slog"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// CacheEntry stores a single sequence
type CacheEntry struct {
Caches []cache.Cache
Count int
Entries map[int32]*CacheEntry
Tokens []int32
Caches []cache.Cache
}
func (s Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
current := &CacheEntry{Entries: s.CacheEntries}
index, cacheIndex := 0, -1
for _, token := range tokens {
if _, ok := current.Entries[token]; !ok {
break
}
current = current.Entries[token]
if len(current.Caches) > 0 {
cacheIndex = index
}
index += 1
// FindNearestCache finds the longest common prefix between tokens and the cached sequence
func (r *Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
if r.cache == nil {
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
}
if cacheIndex == len(tokens)-1 {
slog.Info("Cache hit", "type", "exact", "total", len(tokens), "cached", len(tokens), "left", len(tokens))
return current.Caches, []int32{}
} else if cacheIndex > 1 {
slog.Info("Cache hit", "type", "partial", "total", len(tokens), "cached", cacheIndex+1, "left", len(tokens[cacheIndex+1:]))
return current.Caches, tokens[cacheIndex+1:]
} else if index > 0 && cacheIndex < 0 {
type stackItem struct {
entry *CacheEntry
tokens []int32
}
var best, item stackItem
stack := []stackItem{{entry: current, tokens: []int32{}}}
for len(stack) > 0 {
item, stack = stack[len(stack)-1], stack[:len(stack)-1]
if len(item.entry.Caches) > 0 {
if len(best.tokens) == 0 || len(item.tokens) < len(best.tokens) {
best = item
}
} else {
for token, entry := range item.entry.Entries {
stack = append(stack, stackItem{
entry: entry,
tokens: append(item.tokens, token),
})
}
}
}
prefix := min(len(tokens)-1, index)
caches := make([]cache.Cache, len(best.entry.Caches))
trim := len(best.tokens)+1
for i := range caches {
caches[i] = best.entry.Caches[i].Clone()
caches[i].Trim(trim)
}
slog.Info("Cache hit", "type", "prefix", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]), "trimmed", trim)
return caches, tokens[prefix:]
// Find longest common prefix
prefix := 0
for prefix < len(tokens) && prefix < len(r.cache.Tokens) && tokens[prefix] == r.cache.Tokens[prefix] {
prefix++
}
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
switch {
case prefix == 0:
for _, c := range r.cache.Caches {
c.Free()
}
r.cache = nil
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
case prefix < len(r.cache.Tokens):
trim := len(r.cache.Tokens) - prefix
for _, c := range r.cache.Caches {
c.Trim(trim)
}
r.cache.Tokens = r.cache.Tokens[:prefix]
}
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
return r.cache.Caches, tokens[prefix:]
}
func (s *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
current := &CacheEntry{Entries: s.CacheEntries}
for _, token := range tokens {
if _, ok := current.Entries[token]; !ok {
current.Entries[token] = &CacheEntry{
Entries: make(map[int32]*CacheEntry),
}
}
current = current.Entries[token]
}
if len(current.Caches) > 0 {
current.Count += 1
} else {
current.Caches = caches
func (r *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
r.cache = &CacheEntry{
Tokens: tokens,
Caches: caches,
}
}
func (c *CacheEntry) LogCache() {
var totalBytes int
for _, kv := range c.Caches {
k, v := kv.State()
totalBytes += k.NumBytes() + v.NumBytes()
}
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.Caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
}

View File

@@ -13,6 +13,7 @@ type Cache interface {
State() (keys, values *mlx.Array)
Trim(int) int
Clone() Cache
Free()
Offset() int
Len() int
}
@@ -84,6 +85,11 @@ func (c *KVCache) Clone() Cache {
return clone
}
func (c *KVCache) Free() {
mlx.Unpin(c.keys, c.values)
c.keys, c.values = nil, nil
}
func (c *KVCache) Offset() int { return c.offset }
func (c *KVCache) Len() int { return c.offset }

View File

@@ -125,6 +125,9 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
mlx.LogArrays()
if r.cache != nil {
r.cache.LogCache()
}
}
return nil

View File

@@ -58,10 +58,10 @@ type Response struct {
}
type Runner struct {
Model base.Model
Tokenizer *tokenizer.Tokenizer
Requests chan Request
CacheEntries map[int32]*CacheEntry
Model base.Model
Tokenizer *tokenizer.Tokenizer
Requests chan Request
cache *CacheEntry
}
func (r *Runner) Load(modelName string) error {

View File

@@ -40,8 +40,7 @@ func Execute(args []string) error {
flagSet.Parse(args)
runner := Runner{
Requests: make(chan Request),
CacheEntries: make(map[int32]*CacheEntry),
Requests: make(chan Request),
}
if err := runner.Load(modelName); err != nil {