mirror of
https://github.com/ollama/ollama.git
synced 2026-03-09 07:16:38 -05:00
bugfix: better mlx model scheduling (#14290)
This fixes a bug with current MLX based models which don't get loaded/unloaded correctly. The first model currently gets loaded and then subsequent model starts get shunted to the first runner which results in the wrong model being run.
This commit is contained in:
@@ -2371,30 +2371,6 @@ func TestImageGenerateStreamFalse(t *testing.T) {
|
||||
return nil
|
||||
}
|
||||
|
||||
opts := api.DefaultOptions()
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: map[string]*runnerRef{
|
||||
"": {
|
||||
llama: &mock,
|
||||
Options: &opts,
|
||||
model: &Model{Config: model.ConfigV2{Capabilities: []string{"image"}}},
|
||||
isImagegen: true,
|
||||
numParallel: 1,
|
||||
},
|
||||
},
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create model manifest with image capability
|
||||
n := model.ParseName("test-image")
|
||||
cfg := model.ConfigV2{Capabilities: []string{"image"}}
|
||||
@@ -2410,6 +2386,35 @@ func TestImageGenerateStreamFalse(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loadedModel, err := GetModel("test-image")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
opts := api.DefaultOptions()
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: map[string]*runnerRef{
|
||||
schedulerModelKey(loadedModel): {
|
||||
llama: &mock,
|
||||
Options: &opts,
|
||||
model: loadedModel,
|
||||
isImagegen: true,
|
||||
numParallel: 1,
|
||||
},
|
||||
},
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
streamFalse := false
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-image",
|
||||
|
||||
@@ -83,6 +83,28 @@ func InitScheduler(ctx context.Context) *Scheduler {
|
||||
return sched
|
||||
}
|
||||
|
||||
// schedulerModelKey returns the scheduler map key for a model.
|
||||
// GGUF-backed models use ModelPath; safetensors/image models without a
|
||||
// ModelPath use manifest digest so distinct models don't collide.
|
||||
func schedulerModelKey(m *Model) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
if m.ModelPath != "" {
|
||||
return m.ModelPath
|
||||
}
|
||||
if m.Digest != "" {
|
||||
return "digest:" + m.Digest
|
||||
}
|
||||
if m.Name != "" {
|
||||
return "name:" + m.Name
|
||||
}
|
||||
if m.ShortName != "" {
|
||||
return "short:" + m.ShortName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// context must be canceled to decrement ref count and release the runner
|
||||
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration, useImagegen bool) (chan *runnerRef, chan error) {
|
||||
if opts.NumCtx < 4 {
|
||||
@@ -104,8 +126,9 @@ func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, ses
|
||||
useImagegen: useImagegen,
|
||||
}
|
||||
|
||||
key := schedulerModelKey(req.model)
|
||||
s.loadedMu.Lock()
|
||||
runner := s.loaded[req.model.ModelPath]
|
||||
runner := s.loaded[key]
|
||||
s.loadedMu.Unlock()
|
||||
if runner != nil && !runner.needsReload(c, req) {
|
||||
req.useLoadedRunner(runner, s.finishedReqCh)
|
||||
@@ -151,8 +174,9 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
|
||||
for {
|
||||
var runnerToExpire *runnerRef
|
||||
pendingKey := schedulerModelKey(pending.model)
|
||||
s.loadedMu.Lock()
|
||||
runner := s.loaded[pending.model.ModelPath]
|
||||
runner := s.loaded[pendingKey]
|
||||
loadedCount := len(s.loaded)
|
||||
runnersSnapshot := make([]ml.FilteredRunnerDiscovery, 0, len(s.loaded))
|
||||
for _, r := range s.loaded {
|
||||
@@ -166,7 +190,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
runnerToExpire = runner
|
||||
} else {
|
||||
// Runner is usable, return it
|
||||
logutil.Trace("using existing loaded runner", "model", pending.model.ModelPath)
|
||||
logutil.Trace("using existing loaded runner", "model", pendingKey)
|
||||
pending.useLoadedRunner(runner, s.finishedReqCh)
|
||||
break
|
||||
}
|
||||
@@ -292,11 +316,12 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
||||
slog.Debug("shutting down scheduler completed loop")
|
||||
return
|
||||
case finished := <-s.finishedReqCh:
|
||||
finishedKey := schedulerModelKey(finished.model)
|
||||
s.loadedMu.Lock()
|
||||
runner := s.loaded[finished.model.ModelPath]
|
||||
runner := s.loaded[finishedKey]
|
||||
s.loadedMu.Unlock()
|
||||
if runner == nil {
|
||||
slog.Error("finished request signal received after model unloaded", "modelPath", finished.model.ModelPath)
|
||||
slog.Error("finished request signal received after model unloaded", "modelPath", finishedKey)
|
||||
continue
|
||||
}
|
||||
runner.refMu.Lock()
|
||||
@@ -347,7 +372,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
||||
|
||||
s.loadedMu.Lock()
|
||||
slog.Debug("got lock to unload expired event", "runner", runner)
|
||||
runnerToUnload := s.loaded[runner.modelPath]
|
||||
runnerToUnload := s.loaded[runner.modelKey]
|
||||
if runnerToUnload == nil {
|
||||
// If runnerToUnload is nil, we already processed an event and
|
||||
// unloaded it. This double unload can happen if the initial
|
||||
@@ -376,7 +401,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
||||
}
|
||||
finished := s.waitForVRAMRecovery(runner, runnersSnapshot)
|
||||
runner.unload()
|
||||
delete(s.loaded, runner.modelPath)
|
||||
delete(s.loaded, runner.modelKey)
|
||||
s.loadedMu.Unlock()
|
||||
slog.Debug("runner terminated and removed from list, blocking for VRAM recovery", "runner", runner)
|
||||
<-finished
|
||||
@@ -514,6 +539,7 @@ iGPUScan:
|
||||
runner := &runnerRef{
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
modelKey: schedulerModelKey(req.model),
|
||||
llama: llama,
|
||||
Options: &req.opts,
|
||||
sessionDuration: sessionDuration,
|
||||
@@ -528,7 +554,7 @@ iGPUScan:
|
||||
runner.refMu.Lock() // hold lock until running or aborted
|
||||
|
||||
s.loadedMu.Lock()
|
||||
if oldRunner, ok := s.loaded[req.model.ModelPath]; ok {
|
||||
if oldRunner, ok := s.loaded[runner.modelKey]; ok {
|
||||
// Shouldn't happen, but safeguard against leaking a runner
|
||||
slog.Warn("model was still loaded", "old_runner", oldRunner, "new_runner", runner)
|
||||
oldRunner.refMu.Lock()
|
||||
@@ -536,7 +562,7 @@ iGPUScan:
|
||||
oldRunner.refMu.Unlock()
|
||||
}
|
||||
s.activeLoading = nil
|
||||
s.loaded[req.model.ModelPath] = runner
|
||||
s.loaded[runner.modelKey] = runner
|
||||
slog.Info("loaded runners", "count", len(s.loaded))
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
@@ -596,6 +622,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
runner := &runnerRef{
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
modelKey: schedulerModelKey(req.model),
|
||||
llama: server,
|
||||
Options: &req.opts,
|
||||
loading: false,
|
||||
@@ -606,7 +633,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded[req.model.ModelPath] = runner
|
||||
s.loaded[runner.modelKey] = runner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Set up expiration timer
|
||||
@@ -684,6 +711,7 @@ type runnerRef struct {
|
||||
|
||||
model *Model
|
||||
modelPath string
|
||||
modelKey string
|
||||
numParallel int
|
||||
*api.Options
|
||||
}
|
||||
@@ -703,7 +731,7 @@ func (runner *runnerRef) unload() {
|
||||
}
|
||||
|
||||
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
|
||||
slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
|
||||
slog.Debug("evaluating already loaded", "model", schedulerModelKey(req.model))
|
||||
runner.refMu.Lock()
|
||||
defer runner.refMu.Unlock()
|
||||
|
||||
@@ -814,6 +842,10 @@ func (runner *runnerRef) LogValue() slog.Value {
|
||||
if runner == nil {
|
||||
return slog.StringValue("nil")
|
||||
}
|
||||
modelID := runner.modelPath
|
||||
if modelID == "" {
|
||||
modelID = runner.modelKey
|
||||
}
|
||||
attrs := []slog.Attr{}
|
||||
if runner.model != nil {
|
||||
attrs = append(attrs, slog.String("name", runner.model.Name))
|
||||
@@ -828,7 +860,7 @@ func (runner *runnerRef) LogValue() slog.Value {
|
||||
slog.String("vram", format.HumanBytes2(runner.vramSize)),
|
||||
slog.Int("parallel", runner.numParallel),
|
||||
slog.Int("pid", runner.pid),
|
||||
slog.String("model", runner.modelPath),
|
||||
slog.String("model", modelID),
|
||||
)
|
||||
if runner.Options != nil {
|
||||
attrs = append(attrs, slog.Int("num_ctx", runner.Options.NumCtx))
|
||||
@@ -873,8 +905,16 @@ func (a ByDurationAndName) Less(i, j int) bool {
|
||||
if d1 != d2 {
|
||||
return d1 < d2
|
||||
}
|
||||
// Secondary sort by model path lex order
|
||||
return a[i].modelPath < a[j].modelPath
|
||||
// Secondary sort by model key/path lex order
|
||||
n1 := a[i].modelPath
|
||||
if n1 == "" {
|
||||
n1 = a[i].modelKey
|
||||
}
|
||||
n2 := a[j].modelPath
|
||||
if n2 == "" {
|
||||
n2 = a[j].modelKey
|
||||
}
|
||||
return n1 < n2
|
||||
}
|
||||
|
||||
// TODO - future consideration to pick runners based on size
|
||||
@@ -934,8 +974,9 @@ func (s *Scheduler) unloadAllRunners() {
|
||||
}
|
||||
|
||||
func (s *Scheduler) expireRunner(model *Model) {
|
||||
modelKey := schedulerModelKey(model)
|
||||
s.loadedMu.Lock()
|
||||
runner, ok := s.loaded[model.ModelPath]
|
||||
runner, ok := s.loaded[modelKey]
|
||||
s.loadedMu.Unlock()
|
||||
if ok {
|
||||
runner.refMu.Lock()
|
||||
|
||||
@@ -448,6 +448,71 @@ func TestSchedGetRunner(t *testing.T) {
|
||||
b.ctxDone()
|
||||
}
|
||||
|
||||
func TestSchedGetRunnerUsesDigestKeyWhenModelPathEmpty(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
opts := api.DefaultOptions()
|
||||
opts.NumCtx = 4
|
||||
|
||||
loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"}
|
||||
loadedRunner := &runnerRef{
|
||||
model: loadedModel,
|
||||
modelKey: schedulerModelKey(loadedModel),
|
||||
llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}},
|
||||
Options: &opts,
|
||||
numParallel: 1,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded[loadedRunner.modelKey] = loadedRunner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
reqModel := &Model{Name: "safetensors-b", Digest: "sha-b"}
|
||||
successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil, false)
|
||||
|
||||
require.Empty(t, successCh)
|
||||
require.Empty(t, errCh)
|
||||
require.Len(t, s.pendingReqCh, 1)
|
||||
}
|
||||
|
||||
func TestSchedGetRunnerReusesSameDigestWhenModelPathEmpty(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
opts := api.DefaultOptions()
|
||||
opts.NumCtx = 4
|
||||
|
||||
loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"}
|
||||
loadedRunner := &runnerRef{
|
||||
model: loadedModel,
|
||||
modelKey: schedulerModelKey(loadedModel),
|
||||
llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}},
|
||||
Options: &opts,
|
||||
numParallel: 1,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded[loadedRunner.modelKey] = loadedRunner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
reqCtx, cancelReq := context.WithCancel(ctx)
|
||||
successCh, errCh := s.GetRunner(reqCtx, &Model{Name: "safetensors-a-copy", Digest: "sha-a"}, opts, nil, false)
|
||||
cancelReq()
|
||||
|
||||
select {
|
||||
case runner := <-successCh:
|
||||
require.Equal(t, loadedRunner, runner)
|
||||
default:
|
||||
t.Fatal("expected existing runner to be reused")
|
||||
}
|
||||
|
||||
require.Empty(t, errCh)
|
||||
require.Empty(t, s.pendingReqCh)
|
||||
}
|
||||
|
||||
func TestSchedExpireRunner(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 20*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
Reference in New Issue
Block a user