Files
ollama/x/mlxrunner/mlx/thread_test.go
Daniel Hiltgen 534342e7e2 Update MLX and MLX-C with threading fixes (#15845)
* Update MLX and MLX-C

* Run MLX CGO work on a locked OS thread

MLX now relies on OS-thread-local execution state for streams, encoders, and caches. Add an mlxthread executor backed by runtime.LockOSThread and route runner initialization, model load, inference, status memory reads, and cleanup through the worker so Go goroutine migration cannot split MLX state across native threads.

Also stop caching default MLX streams before the runner owns the thread and add worker/threaded MLX regression tests.

* mlx: use common status writer

* mlx: bundle missing libjaccl on arm64

Inspired by #15793

* review comments
2026-05-03 10:03:14 -07:00

105 lines
1.7 KiB
Go

package mlx
import (
"context"
"runtime"
"sync"
"testing"
"github.com/ollama/ollama/x/internal/mlxthread"
)
func skipIfNoMLX(t *testing.T) {
t.Helper()
if err := CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
}
func startMLXThread(t *testing.T) *mlxthread.Thread {
t.Helper()
thread, err := mlxthread.Start("mlx-test", func() error {
if err := CheckInit(); err != nil {
return err
}
if GPUIsAvailable() {
SetDefaultDeviceGPU()
}
return nil
})
if err != nil {
t.Skipf("MLX not available: %v", err)
}
return thread
}
func stopMLXThread(t *testing.T, thread *mlxthread.Thread) {
t.Helper()
if err := thread.Stop(context.Background(), func() {
Sweep()
ClearCache()
resetDefaultStreamCache()
}); err != nil {
t.Fatal(err)
}
}
func withMLXThread(t *testing.T, fn func()) {
t.Helper()
thread := startMLXThread(t)
defer stopMLXThread(t, thread)
if err := thread.Do(context.Background(), func() error {
fn()
return nil
}); err != nil {
t.Fatal(err)
}
}
func TestThreadedMLXOperations(t *testing.T) {
thread := startMLXThread(t)
defer stopMLXThread(t, thread)
oldProcs := runtime.GOMAXPROCS(8)
defer runtime.GOMAXPROCS(oldProcs)
const goroutines = 8
const iterations = 8
var wg sync.WaitGroup
errCh := make(chan error, goroutines)
for range goroutines {
wg.Add(1)
go func() {
defer wg.Done()
for range iterations {
if err := thread.Do(context.Background(), func() error {
a := FromValues([]float32{1, 2, 3, 4}, 2, 2)
b := Matmul(a, a)
AsyncEval(b)
Eval(b)
Sweep()
ClearCache()
return nil
}); err != nil {
errCh <- err
return
}
}
}()
}
wg.Wait()
close(errCh)
for err := range errCh {
t.Fatal(err)
}
}