From 041fb7763945d11e0484cb5c838b5fea67741c94 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Sun, 15 Feb 2026 22:47:59 -0800 Subject: [PATCH] model: add gemma3 to the mlxrunner (#14276) This change adds the gemma3 model to the mlxrunner and simplifies some of the quantization code for loading weights. --- x/mlxrunner/imports.go | 1 + x/mlxrunner/model/quant.go | 130 ++++++ x/mlxrunner/model/root.go | 213 ++++++++-- x/mlxrunner/pipeline.go | 10 +- x/models/gemma3/gemma3.go | 521 ++++++++++++++++++++++++ x/models/glm4_moe_lite/glm4_moe_lite.go | 115 +++--- 6 files changed, 895 insertions(+), 95 deletions(-) create mode 100644 x/mlxrunner/model/quant.go create mode 100644 x/models/gemma3/gemma3.go diff --git a/x/mlxrunner/imports.go b/x/mlxrunner/imports.go index e8950eff8..1daad3e4d 100644 --- a/x/mlxrunner/imports.go +++ b/x/mlxrunner/imports.go @@ -3,5 +3,6 @@ package mlxrunner import ( + _ "github.com/ollama/ollama/x/models/gemma3" _ "github.com/ollama/ollama/x/models/glm4_moe_lite" ) diff --git a/x/mlxrunner/model/quant.go b/x/mlxrunner/model/quant.go new file mode 100644 index 000000000..3a17ab485 --- /dev/null +++ b/x/mlxrunner/model/quant.go @@ -0,0 +1,130 @@ +//go:build mlx + +package model + +import ( + "strings" + + "github.com/ollama/ollama/x/mlxrunner/mlx" +) + +// QuantizationParams returns default groupSize, bits, and mode for a quantization type. +func QuantizationParams(quantization string) (groupSize, bits int, mode string) { + switch strings.ToUpper(quantization) { + case "NVFP4": + return 16, 4, "nvfp4" + case "FP4", "Q4", "INT4": + return 32, 4, "affine" + case "MXFP8": + return 32, 8, "mxfp8" + case "FP8", "Q8", "INT8", "": + return 64, 8, "affine" + default: + return 32, 8, "affine" + } +} + +// TensorQuantParams resolves quant params for a tensor using per-tensor metadata +// when available, otherwise falling back to the provided model defaults. +func TensorQuantParams( + defaultGroupSize, defaultBits int, + defaultMode string, + tensorQuant map[string]*TensorQuantInfo, + tensorName string, +) (groupSize, bits int, mode string, fromTensor bool) { + if tensorQuant != nil { + if tq := tensorQuant[tensorName]; tq != nil { + groupSize, bits, mode = QuantizationParams(tq.QuantType) + if tq.GroupSize > 0 { + groupSize = tq.GroupSize + } + return groupSize, bits, mode, true + } + } + return defaultGroupSize, defaultBits, defaultMode, false +} + +// ResolveLinearQuantParams resolves quantization params for a quantized linear +// tensor, preferring per-tensor metadata and falling back to shape-based +// inference for affine packed tensors. +func ResolveLinearQuantParams( + defaultGroupSize, defaultBits int, + defaultMode string, + tensorQuant map[string]*TensorQuantInfo, + tensorName string, + weight, scales *mlx.Array, +) (groupSize, bits int, mode string) { + groupSize, bits, mode, fromTensor := TensorQuantParams( + defaultGroupSize, + defaultBits, + defaultMode, + tensorQuant, + tensorName, + ) + + if mode == "affine" { + if inferredGroupSize, inferredBits, ok := InferAffineQuantParamsFromShapes(weight, scales, bits); ok { + if !fromTensor || groupSize == 0 || bits == 0 { + groupSize = inferredGroupSize + bits = inferredBits + } + } + } + + return groupSize, bits, mode +} + +// InferAffineQuantParamsFromShapes infers (groupSize,bits) for affine quantized +// tensors from packed weight and scale shapes. +func InferAffineQuantParamsFromShapes(weight, scales *mlx.Array, hintBits int) (groupSize, bits int, ok bool) { + if weight == nil || scales == nil { + return 0, 0, false + } + + weightShape := weight.Dims() + scaleShape := scales.Dims() + if len(weightShape) == 0 || len(scaleShape) == 0 { + return 0, 0, false + } + + weightCols := weightShape[len(weightShape)-1] + scalesCols := scaleShape[len(scaleShape)-1] + if weightCols <= 0 || scalesCols <= 0 { + return 0, 0, false + } + + groupSize4 := weightCols * 8 / scalesCols + groupSize8 := weightCols * 4 / scalesCols + + switch { + case groupSize4 == 32: + return 32, 4, true + case groupSize8 == 64: + return 64, 8, true + case groupSize4 == 64 && groupSize8 == 32: + if hintBits == 8 { + return 32, 8, true + } + if hintBits == 4 { + return 64, 4, true + } + } + + if isCommonGroupSize(groupSize4) && !isCommonGroupSize(groupSize8) { + return groupSize4, 4, true + } + if isCommonGroupSize(groupSize8) && !isCommonGroupSize(groupSize4) { + return groupSize8, 8, true + } + + return 0, 0, false +} + +func isCommonGroupSize(v int) bool { + switch v { + case 16, 32, 64, 128: + return true + default: + return false + } +} diff --git a/x/mlxrunner/model/root.go b/x/mlxrunner/model/root.go index 885647ab3..c912f7f4c 100644 --- a/x/mlxrunner/model/root.go +++ b/x/mlxrunner/model/root.go @@ -8,42 +8,63 @@ import ( "fmt" "io" "os" + "sort" + "strconv" "strings" "github.com/ollama/ollama/x/imagegen/manifest" ) -// Root wraps a ModelManifest with pre-scanned quantization metadata. -type Root struct { - Manifest *manifest.ModelManifest - quantType string - groupSize int +// TensorQuantInfo describes per-tensor quantization metadata. +type TensorQuantInfo struct { + QuantType string + GroupSize int } -// Open loads a manifest for the given model name and pre-scans the first -// tensor blob for quantization metadata (quant_type, group_size). +// Root wraps a ModelManifest with pre-scanned quantization metadata. +type Root struct { + Manifest *manifest.ModelManifest + + // Backwards-compatible model-level quant metadata (first tensor blob). + quantType string + groupSize int + + // Per-tensor quantization metadata. + tensorQuant map[string]*TensorQuantInfo +} + +// Open loads a manifest for the given model name and scans tensor blobs for +// quantization metadata. func Open(modelName string) (*Root, error) { m, err := manifest.LoadManifest(modelName) if err != nil { return nil, err } - root := &Root{Manifest: m} + root := &Root{ + Manifest: m, + tensorQuant: make(map[string]*TensorQuantInfo), + } - // Pre-scan first tensor blob for quantization metadata for _, layer := range m.GetTensorLayers("") { blobPath := m.BlobPath(layer.Digest) - meta, err := readBlobMetadata(blobPath) - if err != nil || meta == nil { + + infos, blobQuantType, blobGroupSize, err := readBlobTensorQuantInfo(blobPath) + if err != nil { continue } - if qt := meta["quant_type"]; qt != "" { - root.quantType = strings.ToUpper(qt) + + for name, info := range infos { + root.tensorQuant[name] = info } - if gs := meta["group_size"]; gs != "" { - fmt.Sscanf(gs, "%d", &root.groupSize) + + if root.quantType == "" && blobQuantType != "" { + root.quantType = strings.ToUpper(blobQuantType) + root.groupSize = blobGroupSize + if root.groupSize == 0 { + root.groupSize = defaultGroupSize(root.quantType) + } } - break // only check the first tensor blob } return root, nil @@ -52,46 +73,180 @@ func Open(modelName string) (*Root, error) { // Close is a no-op for now (future: release resources). func (r *Root) Close() {} -// QuantType returns the quantization type detected from tensor metadata. +// QuantType returns the quantization type detected from the first tensor blob metadata. func (r *Root) QuantType() string { return r.quantType } -// GroupSize returns the quantization group size detected from tensor metadata. +// GroupSize returns the quantization group size detected from the first tensor blob metadata. func (r *Root) GroupSize() int { return r.groupSize } -// readBlobMetadata reads the __metadata__ from a safetensors blob header. -func readBlobMetadata(path string) (map[string]string, error) { +// TensorQuant returns per-tensor quantization metadata if available. +func (r *Root) TensorQuant(name string) *TensorQuantInfo { + if r == nil { + return nil + } + return r.tensorQuant[name] +} + +// AllTensorQuant returns a copy of the per-tensor quantization metadata. +func (r *Root) AllTensorQuant() map[string]*TensorQuantInfo { + out := make(map[string]*TensorQuantInfo, len(r.tensorQuant)) + for k, v := range r.tensorQuant { + if v == nil { + continue + } + copy := *v + out[k] = © + } + return out +} + +func defaultGroupSize(quantType string) int { + groupSize, _, _ := QuantizationParams(quantType) + return groupSize +} + +func readBlobTensorQuantInfo(path string) (map[string]*TensorQuantInfo, string, int, error) { f, err := os.Open(path) if err != nil { - return nil, err + return nil, "", 0, err } defer f.Close() var headerSize uint64 if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil { - return nil, err + return nil, "", 0, err } - if headerSize > 1024*1024 { - return nil, fmt.Errorf("header too large: %d", headerSize) + if headerSize > 100*1024*1024 { + return nil, "", 0, fmt.Errorf("header too large: %d", headerSize) } data := make([]byte, headerSize) if _, err := io.ReadFull(f, data); err != nil { - return nil, err + return nil, "", 0, err } var header map[string]json.RawMessage if err := json.Unmarshal(data, &header); err != nil { - return nil, err + return nil, "", 0, err } + globalQuantType, globalGroupSize := parseGlobalQuantMetadata(header) + globalQuantType = strings.ToUpper(globalQuantType) + + mainNames := mainTensorNames(header) + infos := make(map[string]*TensorQuantInfo) + for _, name := range mainNames { + if _, ok := header[name+".scale"]; !ok { + continue + } + + quantType := globalQuantType + groupSize := globalGroupSize + + inferredType, inferredGroup := inferQuantTypeFromShapes(header, name, quantType) + if quantType == "" { + quantType = inferredType + } + if groupSize == 0 { + groupSize = inferredGroup + } + if quantType == "" { + continue + } + if groupSize == 0 { + groupSize = defaultGroupSize(quantType) + } + + infos[name] = &TensorQuantInfo{QuantType: quantType, GroupSize: groupSize} + } + + return infos, globalQuantType, globalGroupSize, nil +} + +func parseGlobalQuantMetadata(header map[string]json.RawMessage) (quantType string, groupSize int) { metaRaw, ok := header["__metadata__"] if !ok { - return nil, nil + return "", 0 } var meta map[string]string if err := json.Unmarshal(metaRaw, &meta); err != nil { - return nil, err + return "", 0 } - return meta, nil + + quantType = meta["quant_type"] + if gs := meta["group_size"]; gs != "" { + groupSize, _ = strconv.Atoi(gs) + } + return quantType, groupSize +} + +func mainTensorNames(header map[string]json.RawMessage) []string { + names := make([]string, 0, len(header)) + for name := range header { + if name == "__metadata__" || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".bias") { + continue + } + names = append(names, name) + } + sort.Strings(names) + return names +} + +func inferQuantTypeFromShapes(header map[string]json.RawMessage, tensorName string, hintQuantType string) (string, int) { + type tensorShape struct { + Shape []int64 `json:"shape"` + } + + mainRaw, ok := header[tensorName] + if !ok { + return "", 0 + } + scaleRaw, ok := header[tensorName+".scale"] + if !ok { + return "", 0 + } + + var mainInfo tensorShape + if err := json.Unmarshal(mainRaw, &mainInfo); err != nil || len(mainInfo.Shape) == 0 { + return "", 0 + } + + var scaleInfo tensorShape + if err := json.Unmarshal(scaleRaw, &scaleInfo); err != nil || len(scaleInfo.Shape) == 0 { + return "", 0 + } + + weightCols := int(mainInfo.Shape[len(mainInfo.Shape)-1]) + scalesCols := int(scaleInfo.Shape[len(scaleInfo.Shape)-1]) + if weightCols <= 0 || scalesCols <= 0 { + return "", 0 + } + + groupSize4 := weightCols * 8 / scalesCols + groupSize8 := weightCols * 4 / scalesCols + + switch { + case groupSize4 == 32: + return "INT4", 32 + case groupSize8 == 64: + return "INT8", 64 + case groupSize4 == 64 && groupSize8 == 32: + h := strings.ToUpper(hintQuantType) + if strings.Contains(h, "8") { + return "INT8", 32 + } + if strings.Contains(h, "4") { + return "INT4", 64 + } + } + + if isCommonGroupSize(groupSize4) && !isCommonGroupSize(groupSize8) { + return "INT4", groupSize4 + } + if isCommonGroupSize(groupSize8) && !isCommonGroupSize(groupSize4) { + return "INT8", groupSize8 + } + + return "", 0 } diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index b7650b68d..cd4d78620 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -24,9 +24,13 @@ func (r *Runner) TextGenerationPipeline(request Request) error { caches, tokens := r.FindNearestCache(inputs) if len(caches) == 0 { - caches = make([]cache.Cache, r.Model.NumLayers()) - for i := range caches { - caches[i] = cache.NewKVCache() + if cacheFactory, ok := r.Model.(interface{ NewCaches() []cache.Cache }); ok { + caches = cacheFactory.NewCaches() + } else { + caches = make([]cache.Cache, r.Model.NumLayers()) + for i := range caches { + caches[i] = cache.NewKVCache() + } } } diff --git a/x/models/gemma3/gemma3.go b/x/models/gemma3/gemma3.go new file mode 100644 index 000000000..f35ef2a75 --- /dev/null +++ b/x/models/gemma3/gemma3.go @@ -0,0 +1,521 @@ +//go:build mlx + +// Package gemma3 provides the Gemma 3 text model implementation for MLX. +package gemma3 + +import ( + "encoding/json" + "fmt" + "math" + + "github.com/ollama/ollama/x/imagegen/tokenizer" + "github.com/ollama/ollama/x/mlxrunner/cache" + "github.com/ollama/ollama/x/mlxrunner/mlx" + "github.com/ollama/ollama/x/mlxrunner/model" + "github.com/ollama/ollama/x/mlxrunner/model/base" + "github.com/ollama/ollama/x/models/nn" +) + +func init() { + base.Register("Gemma3ForCausalLM", newModel) + base.Register("Gemma3ForConditionalGeneration", newModel) +} + +// TextConfig holds configuration for the Gemma 3 text model. +type TextConfig struct { + HiddenSize int32 `json:"hidden_size"` + NumHiddenLayers int32 `json:"num_hidden_layers"` + IntermediateSize int32 `json:"intermediate_size"` + NumAttentionHeads int32 `json:"num_attention_heads"` + NumKeyValueHeads int32 `json:"num_key_value_heads"` + HeadDim int32 `json:"head_dim"` + VocabSize int32 `json:"vocab_size"` + RMSNormEps float32 `json:"rms_norm_eps"` + RopeTheta float32 `json:"rope_theta"` + RopeLocalBaseFreq float32 `json:"rope_local_base_freq"` + MaxPositionEmbeddings int32 `json:"max_position_embeddings"` + SlidingWindow int32 `json:"sliding_window"` + SlidingWindowPattern int32 `json:"sliding_window_pattern"` + LayerTypes []string `json:"layer_types"` + TieWordEmbeddings bool `json:"tie_word_embeddings"` + + // Quantization parameters (set during load based on model quantization). + QuantGroupSize int `json:"-"` + QuantBits int `json:"-"` + QuantMode string `json:"-"` + TensorQuant map[string]*model.TensorQuantInfo `json:"-"` + + // Computed fields. + Scale float32 `json:"-"` +} + +// Attention implements Gemma 3 attention with Q/K normalization. +type Attention struct { + QProj nn.LinearLayer + KProj nn.LinearLayer + VProj nn.LinearLayer + OProj nn.LinearLayer + + QNorm *nn.RMSNorm + KNorm *nn.RMSNorm + + // Precomputed (1 + weight) for Gemma-style RMSNorm. + QNormScaled *mlx.Array + KNormScaled *mlx.Array +} + +// MLP is the feed-forward network with GELU activation. +type MLP struct { + GateProj nn.LinearLayer + UpProj nn.LinearLayer + DownProj nn.LinearLayer +} + +// DecoderLayer is a single transformer block. +type DecoderLayer struct { + InputNorm *nn.RMSNorm + Attention *Attention + PostAttnNorm *nn.RMSNorm + PreFFNorm *nn.RMSNorm + MLP *MLP + PostFFNorm *nn.RMSNorm + + // Precomputed (1 + weight) for Gemma-style RMSNorm. + InputNormScaled *mlx.Array + PostAttnNormScaled *mlx.Array + PreFFNormScaled *mlx.Array + PostFFNormScaled *mlx.Array + + // Layer metadata. + IsSliding bool + LayerIdx int32 +} + +// Model is the Gemma 3 text-only model. +type Model struct { + EmbedTokens *nn.Embedding + Layers []*DecoderLayer + Norm *nn.RMSNorm + LMHead nn.LinearLayer + + // Precomputed (1 + weight) for Gemma-style RMSNorm. + NormScaled *mlx.Array + + tok *tokenizer.Tokenizer + *TextConfig + + weightPrefix string +} + +func defaultHeads(numLayers int32) (numHeads, numKVHeads int32) { + switch numLayers { + case 34: + return 8, 4 + case 48: + return 16, 8 + case 62: + return 32, 16 + default: + return 8, 4 + } +} + +func parseTextConfig(configData []byte) (TextConfig, bool, error) { + var cfg TextConfig + if err := json.Unmarshal(configData, &cfg); err != nil { + return TextConfig{}, false, fmt.Errorf("parse config: %w", err) + } + + var wrapped struct { + TextConfig *TextConfig `json:"text_config"` + } + if err := json.Unmarshal(configData, &wrapped); err != nil { + return TextConfig{}, false, fmt.Errorf("parse nested text config: %w", err) + } + + fromConditional := wrapped.TextConfig != nil + if fromConditional { + cfg = *wrapped.TextConfig + + if cfg.HeadDim == 0 { + cfg.HeadDim = 256 + } + if cfg.NumAttentionHeads == 0 { + cfg.NumAttentionHeads, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers) + } + if cfg.NumKeyValueHeads == 0 { + _, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers) + } + if cfg.VocabSize == 0 { + cfg.VocabSize = 262208 + } + if cfg.SlidingWindowPattern == 0 && len(cfg.LayerTypes) == 0 { + cfg.SlidingWindowPattern = 6 + } + if cfg.MaxPositionEmbeddings == 0 { + cfg.MaxPositionEmbeddings = 131072 + } + } + + if cfg.HeadDim == 0 { + cfg.HeadDim = 256 + } + if cfg.NumAttentionHeads == 0 { + cfg.NumAttentionHeads, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers) + } + if cfg.NumKeyValueHeads == 0 { + cfg.NumKeyValueHeads = max(1, cfg.NumAttentionHeads/2) + } + if cfg.RopeTheta == 0 { + cfg.RopeTheta = 1000000 + } + if cfg.RopeLocalBaseFreq == 0 { + cfg.RopeLocalBaseFreq = 10000 + } + if cfg.RMSNormEps == 0 { + cfg.RMSNormEps = 1e-6 + } + if cfg.VocabSize == 0 { + cfg.VocabSize = 262208 + } + + cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) + + return cfg, fromConditional, nil +} + +func resolveWeightPrefix(tensors map[string]*mlx.Array) string { + for _, prefix := range []string{"", "language_model."} { + if tensors[prefix+"model.embed_tokens.weight"] != nil { + return prefix + } + } + return "" +} + +func isLayerSliding(layerIdx int32, cfg *TextConfig) bool { + if len(cfg.LayerTypes) > 0 && int(layerIdx) < len(cfg.LayerTypes) { + return cfg.LayerTypes[layerIdx] == "sliding_attention" + } + if cfg.SlidingWindowPattern <= 0 { + return false + } + return (layerIdx+1)%cfg.SlidingWindowPattern != 0 +} + +func precomputeGemmaScaledWeights(m *Model) { + if m.Norm != nil { + m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0) + } + + var scaled []*mlx.Array + if m.NormScaled != nil { + scaled = append(scaled, m.NormScaled) + } + + for _, layer := range m.Layers { + if layer == nil || layer.Attention == nil { + continue + } + + if layer.InputNorm != nil { + layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0) + scaled = append(scaled, layer.InputNormScaled) + } + if layer.PostAttnNorm != nil { + layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0) + scaled = append(scaled, layer.PostAttnNormScaled) + } + if layer.PreFFNorm != nil { + layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0) + scaled = append(scaled, layer.PreFFNormScaled) + } + if layer.PostFFNorm != nil { + layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0) + scaled = append(scaled, layer.PostFFNormScaled) + } + + if layer.Attention.QNorm != nil { + layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0) + scaled = append(scaled, layer.Attention.QNormScaled) + } + if layer.Attention.KNorm != nil { + layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0) + scaled = append(scaled, layer.Attention.KNormScaled) + } + } + + if len(scaled) > 0 { + mlx.Eval(scaled...) + } +} + +func newModel(root *model.Root) (base.Model, error) { + configData, err := root.Manifest.ReadConfig("config.json") + if err != nil { + return nil, fmt.Errorf("load config: %w", err) + } + + cfg, _, err := parseTextConfig(configData) + if err != nil { + return nil, err + } + + if qt := root.QuantType(); qt != "" { + cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt) + if gs := root.GroupSize(); gs > 0 { + cfg.QuantGroupSize = gs + } + } else { + cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("") + } + cfg.TensorQuant = root.AllTensorQuant() + + tokData, err := root.Manifest.ReadConfig("tokenizer.json") + if err != nil { + return nil, fmt.Errorf("load tokenizer config: %w", err) + } + + tokConfig := &tokenizer.TokenizerConfig{ConfigJSON: configData} + if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil { + tokConfig.GenerationConfigJSON = genConfigData + } + if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil { + tokConfig.TokenizerConfigJSON = tokConfigData + } + + tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig) + if err != nil { + return nil, fmt.Errorf("parse tokenizer: %w", err) + } + + m := &Model{ + Layers: make([]*DecoderLayer, cfg.NumHiddenLayers), + TextConfig: &cfg, + tok: tok, + } + + for i := range m.Layers { + m.Layers[i] = &DecoderLayer{ + LayerIdx: int32(i), + IsSliding: isLayerSliding(int32(i), m.TextConfig), + } + } + + return m, nil +} + +// LoadWeights receives all tensors loaded from the manifest and assigns them +// to model fields. +func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { + m.weightPrefix = resolveWeightPrefix(tensors) + prefix := m.weightPrefix + linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant) + + embedWeight := tensors[prefix+"model.embed_tokens.weight"] + if embedWeight == nil { + return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix) + } + m.EmbedTokens = nn.NewEmbedding(embedWeight) + + normWeight := tensors[prefix+"model.norm.weight"] + if normWeight == nil { + return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix) + } + m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps) + + if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil { + m.LMHead = lmHead + } else if lmHead := linears.Make("lm_head"); lmHead != nil { + m.LMHead = lmHead + } else { + // Gemma usually ties output projection to embeddings. + m.LMHead = nn.NewLinear(embedWeight, nil) + } + + for i := int32(0); i < m.NumHiddenLayers; i++ { + layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i) + + layer := &DecoderLayer{ + LayerIdx: i, + IsSliding: isLayerSliding(i, m.TextConfig), + Attention: &Attention{}, + MLP: &MLP{}, + } + + if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil { + layer.InputNorm = nn.NewRMSNorm(w, m.RMSNormEps) + } + if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil { + layer.PostAttnNorm = nn.NewRMSNorm(w, m.RMSNormEps) + } + if w := tensors[layerPrefix+".pre_feedforward_layernorm.weight"]; w != nil { + layer.PreFFNorm = nn.NewRMSNorm(w, m.RMSNormEps) + } + if w := tensors[layerPrefix+".post_feedforward_layernorm.weight"]; w != nil { + layer.PostFFNorm = nn.NewRMSNorm(w, m.RMSNormEps) + } + + layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj") + layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj") + layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj") + layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj") + + if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil { + layer.Attention.QNorm = nn.NewRMSNorm(w, m.RMSNormEps) + } + if w := tensors[layerPrefix+".self_attn.k_norm.weight"]; w != nil { + layer.Attention.KNorm = nn.NewRMSNorm(w, m.RMSNormEps) + } + + layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj") + layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj") + layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj") + + if layer.InputNorm == nil { + return fmt.Errorf("layer %d: missing input_layernorm", i) + } + if layer.PostAttnNorm == nil { + return fmt.Errorf("layer %d: missing post_attention_layernorm", i) + } + if layer.PreFFNorm == nil { + return fmt.Errorf("layer %d: missing pre_feedforward_layernorm", i) + } + if layer.PostFFNorm == nil { + return fmt.Errorf("layer %d: missing post_feedforward_layernorm", i) + } + if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil { + return fmt.Errorf("layer %d: missing attention projections", i) + } + if layer.Attention.QNorm == nil || layer.Attention.KNorm == nil { + return fmt.Errorf("layer %d: missing attention q/k norms", i) + } + if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil { + return fmt.Errorf("layer %d: missing mlp projections", i) + } + + m.Layers[i] = layer + } + + precomputeGemmaScaledWeights(m) + if m.NormScaled == nil { + return fmt.Errorf("missing precomputed final norm weight") + } + collected := mlx.Collect(m) + mlx.Eval(collected...) + + return nil +} + +func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { + dims := tokens.Dims() + B, L := int32(dims[0]), int32(dims[1]) + + h := m.EmbedTokens.Forward(tokens) + h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize)))) + + for i, layer := range m.Layers { + var c cache.Cache + if caches != nil && i < len(caches) { + c = caches[i] + } + h = layer.Forward(h, c, B, L, m.TextConfig) + } + + return mlx.RMSNormFn(h, m.NormScaled, m.RMSNormEps) +} + +func (m *Model) Unembed(x *mlx.Array) *mlx.Array { + return m.LMHead.Forward(x) +} + +func (m *Model) NumLayers() int { + return len(m.Layers) +} + +func (m *Model) Tokenizer() *tokenizer.Tokenizer { + return m.tok +} + +// NewCaches creates cache objects for all layers. +func (m *Model) NewCaches() []cache.Cache { + caches := make([]cache.Cache, len(m.Layers)) + for i, layer := range m.Layers { + if m.SlidingWindow > 0 && layer.IsSliding { + caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow)) + } else { + caches[i] = cache.NewKVCache() + } + } + return caches +} + +// FormatPrompt applies the Gemma 3 chat template. +func (m *Model) FormatPrompt(prompt string) string { + return fmt.Sprintf("user\n%s\nmodel\n", prompt) +} + +func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array { + normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps) + + attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg) + attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps) + h := mlx.Add(x, attnOut) + + normed = mlx.RMSNormFn(h, l.PreFFNormScaled, cfg.RMSNormEps) + + mlpOut := l.MLP.Forward(normed) + mlpOut = mlx.RMSNormFn(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps) + + return mlx.Add(h, mlpOut) +} + +func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array { + q := a.QProj.Forward(x) + k := a.KProj.Forward(x) + v := a.VProj.Forward(x) + + q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim) + q = mlx.Transpose(q, 0, 2, 1, 3) + + k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim) + k = mlx.Transpose(k, 0, 2, 1, 3) + + v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim) + v = mlx.Transpose(v, 0, 2, 1, 3) + + q = mlx.RMSNormFn(q, a.QNormScaled, cfg.RMSNormEps) + k = mlx.RMSNormFn(k, a.KNormScaled, cfg.RMSNormEps) + + ropeTheta := cfg.RopeTheta + if isSliding { + ropeTheta = cfg.RopeLocalBaseFreq + } + + offset := 0 + if c != nil { + offset = c.Offset() + } + q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, ropeTheta, 1.0, offset) + k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, ropeTheta, 1.0, offset) + + if c != nil { + k, v = c.Update(k, v) + } + + repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads + if repeatFactor > 1 { + k = nn.RepeatKV(k, repeatFactor) + v = nn.RepeatKV(v, repeatFactor) + } + + out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1) + out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim) + return a.OProj.Forward(out) +} + +func (m *MLP) Forward(x *mlx.Array) *mlx.Array { + gate := mlx.GELUApprox(m.GateProj.Forward(x)) + up := m.UpProj.Forward(x) + return m.DownProj.Forward(mlx.Mul(gate, up)) +} diff --git a/x/models/glm4_moe_lite/glm4_moe_lite.go b/x/models/glm4_moe_lite/glm4_moe_lite.go index 974213196..65e26244d 100644 --- a/x/models/glm4_moe_lite/glm4_moe_lite.go +++ b/x/models/glm4_moe_lite/glm4_moe_lite.go @@ -8,7 +8,6 @@ import ( "encoding/json" "fmt" "math" - "strings" "github.com/ollama/ollama/x/imagegen/tokenizer" "github.com/ollama/ollama/x/mlxrunner/cache" @@ -64,9 +63,10 @@ type Config struct { RopeScaling *RopeScaling `json:"rope_scaling"` // Quantization parameters (set during load based on model quantization) - QuantGroupSize int `json:"-"` // Group size for quantization (default 64) - QuantBits int `json:"-"` // Bits per weight (4 or 8) - QuantMode string `json:"-"` // Quantization mode ("affine", etc.) + QuantGroupSize int `json:"-"` // Group size for quantization (default 64) + QuantBits int `json:"-"` // Bits per weight (4 or 8) + QuantMode string `json:"-"` // Quantization mode ("affine", etc.) + TensorQuant map[string]*model.TensorQuantInfo `json:"-"` // Computed fields QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim @@ -372,22 +372,6 @@ func supportsGatherQMM(mode string, bits int) bool { return mode == "affine" && (bits == 4 || bits == 8) } -// quantizationParams returns groupSize, bits, mode for a quantization type string. -func quantizationParams(quantization string) (groupSize, bits int, mode string) { - switch strings.ToUpper(quantization) { - case "NVFP4": - return 16, 4, "nvfp4" - case "FP4", "Q4", "INT4": - return 32, 4, "affine" - case "MXFP8": - return 32, 8, "mxfp8" - case "FP8", "Q8", "INT8", "": - return 64, 8, "affine" - default: - return 32, 8, "affine" - } -} - // ExpertWeight holds a single expert's weight with optional quantization components. type ExpertWeight struct { Weight *mlx.Array @@ -408,7 +392,15 @@ func loadExpertWeight(tensors map[string]*mlx.Array, path string, useQuantized b if scales != nil { qbiases := tensors[path+".weight_qbias"] - groupSize, bits, mode := cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode + groupSize, bits, mode := model.ResolveLinearQuantParams( + cfg.QuantGroupSize, + cfg.QuantBits, + cfg.QuantMode, + cfg.TensorQuant, + path+".weight", + w, + scales, + ) if useQuantized && supportsGatherQMM(mode, bits) { return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize} @@ -492,7 +484,16 @@ func sanitizeMLAWeights(tensors map[string]*mlx.Array, prefix string, cfg *Confi // Check if quantized and dequantize if scales := tensors[path+".weight_scale"]; scales != nil { qbiases := tensors[path+".weight_qbias"] - w = mlx.Dequantize(w, scales, qbiases, cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode) + groupSize, bits, mode := model.ResolveLinearQuantParams( + cfg.QuantGroupSize, + cfg.QuantBits, + cfg.QuantMode, + cfg.TensorQuant, + path+".weight", + w, + scales, + ) + w = mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode) } headDim := cfg.QKNopeHeadDim + cfg.VHeadDim @@ -507,32 +508,6 @@ func sanitizeMLAWeights(tensors map[string]*mlx.Array, prefix string, cfg *Confi return embedQ, unembedOut } -// makeLinear creates a Linear or QuantizedLinear layer from the tensor map. -func makeLinear(tensors map[string]*mlx.Array, path string, cfg *Config) nn.LinearLayer { - w := tensors[path+".weight"] - if w == nil { - return nil - } - - scales := tensors[path+".weight_scale"] - if scales != nil { - qbiases := tensors[path+".weight_qbias"] - bias := tensors[path+".bias"] - return &nn.QuantizedLinear{ - Weight: w, - Scales: scales, - QBiases: qbiases, - Bias: bias, - GroupSize: cfg.QuantGroupSize, - Bits: cfg.QuantBits, - Mode: cfg.QuantMode, - } - } - - bias := tensors[path+".bias"] - return nn.NewLinear(w, bias) -} - // newModel creates a new GLM4-MoE-Lite model from a Root (config + tokenizer, // no weights loaded yet). Called by the registry via base.New(). func newModel(root *model.Root) (base.Model, error) { @@ -551,13 +526,14 @@ func newModel(root *model.Root) (base.Model, error) { // Set up quantization parameters from pre-scanned metadata if qt := root.QuantType(); qt != "" { - _, cfg.QuantBits, cfg.QuantMode = quantizationParams(qt) + cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt) if gs := root.GroupSize(); gs > 0 { cfg.QuantGroupSize = gs - } else { - cfg.QuantGroupSize, _, _ = quantizationParams(qt) } + } else { + cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("") } + cfg.TensorQuant = root.AllTensorQuant() // Load tokenizer tokData, err := root.Manifest.ReadConfig("tokenizer.json") @@ -596,7 +572,20 @@ func newModel(root *model.Root) (base.Model, error) { // layer creation. func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { cfg := m.Config + linears := model.NewLinearFactory(tensors, cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant) useQuantized := supportsGatherQMM(cfg.QuantMode, cfg.QuantBits) + if !useQuantized && cfg.TensorQuant != nil { + for _, tq := range cfg.TensorQuant { + if tq == nil { + continue + } + _, bits, mode := model.QuantizationParams(tq.QuantType) + if supportsGatherQMM(mode, bits) { + useQuantized = true + break + } + } + } // Load embedding if w := tensors["model.embed_tokens.weight"]; w != nil { @@ -609,7 +598,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { } // Load LM head - m.LMHead = makeLinear(tensors, "lm_head", cfg) + m.LMHead = linears.Make("lm_head") // Load layers for i := int32(0); i < cfg.NumHiddenLayers; i++ { @@ -617,16 +606,16 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { // Load attention (same for both block types) attn := &MLAAttention{} - attn.QAProj = makeLinear(tensors, prefix+".self_attn.q_a_proj", cfg) + attn.QAProj = linears.Make(prefix + ".self_attn.q_a_proj") if w := tensors[prefix+".self_attn.q_a_layernorm.weight"]; w != nil { attn.QALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps) } - attn.QBProj = makeLinear(tensors, prefix+".self_attn.q_b_proj", cfg) - attn.KVAProjWithMQA = makeLinear(tensors, prefix+".self_attn.kv_a_proj_with_mqa", cfg) + attn.QBProj = linears.Make(prefix + ".self_attn.q_b_proj") + attn.KVAProjWithMQA = linears.Make(prefix + ".self_attn.kv_a_proj_with_mqa") if w := tensors[prefix+".self_attn.kv_a_layernorm.weight"]; w != nil { attn.KVALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps) } - attn.OProj = makeLinear(tensors, prefix+".self_attn.o_proj", cfg) + attn.OProj = linears.Make(prefix + ".self_attn.o_proj") // Sanitize MLA weights for absorbed attention embedQ, unembedOut := sanitizeMLAWeights(tensors, prefix, cfg) @@ -647,9 +636,9 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { } block.MLP = &DenseMLP{ - GateProj: makeLinear(tensors, prefix+".mlp.gate_proj", cfg), - UpProj: makeLinear(tensors, prefix+".mlp.up_proj", cfg), - DownProj: makeLinear(tensors, prefix+".mlp.down_proj", cfg), + GateProj: linears.Make(prefix + ".mlp.gate_proj"), + UpProj: linears.Make(prefix + ".mlp.up_proj"), + DownProj: linears.Make(prefix + ".mlp.down_proj"), } m.Layers[i] = block @@ -690,7 +679,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { } moeGate := &MoEGate{} - moeGate.Gate = makeLinear(tensors, prefix+".mlp.gate", cfg) + moeGate.Gate = linears.Make(prefix + ".mlp.gate") if bias := tensors[prefix+".mlp.gate.e_score_correction_bias"]; bias != nil { moeGate.EScoreCorrectionBias = bias } @@ -703,9 +692,9 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { // Load shared experts if present if cfg.NSharedExperts > 0 { block.MoE.SharedExperts = &SharedExperts{ - GateProj: makeLinear(tensors, prefix+".mlp.shared_experts.gate_proj", cfg), - UpProj: makeLinear(tensors, prefix+".mlp.shared_experts.up_proj", cfg), - DownProj: makeLinear(tensors, prefix+".mlp.shared_experts.down_proj", cfg), + GateProj: linears.Make(prefix + ".mlp.shared_experts.gate_proj"), + UpProj: linears.Make(prefix + ".mlp.shared_experts.up_proj"), + DownProj: linears.Make(prefix + ".mlp.shared_experts.down_proj"), } }