mlxrunner: Fix memory leaks with pin/sweep lifecycle management

The previous approach tracked array lifecycles through reference
counting, where each array recorded its inputs and a reference count
that was decremented as dependents were freed. This is not really
necessary as MLX tracks references internally. It is also error
prone as it is easy to create new arrays and forget to free them
when the Go variable goes out of scope.

Instead, we can pin just the arrays we want (typically outputs and
specific intermediates, like the cache). All other arrays are freed
by default when we run sweep. This avoids most causes of memory leaks
while still giving the freedom to save what we want.
This commit is contained in:
Jesse Gross
2026-02-19 15:05:35 -08:00
parent 0ade9205cc
commit 5daf59cc66
14 changed files with 159 additions and 151 deletions

View File

@@ -401,9 +401,6 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
if m.NormScaled == nil {
return fmt.Errorf("missing precomputed final norm weight")
}
collected := mlx.Collect(m)
mlx.Eval(collected...)
return nil
}

View File

@@ -702,9 +702,6 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
}
}
collected := mlx.Collect(m)
mlx.Eval(collected...)
return nil
}

View File

@@ -235,9 +235,6 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
m.Layers[i] = layer
}
collected := mlx.Collect(m)
mlx.Eval(collected...)
return nil
}

View File

@@ -252,9 +252,6 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
m.Layers[i] = layer
}
collected := mlx.Collect(m)
mlx.Eval(collected...)
return nil
}