From 8da09b1e7e7394a818bc0f36d4244a927a91c126 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Sat, 28 Feb 2026 14:21:42 -0800 Subject: [PATCH] qwen3next: add compatibility with imported GGUF models (#14517) --- model/models/qwen3next/deltanet.go | 16 ++- model/models/qwen3next/model.go | 120 +++++++++++++++--- model/models/qwen3next/model_new_test.go | 65 ++++++++++ model/models/qwen3next/model_validate_test.go | 45 +++++++ 4 files changed, 227 insertions(+), 19 deletions(-) create mode 100644 model/models/qwen3next/model_new_test.go create mode 100644 model/models/qwen3next/model_validate_test.go diff --git a/model/models/qwen3next/deltanet.go b/model/models/qwen3next/deltanet.go index d928efc98..6ce315649 100644 --- a/model/models/qwen3next/deltanet.go +++ b/model/models/qwen3next/deltanet.go @@ -41,8 +41,8 @@ type GatedDeltaNet struct { SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35) SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35) SSMConv1D *convKernel `gguf:"ssm_conv1d"` - SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias - SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp() + SSMDT ml.Tensor `gguf:"ssm_dt,alt:ssm_dt.bias"` // alpha bias + SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp() SSMNorm *nn.RMSNorm `gguf:"ssm_norm"` SSMOut *nn.Linear `gguf:"ssm_out"` @@ -135,6 +135,18 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac default: return nil, errors.New("qwen3next: missing linear attention beta/alpha projections") } + if gdn.SSMDT == nil { + return nil, errors.New("qwen3next: missing linear attention ssm_dt tensor") + } + if gdn.SSMA == nil { + return nil, errors.New("qwen3next: missing linear attention ssm_a tensor") + } + if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil { + return nil, errors.New("qwen3next: missing linear attention ssm_conv1d tensor") + } + if gdn.SSMNorm == nil || gdn.SSMOut == nil { + return nil, errors.New("qwen3next: missing linear attention ssm_norm/ssm_out projections") + } // Compute gate: softplus(alpha + dt_bias) * -A alphaBiased := alpha.Add(ctx, gdn.SSMDT) diff --git a/model/models/qwen3next/model.go b/model/models/qwen3next/model.go index 7611362f1..9681efda3 100644 --- a/model/models/qwen3next/model.go +++ b/model/models/qwen3next/model.go @@ -437,6 +437,46 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { return m.Output.Forward(ctx, hiddenStates), nil } +func (m *Model) Validate() error { + if m.Options == nil { + return fmt.Errorf("qwen3next: missing model options") + } + if len(m.Layers) != len(m.Options.isRecurrent) { + return fmt.Errorf("qwen3next: layer config mismatch: have %d layers, %d recurrent flags", len(m.Layers), len(m.Options.isRecurrent)) + } + + for i, layer := range m.Layers { + if !m.Options.isRecurrent[i] { + continue + } + + gdn, ok := layer.Operator.(*GatedDeltaNet) + if !ok || gdn == nil { + return fmt.Errorf("qwen3next: layer %d expected recurrent operator", i) + } + if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil { + return fmt.Errorf("qwen3next: layer %d missing attn_qkv/attn_gate projections", i) + } + if gdn.SSMBetaAlpha == nil && (gdn.SSMBeta == nil || gdn.SSMAlpha == nil) { + return fmt.Errorf("qwen3next: layer %d missing linear attention beta/alpha projections", i) + } + if gdn.SSMDT == nil { + return fmt.Errorf("qwen3next: layer %d missing ssm_dt tensor", i) + } + if gdn.SSMA == nil { + return fmt.Errorf("qwen3next: layer %d missing ssm_a tensor", i) + } + if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil { + return fmt.Errorf("qwen3next: layer %d missing ssm_conv1d tensor", i) + } + if gdn.SSMNorm == nil || gdn.SSMOut == nil { + return fmt.Errorf("qwen3next: layer %d missing ssm_norm/ssm_out projections", i) + } + } + + return nil +} + func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { m.positionCache = nil if len(m.mropeSections) > 0 { @@ -450,6 +490,64 @@ var ( _ model.MultimodalProcessor = (*Model)(nil) ) +func defaultVHeadReordered(arch string) bool { + return arch == "qwen35" || arch == "qwen35moe" +} + +func inferRecurrentLayers(headCountKV []uint64, numLayers int, fullAttentionInterval uint32) ([]bool, error) { + isRecurrent := make([]bool, numLayers) + + hasZero := false + hasFull := false + for i := range numLayers { + if i >= len(headCountKV) { + continue + } + + if headCountKV[i] == 0 { + isRecurrent[i] = true + hasZero = true + } else { + hasFull = true + } + } + if hasZero && hasFull { + return isRecurrent, nil + } + if !hasFull { + return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value") + } + + // Compatibility path: older imports store a scalar KV head count and omit + // per-layer recurrent flags. Derive the hybrid layout from the interval. + interval := int(fullAttentionInterval) + if interval == 0 { + interval = min(4, numLayers) + } + if interval <= 0 { + return nil, fmt.Errorf("qwen3next: invalid block_count (%d)", numLayers) + } + if interval > numLayers { + return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds block_count (%d)", interval, numLayers) + } + + hasZero = false + hasFull = false + for i := range numLayers { + isRecurrent[i] = (i+1)%interval != 0 + if isRecurrent[i] { + hasZero = true + } else { + hasFull = true + } + } + if !hasZero || !hasFull { + return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) does not produce a mixed recurrent/full layout", interval) + } + + return isRecurrent, nil +} + func New(c fs.Config) (model.Model, error) { numLayers := int(c.Uint("block_count")) layers := make([]Layer, numLayers) @@ -460,26 +558,14 @@ func New(c fs.Config) (model.Model, error) { HeadCountKV() []uint64 } - var isRecurrent []bool var headCountKV []uint64 if hc, ok := c.(headCounts); ok { headCountKV = hc.HeadCountKV() } - isRecurrent = make([]bool, numLayers) - hasZero := false - hasFull := false - for i := range numLayers { - // If KV head count is 0, it's a recurrent layer - if i < len(headCountKV) && headCountKV[i] == 0 { - isRecurrent[i] = true - hasZero = true - } else if i < len(headCountKV) && headCountKV[i] > 0 { - hasFull = true - } - } - if !hasZero || !hasFull { - return nil, fmt.Errorf("qwen3next: invalid attention.head_count_kv array; expected mix of zero and non-zero values") + isRecurrent, err := inferRecurrentLayers(headCountKV, numLayers, c.Uint("full_attention_interval")) + if err != nil { + return nil, err } // Determine if MoE @@ -543,7 +629,7 @@ func New(c fs.Config) (model.Model, error) { ssmNGroup: int(c.Uint("ssm.group_count")), ssmDtRank: int(c.Uint("ssm.time_step_rank")), convKernelSize: int(c.Uint("ssm.conv_kernel")), - vHeadReordered: c.Bool("ssm.v_head_reordered", false), + vHeadReordered: c.Bool("ssm.v_head_reordered", defaultVHeadReordered(c.Architecture())), isRecurrent: isRecurrent, mropeSections: slices.Collect(func(yield func(int) bool) { for _, section := range mropeSections { @@ -555,7 +641,7 @@ func New(c fs.Config) (model.Model, error) { mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)), } if opts.numKVHeads == 0 { - return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value") + return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value") } // Calculate cache dimensions diff --git a/model/models/qwen3next/model_new_test.go b/model/models/qwen3next/model_new_test.go new file mode 100644 index 000000000..409f01732 --- /dev/null +++ b/model/models/qwen3next/model_new_test.go @@ -0,0 +1,65 @@ +package qwen3next + +import ( + "slices" + "strings" + "testing" +) + +func TestInferRecurrentLayersMixedKVArray(t *testing.T) { + got, err := inferRecurrentLayers([]uint64{0, 2, 0, 2}, 4, 0) + if err != nil { + t.Fatalf("inferRecurrentLayers() error = %v", err) + } + + want := []bool{true, false, true, false} + if !slices.Equal(got, want) { + t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want) + } +} + +func TestInferRecurrentLayersScalarKVDefaultInterval(t *testing.T) { + got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2, 2, 2}, 8, 0) + if err != nil { + t.Fatalf("inferRecurrentLayers() error = %v", err) + } + + want := []bool{true, true, true, false, true, true, true, false} + if !slices.Equal(got, want) { + t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want) + } +} + +func TestInferRecurrentLayersScalarKVConfiguredInterval(t *testing.T) { + got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2}, 6, 3) + if err != nil { + t.Fatalf("inferRecurrentLayers() error = %v", err) + } + + want := []bool{true, true, false, true, true, false} + if !slices.Equal(got, want) { + t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want) + } +} + +func TestInferRecurrentLayersAllZeroRejects(t *testing.T) { + _, err := inferRecurrentLayers([]uint64{0, 0, 0, 0}, 4, 0) + if err == nil { + t.Fatal("inferRecurrentLayers() expected error, got nil") + } + if !strings.Contains(err.Error(), "must include at least one non-zero value") { + t.Fatalf("unexpected error = %v", err) + } +} + +func TestDefaultVHeadReordered(t *testing.T) { + if !defaultVHeadReordered("qwen35") { + t.Fatal("defaultVHeadReordered(qwen35) = false, want true") + } + if !defaultVHeadReordered("qwen35moe") { + t.Fatal("defaultVHeadReordered(qwen35moe) = false, want true") + } + if defaultVHeadReordered("qwen3next") { + t.Fatal("defaultVHeadReordered(qwen3next) = true, want false") + } +} diff --git a/model/models/qwen3next/model_validate_test.go b/model/models/qwen3next/model_validate_test.go new file mode 100644 index 000000000..cf2c5f7e9 --- /dev/null +++ b/model/models/qwen3next/model_validate_test.go @@ -0,0 +1,45 @@ +package qwen3next + +import ( + "strings" + "testing" + + "github.com/ollama/ollama/ml/nn" +) + +func TestValidateRecurrentLayerRequiresSSMDT(t *testing.T) { + m := &Model{ + Layers: []Layer{{ + Operator: &GatedDeltaNet{ + SSMQKV: &nn.Linear{}, + SSMQKVGate: &nn.Linear{}, + SSMBeta: &nn.Linear{}, + SSMAlpha: &nn.Linear{}, + }, + }}, + Options: &Options{ + isRecurrent: []bool{true}, + }, + } + + err := m.Validate() + if err == nil { + t.Fatal("Validate() expected error, got nil") + } + if !strings.Contains(err.Error(), "missing ssm_dt") { + t.Fatalf("unexpected error = %v", err) + } +} + +func TestValidateNonRecurrentSkipsLinearChecks(t *testing.T) { + m := &Model{ + Layers: []Layer{{Operator: &FullAttention{}}}, + Options: &Options{ + isRecurrent: []bool{false}, + }, + } + + if err := m.Validate(); err != nil { + t.Fatalf("Validate() error = %v", err) + } +}