mirror of
https://github.com/ollama/ollama.git
synced 2026-03-09 03:12:11 -05:00
mlxrunner: Simplify pipeline memory and cache management
Particularly in error cases, it can be difficult to ensure that all pinned memory is unpinned, MLX buffers are released and cache state is consistent. This encapsulates those pieces and sets up proper deferrals so that this happens automatically on exit.
This commit is contained in:
@@ -9,59 +9,99 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
)
|
||||
|
||||
// CacheEntry stores a single sequence
|
||||
type CacheEntry struct {
|
||||
Tokens []int32
|
||||
Caches []cache.Cache
|
||||
type kvCache struct {
|
||||
// For now we only support a single entry, so this is just one sequence
|
||||
tokens []int32
|
||||
caches []cache.Cache
|
||||
}
|
||||
|
||||
// FindNearestCache finds the longest common prefix between tokens and the cached sequence
|
||||
func (r *Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
|
||||
if r.cache == nil {
|
||||
slog.Info("Cache miss", "left", len(tokens))
|
||||
return nil, tokens
|
||||
// cacheSession manages caches for a single pipeline run.
|
||||
// Callers should append generated tokens to outputs and
|
||||
// defer close to save the cache state.
|
||||
type cacheSession struct {
|
||||
cache *kvCache
|
||||
inputs []int32
|
||||
outputs []int32
|
||||
|
||||
caches []cache.Cache
|
||||
remaining []int32
|
||||
}
|
||||
|
||||
// begin prepares caches for a new request. It finds the nearest
|
||||
// matching cache or creates new caches if none match.
|
||||
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
||||
if len(c.caches) == 0 {
|
||||
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
|
||||
c.caches = cacheFactory.NewCaches()
|
||||
} else {
|
||||
c.caches = make([]cache.Cache, m.NumLayers())
|
||||
for i := range c.caches {
|
||||
c.caches[i] = cache.NewKVCache()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find longest common prefix
|
||||
remaining := c.findRemaining(inputs)
|
||||
|
||||
return &cacheSession{
|
||||
cache: c,
|
||||
inputs: inputs,
|
||||
caches: c.caches,
|
||||
remaining: remaining,
|
||||
}
|
||||
}
|
||||
|
||||
// close saves the token state if the forward pass ran.
|
||||
func (s *cacheSession) close() {
|
||||
if offset := s.caches[0].Offset(); offset > 0 {
|
||||
// Ensure that if we have run the forward pass and set the metadata
|
||||
// that we also actually have the data
|
||||
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
|
||||
for _, c := range s.caches {
|
||||
k, v := c.State()
|
||||
arrays = append(arrays, k, v)
|
||||
}
|
||||
mlx.AsyncEval(arrays...)
|
||||
|
||||
s.cache.tokens = append(s.inputs, s.outputs...)[:offset]
|
||||
}
|
||||
}
|
||||
|
||||
// findRemaining finds the longest common prefix between tokens and the cached
|
||||
// sequence, trims stale cache entries, and returns the remaining tokens.
|
||||
func (c *kvCache) findRemaining(tokens []int32) []int32 {
|
||||
prefix := 0
|
||||
for prefix < len(tokens) && prefix < len(r.cache.Tokens) && tokens[prefix] == r.cache.Tokens[prefix] {
|
||||
for prefix < len(tokens) && prefix < len(c.tokens) && tokens[prefix] == c.tokens[prefix] {
|
||||
prefix++
|
||||
}
|
||||
|
||||
switch {
|
||||
case prefix == 0:
|
||||
for _, c := range r.cache.Caches {
|
||||
c.Free()
|
||||
if prefix < len(c.tokens) {
|
||||
trim := len(c.tokens) - prefix
|
||||
for _, kv := range c.caches {
|
||||
kv.Trim(trim)
|
||||
}
|
||||
r.cache = nil
|
||||
c.tokens = c.tokens[:prefix]
|
||||
}
|
||||
|
||||
if prefix == 0 {
|
||||
slog.Info("Cache miss", "left", len(tokens))
|
||||
return nil, tokens
|
||||
case prefix < len(r.cache.Tokens):
|
||||
trim := len(r.cache.Tokens) - prefix
|
||||
for _, c := range r.cache.Caches {
|
||||
c.Trim(trim)
|
||||
}
|
||||
r.cache.Tokens = r.cache.Tokens[:prefix]
|
||||
} else {
|
||||
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
|
||||
}
|
||||
|
||||
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
|
||||
return r.cache.Caches, tokens[prefix:]
|
||||
return tokens[prefix:]
|
||||
}
|
||||
|
||||
func (r *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
|
||||
r.cache = &CacheEntry{
|
||||
Tokens: tokens,
|
||||
Caches: caches,
|
||||
func (c *kvCache) log() {
|
||||
if len(c.caches) == 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CacheEntry) LogCache() {
|
||||
var totalBytes int
|
||||
for _, kv := range c.Caches {
|
||||
for _, kv := range c.caches {
|
||||
k, v := kv.State()
|
||||
totalBytes += k.NumBytes() + v.NumBytes()
|
||||
}
|
||||
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.Caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
|
||||
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
@@ -19,6 +18,23 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
return errors.New("model not loaded")
|
||||
}
|
||||
|
||||
var (
|
||||
sample, logprobs *mlx.Array
|
||||
nextSample, nextLogprobs *mlx.Array
|
||||
)
|
||||
|
||||
defer func() {
|
||||
mlx.Unpin(sample, logprobs)
|
||||
mlx.Unpin(nextSample, nextLogprobs)
|
||||
mlx.Sweep()
|
||||
mlx.ClearCache()
|
||||
|
||||
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
|
||||
mlx.LogArrays()
|
||||
r.cache.log()
|
||||
}
|
||||
}()
|
||||
|
||||
enableCompile := true
|
||||
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
|
||||
enableCompile = modelCompile.EnableCompile()
|
||||
@@ -30,18 +46,11 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
}
|
||||
|
||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
||||
session := r.cache.begin(r.Model, inputs)
|
||||
defer session.close()
|
||||
|
||||
caches, tokens := r.FindNearestCache(inputs)
|
||||
if len(caches) == 0 {
|
||||
if cacheFactory, ok := r.Model.(interface{ NewCaches() []cache.Cache }); ok {
|
||||
caches = cacheFactory.NewCaches()
|
||||
} else {
|
||||
caches = make([]cache.Cache, r.Model.NumLayers())
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
}
|
||||
}
|
||||
caches := session.caches
|
||||
tokens := session.remaining
|
||||
|
||||
total, processed := len(tokens), 0
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
@@ -76,15 +85,14 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
return sample, logprobs
|
||||
}
|
||||
|
||||
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
|
||||
sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed))
|
||||
|
||||
var b bytes.Buffer
|
||||
|
||||
now := time.Now()
|
||||
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
|
||||
outputs := make([]int32, 0, request.Options.MaxTokens)
|
||||
for i := range request.Options.MaxTokens {
|
||||
nextSample, nextLogprobs := step(sample)
|
||||
nextSample, nextLogprobs = step(sample)
|
||||
|
||||
if i == 0 {
|
||||
slog.Info("Prompt processing progress", "processed", total, "total", total)
|
||||
@@ -94,10 +102,9 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
}
|
||||
|
||||
output := int32(sample.Int())
|
||||
outputs = append(outputs, output)
|
||||
session.outputs = append(session.outputs, output)
|
||||
|
||||
if r.Tokenizer.IsEOS(output) {
|
||||
mlx.Unpin(nextSample, nextLogprobs)
|
||||
final.Token = int(output)
|
||||
final.DoneReason = 0
|
||||
final.CompletionTokens = i
|
||||
@@ -110,26 +117,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
}
|
||||
|
||||
mlx.Unpin(sample, logprobs)
|
||||
sample, logprobs = nextSample, nextLogprobs
|
||||
nextSample, nextLogprobs = nil, nil
|
||||
|
||||
if i%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
|
||||
sample, logprobs = nextSample, nextLogprobs
|
||||
}
|
||||
|
||||
mlx.Unpin(sample, logprobs)
|
||||
final.CompletionTokensDuration = time.Since(now)
|
||||
request.Responses <- final
|
||||
r.InsertCache(append(inputs, outputs...), caches)
|
||||
mlx.Sweep()
|
||||
|
||||
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
|
||||
mlx.LogArrays()
|
||||
if r.cache != nil {
|
||||
r.cache.LogCache()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ type Runner struct {
|
||||
Model base.Model
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Requests chan Request
|
||||
cache *CacheEntry
|
||||
cache kvCache
|
||||
}
|
||||
|
||||
func (r *Runner) Load(modelName string) error {
|
||||
|
||||
Reference in New Issue
Block a user