From c330ea33edd8bddac4988cdb33eb3cdcda224f4f Mon Sep 17 00:00:00 2001 From: jmorganca Date: Thu, 5 Feb 2026 11:47:27 -0800 Subject: [PATCH] qwen3next: handle mixed recurrent batches Allow mixed token-count batches by tracking per-seq indices and falling back to per-seq recurrent processing when layouts differ. Add per-slot conv/delta state access with checkpoint capture, relax attention layout handling, and reuse projections in mixed batches to reduce overhead. --- model/models/qwen3next/attention.go | 11 +- model/models/qwen3next/cache.go | 127 +++++++++++++++++++- model/models/qwen3next/deltanet.go | 175 ++++++++++++++++++++++------ 3 files changed, 273 insertions(+), 40 deletions(-) diff --git a/model/models/qwen3next/attention.go b/model/models/qwen3next/attention.go index ee4a06bea..b9d8b1bc6 100644 --- a/model/models/qwen3next/attention.go +++ b/model/models/qwen3next/attention.go @@ -39,12 +39,15 @@ func (sa *FullAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tens if nSeqs > 0 { // 3D tensor: [hiddenDim, seqTokens, nSeqs] if batchSize != seqTokens || nSeqs != seqs { - return nil, ErrUnsupportedBatchLayout + // Fallback: treat as flat batch if layout doesn't match. + hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, batchSize*nSeqs) + batchSize = batchSize * nSeqs + } else { + hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, seqTokens*seqs) + batchSize = seqTokens * seqs } - hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, seqTokens*seqs) - batchSize = seqTokens * seqs } else if batchSize != seqTokens*seqs { - return nil, ErrUnsupportedBatchLayout + // Layout mismatch; proceed with flat batch. } } } diff --git a/model/models/qwen3next/cache.go b/model/models/qwen3next/cache.go index 86ee2b58d..3e24666ae 100644 --- a/model/models/qwen3next/cache.go +++ b/model/models/qwen3next/cache.go @@ -64,6 +64,8 @@ type HybridCache struct { curSlots []int curSlotsInput ml.Tensor curSeqTokens int + // token indices per sequence in batch order + curSeqTokenIdxs [][]int32 // track if EnsureWritable has been called for this forward pass writableEnsured bool @@ -168,19 +170,44 @@ func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bo } if len(c.curSeqs) == 0 { + c.curSeqTokenIdxs = c.curSeqTokenIdxs[:0] return nil } + if cap(c.curSeqTokenIdxs) < len(c.curSeqs) { + c.curSeqTokenIdxs = make([][]int32, len(c.curSeqs)) + } else { + c.curSeqTokenIdxs = c.curSeqTokenIdxs[:len(c.curSeqs)] + } + for i := range c.curSeqTokenIdxs { + c.curSeqTokenIdxs[i] = c.curSeqTokenIdxs[i][:0] + } + + seqIndex := make(map[int]int, len(c.curSeqs)) + for i, s := range c.curSeqs { + seqIndex[s] = i + } + for i, s := range batch.Sequences { + c.curSeqTokenIdxs[seqIndex[s]] = append(c.curSeqTokenIdxs[seqIndex[s]], int32(i)) + } + nTokens := len(batch.Sequences) nSeqs := len(c.curSeqs) want := nTokens / nSeqs + uniform := true for _, s := range c.curSeqs { if seqCounts[s] != want { - return kvcache.ErrNotSupported + uniform = false + break } } - c.curSeqTokens = want + if uniform { + c.curSeqTokens = want + } else { + // Mixed batch: recurrent layers will process sequences independently. + c.curSeqTokens = 0 + } // When reserving memory for estimation, use fake slot assignments if reserve { @@ -585,7 +612,101 @@ func (c *HybridCache) UpdateDeltaState(ctx ml.Context, layer int, newState ml.Te c.captureDeltaCheckpoint(ctx, layer, srcF32) } -// IsSupportedForBatch returns true if the current batch layout supports recurrent layers. +// convStateForSlot returns the conv state for a single slot as [convDim, convChannels, 1]. +func (c *HybridCache) convStateForSlot(ctx ml.Context, layer int, slot int) (ml.Tensor, error) { + c.ensureWritableOnce(ctx) + if c.writableError != nil { + return nil, c.writableError + } + buf := c.convBuffer(ctx, layer) + slotIdx := ctx.Input().FromInts([]int32{int32(slot)}, 1) + cur := buf.Rows(ctx, slotIdx) + return cur.Reshape(ctx, c.convDim, c.convChannels, 1), nil +} + +// updateConvStateForSlot writes a new conv state for a single slot. +func (c *HybridCache) updateConvStateForSlot(ctx ml.Context, layer int, slot int, seqIndex int, newState ml.Tensor) { + buf := c.convBuffer(ctx, layer) + src := newState.Reshape(ctx, c.convDim*c.convChannels, 1) + srcF32 := src.Cast(ctx, ml.DTypeF32) + slotIdx := ctx.Input().FromInts([]int32{int32(slot)}, 1) + ctx.Forward(buf.SetRows(ctx, srcF32, slotIdx)) + c.captureConvCheckpointForSeq(ctx, layer, seqIndex, srcF32) +} + +// deltaStateForSlot returns the delta state for a single slot as [headVDim, headVDim*numVHeads, 1]. +func (c *HybridCache) deltaStateForSlot(ctx ml.Context, layer int, slot int, headVDim, numVHeads int) (ml.Tensor, error) { + c.ensureWritableOnce(ctx) + if c.writableError != nil { + return nil, c.writableError + } + buf := c.deltaBuffer(ctx, layer) + slotIdx := ctx.Input().FromInts([]int32{int32(slot)}, 1) + cur := buf.Rows(ctx, slotIdx) + return cur.Reshape(ctx, headVDim, headVDim*numVHeads, 1), nil +} + +// updateDeltaStateForSlot writes a new delta state for a single slot. +func (c *HybridCache) updateDeltaStateForSlot(ctx ml.Context, layer int, slot int, seqIndex int, newState ml.Tensor) { + buf := c.deltaBuffer(ctx, layer) + src := newState.Reshape(ctx, c.deltaStateSize, 1) + srcF32 := src.Cast(ctx, ml.DTypeF32) + slotIdx := ctx.Input().FromInts([]int32{int32(slot)}, 1) + ctx.Forward(buf.SetRows(ctx, srcF32, slotIdx)) + c.captureDeltaCheckpointForSeq(ctx, layer, seqIndex, srcF32) +} + +func (c *HybridCache) captureConvCheckpointForSeq(ctx ml.Context, layer int, seqIndex int, src ml.Tensor) { + if c.checkpointCount == 0 { + return + } + if c.reserveCheckpoints { + c.reserveCheckpointConv(layer) + return + } + if seqIndex < 0 || seqIndex >= len(c.curCheckpointPos) { + return + } + pos := c.curCheckpointPos[seqIndex] + if pos < 0 { + return + } + slot := c.curSlots[seqIndex] + idx := c.checkpointIndexForSlot(slot, pos) + if idx < 0 { + return + } + entry := &c.checkpoints[slot].entries[idx] + dst := c.ensureCheckpointConv(layer, entry) + ctx.Forward(src.Copy(ctx, dst)) +} + +func (c *HybridCache) captureDeltaCheckpointForSeq(ctx ml.Context, layer int, seqIndex int, src ml.Tensor) { + if c.checkpointCount == 0 { + return + } + if c.reserveCheckpoints { + c.reserveCheckpointDelta(layer) + return + } + if seqIndex < 0 || seqIndex >= len(c.curCheckpointPos) { + return + } + pos := c.curCheckpointPos[seqIndex] + if pos < 0 { + return + } + slot := c.curSlots[seqIndex] + idx := c.checkpointIndexForSlot(slot, pos) + if idx < 0 { + return + } + entry := &c.checkpoints[slot].entries[idx] + dst := c.ensureCheckpointDelta(layer, entry) + ctx.Forward(src.Copy(ctx, dst)) +} + +// IsSupportedForBatch returns true if the current batch layout supports grid-style recurrent processing. func (c *HybridCache) IsSupportedForBatch() bool { return c.curSeqTokens > 0 && len(c.curSeqs) > 0 } diff --git a/model/models/qwen3next/deltanet.go b/model/models/qwen3next/deltanet.go index e0a6f7b25..3a96850ff 100644 --- a/model/models/qwen3next/deltanet.go +++ b/model/models/qwen3next/deltanet.go @@ -48,6 +48,13 @@ type GatedDeltaNet struct { Layer int } +type stateAccessors struct { + convState func() (ml.Tensor, error) + updateConv func(ml.Tensor) + deltaState func() (ml.Tensor, error) + updateDelta func(ml.Tensor) +} + // createMasks builds the constant mask tensors (called once, reused for all chunks) func createMasks(ctx ml.Context) *Masks { ones := ctx.Input().Zeros(ml.DTypeF32, chunkSize, chunkSize) @@ -68,7 +75,6 @@ func createMasks(ctx ml.Context) *Masks { } func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) { - layer := gdn.Layer nSeqTokens := hiddenStates.Dim(1) nSeqs := hiddenStates.Dim(2) if cache != nil && cache.IsSupportedForBatch() { @@ -77,34 +83,140 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac if seqTokens > 0 && seqs > 0 { if nSeqs > 1 { if nSeqTokens != seqTokens || nSeqs != seqs { - return nil, ErrUnsupportedBatchLayout + return gdn.forwardMixed(ctx, hiddenStates, cache, opts) } } else { if nSeqTokens != seqTokens*seqs { - return nil, ErrUnsupportedBatchLayout + return gdn.forwardMixed(ctx, hiddenStates, cache, opts) } hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), seqTokens, seqs) - nSeqTokens = seqTokens - nSeqs = seqs } } + + numVHeads := opts.ssmDtRank + headVDim := opts.ssmDInner / numVHeads + layer := gdn.Layer + access := stateAccessors{ + convState: func() (ml.Tensor, error) { + return cache.ConvState(ctx, layer) + }, + updateConv: func(newState ml.Tensor) { + cache.UpdateConvState(ctx, layer, newState) + }, + deltaState: func() (ml.Tensor, error) { + return cache.DeltaState(ctx, layer, headVDim, numVHeads) + }, + updateDelta: func(newState ml.Tensor) { + cache.UpdateDeltaState(ctx, layer, newState) + }, + } + + return gdn.forwardWithAccessors(ctx, hiddenStates, opts, access) } + if cache == nil { + return nil, ErrUnsupportedBatchLayout + } + + return gdn.forwardMixed(ctx, hiddenStates, cache, opts) +} + +func (gdn *GatedDeltaNet) forwardMixed(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) { + if hiddenStates.Dim(2) > 0 { + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), hiddenStates.Dim(1)*hiddenStates.Dim(2)) + } + + if len(cache.curSeqs) == 0 { + return hiddenStates, nil + } + + // Ensure any shared slots are detached once for this forward pass. + cache.ensureWritableOnce(ctx) + + layer := gdn.Layer + numVHeads := opts.ssmDtRank + headVDim := opts.ssmDInner / numVHeads + + if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil { + return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)") + } + + // Precompute projections for the full batch and slice per sequence. + mixedBAFull := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates) + qkvMixedFull := gdn.SSMQKV.Forward(ctx, hiddenStates) + zFull := gdn.SSMQKVGate.Forward(ctx, hiddenStates) + + out := hiddenStates + for seqIndex := range cache.curSeqs { + idxs := cache.curSeqTokenIdxs[seqIndex] + if len(idxs) == 0 { + continue + } + idxTensor := ctx.Input().FromInts(idxs, len(idxs)) + + mixedBA := mixedBAFull.Rows(ctx, idxTensor) + qkvMixed := qkvMixedFull.Rows(ctx, idxTensor) + z := zFull.Rows(ctx, idxTensor) + + slot := cache.curSlots[seqIndex] + access := stateAccessors{ + convState: func() (ml.Tensor, error) { + return cache.convStateForSlot(ctx, layer, slot) + }, + updateConv: func(newState ml.Tensor) { + cache.updateConvStateForSlot(ctx, layer, slot, seqIndex, newState) + }, + deltaState: func() (ml.Tensor, error) { + return cache.deltaStateForSlot(ctx, layer, slot, headVDim, numVHeads) + }, + updateDelta: func(newState ml.Tensor) { + cache.updateDeltaStateForSlot(ctx, layer, slot, seqIndex, newState) + }, + } + + seqOut, err := gdn.forwardProjected(ctx, len(idxs), 1, mixedBA, qkvMixed, z, opts, access) + if err != nil { + return nil, err + } + out = out.SetRows(ctx, seqOut, idxTensor) + } + + return out, nil +} + +func (gdn *GatedDeltaNet) forwardWithAccessors(ctx ml.Context, hiddenStates ml.Tensor, opts *Options, access stateAccessors) (ml.Tensor, error) { + nSeqTokens := hiddenStates.Dim(1) + nSeqs := hiddenStates.Dim(2) + + mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates) + + if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil { + return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)") + } + // Optimized path: pre-split QKV and gate + qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates) + z := gdn.SSMQKVGate.Forward(ctx, hiddenStates) + + return gdn.forwardProjected(ctx, nSeqTokens, nSeqs, mixedBA, qkvMixed, z, opts, access) +} + +func (gdn *GatedDeltaNet) forwardProjected( + ctx ml.Context, + nSeqTokens, nSeqs int, + mixedBA, qkvMixed, z ml.Tensor, + opts *Options, + access stateAccessors, +) (ml.Tensor, error) { + layer := gdn.Layer + headKDim := opts.ssmDState numKHeads := opts.ssmNGroup numVHeads := opts.ssmDtRank headVDim := opts.ssmDInner / numVHeads convKernelSize := opts.convKernelSize - - mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates) qkvDim := headKDim*numKHeads*2 + headVDim*numVHeads - if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil { - return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)") - } - // Optimized path: pre-split QKV and gate - qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs) - z := gdn.SSMQKVGate.Forward(ctx, hiddenStates) + qkvMixed = qkvMixed.Reshape(ctx, qkvDim, nSeqTokens, nSeqs) baNewDim := 2 * numVHeads / numKHeads mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs) @@ -127,7 +239,7 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac qkvMixed = qkvMixed.Permute(ctx, 1, 0, 2, 3) // Get conv state from cache - convStates, err := cache.ConvState(ctx, layer) + convStates, err := access.convState() if err != nil { // Log this - if it happens, short-term context will be lost slog.Warn("qwen3next: failed to get conv state, using zeros", "layer", layer, "error", err) @@ -142,7 +254,7 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac // Save new conv state (last convKernelSize-1 tokens) lastConvStates := convInput.Slice(ctx, 0, nSeqTokens, nSeqTokens+convKernelSize-1, 1) - cache.UpdateConvState(ctx, layer, lastConvStates) + access.updateConv(lastConvStates) // Apply SSM convolution (kernel must be F32 for Metal) convOutput := convInput.SSMConv(ctx, gdn.SSMConv1D.Weight) @@ -162,7 +274,7 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac vConv = vConv.Contiguous(ctx, headVDim, numVHeads, nSeqTokens, nSeqs) // Get delta state from cache - state, err := cache.DeltaState(ctx, layer, headVDim, numVHeads) + state, err := access.deltaState() if err != nil { // Log this - if it happens frequently, context will degrade slog.Warn("qwen3next: failed to get delta state, using zeros", "layer", layer, "error", err) @@ -185,14 +297,19 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac } // Choose computation mode based on sequence length - var attnOut ml.Tensor + var ( + attnOut ml.Tensor + newState ml.Tensor + ) if nSeqTokens == 1 { - attnOut = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts, layer, cache) + attnOut, newState = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts) } else { // Use pre-computed masks from opts (created once in Model.Forward) - attnOut = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts, layer, cache) + attnOut, newState = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts) } + access.updateDelta(newState) + // Apply gated normalization attnOut2D := attnOut.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs) z2D := z.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs) @@ -215,9 +332,7 @@ func (gdn *GatedDeltaNet) deltaNetAutoregressive( ctx ml.Context, q, k, v, gate, beta, state ml.Tensor, opts *Options, - layer int, - cache *HybridCache, -) ml.Tensor { +) (ml.Tensor, ml.Tensor) { numVHeads := v.Dim(1) headVDim := v.Dim(0) nSeqs := q.Dim(3) @@ -273,10 +388,8 @@ func (gdn *GatedDeltaNet) deltaNetAutoregressive( coreAttnOut := stateQ.SumRows(ctx) coreAttnOut = coreAttnOut.Permute(ctx, 1, 0, 2, 3) - // Update delta state in cache - cache.UpdateDeltaState(ctx, layer, state.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs)) - - return coreAttnOut.Reshape(ctx, headVDim, numVHeads, 1, nSeqs) + newState := state.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs) + return coreAttnOut.Reshape(ctx, headVDim, numVHeads, 1, nSeqs), newState } // deltaNetChunked implements chunked computation for prefill. @@ -286,9 +399,7 @@ func (gdn *GatedDeltaNet) deltaNetChunked( q, k, v, gate, beta, state ml.Tensor, masks *Masks, opts *Options, - layer int, - cache *HybridCache, -) ml.Tensor { +) (ml.Tensor, ml.Tensor) { headKDim := q.Dim(0) numVHeads := v.Dim(1) headVDim := v.Dim(0) @@ -465,8 +576,6 @@ func (gdn *GatedDeltaNet) deltaNetChunked( coreAttnOut = coreAttnOut.Slice(ctx, 1, 0, nTokens, 1) } - // Update delta state in cache - cache.UpdateDeltaState(ctx, layer, newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs)) - - return coreAttnOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, numVHeads, nTokens, nSeqs) + newStateFlat := newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs) + return coreAttnOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, numVHeads, nTokens, nSeqs), newStateFlat }