mlxrunner: Simplify pipeline memory and cache management

Particularly in error cases, it can be difficult to ensure that
all pinned memory is unpinned, MLX buffers are released and cache
state is consistent. This encapsulates those pieces and sets up
proper deferrals so that this happens automatically on exit.
This commit is contained in:
Jesse Gross
2026-02-24 14:19:12 -08:00
parent 7f9efd53df
commit 4e57d2094e
3 changed files with 103 additions and 66 deletions

View File

@@ -9,59 +9,99 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
)
// CacheEntry stores a single sequence
type CacheEntry struct {
Tokens []int32
Caches []cache.Cache
type kvCache struct {
// For now we only support a single entry, so this is just one sequence
tokens []int32
caches []cache.Cache
}
// FindNearestCache finds the longest common prefix between tokens and the cached sequence
func (r *Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
if r.cache == nil {
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
// cacheSession manages caches for a single pipeline run.
// Callers should append generated tokens to outputs and
// defer close to save the cache state.
type cacheSession struct {
cache *kvCache
inputs []int32
outputs []int32
caches []cache.Cache
remaining []int32
}
// begin prepares caches for a new request. It finds the nearest
// matching cache or creates new caches if none match.
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
if len(c.caches) == 0 {
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
c.caches = cacheFactory.NewCaches()
} else {
c.caches = make([]cache.Cache, m.NumLayers())
for i := range c.caches {
c.caches[i] = cache.NewKVCache()
}
}
}
// Find longest common prefix
remaining := c.findRemaining(inputs)
return &cacheSession{
cache: c,
inputs: inputs,
caches: c.caches,
remaining: remaining,
}
}
// close saves the token state if the forward pass ran.
func (s *cacheSession) close() {
if offset := s.caches[0].Offset(); offset > 0 {
// Ensure that if we have run the forward pass and set the metadata
// that we also actually have the data
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
for _, c := range s.caches {
k, v := c.State()
arrays = append(arrays, k, v)
}
mlx.AsyncEval(arrays...)
s.cache.tokens = append(s.inputs, s.outputs...)[:offset]
}
}
// findRemaining finds the longest common prefix between tokens and the cached
// sequence, trims stale cache entries, and returns the remaining tokens.
func (c *kvCache) findRemaining(tokens []int32) []int32 {
prefix := 0
for prefix < len(tokens) && prefix < len(r.cache.Tokens) && tokens[prefix] == r.cache.Tokens[prefix] {
for prefix < len(tokens) && prefix < len(c.tokens) && tokens[prefix] == c.tokens[prefix] {
prefix++
}
switch {
case prefix == 0:
for _, c := range r.cache.Caches {
c.Free()
if prefix < len(c.tokens) {
trim := len(c.tokens) - prefix
for _, kv := range c.caches {
kv.Trim(trim)
}
r.cache = nil
c.tokens = c.tokens[:prefix]
}
if prefix == 0 {
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
case prefix < len(r.cache.Tokens):
trim := len(r.cache.Tokens) - prefix
for _, c := range r.cache.Caches {
c.Trim(trim)
}
r.cache.Tokens = r.cache.Tokens[:prefix]
} else {
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
}
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
return r.cache.Caches, tokens[prefix:]
return tokens[prefix:]
}
func (r *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
r.cache = &CacheEntry{
Tokens: tokens,
Caches: caches,
func (c *kvCache) log() {
if len(c.caches) == 0 {
return
}
}
func (c *CacheEntry) LogCache() {
var totalBytes int
for _, kv := range c.Caches {
for _, kv := range c.caches {
k, v := kv.State()
totalBytes += k.NumBytes() + v.NumBytes()
}
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.Caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
}

View File

@@ -10,7 +10,6 @@ import (
"time"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
@@ -19,6 +18,23 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
return errors.New("model not loaded")
}
var (
sample, logprobs *mlx.Array
nextSample, nextLogprobs *mlx.Array
)
defer func() {
mlx.Unpin(sample, logprobs)
mlx.Unpin(nextSample, nextLogprobs)
mlx.Sweep()
mlx.ClearCache()
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
mlx.LogArrays()
r.cache.log()
}
}()
enableCompile := true
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
enableCompile = modelCompile.EnableCompile()
@@ -30,18 +46,11 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
}
inputs := r.Tokenizer.Encode(request.Prompt, true)
session := r.cache.begin(r.Model, inputs)
defer session.close()
caches, tokens := r.FindNearestCache(inputs)
if len(caches) == 0 {
if cacheFactory, ok := r.Model.(interface{ NewCaches() []cache.Cache }); ok {
caches = cacheFactory.NewCaches()
} else {
caches = make([]cache.Cache, r.Model.NumLayers())
for i := range caches {
caches[i] = cache.NewKVCache()
}
}
}
caches := session.caches
tokens := session.remaining
total, processed := len(tokens), 0
slog.Info("Prompt processing progress", "processed", processed, "total", total)
@@ -76,15 +85,14 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
return sample, logprobs
}
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed))
var b bytes.Buffer
now := time.Now()
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
outputs := make([]int32, 0, request.Options.MaxTokens)
for i := range request.Options.MaxTokens {
nextSample, nextLogprobs := step(sample)
nextSample, nextLogprobs = step(sample)
if i == 0 {
slog.Info("Prompt processing progress", "processed", total, "total", total)
@@ -94,10 +102,9 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
}
output := int32(sample.Int())
outputs = append(outputs, output)
session.outputs = append(session.outputs, output)
if r.Tokenizer.IsEOS(output) {
mlx.Unpin(nextSample, nextLogprobs)
final.Token = int(output)
final.DoneReason = 0
final.CompletionTokens = i
@@ -110,26 +117,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
}
mlx.Unpin(sample, logprobs)
sample, logprobs = nextSample, nextLogprobs
nextSample, nextLogprobs = nil, nil
if i%256 == 0 {
mlx.ClearCache()
}
sample, logprobs = nextSample, nextLogprobs
}
mlx.Unpin(sample, logprobs)
final.CompletionTokensDuration = time.Since(now)
request.Responses <- final
r.InsertCache(append(inputs, outputs...), caches)
mlx.Sweep()
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
mlx.LogArrays()
if r.cache != nil {
r.cache.LogCache()
}
}
return nil
}

View File

@@ -61,7 +61,7 @@ type Runner struct {
Model base.Model
Tokenizer *tokenizer.Tokenizer
Requests chan Request
cache *CacheEntry
cache kvCache
}
func (r *Runner) Load(modelName string) error {