mirror of
https://github.com/ollama/ollama.git
synced 2026-03-09 03:12:11 -05:00
compile expert select
This commit is contained in:
@@ -194,21 +194,39 @@ type Gate struct {
|
||||
CorrectionBias mlx.Array `weight:"gate.e_score_correction_bias"`
|
||||
}
|
||||
|
||||
func (m Gate) Forward(h *mlx.Array, opts Options) (indices, scores *mlx.Array) {
|
||||
scores = m.Gate.Forward(h).AsType(mlx.DTypeFloat32).Sigmoid()
|
||||
original := scores
|
||||
scores = scores.Add(&m.CorrectionBias)
|
||||
var expertSelect *mlx.Closure
|
||||
|
||||
indices = scores.Negative().ArgpartitionAxis(opts.NumExpertsPerTok-1, -1)
|
||||
indices = indices.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, opts.NumExpertsPerTok))
|
||||
func ExpertSelect(opts Options) *mlx.Closure {
|
||||
if expertSelect == nil {
|
||||
expertSelect = mlx.Compile(func(inputs []*mlx.Array) []*mlx.Array {
|
||||
scores, correctionBias := inputs[0], inputs[1]
|
||||
|
||||
scores = original.TakeAlongAxis(indices, -1)
|
||||
if opts.NumExpertsPerTok > 1 && opts.NormTopKProb {
|
||||
scores = scores.Divide(scores.SumAxis(-1, true).Add(mlx.FromValue[float32](1e-20)))
|
||||
scores = scores.Sigmoid()
|
||||
original := scores
|
||||
scores = scores.Add(correctionBias)
|
||||
|
||||
indices := scores.Negative().ArgpartitionAxis(opts.NumExpertsPerTok-1, -1)
|
||||
indices = indices.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, opts.NumExpertsPerTok))
|
||||
|
||||
scores = original.TakeAlongAxis(indices, -1)
|
||||
if opts.NumExpertsPerTok > 1 && opts.NormTopKProb {
|
||||
scores = scores.Divide(scores.SumAxis(-1, true).Add(mlx.FromValue[float32](1e-20)))
|
||||
}
|
||||
|
||||
scores = scores.Multiply(mlx.FromValue(opts.RoutedScalingFactor))
|
||||
return []*mlx.Array{indices, scores}
|
||||
}, false)
|
||||
}
|
||||
|
||||
scores = scores.Multiply(mlx.FromValue(opts.RoutedScalingFactor))
|
||||
return indices, scores
|
||||
return expertSelect
|
||||
}
|
||||
|
||||
func (m Gate) Forward(h *mlx.Array, opts Options) (indices, scores *mlx.Array) {
|
||||
outputs := ExpertSelect(opts).Call([]*mlx.Array{
|
||||
m.Gate.Forward(h).AsType(mlx.DTypeFloat32),
|
||||
&m.CorrectionBias,
|
||||
})
|
||||
return outputs[0], outputs[1]
|
||||
}
|
||||
|
||||
type sparse struct {
|
||||
|
||||
Reference in New Issue
Block a user