qwen3next: add compatibility with imported GGUF models (#14517)

This commit is contained in:
Jeffrey Morgan
2026-02-28 14:21:42 -08:00
committed by GitHub
parent a60b9adcce
commit 8da09b1e7e
4 changed files with 227 additions and 19 deletions

View File

@@ -41,8 +41,8 @@ type GatedDeltaNet struct {
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
SSMDT ml.Tensor `gguf:"ssm_dt,alt:ssm_dt.bias"` // alpha bias
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
SSMOut *nn.Linear `gguf:"ssm_out"`
@@ -135,6 +135,18 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
default:
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
}
if gdn.SSMDT == nil {
return nil, errors.New("qwen3next: missing linear attention ssm_dt tensor")
}
if gdn.SSMA == nil {
return nil, errors.New("qwen3next: missing linear attention ssm_a tensor")
}
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
return nil, errors.New("qwen3next: missing linear attention ssm_conv1d tensor")
}
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
return nil, errors.New("qwen3next: missing linear attention ssm_norm/ssm_out projections")
}
// Compute gate: softplus(alpha + dt_bias) * -A
alphaBiased := alpha.Add(ctx, gdn.SSMDT)

View File

@@ -437,6 +437,46 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
return m.Output.Forward(ctx, hiddenStates), nil
}
func (m *Model) Validate() error {
if m.Options == nil {
return fmt.Errorf("qwen3next: missing model options")
}
if len(m.Layers) != len(m.Options.isRecurrent) {
return fmt.Errorf("qwen3next: layer config mismatch: have %d layers, %d recurrent flags", len(m.Layers), len(m.Options.isRecurrent))
}
for i, layer := range m.Layers {
if !m.Options.isRecurrent[i] {
continue
}
gdn, ok := layer.Operator.(*GatedDeltaNet)
if !ok || gdn == nil {
return fmt.Errorf("qwen3next: layer %d expected recurrent operator", i)
}
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
return fmt.Errorf("qwen3next: layer %d missing attn_qkv/attn_gate projections", i)
}
if gdn.SSMBetaAlpha == nil && (gdn.SSMBeta == nil || gdn.SSMAlpha == nil) {
return fmt.Errorf("qwen3next: layer %d missing linear attention beta/alpha projections", i)
}
if gdn.SSMDT == nil {
return fmt.Errorf("qwen3next: layer %d missing ssm_dt tensor", i)
}
if gdn.SSMA == nil {
return fmt.Errorf("qwen3next: layer %d missing ssm_a tensor", i)
}
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
return fmt.Errorf("qwen3next: layer %d missing ssm_conv1d tensor", i)
}
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
return fmt.Errorf("qwen3next: layer %d missing ssm_norm/ssm_out projections", i)
}
}
return nil
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
m.positionCache = nil
if len(m.mropeSections) > 0 {
@@ -450,6 +490,64 @@ var (
_ model.MultimodalProcessor = (*Model)(nil)
)
func defaultVHeadReordered(arch string) bool {
return arch == "qwen35" || arch == "qwen35moe"
}
func inferRecurrentLayers(headCountKV []uint64, numLayers int, fullAttentionInterval uint32) ([]bool, error) {
isRecurrent := make([]bool, numLayers)
hasZero := false
hasFull := false
for i := range numLayers {
if i >= len(headCountKV) {
continue
}
if headCountKV[i] == 0 {
isRecurrent[i] = true
hasZero = true
} else {
hasFull = true
}
}
if hasZero && hasFull {
return isRecurrent, nil
}
if !hasFull {
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
}
// Compatibility path: older imports store a scalar KV head count and omit
// per-layer recurrent flags. Derive the hybrid layout from the interval.
interval := int(fullAttentionInterval)
if interval == 0 {
interval = min(4, numLayers)
}
if interval <= 0 {
return nil, fmt.Errorf("qwen3next: invalid block_count (%d)", numLayers)
}
if interval > numLayers {
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds block_count (%d)", interval, numLayers)
}
hasZero = false
hasFull = false
for i := range numLayers {
isRecurrent[i] = (i+1)%interval != 0
if isRecurrent[i] {
hasZero = true
} else {
hasFull = true
}
}
if !hasZero || !hasFull {
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) does not produce a mixed recurrent/full layout", interval)
}
return isRecurrent, nil
}
func New(c fs.Config) (model.Model, error) {
numLayers := int(c.Uint("block_count"))
layers := make([]Layer, numLayers)
@@ -460,26 +558,14 @@ func New(c fs.Config) (model.Model, error) {
HeadCountKV() []uint64
}
var isRecurrent []bool
var headCountKV []uint64
if hc, ok := c.(headCounts); ok {
headCountKV = hc.HeadCountKV()
}
isRecurrent = make([]bool, numLayers)
hasZero := false
hasFull := false
for i := range numLayers {
// If KV head count is 0, it's a recurrent layer
if i < len(headCountKV) && headCountKV[i] == 0 {
isRecurrent[i] = true
hasZero = true
} else if i < len(headCountKV) && headCountKV[i] > 0 {
hasFull = true
}
}
if !hasZero || !hasFull {
return nil, fmt.Errorf("qwen3next: invalid attention.head_count_kv array; expected mix of zero and non-zero values")
isRecurrent, err := inferRecurrentLayers(headCountKV, numLayers, c.Uint("full_attention_interval"))
if err != nil {
return nil, err
}
// Determine if MoE
@@ -543,7 +629,7 @@ func New(c fs.Config) (model.Model, error) {
ssmNGroup: int(c.Uint("ssm.group_count")),
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
convKernelSize: int(c.Uint("ssm.conv_kernel")),
vHeadReordered: c.Bool("ssm.v_head_reordered", false),
vHeadReordered: c.Bool("ssm.v_head_reordered", defaultVHeadReordered(c.Architecture())),
isRecurrent: isRecurrent,
mropeSections: slices.Collect(func(yield func(int) bool) {
for _, section := range mropeSections {
@@ -555,7 +641,7 @@ func New(c fs.Config) (model.Model, error) {
mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)),
}
if opts.numKVHeads == 0 {
return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value")
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
}
// Calculate cache dimensions

View File

@@ -0,0 +1,65 @@
package qwen3next
import (
"slices"
"strings"
"testing"
)
func TestInferRecurrentLayersMixedKVArray(t *testing.T) {
got, err := inferRecurrentLayers([]uint64{0, 2, 0, 2}, 4, 0)
if err != nil {
t.Fatalf("inferRecurrentLayers() error = %v", err)
}
want := []bool{true, false, true, false}
if !slices.Equal(got, want) {
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
}
}
func TestInferRecurrentLayersScalarKVDefaultInterval(t *testing.T) {
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2, 2, 2}, 8, 0)
if err != nil {
t.Fatalf("inferRecurrentLayers() error = %v", err)
}
want := []bool{true, true, true, false, true, true, true, false}
if !slices.Equal(got, want) {
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
}
}
func TestInferRecurrentLayersScalarKVConfiguredInterval(t *testing.T) {
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2}, 6, 3)
if err != nil {
t.Fatalf("inferRecurrentLayers() error = %v", err)
}
want := []bool{true, true, false, true, true, false}
if !slices.Equal(got, want) {
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
}
}
func TestInferRecurrentLayersAllZeroRejects(t *testing.T) {
_, err := inferRecurrentLayers([]uint64{0, 0, 0, 0}, 4, 0)
if err == nil {
t.Fatal("inferRecurrentLayers() expected error, got nil")
}
if !strings.Contains(err.Error(), "must include at least one non-zero value") {
t.Fatalf("unexpected error = %v", err)
}
}
func TestDefaultVHeadReordered(t *testing.T) {
if !defaultVHeadReordered("qwen35") {
t.Fatal("defaultVHeadReordered(qwen35) = false, want true")
}
if !defaultVHeadReordered("qwen35moe") {
t.Fatal("defaultVHeadReordered(qwen35moe) = false, want true")
}
if defaultVHeadReordered("qwen3next") {
t.Fatal("defaultVHeadReordered(qwen3next) = true, want false")
}
}

View File

@@ -0,0 +1,45 @@
package qwen3next
import (
"strings"
"testing"
"github.com/ollama/ollama/ml/nn"
)
func TestValidateRecurrentLayerRequiresSSMDT(t *testing.T) {
m := &Model{
Layers: []Layer{{
Operator: &GatedDeltaNet{
SSMQKV: &nn.Linear{},
SSMQKVGate: &nn.Linear{},
SSMBeta: &nn.Linear{},
SSMAlpha: &nn.Linear{},
},
}},
Options: &Options{
isRecurrent: []bool{true},
},
}
err := m.Validate()
if err == nil {
t.Fatal("Validate() expected error, got nil")
}
if !strings.Contains(err.Error(), "missing ssm_dt") {
t.Fatalf("unexpected error = %v", err)
}
}
func TestValidateNonRecurrentSkipsLinearChecks(t *testing.T) {
m := &Model{
Layers: []Layer{{Operator: &FullAttention{}}},
Options: &Options{
isRecurrent: []bool{false},
},
}
if err := m.Validate(); err != nil {
t.Fatalf("Validate() error = %v", err)
}
}