mirror of
https://github.com/ollama/ollama.git
synced 2026-03-08 23:04:13 -05:00
mlxrunner: Refcount pinned tensors
Otherwise, it is error prone to manage multiple components working with the same tensor.
This commit is contained in:
@@ -20,7 +20,7 @@ import (
|
||||
type Array struct {
|
||||
ctx C.mlx_array
|
||||
name string
|
||||
pinned bool
|
||||
pinned int
|
||||
}
|
||||
|
||||
var arrays []*Array
|
||||
@@ -129,7 +129,7 @@ func (t *Array) Clone() *Array {
|
||||
func Pin(s ...*Array) {
|
||||
for _, t := range s {
|
||||
if t != nil {
|
||||
t.pinned = true
|
||||
t.pinned++
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -138,7 +138,7 @@ func Pin(s ...*Array) {
|
||||
func Unpin(s ...*Array) {
|
||||
for _, t := range s {
|
||||
if t != nil {
|
||||
t.pinned = false
|
||||
t.pinned--
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -148,7 +148,7 @@ func Unpin(s ...*Array) {
|
||||
func Sweep() {
|
||||
n := 0
|
||||
for _, t := range arrays {
|
||||
if t.pinned && t.Valid() {
|
||||
if t.pinned > 0 && t.Valid() {
|
||||
arrays[n] = t
|
||||
n++
|
||||
} else if t.Valid() {
|
||||
@@ -175,7 +175,7 @@ func (t *Array) String() string {
|
||||
func (t *Array) LogValue() slog.Value {
|
||||
attrs := []slog.Attr{
|
||||
slog.String("name", t.name),
|
||||
slog.Bool("pinned", t.pinned),
|
||||
slog.Int("pinned", t.pinned),
|
||||
}
|
||||
if t.Valid() {
|
||||
attrs = append(attrs,
|
||||
|
||||
Reference in New Issue
Block a user