mirror of
https://github.com/ollama/ollama.git
synced 2026-04-30 16:08:07 -05:00
qwen3next: add compatibility with imported GGUF models (#14517)
This commit is contained in:
@@ -41,7 +41,7 @@ type GatedDeltaNet struct {
|
|||||||
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
|
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
|
||||||
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
|
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
|
||||||
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
|
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
|
||||||
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
|
SSMDT ml.Tensor `gguf:"ssm_dt,alt:ssm_dt.bias"` // alpha bias
|
||||||
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
||||||
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
|
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
|
||||||
SSMOut *nn.Linear `gguf:"ssm_out"`
|
SSMOut *nn.Linear `gguf:"ssm_out"`
|
||||||
@@ -135,6 +135,18 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
|||||||
default:
|
default:
|
||||||
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
|
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
|
||||||
}
|
}
|
||||||
|
if gdn.SSMDT == nil {
|
||||||
|
return nil, errors.New("qwen3next: missing linear attention ssm_dt tensor")
|
||||||
|
}
|
||||||
|
if gdn.SSMA == nil {
|
||||||
|
return nil, errors.New("qwen3next: missing linear attention ssm_a tensor")
|
||||||
|
}
|
||||||
|
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
|
||||||
|
return nil, errors.New("qwen3next: missing linear attention ssm_conv1d tensor")
|
||||||
|
}
|
||||||
|
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
|
||||||
|
return nil, errors.New("qwen3next: missing linear attention ssm_norm/ssm_out projections")
|
||||||
|
}
|
||||||
|
|
||||||
// Compute gate: softplus(alpha + dt_bias) * -A
|
// Compute gate: softplus(alpha + dt_bias) * -A
|
||||||
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
|
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
|
||||||
|
|||||||
@@ -437,6 +437,46 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
return m.Output.Forward(ctx, hiddenStates), nil
|
return m.Output.Forward(ctx, hiddenStates), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) Validate() error {
|
||||||
|
if m.Options == nil {
|
||||||
|
return fmt.Errorf("qwen3next: missing model options")
|
||||||
|
}
|
||||||
|
if len(m.Layers) != len(m.Options.isRecurrent) {
|
||||||
|
return fmt.Errorf("qwen3next: layer config mismatch: have %d layers, %d recurrent flags", len(m.Layers), len(m.Options.isRecurrent))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, layer := range m.Layers {
|
||||||
|
if !m.Options.isRecurrent[i] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
gdn, ok := layer.Operator.(*GatedDeltaNet)
|
||||||
|
if !ok || gdn == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d expected recurrent operator", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing attn_qkv/attn_gate projections", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMBetaAlpha == nil && (gdn.SSMBeta == nil || gdn.SSMAlpha == nil) {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing linear attention beta/alpha projections", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMDT == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing ssm_dt tensor", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMA == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing ssm_a tensor", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing ssm_conv1d tensor", i)
|
||||||
|
}
|
||||||
|
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
|
||||||
|
return fmt.Errorf("qwen3next: layer %d missing ssm_norm/ssm_out projections", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
m.positionCache = nil
|
m.positionCache = nil
|
||||||
if len(m.mropeSections) > 0 {
|
if len(m.mropeSections) > 0 {
|
||||||
@@ -450,6 +490,64 @@ var (
|
|||||||
_ model.MultimodalProcessor = (*Model)(nil)
|
_ model.MultimodalProcessor = (*Model)(nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func defaultVHeadReordered(arch string) bool {
|
||||||
|
return arch == "qwen35" || arch == "qwen35moe"
|
||||||
|
}
|
||||||
|
|
||||||
|
func inferRecurrentLayers(headCountKV []uint64, numLayers int, fullAttentionInterval uint32) ([]bool, error) {
|
||||||
|
isRecurrent := make([]bool, numLayers)
|
||||||
|
|
||||||
|
hasZero := false
|
||||||
|
hasFull := false
|
||||||
|
for i := range numLayers {
|
||||||
|
if i >= len(headCountKV) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if headCountKV[i] == 0 {
|
||||||
|
isRecurrent[i] = true
|
||||||
|
hasZero = true
|
||||||
|
} else {
|
||||||
|
hasFull = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hasZero && hasFull {
|
||||||
|
return isRecurrent, nil
|
||||||
|
}
|
||||||
|
if !hasFull {
|
||||||
|
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compatibility path: older imports store a scalar KV head count and omit
|
||||||
|
// per-layer recurrent flags. Derive the hybrid layout from the interval.
|
||||||
|
interval := int(fullAttentionInterval)
|
||||||
|
if interval == 0 {
|
||||||
|
interval = min(4, numLayers)
|
||||||
|
}
|
||||||
|
if interval <= 0 {
|
||||||
|
return nil, fmt.Errorf("qwen3next: invalid block_count (%d)", numLayers)
|
||||||
|
}
|
||||||
|
if interval > numLayers {
|
||||||
|
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds block_count (%d)", interval, numLayers)
|
||||||
|
}
|
||||||
|
|
||||||
|
hasZero = false
|
||||||
|
hasFull = false
|
||||||
|
for i := range numLayers {
|
||||||
|
isRecurrent[i] = (i+1)%interval != 0
|
||||||
|
if isRecurrent[i] {
|
||||||
|
hasZero = true
|
||||||
|
} else {
|
||||||
|
hasFull = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasZero || !hasFull {
|
||||||
|
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) does not produce a mixed recurrent/full layout", interval)
|
||||||
|
}
|
||||||
|
|
||||||
|
return isRecurrent, nil
|
||||||
|
}
|
||||||
|
|
||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
numLayers := int(c.Uint("block_count"))
|
numLayers := int(c.Uint("block_count"))
|
||||||
layers := make([]Layer, numLayers)
|
layers := make([]Layer, numLayers)
|
||||||
@@ -460,26 +558,14 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
HeadCountKV() []uint64
|
HeadCountKV() []uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
var isRecurrent []bool
|
|
||||||
var headCountKV []uint64
|
var headCountKV []uint64
|
||||||
if hc, ok := c.(headCounts); ok {
|
if hc, ok := c.(headCounts); ok {
|
||||||
headCountKV = hc.HeadCountKV()
|
headCountKV = hc.HeadCountKV()
|
||||||
}
|
}
|
||||||
|
|
||||||
isRecurrent = make([]bool, numLayers)
|
isRecurrent, err := inferRecurrentLayers(headCountKV, numLayers, c.Uint("full_attention_interval"))
|
||||||
hasZero := false
|
if err != nil {
|
||||||
hasFull := false
|
return nil, err
|
||||||
for i := range numLayers {
|
|
||||||
// If KV head count is 0, it's a recurrent layer
|
|
||||||
if i < len(headCountKV) && headCountKV[i] == 0 {
|
|
||||||
isRecurrent[i] = true
|
|
||||||
hasZero = true
|
|
||||||
} else if i < len(headCountKV) && headCountKV[i] > 0 {
|
|
||||||
hasFull = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !hasZero || !hasFull {
|
|
||||||
return nil, fmt.Errorf("qwen3next: invalid attention.head_count_kv array; expected mix of zero and non-zero values")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine if MoE
|
// Determine if MoE
|
||||||
@@ -543,7 +629,7 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
ssmNGroup: int(c.Uint("ssm.group_count")),
|
ssmNGroup: int(c.Uint("ssm.group_count")),
|
||||||
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
|
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
|
||||||
convKernelSize: int(c.Uint("ssm.conv_kernel")),
|
convKernelSize: int(c.Uint("ssm.conv_kernel")),
|
||||||
vHeadReordered: c.Bool("ssm.v_head_reordered", false),
|
vHeadReordered: c.Bool("ssm.v_head_reordered", defaultVHeadReordered(c.Architecture())),
|
||||||
isRecurrent: isRecurrent,
|
isRecurrent: isRecurrent,
|
||||||
mropeSections: slices.Collect(func(yield func(int) bool) {
|
mropeSections: slices.Collect(func(yield func(int) bool) {
|
||||||
for _, section := range mropeSections {
|
for _, section := range mropeSections {
|
||||||
@@ -555,7 +641,7 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)),
|
mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)),
|
||||||
}
|
}
|
||||||
if opts.numKVHeads == 0 {
|
if opts.numKVHeads == 0 {
|
||||||
return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value")
|
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate cache dimensions
|
// Calculate cache dimensions
|
||||||
|
|||||||
65
model/models/qwen3next/model_new_test.go
Normal file
65
model/models/qwen3next/model_new_test.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
package qwen3next
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInferRecurrentLayersMixedKVArray(t *testing.T) {
|
||||||
|
got, err := inferRecurrentLayers([]uint64{0, 2, 0, 2}, 4, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []bool{true, false, true, false}
|
||||||
|
if !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInferRecurrentLayersScalarKVDefaultInterval(t *testing.T) {
|
||||||
|
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2, 2, 2}, 8, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []bool{true, true, true, false, true, true, true, false}
|
||||||
|
if !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInferRecurrentLayersScalarKVConfiguredInterval(t *testing.T) {
|
||||||
|
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2}, 6, 3)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []bool{true, true, false, true, true, false}
|
||||||
|
if !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInferRecurrentLayersAllZeroRejects(t *testing.T) {
|
||||||
|
_, err := inferRecurrentLayers([]uint64{0, 0, 0, 0}, 4, 0)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("inferRecurrentLayers() expected error, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "must include at least one non-zero value") {
|
||||||
|
t.Fatalf("unexpected error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultVHeadReordered(t *testing.T) {
|
||||||
|
if !defaultVHeadReordered("qwen35") {
|
||||||
|
t.Fatal("defaultVHeadReordered(qwen35) = false, want true")
|
||||||
|
}
|
||||||
|
if !defaultVHeadReordered("qwen35moe") {
|
||||||
|
t.Fatal("defaultVHeadReordered(qwen35moe) = false, want true")
|
||||||
|
}
|
||||||
|
if defaultVHeadReordered("qwen3next") {
|
||||||
|
t.Fatal("defaultVHeadReordered(qwen3next) = true, want false")
|
||||||
|
}
|
||||||
|
}
|
||||||
45
model/models/qwen3next/model_validate_test.go
Normal file
45
model/models/qwen3next/model_validate_test.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package qwen3next
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidateRecurrentLayerRequiresSSMDT(t *testing.T) {
|
||||||
|
m := &Model{
|
||||||
|
Layers: []Layer{{
|
||||||
|
Operator: &GatedDeltaNet{
|
||||||
|
SSMQKV: &nn.Linear{},
|
||||||
|
SSMQKVGate: &nn.Linear{},
|
||||||
|
SSMBeta: &nn.Linear{},
|
||||||
|
SSMAlpha: &nn.Linear{},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
Options: &Options{
|
||||||
|
isRecurrent: []bool{true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := m.Validate()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Validate() expected error, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "missing ssm_dt") {
|
||||||
|
t.Fatalf("unexpected error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateNonRecurrentSkipsLinearChecks(t *testing.T) {
|
||||||
|
m := &Model{
|
||||||
|
Layers: []Layer{{Operator: &FullAttention{}}},
|
||||||
|
Options: &Options{
|
||||||
|
isRecurrent: []bool{false},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.Validate(); err != nil {
|
||||||
|
t.Fatalf("Validate() error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user