From 5daf59cc6666dd036af8fab8c5df6b5571a9a9ba Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Thu, 19 Feb 2026 15:05:35 -0800 Subject: [PATCH] mlxrunner: Fix memory leaks with pin/sweep lifecycle management The previous approach tracked array lifecycles through reference counting, where each array recorded its inputs and a reference count that was decremented as dependents were freed. This is not really necessary as MLX tracks references internally. It is also error prone as it is easy to create new arrays and forget to free them when the Go variable goes out of scope. Instead, we can pin just the arrays we want (typically outputs and specific intermediates, like the cache). All other arrays are freed by default when we run sweep. This avoids most causes of memory leaks while still giving the freedom to save what we want. --- x/mlxrunner/cache/cache.go | 9 +- x/mlxrunner/mlx/array.go | 125 ++++++++++++------------ x/mlxrunner/mlx/fast.go | 8 +- x/mlxrunner/mlx/io.go | 4 +- x/mlxrunner/mlx/ops.go | 62 ++++++------ x/mlxrunner/mlx/ops_extra.go | 38 +++---- x/mlxrunner/mlx/random.go | 2 +- x/mlxrunner/mlx/slice.go | 4 +- x/mlxrunner/model/base/base.go | 19 +++- x/mlxrunner/pipeline.go | 27 +++-- x/models/gemma3/gemma3.go | 3 - x/models/glm4_moe_lite/glm4_moe_lite.go | 3 - x/models/llama/llama.go | 3 - x/models/qwen3/qwen3.go | 3 - 14 files changed, 159 insertions(+), 151 deletions(-) diff --git a/x/mlxrunner/cache/cache.go b/x/mlxrunner/cache/cache.go index 05cffbf5e..3196b9e2a 100644 --- a/x/mlxrunner/cache/cache.go +++ b/x/mlxrunner/cache/cache.go @@ -47,6 +47,7 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { c.values.Set(c.values.Concatenate(2, newValues)) } else { c.keys, c.values = newKeys, newValues + mlx.Pin(c.keys, c.values) } } @@ -73,12 +74,14 @@ func (c *KVCache) Trim(n int) int { } func (c *KVCache) Clone() Cache { - return &KVCache{ + clone := &KVCache{ keys: c.keys.Clone(), values: c.values.Clone(), offset: c.offset, step: c.step, } + mlx.Pin(clone.keys, clone.values) + return clone } func (c *KVCache) Offset() int { return c.offset } @@ -106,7 +109,8 @@ func (c *RotatingKVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Arra func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) { slog.Debug("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize) if c.keys == nil { - c.keys, c.values = keys, values + c.keys, c.values = keys.Clone(), values.Clone() + mlx.Pin(c.keys, c.values) } else { if c.idx < c.keys.Dim(2) { c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())) @@ -145,6 +149,7 @@ func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Arra c.values.Set(c.values.Concatenate(2, newValues)) } else { c.keys, c.values = newKeys, newValues + mlx.Pin(c.keys, c.values) } c.idx = prev } diff --git a/x/mlxrunner/mlx/array.go b/x/mlxrunner/mlx/array.go index 43254d230..07f3ff1c1 100644 --- a/x/mlxrunner/mlx/array.go +++ b/x/mlxrunner/mlx/array.go @@ -7,48 +7,29 @@ import "C" import ( "encoding/binary" + "fmt" "log/slog" "reflect" + "sort" "strings" - "time" "unsafe" "github.com/ollama/ollama/logutil" ) -type tensorDesc struct { - name string - inputs []*Array - numRefs int -} - -func (d tensorDesc) LogValue() slog.Value { - return slog.GroupValue( - slog.String("name", d.name), - slog.Int("inputs", len(d.inputs)), - slog.Int("num_refs", d.numRefs), - ) -} - type Array struct { - ctx C.mlx_array - desc tensorDesc + ctx C.mlx_array + name string + pinned bool } +var arrays []*Array + // constructor utilities -func New(name string, inputs ...*Array) *Array { - t := &Array{ - desc: tensorDesc{ - name: name, - inputs: inputs, - }, - } - - for _, input := range inputs { - input.desc.numRefs++ - } - logutil.Trace("New", "t", t) +func New(name string) *Array { + t := &Array{name: name} + arrays = append(arrays, t) return t } @@ -133,18 +114,51 @@ func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array { } func (t *Array) Set(other *Array) { - Free(t.desc.inputs...) - other.desc.numRefs++ - t.desc.inputs = []*Array{other} C.mlx_array_set(&t.ctx, other.ctx) } func (t *Array) Clone() *Array { - tt := New(t.desc.name, t.desc.inputs...) + tt := New(t.name) C.mlx_array_set(&tt.ctx, t.ctx) return tt } +// lifecycle utilities + +// Pin marks arrays as in-use so they are retained during Sweep. +func Pin(s ...*Array) { + for _, t := range s { + if t != nil { + t.pinned = true + } + } +} + +// Unpin marks arrays as no longer in-use, allowing Sweep to free them. +func Unpin(s ...*Array) { + for _, t := range s { + if t != nil { + t.pinned = false + } + } +} + +// Sweep releases all unpinned arrays, primarily intermediate tensors. MLX will truly +// free them when there are no other references, including dependencies in the graph. +func Sweep() { + n := 0 + for _, t := range arrays { + if t.pinned && t.Valid() { + arrays[n] = t + n++ + } else if t.Valid() { + C.mlx_array_free(t.ctx) + t.ctx.ctx = nil + } + } + arrays = arrays[:n] +} + // misc. utilities func (t *Array) Valid() bool { @@ -159,7 +173,10 @@ func (t *Array) String() string { } func (t *Array) LogValue() slog.Value { - attrs := []slog.Attr{slog.Any("", t.desc)} + attrs := []slog.Attr{ + slog.String("name", t.name), + slog.Bool("pinned", t.pinned), + } if t.Valid() { attrs = append(attrs, slog.Any("dtype", t.DType()), @@ -238,37 +255,15 @@ func (t Array) Save(name string) error { return nil } -func Free(s ...*Array) (n int) { - now := time.Now() - defer func() { - if n > 0 { - logutil.Trace("Freed tensors", "num_bytes", PrettyBytes(n), "took", time.Since(now)) - } - }() +// LogArrays logs all live arrays, sorted by size +func LogArrays() { + sort.Slice(arrays, func(i, j int) bool { + return arrays[i].NumBytes() > arrays[j].NumBytes() + }) - free := make([]*Array, 0, 8192) - fn := func(t *Array) { - if t.Valid() { - t.desc.numRefs-- - if t.desc.numRefs <= 0 { - free = append(free, t.desc.inputs...) - logutil.Trace("Free", "t", t) - n += t.NumBytes() - C.mlx_array_free(t.ctx) - t.ctx.ctx = nil - } - } + for _, t := range arrays { + nb := t.NumBytes() + logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s %v", t.name, t.DType(), PrettyBytes(nb), t.Dims())) } - - for _, t := range s { - fn(t) - } - - for len(free) > 0 { - tail := free[len(free)-1] - free = free[:len(free)-1] - fn(tail) - } - - return n + logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory()))) } diff --git a/x/mlxrunner/mlx/fast.go b/x/mlxrunner/mlx/fast.go index 250d42dc8..0570840d6 100644 --- a/x/mlxrunner/mlx/fast.go +++ b/x/mlxrunner/mlx/fast.go @@ -20,7 +20,7 @@ func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *A cMode := C.CString(mode) defer C.free(unsafe.Pointer(cMode)) - out := New("FAST_SDPA", query, key, value, mask, sinks) + out := New("FAST_SDPA") C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx) return out } @@ -31,7 +31,7 @@ type LayerNorm struct { } func (r *LayerNorm) Forward(x *Array, eps float32) *Array { - out := New("FAST_LAYERNORM", x) + out := New("FAST_LAYERNORM") C.mlx_fast_layer_norm(&out.ctx, x.ctx, r.Weight.ctx, r.Bias.ctx, C.float(eps), DefaultStream().ctx) return out } @@ -41,7 +41,7 @@ type RMSNorm struct { } func (r RMSNorm) Forward(x *Array, eps float32) *Array { - out := New("FAST_RMSNORM", x) + out := New("FAST_RMSNORM") C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx) return out } @@ -55,7 +55,7 @@ type RoPE struct { func (r RoPE) Forward(t *Array, offset int) *Array { freqs := New("") - out := New("FAST_ROPE", t, freqs) + out := New("FAST_ROPE") C.mlx_fast_rope( &out.ctx, t.ctx, diff --git a/x/mlxrunner/mlx/io.go b/x/mlxrunner/mlx/io.go index 304cfcd2c..84868e005 100644 --- a/x/mlxrunner/mlx/io.go +++ b/x/mlxrunner/mlx/io.go @@ -37,7 +37,9 @@ func Load(path string) iter.Seq2[string, *Array] { } name := C.GoString(key) - if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) { + arr := New(name) + arr.ctx = value + if !yield(name, arr) { break } } diff --git a/x/mlxrunner/mlx/ops.go b/x/mlxrunner/mlx/ops.go index 01a7f4835..011a42319 100644 --- a/x/mlxrunner/mlx/ops.go +++ b/x/mlxrunner/mlx/ops.go @@ -10,43 +10,43 @@ import ( ) func (t *Array) Abs() *Array { - out := New("ABS", t) + out := New("ABS") C.mlx_abs(&out.ctx, t.ctx, DefaultStream().ctx) return out } func (t *Array) Add(other *Array) *Array { - out := New("ADD", t, other) + out := New("ADD") C.mlx_add(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx) return out } func (t *Array) Addmm(a, b *Array, alpha, beta float32) *Array { - out := New("ADDMM", t, a, b) + out := New("ADDMM") C.mlx_addmm(&out.ctx, t.ctx, a.ctx, b.ctx, C.float(alpha), C.float(beta), DefaultStream().ctx) return out } func (t *Array) Argmax(axis int, keepDims bool) *Array { - out := New("ARGMAX", t) + out := New("ARGMAX") C.mlx_argmax_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx) return out } func (t *Array) ArgpartitionAxis(kth int, axis int) *Array { - out := New("ARGPARTITION", t) + out := New("ARGPARTITION") C.mlx_argpartition_axis(&out.ctx, t.ctx, C.int(kth), C.int(axis), DefaultStream().ctx) return out } func (t *Array) ArgsortAxis(axis int) *Array { - out := New("ARGSORT_AXIS", t) + out := New("ARGSORT_AXIS") C.mlx_argsort_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx) return out } func (t *Array) AsType(dtype DType) *Array { - out := New("AS_TYPE", t) + out := New("AS_TYPE") C.mlx_astype(&out.ctx, t.ctx, C.mlx_dtype(dtype), DefaultStream().ctx) return out } @@ -62,7 +62,7 @@ func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array { cStrides[i] = C.int64_t(s) } - out := New("AS_STRIDED", t) + out := New("AS_STRIDED") C.mlx_as_strided( &out.ctx, t.ctx, unsafe.SliceData(cShape), C.size_t(len(shape)), @@ -82,31 +82,31 @@ func (t *Array) Concatenate(axis int, others ...*Array) *Array { C.mlx_vector_array_append_value(vector, other.ctx) } - out := New("CONCATENATE", s...) + out := New("CONCATENATE") C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx) return out } func (t *Array) Divide(other *Array) *Array { - out := New("DIVIDE", t, other) + out := New("DIVIDE") C.mlx_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx) return out } func (t *Array) ExpandDims(axis int) *Array { - out := New("EXPAND_DIMS", t) + out := New("EXPAND_DIMS") C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx) return out } func (t *Array) Flatten(startAxis, endAxis int) *Array { - out := New("FLATTEN", t) + out := New("FLATTEN") C.mlx_flatten(&out.ctx, t.ctx, C.int(startAxis), C.int(endAxis), DefaultStream().ctx) return out } func (t *Array) FloorDivide(other *Array) *Array { - out := New("FLOOR_DIVIDE", t, other) + out := New("FLOOR_DIVIDE") C.mlx_floor_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx) return out } @@ -118,43 +118,43 @@ func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array { if rhs == nil { rhs = New("") } - out := New("GATHER_MM", t, other, lhs, rhs) + out := New("GATHER_MM") C.mlx_gather_mm(&out.ctx, t.ctx, other.ctx, lhs.ctx, rhs.ctx, C.bool(sorted), DefaultStream().ctx) return out } func (t *Array) Logsumexp(keepDims bool) *Array { - out := New("LOGSUMEXP", t) + out := New("LOGSUMEXP") C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx) return out } func (t *Array) Matmul(other *Array) *Array { - out := New("MATMUL", t, other) + out := New("MATMUL") C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx) return out } func (t *Array) Multiply(other *Array) *Array { - out := New("MULTIPLY", t, other) + out := New("MULTIPLY") C.mlx_multiply(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx) return out } func (t *Array) Negative() *Array { - out := New("NEGATIVE", t) + out := New("NEGATIVE") C.mlx_negative(&out.ctx, t.ctx, DefaultStream().ctx) return out } func (t *Array) Power(exponent *Array) *Array { - out := New("POWER", t, exponent) + out := New("POWER") C.mlx_power(&out.ctx, t.ctx, exponent.ctx, DefaultStream().ctx) return out } func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array { - out := New("PUT_ALONG_AXIS", t, indices, values) + out := New("PUT_ALONG_AXIS") C.mlx_put_along_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx) return out } @@ -165,25 +165,25 @@ func (t *Array) Reshape(axes ...int) *Array { cAxes[i] = C.int(axes[i]) } - out := New("RESHAPE", t) + out := New("RESHAPE") C.mlx_reshape(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx) return out } func (t *Array) Sigmoid() *Array { - out := New("SIGMOID", t) + out := New("SIGMOID") C.mlx_sigmoid(&out.ctx, t.ctx, DefaultStream().ctx) return out } func (t *Array) Sqrt() *Array { - out := New("SQRT", t) + out := New("SQRT") C.mlx_sqrt(&out.ctx, t.ctx, DefaultStream().ctx) return out } func (t *Array) Squeeze(axis int) *Array { - out := New("SQUEEZE", t) + out := New("SQUEEZE") C.mlx_squeeze_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx) return out } @@ -198,37 +198,37 @@ func (t *Array) StackAxis(axis int, others ...*Array) *Array { vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData))) defer C.mlx_vector_array_free(vector) - out := New("STACK_AXIS", append(others, t)...) + out := New("STACK_AXIS") C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx) return out } func (t *Array) Subtract(other *Array) *Array { - out := New("SUBTRACT", t, other) + out := New("SUBTRACT") C.mlx_subtract(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx) return out } func (t *Array) SumAxis(axis int, keepDims bool) *Array { - out := New("SUM_AXIS", t) + out := New("SUM_AXIS") C.mlx_sum_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx) return out } func (t *Array) TakeAxis(indices *Array, axis int) *Array { - out := New("TAKE_AXIS", t, indices) + out := New("TAKE_AXIS") C.mlx_take_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx) return out } func (t *Array) TakeAlongAxis(indices *Array, axis int) *Array { - out := New("TAKE_ALONG_AXIS", t, indices) + out := New("TAKE_ALONG_AXIS") C.mlx_take_along_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx) return out } func (t *Array) Tanh() *Array { - out := New("TANH", t) + out := New("TANH") C.mlx_tanh(&out.ctx, t.ctx, DefaultStream().ctx) return out } @@ -239,7 +239,7 @@ func (t *Array) Transpose(axes ...int) *Array { cAxes[i] = C.int(axis) } - out := New("TRANSPOSE", t) + out := New("TRANSPOSE") C.mlx_transpose_axes(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx) return out } diff --git a/x/mlxrunner/mlx/ops_extra.go b/x/mlxrunner/mlx/ops_extra.go index f2882e989..e83b77fb8 100644 --- a/x/mlxrunner/mlx/ops_extra.go +++ b/x/mlxrunner/mlx/ops_extra.go @@ -41,14 +41,12 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr optBits := C.mlx_optional_int{value: C.int(bits), has_value: true} optDtype := C.mlx_optional_dtype{has_value: false} - inputs := []*Array{w, scales} var b C.mlx_array if biases != nil { b = biases.ctx - inputs = append(inputs, biases) } - out := New("DEQUANTIZE", inputs...) + out := New("DEQUANTIZE") C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, optDtype, DefaultStream().ctx) return out } @@ -59,14 +57,12 @@ func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bit optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true} optBits := C.mlx_optional_int{value: C.int(bits), has_value: true} - inputs := []*Array{x, w, scales} var b C.mlx_array if biases != nil { b = biases.ctx - inputs = append(inputs, biases) } - out := New("QUANTIZED_MATMUL", inputs...) + out := New("QUANTIZED_MATMUL") C.mlx_quantized_matmul(&out.ctx, x.ctx, w.ctx, scales.ctx, b, C.bool(transpose), optGroupSize, optBits, cMode, DefaultStream().ctx) return out } @@ -77,22 +73,18 @@ func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, trans optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true} optBits := C.mlx_optional_int{value: C.int(bits), has_value: true} - inputs := []*Array{x, w, scales} var b, lhs, rhs C.mlx_array if biases != nil { b = biases.ctx - inputs = append(inputs, biases) } if lhsIndices != nil { lhs = lhsIndices.ctx - inputs = append(inputs, lhsIndices) } if rhsIndices != nil { rhs = rhsIndices.ctx - inputs = append(inputs, rhsIndices) } - out := New("GATHER_QMM", inputs...) + out := New("GATHER_QMM") C.mlx_gather_qmm(&out.ctx, x.ctx, w.ctx, scales.ctx, b, lhs, rhs, C.bool(transpose), optGroupSize, optBits, cMode, C.bool(sortedIndices), DefaultStream().ctx) return out } @@ -104,7 +96,7 @@ func Tile(a *Array, reps []int32) *Array { for i, r := range reps { cReps[i] = C.int(r) } - out := New("TILE", a) + out := New("TILE") C.mlx_tile(&out.ctx, a.ctx, unsafe.SliceData(cReps), C.size_t(len(reps)), DefaultStream().ctx) return out } @@ -116,7 +108,7 @@ func Tri(n, m int32, k int) *Array { } func Where(condition, a, b *Array) *Array { - out := New("WHERE", condition, a, b) + out := New("WHERE") C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx) return out } @@ -131,7 +123,7 @@ func Stack(arrays []*Array, axis int) *Array { vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData))) defer C.mlx_vector_array_free(vector) - out := New("STACK", arrays...) + out := New("STACK") C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx) return out } @@ -153,13 +145,13 @@ func Take(a *Array, indices *Array, axis int) *Array { } func RSqrt(a *Array) *Array { - out := New("RSQRT", a) + out := New("RSQRT") C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx) return out } func Mean(a *Array, axis int, keepDims bool) *Array { - out := New("MEAN_AXIS", a) + out := New("MEAN_AXIS") C.mlx_mean_axis(&out.ctx, a.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx) return out } @@ -235,7 +227,7 @@ func SliceStartStop(a *Array, start, stop []int32) *Array { cStop[i] = C.int(stop[i]) cStrides[i] = 1 } - out := New("SLICE", a) + out := New("SLICE") C.mlx_slice(&out.ctx, a.ctx, unsafe.SliceData(cStart), C.size_t(n), unsafe.SliceData(cStop), C.size_t(n), unsafe.SliceData(cStrides), C.size_t(n), DefaultStream().ctx) return out } @@ -257,7 +249,7 @@ func SiLU(a *Array) *Array { func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array { freqs := New("") - out := New("FAST_ROPE", x, freqs) + out := New("FAST_ROPE") C.mlx_fast_rope( &out.ctx, x.ctx, @@ -289,13 +281,13 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b cMode := C.CString(mode) defer C.free(unsafe.Pointer(cMode)) - out := New("FAST_SDPA", q, k, v, mask, sinks) + out := New("FAST_SDPA") C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx) return out } func RMSNormFn(x, weight *Array, eps float32) *Array { - out := New("FAST_RMSNORM", x) + out := New("FAST_RMSNORM") C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx) return out } @@ -322,7 +314,7 @@ func scalarWithDtype(s float32, a *Array) C.mlx_array { func AddScalar(a *Array, s float32) *Array { scalar := scalarWithDtype(s, a) - out := New("ADD_SCALAR", a) + out := New("ADD_SCALAR") C.mlx_add(&out.ctx, a.ctx, scalar, DefaultStream().ctx) C.mlx_array_free(scalar) return out @@ -330,7 +322,7 @@ func AddScalar(a *Array, s float32) *Array { func MulScalar(a *Array, s float32) *Array { scalar := scalarWithDtype(s, a) - out := New("MUL_SCALAR", a) + out := New("MUL_SCALAR") C.mlx_multiply(&out.ctx, a.ctx, scalar, DefaultStream().ctx) C.mlx_array_free(scalar) return out @@ -338,7 +330,7 @@ func MulScalar(a *Array, s float32) *Array { func DivScalar(a *Array, s float32) *Array { scalar := scalarWithDtype(s, a) - out := New("DIV_SCALAR", a) + out := New("DIV_SCALAR") C.mlx_divide(&out.ctx, a.ctx, scalar, DefaultStream().ctx) C.mlx_array_free(scalar) return out diff --git a/x/mlxrunner/mlx/random.go b/x/mlxrunner/mlx/random.go index 805308b4a..6afdbbab4 100644 --- a/x/mlxrunner/mlx/random.go +++ b/x/mlxrunner/mlx/random.go @@ -7,7 +7,7 @@ import "C" func (t *Array) Categorical(axis int) *Array { key := New("") - out := New("", t, key) + out := New("") C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx) return out } diff --git a/x/mlxrunner/mlx/slice.go b/x/mlxrunner/mlx/slice.go index 7ab7e2031..ab1324774 100644 --- a/x/mlxrunner/mlx/slice.go +++ b/x/mlxrunner/mlx/slice.go @@ -61,7 +61,7 @@ func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) { func (t *Array) Slice(slices ...slice) *Array { starts, stops, strides := makeSlices(t.Dims(), slices...) - out := New("SLICE", t) + out := New("SLICE") C.mlx_slice( &out.ctx, t.ctx, unsafe.SliceData(starts), C.size_t(len(starts)), @@ -74,7 +74,7 @@ func (t *Array) Slice(slices ...slice) *Array { func (t *Array) SliceUpdate(other *Array, slices ...slice) *Array { starts, stops, strides := makeSlices(t.Dims(), slices...) - out := New("SLICE_UPDATE", t, other) + out := New("SLICE_UPDATE") C.mlx_slice_update( &out.ctx, t.ctx, other.ctx, unsafe.SliceData(starts), C.size_t(len(starts)), diff --git a/x/mlxrunner/model/base/base.go b/x/mlxrunner/model/base/base.go index 6d3a25798..e35eddd2d 100644 --- a/x/mlxrunner/model/base/base.go +++ b/x/mlxrunner/model/base/base.go @@ -78,8 +78,21 @@ func New(root *model.Root) (Model, error) { return fn(root) } -// Weights returns the model's LoadWeights method, which encapsulates all -// weight assignment and post-processing (MLA absorption, expert stacking). +// Weights returns a function that loads model weights, then pins all +// arrays reachable from the model struct and sweeps everything else. func Weights(m Model) func(map[string]*mlx.Array) error { - return m.LoadWeights + return func(tensors map[string]*mlx.Array) error { + if err := m.LoadWeights(tensors); err != nil { + return err + } + + collected := mlx.Collect(m) + for _, arr := range collected { + mlx.Pin(arr) + } + mlx.Sweep() + mlx.Eval(collected...) + + return nil + } } diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index 274fc9be6..618d7ec9e 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -4,10 +4,12 @@ package mlxrunner import ( "bytes" + "context" "errors" "log/slog" "time" + "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" ) @@ -45,8 +47,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error { slog.Info("Prompt processing progress", "processed", processed, "total", total) for total-processed > 1 { n := min(2<<10, total-processed-1) - temp := r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches) - defer mlx.Free(temp) + r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches) + mlx.Sweep() mlx.Eval(func() []*mlx.Array { s := make([]*mlx.Array, 2*len(caches)) for i, c := range caches { @@ -65,11 +67,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error { logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1) logprobs := logits.Subtract(logits.Logsumexp(true)) - return request.Sample(logprobs), logprobs + sample := request.Sample(logprobs) + + mlx.Pin(sample, logprobs) + mlx.Sweep() + mlx.AsyncEval(sample, logprobs) + + return sample, logprobs } sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed)) - mlx.AsyncEval(sample, logprobs) var b bytes.Buffer @@ -78,7 +85,6 @@ func (r *Runner) TextGenerationPipeline(request Request) error { outputs := make([]int32, 0, request.Options.MaxTokens) for i := range request.Options.MaxTokens { nextSample, nextLogprobs := step(sample) - mlx.AsyncEval(nextSample, nextLogprobs) if i == 0 { slog.Info("Prompt processing progress", "processed", total, "total", total) @@ -91,6 +97,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error { outputs = append(outputs, output) if r.Tokenizer.IsEOS(output) { + mlx.Unpin(nextSample, nextLogprobs) final.Token = int(output) final.DoneReason = 0 final.CompletionTokens = i @@ -102,7 +109,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error { Token: int(output), } - mlx.Free(sample, logprobs) + mlx.Unpin(sample, logprobs) if i%256 == 0 { mlx.ClearCache() } @@ -110,10 +117,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error { sample, logprobs = nextSample, nextLogprobs } - mlx.Free(sample, logprobs) + mlx.Unpin(sample, logprobs) final.CompletionTokensDuration = time.Since(now) request.Responses <- final r.InsertCache(append(inputs, outputs...), caches) + mlx.Sweep() + + if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) { + mlx.LogArrays() + } + return nil } diff --git a/x/models/gemma3/gemma3.go b/x/models/gemma3/gemma3.go index 7ba24d294..3d2c398b7 100644 --- a/x/models/gemma3/gemma3.go +++ b/x/models/gemma3/gemma3.go @@ -401,9 +401,6 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { if m.NormScaled == nil { return fmt.Errorf("missing precomputed final norm weight") } - collected := mlx.Collect(m) - mlx.Eval(collected...) - return nil } diff --git a/x/models/glm4_moe_lite/glm4_moe_lite.go b/x/models/glm4_moe_lite/glm4_moe_lite.go index a1ec55972..b79e245da 100644 --- a/x/models/glm4_moe_lite/glm4_moe_lite.go +++ b/x/models/glm4_moe_lite/glm4_moe_lite.go @@ -702,9 +702,6 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { } } - collected := mlx.Collect(m) - mlx.Eval(collected...) - return nil } diff --git a/x/models/llama/llama.go b/x/models/llama/llama.go index 61e51b35c..bef98fbb4 100644 --- a/x/models/llama/llama.go +++ b/x/models/llama/llama.go @@ -235,9 +235,6 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { m.Layers[i] = layer } - collected := mlx.Collect(m) - mlx.Eval(collected...) - return nil } diff --git a/x/models/qwen3/qwen3.go b/x/models/qwen3/qwen3.go index 76170881a..392f90755 100644 --- a/x/models/qwen3/qwen3.go +++ b/x/models/qwen3/qwen3.go @@ -252,9 +252,6 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { m.Layers[i] = layer } - collected := mlx.Collect(m) - mlx.Eval(collected...) - return nil }