From 088dfd89a8b368d7f2f82ad51df723327424b21f Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 21 Apr 2026 15:29:07 -0700 Subject: [PATCH] mlxrunner: wrap model forward inputs in a Batch struct Gives a single extension point for per-call context (positions, sequence IDs, masks) as multi-sequence batching grows, without having to churn every model's Forward signature again. --- x/mlxrunner/batch/batch.go | 9 +++++++++ x/mlxrunner/model/base/base.go | 3 ++- x/mlxrunner/pipeline.go | 5 +++-- x/models/gemma3/gemma3.go | 7 ++++--- x/models/gemma4/gemma4.go | 9 +++++---- x/models/glm4_moe_lite/glm4_moe_lite.go | 7 ++++--- x/models/llama/llama.go | 7 ++++--- x/models/qwen3/qwen3.go | 7 ++++--- x/models/qwen3_5/qwen3_5.go | 7 ++++--- 9 files changed, 39 insertions(+), 22 deletions(-) create mode 100644 x/mlxrunner/batch/batch.go diff --git a/x/mlxrunner/batch/batch.go b/x/mlxrunner/batch/batch.go new file mode 100644 index 000000000..c6cf47d93 --- /dev/null +++ b/x/mlxrunner/batch/batch.go @@ -0,0 +1,9 @@ +package batch + +import "github.com/ollama/ollama/x/mlxrunner/mlx" + +// Batch is the per-forward-pass input handed to a model. +type Batch struct { + // InputIDs is the input token IDs for this forward pass, shape (B, L). + InputIDs *mlx.Array +} diff --git a/x/mlxrunner/model/base/base.go b/x/mlxrunner/model/base/base.go index 3a85b6eb0..8fa5a2347 100644 --- a/x/mlxrunner/model/base/base.go +++ b/x/mlxrunner/model/base/base.go @@ -6,6 +6,7 @@ import ( "log/slog" "sync" + "github.com/ollama/ollama/x/mlxrunner/batch" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" @@ -14,7 +15,7 @@ import ( // Model is the interface that model implementations must satisfy. type Model interface { - Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array + Forward(b *batch.Batch, cache []cache.Cache) *mlx.Array Unembed(x *mlx.Array) *mlx.Array NumLayers() int Tokenizer() *tokenizer.Tokenizer diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index 34d3e3d13..0fe8bff92 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -11,6 +11,7 @@ import ( "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" + "github.com/ollama/ollama/x/mlxrunner/batch" "github.com/ollama/ollama/x/mlxrunner/mlx" sampler "github.com/ollama/ollama/x/mlxrunner/sample" "github.com/ollama/ollama/x/tokenizer" @@ -122,7 +123,7 @@ func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) er } } - r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], 1, n), caches) + r.Model.Forward(&batch.Batch{InputIDs: mlx.FromValues(tokens[processed:processed+n], 1, n)}, caches) mlx.Sweep() materializeCaches() processed += n @@ -143,7 +144,7 @@ func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) er r.Sampler.Add(pipelineSlot, request.SamplerOpts, inputs) step := func(token *mlx.Array) sampler.Result { - fwd := r.Model.Forward(token, caches) + fwd := r.Model.Forward(&batch.Batch{InputIDs: token}, caches) logits := r.Model.Unembed(fwd) logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1) diff --git a/x/models/gemma3/gemma3.go b/x/models/gemma3/gemma3.go index 266222b69..24b0cee5a 100644 --- a/x/models/gemma3/gemma3.go +++ b/x/models/gemma3/gemma3.go @@ -6,6 +6,7 @@ import ( "fmt" "math" + "github.com/ollama/ollama/x/mlxrunner/batch" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" @@ -402,11 +403,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { return nil } -func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { - dims := tokens.Dims() +func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array { + dims := b.InputIDs.Dims() B, L := int32(dims[0]), int32(dims[1]) - h := m.EmbedTokens.Forward(tokens) + h := m.EmbedTokens.Forward(b.InputIDs) h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize)))) for i, layer := range m.Layers { diff --git a/x/models/gemma4/gemma4.go b/x/models/gemma4/gemma4.go index c2d9a979a..19b92c58a 100644 --- a/x/models/gemma4/gemma4.go +++ b/x/models/gemma4/gemma4.go @@ -6,6 +6,7 @@ import ( "fmt" "math" + "github.com/ollama/ollama/x/mlxrunner/batch" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" @@ -1013,16 +1014,16 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { return nil } -func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { - dims := tokens.Dims() +func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array { + dims := b.InputIDs.Dims() B, L := int32(dims[0]), int32(dims[1]) - h := m.EmbedTokens.Forward(tokens) + h := m.EmbedTokens.Forward(b.InputIDs) h = mlx.MulScalar(h, m.EmbedScale) // Compute PLE inputs if configured. var perLayerInputs *mlx.Array if m.HiddenSizePerLayer > 0 && m.EmbedTokensPerLayer != nil { - perLayerInputs = m.computePLEInputs(tokens, h) + perLayerInputs = m.computePLEInputs(b.InputIDs, h) } var sharedKV map[int32]sharedKVEntry diff --git a/x/models/glm4_moe_lite/glm4_moe_lite.go b/x/models/glm4_moe_lite/glm4_moe_lite.go index aac320806..ed5a2e005 100644 --- a/x/models/glm4_moe_lite/glm4_moe_lite.go +++ b/x/models/glm4_moe_lite/glm4_moe_lite.go @@ -7,6 +7,7 @@ import ( "fmt" "math" + "github.com/ollama/ollama/x/mlxrunner/batch" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" @@ -698,11 +699,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { } // Forward computes the forward pass of the model -func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { - dims := tokens.Dims() +func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array { + dims := b.InputIDs.Dims() B, L := int32(dims[0]), int32(dims[1]) - h := m.EmbedTokens.Forward(tokens) + h := m.EmbedTokens.Forward(b.InputIDs) for i, layer := range m.Layers { var c cache.Cache diff --git a/x/models/llama/llama.go b/x/models/llama/llama.go index 4f4da05a7..0cad59e04 100644 --- a/x/models/llama/llama.go +++ b/x/models/llama/llama.go @@ -6,6 +6,7 @@ import ( "fmt" "math" + "github.com/ollama/ollama/x/mlxrunner/batch" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" @@ -236,11 +237,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { return nil } -func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { - dims := tokens.Dims() +func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array { + dims := b.InputIDs.Dims() B, L := int32(dims[0]), int32(dims[1]) - h := m.EmbedTokens.Forward(tokens) + h := m.EmbedTokens.Forward(b.InputIDs) for i, layer := range m.Layers { var c cache.Cache if caches != nil && i < len(caches) { diff --git a/x/models/qwen3/qwen3.go b/x/models/qwen3/qwen3.go index a1b31af0d..55dccd93f 100644 --- a/x/models/qwen3/qwen3.go +++ b/x/models/qwen3/qwen3.go @@ -6,6 +6,7 @@ import ( "fmt" "math" + "github.com/ollama/ollama/x/mlxrunner/batch" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" @@ -253,11 +254,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { return nil } -func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { - dims := tokens.Dims() +func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array { + dims := b.InputIDs.Dims() B, L := int32(dims[0]), int32(dims[1]) - h := m.EmbedTokens.Forward(tokens) + h := m.EmbedTokens.Forward(b.InputIDs) for i, layer := range m.Layers { var c cache.Cache if caches != nil && i < len(caches) { diff --git a/x/models/qwen3_5/qwen3_5.go b/x/models/qwen3_5/qwen3_5.go index f29563f88..cc246a577 100644 --- a/x/models/qwen3_5/qwen3_5.go +++ b/x/models/qwen3_5/qwen3_5.go @@ -7,6 +7,7 @@ import ( "math" "strings" + "github.com/ollama/ollama/x/mlxrunner/batch" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" @@ -1345,11 +1346,11 @@ func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *m return mlx.Add(h, r) } -func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { - dims := tokens.Dims() +func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array { + dims := b.InputIDs.Dims() B, L := int32(dims[0]), int32(dims[1]) - h := m.EmbedTokens.Forward(tokens) + h := m.EmbedTokens.Forward(b.InputIDs) for i, layer := range m.Layers { var c cache.Cache if caches != nil && i < len(caches) {