mirror of
https://github.com/ollama/ollama.git
synced 2026-03-09 03:12:11 -05:00
753 lines
18 KiB
Go
753 lines
18 KiB
Go
package kvcache
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"slices"
|
|
|
|
"github.com/ollama/ollama/ml"
|
|
"github.com/ollama/ollama/model/input"
|
|
)
|
|
|
|
const (
|
|
DefaultCheckpointCount = 24
|
|
DefaultCheckpointMinPos = int32(16)
|
|
DefaultCheckpointInterval = int32(1664)
|
|
)
|
|
|
|
var ErrInvalidRecurrentShape = errors.New("kvcache: invalid recurrent state shape")
|
|
|
|
// Config configures a shared hybrid recurrent cache.
|
|
type RecurrentConfig struct {
|
|
Shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
|
|
ConvDim int
|
|
ConvChannels int
|
|
RecurrentStateSize int
|
|
CheckpointLogPrefix string
|
|
}
|
|
|
|
var (
|
|
_ Cache = (*Recurrent)(nil)
|
|
_ CheckpointCache = (*Recurrent)(nil)
|
|
)
|
|
|
|
// Cache stores:
|
|
// - a standard causal KV cache
|
|
// - per-sequence conv state for recurrent operators
|
|
// - per-sequence recurrent state for recurrent operators
|
|
//
|
|
// Conv state shape (per layer, per sequence): [convDim, convChannels]
|
|
// Recurrent state shape (per layer, per sequence): [recurrentStateSize]
|
|
type Recurrent struct {
|
|
kv *Causal
|
|
|
|
backend ml.Backend
|
|
dtype ml.DType
|
|
maxSequences int
|
|
|
|
// Conv state dimensions
|
|
convDim int
|
|
convChannels int
|
|
|
|
// Recurrent state dimensions
|
|
recurrentStateSize int
|
|
|
|
logPrefix string
|
|
|
|
// slot mapping for recurrent state (copy-on-write)
|
|
slotForSeq map[int]int
|
|
refCount []int
|
|
freeSlots []int
|
|
seqCounts map[int]int
|
|
slotScratch [1]int32
|
|
|
|
// per-layer conv state buffers (allocated lazily)
|
|
convCtxs map[int]ml.Context
|
|
convStates map[int]ml.Tensor // [convDim*convChannels, maxSlots]
|
|
|
|
// per-layer recurrent state buffers (allocated lazily)
|
|
recurrentCtxs map[int]ml.Context
|
|
recurrentStates map[int]ml.Tensor // [recurrentStateSize, maxSlots]
|
|
|
|
// recurrent checkpoints (per slot)
|
|
checkpointCount int
|
|
checkpointMinPos int32
|
|
checkpointInterval int32
|
|
checkpointCtxSize int
|
|
checkpoints map[int]*slotCheckpointStore
|
|
pendingRestore map[int]checkpointRestore
|
|
curCheckpointPos []int32
|
|
curCheckpointSlots map[int]int
|
|
reserveCheckpoints bool
|
|
checkpointConvCtxs map[int]ml.Context
|
|
checkpointRecurCtxs map[int]ml.Context
|
|
checkpointReserved map[int]struct{}
|
|
|
|
// current forward batch (derived in StartForward)
|
|
curSeqs []int
|
|
curSlots []int
|
|
curSlotsInput ml.Tensor
|
|
curSeqTokens int
|
|
|
|
// track if EnsureWritable has been called for this forward pass
|
|
writableEnsured bool
|
|
writableError error
|
|
}
|
|
|
|
func NewRecurrentCache(config RecurrentConfig) *Recurrent {
|
|
return &Recurrent{
|
|
kv: NewCausalCache(config.Shift),
|
|
convDim: config.ConvDim,
|
|
convChannels: config.ConvChannels,
|
|
recurrentStateSize: config.RecurrentStateSize,
|
|
logPrefix: config.CheckpointLogPrefix,
|
|
slotForSeq: make(map[int]int),
|
|
seqCounts: make(map[int]int),
|
|
convCtxs: make(map[int]ml.Context),
|
|
convStates: make(map[int]ml.Tensor),
|
|
recurrentCtxs: make(map[int]ml.Context),
|
|
recurrentStates: make(map[int]ml.Tensor),
|
|
checkpointCount: DefaultCheckpointCount,
|
|
checkpointMinPos: DefaultCheckpointMinPos,
|
|
checkpointInterval: DefaultCheckpointInterval,
|
|
checkpoints: make(map[int]*slotCheckpointStore),
|
|
pendingRestore: make(map[int]checkpointRestore),
|
|
curCheckpointSlots: make(map[int]int),
|
|
checkpointConvCtxs: make(map[int]ml.Context),
|
|
checkpointRecurCtxs: make(map[int]ml.Context),
|
|
checkpointReserved: make(map[int]struct{}),
|
|
}
|
|
}
|
|
|
|
func (c *Recurrent) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
|
c.backend = backend
|
|
c.dtype = dtype
|
|
c.maxSequences = maxSequences
|
|
c.checkpoints = make(map[int]*slotCheckpointStore)
|
|
c.pendingRestore = make(map[int]checkpointRestore)
|
|
c.curCheckpointPos = c.curCheckpointPos[:0]
|
|
c.curCheckpointSlots = make(map[int]int)
|
|
c.checkpointReserved = make(map[int]struct{})
|
|
c.checkpointCtxSize = c.checkpointCount * c.maxSequences
|
|
if c.checkpointCtxSize < 8 {
|
|
c.checkpointCtxSize = 8
|
|
}
|
|
|
|
// initialize slot allocator
|
|
c.refCount = make([]int, maxSequences)
|
|
c.freeSlots = c.freeSlots[:0]
|
|
for i := maxSequences - 1; i >= 0; i-- {
|
|
c.freeSlots = append(c.freeSlots, i)
|
|
}
|
|
|
|
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
|
}
|
|
|
|
func (c *Recurrent) Close() {
|
|
for _, ctx := range c.convCtxs {
|
|
ctx.Close()
|
|
}
|
|
for _, ctx := range c.recurrentCtxs {
|
|
ctx.Close()
|
|
}
|
|
for _, ctx := range c.checkpointConvCtxs {
|
|
ctx.Close()
|
|
}
|
|
for _, ctx := range c.checkpointRecurCtxs {
|
|
ctx.Close()
|
|
}
|
|
c.kv.Close()
|
|
}
|
|
|
|
func (c *Recurrent) SetConfig(config ml.CacheConfig) {
|
|
c.kv.SetConfig(config)
|
|
}
|
|
|
|
func (c *Recurrent) SetLayer(layer int) {
|
|
c.kv.SetLayer(layer)
|
|
}
|
|
|
|
func (c *Recurrent) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
|
return c.kv.Get(ctx)
|
|
}
|
|
|
|
func (c *Recurrent) Put(ctx ml.Context, key, value ml.Tensor) {
|
|
c.kv.Put(ctx, key, value)
|
|
}
|
|
|
|
func (c *Recurrent) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
|
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
|
|
return err
|
|
}
|
|
|
|
nTokens := len(batch.Sequences)
|
|
if nTokens == 0 {
|
|
c.curSeqs = c.curSeqs[:0]
|
|
c.curSlots = c.curSlots[:0]
|
|
c.curSlotsInput = nil
|
|
c.curSeqTokens = 0
|
|
c.reserveCheckpoints = false
|
|
c.writableEnsured = false
|
|
c.writableError = nil
|
|
return nil
|
|
}
|
|
|
|
// Fast path for single-sequence batches (common during decode and prefill).
|
|
firstSeq := batch.Sequences[0]
|
|
singleSeq := true
|
|
for _, s := range batch.Sequences[1:] {
|
|
if s != firstSeq {
|
|
singleSeq = false
|
|
break
|
|
}
|
|
}
|
|
if singleSeq {
|
|
return c.startForwardSingleSeq(ctx, firstSeq, nTokens, batch, reserve)
|
|
}
|
|
|
|
// Derive equal-length sequence layout for recurrent layers.
|
|
seqCounts := c.seqCounts
|
|
for s := range seqCounts {
|
|
delete(seqCounts, s)
|
|
}
|
|
|
|
c.curSeqs = c.curSeqs[:0]
|
|
for _, s := range batch.Sequences {
|
|
if seqCounts[s] == 0 {
|
|
c.curSeqs = append(c.curSeqs, s)
|
|
}
|
|
seqCounts[s]++
|
|
}
|
|
|
|
nSeqs := len(c.curSeqs)
|
|
want := nTokens / nSeqs
|
|
for _, s := range c.curSeqs {
|
|
if seqCounts[s] != want {
|
|
return ErrNotSupported
|
|
}
|
|
}
|
|
|
|
c.curSeqTokens = want
|
|
|
|
if reserve {
|
|
c.curSlots = c.curSlots[:0]
|
|
for i := range nSeqs {
|
|
c.curSlots = append(c.curSlots, i)
|
|
}
|
|
c.finalizeStartForward(ctx, batch, true)
|
|
return nil
|
|
}
|
|
|
|
// Ensure slots exist for sequences in this batch.
|
|
c.curSlots = c.curSlots[:0]
|
|
var newSlots []int
|
|
for _, s := range c.curSeqs {
|
|
slot, ok := c.slotForSeq[s]
|
|
if !ok {
|
|
var err error
|
|
slot, err = c.allocSlot()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c.slotForSeq[s] = slot
|
|
c.refCount[slot] = 1
|
|
newSlots = append(newSlots, slot)
|
|
}
|
|
c.curSlots = append(c.curSlots, slot)
|
|
}
|
|
|
|
if len(newSlots) > 0 {
|
|
c.zeroSlots(ctx, newSlots)
|
|
}
|
|
|
|
c.finalizeStartForward(ctx, batch, false)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Recurrent) startForwardSingleSeq(ctx ml.Context, seq, seqTokens int, batch input.Batch, reserve bool) error {
|
|
c.curSeqs = append(c.curSeqs[:0], seq)
|
|
c.curSeqTokens = seqTokens
|
|
|
|
if reserve {
|
|
c.curSlots = append(c.curSlots[:0], 0)
|
|
c.finalizeStartForward(ctx, batch, true)
|
|
return nil
|
|
}
|
|
|
|
slot, ok := c.slotForSeq[seq]
|
|
if !ok {
|
|
var err error
|
|
slot, err = c.allocSlot()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
c.slotForSeq[seq] = slot
|
|
c.refCount[slot] = 1
|
|
slotList := [1]int{slot}
|
|
c.zeroSlots(ctx, slotList[:])
|
|
}
|
|
|
|
c.curSlots = append(c.curSlots[:0], slot)
|
|
c.finalizeStartForward(ctx, batch, false)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Recurrent) finalizeStartForward(ctx ml.Context, batch input.Batch, reserve bool) {
|
|
c.setCurSlotsInput(ctx)
|
|
c.writableEnsured = false
|
|
c.writableError = nil
|
|
c.reserveCheckpoints = reserve
|
|
c.planCheckpoints(batch)
|
|
}
|
|
|
|
func (c *Recurrent) setCurSlotsInput(ctx ml.Context) {
|
|
c.curSlotsInput = c.slotsInput(ctx, c.curSlots)
|
|
}
|
|
|
|
func (c *Recurrent) slotsInput(ctx ml.Context, slots []int) ml.Tensor {
|
|
switch len(slots) {
|
|
case 0:
|
|
return nil
|
|
case 1:
|
|
c.slotScratch[0] = int32(slots[0])
|
|
return ctx.Input().FromInts(c.slotScratch[:], 1)
|
|
default:
|
|
slotIndices := make([]int32, len(slots))
|
|
for i, v := range slots {
|
|
slotIndices[i] = int32(v)
|
|
}
|
|
return ctx.Input().FromInts(slotIndices, len(slotIndices))
|
|
}
|
|
}
|
|
|
|
func (c *Recurrent) allocSlot() (int, error) {
|
|
if len(c.freeSlots) == 0 {
|
|
return 0, ErrKvCacheFull
|
|
}
|
|
slot := c.freeSlots[len(c.freeSlots)-1]
|
|
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
|
|
return slot, nil
|
|
}
|
|
|
|
func (c *Recurrent) freeSlot(slot int) {
|
|
if slot >= 0 && slot < c.maxSequences {
|
|
c.freeSlots = append(c.freeSlots, slot)
|
|
}
|
|
}
|
|
|
|
// zeroSlots zeros recurrent state for the given slots across all cached layers.
|
|
func (c *Recurrent) zeroSlots(ctx ml.Context, slots []int) {
|
|
if len(slots) == 0 {
|
|
return
|
|
}
|
|
|
|
inputCtx := ctx.Input()
|
|
slotsTensor := c.slotsInput(ctx, slots)
|
|
|
|
if len(c.convStates) > 0 {
|
|
zeros := inputCtx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, len(slots))
|
|
for _, buf := range c.convStates {
|
|
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
|
}
|
|
}
|
|
|
|
if len(c.recurrentStates) > 0 {
|
|
zeros := inputCtx.Zeros(ml.DTypeF32, c.recurrentStateSize, len(slots))
|
|
for _, buf := range c.recurrentStates {
|
|
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
|
}
|
|
}
|
|
}
|
|
|
|
// EnsureWritable ensures sequences have private slots (copy-on-write).
|
|
func (c *Recurrent) EnsureWritable(ctx ml.Context) error {
|
|
for i, seq := range c.curSeqs {
|
|
slot, ok := c.slotForSeq[seq]
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
if slot < 0 || slot >= len(c.refCount) {
|
|
continue
|
|
}
|
|
|
|
if c.refCount[slot] <= 1 {
|
|
continue
|
|
}
|
|
|
|
newSlot, err := c.allocSlot()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c.refCount[slot]--
|
|
c.refCount[newSlot] = 1
|
|
c.slotForSeq[seq] = newSlot
|
|
c.curSlots[i] = newSlot
|
|
|
|
c.copyRecurrentState(ctx, slot, newSlot)
|
|
c.copyCheckpoints(ctx, slot, newSlot)
|
|
}
|
|
|
|
c.setCurSlotsInput(ctx)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Recurrent) copyRecurrentState(ctx ml.Context, srcSlot, dstSlot int) {
|
|
src := ctx.Input().FromInts([]int32{int32(srcSlot)}, 1)
|
|
dst := ctx.Input().FromInts([]int32{int32(dstSlot)}, 1)
|
|
|
|
for _, buf := range c.convStates {
|
|
rows := buf.Rows(ctx, src)
|
|
if rows.DType() != ml.DTypeF32 {
|
|
rows = rows.Cast(ctx, ml.DTypeF32)
|
|
}
|
|
ctx.Forward(buf.SetRows(ctx, rows, dst))
|
|
}
|
|
|
|
for _, buf := range c.recurrentStates {
|
|
rows := buf.Rows(ctx, src)
|
|
if rows.DType() != ml.DTypeF32 {
|
|
rows = rows.Cast(ctx, ml.DTypeF32)
|
|
}
|
|
ctx.Forward(buf.SetRows(ctx, rows, dst))
|
|
}
|
|
}
|
|
|
|
func (c *Recurrent) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
|
|
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
|
|
|
|
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
|
|
if c.validSlot(dstSlot) {
|
|
c.refCount[dstSlot]--
|
|
if c.refCount[dstSlot] <= 0 {
|
|
c.refCount[dstSlot] = 0
|
|
c.freeSlot(dstSlot)
|
|
}
|
|
}
|
|
delete(c.slotForSeq, dstSeq)
|
|
}
|
|
|
|
srcSlot, ok := c.slotForSeq[srcSeq]
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
if c.validSlot(srcSlot) {
|
|
c.slotForSeq[dstSeq] = srcSlot
|
|
c.refCount[srcSlot]++
|
|
}
|
|
}
|
|
|
|
func (c *Recurrent) CanResume(seq int, pos int32) bool {
|
|
if !c.kv.CanResume(seq, pos) {
|
|
return false
|
|
}
|
|
if pos == 0 {
|
|
return true
|
|
}
|
|
return c.hasCheckpoint(seq, pos)
|
|
}
|
|
|
|
func (c *Recurrent) Remove(seq int, beginIndex, endIndex int32) error {
|
|
if beginIndex > 0 && endIndex != math.MaxInt32 {
|
|
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
|
return err
|
|
}
|
|
delete(c.pendingRestore, seq)
|
|
|
|
slot, ok := c.slotForSeq[seq]
|
|
if !ok || !c.validSlot(slot) {
|
|
return nil
|
|
}
|
|
|
|
// Detach shared recurrent state/checkpoints before mutating checkpoint positions.
|
|
if c.refCount[slot] > 1 {
|
|
newSlot, err := c.allocSlot()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ctx := c.backend.NewContext()
|
|
c.copyRecurrentState(ctx, slot, newSlot)
|
|
c.copyCheckpoints(ctx, slot, newSlot)
|
|
if len(c.convStates) > 0 || len(c.recurrentStates) > 0 {
|
|
ctx.Compute()
|
|
}
|
|
ctx.Close()
|
|
|
|
c.refCount[slot]--
|
|
c.refCount[newSlot] = 1
|
|
c.slotForSeq[seq] = newSlot
|
|
slot = newSlot
|
|
}
|
|
|
|
c.shiftCheckpoints(slot, beginIndex, endIndex)
|
|
return nil
|
|
}
|
|
|
|
if beginIndex > 0 {
|
|
restore, ok := c.pendingRestore[seq]
|
|
if !ok || restore.pos+1 != beginIndex {
|
|
return ErrNotSupported
|
|
}
|
|
if !c.restoreComplete(restore) {
|
|
return ErrNotSupported
|
|
}
|
|
if slot, ok := c.slotForSeq[seq]; ok && c.validSlot(slot) && c.refCount[slot] > 1 {
|
|
newSlot, err := c.allocSlot()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ctx := c.backend.NewContext()
|
|
c.copyRecurrentState(ctx, slot, newSlot)
|
|
c.copyCheckpoints(ctx, slot, newSlot)
|
|
if len(c.convStates) > 0 || len(c.recurrentStates) > 0 {
|
|
ctx.Compute()
|
|
}
|
|
ctx.Close()
|
|
|
|
c.refCount[slot]--
|
|
c.refCount[newSlot] = 1
|
|
c.slotForSeq[seq] = newSlot
|
|
|
|
restore.slot = newSlot
|
|
c.pendingRestore[seq] = restore
|
|
}
|
|
}
|
|
|
|
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
|
return err
|
|
}
|
|
|
|
if beginIndex > 0 {
|
|
restore := c.pendingRestore[seq]
|
|
delete(c.pendingRestore, seq)
|
|
return c.applyCheckpointRestore(restore)
|
|
}
|
|
|
|
slot, ok := c.slotForSeq[seq]
|
|
delete(c.pendingRestore, seq)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
if !c.validSlot(slot) {
|
|
delete(c.slotForSeq, seq)
|
|
return nil
|
|
}
|
|
|
|
c.refCount[slot]--
|
|
if c.refCount[slot] <= 0 {
|
|
c.refCount[slot] = 0
|
|
c.clearCheckpoints(slot)
|
|
c.freeSlot(slot)
|
|
}
|
|
delete(c.slotForSeq, seq)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Recurrent) validSlot(slot int) bool {
|
|
return slot >= 0 && slot < len(c.refCount)
|
|
}
|
|
|
|
func (c *Recurrent) SlotsTensor() ml.Tensor {
|
|
return c.curSlotsInput
|
|
}
|
|
|
|
// contiguousSlots returns the starting slot if current slots are contiguous and ordered.
|
|
func (c *Recurrent) contiguousSlots() (int, bool) {
|
|
if len(c.curSlots) == 0 {
|
|
return 0, false
|
|
}
|
|
start := c.curSlots[0]
|
|
for i, s := range c.curSlots {
|
|
if s != start+i {
|
|
return 0, false
|
|
}
|
|
}
|
|
return start, true
|
|
}
|
|
|
|
func (c *Recurrent) SeqTokens() int {
|
|
return c.curSeqTokens
|
|
}
|
|
|
|
func (c *Recurrent) NumSeqs() int {
|
|
return len(c.curSeqs)
|
|
}
|
|
|
|
func (c *Recurrent) convBuffer(layer int) ml.Tensor {
|
|
if buf, ok := c.convStates[layer]; ok {
|
|
return buf
|
|
}
|
|
|
|
if _, ok := c.convCtxs[layer]; !ok {
|
|
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
|
}
|
|
|
|
buf := c.convCtxs[layer].Zeros(ml.DTypeF32, c.convDim*c.convChannels, c.maxSequences)
|
|
c.convStates[layer] = buf
|
|
return buf
|
|
}
|
|
|
|
func (c *Recurrent) recurrentBuffer(layer int) ml.Tensor {
|
|
if buf, ok := c.recurrentStates[layer]; ok {
|
|
return buf
|
|
}
|
|
|
|
if _, ok := c.recurrentCtxs[layer]; !ok {
|
|
c.recurrentCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
|
}
|
|
|
|
buf := c.recurrentCtxs[layer].Zeros(ml.DTypeF32, c.recurrentStateSize, c.maxSequences)
|
|
c.recurrentStates[layer] = buf
|
|
return buf
|
|
}
|
|
|
|
func (c *Recurrent) ensureWritable(ctx ml.Context) error {
|
|
c.ensureWritableOnce(ctx)
|
|
return c.writableError
|
|
}
|
|
|
|
func (c *Recurrent) currentSlotRows(ctx ml.Context, buf ml.Tensor, rowSize int) ml.Tensor {
|
|
if start, ok := c.contiguousSlots(); ok {
|
|
offset := start * buf.Stride(1)
|
|
return buf.View(ctx, offset, rowSize, buf.Stride(1), c.NumSeqs())
|
|
}
|
|
|
|
return buf.Rows(ctx, c.SlotsTensor())
|
|
}
|
|
|
|
func (c *Recurrent) writeCurrentSlotRows(ctx ml.Context, buf ml.Tensor, rowSize int, src ml.Tensor) {
|
|
if start, ok := c.contiguousSlots(); ok {
|
|
offset := start * buf.Stride(1)
|
|
view := buf.View(ctx, offset, rowSize, buf.Stride(1), c.NumSeqs())
|
|
ctx.Forward(src.Copy(ctx, view))
|
|
return
|
|
}
|
|
|
|
ctx.Forward(buf.SetRows(ctx, src, c.SlotsTensor()))
|
|
}
|
|
|
|
func (c *Recurrent) ensureWritableOnce(ctx ml.Context) {
|
|
if !c.writableEnsured {
|
|
needsWritable := false
|
|
for _, seq := range c.curSeqs {
|
|
slot, ok := c.slotForSeq[seq]
|
|
if !ok {
|
|
continue
|
|
}
|
|
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
|
|
needsWritable = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if needsWritable {
|
|
if err := c.EnsureWritable(ctx); err != nil {
|
|
c.writableError = err
|
|
}
|
|
}
|
|
c.writableEnsured = true
|
|
}
|
|
}
|
|
|
|
// ConvState returns conv state for current batch sequences as [convDim, convChannels, nSeqs].
|
|
func (c *Recurrent) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
|
|
if err := c.ensureWritable(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
buf := c.convBuffer(layer)
|
|
cur := c.currentSlotRows(ctx, buf, c.convDim*c.convChannels)
|
|
return cur.Reshape(ctx, c.convDim, c.convChannels, c.NumSeqs()), nil
|
|
}
|
|
|
|
// UpdateConvState writes new conv state for current batch sequences.
|
|
func (c *Recurrent) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
|
|
buf := c.convBuffer(layer)
|
|
src := newState.Reshape(ctx, c.convDim*c.convChannels, c.NumSeqs())
|
|
srcF32 := src
|
|
if src.DType() != ml.DTypeF32 {
|
|
srcF32 = src.Cast(ctx, ml.DTypeF32)
|
|
}
|
|
c.writeCurrentSlotRows(ctx, buf, c.convDim*c.convChannels, srcF32)
|
|
|
|
c.captureConvCheckpoint(ctx, layer, srcF32)
|
|
}
|
|
|
|
// RecurrentState returns recurrent state for current batch sequences with shape [dims..., nSeqs].
|
|
func (c *Recurrent) RecurrentState(ctx ml.Context, layer int, dims ...int) (ml.Tensor, error) {
|
|
if err := c.ensureWritable(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
if len(dims) == 0 {
|
|
return nil, ErrInvalidRecurrentShape
|
|
}
|
|
|
|
size := 1
|
|
for _, d := range dims {
|
|
if d <= 0 {
|
|
return nil, ErrInvalidRecurrentShape
|
|
}
|
|
size *= d
|
|
}
|
|
if size != c.recurrentStateSize {
|
|
return nil, fmt.Errorf("%w: got %v (size %d), want size %d", ErrInvalidRecurrentShape, dims, size, c.recurrentStateSize)
|
|
}
|
|
|
|
buf := c.recurrentBuffer(layer)
|
|
cur := c.currentSlotRows(ctx, buf, c.recurrentStateSize)
|
|
shape := make([]int, 0, len(dims)+1)
|
|
shape = append(shape, dims...)
|
|
shape = append(shape, c.NumSeqs())
|
|
return cur.Reshape(ctx, shape...), nil
|
|
}
|
|
|
|
// RecurrentState4D returns recurrent state as [dim0, dim1, dim2, nSeqs].
|
|
func (c *Recurrent) RecurrentState4D(ctx ml.Context, layer int, dim0, dim1, dim2 int) (ml.Tensor, error) {
|
|
if err := c.ensureWritable(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
if dim0 <= 0 || dim1 <= 0 || dim2 <= 0 {
|
|
return nil, ErrInvalidRecurrentShape
|
|
}
|
|
|
|
size := dim0 * dim1 * dim2
|
|
if size != c.recurrentStateSize {
|
|
return nil, fmt.Errorf("%w: got [%d %d %d] (size %d), want size %d", ErrInvalidRecurrentShape, dim0, dim1, dim2, size, c.recurrentStateSize)
|
|
}
|
|
|
|
buf := c.recurrentBuffer(layer)
|
|
cur := c.currentSlotRows(ctx, buf, c.recurrentStateSize)
|
|
return cur.Reshape(ctx, dim0, dim1, dim2, c.NumSeqs()), nil
|
|
}
|
|
|
|
// UpdateRecurrentState writes new recurrent state for current batch sequences.
|
|
func (c *Recurrent) UpdateRecurrentState(ctx ml.Context, layer int, newState ml.Tensor) {
|
|
buf := c.recurrentBuffer(layer)
|
|
src := newState.Reshape(ctx, c.recurrentStateSize, c.NumSeqs())
|
|
srcF32 := src
|
|
if src.DType() != ml.DTypeF32 {
|
|
srcF32 = src.Cast(ctx, ml.DTypeF32)
|
|
}
|
|
c.writeCurrentSlotRows(ctx, buf, c.recurrentStateSize, srcF32)
|
|
|
|
c.captureRecurrentCheckpoint(ctx, layer, srcF32)
|
|
}
|
|
|
|
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
|
|
func (c *Recurrent) IsSupportedForBatch() bool {
|
|
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
|
|
}
|
|
|
|
// Seqs returns the ordered unique sequences for the current forward pass.
|
|
func (c *Recurrent) Seqs() []int {
|
|
return slices.Clone(c.curSeqs)
|
|
}
|