mirror of
https://github.com/ollama/ollama.git
synced 2026-05-05 23:53:43 -05:00
* Update MLX and MLX-C * Run MLX CGO work on a locked OS thread MLX now relies on OS-thread-local execution state for streams, encoders, and caches. Add an mlxthread executor backed by runtime.LockOSThread and route runner initialization, model load, inference, status memory reads, and cleanup through the worker so Go goroutine migration cannot split MLX state across native threads. Also stop caching default MLX streams before the runner owns the thread and add worker/threaded MLX regression tests. * mlx: use common status writer * mlx: bundle missing libjaccl on arm64 Inspired by #15793 * review comments
184 lines
3.3 KiB
Go
184 lines
3.3 KiB
Go
package mlxthread
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"runtime"
|
|
"sync/atomic"
|
|
)
|
|
|
|
var ErrStopped = errors.New("mlx thread stopped")
|
|
|
|
type Thread struct {
|
|
name string
|
|
|
|
jobs chan job
|
|
done chan struct{}
|
|
stopping atomic.Bool
|
|
}
|
|
|
|
type job struct {
|
|
fn func() error
|
|
result chan result
|
|
stop bool
|
|
}
|
|
|
|
type result struct {
|
|
err error
|
|
panicValue any
|
|
}
|
|
|
|
// Start creates a long-lived worker goroutine locked to one OS thread.
|
|
func Start(name string, init func() error) (*Thread, error) {
|
|
t := &Thread{
|
|
name: name,
|
|
jobs: make(chan job),
|
|
done: make(chan struct{}),
|
|
}
|
|
|
|
initResult := make(chan result, 1)
|
|
go t.loop(init, initResult)
|
|
|
|
res := <-initResult
|
|
if res.panicValue != nil {
|
|
panic(res.panicValue)
|
|
}
|
|
if res.err != nil {
|
|
return nil, res.err
|
|
}
|
|
|
|
return t, nil
|
|
}
|
|
|
|
// Do runs fn on the locked OS thread.
|
|
//
|
|
// Context cancellation only applies while the work is queued. Once the worker
|
|
// accepts a job, the job runs until fn returns or reaches its own cancellation
|
|
// checks.
|
|
func (t *Thread) Do(ctx context.Context, fn func() error) error {
|
|
res, err := t.enqueue(ctx, fn, false, false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if res.panicValue != nil {
|
|
panic(res.panicValue)
|
|
}
|
|
return res.err
|
|
}
|
|
|
|
func Call[T any](ctx context.Context, t *Thread, fn func() (T, error)) (T, error) {
|
|
var value T
|
|
err := t.Do(ctx, func() error {
|
|
var err error
|
|
value, err = fn()
|
|
return err
|
|
})
|
|
return value, err
|
|
}
|
|
|
|
// Stop runs cleanup on the locked OS thread and then shuts the worker down.
|
|
func (t *Thread) Stop(ctx context.Context, cleanup func()) error {
|
|
ctx = contextOrBackground(ctx)
|
|
|
|
if !t.stopping.CompareAndSwap(false, true) {
|
|
select {
|
|
case <-t.done:
|
|
return nil
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
}
|
|
}
|
|
|
|
res, err := t.enqueue(ctx, func() error {
|
|
if cleanup != nil {
|
|
cleanup()
|
|
}
|
|
return nil
|
|
}, true, true)
|
|
if err != nil {
|
|
if !errors.Is(err, ErrStopped) {
|
|
t.stopping.Store(false)
|
|
}
|
|
return err
|
|
}
|
|
if res.panicValue != nil {
|
|
panic(res.panicValue)
|
|
}
|
|
if res.err != nil {
|
|
return res.err
|
|
}
|
|
|
|
select {
|
|
case <-t.done:
|
|
return nil
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
}
|
|
}
|
|
|
|
func (t *Thread) loop(init func() error, initResult chan<- result) {
|
|
runtime.LockOSThread()
|
|
// Deliberately do not unlock. MLX thread-local state belongs to this worker
|
|
// until shutdown so it cannot leak back to arbitrary Go goroutines.
|
|
|
|
res := run(init)
|
|
initResult <- res
|
|
if res.err != nil || res.panicValue != nil {
|
|
close(t.done)
|
|
return
|
|
}
|
|
|
|
for {
|
|
j := <-t.jobs
|
|
res := run(j.fn)
|
|
j.result <- res
|
|
if j.stop {
|
|
close(t.done)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (t *Thread) enqueue(ctx context.Context, fn func() error, stop, allowStopping bool) (result, error) {
|
|
ctx = contextOrBackground(ctx)
|
|
if err := ctx.Err(); err != nil {
|
|
return result{}, err
|
|
}
|
|
|
|
if !allowStopping && t.stopping.Load() {
|
|
return result{}, ErrStopped
|
|
}
|
|
|
|
resultCh := make(chan result, 1)
|
|
j := job{fn: fn, result: resultCh, stop: stop}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return result{}, ctx.Err()
|
|
case <-t.done:
|
|
return result{}, ErrStopped
|
|
case t.jobs <- j:
|
|
}
|
|
|
|
return <-resultCh, nil
|
|
}
|
|
|
|
func run(fn func() error) (res result) {
|
|
defer func() {
|
|
if v := recover(); v != nil {
|
|
res.panicValue = v
|
|
}
|
|
}()
|
|
if fn != nil {
|
|
res.err = fn()
|
|
}
|
|
return res
|
|
}
|
|
|
|
func contextOrBackground(ctx context.Context) context.Context {
|
|
if ctx != nil {
|
|
return ctx
|
|
}
|
|
return context.Background()
|
|
}
|