From d07e4a1dd39c6334184c05e1c3b8192865e114c2 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Tue, 17 Feb 2026 13:57:05 -0800 Subject: [PATCH] bugfix: better mlx model scheduling (#14290) This fixes a bug with current MLX based models which don't get loaded/unloaded correctly. The first model currently gets loaded and then subsequent model starts get shunted to the first runner which results in the wrong model being run. --- server/routes_generate_test.go | 53 +++++++++++++------------ server/sched.go | 71 +++++++++++++++++++++++++++------- server/sched_test.go | 65 +++++++++++++++++++++++++++++++ 3 files changed, 150 insertions(+), 39 deletions(-) diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index 677fef369..0679b4262 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -2371,30 +2371,6 @@ func TestImageGenerateStreamFalse(t *testing.T) { return nil } - opts := api.DefaultOptions() - s := Server{ - sched: &Scheduler{ - pendingReqCh: make(chan *LlmRequest, 1), - finishedReqCh: make(chan *LlmRequest, 1), - expiredCh: make(chan *runnerRef, 1), - unloadedCh: make(chan any, 1), - loaded: map[string]*runnerRef{ - "": { - llama: &mock, - Options: &opts, - model: &Model{Config: model.ConfigV2{Capabilities: []string{"image"}}}, - isImagegen: true, - numParallel: 1, - }, - }, - newServerFn: newMockServer(&mock), - getGpuFn: getGpuFn, - getSystemInfoFn: getSystemInfoFn, - }, - } - - go s.sched.Run(t.Context()) - // Create model manifest with image capability n := model.ParseName("test-image") cfg := model.ConfigV2{Capabilities: []string{"image"}} @@ -2410,6 +2386,35 @@ func TestImageGenerateStreamFalse(t *testing.T) { t.Fatal(err) } + loadedModel, err := GetModel("test-image") + if err != nil { + t.Fatal(err) + } + + opts := api.DefaultOptions() + s := Server{ + sched: &Scheduler{ + pendingReqCh: make(chan *LlmRequest, 1), + finishedReqCh: make(chan *LlmRequest, 1), + expiredCh: make(chan *runnerRef, 1), + unloadedCh: make(chan any, 1), + loaded: map[string]*runnerRef{ + schedulerModelKey(loadedModel): { + llama: &mock, + Options: &opts, + model: loadedModel, + isImagegen: true, + numParallel: 1, + }, + }, + newServerFn: newMockServer(&mock), + getGpuFn: getGpuFn, + getSystemInfoFn: getSystemInfoFn, + }, + } + + go s.sched.Run(t.Context()) + streamFalse := false w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ Model: "test-image", diff --git a/server/sched.go b/server/sched.go index 728ec47b6..0049b87ce 100644 --- a/server/sched.go +++ b/server/sched.go @@ -83,6 +83,28 @@ func InitScheduler(ctx context.Context) *Scheduler { return sched } +// schedulerModelKey returns the scheduler map key for a model. +// GGUF-backed models use ModelPath; safetensors/image models without a +// ModelPath use manifest digest so distinct models don't collide. +func schedulerModelKey(m *Model) string { + if m == nil { + return "" + } + if m.ModelPath != "" { + return m.ModelPath + } + if m.Digest != "" { + return "digest:" + m.Digest + } + if m.Name != "" { + return "name:" + m.Name + } + if m.ShortName != "" { + return "short:" + m.ShortName + } + return "" +} + // context must be canceled to decrement ref count and release the runner func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration, useImagegen bool) (chan *runnerRef, chan error) { if opts.NumCtx < 4 { @@ -104,8 +126,9 @@ func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, ses useImagegen: useImagegen, } + key := schedulerModelKey(req.model) s.loadedMu.Lock() - runner := s.loaded[req.model.ModelPath] + runner := s.loaded[key] s.loadedMu.Unlock() if runner != nil && !runner.needsReload(c, req) { req.useLoadedRunner(runner, s.finishedReqCh) @@ -151,8 +174,9 @@ func (s *Scheduler) processPending(ctx context.Context) { for { var runnerToExpire *runnerRef + pendingKey := schedulerModelKey(pending.model) s.loadedMu.Lock() - runner := s.loaded[pending.model.ModelPath] + runner := s.loaded[pendingKey] loadedCount := len(s.loaded) runnersSnapshot := make([]ml.FilteredRunnerDiscovery, 0, len(s.loaded)) for _, r := range s.loaded { @@ -166,7 +190,7 @@ func (s *Scheduler) processPending(ctx context.Context) { runnerToExpire = runner } else { // Runner is usable, return it - logutil.Trace("using existing loaded runner", "model", pending.model.ModelPath) + logutil.Trace("using existing loaded runner", "model", pendingKey) pending.useLoadedRunner(runner, s.finishedReqCh) break } @@ -292,11 +316,12 @@ func (s *Scheduler) processCompleted(ctx context.Context) { slog.Debug("shutting down scheduler completed loop") return case finished := <-s.finishedReqCh: + finishedKey := schedulerModelKey(finished.model) s.loadedMu.Lock() - runner := s.loaded[finished.model.ModelPath] + runner := s.loaded[finishedKey] s.loadedMu.Unlock() if runner == nil { - slog.Error("finished request signal received after model unloaded", "modelPath", finished.model.ModelPath) + slog.Error("finished request signal received after model unloaded", "modelPath", finishedKey) continue } runner.refMu.Lock() @@ -347,7 +372,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) { s.loadedMu.Lock() slog.Debug("got lock to unload expired event", "runner", runner) - runnerToUnload := s.loaded[runner.modelPath] + runnerToUnload := s.loaded[runner.modelKey] if runnerToUnload == nil { // If runnerToUnload is nil, we already processed an event and // unloaded it. This double unload can happen if the initial @@ -376,7 +401,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) { } finished := s.waitForVRAMRecovery(runner, runnersSnapshot) runner.unload() - delete(s.loaded, runner.modelPath) + delete(s.loaded, runner.modelKey) s.loadedMu.Unlock() slog.Debug("runner terminated and removed from list, blocking for VRAM recovery", "runner", runner) <-finished @@ -514,6 +539,7 @@ iGPUScan: runner := &runnerRef{ model: req.model, modelPath: req.model.ModelPath, + modelKey: schedulerModelKey(req.model), llama: llama, Options: &req.opts, sessionDuration: sessionDuration, @@ -528,7 +554,7 @@ iGPUScan: runner.refMu.Lock() // hold lock until running or aborted s.loadedMu.Lock() - if oldRunner, ok := s.loaded[req.model.ModelPath]; ok { + if oldRunner, ok := s.loaded[runner.modelKey]; ok { // Shouldn't happen, but safeguard against leaking a runner slog.Warn("model was still loaded", "old_runner", oldRunner, "new_runner", runner) oldRunner.refMu.Lock() @@ -536,7 +562,7 @@ iGPUScan: oldRunner.refMu.Unlock() } s.activeLoading = nil - s.loaded[req.model.ModelPath] = runner + s.loaded[runner.modelKey] = runner slog.Info("loaded runners", "count", len(s.loaded)) s.loadedMu.Unlock() @@ -596,6 +622,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool { runner := &runnerRef{ model: req.model, modelPath: req.model.ModelPath, + modelKey: schedulerModelKey(req.model), llama: server, Options: &req.opts, loading: false, @@ -606,7 +633,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool { } s.loadedMu.Lock() - s.loaded[req.model.ModelPath] = runner + s.loaded[runner.modelKey] = runner s.loadedMu.Unlock() // Set up expiration timer @@ -684,6 +711,7 @@ type runnerRef struct { model *Model modelPath string + modelKey string numParallel int *api.Options } @@ -703,7 +731,7 @@ func (runner *runnerRef) unload() { } func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool { - slog.Debug("evaluating already loaded", "model", req.model.ModelPath) + slog.Debug("evaluating already loaded", "model", schedulerModelKey(req.model)) runner.refMu.Lock() defer runner.refMu.Unlock() @@ -814,6 +842,10 @@ func (runner *runnerRef) LogValue() slog.Value { if runner == nil { return slog.StringValue("nil") } + modelID := runner.modelPath + if modelID == "" { + modelID = runner.modelKey + } attrs := []slog.Attr{} if runner.model != nil { attrs = append(attrs, slog.String("name", runner.model.Name)) @@ -828,7 +860,7 @@ func (runner *runnerRef) LogValue() slog.Value { slog.String("vram", format.HumanBytes2(runner.vramSize)), slog.Int("parallel", runner.numParallel), slog.Int("pid", runner.pid), - slog.String("model", runner.modelPath), + slog.String("model", modelID), ) if runner.Options != nil { attrs = append(attrs, slog.Int("num_ctx", runner.Options.NumCtx)) @@ -873,8 +905,16 @@ func (a ByDurationAndName) Less(i, j int) bool { if d1 != d2 { return d1 < d2 } - // Secondary sort by model path lex order - return a[i].modelPath < a[j].modelPath + // Secondary sort by model key/path lex order + n1 := a[i].modelPath + if n1 == "" { + n1 = a[i].modelKey + } + n2 := a[j].modelPath + if n2 == "" { + n2 = a[j].modelKey + } + return n1 < n2 } // TODO - future consideration to pick runners based on size @@ -934,8 +974,9 @@ func (s *Scheduler) unloadAllRunners() { } func (s *Scheduler) expireRunner(model *Model) { + modelKey := schedulerModelKey(model) s.loadedMu.Lock() - runner, ok := s.loaded[model.ModelPath] + runner, ok := s.loaded[modelKey] s.loadedMu.Unlock() if ok { runner.refMu.Lock() diff --git a/server/sched_test.go b/server/sched_test.go index 732e5b3cc..4b1ed54f2 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -448,6 +448,71 @@ func TestSchedGetRunner(t *testing.T) { b.ctxDone() } +func TestSchedGetRunnerUsesDigestKeyWhenModelPathEmpty(t *testing.T) { + ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond) + defer done() + + s := InitScheduler(ctx) + opts := api.DefaultOptions() + opts.NumCtx = 4 + + loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"} + loadedRunner := &runnerRef{ + model: loadedModel, + modelKey: schedulerModelKey(loadedModel), + llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}}, + Options: &opts, + numParallel: 1, + } + + s.loadedMu.Lock() + s.loaded[loadedRunner.modelKey] = loadedRunner + s.loadedMu.Unlock() + + reqModel := &Model{Name: "safetensors-b", Digest: "sha-b"} + successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil, false) + + require.Empty(t, successCh) + require.Empty(t, errCh) + require.Len(t, s.pendingReqCh, 1) +} + +func TestSchedGetRunnerReusesSameDigestWhenModelPathEmpty(t *testing.T) { + ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond) + defer done() + + s := InitScheduler(ctx) + opts := api.DefaultOptions() + opts.NumCtx = 4 + + loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"} + loadedRunner := &runnerRef{ + model: loadedModel, + modelKey: schedulerModelKey(loadedModel), + llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}}, + Options: &opts, + numParallel: 1, + } + + s.loadedMu.Lock() + s.loaded[loadedRunner.modelKey] = loadedRunner + s.loadedMu.Unlock() + + reqCtx, cancelReq := context.WithCancel(ctx) + successCh, errCh := s.GetRunner(reqCtx, &Model{Name: "safetensors-a-copy", Digest: "sha-a"}, opts, nil, false) + cancelReq() + + select { + case runner := <-successCh: + require.Equal(t, loadedRunner, runner) + default: + t.Fatal("expected existing runner to be reused") + } + + require.Empty(t, errCh) + require.Empty(t, s.pendingReqCh) +} + func TestSchedExpireRunner(t *testing.T) { ctx, done := context.WithTimeout(t.Context(), 20*time.Millisecond) defer done()