bugfix: better mlx model scheduling (#14290)

This fixes a bug with current MLX based models which don't get loaded/unloaded correctly. The first model currently gets loaded and then subsequent model starts get shunted to the first runner which results in the wrong model being run.
This commit is contained in:
Patrick Devine
2026-02-17 13:57:05 -08:00
committed by GitHub
parent 8a257ec00a
commit d07e4a1dd3
3 changed files with 150 additions and 39 deletions

View File

@@ -2371,30 +2371,6 @@ func TestImageGenerateStreamFalse(t *testing.T) {
return nil
}
opts := api.DefaultOptions()
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: map[string]*runnerRef{
"": {
llama: &mock,
Options: &opts,
model: &Model{Config: model.ConfigV2{Capabilities: []string{"image"}}},
isImagegen: true,
numParallel: 1,
},
},
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
},
}
go s.sched.Run(t.Context())
// Create model manifest with image capability
n := model.ParseName("test-image")
cfg := model.ConfigV2{Capabilities: []string{"image"}}
@@ -2410,6 +2386,35 @@ func TestImageGenerateStreamFalse(t *testing.T) {
t.Fatal(err)
}
loadedModel, err := GetModel("test-image")
if err != nil {
t.Fatal(err)
}
opts := api.DefaultOptions()
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: map[string]*runnerRef{
schedulerModelKey(loadedModel): {
llama: &mock,
Options: &opts,
model: loadedModel,
isImagegen: true,
numParallel: 1,
},
},
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
},
}
go s.sched.Run(t.Context())
streamFalse := false
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-image",

View File

