ollamarunner: Fix off by one error with numPredict

When numPredict is set, the user will receive one less token
than the requested limit. In addition, the stats will incorrectly
show the number of tokens returned as the limit. In cases where
numPredict is not set, the number of tokens is reported correctly.

This occurs because numPredict is checked when setting up the next
batch but hitting the limit will terminate the current batch as well.
Instead, is is better to check the limit as we actually predict them.
This commit is contained in:
Jesse Gross
2026-02-04 15:36:11 -08:00
parent d25535c3f3
commit c61023f554
2 changed files with 53 additions and 8 deletions

View File

@@ -144,3 +144,47 @@ func TestUnicodeModelDir(t *testing.T) {
} }
ChatTestHelper(ctx, t, req, blueSkyExpected) ChatTestHelper(ctx, t, req, blueSkyExpected)
} }
// TestNumPredict verifies that when num_predict is set, the model generates
// exactly that many tokens. It uses logprobs to count the actual tokens output.
func TestNumPredict(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, "qwen3:0.6b"); err != nil {
t.Fatalf("failed to pull model: %v", err)
}
req := api.GenerateRequest{
Model: "qwen3:0.6b",
Prompt: "Write a long story.",
Stream: &stream,
Logprobs: true,
Options: map[string]any{
"num_predict": 10,
"temperature": 0,
"seed": 123,
},
}
logprobCount := 0
var finalResponse api.GenerateResponse
err := client.Generate(ctx, &req, func(resp api.GenerateResponse) error {
logprobCount += len(resp.Logprobs)
if resp.Done {
finalResponse = resp
}
return nil
})
if err != nil {
t.Fatalf("generate failed: %v", err)
}
if logprobCount != 10 {
t.Errorf("expected 10 tokens (logprobs), got %d (EvalCount=%d, DoneReason=%s)",
logprobCount, finalResponse.EvalCount, finalResponse.DoneReason)
}
}

View File

@@ -514,13 +514,6 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
continue continue
} }
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(seqIdx, llm.DoneReasonLength)
nextBatch.seqs[seqIdx] = nil
continue
}
if !s.cache.enabled { if !s.cache.enabled {
seq.inputs = append(seq.cache.Inputs, seq.inputs...) seq.inputs = append(seq.cache.Inputs, seq.inputs...)
seq.cache.Inputs = []*input.Input{} seq.cache.Inputs = []*input.Input{}
@@ -709,7 +702,6 @@ func (s *Server) computeBatch(activeBatch batchState) {
continue continue
} }
seq.numPredicted++
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
seq.inputs = []*input.Input{nextToken} seq.inputs = []*input.Input{nextToken}
nextBatchTokens[i] = nextToken nextBatchTokens[i] = nextToken
@@ -745,7 +737,9 @@ func (s *Server) computeBatch(activeBatch batchState) {
logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i) logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
continue continue
} }
seq.lastUpdatedAt = t seq.lastUpdatedAt = t
seq.numPredicted++
if seq.numPredicted == 1 { if seq.numPredicted == 1 {
seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt) seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt)
seq.startedAt = seq.lastUpdatedAt seq.startedAt = seq.lastUpdatedAt
@@ -791,6 +785,13 @@ func (s *Server) computeBatch(activeBatch batchState) {
} }
seq.pendingResponses = append(seq.pendingResponses, piece) seq.pendingResponses = append(seq.pendingResponses, piece)
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(i, llm.DoneReasonLength)
continue
}
sequence := strings.Join(seq.pendingResponses, "") sequence := strings.Join(seq.pendingResponses, "")
if ok, stop := common.FindStop(sequence, seq.stop); ok { if ok, stop := common.FindStop(sequence, seq.stop); ok {