kvcache: Use SetRows to store cache data

We currently copy data into the KV cache in contiguous buffers using
ggml_cpy(). ggml_set_rows() was introduced to allow scatter operation
so that contiguous buffers are no longer required. The direct primary
benefit of this is that we no longer need to perform defragmentation.

However, GGML recently removed an optimization for ggml_cpy() and
we picked it up in 544b673 "ggml update to b6840 (#12791)". This
caused a roughly 40% drop in token generation performance on CUDA
due to CUDA graphs no longer being used. By switching to
ggml_set_rows(), the original optimization is no longer necessary
and CUDA performance is restored.

Fixes #13112
This commit is contained in:
Jesse Gross
2025-08-18 10:45:58 -07:00
committed by Jesse Gross
parent b6e02cbbd2
commit 53985b3c4d
4 changed files with 164 additions and 235 deletions

View File

@@ -3,7 +3,6 @@ package kvcache
import (
"errors"
"fmt"
"log/slog"
"math"
"slices"
@@ -40,18 +39,18 @@ type Causal struct {
// ** current forward pass **
// the active layer for Get and Put
curLayer int
// starting location for data storage for this batch
curLoc int
// size of the current batch
curBatchSize int
// locations for data storage for this batch
curLoc ml.Tensor
// mask of the cache as used by this batch
curMask ml.Tensor
// the active layer for Get and Put
curLayer int
// locations in the cache that are needed for this batch
curCellRange cellRange
@@ -206,45 +205,47 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
c.curPositions = batch.Positions
c.opts.Except = nil
var locs []int32
if !reserve {
c.updateSlidingWindow()
var err error
c.curLoc, err = c.findStartLoc()
if errors.Is(err, ErrKvCacheFull) {
c.defrag()
c.curLoc, err = c.findStartLoc()
}
locs, err = c.findLocs()
if err != nil {
return err
}
for i, pos := range batch.Positions {
seq := batch.Sequences[i]
loc := int(locs[i])
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}}
seqRange, ok := c.cellRanges[seq]
if !ok {
seqRange = newRange()
}
seqRange.min = min(seqRange.min, c.curLoc+i)
c.curCellRange.min = min(c.curCellRange.min, c.curLoc+i)
seqRange.min = min(seqRange.min, loc)
c.curCellRange.min = min(c.curCellRange.min, loc)
seqRange.max = max(seqRange.max, c.curLoc+i)
c.curCellRange.max = max(c.curCellRange.max, c.curLoc+i)
seqRange.max = max(seqRange.max, loc)
c.curCellRange.max = max(c.curCellRange.max, loc)
c.cellRanges[seq] = seqRange
}
} else {
// If we are reserving memory, don't update any of the cache metadata but set the size
// to the worst case.
c.curLoc = 0
locs = make([]int32, c.curBatchSize)
for i := range locs {
locs[i] = int32(i)
}
c.curCellRange.min = 0
c.curCellRange.max = len(c.cells) - 1
}
c.curLoc = ctx.Input().FromInts(locs, len(locs))
c.curMask = c.buildMask(ctx)
return nil
@@ -257,22 +258,20 @@ func newRange() cellRange {
}
}
// Find the first contiguous block of at least curBatchSize
func (c *Causal) findStartLoc() (int, error) {
var start, count int
// Returns a slice of locations where each token in the batch should be stored
func (c *Causal) findLocs() ([]int32, error) {
loc := make([]int32, 0, c.curBatchSize)
for i := range c.cells {
if len(c.cells[i].sequences) == 0 {
count++
if count >= c.curBatchSize {
return start, nil
loc = append(loc, int32(i))
if len(loc) >= c.curBatchSize {
return loc, nil
}
} else {
start = i + 1
count = 0
}
}
return 0, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
}
func (c *Causal) updateSlidingWindow() {
@@ -402,145 +401,6 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
return maskTensor
}
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
for i, key := range c.keys {
if key == nil {
continue
}
kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
value := c.values[i]
var vSrcView, vDstView ml.Tensor
if c.config.PermutedV {
vHeadDim := value.Dim(1)
elemSize := value.Stride(0)
vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
}
ctx.Forward(
kSrcView.Copy(ctx, kDstView),
vSrcView.Copy(ctx, vDstView),
)
}
}
func (c *Causal) defrag() {
slog.Debug("defragmenting kv cache")
// Defrag strategy:
// - Search for empty holes at the beginning of the cache,
// filling them with active data starting at the end
// - If there are contiguous elements that need to be moved,
// combine them into a single operation by holding new moves
// until we see that the next one is non-contiguous
// - Fill up the context with the maximum number of operations it
// can hold then compute that and continue with a new context
//
// We could try to optimize placement by grouping blocks from
// the same sequences together but most likely the next forward
// pass will disrupt this anyways, so the real world benefit
// seems limited as this time.
ctx := c.backend.NewContext()
// For every move, 6 tensors are required per layer (2 views and a
// copy for each of k and v). We also need to refer to the original
// k and v cache tensors - once per layer, not per move.
layers := 0
for _, key := range c.keys {
if key == nil {
continue
}
layers++
}
maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
moves := 0
var pendingSrc, pendingDst, pendingLen int
src := len(c.cells) - 1
for dst := 0; dst < src; dst++ {
if len(c.cells[dst].sequences) == 0 {
for ; src > dst; src-- {
if len(c.cells[src].sequences) != 0 {
c.cells[dst] = c.cells[src]
c.cells[src] = cacheCell{}
if pendingLen > 0 {
if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen {
pendingSrc = src
pendingLen++
break
} else {
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moves++
}
}
pendingSrc = src
pendingDst = dst
pendingLen = 1
break
}
}
}
if moves >= maxMoves {
ctx.Compute()
ctx.Close()
ctx = c.backend.NewContext()
moves = 0
}
}
if pendingLen > 0 {
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moves++
}
if moves > 0 {
ctx.Compute()
}
ctx.Close()
// Reset range metadata
for seq := range c.cellRanges {
seqRange := newRange()
for i, cell := range c.cells {
if slices.Contains(cell.sequences, seq) {
if i < seqRange.min {
seqRange.min = i
}
if i > seqRange.max {
seqRange.max = i
}
}
}
c.cellRanges[seq] = seqRange
}
c.updateSlidingWindow()
}
func (c *Causal) SetLayer(layer int) {
c.curLayer = layer
}
@@ -625,18 +485,25 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
}
}
rowSize := c.keys[c.curLayer].Stride(2)
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
key = key.Reshape(ctx, kHeadDim*numKVHeads, batchSize)
keyCache := c.keys[c.curLayer]
keyCache = keyCache.Reshape(ctx, kHeadDim*numKVHeads, len(c.cells))
ctx.Forward(keyCache.SetRows(ctx, key, c.curLoc))
if c.config.PermutedV {
elemSize := c.values[c.curLayer].Stride(0)
value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
value = value.Permute(ctx, 2, 0, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
valueCache := c.values[c.curLayer]
valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
} else {
rowSize := c.values[c.curLayer].Stride(2)
value = value.Reshape(ctx, vHeadDim*numKVHeads, batchSize)
valueCache := c.values[c.curLayer]
valueCache = valueCache.Reshape(ctx, vHeadDim*numKVHeads, len(c.cells))
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
}
}

