From cbe9e43c9db0f6ab0bc80e63c1ad940a0ad418e2 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Mon, 2 Mar 2026 15:27:34 -0800 Subject: [PATCH] sched: Model eviction for MLX MLX runners (image generation and LLM) previously bypassed the scheduler's standard load path via a separate loadMLX method. This meant they skipped VRAM fitting checks and couldn't participate in model eviction. Now all model types flow through the same load function. Model eviction for MLX is based on weights as KV cache and compute graph are dynamic. This means that eviction does not take into account the worst case memory and models can still compete for memory but it is a significant improvement. --- server/routes_debug_test.go | 4 +- server/routes_generate_renderer_test.go | 4 +- server/routes_generate_test.go | 14 +- server/routes_harmony_streaming_test.go | 6 +- server/sched.go | 131 ++++----------- server/sched_test.go | 75 ++++++--- x/imagegen/server.go | 106 ++++++------ x/mlxrunner/client.go | 213 +++++++++++++----------- 8 files changed, 276 insertions(+), 277 deletions(-) diff --git a/server/routes_debug_test.go b/server/routes_debug_test.go index a9d14b8cd..81ed4381c 100644 --- a/server/routes_debug_test.go +++ b/server/routes_debug_test.go @@ -40,7 +40,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) { getGpuFn: getGpuFn, getSystemInfoFn: getSystemInfoFn, waitForRecovery: 250 * time.Millisecond, - loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { + loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { // add small delay to simulate loading time.Sleep(time.Millisecond) req.successCh <- &runnerRef{ @@ -234,7 +234,7 @@ func TestChatDebugRenderOnly(t *testing.T) { getGpuFn: getGpuFn, getSystemInfoFn: getSystemInfoFn, waitForRecovery: 250 * time.Millisecond, - loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { + loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { // add small delay to simulate loading time.Sleep(time.Millisecond) req.successCh <- &runnerRef{ diff --git a/server/routes_generate_renderer_test.go b/server/routes_generate_renderer_test.go index d1e6bb56d..871486e5c 100644 --- a/server/routes_generate_renderer_test.go +++ b/server/routes_generate_renderer_test.go @@ -45,7 +45,7 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) { getGpuFn: getGpuFn, getSystemInfoFn: getSystemInfoFn, waitForRecovery: 250 * time.Millisecond, - loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { + loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { time.Sleep(time.Millisecond) req.successCh <- &runnerRef{ llama: &mock, @@ -230,7 +230,7 @@ func TestGenerateWithDebugRenderOnly(t *testing.T) { getGpuFn: getGpuFn, getSystemInfoFn: getSystemInfoFn, waitForRecovery: 250 * time.Millisecond, - loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { + loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { time.Sleep(time.Millisecond) req.successCh <- &runnerRef{ llama: &mock, diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index 0679b4262..458be7ddd 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -187,7 +187,7 @@ func TestGenerateChat(t *testing.T) { getGpuFn: getGpuFn, getSystemInfoFn: getSystemInfoFn, waitForRecovery: 250 * time.Millisecond, - loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { + loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { // add small delay to simulate loading time.Sleep(time.Millisecond) req.successCh <- &runnerRef{ @@ -904,7 +904,7 @@ func TestGenerate(t *testing.T) { getGpuFn: getGpuFn, getSystemInfoFn: getSystemInfoFn, waitForRecovery: 250 * time.Millisecond, - loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { + loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { // add small delay to simulate loading time.Sleep(time.Millisecond) req.successCh <- &runnerRef{ @@ -1388,7 +1388,7 @@ func TestGenerateLogprobs(t *testing.T) { getGpuFn: getGpuFn, getSystemInfoFn: getSystemInfoFn, waitForRecovery: 250 * time.Millisecond, - loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { + loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { req.successCh <- &runnerRef{llama: mock} return false }, @@ -1568,7 +1568,7 @@ func TestChatLogprobs(t *testing.T) { getGpuFn: getGpuFn, getSystemInfoFn: getSystemInfoFn, waitForRecovery: 250 * time.Millisecond, - loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { + loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { req.successCh <- &runnerRef{llama: mock} return false }, @@ -1678,7 +1678,7 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) { getGpuFn: getGpuFn, getSystemInfoFn: getSystemInfoFn, waitForRecovery: 250 * time.Millisecond, - loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { + loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { time.Sleep(time.Millisecond) req.successCh <- &runnerRef{llama: mock} return false @@ -2123,7 +2123,7 @@ func TestGenerateUnload(t *testing.T) { newServerFn: newMockServer(&mockRunner{}), getGpuFn: getGpuFn, getSystemInfoFn: getSystemInfoFn, - loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { + loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { loadFnCalled = true req.successCh <- &runnerRef{llama: &mockRunner{}} return false @@ -2225,7 +2225,7 @@ func TestGenerateWithImages(t *testing.T) { getGpuFn: getGpuFn, getSystemInfoFn: getSystemInfoFn, waitForRecovery: 250 * time.Millisecond, - loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { + loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { time.Sleep(time.Millisecond) req.successCh <- &runnerRef{ llama: &mock, diff --git a/server/routes_harmony_streaming_test.go b/server/routes_harmony_streaming_test.go index de130c8c8..0a5145d96 100644 --- a/server/routes_harmony_streaming_test.go +++ b/server/routes_harmony_streaming_test.go @@ -265,7 +265,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { getGpuFn: getGpuFn, getSystemInfoFn: getSystemInfoFn, waitForRecovery: 100 * time.Millisecond, - loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { + loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { req.successCh <- &runnerRef{ llama: &mock, } @@ -416,7 +416,7 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) { getGpuFn: getGpuFn, getSystemInfoFn: getSystemInfoFn, waitForRecovery: 100 * time.Millisecond, - loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { + loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { req.successCh <- &runnerRef{ llama: &mock, } @@ -598,7 +598,7 @@ func TestChatHarmonyParserStreaming(t *testing.T) { getGpuFn: getGpuFn, getSystemInfoFn: getSystemInfoFn, waitForRecovery: 250 * time.Millisecond, - loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { + loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { req.successCh <- &runnerRef{ llama: &mock, } diff --git a/server/sched.go b/server/sched.go index 3d0dac863..f040e34f3 100644 --- a/server/sched.go +++ b/server/sched.go @@ -51,7 +51,7 @@ type Scheduler struct { activeLoading llm.LlamaServer loaded map[string]*runnerRef - loadFn func(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) bool + loadFn func(req *LlmRequest, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) bool newServerFn func(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) getGpuFn func(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.DeviceInfo getSystemInfoFn func() ml.SystemInfo @@ -220,33 +220,6 @@ func (s *Scheduler) processPending(ctx context.Context) { slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus)) } - // Check for image generation models - all use MLX runner - if slices.Contains(pending.model.Config.Capabilities, "image") { - if s.loadMLX(pending) { - break - } - continue - } - - // Check for experimental safetensors LLM models - if pending.model.IsMLX() { - if slices.Contains(pending.model.Config.Capabilities, "completion") { - // LLM model with safetensors format - use MLX runner - if s.loadMLX(pending) { - break - } - continue - } - } - - // Load model for fitting - logutil.Trace("loading model metadata", "model", pending.model.ModelPath) - ggml, err := llm.LoadModel(pending.model.ModelPath, 1024) - if err != nil { - pending.errCh <- err - break - } - // Update free memory from currently loaded models logutil.Trace("updating free space", "gpu_count", len(gpus), "model", pending.model.ModelPath) s.updateFreeSpace(gpus) @@ -254,14 +227,14 @@ func (s *Scheduler) processPending(ctx context.Context) { if loadedCount == 0 { // No models loaded. Load the model but prefer the best fit. slog.Debug("loading first model", "model", pending.model.ModelPath) - s.loadFn(pending, ggml, systemInfo, gpus, false) + s.loadFn(pending, systemInfo, gpus, false) break } // More than one loaded model, so we have to see if the // new one fits logutil.Trace("loading additional model", "model", pending.model.ModelPath) - needEvict := s.loadFn(pending, ggml, systemInfo, gpus, true) + needEvict := s.loadFn(pending, systemInfo, gpus, true) if !needEvict { slog.Debug("new model fits with existing models, loading") break @@ -435,7 +408,7 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm // load creates a new model based on req and loads it. If requireFull is true then the model must be loaded fully onto GPUs // (if any). Returns whether the scheduler needs to evict a model to make this one fit. -func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) bool { +func (s *Scheduler) load(req *LlmRequest, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) bool { numParallel := max(int(envconfig.NumParallel()), 1) // Embedding models should always be loaded with parallel=1 @@ -460,15 +433,33 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo if llama == nil { var err error - llama, err = s.newServerFn(systemInfo, gpus, req.model.ModelPath, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel) - if err != nil { - // some older models are not compatible with newer versions of llama.cpp - // show a generalized compatibility error until there is a better way to - // check for model compatibility - if errors.Is(err, ggml.ErrUnsupportedFormat) || strings.Contains(err.Error(), "failed to load model") { - err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName) + if !req.model.IsMLX() { + f, loadErr := llm.LoadModel(req.model.ModelPath, 1024) + if loadErr != nil { + slog.Info("failed to load model metadata", "model", req.model.ModelPath, "error", loadErr) + req.errCh <- loadErr + s.loadedMu.Unlock() + return false } - slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err) + llama, err = s.newServerFn(systemInfo, gpus, req.model.ModelPath, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel) + if err != nil { + // some older models are not compatible with newer versions of llama.cpp + // show a generalized compatibility error until there is a better way to + // check for model compatibility + if errors.Is(err, ggml.ErrUnsupportedFormat) || strings.Contains(err.Error(), "failed to load model") { + err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName) + } + } + } else { + modelName := req.model.ShortName + if slices.Contains(req.model.Config.Capabilities, "image") { + llama, err = imagegen.NewServer(modelName) + } else { + llama, err = mlxrunner.NewClient(modelName) + } + } + if err != nil { + slog.Info("failed to create server", "model", req.model.ShortName, "error", err) req.errCh <- err s.loadedMu.Unlock() return false @@ -476,8 +467,12 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo s.activeLoading = llama } else { - if s.activeLoading.ModelPath() != req.model.ModelPath { - panic(fmt.Errorf("attempting to load different model after eviction (original %v new %v)", s.activeLoading.ModelPath(), req.model.ModelPath)) + wantPath := req.model.ModelPath + if wantPath == "" { + wantPath = req.model.ShortName + } + if s.activeLoading.ModelPath() != wantPath { + panic(fmt.Errorf("attempting to load different model after eviction (original %v new %v)", s.activeLoading.ModelPath(), wantPath)) } } @@ -544,6 +539,7 @@ iGPUScan: sessionDuration: sessionDuration, gpus: gpuIDs, discreteGPUs: discreteGPUs, + isImagegen: slices.Contains(req.model.Config.Capabilities, "image"), totalSize: totalSize, vramSize: vramSize, loading: true, @@ -591,59 +587,6 @@ iGPUScan: return false } -// loadMLX loads an experimental safetensors model using MLX runners. -// Image models use x/imagegen; LLM models use x/mlxrunner. -func (s *Scheduler) loadMLX(req *LlmRequest) bool { - modelName := req.model.ShortName - var server llm.LlamaServer - var err error - - if slices.Contains(req.model.Config.Capabilities, "image") { - server, err = imagegen.NewServer(modelName) - } else { - server, err = mlxrunner.NewClient(modelName) - } - if err != nil { - req.errCh <- err - return true - } - - sessionDuration := envconfig.KeepAlive() - if req.sessionDuration != nil { - sessionDuration = req.sessionDuration.Duration - } - - totalSize, vramSize := server.MemorySize() - runner := &runnerRef{ - model: req.model, - modelPath: req.model.ModelPath, - modelKey: schedulerModelKey(req.model), - llama: server, - Options: &req.opts, - loading: false, - isImagegen: slices.Contains(req.model.Config.Capabilities, "image"), - sessionDuration: sessionDuration, - totalSize: totalSize, - vramSize: vramSize, - } - - s.loadedMu.Lock() - s.loaded[runner.modelKey] = runner - s.loadedMu.Unlock() - - // Set up expiration timer - runner.refMu.Lock() - if sessionDuration > 0 { - runner.expireTimer = time.AfterFunc(sessionDuration, func() { - s.expiredCh <- runner - }) - } - runner.refMu.Unlock() - - req.useLoadedRunner(runner, s.finishedReqCh) - return true -} - func (s *Scheduler) updateFreeSpace(allGpus []ml.DeviceInfo) { if len(allGpus) == 0 { return diff --git a/server/sched_test.go b/server/sched_test.go index 0b79c7834..f40dc117f 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -39,10 +39,25 @@ func TestSchedLoad(t *testing.T) { defer done() s := InitScheduler(ctx) s.waitForRecovery = 10 * time.Millisecond - var f *ggml.GGML // value not used in tests + + modelPath, _ := createBinFile(t, ggml.KV{ + "general.architecture": "llama", + "llama.context_length": uint32(32), + "llama.embedding_length": uint32(4096), + "llama.block_count": uint32(1), + "llama.attention.head_count": uint32(32), + "llama.attention.head_count_kv": uint32(32), + "tokenizer.ggml.tokens": []string{" "}, + "tokenizer.ggml.scores": []float32{0}, + "tokenizer.ggml.token_type": []int32{0}, + }, []*ggml.Tensor{ + {Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))}, + {Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))}, + }) + req := &LlmRequest{ ctx: ctx, - model: &Model{ModelPath: "foo"}, + model: &Model{ModelPath: modelPath}, opts: api.DefaultOptions(), successCh: make(chan *runnerRef, 1), errCh: make(chan error, 1), @@ -54,7 +69,7 @@ func TestSchedLoad(t *testing.T) { } gpus := []ml.DeviceInfo{} systemInfo := ml.SystemInfo{} - s.load(req, f, systemInfo, gpus, false) + s.load(req, systemInfo, gpus, false) require.Empty(t, req.successCh) require.Len(t, req.errCh, 1) s.loadedMu.Lock() @@ -68,7 +83,7 @@ func TestSchedLoad(t *testing.T) { server.modelPath = model return server, nil } - s.load(req, f, systemInfo, gpus, false) + s.load(req, systemInfo, gpus, false) select { case err := <-req.errCh: require.NoError(t, err) @@ -80,9 +95,24 @@ func TestSchedLoad(t *testing.T) { s.loadedMu.Unlock() } - req.model.ModelPath = "dummy_model_path" + modelPath2, _ := createBinFile(t, ggml.KV{ + "general.architecture": "llama", + "llama.context_length": uint32(32), + "llama.embedding_length": uint32(4096), + "llama.block_count": uint32(1), + "llama.attention.head_count": uint32(32), + "llama.attention.head_count_kv": uint32(32), + "tokenizer.ggml.tokens": []string{" "}, + "tokenizer.ggml.scores": []float32{0}, + "tokenizer.ggml.token_type": []int32{0}, + }, []*ggml.Tensor{ + {Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))}, + {Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))}, + }) + + req.model.ModelPath = modelPath2 server.waitResp = errors.New("wait failure") - s.load(req, f, systemInfo, gpus, false) + s.load(req, systemInfo, gpus, false) select { case err := <-req.errCh: require.Contains(t, err.Error(), "wait failure") @@ -90,7 +120,7 @@ func TestSchedLoad(t *testing.T) { t.Fatalf("unexpected success %v", resp) } s.loadedMu.Lock() - runner := s.loaded["dummy_model_path"] + runner := s.loaded[modelPath2] s.loadedMu.Unlock() require.NotNil(t, runner) require.Equal(t, uint(0), runner.refCount) @@ -103,7 +133,6 @@ type reqBundle struct { ctxDone func() srv *mockLlm req *LlmRequest - f *ggml.GGML } func (scenario *reqBundle) newServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { @@ -132,11 +161,6 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, vra }) model := &Model{Name: modelName, ModelPath: p} - f, err := llm.LoadModel(model.ModelPath, 0) - if err != nil { - t.Fatal(err) - } - b.f = f if duration == nil { duration = &api.Duration{Duration: 5 * time.Millisecond} } @@ -178,7 +202,6 @@ func TestSchedRequestsSameModelSameRequest(t *testing.T) { a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond}, nil) b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0}, nil) b.req.model = a.req.model - b.f = a.f s.newServerFn = a.newServer slog.Info("a") @@ -223,7 +246,6 @@ func TestSchedRequestsSimpleReloadSameModel(t *testing.T) { b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond}, nil) tmpModel := *a.req.model b.req.model = &tmpModel - b.f = a.f s.newServerFn = a.newServer slog.Info("a") @@ -518,16 +540,31 @@ func TestSchedExpireRunner(t *testing.T) { defer done() s := InitScheduler(ctx) s.waitForRecovery = 10 * time.Millisecond + + modelPath, _ := createBinFile(t, ggml.KV{ + "general.architecture": "llama", + "llama.context_length": uint32(32), + "llama.embedding_length": uint32(4096), + "llama.block_count": uint32(1), + "llama.attention.head_count": uint32(32), + "llama.attention.head_count_kv": uint32(32), + "tokenizer.ggml.tokens": []string{" "}, + "tokenizer.ggml.scores": []float32{0}, + "tokenizer.ggml.token_type": []int32{0}, + }, []*ggml.Tensor{ + {Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))}, + {Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))}, + }) + req := &LlmRequest{ ctx: ctx, - model: &Model{ModelPath: "foo"}, + model: &Model{ModelPath: modelPath}, opts: api.DefaultOptions(), successCh: make(chan *runnerRef, 1), errCh: make(chan error, 1), sessionDuration: &api.Duration{Duration: 2 * time.Minute}, } - var f *ggml.GGML gpus := []ml.DeviceInfo{} systemInfo := ml.SystemInfo{} server := &mockLlm{vramSize: 10, vramByGPU: map[ml.DeviceID]uint64{}} @@ -535,7 +572,7 @@ func TestSchedExpireRunner(t *testing.T) { server.modelPath = model return server, nil } - s.load(req, f, systemInfo, gpus, false) + s.load(req, systemInfo, gpus, false) select { case err := <-req.errCh: @@ -550,7 +587,7 @@ func TestSchedExpireRunner(t *testing.T) { s.loadedMu.Unlock() } - s.expireRunner(&Model{ModelPath: "foo"}) + s.expireRunner(&Model{ModelPath: modelPath}) s.finishedReqCh <- req s.processCompleted(ctx) diff --git a/x/imagegen/server.go b/x/imagegen/server.go index 102cb0c55..d3eccc4d7 100644 --- a/x/imagegen/server.go +++ b/x/imagegen/server.go @@ -22,6 +22,7 @@ import ( "time" "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/format" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/x/imagegen/manifest" @@ -43,13 +44,52 @@ type Server struct { lastErrLock sync.Mutex } -// NewServer spawns a new MLX runner subprocess and waits until it's ready. +// NewServer prepares a new MLX runner server for image generation models. +// The subprocess is not started until Load() is called. func NewServer(modelName string) (*Server, error) { // Validate platform support before attempting to start if err := CheckPlatformSupport(); err != nil { return nil, err } + return &Server{ + modelName: modelName, + done: make(chan error, 1), + client: &http.Client{Timeout: 10 * time.Minute}, + }, nil +} + +// ModelPath returns the path to the model. +func (s *Server) ModelPath() string { + return s.modelName +} + +// Load checks whether the model fits in GPU memory and starts the subprocess. +func (s *Server) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) { + // Estimate VRAM based on tensor size from manifest + if modelManifest, err := manifest.LoadManifest(s.modelName); err == nil { + s.vramSize = uint64(modelManifest.TotalTensorSize()) + } else { + s.vramSize = 8 * 1024 * 1024 * 1024 + } + + if len(gpus) > 0 { + available := gpus[0].FreeMemory + overhead := gpus[0].MinimumMemory() + envconfig.GpuOverhead() + if available > overhead { + available -= overhead + } else { + available = 0 + } + + if s.vramSize > available { + if requireFull { + return nil, llm.ErrLoadRequiredFull + } + return nil, fmt.Errorf("model requires %s but only %s are available (after %s overhead)", format.HumanBytes2(s.vramSize), format.HumanBytes2(available), format.HumanBytes2(overhead)) + } + } + // Find a free port port := 0 if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { @@ -61,6 +101,7 @@ func NewServer(modelName string) (*Server, error) { if port == 0 { port = rand.Intn(65535-49152) + 49152 } + s.port = port // Get the current executable path (we use the same binary with runner subcommand) exe, err := os.Executable() @@ -72,7 +113,7 @@ func NewServer(modelName string) (*Server, error) { } // Spawn subprocess: ollama runner --imagegen-engine --model --port - cmd := exec.Command(exe, "runner", "--imagegen-engine", "--model", modelName, "--port", strconv.Itoa(port)) + cmd := exec.Command(exe, "runner", "--imagegen-engine", "--model", s.modelName, "--port", strconv.Itoa(port)) cmd.Env = os.Environ() // On Linux, set LD_LIBRARY_PATH to include MLX library directories @@ -105,23 +146,7 @@ func NewServer(modelName string) (*Server, error) { slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal) } - // Estimate VRAM based on tensor size from manifest - var vramSize uint64 - if modelManifest, err := manifest.LoadManifest(modelName); err == nil { - vramSize = uint64(modelManifest.TotalTensorSize()) - } else { - // Fallback: default to 8GB if manifest can't be loaded - vramSize = 8 * 1024 * 1024 * 1024 - } - - s := &Server{ - cmd: cmd, - port: port, - modelName: modelName, - vramSize: vramSize, - done: make(chan error, 1), - client: &http.Client{Timeout: 10 * time.Minute}, - } + s.cmd = cmd // Forward subprocess stdout/stderr to server logs stdout, _ := cmd.StdoutPipe() @@ -143,7 +168,7 @@ func NewServer(modelName string) (*Server, error) { } }() - slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port) + slog.Info("starting mlx runner subprocess", "model", s.modelName, "port", s.port) if err := cmd.Start(); err != nil { return nil, fmt.Errorf("failed to start mlx runner: %w", err) } @@ -154,22 +179,6 @@ func NewServer(modelName string) (*Server, error) { s.done <- err }() - // Wait for subprocess to be ready - if err := s.waitUntilRunning(); err != nil { - s.Close() - return nil, err - } - - return s, nil -} - -// ModelPath returns the path to the model. -func (s *Server) ModelPath() string { - return s.modelName -} - -// Load satisfies the LlamaServer interface. MLX models don't need GPU layer assignment. -func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) { return nil, nil } @@ -191,9 +200,15 @@ func (s *Server) Ping(ctx context.Context) error { return nil } -// waitUntilRunning waits for the subprocess to be ready. -func (s *Server) waitUntilRunning() error { - ctx := context.Background() +// getLastErr returns the last stderr line. +func (s *Server) getLastErr() string { + s.lastErrLock.Lock() + defer s.lastErrLock.Unlock() + return s.lastErr +} + +// WaitUntilRunning waits for the subprocess to be ready. +func (s *Server) WaitUntilRunning(ctx context.Context) error { timeout := time.After(envconfig.LoadTimeout()) ticker := time.NewTicker(100 * time.Millisecond) defer ticker.Stop() @@ -201,7 +216,6 @@ func (s *Server) waitUntilRunning() error { for { select { case err := <-s.done: - // Include recent stderr lines for better error context errMsg := s.getLastErr() if errMsg != "" { return fmt.Errorf("mlx runner failed: %s (exit: %v)", errMsg, err) @@ -222,18 +236,6 @@ func (s *Server) waitUntilRunning() error { } } -// getLastErr returns the last stderr line. -func (s *Server) getLastErr() string { - s.lastErrLock.Lock() - defer s.lastErrLock.Unlock() - return s.lastErr -} - -// WaitUntilRunning satisfies the LlamaServer interface. -func (s *Server) WaitUntilRunning(ctx context.Context) error { - return nil -} - // Completion handles both text and image generation requests. func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { seed := req.Seed diff --git a/x/mlxrunner/client.go b/x/mlxrunner/client.go index c4f8c77ce..b5d0e6fb7 100644 --- a/x/mlxrunner/client.go +++ b/x/mlxrunner/client.go @@ -22,9 +22,12 @@ import ( "time" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/format" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/x/imagegen" + "github.com/ollama/ollama/x/imagegen/manifest" ) // Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models. @@ -41,105 +44,24 @@ type Client struct { cmd *exec.Cmd } -// NewClient spawns a new MLX runner subprocess for LLM models and waits until it's ready. +// NewClient prepares a new MLX runner client for LLM models. +// The subprocess is not started until Load() is called. func NewClient(modelName string) (*Client, error) { if err := imagegen.CheckPlatformSupport(); err != nil { return nil, err } - // Find a free port - port := 0 - if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { - if l, err := net.ListenTCP("tcp", a); err == nil { - port = l.Addr().(*net.TCPAddr).Port - l.Close() - } - } - if port == 0 { - port = rand.Intn(65535-49152) + 49152 - } - - // Get the current executable path - exe, err := os.Executable() - if err != nil { - return nil, fmt.Errorf("unable to lookup executable path: %w", err) - } - if eval, err := filepath.EvalSymlinks(exe); err == nil { - exe = eval - } - - // Spawn subprocess: ollama runner --mlx-engine --model --port - cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", modelName, "--port", strconv.Itoa(port)) - cmd.Env = os.Environ() - - // On Linux, set LD_LIBRARY_PATH to include MLX library directories - if runtime.GOOS == "linux" { - libraryPaths := []string{ml.LibOllamaPath} - if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil { - libraryPaths = append(libraryPaths, mlxDirs...) - } - - if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok { - libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...) - } - - pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator)) - - found := false - for i := range cmd.Env { - if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") { - cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal - found = true - break - } - } - if !found { - cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal) - } - slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal) - } - c := &Client{ - port: port, modelName: modelName, done: make(chan error, 1), client: &http.Client{Timeout: 10 * time.Minute}, - cmd: cmd, } - // Forward subprocess stdout/stderr to server logs - stdout, _ := cmd.StdoutPipe() - stderr, _ := cmd.StderrPipe() - go func() { - io.Copy(os.Stderr, stdout) //nolint:errcheck - }() - go func() { - scanner := bufio.NewScanner(stderr) - for scanner.Scan() { - line := scanner.Text() - fmt.Fprintln(os.Stderr, line) - c.lastErrLock.Lock() - c.lastErr = line - c.lastErrLock.Unlock() - } - }() - - slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port) - if err := cmd.Start(); err != nil { - return nil, fmt.Errorf("failed to start mlx runner: %w", err) - } - - // Reap subprocess when it exits - go func() { - err := cmd.Wait() - c.done <- err - }() - - // Wait for subprocess to be ready - if err := c.waitUntilRunning(); err != nil { - c.Close() + modelManifest, err := manifest.LoadManifest(modelName) + if err != nil { return nil, err } + c.memory.Store(uint64(modelManifest.TotalTensorSize())) return c, nil } @@ -150,8 +72,8 @@ func (c *Client) getLastErr() string { return c.lastErr } -func (c *Client) waitUntilRunning() error { - ctx := context.Background() +// WaitUntilRunning waits for the subprocess to be ready. +func (c *Client) WaitUntilRunning(ctx context.Context) error { timeout := time.After(2 * time.Minute) ticker := time.NewTicker(100 * time.Millisecond) defer ticker.Stop() @@ -328,8 +250,110 @@ func (c *Client) HasExited() bool { } } -// Load implements llm.LlamaServer. -func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) ([]ml.DeviceID, error) { +// Load checks whether the model fits in GPU memory and starts the subprocess. +func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) { + if len(gpus) > 0 { + modelSize := c.memory.Load() + // We currently only use the first GPU with MLX + available := gpus[0].FreeMemory + overhead := gpus[0].MinimumMemory() + envconfig.GpuOverhead() + if available > overhead { + available -= overhead + } else { + available = 0 + } + + if modelSize > available { + if requireFull { + return nil, llm.ErrLoadRequiredFull + } + return nil, fmt.Errorf("model requires %s but only %s are available (after %s overhead)", format.HumanBytes2(modelSize), format.HumanBytes2(available), format.HumanBytes2(overhead)) + } + } + + // Find a free port + port := 0 + if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { + if l, err := net.ListenTCP("tcp", a); err == nil { + port = l.Addr().(*net.TCPAddr).Port + l.Close() + } + } + if port == 0 { + port = rand.Intn(65535-49152) + 49152 + } + c.port = port + + // Get the current executable path + exe, err := os.Executable() + if err != nil { + return nil, fmt.Errorf("unable to lookup executable path: %w", err) + } + if eval, err := filepath.EvalSymlinks(exe); err == nil { + exe = eval + } + + // Spawn subprocess: ollama runner --mlx-engine --model --port + cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", c.modelName, "--port", strconv.Itoa(port)) + cmd.Env = os.Environ() + + // On Linux, set LD_LIBRARY_PATH to include MLX library directories + if runtime.GOOS == "linux" { + libraryPaths := []string{ml.LibOllamaPath} + if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil { + libraryPaths = append(libraryPaths, mlxDirs...) + } + + if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok { + libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...) + } + + pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator)) + + found := false + for i := range cmd.Env { + if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") { + cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal + found = true + break + } + } + if !found { + cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal) + } + slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal) + } + + c.cmd = cmd + + // Forward subprocess stdout/stderr to server logs + stdout, _ := cmd.StdoutPipe() + stderr, _ := cmd.StderrPipe() + go func() { + io.Copy(os.Stderr, stdout) //nolint:errcheck + }() + go func() { + scanner := bufio.NewScanner(stderr) + for scanner.Scan() { + line := scanner.Text() + fmt.Fprintln(os.Stderr, line) + c.lastErrLock.Lock() + c.lastErr = line + c.lastErrLock.Unlock() + } + }() + + slog.Info("starting mlx runner subprocess", "model", c.modelName, "port", c.port) + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("failed to start mlx runner: %w", err) + } + + // Reap subprocess when it exits + go func() { + err := cmd.Wait() + c.done <- err + }() + return nil, nil } @@ -408,9 +432,7 @@ func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) { func (c *Client) currentMemory() uint64 { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - if err := c.Ping(ctx); err != nil { - slog.Warn("failed to get current memory", "error", err) - } + c.Ping(ctx) //nolint:errcheck return c.memory.Load() } @@ -425,9 +447,4 @@ func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 { return c.currentMemory() } -// WaitUntilRunning implements llm.LlamaServer. -func (c *Client) WaitUntilRunning(ctx context.Context) error { - return nil -} - var _ llm.LlamaServer = (*Client)(nil)