mirror of
https://github.com/ollama/ollama.git
synced 2026-05-06 08:02:14 -05:00
mlxrunner: wrap model forward inputs in a Batch struct
Gives a single extension point for per-call context (positions, sequence IDs, masks) as multi-sequence batching grows, without having to churn every model's Forward signature again.
This commit is contained in:
9
x/mlxrunner/batch/batch.go
Normal file
9
x/mlxrunner/batch/batch.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package batch
|
||||
|
||||
import "github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
|
||||
// Batch is the per-forward-pass input handed to a model.
|
||||
type Batch struct {
|
||||
// InputIDs is the input token IDs for this forward pass, shape (B, L).
|
||||
InputIDs *mlx.Array
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
@@ -14,7 +15,7 @@ import (
|
||||
|
||||
// Model is the interface that model implementations must satisfy.
|
||||
type Model interface {
|
||||
Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array
|
||||
Forward(b *batch.Batch, cache []cache.Cache) *mlx.Array
|
||||
Unembed(x *mlx.Array) *mlx.Array
|
||||
NumLayers() int
|
||||
Tokenizer() *tokenizer.Tokenizer
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
sampler "github.com/ollama/ollama/x/mlxrunner/sample"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
@@ -122,7 +123,7 @@ func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) er
|
||||
}
|
||||
}
|
||||
|
||||
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], 1, n), caches)
|
||||
r.Model.Forward(&batch.Batch{InputIDs: mlx.FromValues(tokens[processed:processed+n], 1, n)}, caches)
|
||||
mlx.Sweep()
|
||||
materializeCaches()
|
||||
processed += n
|
||||
@@ -143,7 +144,7 @@ func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) er
|
||||
r.Sampler.Add(pipelineSlot, request.SamplerOpts, inputs)
|
||||
|
||||
step := func(token *mlx.Array) sampler.Result {
|
||||
fwd := r.Model.Forward(token, caches)
|
||||
fwd := r.Model.Forward(&batch.Batch{InputIDs: token}, caches)
|
||||
logits := r.Model.Unembed(fwd)
|
||||
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
@@ -402,11 +403,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
dims := tokens.Dims()
|
||||
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
|
||||
dims := b.InputIDs.Dims()
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
h := m.EmbedTokens.Forward(b.InputIDs)
|
||||
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize))))
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
@@ -1013,16 +1014,16 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
dims := tokens.Dims()
|
||||
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
|
||||
dims := b.InputIDs.Dims()
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
h := m.EmbedTokens.Forward(b.InputIDs)
|
||||
h = mlx.MulScalar(h, m.EmbedScale)
|
||||
|
||||
// Compute PLE inputs if configured.
|
||||
var perLayerInputs *mlx.Array
|
||||
if m.HiddenSizePerLayer > 0 && m.EmbedTokensPerLayer != nil {
|
||||
perLayerInputs = m.computePLEInputs(tokens, h)
|
||||
perLayerInputs = m.computePLEInputs(b.InputIDs, h)
|
||||
}
|
||||
|
||||
var sharedKV map[int32]sharedKVEntry
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
@@ -698,11 +699,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
}
|
||||
|
||||
// Forward computes the forward pass of the model
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
dims := tokens.Dims()
|
||||
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
|
||||
dims := b.InputIDs.Dims()
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
h := m.EmbedTokens.Forward(b.InputIDs)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
@@ -236,11 +237,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
dims := tokens.Dims()
|
||||
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
|
||||
dims := b.InputIDs.Dims()
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
h := m.EmbedTokens.Forward(b.InputIDs)
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil && i < len(caches) {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
@@ -253,11 +254,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
dims := tokens.Dims()
|
||||
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
|
||||
dims := b.InputIDs.Dims()
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
h := m.EmbedTokens.Forward(b.InputIDs)
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil && i < len(caches) {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
@@ -1345,11 +1346,11 @@ func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *m
|
||||
return mlx.Add(h, r)
|
||||
}
|
||||
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
dims := tokens.Dims()
|
||||
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
|
||||
dims := b.InputIDs.Dims()
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
h := m.EmbedTokens.Forward(b.InputIDs)
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil && i < len(caches) {
|
||||
|
||||
Reference in New Issue
Block a user