mirror of
https://github.com/ollama/ollama.git
synced 2026-05-07 00:22:43 -05:00
* Update MLX and MLX-C * Run MLX CGO work on a locked OS thread MLX now relies on OS-thread-local execution state for streams, encoders, and caches. Add an mlxthread executor backed by runtime.LockOSThread and route runner initialization, model load, inference, status memory reads, and cleanup through the worker so Go goroutine migration cannot split MLX state across native threads. Also stop caching default MLX streams before the runner owns the thread and add worker/threaded MLX regression tests. * mlx: use common status writer * mlx: bundle missing libjaccl on arm64 Inspired by #15793 * review comments
352 lines
7.1 KiB
Go
352 lines
7.1 KiB
Go
package mlxthread
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"reflect"
|
|
"runtime"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestDoRunsInOrder(t *testing.T) {
|
|
thread, err := Start("test", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer thread.Stop(context.Background(), nil)
|
|
|
|
var got []int
|
|
for i := 0; i < 5; i++ {
|
|
i := i
|
|
if err := thread.Do(context.Background(), func() error {
|
|
got = append(got, i)
|
|
return nil
|
|
}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
if want := []int{0, 1, 2, 3, 4}; !reflect.DeepEqual(got, want) {
|
|
t.Fatalf("got %v, want %v", got, want)
|
|
}
|
|
}
|
|
|
|
func TestDoPropagatesPanicToCaller(t *testing.T) {
|
|
thread, err := Start("test", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer thread.Stop(context.Background(), nil)
|
|
|
|
defer func() {
|
|
if got := recover(); got != "boom" {
|
|
t.Fatalf("got panic %v, want boom", got)
|
|
}
|
|
}()
|
|
|
|
_ = thread.Do(context.Background(), func() error {
|
|
panic("boom")
|
|
})
|
|
}
|
|
|
|
func TestDoCancelsBeforeJobStarts(t *testing.T) {
|
|
thread, err := Start("test", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer thread.Stop(context.Background(), nil)
|
|
|
|
running := make(chan struct{})
|
|
release := make(chan struct{})
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
errCh <- thread.Do(context.Background(), func() error {
|
|
close(running)
|
|
<-release
|
|
return nil
|
|
})
|
|
}()
|
|
<-running
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
err = thread.Do(ctx, func() error {
|
|
t.Fatal("canceled job should not run")
|
|
return nil
|
|
})
|
|
if !errors.Is(err, context.Canceled) {
|
|
t.Fatalf("got %v, want %v", err, context.Canceled)
|
|
}
|
|
|
|
close(release)
|
|
if err := <-errCh; err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func TestAlreadyCanceledContextDoesNotEnqueue(t *testing.T) {
|
|
t.Run("Do", func(t *testing.T) {
|
|
thread, err := Start("test", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer thread.Stop(context.Background(), nil)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
ran := false
|
|
err = thread.Do(ctx, func() error {
|
|
ran = true
|
|
return nil
|
|
})
|
|
if !errors.Is(err, context.Canceled) {
|
|
t.Fatalf("got %v, want %v", err, context.Canceled)
|
|
}
|
|
if ran {
|
|
t.Fatal("canceled job ran")
|
|
}
|
|
})
|
|
|
|
t.Run("Stop", func(t *testing.T) {
|
|
thread, err := Start("test", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer thread.Stop(context.Background(), nil)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
cleaned := false
|
|
err = thread.Stop(ctx, func() {
|
|
cleaned = true
|
|
})
|
|
if !errors.Is(err, context.Canceled) {
|
|
t.Fatalf("got %v, want %v", err, context.Canceled)
|
|
}
|
|
if cleaned {
|
|
t.Fatal("cleanup ran for canceled stop")
|
|
}
|
|
if err := thread.Do(context.Background(), func() error { return nil }); err != nil {
|
|
t.Fatalf("thread did not accept work after canceled Stop: %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestCallReturnsValue(t *testing.T) {
|
|
thread, err := Start("test", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer thread.Stop(context.Background(), nil)
|
|
|
|
got, err := Call(context.Background(), thread, func() (int, error) {
|
|
return 42, nil
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if got != 42 {
|
|
t.Fatalf("got %d, want 42", got)
|
|
}
|
|
}
|
|
|
|
func TestDoRunsConcurrentlySubmittedWorkSerially(t *testing.T) {
|
|
thread, err := Start("test", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer thread.Stop(context.Background(), nil)
|
|
|
|
oldProcs := runtime.GOMAXPROCS(8)
|
|
defer runtime.GOMAXPROCS(oldProcs)
|
|
|
|
const goroutines = 16
|
|
const iterations = 64
|
|
|
|
var active atomic.Int32
|
|
var count atomic.Int64
|
|
var wg sync.WaitGroup
|
|
errCh := make(chan error, goroutines)
|
|
|
|
for range goroutines {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
for range iterations {
|
|
if err := thread.Do(context.Background(), func() error {
|
|
if got := active.Add(1); got != 1 {
|
|
return errors.New("thread executed jobs concurrently")
|
|
}
|
|
runtime.Gosched()
|
|
count.Add(1)
|
|
if got := active.Add(-1); got != 0 {
|
|
return errors.New("thread active count did not return to zero")
|
|
}
|
|
return nil
|
|
}); err != nil {
|
|
errCh <- err
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
close(errCh)
|
|
|
|
for err := range errCh {
|
|
t.Fatal(err)
|
|
}
|
|
if got, want := count.Load(), int64(goroutines*iterations); got != want {
|
|
t.Fatalf("got %d jobs, want %d", got, want)
|
|
}
|
|
}
|
|
|
|
func TestStopRunsCleanupAndRejectsWork(t *testing.T) {
|
|
thread, err := Start("test", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
cleaned := 0
|
|
if err := thread.Stop(context.Background(), func() {
|
|
cleaned++
|
|
}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if cleaned != 1 {
|
|
t.Fatalf("cleanup ran %d times, want 1", cleaned)
|
|
}
|
|
|
|
if err := thread.Stop(context.Background(), func() {
|
|
cleaned++
|
|
}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if cleaned != 1 {
|
|
t.Fatalf("cleanup ran %d times after second Stop, want 1", cleaned)
|
|
}
|
|
|
|
err = thread.Do(context.Background(), func() error {
|
|
t.Fatal("job should not run after stop")
|
|
return nil
|
|
})
|
|
if !errors.Is(err, ErrStopped) {
|
|
t.Fatalf("got %v, want %v", err, ErrStopped)
|
|
}
|
|
}
|
|
|
|
func TestStopCanceledBeforeEnqueueCanBeRetried(t *testing.T) {
|
|
thread, err := Start("test", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer thread.Stop(context.Background(), nil)
|
|
|
|
running := make(chan struct{})
|
|
release := make(chan struct{})
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
errCh <- thread.Do(context.Background(), func() error {
|
|
close(running)
|
|
<-release
|
|
return nil
|
|
})
|
|
}()
|
|
<-running
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
|
defer cancel()
|
|
|
|
cleanupRan := false
|
|
err = thread.Stop(ctx, func() {
|
|
cleanupRan = true
|
|
})
|
|
if !errors.Is(err, context.DeadlineExceeded) {
|
|
t.Fatalf("got %v, want %v", err, context.DeadlineExceeded)
|
|
}
|
|
if cleanupRan {
|
|
t.Fatal("cleanup ran even though stop was not enqueued")
|
|
}
|
|
|
|
close(release)
|
|
if err := <-errCh; err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if err := thread.Do(context.Background(), func() error { return nil }); err != nil {
|
|
t.Fatalf("thread did not accept work after canceled Stop: %v", err)
|
|
}
|
|
|
|
cleanupRan = false
|
|
if err := thread.Stop(context.Background(), func() {
|
|
cleanupRan = true
|
|
}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !cleanupRan {
|
|
t.Fatal("cleanup did not run on retried Stop")
|
|
}
|
|
}
|
|
|
|
func TestStopWaitsForActiveWorkBeforeCleanup(t *testing.T) {
|
|
thread, err := Start("test", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
running := make(chan struct{})
|
|
release := make(chan struct{})
|
|
jobErr := make(chan error, 1)
|
|
go func() {
|
|
jobErr <- thread.Do(context.Background(), func() error {
|
|
close(running)
|
|
<-release
|
|
return nil
|
|
})
|
|
}()
|
|
<-running
|
|
|
|
cleaned := make(chan struct{})
|
|
stopErr := make(chan error, 1)
|
|
go func() {
|
|
stopErr <- thread.Stop(context.Background(), func() {
|
|
close(cleaned)
|
|
})
|
|
}()
|
|
|
|
select {
|
|
case <-cleaned:
|
|
t.Fatal("cleanup ran before active job completed")
|
|
case <-time.After(10 * time.Millisecond):
|
|
}
|
|
|
|
err = thread.Do(context.Background(), func() error {
|
|
return errors.New("work should be rejected once Stop starts")
|
|
})
|
|
if !errors.Is(err, ErrStopped) {
|
|
t.Fatalf("got %v, want %v", err, ErrStopped)
|
|
}
|
|
|
|
close(release)
|
|
if err := <-jobErr; err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := <-stopErr; err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
select {
|
|
case <-cleaned:
|
|
default:
|
|
t.Fatal("cleanup did not run")
|
|
}
|
|
}
|