diff --git a/x/mlxrunner/client.go b/x/mlxrunner/client.go index c4f8c77ce..f1a0e4cca 100644 --- a/x/mlxrunner/client.go +++ b/x/mlxrunner/client.go @@ -186,11 +186,13 @@ type completionRequest struct { } type completionOpts struct { - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - MinP float32 `json:"min_p,omitempty"` - TopK int `json:"top_k,omitempty"` - NumPredict int `json:"num_predict,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + MinP float32 `json:"min_p,omitempty"` + TopK int `json:"top_k,omitempty"` + RepeatLastN int `json:"repeat_last_n,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + NumPredict int `json:"num_predict,omitempty"` } type CompletionResponse struct { @@ -232,11 +234,13 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f } if req.Options != nil { creq.Options = &completionOpts{ - Temperature: req.Options.Temperature, - TopP: req.Options.TopP, - MinP: req.Options.MinP, - TopK: req.Options.TopK, - NumPredict: req.Options.NumPredict, + Temperature: req.Options.Temperature, + TopP: req.Options.TopP, + MinP: req.Options.MinP, + TopK: req.Options.TopK, + RepeatLastN: req.Options.RepeatLastN, + PresencePenalty: req.Options.PresencePenalty, + NumPredict: req.Options.NumPredict, } } diff --git a/x/mlxrunner/mlx/ops.go b/x/mlxrunner/mlx/ops.go index 011a42319..2f97ba8d2 100644 --- a/x/mlxrunner/mlx/ops.go +++ b/x/mlxrunner/mlx/ops.go @@ -87,6 +87,12 @@ func (t *Array) Concatenate(axis int, others ...*Array) *Array { return out } +func (t *Array) Cumsum(axis int, reverse, inclusive bool) *Array { + out := New("CUMSUM") + C.mlx_cumsum(&out.ctx, t.ctx, C.int(axis), C.bool(reverse), C.bool(inclusive), DefaultStream().ctx) + return out +} + func (t *Array) Divide(other *Array) *Array { out := New("DIVIDE") C.mlx_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx) @@ -129,6 +135,12 @@ func (t *Array) Logsumexp(keepDims bool) *Array { return out } +func (t *Array) Less(other *Array) *Array { + out := New("LESS") + C.mlx_less(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx) + return out +} + func (t *Array) Matmul(other *Array) *Array { out := New("MATMUL") C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx) diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index dbf1e182d..852b04dcc 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -42,6 +42,9 @@ func (r *Runner) TextGenerationPipeline(request Request) error { ) defer func() { + if request.Sampler != nil { + request.Sampler.Free() + } mlx.Unpin(sample, logprobs) mlx.Unpin(nextSample, nextLogprobs) mlx.Sweep() @@ -74,6 +77,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error { request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate) } + request.Sampler.ResetHistory(inputs) + session := r.cache.begin(r.Model, inputs) defer session.close() caches := session.caches @@ -113,7 +118,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error { logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1) logprobs := logits.Subtract(logits.Logsumexp(true)) - sample := request.Sample(logprobs) + sample := request.Sampler.Sample(logprobs) mlx.Pin(sample, logprobs) mlx.Sweep() @@ -132,6 +137,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error { return err } + request.Sampler.AppendToken(sample) nextSample, nextLogprobs = step(sample) if i == 0 { diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go index 5fe06bcd5..acaef79bf 100644 --- a/x/mlxrunner/runner.go +++ b/x/mlxrunner/runner.go @@ -27,17 +27,19 @@ type Request struct { Ctx context.Context - sample.Sampler + Sampler *sample.Sampler } type TextCompletionsRequest struct { Prompt string `json:"prompt"` Options struct { - Temperature float32 `json:"temperature"` - TopP float32 `json:"top_p"` - MinP float32 `json:"min_p"` - TopK int `json:"top_k"` - MaxTokens int `json:"max_tokens"` + Temperature float32 `json:"temperature"` + TopP float32 `json:"top_p"` + MinP float32 `json:"min_p"` + TopK int `json:"top_k"` + RepeatLastN int `json:"repeat_last_n"` + PresencePenalty float32 `json:"presence_penalty"` + MaxTokens int `json:"max_tokens"` // Deprecated: use MaxTokens instead NumPredict int `json:"num_predict"` diff --git a/x/mlxrunner/sample/sample.go b/x/mlxrunner/sample/sample.go index b0656973f..a25b23d03 100644 --- a/x/mlxrunner/sample/sample.go +++ b/x/mlxrunner/sample/sample.go @@ -8,70 +8,184 @@ import ( "github.com/ollama/ollama/x/mlxrunner/mlx" ) -type Sampler interface { - Sample(*mlx.Array) *mlx.Array +type Transform func(*Sampler, *mlx.Array) *mlx.Array + +type Sampler struct { + Temperature float32 + TopP float32 + MinP float32 + TopK int + RepeatLastN int + PresencePenalty float32 + + history *mlx.Array + historyLen int + transforms []Transform } -func New(temp, top_p, min_p float32, top_k int) Sampler { - if temp == 0 { - return greedy{} +func New(temp, top_p, min_p float32, top_k, repeatLastN int, presencePenalty float32) *Sampler { + s := &Sampler{ + Temperature: temp, + TopP: top_p, + MinP: min_p, + TopK: top_k, + RepeatLastN: repeatLastN, + PresencePenalty: presencePenalty, + } + + var transforms []Transform + if presencePenalty != 0 { + transforms = append(transforms, penalty) } - var samplers []Sampler if top_p > 0 && top_p < 1 { - samplers = append(samplers, TopP(top_p)) + transforms = append(transforms, topP) } if min_p != 0 { - samplers = append(samplers, MinP(min_p)) + transforms = append(transforms, minP) } if top_k > 0 { - samplers = append(samplers, TopK(top_k)) + transforms = append(transforms, topK) } - samplers = append(samplers, Temperature(temp)) - return chain(samplers) + if temp == 0 { + transforms = append(transforms, greedy) + } else { + transforms = append(transforms, temperature) + } + + s.transforms = transforms + return s } -type greedy struct{} - -func (greedy) Sample(logits *mlx.Array) *mlx.Array { - return logits.Argmax(-1, false) +func (s *Sampler) usesHistory() bool { + return s.PresencePenalty != 0 } -type chain []Sampler +func (s *Sampler) setHistory(history *mlx.Array, historyLen int) { + if history != nil { + mlx.Pin(history) + } + if s.history != nil { + mlx.Unpin(s.history) + } + s.history = history + s.historyLen = historyLen +} -func (c chain) Sample(logits *mlx.Array) *mlx.Array { - for _, sampler := range c { - logits = sampler.Sample(logits) +func (s *Sampler) ResetHistory(history []int32) { + if !s.usesHistory() { + return + } + if s.RepeatLastN > 0 && len(history) > s.RepeatLastN { + history = history[len(history)-s.RepeatLastN:] + } + if len(history) == 0 { + s.setHistory(nil, 0) + return + } + + tokens := append([]int32(nil), history...) + s.setHistory(mlx.NewArrayInt32(tokens, []int32{int32(len(tokens))}), len(tokens)) +} + +func (s *Sampler) AppendToken(token *mlx.Array) { + if !s.usesHistory() || token == nil { + return + } + + next := token.AsType(mlx.DTypeInt32) + nextLen := next.Size() + + if s.history != nil && s.historyLen > 0 { + next = s.history.Concatenate(0, next) + nextLen += s.historyLen + } + + if s.RepeatLastN > 0 && nextLen > s.RepeatLastN { + trim := nextLen - s.RepeatLastN + next = next.Slice(mlx.Slice(trim, nextLen)) + nextLen = s.RepeatLastN + } + + s.setHistory(next, nextLen) +} + +func (s *Sampler) Free() { + s.setHistory(nil, 0) +} + +func (s *Sampler) Sample(logits *mlx.Array) *mlx.Array { + for _, transform := range s.transforms { + logits = transform(s, logits) } return logits } -type Temperature float32 - -func (t Temperature) Sample(logits *mlx.Array) *mlx.Array { - return mlx.DivScalar(logits, float32(t)).Categorical(-1) +func greedy(_ *Sampler, logits *mlx.Array) *mlx.Array { + return logits.Argmax(-1, false) } -type TopP float32 - -func (p TopP) Sample(logprobs *mlx.Array) *mlx.Array { - // TODO: implement - return logprobs +func temperature(s *Sampler, logits *mlx.Array) *mlx.Array { + return mlx.DivScalar(logits, s.Temperature).Categorical(-1) } -type MinP float32 +func topP(s *Sampler, logprobs *mlx.Array) *mlx.Array { + if s.TopP <= 0 || s.TopP >= 1 { + return logprobs + } -func (p MinP) Sample(logprobs *mlx.Array) *mlx.Array { - // TODO: implement - return logprobs + order := logprobs.Negative().ArgsortAxis(-1) + sortedLogprobs := logprobs.TakeAlongAxis(order, -1) + sortedProbs := mlx.SoftmaxAxis(sortedLogprobs, -1, true) + prevCumProbs := sortedProbs.Cumsum(-1, false, true).Subtract(sortedProbs) + keep := prevCumProbs.Less(mlx.FromValue(s.TopP)) + filtered := mlx.Where(keep, sortedLogprobs, mlx.FromValue(float32(math.Inf(-1)))) + return logprobs.PutAlongAxis(order, filtered, -1) } -type TopK int +func minP(s *Sampler, logprobs *mlx.Array) *mlx.Array { + if s.MinP <= 0 || s.MinP > 1 { + return logprobs + } -func (k TopK) Sample(logprobs *mlx.Array) *mlx.Array { - mask := logprobs.Negative().ArgpartitionAxis(int(k)-1, -1).Slice(mlx.Slice(), mlx.Slice(int(k), 0)) + maxLogprobs := logprobs.TakeAlongAxis(logprobs.Argmax(-1, true), -1) + minLogprobs := mlx.AddScalar(maxLogprobs, float32(math.Log(float64(s.MinP)))) + + return mlx.Where( + logprobs.Less(minLogprobs), + mlx.FromValue(float32(math.Inf(-1))), + logprobs, + ) +} + +func topK(s *Sampler, logprobs *mlx.Array) *mlx.Array { + if s.TopK <= 0 { + return logprobs + } + + vocab := logprobs.Dim(logprobs.NumDims() - 1) + if s.TopK >= vocab { + return logprobs + } + + mask := logprobs.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, 0)) return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1) } + +func penalty(s *Sampler, logprobs *mlx.Array) *mlx.Array { + if s.history == nil || s.historyLen == 0 || s.PresencePenalty == 0 { + return logprobs + } + + tokenIndices := s.history + if logprobs.NumDims() > 1 { + tokenIndices = tokenIndices.ExpandDims(0) + } + + selected := logprobs.TakeAlongAxis(tokenIndices, -1) + adjusted := mlx.AddScalar(selected, -s.PresencePenalty) + return logprobs.PutAlongAxis(tokenIndices, adjusted, -1) +} diff --git a/x/mlxrunner/sample/sample_test.go b/x/mlxrunner/sample/sample_test.go new file mode 100644 index 000000000..de9fe2276 --- /dev/null +++ b/x/mlxrunner/sample/sample_test.go @@ -0,0 +1,62 @@ +//go:build mlx + +package sample + +import ( + "math" + "testing" + + "github.com/ollama/ollama/x/mlxrunner/mlx" +) + +func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) { + // RepeatLastN = 1, PresencePenalty = 6 + s := New(0, 0, 0, 0, 1, 6) + defer func() { + s.Free() + mlx.Sweep() + }() + + s.ResetHistory([]int32{0}) + s.AppendToken(mlx.NewArrayInt32([]int32{1}, []int32{1})) + + logprobs := mlx.FromValues([]float32{0, 5, 4}, 3) + got := s.Sample(logprobs) + mlx.Eval(got) + + // logprobs will be [0, -1, 4] after the penalty + // and then (index) 2 after the greedy sampler + gotInt := got.Int() + if gotInt != 2 { + t.Fatalf("got %d, want 2", gotInt) + } +} + +func TestMinPMasksTokensBelowThreshold(t *testing.T) { + s := New(0, 0, 0.5, 0, 0, 0) + defer func() { + s.Free() + mlx.Sweep() + }() + + logprobs := mlx.FromValues([]float32{ + float32(math.Log(0.5)), + float32(math.Log(0.3)), + float32(math.Log(0.2)), + }, 3) + got := minP(s, logprobs) + mlx.Eval(got) + + gotFloats := got.Floats() + if len(gotFloats) != 3 { + t.Fatalf("got %d scores, want 3", len(gotFloats)) + } + + if math.IsInf(float64(gotFloats[0]), -1) || math.IsInf(float64(gotFloats[1]), -1) { + t.Fatalf("kept tokens were masked: %v", gotFloats) + } + + if !math.IsInf(float64(gotFloats[2]), -1) { + t.Fatalf("lowest-probability token should be masked, got %v", gotFloats) + } +} diff --git a/x/mlxrunner/server.go b/x/mlxrunner/server.go index 436b47e59..9c7d7e775 100644 --- a/x/mlxrunner/server.go +++ b/x/mlxrunner/server.go @@ -96,6 +96,8 @@ func Execute(args []string) error { request.Options.TopP, request.Options.MinP, request.Options.TopK, + request.Options.RepeatLastN, + request.Options.PresencePenalty, ) var cancel context.CancelFunc