mirror of
https://github.com/ollama/ollama.git
synced 2026-03-08 23:04:13 -05:00
- Collapse MLX sampling state into a single sample.Sampler struct (options + history). - Replace interface-based sampler chain (TopP, TopK, penalty, etc.) with function-based transforms. - Update request/pipeline wiring to use *sample.Sampler, seed history from prompt tokens, and append generated tokens each step. - Implement top_p, min_p, repeat_penalty, and frequency_penalty
63 lines
1.3 KiB
Go
63 lines
1.3 KiB
Go
//go:build mlx
|
|
|
|
package sample
|
|
|
|
import (
|
|
"math"
|
|
"testing"
|
|
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
)
|
|
|
|
func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) {
|
|
// RepeatLastN = 1, PresencePenalty = 6
|
|
s := New(0, 0, 0, 0, 1, 6)
|
|
defer func() {
|
|
s.Free()
|
|
mlx.Sweep()
|
|
}()
|
|
|
|
s.ResetHistory([]int32{0})
|
|
s.AppendToken(mlx.NewArrayInt32([]int32{1}, []int32{1}))
|
|
|
|
logprobs := mlx.FromValues([]float32{0, 5, 4}, 3)
|
|
got := s.Sample(logprobs)
|
|
mlx.Eval(got)
|
|
|
|
// logprobs will be [0, -1, 4] after the penalty
|
|
// and then (index) 2 after the greedy sampler
|
|
gotInt := got.Int()
|
|
if gotInt != 2 {
|
|
t.Fatalf("got %d, want 2", gotInt)
|
|
}
|
|
}
|
|
|
|
func TestMinPMasksTokensBelowThreshold(t *testing.T) {
|
|
s := New(0, 0, 0.5, 0, 0, 0)
|
|
defer func() {
|
|
s.Free()
|
|
mlx.Sweep()
|
|
}()
|
|
|
|
logprobs := mlx.FromValues([]float32{
|
|
float32(math.Log(0.5)),
|
|
float32(math.Log(0.3)),
|
|
float32(math.Log(0.2)),
|
|
}, 3)
|
|
got := minP(s, logprobs)
|
|
mlx.Eval(got)
|
|
|
|
gotFloats := got.Floats()
|
|
if len(gotFloats) != 3 {
|
|
t.Fatalf("got %d scores, want 3", len(gotFloats))
|
|
}
|
|
|
|
if math.IsInf(float64(gotFloats[0]), -1) || math.IsInf(float64(gotFloats[1]), -1) {
|
|
t.Fatalf("kept tokens were masked: %v", gotFloats)
|
|
}
|
|
|
|
if !math.IsInf(float64(gotFloats[2]), -1) {
|
|
t.Fatalf("lowest-probability token should be masked, got %v", gotFloats)
|
|
}
|
|
}
|