From 0ade9205cce88006245dc54ea8884607822b103b Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Sun, 22 Feb 2026 15:09:14 -0800 Subject: [PATCH] models: add nemotronh architecture support (#14356) --- convert/convert.go | 7 +- convert/convert_nemotron_h.go | 385 +++++++++ convert/convert_nemotron_h_test.go | 230 ++++++ convert/json_compat.go | 97 +++ convert/json_compat_test.go | 46 ++ fs/ggml/ggml.go | 23 + kvcache/recurrent.go | 752 ++++++++++++++++++ kvcache/recurrent_checkpoints.go | 561 +++++++++++++ kvcache/recurrent_checkpoints_test.go | 288 +++++++ ...-mul_mat_id-map0-and-add-ne20-22-spe.patch | 37 + ml/backend.go | 1 + ml/backend/ggml/ggml.go | 7 + .../src/ggml-metal/ggml-metal-embed.metal | 1 + .../ggml/src/ggml-metal/ggml-metal-ops.cpp | 3 +- .../ggml/ggml/src/ggml-metal/ggml-metal.metal | 1 + model/model_test.go | 1 + model/models/models.go | 1 + model/models/nemotronh/attention.go | 88 ++ model/models/nemotronh/cache.go | 55 ++ model/models/nemotronh/mamba2.go | 197 +++++ model/models/nemotronh/model.go | 417 ++++++++++ server/sched.go | 2 +- 22 files changed, 3196 insertions(+), 4 deletions(-) create mode 100644 convert/convert_nemotron_h.go create mode 100644 convert/convert_nemotron_h_test.go create mode 100644 convert/json_compat.go create mode 100644 convert/json_compat_test.go create mode 100644 kvcache/recurrent.go create mode 100644 kvcache/recurrent_checkpoints.go create mode 100644 kvcache/recurrent_checkpoints_test.go create mode 100644 llama/patches/0034-ggml-metal-guard-mul_mat_id-map0-and-add-ne20-22-spe.patch create mode 100644 model/models/nemotronh/attention.go create mode 100644 model/models/nemotronh/cache.go create mode 100644 model/models/nemotronh/mamba2.go create mode 100644 model/models/nemotronh/model.go diff --git a/convert/convert.go b/convert/convert.go index 1f318be90..abb0bc336 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -257,10 +257,11 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) { if err != nil { return nil, nil, err } + bts = sanitizeNonFiniteJSON(bts) var p ModelParameters if err := json.Unmarshal(bts, &p); err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("parse config.json: %w", err) } if len(p.Architectures) < 1 { @@ -319,12 +320,14 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) { conv = &lfm2Model{} case "Qwen3NextForCausalLM": conv = &qwen3NextModel{} + case "NemotronHForCausalLM": + conv = &nemotronHModel{} default: return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0]) } if err := json.Unmarshal(bts, conv); err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("parse config.json for %q: %w", p.Architectures[0], err) } if t, ok := conv.(moreParser); ok { diff --git a/convert/convert_nemotron_h.go b/convert/convert_nemotron_h.go new file mode 100644 index 000000000..59ea461f1 --- /dev/null +++ b/convert/convert_nemotron_h.go @@ -0,0 +1,385 @@ +package convert + +import ( + "cmp" + "encoding/json" + "fmt" + "io/fs" + "math" + "slices" + "strings" + + "github.com/ollama/ollama/fs/ggml" +) + +type hybridPattern string + +func (p *hybridPattern) UnmarshalJSON(data []byte) error { + if string(data) == "null" { + *p = "" + return nil + } + + var single string + if err := json.Unmarshal(data, &single); err == nil { + *p = hybridPattern(strings.TrimSpace(single)) + return nil + } + + var parts []string + if err := json.Unmarshal(data, &parts); err == nil { + *p = hybridPattern(strings.Join(parts, "")) + return nil + } + + return fmt.Errorf("hybrid_override_pattern must be a string or string array") +} + +type nemotronHModel struct { + ModelParameters + MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` + HiddenSize uint32 `json:"hidden_size"` + NumHiddenLayers uint32 `json:"num_hidden_layers"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumKeyValueHeads uint32 `json:"num_key_value_heads"` + HeadDim uint32 `json:"head_dim"` + LayerNormEpsilon float32 `json:"layer_norm_epsilon"` + NormEpsilon float32 `json:"norm_eps"` + RopeTheta float32 `json:"rope_theta"` + PartialRotaryFactor float32 `json:"partial_rotary_factor"` + ConvKernel uint32 `json:"conv_kernel"` + SSMStateSize uint32 `json:"ssm_state_size"` + MambaNumHeads uint32 `json:"mamba_num_heads"` + MambaHeadDim uint32 `json:"mamba_head_dim"` + NGroups uint32 `json:"n_groups"` + IntermediateSize uint32 `json:"intermediate_size"` + HybridOverridePattern hybridPattern `json:"hybrid_override_pattern"` + + // MoE + NumExperts uint32 `json:"num_experts"` + NumSharedExperts uint32 `json:"num_shared_experts"` + NRoutedExperts uint32 `json:"n_routed_experts"` + NSharedExperts uint32 `json:"n_shared_experts"` + NumExpertsPerTok uint32 `json:"num_experts_per_tok"` + MoEIntermediateSize uint32 `json:"moe_intermediate_size"` + MoESharedExpertIntermediate uint32 `json:"moe_shared_expert_intermediate_size"` + NormTopKProb bool `json:"norm_topk_prob"` + RoutedScalingFactor float32 `json:"routed_scaling_factor"` + ExpertGroupCount uint32 `json:"n_group"` + ExpertGroupUsedCount uint32 `json:"topk_group"` +} + +var _ ModelConverter = (*nemotronHModel)(nil) + +func (n *nemotronHModel) parseMore(_ fs.FS) error { + if n.NumHiddenLayers == 0 { + return fmt.Errorf("nemotron_h: num_hidden_layers must be set") + } + if n.HiddenSize == 0 { + return fmt.Errorf("nemotron_h: hidden_size must be set") + } + if n.NumAttentionHeads == 0 { + return fmt.Errorf("nemotron_h: num_attention_heads must be set") + } + if n.HeadDim == 0 { + if n.HiddenSize%n.NumAttentionHeads != 0 { + return fmt.Errorf("nemotron_h: hidden_size (%d) must be divisible by num_attention_heads (%d)", n.HiddenSize, n.NumAttentionHeads) + } + n.HeadDim = n.HiddenSize / n.NumAttentionHeads + } + if n.NumKeyValueHeads == 0 { + n.NumKeyValueHeads = n.NumAttentionHeads + } + if n.ConvKernel == 0 { + return fmt.Errorf("nemotron_h: conv_kernel must be set") + } + if n.SSMStateSize == 0 { + return fmt.Errorf("nemotron_h: ssm_state_size must be set") + } + if n.ssmHeadCount() == 0 { + return fmt.Errorf("nemotron_h: mamba_num_heads must be set") + } + if n.MambaHeadDim == 0 { + return fmt.Errorf("nemotron_h: mamba_head_dim must be set") + } + if n.NGroups == 0 { + n.NGroups = 1 + } + + if _, _, err := n.layerArrays(); err != nil { + return err + } + + if n.isMoE() { + if n.routedExpertCount() == 0 { + return fmt.Errorf("nemotron_h: routed expert count must be set for MoE models") + } + if n.NumExpertsPerTok == 0 { + return fmt.Errorf("nemotron_h: num_experts_per_tok must be set for MoE models") + } + if n.NumExpertsPerTok > n.routedExpertCount() { + return fmt.Errorf("nemotron_h: num_experts_per_tok (%d) cannot exceed expert_count (%d)", n.NumExpertsPerTok, n.routedExpertCount()) + } + if n.moeIntermediateSize() == 0 { + return fmt.Errorf("nemotron_h: moe_intermediate_size must be set for MoE models") + } + } + + return nil +} + +func (n *nemotronHModel) isMoE() bool { + return cmp.Or(n.routedExpertCount(), n.NumExpertsPerTok, n.MoEIntermediateSize) > 0 +} + +func (n *nemotronHModel) routedExpertCount() uint32 { + return cmp.Or(n.NRoutedExperts, n.NumExperts) +} + +func (n *nemotronHModel) sharedExpertCount() uint32 { + return cmp.Or(n.NSharedExperts, n.NumSharedExperts) +} + +func (n *nemotronHModel) ssmHeadCount() uint32 { + return n.MambaNumHeads +} + +func (n *nemotronHModel) ssmInnerSize() uint32 { + return n.MambaHeadDim * n.ssmHeadCount() +} + +func (n *nemotronHModel) epsilon() float32 { + return cmp.Or(n.NormEpsilon, n.LayerNormEpsilon, float32(1e-5)) +} + +func (n *nemotronHModel) moeIntermediateSize() uint32 { + return cmp.Or(n.MoEIntermediateSize, n.IntermediateSize) +} + +func (n *nemotronHModel) denseIntermediateSize() uint32 { + return cmp.Or(n.IntermediateSize, n.MoEIntermediateSize) +} + +func (n *nemotronHModel) layerArrays() (headCountKV []uint32, ffnLengths []uint32, err error) { + pattern := strings.TrimSpace(string(n.HybridOverridePattern)) + if pattern == "" { + return nil, nil, fmt.Errorf("nemotron_h: hybrid_override_pattern must be set") + } + + runes := []rune(pattern) + if len(runes) != int(n.NumHiddenLayers) { + return nil, nil, fmt.Errorf("nemotron_h: hybrid_override_pattern length (%d) must match num_hidden_layers (%d)", len(runes), n.NumHiddenLayers) + } + + headCountKV = make([]uint32, n.NumHiddenLayers) + ffnLengths = make([]uint32, n.NumHiddenLayers) + + attnKVHeads := cmp.Or(n.NumKeyValueHeads, n.NumAttentionHeads) + moeFFN := n.moeIntermediateSize() + denseFFN := n.denseIntermediateSize() + + for i, layerType := range runes { + switch layerType { + case 'M': + // Recurrent layer: no KV heads and no FFN. + case '*', 'A': + // Attention-only layer. + headCountKV[i] = attnKVHeads + case 'E': + // MoE layer. + if moeFFN == 0 { + return nil, nil, fmt.Errorf("nemotron_h: moe layer at index %d but moe_intermediate_size is zero", i) + } + ffnLengths[i] = moeFFN + case '-': + // Dense FFN layer. + if denseFFN == 0 { + return nil, nil, fmt.Errorf("nemotron_h: dense FFN layer at index %d but intermediate_size is zero", i) + } + ffnLengths[i] = denseFFN + default: + return nil, nil, fmt.Errorf("nemotron_h: unsupported layer type %q in hybrid_override_pattern at index %d", layerType, i) + } + } + + return headCountKV, ffnLengths, nil +} + +func (n *nemotronHModel) KV(t *Tokenizer) KV { + kv := n.ModelParameters.KV(t) + + arch := "nemotron_h" + if n.isMoE() { + arch = "nemotron_h_moe" + } + kv["general.architecture"] = arch + kv["block_count"] = n.NumHiddenLayers + kv["context_length"] = n.MaxPositionEmbeddings + kv["embedding_length"] = n.HiddenSize + kv["attention.head_count"] = n.NumAttentionHeads + kv["attention.key_length"] = n.HeadDim + kv["attention.value_length"] = n.HeadDim + kv["attention.layer_norm_epsilon"] = n.epsilon() + kv["attention.layer_norm_rms_epsilon"] = n.epsilon() + kv["rope.freq_base"] = cmp.Or(n.RopeTheta, float32(10000)) + if n.PartialRotaryFactor > 0 && n.PartialRotaryFactor <= 1 { + kv["rope.dimension_count"] = uint32(float32(n.HeadDim) * n.PartialRotaryFactor) + } + + if headCountKV, ffnLengths, err := n.layerArrays(); err == nil { + kv["attention.head_count_kv"] = headCountKV + kv["feed_forward_length"] = ffnLengths + } + + kv["ssm.conv_kernel"] = n.ConvKernel + kv["ssm.inner_size"] = n.ssmInnerSize() + kv["ssm.state_size"] = n.SSMStateSize + kv["ssm.group_count"] = n.NGroups + kv["ssm.time_step_rank"] = n.ssmHeadCount() + + if n.isMoE() { + kv["expert_count"] = n.routedExpertCount() + kv["expert_used_count"] = n.NumExpertsPerTok + kv["expert_feed_forward_length"] = n.moeIntermediateSize() + if n.sharedExpertCount() > 0 { + kv["expert_shared_count"] = n.sharedExpertCount() + } + if n.MoESharedExpertIntermediate > 0 { + kv["expert_shared_feed_forward_length"] = n.MoESharedExpertIntermediate + } + kv["expert_weights_norm"] = n.NormTopKProb + kv["expert_weights_scale"] = n.RoutedScalingFactor + if n.ExpertGroupCount > 0 { + kv["expert_group_count"] = n.ExpertGroupCount + } + if n.ExpertGroupUsedCount > 0 { + kv["expert_group_used_count"] = n.ExpertGroupUsedCount + } + } + + return kv +} + +func normalizeVectorShapeToColumn(shape []uint64) []uint64 { + switch len(shape) { + case 1: + return []uint64{shape[0], 1} + case 2: + if shape[0] == 1 && shape[1] > 1 { + return []uint64{shape[1], 1} + } + if shape[1] == 1 && shape[0] > 1 { + return []uint64{shape[0], 1} + } + } + + return slices.Clone(shape) +} + +func (n *nemotronHModel) Tensors(ts []Tensor) []*ggml.Tensor { + var out []*ggml.Tensor + + remaining := ts + if n.isMoE() { + merges := make([]merge, 0, n.NumHiddenLayers*2) + for i := range n.NumHiddenLayers { + merges = append(merges, merge{ + fmt.Sprintf("blk.%d.mixer.experts.*.up_proj.weight", i), + fmt.Sprintf("blk.%d.ffn_up_exps.weight", i), + }, merge{ + fmt.Sprintf("blk.%d.mixer.experts.*.down_proj.weight", i), + fmt.Sprintf("blk.%d.ffn_down_exps.weight", i), + }) + } + + merged, rest := mergeTensors(ts, merges...) + out = append(out, merged...) + remaining = rest + } + + nGroups := uint64(cmp.Or(n.NGroups, uint32(1))) + for _, t := range remaining { + name := t.Name() + shape := slices.Clone(t.Shape()) + + switch { + case strings.HasSuffix(name, ".ssm_a"): + shape = normalizeVectorShapeToColumn(shape) + t.SetRepacker(func(_ string, data []float32, _ []uint64) ([]float32, error) { + out := make([]float32, len(data)) + for i, v := range data { + out[i] = -float32(math.Exp(float64(v))) + } + return out, nil + }) + case strings.HasSuffix(name, ".ssm_d"): + shape = normalizeVectorShapeToColumn(shape) + case strings.HasSuffix(name, ".ssm_norm.weight"): + switch len(shape) { + case 1: + if nGroups > 0 && shape[0]%nGroups == 0 { + shape = []uint64{nGroups, shape[0] / nGroups} + } + case 2: + if shape[0] == 1 && nGroups > 0 && shape[1]%nGroups == 0 { + shape = []uint64{nGroups, shape[1] / nGroups} + } + } + case strings.HasSuffix(name, ".ssm_conv1d.weight"): + if len(shape) == 3 { + if shape[0] == 1 { + shape = []uint64{shape[1], shape[2]} + } else if shape[1] == 1 { + shape = []uint64{shape[0], shape[2]} + } + } + } + + out = append(out, &ggml.Tensor{ + Name: name, + Kind: t.Kind(), + Shape: shape, + WriterTo: t, + }) + } + + return out +} + +func (n *nemotronHModel) Replacements() []string { + return []string{ + // Embedding and output + "lm_head", "output", + "backbone.embeddings", "token_embd", + "backbone.norm_f", "output_norm", + "backbone.layers", "blk", + + // Recurrent (Mamba2) tensors + "mixer.in_proj", "ssm_in", + "mixer.out_proj", "ssm_out", + "mixer.dt_bias", "ssm_dt.bias", + "mixer.A_log", "ssm_a", + "mixer.D", "ssm_d", + "mixer.conv1d", "ssm_conv1d", + "mixer.norm.weight", "ssm_norm.weight", + + // Attention tensors + "mixer.q_proj", "attn_q", + "mixer.k_proj", "attn_k", + "mixer.v_proj", "attn_v", + "mixer.o_proj", "attn_output", + + // FFN / MoE tensors + "mixer.gate.e_score_correction_bias", "exp_probs_b.bias", + "mixer.gate", "ffn_gate_inp", + "mixer.fc1_latent_proj", "ffn_latent_in", + "mixer.fc2_latent_proj", "ffn_latent_out", + "mixer.shared_experts.up_proj", "ffn_up_shexp", + "mixer.shared_experts.down_proj", "ffn_down_shexp", + "mixer.up_proj", "ffn_up", + "mixer.down_proj", "ffn_down", + + // Per-layer pre-norm + ".norm.weight", ".attn_norm.weight", + } +} diff --git a/convert/convert_nemotron_h_test.go b/convert/convert_nemotron_h_test.go new file mode 100644 index 000000000..db6a675fc --- /dev/null +++ b/convert/convert_nemotron_h_test.go @@ -0,0 +1,230 @@ +package convert + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "io" + "os" + "path/filepath" + "slices" + "strings" + "testing" +) + +func TestHybridPatternUnmarshal(t *testing.T) { + t.Run("string", func(t *testing.T) { + var p hybridPattern + if err := json.Unmarshal([]byte(`"MEM*"`), &p); err != nil { + t.Fatal(err) + } + if got, want := string(p), "MEM*"; got != want { + t.Fatalf("unexpected pattern: got %q want %q", got, want) + } + }) + + t.Run("array", func(t *testing.T) { + var p hybridPattern + if err := json.Unmarshal([]byte(`["M","E","M","*"]`), &p); err != nil { + t.Fatal(err) + } + if got, want := string(p), "MEM*"; got != want { + t.Fatalf("unexpected pattern: got %q want %q", got, want) + } + }) +} + +func TestNemotronHLayerArrays(t *testing.T) { + m := &nemotronHModel{ + NumHiddenLayers: 5, + NumAttentionHeads: 32, + NumKeyValueHeads: 8, + HybridOverridePattern: "MEM*E", + NRoutedExperts: 128, + NumExpertsPerTok: 6, + MoEIntermediateSize: 1856, + } + + headsKV, ffn, err := m.layerArrays() + if err != nil { + t.Fatal(err) + } + + if got, want := headsKV, []uint32{0, 0, 0, 8, 0}; !slices.Equal(got, want) { + t.Fatalf("unexpected head_count_kv: got %v want %v", got, want) + } + if got, want := ffn, []uint32{0, 1856, 0, 0, 1856}; !slices.Equal(got, want) { + t.Fatalf("unexpected feed_forward_length: got %v want %v", got, want) + } +} + +func TestNemotronHKV(t *testing.T) { + m := &nemotronHModel{ + MaxPositionEmbeddings: 1048576, + HiddenSize: 2688, + NumHiddenLayers: 5, + NumAttentionHeads: 32, + NumKeyValueHeads: 2, + HeadDim: 128, + LayerNormEpsilon: 1e-5, + RopeTheta: 10000, + PartialRotaryFactor: 0.5, + ConvKernel: 4, + SSMStateSize: 128, + MambaNumHeads: 64, + MambaHeadDim: 64, + NGroups: 8, + HybridOverridePattern: "MEM*E", + NRoutedExperts: 128, + NSharedExperts: 1, + NumExpertsPerTok: 6, + MoEIntermediateSize: 1856, + MoESharedExpertIntermediate: 3712, + NormTopKProb: true, + RoutedScalingFactor: 2.5, + } + if err := m.parseMore(nil); err != nil { + t.Fatal(err) + } + + kv := m.KV(&Tokenizer{Vocabulary: &Vocabulary{}}) + if got, want := kv["general.architecture"], "nemotron_h_moe"; got != want { + t.Fatalf("unexpected architecture: got %v want %v", got, want) + } + + headCountKV, ok := kv["attention.head_count_kv"].([]uint32) + if !ok { + t.Fatalf("attention.head_count_kv has unexpected type: %T", kv["attention.head_count_kv"]) + } + if got, want := headCountKV, []uint32{0, 0, 0, 2, 0}; !slices.Equal(got, want) { + t.Fatalf("unexpected attention.head_count_kv: got %v want %v", got, want) + } + + ffnLength, ok := kv["feed_forward_length"].([]uint32) + if !ok { + t.Fatalf("feed_forward_length has unexpected type: %T", kv["feed_forward_length"]) + } + if got, want := ffnLength, []uint32{0, 1856, 0, 0, 1856}; !slices.Equal(got, want) { + t.Fatalf("unexpected feed_forward_length: got %v want %v", got, want) + } +} + +func TestNemotronHTensorsTransforms(t *testing.T) { + m := &nemotronHModel{NGroups: 8} + in := []Tensor{ + &fakeTensor{ + name: "blk.0.ssm_a", + shape: []uint64{4}, + data: []float32{0, 1, 2, 3}, + }, + &fakeTensor{ + name: "blk.0.ssm_d", + shape: []uint64{4}, + data: []float32{0, 1, 2, 3}, + }, + &fakeTensor{ + name: "blk.0.ssm_norm.weight", + shape: []uint64{16}, + data: make([]float32, 16), + }, + &fakeTensor{ + name: "blk.0.ssm_conv1d.weight", + shape: []uint64{10, 1, 4}, + data: make([]float32, 40), + }, + } + + out := m.Tensors(in) + if len(out) != len(in) { + t.Fatalf("unexpected output tensor count: got %d want %d", len(out), len(in)) + } + + got := map[string]struct { + shape []uint64 + writer io.WriterTo + }{} + for _, t := range out { + got[t.Name] = struct { + shape []uint64 + writer io.WriterTo + }{shape: t.Shape, writer: t.WriterTo} + } + + if shape := got["blk.0.ssm_a"].shape; !slices.Equal(shape, []uint64{4, 1}) { + t.Fatalf("unexpected ssm_a shape: %v", shape) + } + if shape := got["blk.0.ssm_d"].shape; !slices.Equal(shape, []uint64{4, 1}) { + t.Fatalf("unexpected ssm_d shape: %v", shape) + } + if shape := got["blk.0.ssm_norm.weight"].shape; !slices.Equal(shape, []uint64{8, 2}) { + t.Fatalf("unexpected ssm_norm shape: %v", shape) + } + if shape := got["blk.0.ssm_conv1d.weight"].shape; !slices.Equal(shape, []uint64{10, 4}) { + t.Fatalf("unexpected ssm_conv1d shape: %v", shape) + } + + var b bytes.Buffer + if _, err := got["blk.0.ssm_a"].writer.WriteTo(&b); err != nil { + t.Fatal(err) + } + values := make([]float32, 4) + if err := binary.Read(&b, binary.LittleEndian, &values); err != nil { + t.Fatal(err) + } + // 0 -> -exp(0) == -1 + if values[0] != -1 { + t.Fatalf("unexpected transformed ssm_a[0]: got %v want -1", values[0]) + } +} + +func TestNemotronHLoadModelMetadata(t *testing.T) { + tempDir := t.TempDir() + + config := `{ + "architectures": ["NemotronHForCausalLM"], + "model_type": "nemotron_h", + "num_hidden_layers": 4, + "hidden_size": 512, + "max_position_embeddings": 32768, + "num_attention_heads": 8, + "num_key_value_heads": 2, + "head_dim": 64, + "layer_norm_epsilon": 1e-5, + "conv_kernel": 4, + "ssm_state_size": 128, + "mamba_num_heads": 16, + "mamba_head_dim": 32, + "n_groups": 8, + "hybrid_override_pattern": "ME*M", + "n_routed_experts": 16, + "num_experts_per_tok": 4, + "moe_intermediate_size": 256 + }` + + if err := os.WriteFile(filepath.Join(tempDir, "config.json"), []byte(config), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tempDir, "tokenizer.json"), []byte(`{}`), 0o644); err != nil { + t.Fatal(err) + } + + kv, _, err := LoadModelMetadata(os.DirFS(tempDir)) + if err != nil { + t.Fatal(err) + } + if _, ok := kv.(*nemotronHModel); !ok { + t.Fatalf("unexpected converter type: %T", kv) + } +} + +func TestNemotronHReplacementsLatentProjections(t *testing.T) { + m := &nemotronHModel{} + r := strings.NewReplacer(m.Replacements()...) + + if got, want := r.Replace("backbone.layers.1.mixer.fc1_latent_proj.weight"), "blk.1.ffn_latent_in.weight"; got != want { + t.Fatalf("unexpected fc1 replacement: got %q want %q", got, want) + } + if got, want := r.Replace("backbone.layers.1.mixer.fc2_latent_proj.weight"), "blk.1.ffn_latent_out.weight"; got != want { + t.Fatalf("unexpected fc2 replacement: got %q want %q", got, want) + } +} diff --git a/convert/json_compat.go b/convert/json_compat.go new file mode 100644 index 000000000..281c170b3 --- /dev/null +++ b/convert/json_compat.go @@ -0,0 +1,97 @@ +package convert + +// sanitizeNonFiniteJSON rewrites non-standard JSON numeric tokens that some +// HF configs emit (Infinity, -Infinity, NaN) into standard JSON numbers. +// +// This is intentionally conservative: +// - only runs outside quoted strings +// - only rewrites full tokens +// +// We map these values to 0 because encoding/json rejects non-finite values, +// and these fields are typically model-side metadata not consumed by the +// converter. +func sanitizeNonFiniteJSON(in []byte) []byte { + if len(in) == 0 { + return in + } + + out := make([]byte, 0, len(in)) + inString := false + escape := false + + for i := 0; i < len(in); { + c := in[i] + + if inString { + out = append(out, c) + if escape { + escape = false + } else if c == '\\' { + escape = true + } else if c == '"' { + inString = false + } + i++ + continue + } + + if c == '"' { + inString = true + out = append(out, c) + i++ + continue + } + + if hasToken(in, i, "-Infinity") { + out = append(out, '0') + i += len("-Infinity") + continue + } + + if hasToken(in, i, "Infinity") { + out = append(out, '0') + i += len("Infinity") + continue + } + + if hasToken(in, i, "NaN") { + out = append(out, '0') + i += len("NaN") + continue + } + + out = append(out, c) + i++ + } + + return out +} + +func hasToken(in []byte, at int, tok string) bool { + end := at + len(tok) + if at < 0 || end > len(in) { + return false + } + if string(in[at:end]) != tok { + return false + } + if at > 0 && !isJSONValuePrefixBoundary(in[at-1]) { + return false + } + if end < len(in) && !isJSONValueSuffixBoundary(in[end]) { + return false + } + return true +} + +func isJSONWhitespace(b byte) bool { + return b == ' ' || b == '\t' || b == '\n' || b == '\r' +} + +func isJSONValuePrefixBoundary(b byte) bool { + return isJSONWhitespace(b) || b == ':' || b == ',' || b == '[' +} + +func isJSONValueSuffixBoundary(b byte) bool { + return isJSONWhitespace(b) || b == ',' || b == ']' || b == '}' +} diff --git a/convert/json_compat_test.go b/convert/json_compat_test.go new file mode 100644 index 000000000..05f1432b9 --- /dev/null +++ b/convert/json_compat_test.go @@ -0,0 +1,46 @@ +package convert + +import "testing" + +func TestSanitizeNonFiniteJSON(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + { + name: "infinity token", + in: `{"a":[0,Infinity,1]}`, + want: `{"a":[0,0,1]}`, + }, + { + name: "negative infinity token", + in: `{"a":-Infinity}`, + want: `{"a":0}`, + }, + { + name: "nan token", + in: `{"a":NaN}`, + want: `{"a":0}`, + }, + { + name: "tokens inside strings untouched", + in: `{"a":"Infinity -Infinity NaN","b":Infinity}`, + want: `{"a":"Infinity -Infinity NaN","b":0}`, + }, + { + name: "identifier-like token untouched", + in: `{"a":InfinityValue}`, + want: `{"a":InfinityValue}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := string(sanitizeNonFiniteJSON([]byte(tt.in))) + if got != tt.want { + t.Fatalf("sanitizeNonFiniteJSON() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 11bfb8c90..d75ac19fa 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -160,6 +160,27 @@ func (kv KV) SSMGroupCount() uint64 { return uint64(kv.Uint("ssm.group_count")) } +func (kv KV) FFNLength() []uint64 { + ffnLengthDefault := uint32(0) + ffnLength := kv.UintOrArrayValueAsArray("feed_forward_length", ffnLengthDefault) + if len(ffnLength) == 1 { + ffnLengthDefault = ffnLength[0] + } + nLayers := int(kv.BlockCount()) + if len(ffnLength) > nLayers { + slog.Warn("got more elements of feed_forward_length than layers", "len(ffnLength)", len(ffnLength), "layers", nLayers) + } + out := make([]uint64, nLayers) + for i := range nLayers { + if i >= len(ffnLength) { + out[i] = uint64(ffnLengthDefault) + } else { + out[i] = uint64(ffnLength[i]) + } + } + return out +} + // general types func (kv KV) String(key string, defaultValue ...string) string { @@ -264,6 +285,7 @@ func (kv KV) OllamaEngineRequired() bool { "llama4", "mistral3", "mllama", + "nemotron_h", "nemotron_h_moe", "nomic-bert", "olmo3", "qwen25vl", @@ -865,6 +887,7 @@ func (f GGML) FlashAttention() bool { "gptoss", "gpt-oss", "lfm2", "mistral3", + "nemotron_h", "nemotron_h_moe", "olmo3", "qwen3", "qwen3moe", "qwen3next", diff --git a/kvcache/recurrent.go b/kvcache/recurrent.go new file mode 100644 index 000000000..1b5765530 --- /dev/null +++ b/kvcache/recurrent.go @@ -0,0 +1,752 @@ +package kvcache + +import ( + "errors" + "fmt" + "math" + "slices" + + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" +) + +const ( + DefaultCheckpointCount = 32 + DefaultCheckpointMinPos = int32(16) + DefaultCheckpointInterval = int32(1280) +) + +var ErrInvalidRecurrentShape = errors.New("kvcache: invalid recurrent state shape") + +// Config configures a shared hybrid recurrent cache. +type RecurrentConfig struct { + Shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) + ConvDim int + ConvChannels int + RecurrentStateSize int + CheckpointLogPrefix string +} + +var ( + _ Cache = (*Recurrent)(nil) + _ CheckpointCache = (*Recurrent)(nil) +) + +// Cache stores: +// - a standard causal KV cache +// - per-sequence conv state for recurrent operators +// - per-sequence recurrent state for recurrent operators +// +// Conv state shape (per layer, per sequence): [convDim, convChannels] +// Recurrent state shape (per layer, per sequence): [recurrentStateSize] +type Recurrent struct { + kv *Causal + + backend ml.Backend + dtype ml.DType + maxSequences int + + // Conv state dimensions + convDim int + convChannels int + + // Recurrent state dimensions + recurrentStateSize int + + logPrefix string + + // slot mapping for recurrent state (copy-on-write) + slotForSeq map[int]int + refCount []int + freeSlots []int + seqCounts map[int]int + slotScratch [1]int32 + + // per-layer conv state buffers (allocated lazily) + convCtxs map[int]ml.Context + convStates map[int]ml.Tensor // [convDim*convChannels, maxSlots] + + // per-layer recurrent state buffers (allocated lazily) + recurrentCtxs map[int]ml.Context + recurrentStates map[int]ml.Tensor // [recurrentStateSize, maxSlots] + + // recurrent checkpoints (per slot) + checkpointCount int + checkpointMinPos int32 + checkpointInterval int32 + checkpointCtxSize int + checkpoints map[int]*slotCheckpointStore + pendingRestore map[int]checkpointRestore + curCheckpointPos []int32 + curCheckpointSlots map[int]int + reserveCheckpoints bool + checkpointConvCtxs map[int]ml.Context + checkpointRecurCtxs map[int]ml.Context + checkpointReserved map[int]struct{} + + // current forward batch (derived in StartForward) + curSeqs []int + curSlots []int + curSlotsInput ml.Tensor + curSeqTokens int + + // track if EnsureWritable has been called for this forward pass + writableEnsured bool + writableError error +} + +func NewRecurrentCache(config RecurrentConfig) *Recurrent { + return &Recurrent{ + kv: NewCausalCache(config.Shift), + convDim: config.ConvDim, + convChannels: config.ConvChannels, + recurrentStateSize: config.RecurrentStateSize, + logPrefix: config.CheckpointLogPrefix, + slotForSeq: make(map[int]int), + seqCounts: make(map[int]int), + convCtxs: make(map[int]ml.Context), + convStates: make(map[int]ml.Tensor), + recurrentCtxs: make(map[int]ml.Context), + recurrentStates: make(map[int]ml.Tensor), + checkpointCount: DefaultCheckpointCount, + checkpointMinPos: DefaultCheckpointMinPos, + checkpointInterval: DefaultCheckpointInterval, + checkpoints: make(map[int]*slotCheckpointStore), + pendingRestore: make(map[int]checkpointRestore), + curCheckpointSlots: make(map[int]int), + checkpointConvCtxs: make(map[int]ml.Context), + checkpointRecurCtxs: make(map[int]ml.Context), + checkpointReserved: make(map[int]struct{}), + } +} + +func (c *Recurrent) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { + c.backend = backend + c.dtype = dtype + c.maxSequences = maxSequences + c.checkpoints = make(map[int]*slotCheckpointStore) + c.pendingRestore = make(map[int]checkpointRestore) + c.curCheckpointPos = c.curCheckpointPos[:0] + c.curCheckpointSlots = make(map[int]int) + c.checkpointReserved = make(map[int]struct{}) + c.checkpointCtxSize = c.checkpointCount * c.maxSequences + if c.checkpointCtxSize < 8 { + c.checkpointCtxSize = 8 + } + + // initialize slot allocator + c.refCount = make([]int, maxSequences) + c.freeSlots = c.freeSlots[:0] + for i := maxSequences - 1; i >= 0; i-- { + c.freeSlots = append(c.freeSlots, i) + } + + c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch) +} + +func (c *Recurrent) Close() { + for _, ctx := range c.convCtxs { + ctx.Close() + } + for _, ctx := range c.recurrentCtxs { + ctx.Close() + } + for _, ctx := range c.checkpointConvCtxs { + ctx.Close() + } + for _, ctx := range c.checkpointRecurCtxs { + ctx.Close() + } + c.kv.Close() +} + +func (c *Recurrent) SetConfig(config ml.CacheConfig) { + c.kv.SetConfig(config) +} + +func (c *Recurrent) SetLayer(layer int) { + c.kv.SetLayer(layer) +} + +func (c *Recurrent) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { + return c.kv.Get(ctx) +} + +func (c *Recurrent) Put(ctx ml.Context, key, value ml.Tensor) { + c.kv.Put(ctx, key, value) +} + +func (c *Recurrent) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { + if err := c.kv.StartForward(ctx, batch, reserve); err != nil { + return err + } + + nTokens := len(batch.Sequences) + if nTokens == 0 { + c.curSeqs = c.curSeqs[:0] + c.curSlots = c.curSlots[:0] + c.curSlotsInput = nil + c.curSeqTokens = 0 + c.reserveCheckpoints = false + c.writableEnsured = false + c.writableError = nil + return nil + } + + // Fast path for single-sequence batches (common during decode and prefill). + firstSeq := batch.Sequences[0] + singleSeq := true + for _, s := range batch.Sequences[1:] { + if s != firstSeq { + singleSeq = false + break + } + } + if singleSeq { + return c.startForwardSingleSeq(ctx, firstSeq, nTokens, batch, reserve) + } + + // Derive equal-length sequence layout for recurrent layers. + seqCounts := c.seqCounts + for s := range seqCounts { + delete(seqCounts, s) + } + + c.curSeqs = c.curSeqs[:0] + for _, s := range batch.Sequences { + if seqCounts[s] == 0 { + c.curSeqs = append(c.curSeqs, s) + } + seqCounts[s]++ + } + + nSeqs := len(c.curSeqs) + want := nTokens / nSeqs + for _, s := range c.curSeqs { + if seqCounts[s] != want { + return ErrNotSupported + } + } + + c.curSeqTokens = want + + if reserve { + c.curSlots = c.curSlots[:0] + for i := range nSeqs { + c.curSlots = append(c.curSlots, i) + } + c.finalizeStartForward(ctx, batch, true) + return nil + } + + // Ensure slots exist for sequences in this batch. + c.curSlots = c.curSlots[:0] + var newSlots []int + for _, s := range c.curSeqs { + slot, ok := c.slotForSeq[s] + if !ok { + var err error + slot, err = c.allocSlot() + if err != nil { + return err + } + c.slotForSeq[s] = slot + c.refCount[slot] = 1 + newSlots = append(newSlots, slot) + } + c.curSlots = append(c.curSlots, slot) + } + + if len(newSlots) > 0 { + c.zeroSlots(ctx, newSlots) + } + + c.finalizeStartForward(ctx, batch, false) + + return nil +} + +func (c *Recurrent) startForwardSingleSeq(ctx ml.Context, seq, seqTokens int, batch input.Batch, reserve bool) error { + c.curSeqs = append(c.curSeqs[:0], seq) + c.curSeqTokens = seqTokens + + if reserve { + c.curSlots = append(c.curSlots[:0], 0) + c.finalizeStartForward(ctx, batch, true) + return nil + } + + slot, ok := c.slotForSeq[seq] + if !ok { + var err error + slot, err = c.allocSlot() + if err != nil { + return err + } + + c.slotForSeq[seq] = slot + c.refCount[slot] = 1 + slotList := [1]int{slot} + c.zeroSlots(ctx, slotList[:]) + } + + c.curSlots = append(c.curSlots[:0], slot) + c.finalizeStartForward(ctx, batch, false) + + return nil +} + +func (c *Recurrent) finalizeStartForward(ctx ml.Context, batch input.Batch, reserve bool) { + c.setCurSlotsInput(ctx) + c.writableEnsured = false + c.writableError = nil + c.reserveCheckpoints = reserve + c.planCheckpoints(batch) +} + +func (c *Recurrent) setCurSlotsInput(ctx ml.Context) { + c.curSlotsInput = c.slotsInput(ctx, c.curSlots) +} + +func (c *Recurrent) slotsInput(ctx ml.Context, slots []int) ml.Tensor { + switch len(slots) { + case 0: + return nil + case 1: + c.slotScratch[0] = int32(slots[0]) + return ctx.Input().FromInts(c.slotScratch[:], 1) + default: + slotIndices := make([]int32, len(slots)) + for i, v := range slots { + slotIndices[i] = int32(v) + } + return ctx.Input().FromInts(slotIndices, len(slotIndices)) + } +} + +func (c *Recurrent) allocSlot() (int, error) { + if len(c.freeSlots) == 0 { + return 0, ErrKvCacheFull + } + slot := c.freeSlots[len(c.freeSlots)-1] + c.freeSlots = c.freeSlots[:len(c.freeSlots)-1] + return slot, nil +} + +func (c *Recurrent) freeSlot(slot int) { + if slot >= 0 && slot < c.maxSequences { + c.freeSlots = append(c.freeSlots, slot) + } +} + +// zeroSlots zeros recurrent state for the given slots across all cached layers. +func (c *Recurrent) zeroSlots(ctx ml.Context, slots []int) { + if len(slots) == 0 { + return + } + + inputCtx := ctx.Input() + slotsTensor := c.slotsInput(ctx, slots) + + if len(c.convStates) > 0 { + zeros := inputCtx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, len(slots)) + for _, buf := range c.convStates { + ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor)) + } + } + + if len(c.recurrentStates) > 0 { + zeros := inputCtx.Zeros(ml.DTypeF32, c.recurrentStateSize, len(slots)) + for _, buf := range c.recurrentStates { + ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor)) + } + } +} + +// EnsureWritable ensures sequences have private slots (copy-on-write). +func (c *Recurrent) EnsureWritable(ctx ml.Context) error { + for i, seq := range c.curSeqs { + slot, ok := c.slotForSeq[seq] + if !ok { + continue + } + + if slot < 0 || slot >= len(c.refCount) { + continue + } + + if c.refCount[slot] <= 1 { + continue + } + + newSlot, err := c.allocSlot() + if err != nil { + return err + } + c.refCount[slot]-- + c.refCount[newSlot] = 1 + c.slotForSeq[seq] = newSlot + c.curSlots[i] = newSlot + + c.copyRecurrentState(ctx, slot, newSlot) + c.copyCheckpoints(ctx, slot, newSlot) + } + + c.setCurSlotsInput(ctx) + + return nil +} + +func (c *Recurrent) copyRecurrentState(ctx ml.Context, srcSlot, dstSlot int) { + src := ctx.Input().FromInts([]int32{int32(srcSlot)}, 1) + dst := ctx.Input().FromInts([]int32{int32(dstSlot)}, 1) + + for _, buf := range c.convStates { + rows := buf.Rows(ctx, src) + if rows.DType() != ml.DTypeF32 { + rows = rows.Cast(ctx, ml.DTypeF32) + } + ctx.Forward(buf.SetRows(ctx, rows, dst)) + } + + for _, buf := range c.recurrentStates { + rows := buf.Rows(ctx, src) + if rows.DType() != ml.DTypeF32 { + rows = rows.Cast(ctx, ml.DTypeF32) + } + ctx.Forward(buf.SetRows(ctx, rows, dst)) + } +} + +func (c *Recurrent) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) { + c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen) + + if dstSlot, ok := c.slotForSeq[dstSeq]; ok { + if c.validSlot(dstSlot) { + c.refCount[dstSlot]-- + if c.refCount[dstSlot] <= 0 { + c.refCount[dstSlot] = 0 + c.freeSlot(dstSlot) + } + } + delete(c.slotForSeq, dstSeq) + } + + srcSlot, ok := c.slotForSeq[srcSeq] + if !ok { + return + } + + if c.validSlot(srcSlot) { + c.slotForSeq[dstSeq] = srcSlot + c.refCount[srcSlot]++ + } +} + +func (c *Recurrent) CanResume(seq int, pos int32) bool { + if !c.kv.CanResume(seq, pos) { + return false + } + if pos == 0 { + return true + } + return c.hasCheckpoint(seq, pos) +} + +func (c *Recurrent) Remove(seq int, beginIndex, endIndex int32) error { + if beginIndex > 0 && endIndex != math.MaxInt32 { + if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil { + return err + } + delete(c.pendingRestore, seq) + + slot, ok := c.slotForSeq[seq] + if !ok || !c.validSlot(slot) { + return nil + } + + // Detach shared recurrent state/checkpoints before mutating checkpoint positions. + if c.refCount[slot] > 1 { + newSlot, err := c.allocSlot() + if err != nil { + return err + } + ctx := c.backend.NewContext() + c.copyRecurrentState(ctx, slot, newSlot) + c.copyCheckpoints(ctx, slot, newSlot) + if len(c.convStates) > 0 || len(c.recurrentStates) > 0 { + ctx.Compute() + } + ctx.Close() + + c.refCount[slot]-- + c.refCount[newSlot] = 1 + c.slotForSeq[seq] = newSlot + slot = newSlot + } + + c.shiftCheckpoints(slot, beginIndex, endIndex) + return nil + } + + if beginIndex > 0 { + restore, ok := c.pendingRestore[seq] + if !ok || restore.pos+1 != beginIndex { + return ErrNotSupported + } + if !c.restoreComplete(restore) { + return ErrNotSupported + } + if slot, ok := c.slotForSeq[seq]; ok && c.validSlot(slot) && c.refCount[slot] > 1 { + newSlot, err := c.allocSlot() + if err != nil { + return err + } + ctx := c.backend.NewContext() + c.copyRecurrentState(ctx, slot, newSlot) + c.copyCheckpoints(ctx, slot, newSlot) + if len(c.convStates) > 0 || len(c.recurrentStates) > 0 { + ctx.Compute() + } + ctx.Close() + + c.refCount[slot]-- + c.refCount[newSlot] = 1 + c.slotForSeq[seq] = newSlot + + restore.slot = newSlot + c.pendingRestore[seq] = restore + } + } + + if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil { + return err + } + + if beginIndex > 0 { + restore := c.pendingRestore[seq] + delete(c.pendingRestore, seq) + return c.applyCheckpointRestore(restore) + } + + slot, ok := c.slotForSeq[seq] + delete(c.pendingRestore, seq) + if !ok { + return nil + } + + if !c.validSlot(slot) { + delete(c.slotForSeq, seq) + return nil + } + + c.refCount[slot]-- + if c.refCount[slot] <= 0 { + c.refCount[slot] = 0 + c.clearCheckpoints(slot) + c.freeSlot(slot) + } + delete(c.slotForSeq, seq) + + return nil +} + +func (c *Recurrent) validSlot(slot int) bool { + return slot >= 0 && slot < len(c.refCount) +} + +func (c *Recurrent) SlotsTensor() ml.Tensor { + return c.curSlotsInput +} + +// contiguousSlots returns the starting slot if current slots are contiguous and ordered. +func (c *Recurrent) contiguousSlots() (int, bool) { + if len(c.curSlots) == 0 { + return 0, false + } + start := c.curSlots[0] + for i, s := range c.curSlots { + if s != start+i { + return 0, false + } + } + return start, true +} + +func (c *Recurrent) SeqTokens() int { + return c.curSeqTokens +} + +func (c *Recurrent) NumSeqs() int { + return len(c.curSeqs) +} + +func (c *Recurrent) convBuffer(layer int) ml.Tensor { + if buf, ok := c.convStates[layer]; ok { + return buf + } + + if _, ok := c.convCtxs[layer]; !ok { + c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer) + } + + buf := c.convCtxs[layer].Zeros(ml.DTypeF32, c.convDim*c.convChannels, c.maxSequences) + c.convStates[layer] = buf + return buf +} + +func (c *Recurrent) recurrentBuffer(layer int) ml.Tensor { + if buf, ok := c.recurrentStates[layer]; ok { + return buf + } + + if _, ok := c.recurrentCtxs[layer]; !ok { + c.recurrentCtxs[layer] = c.backend.NewContextSize(1).Layer(layer) + } + + buf := c.recurrentCtxs[layer].Zeros(ml.DTypeF32, c.recurrentStateSize, c.maxSequences) + c.recurrentStates[layer] = buf + return buf +} + +func (c *Recurrent) ensureWritable(ctx ml.Context) error { + c.ensureWritableOnce(ctx) + return c.writableError +} + +func (c *Recurrent) currentSlotRows(ctx ml.Context, buf ml.Tensor, rowSize int) ml.Tensor { + if start, ok := c.contiguousSlots(); ok { + offset := start * buf.Stride(1) + return buf.View(ctx, offset, rowSize, buf.Stride(1), c.NumSeqs()) + } + + return buf.Rows(ctx, c.SlotsTensor()) +} + +func (c *Recurrent) writeCurrentSlotRows(ctx ml.Context, buf ml.Tensor, rowSize int, src ml.Tensor) { + if start, ok := c.contiguousSlots(); ok { + offset := start * buf.Stride(1) + view := buf.View(ctx, offset, rowSize, buf.Stride(1), c.NumSeqs()) + ctx.Forward(src.Copy(ctx, view)) + return + } + + ctx.Forward(buf.SetRows(ctx, src, c.SlotsTensor())) +} + +func (c *Recurrent) ensureWritableOnce(ctx ml.Context) { + if !c.writableEnsured { + needsWritable := false + for _, seq := range c.curSeqs { + slot, ok := c.slotForSeq[seq] + if !ok { + continue + } + if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 { + needsWritable = true + break + } + } + + if needsWritable { + if err := c.EnsureWritable(ctx); err != nil { + c.writableError = err + } + } + c.writableEnsured = true + } +} + +// ConvState returns conv state for current batch sequences as [convDim, convChannels, nSeqs]. +func (c *Recurrent) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) { + if err := c.ensureWritable(ctx); err != nil { + return nil, err + } + + buf := c.convBuffer(layer) + cur := c.currentSlotRows(ctx, buf, c.convDim*c.convChannels) + return cur.Reshape(ctx, c.convDim, c.convChannels, c.NumSeqs()), nil +} + +// UpdateConvState writes new conv state for current batch sequences. +func (c *Recurrent) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) { + buf := c.convBuffer(layer) + src := newState.Reshape(ctx, c.convDim*c.convChannels, c.NumSeqs()) + srcF32 := src + if src.DType() != ml.DTypeF32 { + srcF32 = src.Cast(ctx, ml.DTypeF32) + } + c.writeCurrentSlotRows(ctx, buf, c.convDim*c.convChannels, srcF32) + + c.captureConvCheckpoint(ctx, layer, srcF32) +} + +// RecurrentState returns recurrent state for current batch sequences with shape [dims..., nSeqs]. +func (c *Recurrent) RecurrentState(ctx ml.Context, layer int, dims ...int) (ml.Tensor, error) { + if err := c.ensureWritable(ctx); err != nil { + return nil, err + } + if len(dims) == 0 { + return nil, ErrInvalidRecurrentShape + } + + size := 1 + for _, d := range dims { + if d <= 0 { + return nil, ErrInvalidRecurrentShape + } + size *= d + } + if size != c.recurrentStateSize { + return nil, fmt.Errorf("%w: got %v (size %d), want size %d", ErrInvalidRecurrentShape, dims, size, c.recurrentStateSize) + } + + buf := c.recurrentBuffer(layer) + cur := c.currentSlotRows(ctx, buf, c.recurrentStateSize) + shape := make([]int, 0, len(dims)+1) + shape = append(shape, dims...) + shape = append(shape, c.NumSeqs()) + return cur.Reshape(ctx, shape...), nil +} + +// RecurrentState4D returns recurrent state as [dim0, dim1, dim2, nSeqs]. +func (c *Recurrent) RecurrentState4D(ctx ml.Context, layer int, dim0, dim1, dim2 int) (ml.Tensor, error) { + if err := c.ensureWritable(ctx); err != nil { + return nil, err + } + if dim0 <= 0 || dim1 <= 0 || dim2 <= 0 { + return nil, ErrInvalidRecurrentShape + } + + size := dim0 * dim1 * dim2 + if size != c.recurrentStateSize { + return nil, fmt.Errorf("%w: got [%d %d %d] (size %d), want size %d", ErrInvalidRecurrentShape, dim0, dim1, dim2, size, c.recurrentStateSize) + } + + buf := c.recurrentBuffer(layer) + cur := c.currentSlotRows(ctx, buf, c.recurrentStateSize) + return cur.Reshape(ctx, dim0, dim1, dim2, c.NumSeqs()), nil +} + +// UpdateRecurrentState writes new recurrent state for current batch sequences. +func (c *Recurrent) UpdateRecurrentState(ctx ml.Context, layer int, newState ml.Tensor) { + buf := c.recurrentBuffer(layer) + src := newState.Reshape(ctx, c.recurrentStateSize, c.NumSeqs()) + srcF32 := src + if src.DType() != ml.DTypeF32 { + srcF32 = src.Cast(ctx, ml.DTypeF32) + } + c.writeCurrentSlotRows(ctx, buf, c.recurrentStateSize, srcF32) + + c.captureRecurrentCheckpoint(ctx, layer, srcF32) +} + +// IsSupportedForBatch returns true if the current batch layout supports recurrent layers. +func (c *Recurrent) IsSupportedForBatch() bool { + return c.curSeqTokens > 0 && len(c.curSeqs) > 0 +} + +// Seqs returns the ordered unique sequences for the current forward pass. +func (c *Recurrent) Seqs() []int { + return slices.Clone(c.curSeqs) +} diff --git a/kvcache/recurrent_checkpoints.go b/kvcache/recurrent_checkpoints.go new file mode 100644 index 000000000..1e029a5b3 --- /dev/null +++ b/kvcache/recurrent_checkpoints.go @@ -0,0 +1,561 @@ +package kvcache + +import ( + "log/slog" + "math" + + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" +) + +// TODO(jmorganca): Add byte-serialized host-RAM checkpoints to reduce GPU +// memory usage while preserving prefix reuse for recurrent state. + +type checkpointEntry struct { + pos int32 + conv map[int]ml.Tensor + recurrent map[int]ml.Tensor +} + +type slotCheckpointStore struct { + entries []checkpointEntry + size int + next int + lastPos int32 +} + +type checkpointRestore struct { + slot int + idx int + pos int32 +} + +func newSlotCheckpointStore(n int) *slotCheckpointStore { + entries := make([]checkpointEntry, n) + for i := range entries { + entries[i].pos = -1 + } + return &slotCheckpointStore{ + entries: entries, + lastPos: -1, + } +} + +func (s *slotCheckpointStore) reset() { + s.size = 0 + s.next = 0 + s.lastPos = -1 + for i := range s.entries { + s.entries[i].pos = -1 + } +} + +func (s *slotCheckpointStore) record(pos int32) int { + if len(s.entries) == 0 { + return -1 + } + idx := s.next + s.next = (s.next + 1) % len(s.entries) + if s.size < len(s.entries) { + s.size++ + } + s.entries[idx].pos = pos + s.lastPos = pos + return idx +} + +func (s *slotCheckpointStore) bestIndex(targetPos int32) (int, int32, bool) { + bestIdx := -1 + bestPos := int32(-1) + for i := range s.entries { + pos := s.entries[i].pos + if pos < 0 || pos >= targetPos { + continue + } + if pos > bestPos { + bestPos = pos + bestIdx = i + } + } + if bestIdx < 0 { + return -1, -1, false + } + return bestIdx, bestPos, true +} + +func (s *slotCheckpointStore) pruneAfter(pos int32) { + if len(s.entries) == 0 { + s.size = 0 + s.next = 0 + s.lastPos = -1 + return + } + + size := 0 + next := -1 + minPos := int32(math.MaxInt32) + minIdx := 0 + for i := range s.entries { + if s.entries[i].pos > pos { + s.entries[i].pos = -1 + } + if s.entries[i].pos >= 0 { + size++ + if s.entries[i].pos < minPos { + minPos = s.entries[i].pos + minIdx = i + } + } else if next == -1 { + next = i + } + } + + s.size = size + if size == 0 { + s.next = 0 + s.lastPos = -1 + return + } + if next != -1 { + s.next = next + } else { + // Full ring: overwrite the oldest checkpoint next. + s.next = minIdx + } + s.lastPos = pos +} + +func (s *slotCheckpointStore) shiftRange(beginIndex, endIndex int32) { + if len(s.entries) == 0 { + s.size = 0 + s.next = 0 + s.lastPos = -1 + return + } + + offset := beginIndex - endIndex + + size := 0 + next := -1 + minPos := int32(math.MaxInt32) + maxPos := int32(-1) + minIdx := 0 + + for i := range s.entries { + pos := s.entries[i].pos + if pos >= 0 { + if pos >= beginIndex && pos < endIndex { + s.entries[i].pos = -1 + } else if pos >= endIndex { + s.entries[i].pos = pos + offset + } + } + + pos = s.entries[i].pos + if pos >= 0 { + size++ + if pos < minPos { + minPos = pos + minIdx = i + } + if pos > maxPos { + maxPos = pos + } + } else if next == -1 { + next = i + } + } + + s.size = size + if size == 0 { + s.next = 0 + s.lastPos = -1 + return + } + + if next != -1 { + s.next = next + } else { + // Full ring: overwrite the oldest checkpoint next. + s.next = minIdx + } + s.lastPos = maxPos +} + +func (s *slotCheckpointStore) window() (size int, minPos, maxPos, lastPos int32) { + minPos = int32(math.MaxInt32) + maxPos = int32(-1) + for i := range s.entries { + pos := s.entries[i].pos + if pos < 0 { + continue + } + size++ + if pos < minPos { + minPos = pos + } + if pos > maxPos { + maxPos = pos + } + } + if size == 0 { + minPos = -1 + maxPos = -1 + } + return size, minPos, maxPos, s.lastPos +} + +func (c *Recurrent) checkpointTag() string { + if c.logPrefix == "" { + return "kvcache.recurrent" + } + return c.logPrefix +} + +func (c *Recurrent) planCheckpoints(batch input.Batch) { + if c.checkpointCount == 0 || len(c.curSeqs) == 0 { + c.curCheckpointPos = c.curCheckpointPos[:0] + for k := range c.curCheckpointSlots { + delete(c.curCheckpointSlots, k) + } + return + } + + if cap(c.curCheckpointPos) < len(c.curSeqs) { + c.curCheckpointPos = make([]int32, len(c.curSeqs)) + } else { + c.curCheckpointPos = c.curCheckpointPos[:len(c.curSeqs)] + } + for i := range c.curCheckpointPos { + c.curCheckpointPos[i] = -1 + } + for k := range c.curCheckpointSlots { + delete(c.curCheckpointSlots, k) + } + + posMax := make(map[int]int32, len(c.curSeqs)) + for i, seq := range batch.Sequences { + pos := batch.Positions[i] + if cur, ok := posMax[seq]; !ok || pos > cur { + posMax[seq] = pos + } + } + + for i, seq := range c.curSeqs { + pos, ok := posMax[seq] + if !ok { + continue + } + if pos < c.checkpointMinPos { + continue + } + slot := c.curSlots[i] + store := c.checkpointStore(slot) + lastPos := store.lastPos + if lastPos < 0 || pos-lastPos >= c.checkpointInterval { + c.curCheckpointPos[i] = pos + } + } +} + +func (c *Recurrent) checkpointStore(slot int) *slotCheckpointStore { + store, ok := c.checkpoints[slot] + if ok { + return store + } + store = newSlotCheckpointStore(c.checkpointCount) + c.checkpoints[slot] = store + return store +} + +func (c *Recurrent) checkpointIndexForSlot(slot int, pos int32) int { + if c.checkpointCount == 0 { + return -1 + } + if idx, ok := c.curCheckpointSlots[slot]; ok { + return idx + } + store := c.checkpointStore(slot) + idx := store.record(pos) + if idx >= 0 { + c.curCheckpointSlots[slot] = idx + } + return idx +} + +func (c *Recurrent) hasCheckpoint(seq int, pos int32) bool { + if pos <= 0 { + return false + } + slot, ok := c.slotForSeq[seq] + if !ok { + return false + } + store, ok := c.checkpoints[slot] + if !ok { + return false + } + _, _, ok = store.bestIndex(pos) + return ok +} + +func (c *Recurrent) PrepareRestore(seq int, targetPos int32) (int32, bool) { + if targetPos <= 0 { + return 0, false + } + slot, ok := c.slotForSeq[seq] + if !ok { + return 0, false + } + store, ok := c.checkpoints[slot] + if !ok { + slog.Debug(c.checkpointTag()+": checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", 0) + return 0, false + } + idx, pos, ok := store.bestIndex(targetPos) + if !ok { + size, minPos, maxPos, lastPos := store.window() + slog.Debug(c.checkpointTag()+": checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", size, + "min", minPos, "max", maxPos, "last", lastPos) + return 0, false + } + c.pendingRestore[seq] = checkpointRestore{ + slot: slot, + idx: idx, + pos: pos, + } + return pos + 1, true +} + +func (c *Recurrent) applyCheckpointRestore(restore checkpointRestore) error { + entry, ok := c.restoreEntry(restore) + if !ok { + return ErrNotSupported + } + + ctx := c.backend.NewContext() + defer ctx.Close() + + slotIdx := ctx.Input().FromInts([]int32{int32(restore.slot)}, 1) + for layer, src := range entry.conv { + buf := c.convBuffer(layer) + ctx.Forward(buf.SetRows(ctx, src, slotIdx)) + } + for layer, src := range entry.recurrent { + buf := c.recurrentBuffer(layer) + ctx.Forward(buf.SetRows(ctx, src, slotIdx)) + } + + if len(entry.conv) > 0 || len(entry.recurrent) > 0 { + ctx.Compute() + } + store := c.checkpoints[restore.slot] + store.pruneAfter(restore.pos) + return nil +} + +func (c *Recurrent) restoreComplete(restore checkpointRestore) bool { + _, ok := c.restoreEntry(restore) + return ok +} + +func (c *Recurrent) restoreEntry(restore checkpointRestore) (*checkpointEntry, bool) { + store, ok := c.checkpoints[restore.slot] + if !ok || restore.idx < 0 || restore.idx >= len(store.entries) { + return nil, false + } + entry := &store.entries[restore.idx] + if entry.pos < 0 { + return nil, false + } + if !c.entryComplete(entry) { + return nil, false + } + return entry, true +} + +func (c *Recurrent) entryComplete(entry *checkpointEntry) bool { + for layer := range c.convStates { + if entry.conv == nil || entry.conv[layer] == nil { + return false + } + } + for layer := range c.recurrentStates { + if entry.recurrent == nil || entry.recurrent[layer] == nil { + return false + } + } + return true +} + +func (c *Recurrent) clearCheckpoints(slot int) { + if store, ok := c.checkpoints[slot]; ok { + store.reset() + } +} + +func (c *Recurrent) shiftCheckpoints(slot int, beginIndex, endIndex int32) { + if store, ok := c.checkpoints[slot]; ok { + store.shiftRange(beginIndex, endIndex) + } +} + +func (c *Recurrent) copyCheckpoints(ctx ml.Context, srcSlot, dstSlot int) { + if c.checkpointCount == 0 { + return + } + srcStore, ok := c.checkpoints[srcSlot] + if !ok || srcStore.size == 0 { + return + } + dstStore := c.checkpointStore(dstSlot) + dstStore.size = srcStore.size + dstStore.next = srcStore.next + dstStore.lastPos = srcStore.lastPos + + for i := range srcStore.entries { + srcEntry := &srcStore.entries[i] + dstEntry := &dstStore.entries[i] + dstEntry.pos = srcEntry.pos + if srcEntry.conv != nil { + if dstEntry.conv == nil { + dstEntry.conv = make(map[int]ml.Tensor) + } + for layer, src := range srcEntry.conv { + dst := c.ensureCheckpointConv(layer, dstEntry) + ctx.Forward(src.Copy(ctx, dst)) + } + } + if srcEntry.recurrent != nil { + if dstEntry.recurrent == nil { + dstEntry.recurrent = make(map[int]ml.Tensor) + } + for layer, src := range srcEntry.recurrent { + dst := c.ensureCheckpointRecurrent(layer, dstEntry) + ctx.Forward(src.Copy(ctx, dst)) + } + } + } +} + +func (c *Recurrent) captureConvCheckpoint(ctx ml.Context, layer int, src ml.Tensor) { + if c.checkpointCount == 0 { + return + } + if c.reserveCheckpoints { + c.reserveCheckpointConv(layer) + return + } + if len(c.curCheckpointPos) == 0 { + return + } + for i, pos := range c.curCheckpointPos { + if pos < 0 { + continue + } + slot := c.curSlots[i] + idx := c.checkpointIndexForSlot(slot, pos) + if idx < 0 { + continue + } + entry := &c.checkpoints[slot].entries[idx] + dst := c.ensureCheckpointConv(layer, entry) + seqSlice := src.Slice(ctx, 1, i, i+1, 1) + ctx.Forward(seqSlice.Copy(ctx, dst)) + } +} + +func (c *Recurrent) captureRecurrentCheckpoint(ctx ml.Context, layer int, src ml.Tensor) { + if c.checkpointCount == 0 { + return + } + if c.reserveCheckpoints { + c.reserveCheckpointRecurrent(layer) + return + } + if len(c.curCheckpointPos) == 0 { + return + } + for i, pos := range c.curCheckpointPos { + if pos < 0 { + continue + } + slot := c.curSlots[i] + idx := c.checkpointIndexForSlot(slot, pos) + if idx < 0 { + continue + } + entry := &c.checkpoints[slot].entries[idx] + dst := c.ensureCheckpointRecurrent(layer, entry) + seqSlice := src.Slice(ctx, 1, i, i+1, 1) + ctx.Forward(seqSlice.Copy(ctx, dst)) + } +} + +func (c *Recurrent) ensureCheckpointConv(layer int, entry *checkpointEntry) ml.Tensor { + if entry.conv == nil { + entry.conv = make(map[int]ml.Tensor) + } + if t, ok := entry.conv[layer]; ok { + return t + } + ctx, ok := c.checkpointConvCtxs[layer] + if !ok { + ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer) + c.checkpointConvCtxs[layer] = ctx + } + t := ctx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, 1) + entry.conv[layer] = t + return t +} + +func (c *Recurrent) ensureCheckpointRecurrent(layer int, entry *checkpointEntry) ml.Tensor { + if entry.recurrent == nil { + entry.recurrent = make(map[int]ml.Tensor) + } + if t, ok := entry.recurrent[layer]; ok { + return t + } + ctx, ok := c.checkpointRecurCtxs[layer] + if !ok { + ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer) + c.checkpointRecurCtxs[layer] = ctx + } + t := ctx.Zeros(ml.DTypeF32, c.recurrentStateSize, 1) + entry.recurrent[layer] = t + return t +} + +func (c *Recurrent) reserveCheckpointConv(layer int) { + key := checkpointReserveKey(layer, 0) + if _, ok := c.checkpointReserved[key]; ok { + return + } + for slot := range c.maxSequences { + store := c.checkpointStore(slot) + for i := range store.entries { + entry := &store.entries[i] + _ = c.ensureCheckpointConv(layer, entry) + } + } + c.checkpointReserved[key] = struct{}{} +} + +func (c *Recurrent) reserveCheckpointRecurrent(layer int) { + key := checkpointReserveKey(layer, 1) + if _, ok := c.checkpointReserved[key]; ok { + return + } + for slot := range c.maxSequences { + store := c.checkpointStore(slot) + for i := range store.entries { + entry := &store.entries[i] + _ = c.ensureCheckpointRecurrent(layer, entry) + } + } + c.checkpointReserved[key] = struct{}{} +} + +func checkpointReserveKey(layer int, kind int) int { + return layer*2 + kind +} diff --git a/kvcache/recurrent_checkpoints_test.go b/kvcache/recurrent_checkpoints_test.go new file mode 100644 index 000000000..cf7a7b99a --- /dev/null +++ b/kvcache/recurrent_checkpoints_test.go @@ -0,0 +1,288 @@ +package kvcache + +import ( + "errors" + "math" + "slices" + "testing" + + "github.com/ollama/ollama/ml" +) + +func newTestCache() *Recurrent { + return NewRecurrentCache(RecurrentConfig{ConvDim: 1, ConvChannels: 2, RecurrentStateSize: 2}) +} + +func TestSlotCheckpointStoreBestIndex(t *testing.T) { + store := newSlotCheckpointStore(2) + store.record(10) + store.record(20) + + _, pos, ok := store.bestIndex(15) + if !ok || pos != 10 { + t.Fatalf("expected best pos 10, got pos=%d ok=%v", pos, ok) + } + + store.record(30) // overwrite oldest (10) + + if _, _, ok := store.bestIndex(15); ok { + t.Fatalf("expected no checkpoint for targetPos=15 after overwrite") + } + + _, pos, ok = store.bestIndex(40) + if !ok || pos != 30 { + t.Fatalf("expected best pos 30, got pos=%d ok=%v", pos, ok) + } +} + +func TestCachePrepareRestore(t *testing.T) { + cache := newTestCache() + cache.checkpointCount = 3 + cache.checkpoints = make(map[int]*slotCheckpointStore) + cache.pendingRestore = make(map[int]checkpointRestore) + + cache.slotForSeq[1] = 0 + store := cache.checkpointStore(0) + store.record(5) + store.record(9) + store.record(15) + + restorePos, ok := cache.PrepareRestore(1, 12) + if !ok { + t.Fatalf("expected restore ok") + } + if restorePos != 10 { + t.Fatalf("expected restorePos 10, got %d", restorePos) + } + rest, ok := cache.pendingRestore[1] + if !ok { + t.Fatalf("expected pending restore entry") + } + if rest.pos != 9 { + t.Fatalf("expected pending restore pos 9, got %d", rest.pos) + } +} + +func TestSlotCheckpointStorePruneAfter(t *testing.T) { + store := newSlotCheckpointStore(3) + store.record(10) + store.record(20) + store.record(30) + + store.pruneAfter(20) + + if store.lastPos != 20 { + t.Fatalf("expected lastPos 20, got %d", store.lastPos) + } + + _, pos, ok := store.bestIndex(25) + if !ok || pos != 20 { + t.Fatalf("expected best pos 20 after prune, got pos=%d ok=%v", pos, ok) + } + + _, pos, ok = store.bestIndex(35) + if !ok || pos != 20 { + t.Fatalf("expected pruned best pos 20 for targetPos=35, got pos=%d ok=%v", pos, ok) + } +} + +func TestCacheRestoreRejectsIncompleteCheckpoint(t *testing.T) { + cache := newTestCache() + cache.checkpointCount = 3 + cache.checkpoints = make(map[int]*slotCheckpointStore) + cache.pendingRestore = make(map[int]checkpointRestore) + + cache.slotForSeq[1] = 0 + cache.refCount = []int{1} + cache.freeSlots = nil + + // Simulate layer 0 requires both conv and recurrent checkpoints. + cache.convStates[0] = nil + cache.recurrentStates[0] = nil + + store := cache.checkpointStore(0) + idx := store.record(9) + entry := &store.entries[idx] + entry.conv = map[int]ml.Tensor{0: nil} + // entry.recurrent intentionally missing + + cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9} + + err := cache.Remove(1, 10, math.MaxInt32) + if !errors.Is(err, ErrNotSupported) { + t.Fatalf("expected ErrNotSupported for incomplete checkpoint, got %v", err) + } +} + +func TestCacheRestoreAcceptsCompleteCheckpoint(t *testing.T) { + cache := newTestCache() + cache.checkpointCount = 3 + cache.checkpoints = make(map[int]*slotCheckpointStore) + cache.pendingRestore = make(map[int]checkpointRestore) + + cache.slotForSeq[1] = 0 + cache.refCount = []int{1} + cache.freeSlots = nil + + store := cache.checkpointStore(0) + idx := store.record(9) + + cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9} + + restore := cache.pendingRestore[1] + if !cache.restoreComplete(restore) { + t.Fatalf("expected restoreComplete to return true for complete checkpoint") + } +} + +func TestCacheRecurrentStateShapeValidation(t *testing.T) { + cache := newTestCache() + _, err := cache.RecurrentState(nil, 0, 3) + if !errors.Is(err, ErrInvalidRecurrentShape) { + t.Fatalf("expected ErrInvalidRecurrentShape, got %v", err) + } +} + +func TestSlotCheckpointStoreShiftRange(t *testing.T) { + store := newSlotCheckpointStore(5) + store.record(1) + store.record(4) + store.record(7) + store.record(10) + + store.shiftRange(2, 6) + + var positions []int32 + for i := range store.entries { + if store.entries[i].pos >= 0 { + positions = append(positions, store.entries[i].pos) + } + } + slices.Sort(positions) + + want := []int32{1, 3, 6} + if !slices.Equal(positions, want) { + t.Fatalf("unexpected shifted positions: got=%v want=%v", positions, want) + } + if store.lastPos != 6 { + t.Fatalf("expected lastPos 6, got %d", store.lastPos) + } +} + +func TestCacheRemoveMiddleShiftsCheckpoints(t *testing.T) { + cache := newTestCache() + cache.slotForSeq[1] = 0 + cache.refCount = []int{1} + cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: 0, pos: 1} + + store := cache.checkpointStore(0) + store.record(1) + store.record(4) + store.record(7) + store.record(10) + + if err := cache.Remove(1, 2, 6); err != nil { + t.Fatalf("expected middle remove to succeed, got %v", err) + } + + if _, ok := cache.pendingRestore[1]; ok { + t.Fatalf("expected pending restore to be cleared after middle remove") + } + + var positions []int32 + for i := range store.entries { + if store.entries[i].pos >= 0 { + positions = append(positions, store.entries[i].pos) + } + } + slices.Sort(positions) + + want := []int32{1, 3, 6} + if !slices.Equal(positions, want) { + t.Fatalf("unexpected checkpoint positions after remove: got=%v want=%v", positions, want) + } +} + +func TestSlotCheckpointStoreRingBufferWrapAround(t *testing.T) { + store := newSlotCheckpointStore(3) + + store.record(10) + store.record(20) + store.record(30) + + store.entries[0].conv = make(map[int]ml.Tensor) + store.entries[0].conv[0] = nil + store.entries[0].recurrent = make(map[int]ml.Tensor) + store.entries[0].recurrent[0] = nil + + store.record(40) + + if store.entries[0].conv == nil { + t.Fatalf("expected conv map to be preserved on reuse") + } + if store.entries[0].recurrent == nil { + t.Fatalf("expected recurrent map to be preserved on reuse") + } + if store.entries[0].pos != 40 { + t.Fatalf("expected entry 0 pos to be 40, got %d", store.entries[0].pos) + } +} + +func TestSlotCheckpointStoreFullCapacity(t *testing.T) { + store := newSlotCheckpointStore(2) + + idx1 := store.record(10) + idx2 := store.record(20) + + if idx1 != 0 || idx2 != 1 { + t.Fatalf("expected indices 0, 1, got %d, %d", idx1, idx2) + } + if store.size != 2 { + t.Fatalf("expected size 2, got %d", store.size) + } + + _, pos1, ok1 := store.bestIndex(15) + _, pos2, ok2 := store.bestIndex(25) + + if !ok1 || pos1 != 10 { + t.Fatalf("expected best pos 10 for target 15, got pos=%d ok=%v", pos1, ok1) + } + if !ok2 || pos2 != 20 { + t.Fatalf("expected best pos 20 for target 25, got pos=%d ok=%v", pos2, ok2) + } +} + +func TestSlotCheckpointStoreEmptyBuffer(t *testing.T) { + store := newSlotCheckpointStore(0) + + idx := store.record(10) + if idx != -1 { + t.Fatalf("expected record to return -1 for empty buffer, got %d", idx) + } + + _, _, ok := store.bestIndex(15) + if ok { + t.Fatalf("expected no checkpoint for empty buffer") + } +} + +func TestSlotCheckpointStorePruneAfterAll(t *testing.T) { + store := newSlotCheckpointStore(3) + store.record(10) + store.record(20) + store.record(30) + + store.pruneAfter(5) + + if store.size != 0 { + t.Fatalf("expected size 0 after pruning all, got %d", store.size) + } + if store.lastPos != -1 { + t.Fatalf("expected lastPos -1 after pruning all, got %d", store.lastPos) + } + + _, _, ok := store.bestIndex(100) + if ok { + t.Fatalf("expected no checkpoint after pruning all") + } +} diff --git a/llama/patches/0034-ggml-metal-guard-mul_mat_id-map0-and-add-ne20-22-spe.patch b/llama/patches/0034-ggml-metal-guard-mul_mat_id-map0-and-add-ne20-22-spe.patch new file mode 100644 index 000000000..ff0c8199d --- /dev/null +++ b/llama/patches/0034-ggml-metal-guard-mul_mat_id-map0-and-add-ne20-22-spe.patch @@ -0,0 +1,37 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: jmorganca +Date: Sun, 22 Feb 2026 14:12:30 -0800 +Subject: [PATCH] ggml-metal: guard mul_mat_id map0 and add ne20=22 + specialization + +--- + ggml/src/ggml-metal/ggml-metal-ops.cpp | 3 ++- + ggml/src/ggml-metal/ggml-metal.metal | 1 + + 2 files changed, 3 insertions(+), 1 deletion(-) + +diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp +index 4ac135603..ac5ad53db 100644 +--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp ++++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp +@@ -1961,7 +1961,8 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { + // ne21 = n_rows (batch size) + const int ne21_mm_id_min = 32; + +- if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) { ++ if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min) && ++ (ne20 == 1 || ne20 == 2 || ne20 == 4 || ne20 == 6 || ne20 == 8 || ne20 == 10 || ne20 == 16 || ne20 == 22)) { + // some Metal matrix data types require aligned pointers + // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) + //switch (op->src[0]->type) { +diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal +index c37447a10..4f338aa13 100644 +--- a/ggml/src/ggml-metal/ggml-metal.metal ++++ b/ggml/src/ggml-metal/ggml-metal.metal +@@ -9427,6 +9427,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_ + template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>; + template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>; + template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>; ++template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>; + + template + kernel void kernel_mul_mm_id( diff --git a/ml/backend.go b/ml/backend.go index 4e63d3399..3eb80b839 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -163,6 +163,7 @@ type Tensor interface { Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor Conv3D(ctx Context, weight Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) Tensor SSMConv(ctx Context, kernel Tensor) Tensor + SSMScan(ctx Context, x, dt, A, B, C, ids Tensor) Tensor IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index b2f96c761..46a94d147 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1662,6 +1662,13 @@ func (t *Tensor) SSMConv(ctx ml.Context, kernel ml.Tensor) ml.Tensor { } } +func (t *Tensor) SSMScan(ctx ml.Context, x, dt, A, B, C, ids ml.Tensor) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_ssm_scan(ctx.(*Context).ctx, t.t, x.(*Tensor).t, dt.(*Tensor).t, A.(*Tensor).t, B.(*Tensor).t, C.(*Tensor).t, ids.(*Tensor).t), + } +} + func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor { return &Tensor{ b: t.b, diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal index 9404c93ce..df243edcb 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal @@ -12249,6 +12249,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_ template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>; template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>; template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>; +template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>; template kernel void kernel_mul_mm_id( diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp index 4ac135603..ac5ad53db 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1961,7 +1961,8 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { // ne21 = n_rows (batch size) const int ne21_mm_id_min = 32; - if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) { + if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min) && + (ne20 == 1 || ne20 == 2 || ne20 == 4 || ne20 == 6 || ne20 == 8 || ne20 == 10 || ne20 == 16 || ne20 == 22)) { // some Metal matrix data types require aligned pointers // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) //switch (op->src[0]->type) { diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal index c37447a10..4f338aa13 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal @@ -9427,6 +9427,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_ template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>; template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>; template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>; +template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>; template kernel void kernel_mul_mm_id( diff --git a/model/model_test.go b/model/model_test.go index ed2868ff3..03b9460d0 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -67,6 +67,7 @@ func (f *fakeTensor) Tri(ctx ml.Context, _ int) ml.Tensor func (f *fakeTensor) Fill(ctx ml.Context, _ float32) ml.Tensor { return f } func (f *fakeTensor) Repeat4D(ctx ml.Context, _, _, _, _ int) ml.Tensor { return f } func (f *fakeTensor) SolveTri(ctx ml.Context, _ ml.Tensor, _, _, _ bool) ml.Tensor { return f } +func (f *fakeTensor) SSMScan(ctx ml.Context, _, _, _, _, _, _ ml.Tensor) ml.Tensor { return f } func (m *fakeBackend) Get(name string) ml.Tensor { if slices.Contains(m.names, name) { diff --git a/model/models/models.go b/model/models/models.go index d4a8dc536..20d9c106c 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -15,6 +15,7 @@ import ( _ "github.com/ollama/ollama/model/models/llama4" _ "github.com/ollama/ollama/model/models/mistral3" _ "github.com/ollama/ollama/model/models/mllama" + _ "github.com/ollama/ollama/model/models/nemotronh" _ "github.com/ollama/ollama/model/models/nomicbert" _ "github.com/ollama/ollama/model/models/olmo3" _ "github.com/ollama/ollama/model/models/qwen2" diff --git a/model/models/nemotronh/attention.go b/model/models/nemotronh/attention.go new file mode 100644 index 000000000..311fb76d2 --- /dev/null +++ b/model/models/nemotronh/attention.go @@ -0,0 +1,88 @@ +package nemotronh + +import ( + "fmt" + "math" + + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" +) + +// Attention implements simple attention without RoPE for Nemotron-H. +// Unlike Qwen3Next, Nemotron-H attention has: +// - No RoPE (position info comes from Mamba2 layers) +// - Standard scaled dot-product attention +type Attention struct { + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` +} + +func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) { + hiddenDim := hiddenStates.Dim(0) + nSeqTokens := hiddenStates.Dim(1) + switch hiddenStates.Dim(2) { + case 0: + hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, nSeqTokens, 1) + case 1: + default: + return nil, ErrUnsupportedBatchLayout + } + + // Nemotron-H is currently clamped to num_parallel=1. + if cache != nil && cache.IsSupportedForBatch() { + if cache.numSeqs() != 1 { + return nil, ErrUnsupportedBatchLayout + } + if seqTokens := cache.seqTokens(); seqTokens > 0 && nSeqTokens != seqTokens { + return nil, ErrUnsupportedBatchLayout + } + } + batchSize := nSeqTokens + hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, batchSize) + + headDim := opts.getHeadDim() + if headDim <= 0 { + return nil, fmt.Errorf("nemotronh: invalid attention head dimension %d", headDim) + } + + // Q projection + query := a.Query.Forward(ctx, hiddenStates) + if query.Dim(0)%headDim != 0 { + return nil, fmt.Errorf("nemotronh: query dim %d not divisible by head dim %d", query.Dim(0), headDim) + } + numHeads := query.Dim(0) / headDim + query = query.Reshape(ctx, headDim, numHeads, batchSize) + + // K projection + key := a.Key.Forward(ctx, hiddenStates) + if key.Dim(0)%headDim != 0 { + return nil, fmt.Errorf("nemotronh: key dim %d not divisible by head dim %d", key.Dim(0), headDim) + } + numKVHeads := key.Dim(0) / headDim + key = key.Reshape(ctx, headDim, numKVHeads, batchSize) + + // V projection + value := a.Value.Forward(ctx, hiddenStates) + if value.Dim(0)%headDim != 0 { + return nil, fmt.Errorf("nemotronh: value dim %d not divisible by head dim %d", value.Dim(0), headDim) + } + if value.Dim(0)/headDim != numKVHeads { + return nil, fmt.Errorf("nemotronh: key heads %d and value heads %d do not match", numKVHeads, value.Dim(0)/headDim) + } + value = value.Reshape(ctx, headDim, numKVHeads, batchSize) + + // Standard attention computation (no RoPE) + scale := opts.attentionScale + if scale == 0 { + scale = 1.0 / math.Sqrt(float64(headDim)) + } + attention := nn.Attention(ctx, query, key, value, scale, cache) + + // Flatten heads + attention = attention.Reshape(ctx, headDim*numHeads, batchSize) + + // Output projection + return a.Output.Forward(ctx, attention), nil +} diff --git a/model/models/nemotronh/cache.go b/model/models/nemotronh/cache.go new file mode 100644 index 000000000..8381d0978 --- /dev/null +++ b/model/models/nemotronh/cache.go @@ -0,0 +1,55 @@ +package nemotronh + +import ( + "errors" + + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" +) + +// ErrUnsupportedBatchLayout is returned when the batch layout is incompatible +// with the layer requirements. +var ErrUnsupportedBatchLayout = errors.New("nemotronh: unsupported batch layout") + +var ( + _ kvcache.Cache = (*HybridCache)(nil) + _ kvcache.CheckpointCache = (*HybridCache)(nil) +) + +// HybridCache adapts the shared recurrent cache base for Nemotron-H naming. +type HybridCache struct { + *kvcache.Recurrent +} + +func NewHybridCache(convDim, convChannels, ssmStateSize int) *HybridCache { + base := kvcache.NewRecurrentCache(kvcache.RecurrentConfig{ + Shift: Shift, + ConvDim: convDim, + ConvChannels: convChannels, + RecurrentStateSize: ssmStateSize, + CheckpointLogPrefix: "nemotronh", + }) + return &HybridCache{Recurrent: base} +} + +// SSMState returns the SSM state for current batch sequences. +func (c *HybridCache) SSMState(ctx ml.Context, layer int, dState, headDim, nHead int) (ml.Tensor, error) { + return c.RecurrentState4D(ctx, layer, dState, headDim, nHead) +} + +// UpdateSSMState writes a new SSM state for current batch sequences. +func (c *HybridCache) UpdateSSMState(ctx ml.Context, layer int, newState ml.Tensor) { + c.UpdateRecurrentState(ctx, layer, newState) +} + +func (c *HybridCache) slotsTensor() ml.Tensor { + return c.SlotsTensor() +} + +func (c *HybridCache) seqTokens() int { + return c.SeqTokens() +} + +func (c *HybridCache) numSeqs() int { + return c.NumSeqs() +} diff --git a/model/models/nemotronh/mamba2.go b/model/models/nemotronh/mamba2.go new file mode 100644 index 000000000..2a8c08606 --- /dev/null +++ b/model/models/nemotronh/mamba2.go @@ -0,0 +1,197 @@ +package nemotronh + +import ( + "log/slog" + + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" +) + +// convKernel wraps the 1D convolution kernel tensor +type convKernel struct { + Weight ml.Tensor `gguf:"weight"` +} + +// Mamba2 implements the Mamba2 SSM layer for Nemotron-H. +// The forward pass follows llama.cpp's build_mamba2_layer: +// 1. Input projection: zxBCdt = SSMIn @ hidden +// 2. Split: z, xBC, dt +// 3. Concat with conv state, apply SSMConv, save new conv state +// 4. Apply SiLU to convolved xBC +// 5. Split: x, B, C +// 6. Add dt bias +// 7. SSMScan: y = SSMScan(state, x, dt, A, B, C, ids) +// 8. D skip: y = y + x * D +// 9. Swiglu with z: y = z * silu(y) +// 10. Group RMSNorm +// 11. Output projection +type Mamba2 struct { + SSMIn *nn.Linear `gguf:"ssm_in"` // n_embd → d_in_proj (2*d_inner + 2*n_group*d_state + n_head) + SSMConv1D *convKernel `gguf:"ssm_conv1d"` // conv kernel + SSMConv1DB ml.Tensor `gguf:"ssm_conv1d.bias"` + SSMDtB ml.Tensor `gguf:"ssm_dt.bias"` // dt bias [n_head] + SSMA ml.Tensor `gguf:"ssm_a"` // A parameter [1, n_head] + SSMD ml.Tensor `gguf:"ssm_d"` // D skip connection [1, n_head] + SSMNorm *nn.RMSNorm `gguf:"ssm_norm"` // group norm + SSMOut *nn.Linear `gguf:"ssm_out"` // output projection + Layer int +} + +func (m *Mamba2) Forward(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) { + layer := m.Layer + hiddenDim := hiddenStates.Dim(0) + nSeqTokens := hiddenStates.Dim(1) + switch hiddenStates.Dim(2) { + case 0: + hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, nSeqTokens, 1) + case 1: + default: + return nil, ErrUnsupportedBatchLayout + } + + // Nemotron-H is currently clamped to num_parallel=1. + if cache != nil && cache.IsSupportedForBatch() { + if cache.numSeqs() != 1 { + return nil, ErrUnsupportedBatchLayout + } + if seqTokens := cache.seqTokens(); seqTokens > 0 && nSeqTokens != seqTokens { + return nil, ErrUnsupportedBatchLayout + } + } + nSeqs := 1 + + dConv := opts.ssmDConv + dInner := opts.ssmDInner + dState := opts.ssmDState + nHead := opts.ssmNHead + headDim := dInner / nHead + nGroup := opts.ssmNGroup + + // {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs} + // d_in_proj = 2*d_inner + 2*n_group*d_state + n_head + zxBCdt := m.SSMIn.Forward(ctx, hiddenStates) + + // Split into z, xBC, dt + // z: [head_dim, n_head, n_seq_tokens, n_seqs] + z := zxBCdt.Slice(ctx, 0, 0, dInner, 1) + z = z.Reshape(ctx, headDim, nHead, nSeqTokens, nSeqs) + + // xBC: [d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs] + xBCSize := dInner + 2*nGroup*dState + xBC := zxBCdt.Slice(ctx, 0, dInner, dInner+xBCSize, 1) + if nSeqTokens == 1 { + xBC = xBC.Reshape(ctx, xBCSize, 1, nSeqs) + } + + // dt: [n_head, n_seq_tokens, n_seqs] + dt := zxBCdt.Slice(ctx, 0, 2*dInner+2*nGroup*dState, 2*dInner+2*nGroup*dState+nHead, 1) + if nSeqTokens == 1 { + dt = dt.Reshape(ctx, nHead, 1, nSeqs) + } else { + dt = dt.Contiguous(ctx, nHead, nSeqTokens, nSeqs) + } + + // Get conv state from cache + convStates, err := cache.ConvState(ctx, layer) + if err != nil { + slog.Warn("nemotronh: failed to get conv state, using zeros", "layer", layer, "error", err) + convStates = ctx.Input().Zeros(ml.DTypeF32, dConv-1, xBCSize, nSeqs) + } + + // Reshape conv states: [d_conv-1, xBCSize, n_seqs] + convStates = convStates.Reshape(ctx, dConv-1, xBCSize, nSeqs) + + // For decode (n_seq_tokens == 1), reshape avoids a transpose/contiguous pair. + var xBCT ml.Tensor + if nSeqTokens == 1 { + xBCT = xBC.Reshape(ctx, 1, xBCSize, nSeqs) + } else { + // Prefill path: [xBCSize, n_seq_tokens, n_seqs] -> [n_seq_tokens, xBCSize, n_seqs] + xBCT = xBC.Permute(ctx, 1, 0, 2, 3) + } + + // Concatenate with conv state: [d_conv-1 + n_seq_tokens, xBCSize, n_seqs] + convInput := convStates.Concat(ctx, xBCT, 0) + + // Save new conv state (last d_conv-1 columns) + lastConvStates := convInput.Slice(ctx, 0, nSeqTokens, nSeqTokens+dConv-1, 1) + cache.UpdateConvState(ctx, layer, lastConvStates) + + // Apply SSM convolution + xBC = convInput.SSMConv(ctx, m.SSMConv1D.Weight) + + // Add conv bias + if m.SSMConv1DB != nil { + xBC = xBC.Add(ctx, m.SSMConv1DB) + } + + // Apply SiLU + xBC = xBC.SILU(ctx) + + // Split xBC into x, B, C + // x: [head_dim, n_head, n_seq_tokens, n_seqs] + x := xBC.Slice(ctx, 0, 0, dInner, 1) + x = x.Reshape(ctx, headDim, nHead, nSeqTokens, nSeqs) + + // B: [d_state, n_group, n_seq_tokens, n_seqs] + B := xBC.Slice(ctx, 0, dInner, dInner+nGroup*dState, 1) + B = B.Reshape(ctx, dState, nGroup, nSeqTokens, nSeqs) + + // C: [d_state, n_group, n_seq_tokens, n_seqs] + C := xBC.Slice(ctx, 0, dInner+nGroup*dState, dInner+2*nGroup*dState, 1) + C = C.Reshape(ctx, dState, nGroup, nSeqTokens, nSeqs) + + // Add dt bias + dt = dt.Add(ctx, m.SSMDtB) + + // Get SSM state from cache + state, err := cache.SSMState(ctx, layer, dState, headDim, nHead) + if err != nil { + slog.Warn("nemotronh: failed to get SSM state, using zeros", "layer", layer, "error", err) + state = ctx.Input().Zeros(ml.DTypeF32, dState, headDim, nHead, nSeqs) + } + + // SSMScan + // state: [d_state, head_dim, n_head, n_seqs] + // returns: [head_dim, n_head, n_seq_tokens, n_seqs] concatenated with new state + ySsm := state.SSMScan(ctx, x, dt, m.SSMA, B, C, cache.slotsTensor()) + + // ySsm is a packed 1D buffer: [y (nSeqTokens*headDim*nHead*nSeqs), newState] + yElems := headDim * nHead * nSeqTokens * nSeqs + y := ySsm.View(ctx, 0, yElems).Reshape(ctx, headDim, nHead, nSeqTokens, nSeqs) + + stateOffsetBytes := yElems * x.Stride(0) + stateElems := dState * headDim * nHead * nSeqs + newState := ySsm.View(ctx, stateOffsetBytes, stateElems) + newState = newState.Reshape(ctx, dState, headDim, nHead, nSeqs) + + // Update SSM state in cache + cache.UpdateSSMState(ctx, layer, newState) + + // D skip connection: y = y + x * D + if m.SSMD != nil { + // SSMD shape: [1, n_head] -> broadcast to [head_dim, n_head, n_seq_tokens, n_seqs] + xD := x.Mul(ctx, m.SSMD) + y = y.Add(ctx, xD) + } + + // Swiglu with z: y = z * silu(y) + y = z.SILU(ctx, y) + + // Group RMSNorm + if m.SSMNorm != nil { + // Reshape for group norm: [d_inner/n_group, n_group, n_seq_tokens, n_seqs] + innerPerGroup := dInner / nGroup + y = y.Reshape(ctx, innerPerGroup, nGroup, nSeqTokens, nSeqs) + y = m.SSMNorm.Forward(ctx, y, opts.eps) + } + + // Reshape back to [d_inner, n_seq_tokens, n_seqs] + y = y.Reshape(ctx, dInner, nSeqTokens, nSeqs) + + // Output projection + out := m.SSMOut.Forward(ctx, y) + + // Reshape to 2D for consistency with attention output + return out.Reshape(ctx, out.Dim(0), nSeqTokens*nSeqs), nil +} diff --git a/model/models/nemotronh/model.go b/model/models/nemotronh/model.go new file mode 100644 index 000000000..33220fe9b --- /dev/null +++ b/model/models/nemotronh/model.go @@ -0,0 +1,417 @@ +package nemotronh + +import ( + "fmt" + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" + "github.com/ollama/ollama/tokenizer" +) + +// Options contains model configuration +type Options struct { + hiddenSize int + numHeads int // attention heads + numKVHeads int // KV heads for attention layers + headDim int + eps float32 + + // Mamba2 SSM config + ssmDConv int // conv kernel size + ssmDInner int // inner dimension (d_inner) + ssmDState int // state dimension + ssmNHead int // number of SSM heads (dt_rank) + ssmNGroup int // number of groups for B, C + + // Per-layer configuration + isRecurrent []bool // true = Mamba2, false = attention or FFN + nFF []int // n_ff per layer (0 = attention-only) + + // Attention scale + attentionScale float64 + + // MoE config + numExperts int + numExpertsUsed int + expertWeightsNorm bool + expertWeightsScale float32 + expertWeightsNormClip float32 +} + +func (o Options) getHeadDim() int { + if o.headDim > 0 { + return o.headDim + } + if o.numHeads <= 0 { + return 0 + } + return o.hiddenSize / o.numHeads +} + +// Operator is the interface for layer operators (Mamba2 or Attention) +type Operator interface { + Forward(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) +} + +// MLP is the interface for feedforward networks +type MLP interface { + Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor +} + +// Layer represents a single transformer layer +type Layer struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + Operator Operator // Mamba2, Attention, or nil (for FFN-only layers) + MLP MLP // Dense or MoE FFN, or nil +} + +func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, outputs ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) { + residual := hiddenStates + + // Pre-layer norm + hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) + + // Layer operator (Mamba2, Attention, or FFN) + if l.Operator != nil { + var err error + hiddenStates, err = l.Operator.Forward(ctx, hiddenStates, cache, opts) + if err != nil { + return nil, err + } + } else if l.MLP != nil { + // FFN-only layer + hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts) + } + + // Output projection for last layer + if outputs != nil { + hiddenStates = hiddenStates.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + + // Residual connection + return hiddenStates.Add(ctx, residual), nil +} + +// Model is the main Nemotron-H model +type Model struct { + model.Base + tokenizer.Tokenizer + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + OutputNorm *nn.RMSNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` + + Layers []Layer `gguf:"blk"` + + *Options +} + +// Shift is used for KV cache position shifting. +// Nemotron-H attention does not apply RoPE, so keys do not need to be transformed. +func Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return key, nil +} + +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) + + cache := m.Cache.(*HybridCache) + + for i, layer := range m.Layers { + cache.SetLayer(i) + + var outputs ml.Tensor + if i == len(m.Layers)-1 { + outputs = batch.Outputs + } + + var err error + hiddenStates, err = layer.Forward(ctx, i, hiddenStates, outputs, cache, m.Options) + if err != nil { + return nil, err + } + } + + hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) + return m.Output.Forward(ctx, hiddenStates), nil +} + +func New(c fs.Config) (model.Model, error) { + numLayers := int(c.Uint("block_count")) + layers := make([]Layer, numLayers) + + // Get per-layer configuration from GGUF metadata + // Use the same interface pattern as qwen3next + type perLayerConfig interface { + HeadCount() []uint64 + HeadCountKV() []uint64 + FFNLength() []uint64 + } + + var headCount []uint64 + var headCountKV []uint64 + var ffnLength []uint64 + + if plc, ok := c.(perLayerConfig); ok { + headCount = plc.HeadCount() + headCountKV = plc.HeadCountKV() + ffnLength = plc.FFNLength() + } + + // Build per-layer arrays with defaults + isRecurrent := make([]bool, numLayers) + nFF := make([]int, numLayers) + + for i := range numLayers { + // Get per-layer values + kvHeads := uint64(1) // Default non-zero + if i < len(headCountKV) { + kvHeads = headCountKV[i] + } + ff := uint64(0) + if i < len(ffnLength) { + ff = ffnLength[i] + } + nFF[i] = int(ff) + + // A layer is recurrent IFF n_head_kv == 0 AND n_ff == 0 + // This matches llama.cpp behavior for Nemotron-H + isRecurrent[i] = kvHeads == 0 && ff == 0 + } + + // Determine if MoE + isMoE := c.Uint("expert_count") > 0 + + for i := range layers { + if isRecurrent[i] { + // Mamba2 layer + layers[i].Operator = &Mamba2{Layer: i} + } else if nFF[i] == 0 { + // Attention-only layer (n_head_kv > 0, n_ff == 0) + layers[i].Operator = &Attention{} + } else { + // FFN layer (n_ff > 0) + if isMoE { + layers[i].MLP = &MoESparse{} + } else { + layers[i].MLP = &Dense{} + } + } + } + + // Get attention head configuration + numHeads := int(c.Uint("attention.head_count")) + if numHeads == 0 { + for i := range numLayers { + if i < len(headCount) && i < len(headCountKV) && headCount[i] > 0 && headCountKV[i] > 0 { + numHeads = int(headCount[i]) + break + } + } + } + numKVHeads := int(c.Uint("attention.head_count_kv")) + if numKVHeads == 0 { + for i := range numLayers { + if i < len(headCountKV) && i < len(ffnLength) && headCountKV[i] > 0 && ffnLength[i] == 0 { + numKVHeads = int(headCountKV[i]) + break + } + } + if numKVHeads == 0 { + numKVHeads = numHeads + } + } + + headDim := int(c.Uint("attention.head_dim")) + if headDim == 0 { + if keyLength := int(c.Uint("attention.key_length")); keyLength > 0 { + headDim = keyLength + } else if numHeads > 0 { + headDim = int(c.Uint("embedding_length")) / numHeads + } + } + if headDim <= 0 { + return nil, fmt.Errorf("nemotronh: invalid attention head dimension") + } + if numHeads <= 0 { + // Attention layers derive per-layer head counts from projection weights. + // Keep a non-zero default to avoid invalid option math. + numHeads = 1 + } + + numExperts := int(c.Uint("expert_count")) + numExpertsUsed := int(c.Uint("expert_used_count")) + if numExperts > 0 { + if numExpertsUsed <= 0 || numExpertsUsed > numExperts { + return nil, fmt.Errorf("nemotronh: invalid expert_used_count=%d for expert_count=%d", numExpertsUsed, numExperts) + } + } + + opts := &Options{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: numHeads, + numKVHeads: numKVHeads, + headDim: headDim, + eps: c.Float("attention.layer_norm_rms_epsilon"), + ssmDConv: int(c.Uint("ssm.conv_kernel")), + ssmDInner: int(c.Uint("ssm.inner_size")), + ssmDState: int(c.Uint("ssm.state_size")), + ssmNHead: int(c.Uint("ssm.time_step_rank")), + ssmNGroup: int(c.Uint("ssm.group_count")), + isRecurrent: isRecurrent, + nFF: nFF, + attentionScale: float64(c.Float("attention.scale")), + numExperts: numExperts, + numExpertsUsed: numExpertsUsed, + expertWeightsNorm: c.Bool("expert_weights_norm", false), + expertWeightsScale: c.Float("expert_weights_scale", 1.0), + expertWeightsNormClip: c.Float("expert_weights_norm_clip", 0), + } + + // Calculate cache dimensions + convDim := max(0, opts.ssmDConv-1) + convChannels := opts.ssmDInner + 2*opts.ssmNGroup*opts.ssmDState + ssmHeadDim := 0 + if opts.ssmNHead > 0 { + ssmHeadDim = opts.ssmDInner / opts.ssmNHead + } + ssmStateSize := opts.ssmDState * ssmHeadDim * opts.ssmNHead + + m := Model{ + Tokenizer: tokenizer.NewBytePairEncoding( + &tokenizer.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), + }, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, + ), + Layers: layers, + Options: opts, + } + + m.Cache = NewHybridCache(convDim, convChannels, ssmStateSize) + return &m, nil +} + +func init() { + model.Register("nemotron_h", New) + model.Register("nemotron_h_moe", New) +} + +// Ensure Model implements model.Model +var _ model.Model = (*Model)(nil) + +// Dense implements standard feedforward with ReLU-squared activation +type Dense struct { + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (d *Dense) Forward(ctx ml.Context, x ml.Tensor, opts *Options) ml.Tensor { + // up -> ReLU-squared -> down + up := d.Up.Forward(ctx, x) + up = up.RELU(ctx) + up = up.Mul(ctx, up) // Square + return d.Down.Forward(ctx, up) +} + +// MoESparse implements MoE with shared experts for Nemotron-H-MoE +type MoESparse struct { + Router *nn.Linear `gguf:"ffn_gate_inp"` + Up *nn.LinearBatch `gguf:"ffn_up_exps"` + Down *nn.LinearBatch `gguf:"ffn_down_exps"` + Bias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"` + + LatentIn *nn.Linear `gguf:"ffn_latent_in"` + LatentOut *nn.Linear `gguf:"ffn_latent_out"` + + // Shared experts + SharedUp *nn.Linear `gguf:"ffn_up_shexp"` + SharedDown *nn.Linear `gguf:"ffn_down_shexp"` +} + +func (m *MoESparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + hiddenDim := hiddenStates.Dim(0) + seqLen := hiddenStates.Dim(1) + batchSize := hiddenStates.Dim(2) + if batchSize == 0 { + batchSize = 1 + } + hiddenStates2D := hiddenStates.Reshape(ctx, hiddenDim, seqLen*batchSize) + + // Router logits with sigmoid gating + routerLogits := m.Router.Forward(ctx, hiddenStates2D) + + // Weights come from unbiased sigmoid probabilities. + probs := routerLogits.Sigmoid(ctx) + + // Selection uses optional bias. + selectionProbs := probs + if m.Bias != nil { + selectionProbs = selectionProbs.Add(ctx, m.Bias) + } + + // Select top-k experts + selectedExperts := selectionProbs.TopK(ctx, opts.numExpertsUsed) + routingWeights := probs.Reshape(ctx, 1, opts.numExperts, hiddenStates2D.Dim(1)).Rows(ctx, selectedExperts) + + // Normalize routing weights + if opts.expertWeightsNorm { + routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates2D.Dim(1)) + weightsSum := routingWeights.SumRows(ctx) + weightsSum = weightsSum.Clamp(ctx, 6.103515625e-5, float32(math.MaxFloat32)) + routingWeights = routingWeights.Div(ctx, weightsSum) + routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates2D.Dim(1)) + } + + // Scale routing weights + if opts.expertWeightsScale != 1.0 { + routingWeights = routingWeights.Scale(ctx, float64(opts.expertWeightsScale)) + } + + routedInput := hiddenStates2D + if m.LatentIn != nil { + routedInput = m.LatentIn.Forward(ctx, routedInput) + } + hiddenStates3D := routedInput.Reshape(ctx, routedInput.Dim(0), 1, routedInput.Dim(1)) + + // Expert computation with ReLU-squared activation + upOut := m.Up.Forward(ctx, hiddenStates3D, selectedExperts) + upOut = upOut.RELU(ctx) + upOut = upOut.Mul(ctx, upOut) // Square + experts := m.Down.Forward(ctx, upOut, selectedExperts) + experts = experts.Mul(ctx, routingWeights) + + // Sum over experts + moeOut := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2)) + for i := 1; i < opts.numExpertsUsed; i++ { + moeOut = moeOut.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2))) + } + if m.LatentOut != nil { + moeOut = m.LatentOut.Forward(ctx, moeOut) + } + + // Add shared experts if present + if m.SharedUp != nil { + sharedUp := m.SharedUp.Forward(ctx, hiddenStates2D) + sharedUp = sharedUp.RELU(ctx) + sharedUp = sharedUp.Mul(ctx, sharedUp) // Square + sharedOut := m.SharedDown.Forward(ctx, sharedUp) + moeOut = moeOut.Add(ctx, sharedOut) + } + + return moeOut +} diff --git a/server/sched.go b/server/sched.go index 0049b87ce..ebaeda699 100644 --- a/server/sched.go +++ b/server/sched.go @@ -447,7 +447,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo // Some architectures are not safe with num_parallel > 1. // ref: https://github.com/ollama/ollama/issues/4165 - if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen3next", "lfm2", "lfm2moe"}, req.model.Config.ModelFamily) && numParallel != 1 { + if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen3next", "lfm2", "lfm2moe", "nemotron_h", "nemotron_h_moe"}, req.model.Config.ModelFamily) && numParallel != 1 { numParallel = 1 slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily) }