From 68e00c7c36714814a0a3642caca539f6794c0d35 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Mon, 19 Jan 2026 12:48:34 -0800 Subject: [PATCH] fix: prevent image generation models from loading during deletion (#13781) Move the unload check (empty prompt + KeepAlive=0) before the image generation model dispatch in GenerateHandler. This prevents models like flux from being loaded into memory just to be immediately unloaded when running `ollama rm`. Also fix a bug in DeleteHandler where `args[0]` was used instead of `arg` in the delete loop, causing only the first model to be unloaded when deleting multiple models. --- cmd/cmd.go | 4 +- server/routes.go | 14 +++--- server/routes_generate_test.go | 92 ++++++++++++++++++++++++++++++++++ 3 files changed, 101 insertions(+), 9 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index 5139c05cb..c9c89af56 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -899,11 +899,11 @@ func DeleteHandler(cmd *cobra.Command, args []string) error { for _, arg := range args { // Unload the model if it's running before deletion if err := loadOrUnloadModel(cmd, &runOptions{ - Model: args[0], + Model: arg, KeepAlive: &api.Duration{Duration: 0}, }); err != nil { if !strings.Contains(strings.ToLower(err.Error()), "not found") { - fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", args[0]) + fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", arg) } } diff --git a/server/routes.go b/server/routes.go index 5029046b6..bc848e5a3 100644 --- a/server/routes.go +++ b/server/routes.go @@ -220,12 +220,6 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - // Handle image generation models - if slices.Contains(m.Capabilities(), model.CapabilityImage) { - s.handleImageGenerate(c, req, name.String(), checkpointStart) - return - } - if req.TopLogprobs < 0 || req.TopLogprobs > 20 { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"}) return @@ -321,7 +315,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - // expire the runner + // expire the runner if unload is requested (empty prompt, keep alive is 0) if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 { s.sched.expireRunner(m) @@ -335,6 +329,12 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } + // Handle image generation models + if slices.Contains(m.Capabilities(), model.CapabilityImage) { + s.handleImageGenerate(c, req, name.String(), checkpointStart) + return + } + if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"}) return diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index 111a9678a..cad721f8b 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -2101,3 +2101,95 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) { } }) } + +func TestGenerateUnload(t *testing.T) { + gin.SetMode(gin.TestMode) + + var loadFnCalled bool + + s := Server{ + sched: &Scheduler{ + pendingReqCh: make(chan *LlmRequest, 1), + finishedReqCh: make(chan *LlmRequest, 1), + expiredCh: make(chan *runnerRef, 1), + unloadedCh: make(chan any, 1), + loaded: make(map[string]*runnerRef), + newServerFn: newMockServer(&mockRunner{}), + getGpuFn: getGpuFn, + getSystemInfoFn: getSystemInfoFn, + loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool { + loadFnCalled = true + req.successCh <- &runnerRef{llama: &mockRunner{}} + return false + }, + }, + } + + go s.sched.Run(t.Context()) + + _, digest := createBinFile(t, ggml.KV{ + "general.architecture": "llama", + "llama.block_count": uint32(1), + "llama.context_length": uint32(8192), + "llama.embedding_length": uint32(4096), + "llama.attention.head_count": uint32(32), + "llama.attention.head_count_kv": uint32(8), + "tokenizer.ggml.tokens": []string{""}, + "tokenizer.ggml.scores": []float32{0}, + "tokenizer.ggml.token_type": []int32{0}, + }, []*ggml.Tensor{ + {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + }) + + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "test", + Files: map[string]string{"file.gguf": digest}, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + t.Run("unload with empty prompt and keepalive 0", func(t *testing.T) { + loadFnCalled = false + + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test", + Prompt: "", + KeepAlive: &api.Duration{Duration: 0}, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + var resp api.GenerateResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if resp.DoneReason != "unload" { + t.Errorf("expected done_reason 'unload', got %q", resp.DoneReason) + } + + if !resp.Done { + t.Error("expected done to be true") + } + + if loadFnCalled { + t.Error("expected model NOT to be loaded for unload request, but loadFn was called") + } + }) +}