compile expert select

This commit is contained in:
Michael Yang
2026-02-06 15:53:36 -08:00
parent 26b08f889e
commit 9319f13ff5

View File

@@ -194,21 +194,39 @@ type Gate struct {
CorrectionBias mlx.Array `weight:"gate.e_score_correction_bias"`
}
func (m Gate) Forward(h *mlx.Array, opts Options) (indices, scores *mlx.Array) {
scores = m.Gate.Forward(h).AsType(mlx.DTypeFloat32).Sigmoid()
original := scores
scores = scores.Add(&m.CorrectionBias)
var expertSelect *mlx.Closure
indices = scores.Negative().ArgpartitionAxis(opts.NumExpertsPerTok-1, -1)
indices = indices.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, opts.NumExpertsPerTok))
func ExpertSelect(opts Options) *mlx.Closure {
if expertSelect == nil {
expertSelect = mlx.Compile(func(inputs []*mlx.Array) []*mlx.Array {
scores, correctionBias := inputs[0], inputs[1]
scores = original.TakeAlongAxis(indices, -1)
if opts.NumExpertsPerTok > 1 && opts.NormTopKProb {
scores = scores.Divide(scores.SumAxis(-1, true).Add(mlx.FromValue[float32](1e-20)))
scores = scores.Sigmoid()
original := scores
scores = scores.Add(correctionBias)
indices := scores.Negative().ArgpartitionAxis(opts.NumExpertsPerTok-1, -1)
indices = indices.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, opts.NumExpertsPerTok))
scores = original.TakeAlongAxis(indices, -1)
if opts.NumExpertsPerTok > 1 && opts.NormTopKProb {
scores = scores.Divide(scores.SumAxis(-1, true).Add(mlx.FromValue[float32](1e-20)))
}
scores = scores.Multiply(mlx.FromValue(opts.RoutedScalingFactor))
return []*mlx.Array{indices, scores}
}, false)
}
scores = scores.Multiply(mlx.FromValue(opts.RoutedScalingFactor))
return indices, scores
return expertSelect
}
func (m Gate) Forward(h *mlx.Array, opts Options) (indices, scores *mlx.Array) {
outputs := ExpertSelect(opts).Call([]*mlx.Array{
m.Gate.Forward(h).AsType(mlx.DTypeFloat32),
&m.CorrectionBias,
})
return outputs[0], outputs[1]
}
type sparse struct {