mirror of
https://github.com/ollama/ollama.git
synced 2026-04-30 16:08:07 -05:00
x/mlxrunner: replace sampler interface chain with single stateful Sampler (#14652)
- Collapse MLX sampling state into a single sample.Sampler struct (options + history). - Replace interface-based sampler chain (TopP, TopK, penalty, etc.) with function-based transforms. - Update request/pipeline wiring to use *sample.Sampler, seed history from prompt tokens, and append generated tokens each step. - Implement top_p, min_p, repeat_penalty, and frequency_penalty
This commit is contained in:
@@ -186,11 +186,13 @@ type completionRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type completionOpts struct {
|
type completionOpts struct {
|
||||||
Temperature float32 `json:"temperature,omitempty"`
|
Temperature float32 `json:"temperature,omitempty"`
|
||||||
TopP float32 `json:"top_p,omitempty"`
|
TopP float32 `json:"top_p,omitempty"`
|
||||||
MinP float32 `json:"min_p,omitempty"`
|
MinP float32 `json:"min_p,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
NumPredict int `json:"num_predict,omitempty"`
|
RepeatLastN int `json:"repeat_last_n,omitempty"`
|
||||||
|
PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
||||||
|
NumPredict int `json:"num_predict,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type CompletionResponse struct {
|
type CompletionResponse struct {
|
||||||
@@ -232,11 +234,13 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
}
|
}
|
||||||
if req.Options != nil {
|
if req.Options != nil {
|
||||||
creq.Options = &completionOpts{
|
creq.Options = &completionOpts{
|
||||||
Temperature: req.Options.Temperature,
|
Temperature: req.Options.Temperature,
|
||||||
TopP: req.Options.TopP,
|
TopP: req.Options.TopP,
|
||||||
MinP: req.Options.MinP,
|
MinP: req.Options.MinP,
|
||||||
TopK: req.Options.TopK,
|
TopK: req.Options.TopK,
|
||||||
NumPredict: req.Options.NumPredict,
|
RepeatLastN: req.Options.RepeatLastN,
|
||||||
|
PresencePenalty: req.Options.PresencePenalty,
|
||||||
|
NumPredict: req.Options.NumPredict,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -87,6 +87,12 @@ func (t *Array) Concatenate(axis int, others ...*Array) *Array {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Array) Cumsum(axis int, reverse, inclusive bool) *Array {
|
||||||
|
out := New("CUMSUM")
|
||||||
|
C.mlx_cumsum(&out.ctx, t.ctx, C.int(axis), C.bool(reverse), C.bool(inclusive), DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Array) Divide(other *Array) *Array {
|
func (t *Array) Divide(other *Array) *Array {
|
||||||
out := New("DIVIDE")
|
out := New("DIVIDE")
|
||||||
C.mlx_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
C.mlx_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||||
@@ -129,6 +135,12 @@ func (t *Array) Logsumexp(keepDims bool) *Array {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Array) Less(other *Array) *Array {
|
||||||
|
out := New("LESS")
|
||||||
|
C.mlx_less(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Array) Matmul(other *Array) *Array {
|
func (t *Array) Matmul(other *Array) *Array {
|
||||||
out := New("MATMUL")
|
out := New("MATMUL")
|
||||||
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||||
|
|||||||
@@ -42,6 +42,9 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
)
|
)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
|
if request.Sampler != nil {
|
||||||
|
request.Sampler.Free()
|
||||||
|
}
|
||||||
mlx.Unpin(sample, logprobs)
|
mlx.Unpin(sample, logprobs)
|
||||||
mlx.Unpin(nextSample, nextLogprobs)
|
mlx.Unpin(nextSample, nextLogprobs)
|
||||||
mlx.Sweep()
|
mlx.Sweep()
|
||||||
@@ -74,6 +77,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
|
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
request.Sampler.ResetHistory(inputs)
|
||||||
|
|
||||||
session := r.cache.begin(r.Model, inputs)
|
session := r.cache.begin(r.Model, inputs)
|
||||||
defer session.close()
|
defer session.close()
|
||||||
caches := session.caches
|
caches := session.caches
|
||||||
@@ -113,7 +118,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||||
|
|
||||||
logprobs := logits.Subtract(logits.Logsumexp(true))
|
logprobs := logits.Subtract(logits.Logsumexp(true))
|
||||||
sample := request.Sample(logprobs)
|
sample := request.Sampler.Sample(logprobs)
|
||||||
|
|
||||||
mlx.Pin(sample, logprobs)
|
mlx.Pin(sample, logprobs)
|
||||||
mlx.Sweep()
|
mlx.Sweep()
|
||||||
@@ -132,6 +137,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
request.Sampler.AppendToken(sample)
|
||||||
nextSample, nextLogprobs = step(sample)
|
nextSample, nextLogprobs = step(sample)
|
||||||
|
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
|
|||||||
@@ -27,17 +27,19 @@ type Request struct {
|
|||||||
|
|
||||||
Ctx context.Context
|
Ctx context.Context
|
||||||
|
|
||||||
sample.Sampler
|
Sampler *sample.Sampler
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextCompletionsRequest struct {
|
type TextCompletionsRequest struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
Options struct {
|
Options struct {
|
||||||
Temperature float32 `json:"temperature"`
|
Temperature float32 `json:"temperature"`
|
||||||
TopP float32 `json:"top_p"`
|
TopP float32 `json:"top_p"`
|
||||||
MinP float32 `json:"min_p"`
|
MinP float32 `json:"min_p"`
|
||||||
TopK int `json:"top_k"`
|
TopK int `json:"top_k"`
|
||||||
MaxTokens int `json:"max_tokens"`
|
RepeatLastN int `json:"repeat_last_n"`
|
||||||
|
PresencePenalty float32 `json:"presence_penalty"`
|
||||||
|
MaxTokens int `json:"max_tokens"`
|
||||||
|
|
||||||
// Deprecated: use MaxTokens instead
|
// Deprecated: use MaxTokens instead
|
||||||
NumPredict int `json:"num_predict"`
|
NumPredict int `json:"num_predict"`
|
||||||
|
|||||||
@@ -8,70 +8,184 @@ import (
|
|||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Sampler interface {
|
type Transform func(*Sampler, *mlx.Array) *mlx.Array
|
||||||
Sample(*mlx.Array) *mlx.Array
|
|
||||||
|
type Sampler struct {
|
||||||
|
Temperature float32
|
||||||
|
TopP float32
|
||||||
|
MinP float32
|
||||||
|
TopK int
|
||||||
|
RepeatLastN int
|
||||||
|
PresencePenalty float32
|
||||||
|
|
||||||
|
history *mlx.Array
|
||||||
|
historyLen int
|
||||||
|
transforms []Transform
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(temp, top_p, min_p float32, top_k int) Sampler {
|
func New(temp, top_p, min_p float32, top_k, repeatLastN int, presencePenalty float32) *Sampler {
|
||||||
if temp == 0 {
|
s := &Sampler{
|
||||||
return greedy{}
|
Temperature: temp,
|
||||||
|
TopP: top_p,
|
||||||
|
MinP: min_p,
|
||||||
|
TopK: top_k,
|
||||||
|
RepeatLastN: repeatLastN,
|
||||||
|
PresencePenalty: presencePenalty,
|
||||||
|
}
|
||||||
|
|
||||||
|
var transforms []Transform
|
||||||
|
if presencePenalty != 0 {
|
||||||
|
transforms = append(transforms, penalty)
|
||||||
}
|
}
|
||||||
|
|
||||||
var samplers []Sampler
|
|
||||||
if top_p > 0 && top_p < 1 {
|
if top_p > 0 && top_p < 1 {
|
||||||
samplers = append(samplers, TopP(top_p))
|
transforms = append(transforms, topP)
|
||||||
}
|
}
|
||||||
|
|
||||||
if min_p != 0 {
|
if min_p != 0 {
|
||||||
samplers = append(samplers, MinP(min_p))
|
transforms = append(transforms, minP)
|
||||||
}
|
}
|
||||||
|
|
||||||
if top_k > 0 {
|
if top_k > 0 {
|
||||||
samplers = append(samplers, TopK(top_k))
|
transforms = append(transforms, topK)
|
||||||
}
|
}
|
||||||
|
|
||||||
samplers = append(samplers, Temperature(temp))
|
if temp == 0 {
|
||||||
return chain(samplers)
|
transforms = append(transforms, greedy)
|
||||||
|
} else {
|
||||||
|
transforms = append(transforms, temperature)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.transforms = transforms
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
type greedy struct{}
|
func (s *Sampler) usesHistory() bool {
|
||||||
|
return s.PresencePenalty != 0
|
||||||
func (greedy) Sample(logits *mlx.Array) *mlx.Array {
|
|
||||||
return logits.Argmax(-1, false)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type chain []Sampler
|
func (s *Sampler) setHistory(history *mlx.Array, historyLen int) {
|
||||||
|
if history != nil {
|
||||||
|
mlx.Pin(history)
|
||||||
|
}
|
||||||
|
if s.history != nil {
|
||||||
|
mlx.Unpin(s.history)
|
||||||
|
}
|
||||||
|
s.history = history
|
||||||
|
s.historyLen = historyLen
|
||||||
|
}
|
||||||
|
|
||||||
func (c chain) Sample(logits *mlx.Array) *mlx.Array {
|
func (s *Sampler) ResetHistory(history []int32) {
|
||||||
for _, sampler := range c {
|
if !s.usesHistory() {
|
||||||
logits = sampler.Sample(logits)
|
return
|
||||||
|
}
|
||||||
|
if s.RepeatLastN > 0 && len(history) > s.RepeatLastN {
|
||||||
|
history = history[len(history)-s.RepeatLastN:]
|
||||||
|
}
|
||||||
|
if len(history) == 0 {
|
||||||
|
s.setHistory(nil, 0)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens := append([]int32(nil), history...)
|
||||||
|
s.setHistory(mlx.NewArrayInt32(tokens, []int32{int32(len(tokens))}), len(tokens))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Sampler) AppendToken(token *mlx.Array) {
|
||||||
|
if !s.usesHistory() || token == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next := token.AsType(mlx.DTypeInt32)
|
||||||
|
nextLen := next.Size()
|
||||||
|
|
||||||
|
if s.history != nil && s.historyLen > 0 {
|
||||||
|
next = s.history.Concatenate(0, next)
|
||||||
|
nextLen += s.historyLen
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.RepeatLastN > 0 && nextLen > s.RepeatLastN {
|
||||||
|
trim := nextLen - s.RepeatLastN
|
||||||
|
next = next.Slice(mlx.Slice(trim, nextLen))
|
||||||
|
nextLen = s.RepeatLastN
|
||||||
|
}
|
||||||
|
|
||||||
|
s.setHistory(next, nextLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Sampler) Free() {
|
||||||
|
s.setHistory(nil, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Sampler) Sample(logits *mlx.Array) *mlx.Array {
|
||||||
|
for _, transform := range s.transforms {
|
||||||
|
logits = transform(s, logits)
|
||||||
}
|
}
|
||||||
return logits
|
return logits
|
||||||
}
|
}
|
||||||
|
|
||||||
type Temperature float32
|
func greedy(_ *Sampler, logits *mlx.Array) *mlx.Array {
|
||||||
|
return logits.Argmax(-1, false)
|
||||||
func (t Temperature) Sample(logits *mlx.Array) *mlx.Array {
|
|
||||||
return mlx.DivScalar(logits, float32(t)).Categorical(-1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TopP float32
|
func temperature(s *Sampler, logits *mlx.Array) *mlx.Array {
|
||||||
|
return mlx.DivScalar(logits, s.Temperature).Categorical(-1)
|
||||||
func (p TopP) Sample(logprobs *mlx.Array) *mlx.Array {
|
|
||||||
// TODO: implement
|
|
||||||
return logprobs
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type MinP float32
|
func topP(s *Sampler, logprobs *mlx.Array) *mlx.Array {
|
||||||
|
if s.TopP <= 0 || s.TopP >= 1 {
|
||||||
|
return logprobs
|
||||||
|
}
|
||||||
|
|
||||||
func (p MinP) Sample(logprobs *mlx.Array) *mlx.Array {
|
order := logprobs.Negative().ArgsortAxis(-1)
|
||||||
// TODO: implement
|
sortedLogprobs := logprobs.TakeAlongAxis(order, -1)
|
||||||
return logprobs
|
sortedProbs := mlx.SoftmaxAxis(sortedLogprobs, -1, true)
|
||||||
|
prevCumProbs := sortedProbs.Cumsum(-1, false, true).Subtract(sortedProbs)
|
||||||
|
keep := prevCumProbs.Less(mlx.FromValue(s.TopP))
|
||||||
|
filtered := mlx.Where(keep, sortedLogprobs, mlx.FromValue(float32(math.Inf(-1))))
|
||||||
|
return logprobs.PutAlongAxis(order, filtered, -1)
|
||||||
}
|
}
|
||||||
|
|
||||||
type TopK int
|
func minP(s *Sampler, logprobs *mlx.Array) *mlx.Array {
|
||||||
|
if s.MinP <= 0 || s.MinP > 1 {
|
||||||
|
return logprobs
|
||||||
|
}
|
||||||
|
|
||||||
func (k TopK) Sample(logprobs *mlx.Array) *mlx.Array {
|
maxLogprobs := logprobs.TakeAlongAxis(logprobs.Argmax(-1, true), -1)
|
||||||
mask := logprobs.Negative().ArgpartitionAxis(int(k)-1, -1).Slice(mlx.Slice(), mlx.Slice(int(k), 0))
|
minLogprobs := mlx.AddScalar(maxLogprobs, float32(math.Log(float64(s.MinP))))
|
||||||
|
|
||||||
|
return mlx.Where(
|
||||||
|
logprobs.Less(minLogprobs),
|
||||||
|
mlx.FromValue(float32(math.Inf(-1))),
|
||||||
|
logprobs,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func topK(s *Sampler, logprobs *mlx.Array) *mlx.Array {
|
||||||
|
if s.TopK <= 0 {
|
||||||
|
return logprobs
|
||||||
|
}
|
||||||
|
|
||||||
|
vocab := logprobs.Dim(logprobs.NumDims() - 1)
|
||||||
|
if s.TopK >= vocab {
|
||||||
|
return logprobs
|
||||||
|
}
|
||||||
|
|
||||||
|
mask := logprobs.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, 0))
|
||||||
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
|
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func penalty(s *Sampler, logprobs *mlx.Array) *mlx.Array {
|
||||||
|
if s.history == nil || s.historyLen == 0 || s.PresencePenalty == 0 {
|
||||||
|
return logprobs
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenIndices := s.history
|
||||||
|
if logprobs.NumDims() > 1 {
|
||||||
|
tokenIndices = tokenIndices.ExpandDims(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
selected := logprobs.TakeAlongAxis(tokenIndices, -1)
|
||||||
|
adjusted := mlx.AddScalar(selected, -s.PresencePenalty)
|
||||||
|
return logprobs.PutAlongAxis(tokenIndices, adjusted, -1)
|
||||||
|
}
|
||||||
|
|||||||
62
x/mlxrunner/sample/sample_test.go
Normal file
62
x/mlxrunner/sample/sample_test.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) {
|
||||||
|
// RepeatLastN = 1, PresencePenalty = 6
|
||||||
|
s := New(0, 0, 0, 0, 1, 6)
|
||||||
|
defer func() {
|
||||||
|
s.Free()
|
||||||
|
mlx.Sweep()
|
||||||
|
}()
|
||||||
|
|
||||||
|
s.ResetHistory([]int32{0})
|
||||||
|
s.AppendToken(mlx.NewArrayInt32([]int32{1}, []int32{1}))
|
||||||
|
|
||||||
|
logprobs := mlx.FromValues([]float32{0, 5, 4}, 3)
|
||||||
|
got := s.Sample(logprobs)
|
||||||
|
mlx.Eval(got)
|
||||||
|
|
||||||
|
// logprobs will be [0, -1, 4] after the penalty
|
||||||
|
// and then (index) 2 after the greedy sampler
|
||||||
|
gotInt := got.Int()
|
||||||
|
if gotInt != 2 {
|
||||||
|
t.Fatalf("got %d, want 2", gotInt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMinPMasksTokensBelowThreshold(t *testing.T) {
|
||||||
|
s := New(0, 0, 0.5, 0, 0, 0)
|
||||||
|
defer func() {
|
||||||
|
s.Free()
|
||||||
|
mlx.Sweep()
|
||||||
|
}()
|
||||||
|
|
||||||
|
logprobs := mlx.FromValues([]float32{
|
||||||
|
float32(math.Log(0.5)),
|
||||||
|
float32(math.Log(0.3)),
|
||||||
|
float32(math.Log(0.2)),
|
||||||
|
}, 3)
|
||||||
|
got := minP(s, logprobs)
|
||||||
|
mlx.Eval(got)
|
||||||
|
|
||||||
|
gotFloats := got.Floats()
|
||||||
|
if len(gotFloats) != 3 {
|
||||||
|
t.Fatalf("got %d scores, want 3", len(gotFloats))
|
||||||
|
}
|
||||||
|
|
||||||
|
if math.IsInf(float64(gotFloats[0]), -1) || math.IsInf(float64(gotFloats[1]), -1) {
|
||||||
|
t.Fatalf("kept tokens were masked: %v", gotFloats)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !math.IsInf(float64(gotFloats[2]), -1) {
|
||||||
|
t.Fatalf("lowest-probability token should be masked, got %v", gotFloats)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -96,6 +96,8 @@ func Execute(args []string) error {
|
|||||||
request.Options.TopP,
|
request.Options.TopP,
|
||||||
request.Options.MinP,
|
request.Options.MinP,
|
||||||
request.Options.TopK,
|
request.Options.TopK,
|
||||||
|
request.Options.RepeatLastN,
|
||||||
|
request.Options.PresencePenalty,
|
||||||
)
|
)
|
||||||
|
|
||||||
var cancel context.CancelFunc
|
var cancel context.CancelFunc
|
||||||
|
|||||||
Reference in New Issue
Block a user