View File

@@ -207,11 +207,11 @@ func TestSWAMem(t *testing.T) {
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{4, 5},
expected: []float32{4, 5, 6},
expectedShape: []int{1, 1, 3},
expected: []float32{5, 2, 3, 4, 6},
expectedShape: []int{1, 1, 5},
expectedMask: []float32{
0, 0, x,
x, 0, 0,
0, x, x, 0, x,
0, x, x, x, 0,
},
},
}
@@ -319,6 +319,8 @@ func TestRemove(t *testing.T) {
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
x := float32(math.Inf(-1))
tests := []testCase{
{
name: "FirstBatch",
@@ -328,7 +330,12 @@ func TestRemove(t *testing.T) {
pos: []int32{0, 1, 0, 1},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
expectedMask: []float32{
0, x, x, x,
0, 0, x, x,
x, x, 0, x,
x, x, 0, 0,
},
},
}
@@ -346,9 +353,12 @@ func TestRemove(t *testing.T) {
inShape: []int{1, 1, 2},
seqs: []int{0, 1},
pos: []int32{1, 2},
expected: []float32{1, 2, 3, 4, 5, 6},
expectedShape: []int{1, 1, 6},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
expected: []float32{1, 5, 3, 4, 6},
expectedShape: []int{1, 1, 5},
expectedMask: []float32{
0, 0, x, x, x,
x, x, 0, 0, 0,
},
},
}
@@ -366,59 +376,12 @@ func TestRemove(t *testing.T) {
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{1, 2},
expected: []float32{7, 8, 3, 4, 4},
expectedShape: []int{1, 1, 5},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0},
},
}
testCache(t, backend, cache, tests)
}
func TestDefrag(t *testing.T) {
backend := &testBackend{}
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.Add(ctx, shift), nil
})
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
inShape: []int{1, 1, 16},
seqs: []int{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
pos: []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
expectedShape: []int{1, 1, 16},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
}
testCache(t, backend, cache, tests)
err := cache.Remove(0, 2, 4)
if err != nil {
panic(err)
}
err = cache.Remove(0, 13, math.MaxInt32)
if err != nil {
panic(err)
}
tests = []testCase{
{
name: "Defrag",
in: []float32{17, 18, 19},
inShape: []int{1, 1, 3},
seqs: []int{0, 0, 0},
pos: []int32{16, 17, 18},
expected: []float32{1, 2, 12, 13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 17, 18, 19},
expectedShape: []int{1, 1, 16},
expectedMask: []float32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
expected: []float32{7, 4, 3, 4, 6, 8},
expectedShape: []int{1, 1, 6},
expectedMask: []float32{
0, 0, x, x, x, x,
0, 0, x, x, x, 0,
},
},
}
@@ -770,6 +733,15 @@ func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return out
}
func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
return &testTensor{
dtype: t.dtype,
elementSize: t.elementSize,
data: t.data,
shape: shape,
}
}
func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
offset /= t.elementSize
@@ -778,6 +750,8 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
switch len(shape) {
case 1:
s = []int{shape[0]}
case 3:
s = []int{shape[0], shape[2]}
case 5:
s = []int{shape[0], shape[2], shape[4]}
default:
@@ -792,6 +766,86 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
return view
}
func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
dst := t
srcTensor := src.(*testTensor)
idxTensor := idxs.(*testTensor)
shapeTo4D := func(shape []int) [4]int {
out := [4]int{1, 1, 1, 1}
for i := 0; i < len(shape) && i < 4; i++ {
out[i] = shape[i]
}
return out
}
computeStrides := func(shape [4]int) [4]int {
out := [4]int{1, 1, 1, 1}
for i := 1; i < 4; i++ {
out[i] = out[i-1] * shape[i-1]
}
return out
}
dstShape4D := shapeTo4D(dst.shape)
srcShape4D := shapeTo4D(srcTensor.shape)
idxShape4D := shapeTo4D(idxTensor.shape)
if dstShape4D[0] != srcShape4D[0] || dstShape4D[2] != srcShape4D[2] || dstShape4D[3] != srcShape4D[3] {
panic("SetRows requires matching tensor shapes")
}
if srcShape4D[1] != idxShape4D[0] {
panic("SetRows rows/index mismatch")
}
if srcShape4D[2]%idxShape4D[1] != 0 || srcShape4D[3]%idxShape4D[2] != 0 {
panic("SetRows cannot broadcast indices")
}
if idxShape4D[3] != 1 {
panic("SetRows expects 1D or 2D index tensors")
}
dstStride := computeStrides(dstShape4D)
srcStride := computeStrides(srcShape4D)
idxStride := computeStrides(idxShape4D)
numColumns := srcShape4D[0]
numRows := srcShape4D[1]
for dim3Index := range dstShape4D[3] {
for dim2Index := range dstShape4D[2] {
idxDim2 := 0
idxDim3 := 0
if idxShape4D[1] > 0 {
idxDim2 = dim2Index % idxShape4D[1]
}
if idxShape4D[2] > 0 {
idxDim3 = dim3Index % idxShape4D[2]
}
idxBase := idxDim3*idxStride[2] + idxDim2*idxStride[1]
srcBase := dim3Index*srcStride[3] + dim2Index*srcStride[2]
dstBase := dim3Index*dstStride[3] + dim2Index*dstStride[2]
for row := range numRows {
idx := int(idxTensor.data[idxBase+row*idxStride[0]])
if idx < 0 || idx >= dstShape4D[1] {
panic("SetRows index out of range")
}
srcOffset := srcBase + row*srcStride[1]
dstOffset := dstBase + idx*dstStride[1]
copy(dst.data[dstOffset:dstOffset+numColumns], srcTensor.data[srcOffset:srcOffset+numColumns])
}
}
}
return dst
}
func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
copy(t2.(*testTensor).data, t.data)
return nil

View File

@@ -194,6 +194,7 @@ type Tensor interface {
Repeat(ctx Context, dim, n int) Tensor
Concat(ctx Context, t2 Tensor, dim int) Tensor
Rows(ctx Context, t2 Tensor) Tensor
SetRows(ctx Context, src Tensor, idxs Tensor) Tensor
Copy(ctx Context, t2 Tensor) Tensor
Duplicate(ctx Context) Tensor

View File

@@ -1338,6 +1338,13 @@ func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_set_rows(ctx.(*Context).ctx, t.t, src.(*Tensor).t, idxs.(*Tensor).t),
}
}
func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,