mirror of
https://github.com/ollama/ollama.git
synced 2025-12-05 19:16:53 -06:00
kvcache: Use SetRows to store cache data
We currently copy data into the KV cache in contiguous buffers using
ggml_cpy(). ggml_set_rows() was introduced to allow scatter operation
so that contiguous buffers are no longer required. The direct primary
benefit of this is that we no longer need to perform defragmentation.
However, GGML recently removed an optimization for ggml_cpy() and
we picked it up in 544b673 "ggml update to b6840 (#12791)". This
caused a roughly 40% drop in token generation performance on CUDA
due to CUDA graphs no longer being used. By switching to
ggml_set_rows(), the original optimization is no longer necessary
and CUDA performance is restored.
Fixes #13112
This commit is contained in:
@@ -3,7 +3,6 @@ package kvcache
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
@@ -40,18 +39,18 @@ type Causal struct {
|
||||
|
||||
// ** current forward pass **
|
||||
|
||||
// the active layer for Get and Put
|
||||
curLayer int
|
||||
|
||||
// starting location for data storage for this batch
|
||||
curLoc int
|
||||
|
||||
// size of the current batch
|
||||
curBatchSize int
|
||||
|
||||
// locations for data storage for this batch
|
||||
curLoc ml.Tensor
|
||||
|
||||
// mask of the cache as used by this batch
|
||||
curMask ml.Tensor
|
||||
|
||||
// the active layer for Get and Put
|
||||
curLayer int
|
||||
|
||||
// locations in the cache that are needed for this batch
|
||||
curCellRange cellRange
|
||||
|
||||
@@ -206,45 +205,47 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
|
||||
c.curPositions = batch.Positions
|
||||
c.opts.Except = nil
|
||||
|
||||
var locs []int32
|
||||
if !reserve {
|
||||
c.updateSlidingWindow()
|
||||
|
||||
var err error
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
if errors.Is(err, ErrKvCacheFull) {
|
||||
c.defrag()
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
}
|
||||
locs, err = c.findLocs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, pos := range batch.Positions {
|
||||
seq := batch.Sequences[i]
|
||||
loc := int(locs[i])
|
||||
|
||||
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||
c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||
|
||||
seqRange, ok := c.cellRanges[seq]
|
||||
if !ok {
|
||||
seqRange = newRange()
|
||||
}
|
||||
|
||||
seqRange.min = min(seqRange.min, c.curLoc+i)
|
||||
c.curCellRange.min = min(c.curCellRange.min, c.curLoc+i)
|
||||
seqRange.min = min(seqRange.min, loc)
|
||||
c.curCellRange.min = min(c.curCellRange.min, loc)
|
||||
|
||||
seqRange.max = max(seqRange.max, c.curLoc+i)
|
||||
c.curCellRange.max = max(c.curCellRange.max, c.curLoc+i)
|
||||
seqRange.max = max(seqRange.max, loc)
|
||||
c.curCellRange.max = max(c.curCellRange.max, loc)
|
||||
|
||||
c.cellRanges[seq] = seqRange
|
||||
}
|
||||
} else {
|
||||
// If we are reserving memory, don't update any of the cache metadata but set the size
|
||||
// to the worst case.
|
||||
c.curLoc = 0
|
||||
locs = make([]int32, c.curBatchSize)
|
||||
for i := range locs {
|
||||
locs[i] = int32(i)
|
||||
}
|
||||
c.curCellRange.min = 0
|
||||
c.curCellRange.max = len(c.cells) - 1
|
||||
}
|
||||
|
||||
c.curLoc = ctx.Input().FromInts(locs, len(locs))
|
||||
c.curMask = c.buildMask(ctx)
|
||||
|
||||
return nil
|
||||
@@ -257,22 +258,20 @@ func newRange() cellRange {
|
||||
}
|
||||
}
|
||||
|
||||
// Find the first contiguous block of at least curBatchSize
|
||||
func (c *Causal) findStartLoc() (int, error) {
|
||||
var start, count int
|
||||
// Returns a slice of locations where each token in the batch should be stored
|
||||
func (c *Causal) findLocs() ([]int32, error) {
|
||||
loc := make([]int32, 0, c.curBatchSize)
|
||||
|
||||
for i := range c.cells {
|
||||
if len(c.cells[i].sequences) == 0 {
|
||||
count++
|
||||
if count >= c.curBatchSize {
|
||||
return start, nil
|
||||
loc = append(loc, int32(i))
|
||||
if len(loc) >= c.curBatchSize {
|
||||
return loc, nil
|
||||
}
|
||||
} else {
|
||||
start = i + 1
|
||||
count = 0
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
|
||||
return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
|
||||
}
|
||||
|
||||
func (c *Causal) updateSlidingWindow() {
|
||||
@@ -402,145 +401,6 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
return maskTensor
|
||||
}
|
||||
|
||||
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
|
||||
for i, key := range c.keys {
|
||||
if key == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
kHeadDim := key.Dim(0)
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
|
||||
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
|
||||
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
|
||||
|
||||
value := c.values[i]
|
||||
var vSrcView, vDstView ml.Tensor
|
||||
if c.config.PermutedV {
|
||||
vHeadDim := value.Dim(1)
|
||||
elemSize := value.Stride(0)
|
||||
|
||||
vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
||||
vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
||||
} else {
|
||||
vHeadDim := value.Dim(0)
|
||||
rowSize := value.Stride(2)
|
||||
|
||||
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
|
||||
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
|
||||
}
|
||||
|
||||
ctx.Forward(
|
||||
kSrcView.Copy(ctx, kDstView),
|
||||
vSrcView.Copy(ctx, vDstView),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) defrag() {
|
||||
slog.Debug("defragmenting kv cache")
|
||||
|
||||
// Defrag strategy:
|
||||
// - Search for empty holes at the beginning of the cache,
|
||||
// filling them with active data starting at the end
|
||||
// - If there are contiguous elements that need to be moved,
|
||||
// combine them into a single operation by holding new moves
|
||||
// until we see that the next one is non-contiguous
|
||||
// - Fill up the context with the maximum number of operations it
|
||||
// can hold then compute that and continue with a new context
|
||||
//
|
||||
// We could try to optimize placement by grouping blocks from
|
||||
// the same sequences together but most likely the next forward
|
||||
// pass will disrupt this anyways, so the real world benefit
|
||||
// seems limited as this time.
|
||||
|
||||
ctx := c.backend.NewContext()
|
||||
|
||||
// For every move, 6 tensors are required per layer (2 views and a
|
||||
// copy for each of k and v). We also need to refer to the original
|
||||
// k and v cache tensors - once per layer, not per move.
|
||||
layers := 0
|
||||
for _, key := range c.keys {
|
||||
if key == nil {
|
||||
continue
|
||||
}
|
||||
layers++
|
||||
}
|
||||
|
||||
maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
|
||||
moves := 0
|
||||
|
||||
var pendingSrc, pendingDst, pendingLen int
|
||||
src := len(c.cells) - 1
|
||||
|
||||
for dst := 0; dst < src; dst++ {
|
||||
if len(c.cells[dst].sequences) == 0 {
|
||||
for ; src > dst; src-- {
|
||||
if len(c.cells[src].sequences) != 0 {
|
||||
c.cells[dst] = c.cells[src]
|
||||
c.cells[src] = cacheCell{}
|
||||
|
||||
if pendingLen > 0 {
|
||||
if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen {
|
||||
pendingSrc = src
|
||||
pendingLen++
|
||||
break
|
||||
} else {
|
||||
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
|
||||
moves++
|
||||
}
|
||||
}
|
||||
|
||||
pendingSrc = src
|
||||
pendingDst = dst
|
||||
pendingLen = 1
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if moves >= maxMoves {
|
||||
ctx.Compute()
|
||||
ctx.Close()
|
||||
ctx = c.backend.NewContext()
|
||||
|
||||
moves = 0
|
||||
}
|
||||
}
|
||||
|
||||
if pendingLen > 0 {
|
||||
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
|
||||
moves++
|
||||
}
|
||||
|
||||
if moves > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
ctx.Close()
|
||||
|
||||
// Reset range metadata
|
||||
for seq := range c.cellRanges {
|
||||
seqRange := newRange()
|
||||
|
||||
for i, cell := range c.cells {
|
||||
if slices.Contains(cell.sequences, seq) {
|
||||
if i < seqRange.min {
|
||||
seqRange.min = i
|
||||
}
|
||||
if i > seqRange.max {
|
||||
seqRange.max = i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.cellRanges[seq] = seqRange
|
||||
}
|
||||
|
||||
c.updateSlidingWindow()
|
||||
}
|
||||
|
||||
func (c *Causal) SetLayer(layer int) {
|
||||
c.curLayer = layer
|
||||
}
|
||||
@@ -625,18 +485,25 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
}
|
||||
}
|
||||
|
||||
rowSize := c.keys[c.curLayer].Stride(2)
|
||||
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
|
||||
key = key.Reshape(ctx, kHeadDim*numKVHeads, batchSize)
|
||||
keyCache := c.keys[c.curLayer]
|
||||
keyCache = keyCache.Reshape(ctx, kHeadDim*numKVHeads, len(c.cells))
|
||||
ctx.Forward(keyCache.SetRows(ctx, key, c.curLoc))
|
||||
|
||||
if c.config.PermutedV {
|
||||
elemSize := c.values[c.curLayer].Stride(0)
|
||||
value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
|
||||
value = value.Permute(ctx, 2, 0, 1, 3)
|
||||
|
||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
|
||||
valueCache := c.values[c.curLayer]
|
||||
valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
|
||||
|
||||
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
|
||||
} else {
|
||||
rowSize := c.values[c.curLayer].Stride(2)
|
||||
value = value.Reshape(ctx, vHeadDim*numKVHeads, batchSize)
|
||||
valueCache := c.values[c.curLayer]
|
||||
valueCache = valueCache.Reshape(ctx, vHeadDim*numKVHeads, len(c.cells))
|
||||
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
|
||||
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -207,11 +207,11 @@ func TestSWAMem(t *testing.T) {
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{0, 0},
|
||||
pos: []int32{4, 5},
|
||||
expected: []float32{4, 5, 6},
|
||||
expectedShape: []int{1, 1, 3},
|
||||
expected: []float32{5, 2, 3, 4, 6},
|
||||
expectedShape: []int{1, 1, 5},
|
||||
expectedMask: []float32{
|
||||
0, 0, x,
|
||||
x, 0, 0,
|
||||
0, x, x, 0, x,
|
||||
0, x, x, x, 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -319,6 +319,8 @@ func TestRemove(t *testing.T) {
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
x := float32(math.Inf(-1))
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "FirstBatch",
|
||||
@@ -328,7 +330,12 @@ func TestRemove(t *testing.T) {
|
||||
pos: []int32{0, 1, 0, 1},
|
||||
expected: []float32{1, 2, 3, 4},
|
||||
expectedShape: []int{1, 1, 4},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||
expectedMask: []float32{
|
||||
0, x, x, x,
|
||||
0, 0, x, x,
|
||||
x, x, 0, x,
|
||||
x, x, 0, 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -346,9 +353,12 @@ func TestRemove(t *testing.T) {
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{0, 1},
|
||||
pos: []int32{1, 2},
|
||||
expected: []float32{1, 2, 3, 4, 5, 6},
|
||||
expectedShape: []int{1, 1, 6},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
|
||||
expected: []float32{1, 5, 3, 4, 6},
|
||||
expectedShape: []int{1, 1, 5},
|
||||
expectedMask: []float32{
|
||||
0, 0, x, x, x,
|
||||
x, x, 0, 0, 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -366,59 +376,12 @@ func TestRemove(t *testing.T) {
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{0, 0},
|
||||
pos: []int32{1, 2},
|
||||
expected: []float32{7, 8, 3, 4, 4},
|
||||
expectedShape: []int{1, 1, 5},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0},
|
||||
},
|
||||
}
|
||||
|
||||
testCache(t, backend, cache, tests)
|
||||
}
|
||||
|
||||
func TestDefrag(t *testing.T) {
|
||||
backend := &testBackend{}
|
||||
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key.Add(ctx, shift), nil
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "FirstBatch",
|
||||
in: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||
inShape: []int{1, 1, 16},
|
||||
seqs: []int{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
pos: []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
|
||||
expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||
expectedShape: []int{1, 1, 16},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
},
|
||||
}
|
||||
|
||||
testCache(t, backend, cache, tests)
|
||||
|
||||
err := cache.Remove(0, 2, 4)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = cache.Remove(0, 13, math.MaxInt32)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
tests = []testCase{
|
||||
{
|
||||
name: "Defrag",
|
||||
in: []float32{17, 18, 19},
|
||||
inShape: []int{1, 1, 3},
|
||||
seqs: []int{0, 0, 0},
|
||||
pos: []int32{16, 17, 18},
|
||||
expected: []float32{1, 2, 12, 13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 17, 18, 19},
|
||||
expectedShape: []int{1, 1, 16},
|
||||
expectedMask: []float32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
expected: []float32{7, 4, 3, 4, 6, 8},
|
||||
expectedShape: []int{1, 1, 6},
|
||||
expectedMask: []float32{
|
||||
0, 0, x, x, x, x,
|
||||
0, 0, x, x, x, 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -770,6 +733,15 @@ func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
return &testTensor{
|
||||
dtype: t.dtype,
|
||||
elementSize: t.elementSize,
|
||||
data: t.data,
|
||||
shape: shape,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
offset /= t.elementSize
|
||||
|
||||
@@ -778,6 +750,8 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
switch len(shape) {
|
||||
case 1:
|
||||
s = []int{shape[0]}
|
||||
case 3:
|
||||
s = []int{shape[0], shape[2]}
|
||||
case 5:
|
||||
s = []int{shape[0], shape[2], shape[4]}
|
||||
default:
|
||||
@@ -792,6 +766,86 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
return view
|
||||
}
|
||||
|
||||
func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
|
||||
dst := t
|
||||
srcTensor := src.(*testTensor)
|
||||
idxTensor := idxs.(*testTensor)
|
||||
|
||||
shapeTo4D := func(shape []int) [4]int {
|
||||
out := [4]int{1, 1, 1, 1}
|
||||
for i := 0; i < len(shape) && i < 4; i++ {
|
||||
out[i] = shape[i]
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
computeStrides := func(shape [4]int) [4]int {
|
||||
out := [4]int{1, 1, 1, 1}
|
||||
for i := 1; i < 4; i++ {
|
||||
out[i] = out[i-1] * shape[i-1]
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
dstShape4D := shapeTo4D(dst.shape)
|
||||
srcShape4D := shapeTo4D(srcTensor.shape)
|
||||
idxShape4D := shapeTo4D(idxTensor.shape)
|
||||
|
||||
if dstShape4D[0] != srcShape4D[0] || dstShape4D[2] != srcShape4D[2] || dstShape4D[3] != srcShape4D[3] {
|
||||
panic("SetRows requires matching tensor shapes")
|
||||
}
|
||||
|
||||
if srcShape4D[1] != idxShape4D[0] {
|
||||
panic("SetRows rows/index mismatch")
|
||||
}
|
||||
|
||||
if srcShape4D[2]%idxShape4D[1] != 0 || srcShape4D[3]%idxShape4D[2] != 0 {
|
||||
panic("SetRows cannot broadcast indices")
|
||||
}
|
||||
|
||||
if idxShape4D[3] != 1 {
|
||||
panic("SetRows expects 1D or 2D index tensors")
|
||||
}
|
||||
|
||||
dstStride := computeStrides(dstShape4D)
|
||||
srcStride := computeStrides(srcShape4D)
|
||||
idxStride := computeStrides(idxShape4D)
|
||||
|
||||
numColumns := srcShape4D[0]
|
||||
numRows := srcShape4D[1]
|
||||
|
||||
for dim3Index := range dstShape4D[3] {
|
||||
for dim2Index := range dstShape4D[2] {
|
||||
idxDim2 := 0
|
||||
idxDim3 := 0
|
||||
if idxShape4D[1] > 0 {
|
||||
idxDim2 = dim2Index % idxShape4D[1]
|
||||
}
|
||||
if idxShape4D[2] > 0 {
|
||||
idxDim3 = dim3Index % idxShape4D[2]
|
||||
}
|
||||
|
||||
idxBase := idxDim3*idxStride[2] + idxDim2*idxStride[1]
|
||||
srcBase := dim3Index*srcStride[3] + dim2Index*srcStride[2]
|
||||
dstBase := dim3Index*dstStride[3] + dim2Index*dstStride[2]
|
||||
|
||||
for row := range numRows {
|
||||
idx := int(idxTensor.data[idxBase+row*idxStride[0]])
|
||||
if idx < 0 || idx >= dstShape4D[1] {
|
||||
panic("SetRows index out of range")
|
||||
}
|
||||
|
||||
srcOffset := srcBase + row*srcStride[1]
|
||||
dstOffset := dstBase + idx*dstStride[1]
|
||||
|
||||
copy(dst.data[dstOffset:dstOffset+numColumns], srcTensor.data[srcOffset:srcOffset+numColumns])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
copy(t2.(*testTensor).data, t.data)
|
||||
return nil
|
||||
|
||||
@@ -194,6 +194,7 @@ type Tensor interface {
|
||||
Repeat(ctx Context, dim, n int) Tensor
|
||||
Concat(ctx Context, t2 Tensor, dim int) Tensor
|
||||
Rows(ctx Context, t2 Tensor) Tensor
|
||||
SetRows(ctx Context, src Tensor, idxs Tensor) Tensor
|
||||
Copy(ctx Context, t2 Tensor) Tensor
|
||||
Duplicate(ctx Context) Tensor
|
||||
|
||||
|
||||
@@ -1338,6 +1338,13 @@ func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_set_rows(ctx.(*Context).ctx, t.t, src.(*Tensor).t, idxs.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
|
||||
Reference in New Issue
Block a user