Files
ollama/x/internal/mlxthread/thread_test.go
Daniel Hiltgen 534342e7e2 Update MLX and MLX-C with threading fixes (#15845)
* Update MLX and MLX-C

* Run MLX CGO work on a locked OS thread

MLX now relies on OS-thread-local execution state for streams, encoders, and caches. Add an mlxthread executor backed by runtime.LockOSThread and route runner initialization, model load, inference, status memory reads, and cleanup through the worker so Go goroutine migration cannot split MLX state across native threads.

Also stop caching default MLX streams before the runner owns the thread and add worker/threaded MLX regression tests.

* mlx: use common status writer

* mlx: bundle missing libjaccl on arm64

Inspired by #15793

* review comments
2026-05-03 10:03:14 -07:00

352 lines
7.1 KiB
Go

package mlxthread
import (
"context"
"errors"
"reflect"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestDoRunsInOrder(t *testing.T) {
thread, err := Start("test", nil)
if err != nil {
t.Fatal(err)
}
defer thread.Stop(context.Background(), nil)
var got []int
for i := 0; i < 5; i++ {
i := i
if err := thread.Do(context.Background(), func() error {
got = append(got, i)
return nil
}); err != nil {
t.Fatal(err)
}
}
if want := []int{0, 1, 2, 3, 4}; !reflect.DeepEqual(got, want) {
t.Fatalf("got %v, want %v", got, want)
}
}
func TestDoPropagatesPanicToCaller(t *testing.T) {
thread, err := Start("test", nil)
if err != nil {
t.Fatal(err)
}
defer thread.Stop(context.Background(), nil)
defer func() {
if got := recover(); got != "boom" {
t.Fatalf("got panic %v, want boom", got)
}
}()
_ = thread.Do(context.Background(), func() error {
panic("boom")
})
}
func TestDoCancelsBeforeJobStarts(t *testing.T) {
thread, err := Start("test", nil)
if err != nil {
t.Fatal(err)
}
defer thread.Stop(context.Background(), nil)
running := make(chan struct{})
release := make(chan struct{})
errCh := make(chan error, 1)
go func() {
errCh <- thread.Do(context.Background(), func() error {
close(running)
<-release
return nil
})
}()
<-running
ctx, cancel := context.WithCancel(context.Background())
cancel()
err = thread.Do(ctx, func() error {
t.Fatal("canceled job should not run")
return nil
})
if !errors.Is(err, context.Canceled) {
t.Fatalf("got %v, want %v", err, context.Canceled)
}
close(release)
if err := <-errCh; err != nil {
t.Fatal(err)
}
}
func TestAlreadyCanceledContextDoesNotEnqueue(t *testing.T) {
t.Run("Do", func(t *testing.T) {
thread, err := Start("test", nil)
if err != nil {
t.Fatal(err)
}
defer thread.Stop(context.Background(), nil)
ctx, cancel := context.WithCancel(context.Background())
cancel()
ran := false
err = thread.Do(ctx, func() error {
ran = true
return nil
})
if !errors.Is(err, context.Canceled) {
t.Fatalf("got %v, want %v", err, context.Canceled)
}
if ran {
t.Fatal("canceled job ran")
}
})
t.Run("Stop", func(t *testing.T) {
thread, err := Start("test", nil)
if err != nil {
t.Fatal(err)
}
defer thread.Stop(context.Background(), nil)
ctx, cancel := context.WithCancel(context.Background())
cancel()
cleaned := false
err = thread.Stop(ctx, func() {
cleaned = true
})
if !errors.Is(err, context.Canceled) {
t.Fatalf("got %v, want %v", err, context.Canceled)
}
if cleaned {
t.Fatal("cleanup ran for canceled stop")
}
if err := thread.Do(context.Background(), func() error { return nil }); err != nil {
t.Fatalf("thread did not accept work after canceled Stop: %v", err)
}
})
}
func TestCallReturnsValue(t *testing.T) {
thread, err := Start("test", nil)
if err != nil {
t.Fatal(err)
}
defer thread.Stop(context.Background(), nil)
got, err := Call(context.Background(), thread, func() (int, error) {
return 42, nil
})
if err != nil {
t.Fatal(err)
}
if got != 42 {
t.Fatalf("got %d, want 42", got)
}
}
func TestDoRunsConcurrentlySubmittedWorkSerially(t *testing.T) {
thread, err := Start("test", nil)
if err != nil {
t.Fatal(err)
}
defer thread.Stop(context.Background(), nil)
oldProcs := runtime.GOMAXPROCS(8)
defer runtime.GOMAXPROCS(oldProcs)
const goroutines = 16
const iterations = 64
var active atomic.Int32
var count atomic.Int64
var wg sync.WaitGroup
errCh := make(chan error, goroutines)
for range goroutines {
wg.Add(1)
go func() {
defer wg.Done()
for range iterations {
if err := thread.Do(context.Background(), func() error {
if got := active.Add(1); got != 1 {
return errors.New("thread executed jobs concurrently")
}
runtime.Gosched()
count.Add(1)
if got := active.Add(-1); got != 0 {
return errors.New("thread active count did not return to zero")
}
return nil
}); err != nil {
errCh <- err
return
}
}
}()
}
wg.Wait()
close(errCh)
for err := range errCh {
t.Fatal(err)
}
if got, want := count.Load(), int64(goroutines*iterations); got != want {
t.Fatalf("got %d jobs, want %d", got, want)
}
}
func TestStopRunsCleanupAndRejectsWork(t *testing.T) {
thread, err := Start("test", nil)
if err != nil {
t.Fatal(err)
}
cleaned := 0
if err := thread.Stop(context.Background(), func() {
cleaned++
}); err != nil {
t.Fatal(err)
}
if cleaned != 1 {
t.Fatalf("cleanup ran %d times, want 1", cleaned)
}
if err := thread.Stop(context.Background(), func() {
cleaned++
}); err != nil {
t.Fatal(err)
}
if cleaned != 1 {
t.Fatalf("cleanup ran %d times after second Stop, want 1", cleaned)
}
err = thread.Do(context.Background(), func() error {
t.Fatal("job should not run after stop")
return nil
})
if !errors.Is(err, ErrStopped) {
t.Fatalf("got %v, want %v", err, ErrStopped)
}
}
func TestStopCanceledBeforeEnqueueCanBeRetried(t *testing.T) {
thread, err := Start("test", nil)
if err != nil {
t.Fatal(err)
}
defer thread.Stop(context.Background(), nil)
running := make(chan struct{})
release := make(chan struct{})
errCh := make(chan error, 1)
go func() {
errCh <- thread.Do(context.Background(), func() error {
close(running)
<-release
return nil
})
}()
<-running
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
cleanupRan := false
err = thread.Stop(ctx, func() {
cleanupRan = true
})
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("got %v, want %v", err, context.DeadlineExceeded)
}
if cleanupRan {
t.Fatal("cleanup ran even though stop was not enqueued")
}
close(release)
if err := <-errCh; err != nil {
t.Fatal(err)
}
if err := thread.Do(context.Background(), func() error { return nil }); err != nil {
t.Fatalf("thread did not accept work after canceled Stop: %v", err)
}
cleanupRan = false
if err := thread.Stop(context.Background(), func() {
cleanupRan = true
}); err != nil {
t.Fatal(err)
}
if !cleanupRan {
t.Fatal("cleanup did not run on retried Stop")
}
}
func TestStopWaitsForActiveWorkBeforeCleanup(t *testing.T) {
thread, err := Start("test", nil)
if err != nil {
t.Fatal(err)
}
running := make(chan struct{})
release := make(chan struct{})
jobErr := make(chan error, 1)
go func() {
jobErr <- thread.Do(context.Background(), func() error {
close(running)
<-release
return nil
})
}()
<-running
cleaned := make(chan struct{})
stopErr := make(chan error, 1)
go func() {
stopErr <- thread.Stop(context.Background(), func() {
close(cleaned)
})
}()
select {
case <-cleaned:
t.Fatal("cleanup ran before active job completed")
case <-time.After(10 * time.Millisecond):
}
err = thread.Do(context.Background(), func() error {
return errors.New("work should be rejected once Stop starts")
})
if !errors.Is(err, ErrStopped) {
t.Fatalf("got %v, want %v", err, ErrStopped)
}
close(release)
if err := <-jobErr; err != nil {
t.Fatal(err)
}
if err := <-stopErr; err != nil {
t.Fatal(err)
}
select {
case <-cleaned:
default:
t.Fatal("cleanup did not run")
}
}