diff --git a/model/models/qwen3next/deltanet.go b/model/models/qwen3next/deltanet.go index 6ce315649..9c7aa02b5 100644 --- a/model/models/qwen3next/deltanet.go +++ b/model/models/qwen3next/deltanet.go @@ -454,6 +454,10 @@ func (gdn *GatedDeltaNet) deltaNetChunked( vT := v.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, headVDim, nChunks, numVHeads*nSeqs) stateT := state.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs) + // Collect chunk outputs and concatenate at the end. + // Avoids SET on buffer-less intermediates under partial offload. + chunks := make([]ml.Tensor, nChunks) + for chunk := range nChunks { qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1) vTChunk := vT.Slice(ctx, 2, chunk, chunk+1, 1) @@ -475,14 +479,7 @@ func (gdn *GatedDeltaNet) deltaNetChunked( vAttn := vTNewChunk.Mulmat(ctx, attnChunk) coreAttnOutChunk := attnInter.Add(ctx, vAttn) - v = v.SetInplace( - ctx, - coreAttnOutChunk, - v.Stride(1), - v.Stride(2), - v.Stride(3), - chunk*v.Stride(2), - ) + chunks[chunk] = coreAttnOutChunk // Update state for next chunk gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1) @@ -495,6 +492,20 @@ func (gdn *GatedDeltaNet) deltaNetChunked( stateT = stateT.Add(ctx, kgdMulVNew) } + // Use a balanced concat tree so concat work does not balloon on long prompts. + for len(chunks) > 1 { + merged := make([]ml.Tensor, 0, (len(chunks)+1)/2) + for i := 0; i < len(chunks); i += 2 { + if i+1 < len(chunks) { + merged = append(merged, chunks[i].Concat(ctx, chunks[i+1], 2)) + } else { + merged = append(merged, chunks[i]) + } + } + chunks = merged + } + v = chunks[0] + // Final reshape coreAttnOut := v.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)