From 9319f13ff521d5fcb45f9157f953b60e15b00c00 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 6 Feb 2026 15:53:36 -0800 Subject: [PATCH] compile expert select --- x/mlxrunner/model/glm/4/moe/lite/model.go | 40 ++++++++++++++++------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/x/mlxrunner/model/glm/4/moe/lite/model.go b/x/mlxrunner/model/glm/4/moe/lite/model.go index e211c16ea..a4a138345 100644 --- a/x/mlxrunner/model/glm/4/moe/lite/model.go +++ b/x/mlxrunner/model/glm/4/moe/lite/model.go @@ -194,21 +194,39 @@ type Gate struct { CorrectionBias mlx.Array `weight:"gate.e_score_correction_bias"` } -func (m Gate) Forward(h *mlx.Array, opts Options) (indices, scores *mlx.Array) { - scores = m.Gate.Forward(h).AsType(mlx.DTypeFloat32).Sigmoid() - original := scores - scores = scores.Add(&m.CorrectionBias) +var expertSelect *mlx.Closure - indices = scores.Negative().ArgpartitionAxis(opts.NumExpertsPerTok-1, -1) - indices = indices.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, opts.NumExpertsPerTok)) +func ExpertSelect(opts Options) *mlx.Closure { + if expertSelect == nil { + expertSelect = mlx.Compile(func(inputs []*mlx.Array) []*mlx.Array { + scores, correctionBias := inputs[0], inputs[1] - scores = original.TakeAlongAxis(indices, -1) - if opts.NumExpertsPerTok > 1 && opts.NormTopKProb { - scores = scores.Divide(scores.SumAxis(-1, true).Add(mlx.FromValue[float32](1e-20))) + scores = scores.Sigmoid() + original := scores + scores = scores.Add(correctionBias) + + indices := scores.Negative().ArgpartitionAxis(opts.NumExpertsPerTok-1, -1) + indices = indices.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, opts.NumExpertsPerTok)) + + scores = original.TakeAlongAxis(indices, -1) + if opts.NumExpertsPerTok > 1 && opts.NormTopKProb { + scores = scores.Divide(scores.SumAxis(-1, true).Add(mlx.FromValue[float32](1e-20))) + } + + scores = scores.Multiply(mlx.FromValue(opts.RoutedScalingFactor)) + return []*mlx.Array{indices, scores} + }, false) } - scores = scores.Multiply(mlx.FromValue(opts.RoutedScalingFactor)) - return indices, scores + return expertSelect +} + +func (m Gate) Forward(h *mlx.Array, opts Options) (indices, scores *mlx.Array) { + outputs := ExpertSelect(opts).Call([]*mlx.Array{ + m.Gate.Forward(h).AsType(mlx.DTypeFloat32), + &m.CorrectionBias, + }) + return outputs[0], outputs[1] } type sparse struct {