mirror of
https://github.com/ollama/ollama.git
synced 2026-04-30 07:57:51 -05:00
qwen3next: fix issue in delta net (#14075)
gDiffExp was being broadcast across the wrong axis when multiplying with k. This fix reshapes gDiffExp to [1, chunkSize, nChunks, ...]
This commit is contained in:
@@ -406,8 +406,10 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
|||||||
gDiff := gCumsum.Neg(ctx).Add(ctx, gLast)
|
gDiff := gCumsum.Neg(ctx).Add(ctx, gLast)
|
||||||
gDiffExp := gDiff.Exp(ctx)
|
gDiffExp := gDiff.Exp(ctx)
|
||||||
|
|
||||||
// key_gdiff = k * exp(g_diff)
|
// Reshapes g_diff_exp to [1, chunkSize, nChunks, ...]
|
||||||
keyGDiff := k.Mul(ctx, gDiffExp)
|
gDiffExpReshaped := gDiffExp.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs)
|
||||||
|
keyGDiff := k.Mul(ctx, gDiffExpReshaped)
|
||||||
|
keyGDiffT := keyGDiff.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||||
|
|
||||||
// Process chunks and update state
|
// Process chunks and update state
|
||||||
var coreAttnOut ml.Tensor
|
var coreAttnOut ml.Tensor
|
||||||
@@ -444,12 +446,9 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
|||||||
coreAttnOut = coreAttnOut.Concat(ctx, coreAttnOutChunk, 1)
|
coreAttnOut = coreAttnOut.Concat(ctx, coreAttnOutChunk, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update state for next chunk using pre-computed values
|
// Update state for next chunk
|
||||||
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||||
kGDiffChunk := keyGDiff.Slice(ctx, 2, chunk, chunk+1, 1)
|
kGDiffChunkT := keyGDiffT.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||||
|
|
||||||
// kgdmulvnew = key_gdiff^T @ v_new
|
|
||||||
kGDiffChunkT := kGDiffChunk.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
||||||
kgdMulVNew := vNewT.Mulmat(ctx, kGDiffChunkT)
|
kgdMulVNew := vNewT.Mulmat(ctx, kGDiffChunkT)
|
||||||
|
|
||||||
// state = state * g_last + kgdmulvnew
|
// state = state * g_last + kgdmulvnew
|
||||||
|
|||||||
Reference in New Issue
Block a user