Files
ollama/x/internal/mlxthread/thread.go
Daniel Hiltgen 534342e7e2 Update MLX and MLX-C with threading fixes (#15845)
* Update MLX and MLX-C

* Run MLX CGO work on a locked OS thread

MLX now relies on OS-thread-local execution state for streams, encoders, and caches. Add an mlxthread executor backed by runtime.LockOSThread and route runner initialization, model load, inference, status memory reads, and cleanup through the worker so Go goroutine migration cannot split MLX state across native threads.

Also stop caching default MLX streams before the runner owns the thread and add worker/threaded MLX regression tests.

* mlx: use common status writer

* mlx: bundle missing libjaccl on arm64

Inspired by #15793

* review comments
2026-05-03 10:03:14 -07:00

184 lines
3.3 KiB
Go

package mlxthread
import (
"context"
"errors"
"runtime"
"sync/atomic"
)
var ErrStopped = errors.New("mlx thread stopped")
type Thread struct {
name string
jobs chan job
done chan struct{}
stopping atomic.Bool
}
type job struct {
fn func() error
result chan result
stop bool
}
type result struct {
err error
panicValue any
}
// Start creates a long-lived worker goroutine locked to one OS thread.
func Start(name string, init func() error) (*Thread, error) {
t := &Thread{
name: name,
jobs: make(chan job),
done: make(chan struct{}),
}
initResult := make(chan result, 1)
go t.loop(init, initResult)
res := <-initResult
if res.panicValue != nil {
panic(res.panicValue)
}
if res.err != nil {
return nil, res.err
}
return t, nil
}
// Do runs fn on the locked OS thread.
//
// Context cancellation only applies while the work is queued. Once the worker
// accepts a job, the job runs until fn returns or reaches its own cancellation
// checks.
func (t *Thread) Do(ctx context.Context, fn func() error) error {
res, err := t.enqueue(ctx, fn, false, false)
if err != nil {
return err
}
if res.panicValue != nil {
panic(res.panicValue)
}
return res.err
}
func Call[T any](ctx context.Context, t *Thread, fn func() (T, error)) (T, error) {
var value T
err := t.Do(ctx, func() error {
var err error
value, err = fn()
return err
})
return value, err
}
// Stop runs cleanup on the locked OS thread and then shuts the worker down.
func (t *Thread) Stop(ctx context.Context, cleanup func()) error {
ctx = contextOrBackground(ctx)
if !t.stopping.CompareAndSwap(false, true) {
select {
case <-t.done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
res, err := t.enqueue(ctx, func() error {
if cleanup != nil {
cleanup()
}
return nil
}, true, true)
if err != nil {
if !errors.Is(err, ErrStopped) {
t.stopping.Store(false)
}
return err
}
if res.panicValue != nil {
panic(res.panicValue)
}
if res.err != nil {
return res.err
}
select {
case <-t.done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (t *Thread) loop(init func() error, initResult chan<- result) {
runtime.LockOSThread()
// Deliberately do not unlock. MLX thread-local state belongs to this worker
// until shutdown so it cannot leak back to arbitrary Go goroutines.
res := run(init)
initResult <- res
if res.err != nil || res.panicValue != nil {
close(t.done)
return
}
for {
j := <-t.jobs
res := run(j.fn)
j.result <- res
if j.stop {
close(t.done)
return
}
}
}
func (t *Thread) enqueue(ctx context.Context, fn func() error, stop, allowStopping bool) (result, error) {
ctx = contextOrBackground(ctx)
if err := ctx.Err(); err != nil {
return result{}, err
}
if !allowStopping && t.stopping.Load() {
return result{}, ErrStopped
}
resultCh := make(chan result, 1)
j := job{fn: fn, result: resultCh, stop: stop}
select {
case <-ctx.Done():
return result{}, ctx.Err()
case <-t.done:
return result{}, ErrStopped
case t.jobs <- j:
}
return <-resultCh, nil
}
func run(fn func() error) (res result) {
defer func() {
if v := recover(); v != nil {
res.panicValue = v
}
}()
if fn != nil {
res.err = fn()
}
return res
}
func contextOrBackground(ctx context.Context) context.Context {
if ctx != nil {
return ctx
}
return context.Background()
}