mlxrunner: Fix memory leaks with pin/sweep lifecycle management

The previous approach tracked array lifecycles through reference
counting, where each array recorded its inputs and a reference count
that was decremented as dependents were freed. This is not really
necessary as MLX tracks references internally. It is also error
prone as it is easy to create new arrays and forget to free them
when the Go variable goes out of scope.

Instead, we can pin just the arrays we want (typically outputs and
specific intermediates, like the cache). All other arrays are freed
by default when we run sweep. This avoids most causes of memory leaks
while still giving the freedom to save what we want.
This commit is contained in:
Jesse Gross
2026-02-19 15:05:35 -08:00
parent 0ade9205cc
commit 5daf59cc66
14 changed files with 159 additions and 151 deletions

View File

@@ -47,6 +47,7 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
c.values.Set(c.values.Concatenate(2, newValues))
} else {
c.keys, c.values = newKeys, newValues
mlx.Pin(c.keys, c.values)
}
}
@@ -73,12 +74,14 @@ func (c *KVCache) Trim(n int) int {
}
func (c *KVCache) Clone() Cache {
return &KVCache{
clone := &KVCache{
keys: c.keys.Clone(),
values: c.values.Clone(),
offset: c.offset,
step: c.step,
}
mlx.Pin(clone.keys, clone.values)
return clone
}
func (c *KVCache) Offset() int { return c.offset }
@@ -106,7 +109,8 @@ func (c *RotatingKVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Arra
func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) {
slog.Debug("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
if c.keys == nil {
c.keys, c.values = keys, values
c.keys, c.values = keys.Clone(), values.Clone()
mlx.Pin(c.keys, c.values)
} else {
if c.idx < c.keys.Dim(2) {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
@@ -145,6 +149,7 @@ func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Arra
c.values.Set(c.values.Concatenate(2, newValues))
} else {
c.keys, c.values = newKeys, newValues
mlx.Pin(c.keys, c.values)
}
c.idx = prev
}

View File

@@ -7,48 +7,29 @@ import "C"
import (
"encoding/binary"
"fmt"
"log/slog"
"reflect"
"sort"
"strings"
"time"
"unsafe"
"github.com/ollama/ollama/logutil"
)
type tensorDesc struct {
name string
inputs []*Array
numRefs int
}
func (d tensorDesc) LogValue() slog.Value {
return slog.GroupValue(
slog.String("name", d.name),
slog.Int("inputs", len(d.inputs)),
slog.Int("num_refs", d.numRefs),
)
}
type Array struct {
ctx C.mlx_array
desc tensorDesc
ctx C.mlx_array
name string
pinned bool
}
var arrays []*Array
// constructor utilities
func New(name string, inputs ...*Array) *Array {
t := &Array{
desc: tensorDesc{
name: name,
inputs: inputs,
},
}
for _, input := range inputs {
input.desc.numRefs++
}
logutil.Trace("New", "t", t)
func New(name string) *Array {
t := &Array{name: name}
arrays = append(arrays, t)
return t
}
@@ -133,18 +114,51 @@ func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
}
func (t *Array) Set(other *Array) {
Free(t.desc.inputs...)
other.desc.numRefs++
t.desc.inputs = []*Array{other}
C.mlx_array_set(&t.ctx, other.ctx)
}
func (t *Array) Clone() *Array {
tt := New(t.desc.name, t.desc.inputs...)
tt := New(t.name)
C.mlx_array_set(&tt.ctx, t.ctx)
return tt
}
// lifecycle utilities
// Pin marks arrays as in-use so they are retained during Sweep.
func Pin(s ...*Array) {
for _, t := range s {
if t != nil {
t.pinned = true
}
}
}
// Unpin marks arrays as no longer in-use, allowing Sweep to free them.
func Unpin(s ...*Array) {
for _, t := range s {
if t != nil {
t.pinned = false
}
}
}
// Sweep releases all unpinned arrays, primarily intermediate tensors. MLX will truly
// free them when there are no other references, including dependencies in the graph.
func Sweep() {
n := 0
for _, t := range arrays {
if t.pinned && t.Valid() {
arrays[n] = t
n++
} else if t.Valid() {
C.mlx_array_free(t.ctx)
t.ctx.ctx = nil
}
}
arrays = arrays[:n]
}
// misc. utilities
func (t *Array) Valid() bool {
@@ -159,7 +173,10 @@ func (t *Array) String() string {
}
func (t *Array) LogValue() slog.Value {
attrs := []slog.Attr{slog.Any("", t.desc)}
attrs := []slog.Attr{
slog.String("name", t.name),
slog.Bool("pinned", t.pinned),
}
if t.Valid() {
attrs = append(attrs,
slog.Any("dtype", t.DType()),
@@ -238,37 +255,15 @@ func (t Array) Save(name string) error {
return nil
}
func Free(s ...*Array) (n int) {
now := time.Now()
defer func() {
if n > 0 {
logutil.Trace("Freed tensors", "num_bytes", PrettyBytes(n), "took", time.Since(now))
}
}()
// LogArrays logs all live arrays, sorted by size
func LogArrays() {
sort.Slice(arrays, func(i, j int) bool {
return arrays[i].NumBytes() > arrays[j].NumBytes()
})
free := make([]*Array, 0, 8192)
fn := func(t *Array) {
if t.Valid() {
t.desc.numRefs--
if t.desc.numRefs <= 0 {
free = append(free, t.desc.inputs...)
logutil.Trace("Free", "t", t)
n += t.NumBytes()
C.mlx_array_free(t.ctx)
t.ctx.ctx = nil
}
}
for _, t := range arrays {
nb := t.NumBytes()
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s %v", t.name, t.DType(), PrettyBytes(nb), t.Dims()))
}
for _, t := range s {
fn(t)
}
for len(free) > 0 {
tail := free[len(free)-1]
free = free[:len(free)-1]
fn(tail)
}
return n
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory())))
}

View File

@@ -20,7 +20,7 @@ func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *A
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
out := New("FAST_SDPA", query, key, value, mask, sinks)
out := New("FAST_SDPA")
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
return out
}
@@ -31,7 +31,7 @@ type LayerNorm struct {
}
func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_LAYERNORM", x)
out := New("FAST_LAYERNORM")
C.mlx_fast_layer_norm(&out.ctx, x.ctx, r.Weight.ctx, r.Bias.ctx, C.float(eps), DefaultStream().ctx)
return out
}
@@ -41,7 +41,7 @@ type RMSNorm struct {
}
func (r RMSNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_RMSNORM", x)
out := New("FAST_RMSNORM")
C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
return out
}
@@ -55,7 +55,7 @@ type RoPE struct {
func (r RoPE) Forward(t *Array, offset int) *Array {
freqs := New("")
out := New("FAST_ROPE", t, freqs)
out := New("FAST_ROPE")
C.mlx_fast_rope(
&out.ctx,
t.ctx,

View File

@@ -37,7 +37,9 @@ func Load(path string) iter.Seq2[string, *Array] {
}
name := C.GoString(key)
if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) {
arr := New(name)
arr.ctx = value
if !yield(name, arr) {
break
}
}

View File

@@ -10,43 +10,43 @@ import (
)
func (t *Array) Abs() *Array {
out := New("ABS", t)
out := New("ABS")
C.mlx_abs(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Add(other *Array) *Array {
out := New("ADD", t, other)
out := New("ADD")
C.mlx_add(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Addmm(a, b *Array, alpha, beta float32) *Array {
out := New("ADDMM", t, a, b)
out := New("ADDMM")
C.mlx_addmm(&out.ctx, t.ctx, a.ctx, b.ctx, C.float(alpha), C.float(beta), DefaultStream().ctx)
return out
}
func (t *Array) Argmax(axis int, keepDims bool) *Array {
out := New("ARGMAX", t)
out := New("ARGMAX")
C.mlx_argmax_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) ArgpartitionAxis(kth int, axis int) *Array {
out := New("ARGPARTITION", t)
out := New("ARGPARTITION")
C.mlx_argpartition_axis(&out.ctx, t.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) ArgsortAxis(axis int) *Array {
out := New("ARGSORT_AXIS", t)
out := New("ARGSORT_AXIS")
C.mlx_argsort_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) AsType(dtype DType) *Array {
out := New("AS_TYPE", t)
out := New("AS_TYPE")
C.mlx_astype(&out.ctx, t.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
return out
}
@@ -62,7 +62,7 @@ func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array {
cStrides[i] = C.int64_t(s)
}
out := New("AS_STRIDED", t)
out := New("AS_STRIDED")
C.mlx_as_strided(
&out.ctx, t.ctx,
unsafe.SliceData(cShape), C.size_t(len(shape)),
@@ -82,31 +82,31 @@ func (t *Array) Concatenate(axis int, others ...*Array) *Array {
C.mlx_vector_array_append_value(vector, other.ctx)
}
out := New("CONCATENATE", s...)
out := New("CONCATENATE")
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Divide(other *Array) *Array {
out := New("DIVIDE", t, other)
out := New("DIVIDE")
C.mlx_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) ExpandDims(axis int) *Array {
out := New("EXPAND_DIMS", t)
out := New("EXPAND_DIMS")
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Flatten(startAxis, endAxis int) *Array {
out := New("FLATTEN", t)
out := New("FLATTEN")
C.mlx_flatten(&out.ctx, t.ctx, C.int(startAxis), C.int(endAxis), DefaultStream().ctx)
return out
}
func (t *Array) FloorDivide(other *Array) *Array {
out := New("FLOOR_DIVIDE", t, other)
out := New("FLOOR_DIVIDE")
C.mlx_floor_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
@@ -118,43 +118,43 @@ func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
if rhs == nil {
rhs = New("")
}
out := New("GATHER_MM", t, other, lhs, rhs)
out := New("GATHER_MM")
C.mlx_gather_mm(&out.ctx, t.ctx, other.ctx, lhs.ctx, rhs.ctx, C.bool(sorted), DefaultStream().ctx)
return out
}
func (t *Array) Logsumexp(keepDims bool) *Array {
out := New("LOGSUMEXP", t)
out := New("LOGSUMEXP")
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) Matmul(other *Array) *Array {
out := New("MATMUL", t, other)
out := New("MATMUL")
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Multiply(other *Array) *Array {
out := New("MULTIPLY", t, other)
out := New("MULTIPLY")
C.mlx_multiply(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Negative() *Array {
out := New("NEGATIVE", t)
out := New("NEGATIVE")
C.mlx_negative(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Power(exponent *Array) *Array {
out := New("POWER", t, exponent)
out := New("POWER")
C.mlx_power(&out.ctx, t.ctx, exponent.ctx, DefaultStream().ctx)
return out
}
func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
out := New("PUT_ALONG_AXIS", t, indices, values)
out := New("PUT_ALONG_AXIS")
C.mlx_put_along_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
return out
}
@@ -165,25 +165,25 @@ func (t *Array) Reshape(axes ...int) *Array {
cAxes[i] = C.int(axes[i])
}
out := New("RESHAPE", t)
out := New("RESHAPE")
C.mlx_reshape(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
return out
}
func (t *Array) Sigmoid() *Array {
out := New("SIGMOID", t)
out := New("SIGMOID")
C.mlx_sigmoid(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Sqrt() *Array {
out := New("SQRT", t)
out := New("SQRT")
C.mlx_sqrt(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Squeeze(axis int) *Array {
out := New("SQUEEZE", t)
out := New("SQUEEZE")
C.mlx_squeeze_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
@@ -198,37 +198,37 @@ func (t *Array) StackAxis(axis int, others ...*Array) *Array {
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
defer C.mlx_vector_array_free(vector)
out := New("STACK_AXIS", append(others, t)...)
out := New("STACK_AXIS")
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Subtract(other *Array) *Array {
out := New("SUBTRACT", t, other)
out := New("SUBTRACT")
C.mlx_subtract(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) SumAxis(axis int, keepDims bool) *Array {
out := New("SUM_AXIS", t)
out := New("SUM_AXIS")
C.mlx_sum_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) TakeAxis(indices *Array, axis int) *Array {
out := New("TAKE_AXIS", t, indices)
out := New("TAKE_AXIS")
C.mlx_take_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) TakeAlongAxis(indices *Array, axis int) *Array {
out := New("TAKE_ALONG_AXIS", t, indices)
out := New("TAKE_ALONG_AXIS")
C.mlx_take_along_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Tanh() *Array {
out := New("TANH", t)
out := New("TANH")
C.mlx_tanh(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
@@ -239,7 +239,7 @@ func (t *Array) Transpose(axes ...int) *Array {
cAxes[i] = C.int(axis)
}
out := New("TRANSPOSE", t)
out := New("TRANSPOSE")
C.mlx_transpose_axes(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
return out
}

View File

@@ -41,14 +41,12 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
optDtype := C.mlx_optional_dtype{has_value: false}
inputs := []*Array{w, scales}
var b C.mlx_array
if biases != nil {
b = biases.ctx
inputs = append(inputs, biases)
}
out := New("DEQUANTIZE", inputs...)
out := New("DEQUANTIZE")
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, optDtype, DefaultStream().ctx)
return out
}
@@ -59,14 +57,12 @@ func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bit
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
inputs := []*Array{x, w, scales}
var b C.mlx_array
if biases != nil {
b = biases.ctx
inputs = append(inputs, biases)
}
out := New("QUANTIZED_MATMUL", inputs...)
out := New("QUANTIZED_MATMUL")
C.mlx_quantized_matmul(&out.ctx, x.ctx, w.ctx, scales.ctx, b, C.bool(transpose), optGroupSize, optBits, cMode, DefaultStream().ctx)
return out
}
@@ -77,22 +73,18 @@ func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, trans
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
inputs := []*Array{x, w, scales}
var b, lhs, rhs C.mlx_array
if biases != nil {
b = biases.ctx
inputs = append(inputs, biases)
}
if lhsIndices != nil {
lhs = lhsIndices.ctx
inputs = append(inputs, lhsIndices)
}
if rhsIndices != nil {
rhs = rhsIndices.ctx
inputs = append(inputs, rhsIndices)
}
out := New("GATHER_QMM", inputs...)
out := New("GATHER_QMM")
C.mlx_gather_qmm(&out.ctx, x.ctx, w.ctx, scales.ctx, b, lhs, rhs, C.bool(transpose), optGroupSize, optBits, cMode, C.bool(sortedIndices), DefaultStream().ctx)
return out
}
@@ -104,7 +96,7 @@ func Tile(a *Array, reps []int32) *Array {
for i, r := range reps {
cReps[i] = C.int(r)
}
out := New("TILE", a)
out := New("TILE")
C.mlx_tile(&out.ctx, a.ctx, unsafe.SliceData(cReps), C.size_t(len(reps)), DefaultStream().ctx)
return out
}
@@ -116,7 +108,7 @@ func Tri(n, m int32, k int) *Array {
}
func Where(condition, a, b *Array) *Array {
out := New("WHERE", condition, a, b)
out := New("WHERE")
C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
@@ -131,7 +123,7 @@ func Stack(arrays []*Array, axis int) *Array {
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
defer C.mlx_vector_array_free(vector)
out := New("STACK", arrays...)
out := New("STACK")
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
@@ -153,13 +145,13 @@ func Take(a *Array, indices *Array, axis int) *Array {
}
func RSqrt(a *Array) *Array {
out := New("RSQRT", a)
out := New("RSQRT")
C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Mean(a *Array, axis int, keepDims bool) *Array {
out := New("MEAN_AXIS", a)
out := New("MEAN_AXIS")
C.mlx_mean_axis(&out.ctx, a.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
@@ -235,7 +227,7 @@ func SliceStartStop(a *Array, start, stop []int32) *Array {
cStop[i] = C.int(stop[i])
cStrides[i] = 1
}
out := New("SLICE", a)
out := New("SLICE")
C.mlx_slice(&out.ctx, a.ctx, unsafe.SliceData(cStart), C.size_t(n), unsafe.SliceData(cStop), C.size_t(n), unsafe.SliceData(cStrides), C.size_t(n), DefaultStream().ctx)
return out
}
@@ -257,7 +249,7 @@ func SiLU(a *Array) *Array {
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
freqs := New("")
out := New("FAST_ROPE", x, freqs)
out := New("FAST_ROPE")
C.mlx_fast_rope(
&out.ctx,
x.ctx,
@@ -289,13 +281,13 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
out := New("FAST_SDPA", q, k, v, mask, sinks)
out := New("FAST_SDPA")
C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
return out
}
func RMSNormFn(x, weight *Array, eps float32) *Array {
out := New("FAST_RMSNORM", x)
out := New("FAST_RMSNORM")
C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
return out
}
@@ -322,7 +314,7 @@ func scalarWithDtype(s float32, a *Array) C.mlx_array {
func AddScalar(a *Array, s float32) *Array {
scalar := scalarWithDtype(s, a)
out := New("ADD_SCALAR", a)
out := New("ADD_SCALAR")
C.mlx_add(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
C.mlx_array_free(scalar)
return out
@@ -330,7 +322,7 @@ func AddScalar(a *Array, s float32) *Array {
func MulScalar(a *Array, s float32) *Array {
scalar := scalarWithDtype(s, a)
out := New("MUL_SCALAR", a)
out := New("MUL_SCALAR")
C.mlx_multiply(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
C.mlx_array_free(scalar)
return out
@@ -338,7 +330,7 @@ func MulScalar(a *Array, s float32) *Array {
func DivScalar(a *Array, s float32) *Array {
scalar := scalarWithDtype(s, a)
out := New("DIV_SCALAR", a)
out := New("DIV_SCALAR")
C.mlx_divide(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
C.mlx_array_free(scalar)
return out

View File

@@ -7,7 +7,7 @@ import "C"
func (t *Array) Categorical(axis int) *Array {
key := New("")
out := New("", t, key)
out := New("")
C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx)
return out
}

View File

@@ -61,7 +61,7 @@ func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
func (t *Array) Slice(slices ...slice) *Array {
starts, stops, strides := makeSlices(t.Dims(), slices...)
out := New("SLICE", t)
out := New("SLICE")
C.mlx_slice(
&out.ctx, t.ctx,
unsafe.SliceData(starts), C.size_t(len(starts)),
@@ -74,7 +74,7 @@ func (t *Array) Slice(slices ...slice) *Array {
func (t *Array) SliceUpdate(other *Array, slices ...slice) *Array {
starts, stops, strides := makeSlices(t.Dims(), slices...)
out := New("SLICE_UPDATE", t, other)
out := New("SLICE_UPDATE")
C.mlx_slice_update(
&out.ctx, t.ctx, other.ctx,
unsafe.SliceData(starts), C.size_t(len(starts)),

View File

@@ -78,8 +78,21 @@ func New(root *model.Root) (Model, error) {
return fn(root)
}
// Weights returns the model's LoadWeights method, which encapsulates all
// weight assignment and post-processing (MLA absorption, expert stacking).
// Weights returns a function that loads model weights, then pins all
// arrays reachable from the model struct and sweeps everything else.
func Weights(m Model) func(map[string]*mlx.Array) error {
return m.LoadWeights
return func(tensors map[string]*mlx.Array) error {
if err := m.LoadWeights(tensors); err != nil {
return err
}
collected := mlx.Collect(m)
for _, arr := range collected {
mlx.Pin(arr)
}
mlx.Sweep()
mlx.Eval(collected...)
return nil
}
}

View File

@@ -4,10 +4,12 @@ package mlxrunner
import (
"bytes"
"context"
"errors"
"log/slog"
"time"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
@@ -45,8 +47,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
slog.Info("Prompt processing progress", "processed", processed, "total", total)
for total-processed > 1 {
n := min(2<<10, total-processed-1)
temp := r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
defer mlx.Free(temp)
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
mlx.Sweep()
mlx.Eval(func() []*mlx.Array {
s := make([]*mlx.Array, 2*len(caches))
for i, c := range caches {
@@ -65,11 +67,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
logprobs := logits.Subtract(logits.Logsumexp(true))
return request.Sample(logprobs), logprobs
sample := request.Sample(logprobs)
mlx.Pin(sample, logprobs)
mlx.Sweep()
mlx.AsyncEval(sample, logprobs)
return sample, logprobs
}
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
mlx.AsyncEval(sample, logprobs)
var b bytes.Buffer
@@ -78,7 +85,6 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
outputs := make([]int32, 0, request.Options.MaxTokens)
for i := range request.Options.MaxTokens {
nextSample, nextLogprobs := step(sample)
mlx.AsyncEval(nextSample, nextLogprobs)
if i == 0 {
slog.Info("Prompt processing progress", "processed", total, "total", total)
@@ -91,6 +97,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
outputs = append(outputs, output)
if r.Tokenizer.IsEOS(output) {
mlx.Unpin(nextSample, nextLogprobs)
final.Token = int(output)
final.DoneReason = 0
final.CompletionTokens = i
@@ -102,7 +109,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
Token: int(output),
}
mlx.Free(sample, logprobs)
mlx.Unpin(sample, logprobs)
if i%256 == 0 {
mlx.ClearCache()
}
@@ -110,10 +117,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
sample, logprobs = nextSample, nextLogprobs
}
mlx.Free(sample, logprobs)
mlx.Unpin(sample, logprobs)
final.CompletionTokensDuration = time.Since(now)
request.Responses <- final
r.InsertCache(append(inputs, outputs...), caches)
mlx.Sweep()
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
mlx.LogArrays()
}
return nil
}

View File

@@ -401,9 +401,6 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
if m.NormScaled == nil {
return fmt.Errorf("missing precomputed final norm weight")
}
collected := mlx.Collect(m)
mlx.Eval(collected...)
return nil
}

View File

@@ -702,9 +702,6 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
}
}
collected := mlx.Collect(m)
mlx.Eval(collected...)
return nil
}

View File

@@ -235,9 +235,6 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
m.Layers[i] = layer
}
collected := mlx.Collect(m)
mlx.Eval(collected...)
return nil
}

View File

@@ -252,9 +252,6 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
m.Layers[i] = layer
}
collected := mlx.Collect(m)
mlx.Eval(collected...)
return nil
}