mlxrunner: wrap model forward inputs in a Batch struct

Gives a single extension point for per-call context (positions,
sequence IDs, masks) as multi-sequence batching grows, without having
to churn every model's Forward signature again.
This commit is contained in:
Jesse Gross
2026-04-21 15:29:07 -07:00
parent 3cab8a7b02
commit 088dfd89a8
9 changed files with 39 additions and 22 deletions

View File

@@ -0,0 +1,9 @@
package batch
import "github.com/ollama/ollama/x/mlxrunner/mlx"
// Batch is the per-forward-pass input handed to a model.
type Batch struct {
// InputIDs is the input token IDs for this forward pass, shape (B, L).
InputIDs *mlx.Array
}

View File

@@ -6,6 +6,7 @@ import (
"log/slog"
"sync"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
@@ -14,7 +15,7 @@ import (
// Model is the interface that model implementations must satisfy.
type Model interface {
Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array
Forward(b *batch.Batch, cache []cache.Cache) *mlx.Array
Unembed(x *mlx.Array) *mlx.Array
NumLayers() int
Tokenizer() *tokenizer.Tokenizer

View File

@@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/mlx"
sampler "github.com/ollama/ollama/x/mlxrunner/sample"
"github.com/ollama/ollama/x/tokenizer"
@@ -122,7 +123,7 @@ func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) er
}
}
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], 1, n), caches)
r.Model.Forward(&batch.Batch{InputIDs: mlx.FromValues(tokens[processed:processed+n], 1, n)}, caches)
mlx.Sweep()
materializeCaches()
processed += n
@@ -143,7 +144,7 @@ func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) er
r.Sampler.Add(pipelineSlot, request.SamplerOpts, inputs)
step := func(token *mlx.Array) sampler.Result {
fwd := r.Model.Forward(token, caches)
fwd := r.Model.Forward(&batch.Batch{InputIDs: token}, caches)
logits := r.Model.Unembed(fwd)
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
@@ -402,11 +403,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
return nil
}
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
dims := tokens.Dims()
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
dims := b.InputIDs.Dims()
B, L := int32(dims[0]), int32(dims[1])
h := m.EmbedTokens.Forward(tokens)
h := m.EmbedTokens.Forward(b.InputIDs)
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize))))
for i, layer := range m.Layers {

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
@@ -1013,16 +1014,16 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
return nil
}
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
dims := tokens.Dims()
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
dims := b.InputIDs.Dims()
B, L := int32(dims[0]), int32(dims[1])
h := m.EmbedTokens.Forward(tokens)
h := m.EmbedTokens.Forward(b.InputIDs)
h = mlx.MulScalar(h, m.EmbedScale)
// Compute PLE inputs if configured.
var perLayerInputs *mlx.Array
if m.HiddenSizePerLayer > 0 && m.EmbedTokensPerLayer != nil {
perLayerInputs = m.computePLEInputs(tokens, h)
perLayerInputs = m.computePLEInputs(b.InputIDs, h)
}
var sharedKV map[int32]sharedKVEntry

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
@@ -698,11 +699,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
}
// Forward computes the forward pass of the model
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
dims := tokens.Dims()
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
dims := b.InputIDs.Dims()
B, L := int32(dims[0]), int32(dims[1])
h := m.EmbedTokens.Forward(tokens)
h := m.EmbedTokens.Forward(b.InputIDs)
for i, layer := range m.Layers {
var c cache.Cache

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
@@ -236,11 +237,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
return nil
}
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
dims := tokens.Dims()
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
dims := b.InputIDs.Dims()
B, L := int32(dims[0]), int32(dims[1])
h := m.EmbedTokens.Forward(tokens)
h := m.EmbedTokens.Forward(b.InputIDs)
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil && i < len(caches) {

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"math"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
@@ -253,11 +254,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
return nil
}
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
dims := tokens.Dims()
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
dims := b.InputIDs.Dims()
B, L := int32(dims[0]), int32(dims[1])
h := m.EmbedTokens.Forward(tokens)
h := m.EmbedTokens.Forward(b.InputIDs)
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil && i < len(caches) {

View File

@@ -7,6 +7,7 @@ import (
"math"
"strings"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
@@ -1345,11 +1346,11 @@ func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *m
return mlx.Add(h, r)
}
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
dims := tokens.Dims()
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
dims := b.InputIDs.Dims()
B, L := int32(dims[0]), int32(dims[1])
h := m.EmbedTokens.Forward(tokens)
h := m.EmbedTokens.Forward(b.InputIDs)
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil && i < len(caches) {