mirror of
https://github.com/ollama/ollama.git
synced 2026-03-09 03:12:11 -05:00
qwen3next: fix issue in delta net (#14075)
gDiffExp was being broadcast across the wrong axis when multiplying with k. This fix reshapes gDiffExp to [1, chunkSize, nChunks, ...]
This commit is contained in:
@@ -406,8 +406,10 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
||||
gDiff := gCumsum.Neg(ctx).Add(ctx, gLast)
|
||||
gDiffExp := gDiff.Exp(ctx)
|
||||
|
||||
// key_gdiff = k * exp(g_diff)
|
||||
keyGDiff := k.Mul(ctx, gDiffExp)
|
||||
// Reshapes g_diff_exp to [1, chunkSize, nChunks, ...]
|
||||
gDiffExpReshaped := gDiffExp.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs)
|
||||
keyGDiff := k.Mul(ctx, gDiffExpReshaped)
|
||||
keyGDiffT := keyGDiff.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
// Process chunks and update state
|
||||
var coreAttnOut ml.Tensor
|
||||
@@ -444,12 +446,9 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
||||
coreAttnOut = coreAttnOut.Concat(ctx, coreAttnOutChunk, 1)
|
||||
}
|
||||
|
||||
// Update state for next chunk using pre-computed values
|
||||
// Update state for next chunk
|
||||
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
kGDiffChunk := keyGDiff.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
|
||||
// kgdmulvnew = key_gdiff^T @ v_new
|
||||
kGDiffChunkT := kGDiffChunk.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
kGDiffChunkT := keyGDiffT.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
kgdMulVNew := vNewT.Mulmat(ctx, kGDiffChunkT)
|
||||
|
||||
// state = state * g_last + kgdmulvnew
|
||||
|
||||
Reference in New Issue
Block a user