qwen3next: fix issue in delta net (#14075)

gDiffExp was being broadcast across the wrong axis when multiplying with k. This fix reshapes gDiffExp to [1, chunkSize, nChunks, ...]
This commit is contained in:
Jeffrey Morgan
2026-02-04 13:40:38 -08:00
committed by GitHub
parent f7102ba826
commit 255579aaa7

View File

@@ -406,8 +406,10 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
gDiff := gCumsum.Neg(ctx).Add(ctx, gLast)
gDiffExp := gDiff.Exp(ctx)
// key_gdiff = k * exp(g_diff)
keyGDiff := k.Mul(ctx, gDiffExp)
// Reshapes g_diff_exp to [1, chunkSize, nChunks, ...]
gDiffExpReshaped := gDiffExp.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs)
keyGDiff := k.Mul(ctx, gDiffExpReshaped)
keyGDiffT := keyGDiff.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// Process chunks and update state
var coreAttnOut ml.Tensor
@@ -444,12 +446,9 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
coreAttnOut = coreAttnOut.Concat(ctx, coreAttnOutChunk, 1)
}
// Update state for next chunk using pre-computed values
// Update state for next chunk
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
kGDiffChunk := keyGDiff.Slice(ctx, 2, chunk, chunk+1, 1)
// kgdmulvnew = key_gdiff^T @ v_new
kGDiffChunkT := kGDiffChunk.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
kGDiffChunkT := keyGDiffT.Slice(ctx, 2, chunk, chunk+1, 1)
kgdMulVNew := vNewT.Mulmat(ctx, kGDiffChunkT)
// state = state * g_last + kgdmulvnew