From c61023f5548f61651b7fd04393e2a93430f89a71 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Wed, 4 Feb 2026 15:36:11 -0800 Subject: [PATCH] ollamarunner: Fix off by one error with numPredict When numPredict is set, the user will receive one less token than the requested limit. In addition, the stats will incorrectly show the number of tokens returned as the limit. In cases where numPredict is not set, the number of tokens is reported correctly. This occurs because numPredict is checked when setting up the next batch but hitting the limit will terminate the current batch as well. Instead, is is better to check the limit as we actually predict them. --- integration/basic_test.go | 44 +++++++++++++++++++++++++++++++++++ runner/ollamarunner/runner.go | 17 +++++++------- 2 files changed, 53 insertions(+), 8 deletions(-) diff --git a/integration/basic_test.go b/integration/basic_test.go index 414061479..351a1e388 100644 --- a/integration/basic_test.go +++ b/integration/basic_test.go @@ -144,3 +144,47 @@ func TestUnicodeModelDir(t *testing.T) { } ChatTestHelper(ctx, t, req, blueSkyExpected) } + +// TestNumPredict verifies that when num_predict is set, the model generates +// exactly that many tokens. It uses logprobs to count the actual tokens output. +func TestNumPredict(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + if err := PullIfMissing(ctx, client, "qwen3:0.6b"); err != nil { + t.Fatalf("failed to pull model: %v", err) + } + + req := api.GenerateRequest{ + Model: "qwen3:0.6b", + Prompt: "Write a long story.", + Stream: &stream, + Logprobs: true, + Options: map[string]any{ + "num_predict": 10, + "temperature": 0, + "seed": 123, + }, + } + + logprobCount := 0 + var finalResponse api.GenerateResponse + err := client.Generate(ctx, &req, func(resp api.GenerateResponse) error { + logprobCount += len(resp.Logprobs) + if resp.Done { + finalResponse = resp + } + return nil + }) + if err != nil { + t.Fatalf("generate failed: %v", err) + } + + if logprobCount != 10 { + t.Errorf("expected 10 tokens (logprobs), got %d (EvalCount=%d, DoneReason=%s)", + logprobCount, finalResponse.EvalCount, finalResponse.DoneReason) + } +} diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 048facde8..f4baf395b 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -514,13 +514,6 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er continue } - // if past the num predict limit - if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { - s.removeSequence(seqIdx, llm.DoneReasonLength) - nextBatch.seqs[seqIdx] = nil - continue - } - if !s.cache.enabled { seq.inputs = append(seq.cache.Inputs, seq.inputs...) seq.cache.Inputs = []*input.Input{} @@ -709,7 +702,6 @@ func (s *Server) computeBatch(activeBatch batchState) { continue } - seq.numPredicted++ nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats seq.inputs = []*input.Input{nextToken} nextBatchTokens[i] = nextToken @@ -745,7 +737,9 @@ func (s *Server) computeBatch(activeBatch batchState) { logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i) continue } + seq.lastUpdatedAt = t + seq.numPredicted++ if seq.numPredicted == 1 { seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt) seq.startedAt = seq.lastUpdatedAt @@ -791,6 +785,13 @@ func (s *Server) computeBatch(activeBatch batchState) { } seq.pendingResponses = append(seq.pendingResponses, piece) + + // if past the num predict limit + if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { + s.removeSequence(i, llm.DoneReasonLength) + continue + } + sequence := strings.Join(seq.pendingResponses, "") if ok, stop := common.FindStop(sequence, seq.stop); ok {