@@ -83,6 +83,28 @@ func InitScheduler(ctx context.Context) *Scheduler {
return sched
}
// schedulerModelKey returns the scheduler map key for a model.
// GGUF-backed models use ModelPath; safetensors/image models without a
// ModelPath use manifest digest so distinct models don't collide.
func schedulerModelKey(m *Model) string {
if m == nil {
return ""
}
if m.ModelPath != "" {
return m.ModelPath
}
if m.Digest != "" {
return "digest:" + m.Digest
}
if m.Name != "" {
return "name:" + m.Name
}
if m.ShortName != "" {
return "short:" + m.ShortName
}
return ""
}
// context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration, useImagegen bool) (chan *runnerRef, chan error) {
if opts.NumCtx < 4 {
@@ -104,8 +126,9 @@ func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, ses
useImagegen: useImagegen,
}
key := schedulerModelKey(req.model)
s.loadedMu.Lock()
runner := s.loaded[req.model.ModelPath]
runner := s.loaded[key]
s.loadedMu.Unlock()
if runner != nil && !runner.needsReload(c, req) {
req.useLoadedRunner(runner, s.finishedReqCh)
@@ -151,8 +174,9 @@ func (s *Scheduler) processPending(ctx context.Context) {
for {
var runnerToExpire *runnerRef
pendingKey := schedulerModelKey(pending.model)
s.loadedMu.Lock()
runner := s.loaded[pending.model.ModelPath]
runner := s.loaded[pendingKey]
loadedCount := len(s.loaded)
runnersSnapshot := make([]ml.FilteredRunnerDiscovery, 0, len(s.loaded))
for _, r := range s.loaded {
@@ -166,7 +190,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
runnerToExpire = runner
} else {
// Runner is usable, return it
logutil.Trace("using existing loaded runner", "model", pending.model.ModelPath)
logutil.Trace("using existing loaded runner", "model", pendingKey)
pending.useLoadedRunner(runner, s.finishedReqCh)
break
}
@@ -292,11 +316,12 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
slog.Debug("shutting down scheduler completed loop")
return
case finished := <-s.finishedReqCh:
finishedKey := schedulerModelKey(finished.model)
s.loadedMu.Lock()
runner := s.loaded[finished.model.ModelPath]
runner := s.loaded[finishedKey]
s.loadedMu.Unlock()
if runner == nil {
slog.Error("finished request signal received after model unloaded", "modelPath", finished.model.ModelPath)
slog.Error("finished request signal received after model unloaded", "modelPath", finishedKey)
continue
}
runner.refMu.Lock()
@@ -347,7 +372,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
s.loadedMu.Lock()
slog.Debug("got lock to unload expired event", "runner", runner)
runnerToUnload := s.loaded[runner.modelPath]
runnerToUnload := s.loaded[runner.modelKey]
if runnerToUnload == nil {
// If runnerToUnload is nil, we already processed an event and
// unloaded it. This double unload can happen if the initial
@@ -376,7 +401,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
}
finished := s.waitForVRAMRecovery(runner, runnersSnapshot)
runner.unload()
delete(s.loaded, runner.modelPath)
delete(s.loaded, runner.modelKey)
s.loadedMu.Unlock()
slog.Debug("runner terminated and removed from list, blocking for VRAM recovery", "runner", runner)
<-finished
@@ -514,6 +539,7 @@ iGPUScan:
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
modelKey: schedulerModelKey(req.model),
llama: llama,
Options: &req.opts,
sessionDuration: sessionDuration,
@@ -528,7 +554,7 @@ iGPUScan:
runner.refMu.Lock() // hold lock until running or aborted
s.loadedMu.Lock()
if oldRunner, ok := s.loaded[req.model.ModelPath]; ok {
if oldRunner, ok := s.loaded[runner.modelKey]; ok {
// Shouldn't happen, but safeguard against leaking a runner
slog.Warn("model was still loaded", "old_runner", oldRunner, "new_runner", runner)
oldRunner.refMu.Lock()
@@ -536,7 +562,7 @@ iGPUScan:
oldRunner.refMu.Unlock()
}
s.activeLoading = nil
s.loaded[req.model.ModelPath] = runner
s.loaded[runner.modelKey] = runner
slog.Info("loaded runners", "count", len(s.loaded))
s.loadedMu.Unlock()
@@ -596,6 +622,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
modelKey: schedulerModelKey(req.model),
llama: server,
Options: &req.opts,
loading: false,
@@ -606,7 +633,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
}
s.loadedMu.Lock()
s.loaded[req.model.ModelPath] = runner
s.loaded[runner.modelKey] = runner
s.loadedMu.Unlock()
// Set up expiration timer
@@ -684,6 +711,7 @@ type runnerRef struct {
model *Model
modelPath string
modelKey string
numParallel int
*api.Options
}
@@ -703,7 +731,7 @@ func (runner *runnerRef) unload() {
}
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
slog.Debug("evaluating already loaded", "model", schedulerModelKey(req.model))
runner.refMu.Lock()
defer runner.refMu.Unlock()
@@ -814,6 +842,10 @@ func (runner *runnerRef) LogValue() slog.Value {
if runner == nil {
return slog.StringValue("nil")
}
modelID := runner.modelPath
if modelID == "" {
modelID = runner.modelKey
}
attrs := []slog.Attr{}
if runner.model != nil {
attrs = append(attrs, slog.String("name", runner.model.Name))
@@ -828,7 +860,7 @@ func (runner *runnerRef) LogValue() slog.Value {
slog.String("vram", format.HumanBytes2(runner.vramSize)),
slog.Int("parallel", runner.numParallel),
slog.Int("pid", runner.pid),
slog.String("model", runner.modelPath),
slog.String("model", modelID),
)
if runner.Options != nil {
attrs = append(attrs, slog.Int("num_ctx", runner.Options.NumCtx))
@@ -873,8 +905,16 @@ func (a ByDurationAndName) Less(i, j int) bool {
if d1 != d2 {
return d1 < d2
}
// Secondary sort by model path lex order
return a[i].modelPath < a[j].modelPath
// Secondary sort by model key/path lex order
n1 := a[i].modelPath
if n1 == "" {
n1 = a[i].modelKey
}
n2 := a[j].modelPath
if n2 == "" {
n2 = a[j].modelKey
}
return n1 < n2
}
// TODO - future consideration to pick runners based on size
@@ -934,8 +974,9 @@ func (s *Scheduler) unloadAllRunners() {
}
func (s *Scheduler) expireRunner(model *Model) {
modelKey := schedulerModelKey(model)
s.loadedMu.Lock()
runner, ok := s.loaded[model.ModelPath]
runner, ok := s.loaded[modelKey]
s.loadedMu.Unlock()
if ok {
runner.refMu.Lock()

View File

@@ -448,6 +448,71 @@ func TestSchedGetRunner(t *testing.T) {
b.ctxDone()
}
func TestSchedGetRunnerUsesDigestKeyWhenModelPathEmpty(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
defer done()
s := InitScheduler(ctx)
opts := api.DefaultOptions()
opts.NumCtx = 4
loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"}
loadedRunner := &runnerRef{
model: loadedModel,
modelKey: schedulerModelKey(loadedModel),
llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}},
Options: &opts,
numParallel: 1,
}
s.loadedMu.Lock()
s.loaded[loadedRunner.modelKey] = loadedRunner
s.loadedMu.Unlock()
reqModel := &Model{Name: "safetensors-b", Digest: "sha-b"}
successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil, false)
require.Empty(t, successCh)
require.Empty(t, errCh)
require.Len(t, s.pendingReqCh, 1)
}
func TestSchedGetRunnerReusesSameDigestWhenModelPathEmpty(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
defer done()
s := InitScheduler(ctx)
opts := api.DefaultOptions()
opts.NumCtx = 4
loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"}
loadedRunner := &runnerRef{
model: loadedModel,
modelKey: schedulerModelKey(loadedModel),
llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}},
Options: &opts,
numParallel: 1,
}
s.loadedMu.Lock()
s.loaded[loadedRunner.modelKey] = loadedRunner
s.loadedMu.Unlock()
reqCtx, cancelReq := context.WithCancel(ctx)
successCh, errCh := s.GetRunner(reqCtx, &Model{Name: "safetensors-a-copy", Digest: "sha-a"}, opts, nil, false)
cancelReq()
select {
case runner := <-successCh:
require.Equal(t, loadedRunner, runner)
default:
t.Fatal("expected existing runner to be reused")
}
require.Empty(t, errCh)
require.Empty(t, s.pendingReqCh)
}
func TestSchedExpireRunner(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 20*time.Millisecond)
defer done()