model: add support for qwen3.5-27b model (#14415)

This commit is contained in:
Jeffrey Morgan
2026-02-25 01:09:58 -08:00
committed by GitHub
parent da70c3222e
commit 7f9efd53df
3 changed files with 45 additions and 26 deletions

View File

@@ -122,7 +122,7 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
b := mixedBAReshaped.Slice(ctx, 0, 0, betaSize, 1)
a := mixedBAReshaped.Slice(ctx, 0, betaSize, betaSize+alphaSize, 1)
// Keep beta layout consistent with qwen35 and llama.cpp:
// Keep beta layout consistent with qwen35.
// [1, numVHeads, nSeqTokens, nSeqs]
beta = b.Contiguous(ctx, 1, numVHeads, nSeqTokens, nSeqs)
alpha = a.Contiguous(ctx, numVHeads, nSeqTokens, nSeqs)
@@ -333,7 +333,6 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
v = v.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, nTokens, numVHeads, nSeqs)
// Match llama.cpp delta-net-base layout:
// gate/beta: [1, numVHeads, nTokens, nSeqs] -> [1, nTokens, numVHeads, nSeqs]
gate = gate.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, 1, nTokens, numVHeads, nSeqs)
beta = beta.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, 1, nTokens, numVHeads, nSeqs)
@@ -437,60 +436,64 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
keyGDiff := k.Mul(ctx, gDiffExpReshaped)
keyGDiffT := keyGDiff.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// Process chunks and update state
var coreAttnOut ml.Tensor
newState := state
// Process chunks and update state.
// Keep a transposed view of v and recurrent state across chunks so the
// chunk loop does not need extra transpose+contiguous nodes.
vT := v.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, headVDim, nChunks, numVHeads*nSeqs)
stateT := state.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
for chunk := range nChunks {
qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
vChunk := v.Slice(ctx, 2, chunk, chunk+1, 1)
vTChunk := vT.Slice(ctx, 2, chunk, chunk+1, 1)
gExpChunk := gExp.Slice(ctx, 2, chunk, chunk+1, 1)
kCumdecayChunk := kCumdecay.Slice(ctx, 2, chunk, chunk+1, 1)
attnChunk := attnKQ.Slice(ctx, 2, chunk, chunk+1, 1) // Pre-computed!
// state^T - permute is needed but Contiguous creates a copy
stateT := newState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
// v'_t = k_cumdecay @ state_t
vTPrime := kCumdecayChunk.Mulmat(ctx, stateT)
// v_prime = k_cumdecay @ state
vPrime := stateT.Mulmat(ctx, kCumdecayChunk)
// v_new = v - v_prime
vNew := vChunk.Sub(ctx, vPrime)
vNewT := vNew.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// v_t_new = v_t - v'_t
vTNewChunk := vTChunk.Sub(ctx, vTPrime)
// attn_inter = (q * g_exp) @ state
qGExp := qChunk.Mul(ctx, gExpChunk)
attnInter := stateT.Mulmat(ctx, qGExp)
// core_attn_out = attn_inter + attn @ v_new
vAttn := vNewT.Mulmat(ctx, attnChunk)
vAttn := vTNewChunk.Mulmat(ctx, attnChunk)
coreAttnOutChunk := attnInter.Add(ctx, vAttn)
if coreAttnOut == nil {
coreAttnOut = coreAttnOutChunk
} else {
coreAttnOut = coreAttnOut.Concat(ctx, coreAttnOutChunk, 1)
}
v = v.SetInplace(
ctx,
coreAttnOutChunk,
v.Stride(1),
v.Stride(2),
v.Stride(3),
chunk*v.Stride(2),
)
// Update state for next chunk
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
kGDiffChunkT := keyGDiffT.Slice(ctx, 2, chunk, chunk+1, 1)
kgdMulVNew := vNewT.Mulmat(ctx, kGDiffChunkT)
// kgdmulvnew = key_gdiff_t @ v_new_t
kgdMulVNew := kGDiffChunkT.Mulmat(ctx, vTNewChunk)
// state = state * g_last + kgdmulvnew
gExpLastReshaped := gExpLastChunk.Contiguous(ctx).Reshape(ctx, 1, 1, numVHeads, nSeqs)
newState = newState.Mul(ctx, gExpLastReshaped)
newState = newState.Add(ctx, kgdMulVNew.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs))
// stateT = stateT * g_last + kgdmulvnew
stateT = stateT.Mul(ctx, gExpLastChunk)
stateT = stateT.Add(ctx, kgdMulVNew)
}
// Final reshape
coreAttnOut = coreAttnOut.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
coreAttnOut := v.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
// Slice to remove padding
if pad > 0 {
coreAttnOut = coreAttnOut.Slice(ctx, 1, 0, nTokens, 1)
}
// Convert stateT back to cache layout [S_v, S_v, H_v, nSeqs]
newState := stateT.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, numVHeads, nSeqs)
// Update delta state in cache
cache.UpdateDeltaState(ctx, layer, newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))