runner: add token history sampling parameters to ollama runner (#14537)

This commit is contained in:
Jeffrey Morgan
2026-03-01 19:16:07 -08:00
committed by GitHub
parent 3490e9590b
commit 86513cb697
8 changed files with 193 additions and 15 deletions

View File

@@ -16,24 +16,49 @@ type token struct {
value float32 // The raw logit or probability from the model
}
const DefaultPenaltyLookback = 64
type Sampler struct {
rng *rand.Rand
topK int
topP float32
minP float32
temperature float32
repeat float32
presence float32
frequency float32
history []int32
grammar *GrammarSampler
}
func (s *Sampler) Reset() {
s.history = s.history[:0]
}
func (s *Sampler) Accept(token int32) {
s.history = append(s.history, token)
if len(s.history) > DefaultPenaltyLookback {
copy(s.history, s.history[len(s.history)-DefaultPenaltyLookback:])
s.history = s.history[:DefaultPenaltyLookback]
}
}
func (s *Sampler) Sample(logits []float32) (int32, error) {
if len(logits) == 0 {
return -1, errors.New("sample: no logits provided to sample")
}
counts := tokenCounts(s.history, len(logits))
tokens := make([]token, len(logits))
for i := range logits {
value := logits[i]
if count := counts[int32(i)]; count > 0 {
value = applyPenalty(value, count, s.repeat, s.presence, s.frequency)
}
tokens[i].id = int32(i)
tokens[i].value = logits[i]
tokens[i].value = value
}
t, err := s.sample(tokens)
@@ -55,8 +80,12 @@ func (s *Sampler) Sample(logits []float32) (int32, error) {
// we need to reset them before applying the grammar and
// sampling again
for i := range logits {
value := logits[i]
if count := counts[int32(i)]; count > 0 {
value = applyPenalty(value, count, s.repeat, s.presence, s.frequency)
}
tokens[i].id = int32(i)
tokens[i].value = logits[i]
tokens[i].value = value
}
s.grammar.Apply(tokens)
t, err = s.sample(tokens)
@@ -127,7 +156,7 @@ func (s *Sampler) sample(tokens []token) (token, error) {
}
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *GrammarSampler) Sampler {
func NewSampler(temperature float32, topK int, topP float32, minP float32, repeatPenalty float32, presencePenalty float32, frequencyPenalty float32, seed int, grammar *GrammarSampler) Sampler {
var rng *rand.Rand
if seed != -1 {
// PCG requires two parameters: sequence and stream
@@ -154,12 +183,19 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
minP = 1.0
}
if repeatPenalty <= 0 {
repeatPenalty = 1.0
}
return Sampler{
rng: rng,
topK: topK,
topP: topP,
minP: minP,
temperature: temperature,
repeat: repeatPenalty,
presence: presencePenalty,
frequency: frequencyPenalty,
grammar: grammar,
}
}

View File

@@ -16,7 +16,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
logits[i] = float32(rand.Float64()*10 - 5)
}
sampler := NewSampler(0.8, 0, 0, 0, 42, nil)
sampler := NewSampler(0.8, 0, 0, 0, 1, 0, 0, 42, nil)
b.ResetTimer()
for b.Loop() {
sampler.Sample(logits)
@@ -49,7 +49,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
for _, tc := range configs {
b.Run("Config"+tc.name, func(b *testing.B) {
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil)
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, 1, 0, 0, tc.seed, nil)
sampler.Sample(logits)
b.ResetTimer()
@@ -62,7 +62,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
// Test with combined transforms separately - topK influences performance greatly
b.Run("TransformCombined", func(b *testing.B) {
sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil)
sampler := NewSampler(0.8, 50, 0.9, 0.05, 1, 0, 0, 42, nil)
b.ResetTimer()
for b.Loop() {
@@ -81,7 +81,7 @@ func BenchmarkGreedySampler(b *testing.B) {
logits[i] = float32(rand.Float64()*10 - 5)
}
sampler := NewSampler(0, -1, 0, 0, -1, nil)
sampler := NewSampler(0, -1, 0, 0, 1, 0, 0, -1, nil)
b.ResetTimer()
for b.Loop() {

View File

@@ -13,7 +13,7 @@ import (
func TestWeighted(t *testing.T) {
logits := []float32{-10, 3, -10, -10}
sampler := NewSampler(0, 0, 0, 0, 0, nil)
sampler := NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil)
got, err := sampler.Sample(logits)
if err != nil {
t.Error(err)
@@ -25,7 +25,7 @@ func TestWeighted(t *testing.T) {
}
logits = []float32{-100, -10, 0, 10}
sampler = NewSampler(0, 0, 0, 0, 0, nil)
sampler = NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil)
got, err = sampler.Sample(logits)
if err != nil {
t.Error(err)
@@ -39,7 +39,7 @@ func TestWeighted(t *testing.T) {
// Test very high p
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
// Use extremely small topP to filter out all tokens
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
sampler = NewSampler(1.0, 0, 1e-10, 0, 1, 0, 0, 0, nil)
got, err = sampler.Sample(logits)
if err != nil {
t.Error(err)
@@ -52,7 +52,7 @@ func TestWeighted(t *testing.T) {
}
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
sampler = NewSampler(1, 0, 0.95, 0.05, 1, 0, 0, 0, nil)
got, err = sampler.Sample(logits)
if err == nil {
t.Errorf("expected error, got %d", got)
@@ -151,8 +151,8 @@ func TestGrammar(t *testing.T) {
func BenchmarkSample(b *testing.B) {
samplers := map[string]Sampler{
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
"Greedy": NewSampler(0, 0, 0, 0, 1, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, 1, 0, 0, -1, nil),
}
// Generate random logits for benchmarking

View File

@@ -25,6 +25,48 @@ func (h *tokenHeap) Pop() any {
return x
}
func tokenCounts(history []int32, vocabSize int) map[int32]int {
if len(history) == 0 {
return nil
}
start := 0
if len(history) > DefaultPenaltyLookback {
start = len(history) - DefaultPenaltyLookback
}
counts := make(map[int32]int, len(history)-start)
for _, token := range history[start:] {
if token < 0 || int(token) >= vocabSize {
continue
}
counts[token]++
}
return counts
}
func applyPenalty(logit float32, count int, repeatPenalty float32, presencePenalty float32, frequencyPenalty float32) float32 {
if repeatPenalty != 1.0 {
// Preserve ordering for negative logits when applying repeat penalty.
if logit < 0 {
logit *= repeatPenalty
} else {
logit /= repeatPenalty
}
}
if frequencyPenalty != 0 {
logit -= float32(count) * frequencyPenalty
}
if presencePenalty != 0 {
logit -= presencePenalty
}
return logit
}
// temperature applies scaling to the logits
func temperature(ts []token, temp float32) {
// Ensure temperature clipping near 0 to avoid numerical instability

View File

@@ -295,6 +295,86 @@ func TestMinP(t *testing.T) {
}
}
func TestTokenCounts(t *testing.T) {
history := make([]int32, 70)
history[0] = 7
history[69] = 7
counts := tokenCounts(history, 8)
if got := counts[7]; got != 1 {
t.Fatalf("lookback mismatch: got %d want %d", got, 1)
}
}
func TestApplyPenalty(t *testing.T) {
logit := applyPenalty(5.0, 3, 1.0, 1.5, 0.5)
if math.Abs(float64(logit-2.0)) > 1e-6 {
t.Fatalf("unexpected penalty result: got %f want %f", logit, 2.0)
}
logit = applyPenalty(4.0, 1, 2.0, 0, 0)
if math.Abs(float64(logit-2.0)) > 1e-6 {
t.Fatalf("unexpected repeat penalty result for positive logits: got %f want %f", logit, 2.0)
}
logit = applyPenalty(-4.0, 1, 2.0, 0, 0)
if math.Abs(float64(logit-(-8.0))) > 1e-6 {
t.Fatalf("unexpected repeat penalty result for negative logits: got %f want %f", logit, -8.0)
}
}
func TestSamplerPresencePenalty(t *testing.T) {
logits := []float32{0.0, 5.0, 0.0}
baseline := NewSampler(0, 0, 1, 0, 1, 0, 0, -1, nil)
baseline.Accept(1)
got, err := baseline.Sample(logits)
if err != nil {
t.Fatal(err)
}
if got != 1 {
t.Fatalf("unexpected baseline token: got %d want %d", got, 1)
}
presence := NewSampler(0, 0, 1, 0, 1, 6, 0, -1, nil)
presence.Accept(1)
got, err = presence.Sample(logits)
if err != nil {
t.Fatal(err)
}
if got == 1 {
t.Fatalf("presence penalty did not change repeated token selection")
}
}
func TestSamplerFrequencyPenalty(t *testing.T) {
logits := []float32{0.0, 5.0, 4.0}
baseline := NewSampler(0, 0, 1, 0, 1, 0, 0, -1, nil)
baseline.Accept(1)
baseline.Accept(1)
baseline.Accept(1)
got, err := baseline.Sample(logits)
if err != nil {
t.Fatal(err)
}
if got != 1 {
t.Fatalf("unexpected baseline token: got %d want %d", got, 1)
}
frequency := NewSampler(0, 0, 1, 0, 1, 0, 1.0, -1, nil)
frequency.Accept(1)
frequency.Accept(1)
frequency.Accept(1)
got, err = frequency.Sample(logits)
if err != nil {
t.Fatal(err)
}
if got != 2 {
t.Fatalf("frequency penalty did not demote repeated token as expected: got %d want %d", got, 2)
}
}
func BenchmarkTransforms(b *testing.B) {
// Generate random logits
tokens := make([]token, 1<<16)