mlxrunner: Refcount pinned tensors

Otherwise, it is error prone to manage multiple components working
with the same tensor.
This commit is contained in:
Jesse Gross
2026-03-02 12:48:02 -08:00
parent a3093cd5e5
commit c1e3ef4bcc

View File

@@ -20,7 +20,7 @@ import (
type Array struct {
ctx C.mlx_array
name string
pinned bool
pinned int
}
var arrays []*Array
@@ -129,7 +129,7 @@ func (t *Array) Clone() *Array {
func Pin(s ...*Array) {
for _, t := range s {
if t != nil {
t.pinned = true
t.pinned++
}
}
}
@@ -138,7 +138,7 @@ func Pin(s ...*Array) {
func Unpin(s ...*Array) {
for _, t := range s {
if t != nil {
t.pinned = false
t.pinned--
}
}
}
@@ -148,7 +148,7 @@ func Unpin(s ...*Array) {
func Sweep() {
n := 0
for _, t := range arrays {
if t.pinned && t.Valid() {
if t.pinned > 0 && t.Valid() {
arrays[n] = t
n++
} else if t.Valid() {
@@ -175,7 +175,7 @@ func (t *Array) String() string {
func (t *Array) LogValue() slog.Value {
attrs := []slog.Attr{
slog.String("name", t.name),
slog.Bool("pinned", t.pinned),
slog.Int("pinned", t.pinned),
}
if t.Valid() {
attrs = append(attrs,