diff --git a/integration/basic_test.go b/integration/basic_test.go index 414061479..351a1e388 100644 --- a/integration/basic_test.go +++ b/integration/basic_test.go @@ -144,3 +144,47 @@ func TestUnicodeModelDir(t *testing.T) { } ChatTestHelper(ctx, t, req, blueSkyExpected) } + +// TestNumPredict verifies that when num_predict is set, the model generates +// exactly that many tokens. It uses logprobs to count the actual tokens output. +func TestNumPredict(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + if err := PullIfMissing(ctx, client, "qwen3:0.6b"); err != nil { + t.Fatalf("failed to pull model: %v", err) + } + + req := api.GenerateRequest{ + Model: "qwen3:0.6b", + Prompt: "Write a long story.", + Stream: &stream, + Logprobs: true, + Options: map[string]any{ + "num_predict": 10, + "temperature": 0, + "seed": 123, + }, + } + + logprobCount := 0 + var finalResponse api.GenerateResponse + err := client.Generate(ctx, &req, func(resp api.GenerateResponse) error { + logprobCount += len(resp.Logprobs) + if resp.Done { + finalResponse = resp + } + return nil + }) + if err != nil { + t.Fatalf("generate failed: %v", err) + } + + if logprobCount != 10 { + t.Errorf("expected 10 tokens (logprobs), got %d (EvalCount=%d, DoneReason=%s)", + logprobCount, finalResponse.EvalCount, finalResponse.DoneReason) + } +} diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 048facde8..f4baf395b 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -514,13 +514,6 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er continue } - // if past the num predict limit - if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { - s.removeSequence(seqIdx, llm.DoneReasonLength) - nextBatch.seqs[seqIdx] = nil - continue - } - if !s.cache.enabled { seq.inputs = append(seq.cache.Inputs, seq.inputs...) seq.cache.Inputs = []*input.Input{} @@ -709,7 +702,6 @@ func (s *Server) computeBatch(activeBatch batchState) { continue } - seq.numPredicted++ nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats seq.inputs = []*input.Input{nextToken} nextBatchTokens[i] = nextToken @@ -745,7 +737,9 @@ func (s *Server) computeBatch(activeBatch batchState) { logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i) continue } + seq.lastUpdatedAt = t + seq.numPredicted++ if seq.numPredicted == 1 { seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt) seq.startedAt = seq.lastUpdatedAt @@ -791,6 +785,13 @@ func (s *Server) computeBatch(activeBatch batchState) { } seq.pendingResponses = append(seq.pendingResponses, piece) + + // if past the num predict limit + if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { + s.removeSequence(i, llm.DoneReasonLength) + continue + } + sequence := strings.Join(seq.pendingResponses, "") if ok, stop := common.FindStop(sequence, seq.stop); ok {