mirror of
https://github.com/ollama/ollama.git
synced 2026-03-09 07:16:38 -05:00
289 lines
7.0 KiB
Go
289 lines
7.0 KiB
Go
package kvcache
|
|
|
|
import (
|
|
"errors"
|
|
"math"
|
|
"slices"
|
|
"testing"
|
|
|
|
"github.com/ollama/ollama/ml"
|
|
)
|
|
|
|
func newTestCache() *Recurrent {
|
|
return NewRecurrentCache(RecurrentConfig{ConvDim: 1, ConvChannels: 2, RecurrentStateSize: 2})
|
|
}
|
|
|
|
func TestSlotCheckpointStoreBestIndex(t *testing.T) {
|
|
store := newSlotCheckpointStore(2)
|
|
store.record(10)
|
|
store.record(20)
|
|
|
|
_, pos, ok := store.bestIndex(15)
|
|
if !ok || pos != 10 {
|
|
t.Fatalf("expected best pos 10, got pos=%d ok=%v", pos, ok)
|
|
}
|
|
|
|
store.record(30) // overwrite oldest (10)
|
|
|
|
if _, _, ok := store.bestIndex(15); ok {
|
|
t.Fatalf("expected no checkpoint for targetPos=15 after overwrite")
|
|
}
|
|
|
|
_, pos, ok = store.bestIndex(40)
|
|
if !ok || pos != 30 {
|
|
t.Fatalf("expected best pos 30, got pos=%d ok=%v", pos, ok)
|
|
}
|
|
}
|
|
|
|
func TestCachePrepareRestore(t *testing.T) {
|
|
cache := newTestCache()
|
|
cache.checkpointCount = 3
|
|
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
|
cache.pendingRestore = make(map[int]checkpointRestore)
|
|
|
|
cache.slotForSeq[1] = 0
|
|
store := cache.checkpointStore(0)
|
|
store.record(5)
|
|
store.record(9)
|
|
store.record(15)
|
|
|
|
restorePos, ok := cache.PrepareRestore(1, 12)
|
|
if !ok {
|
|
t.Fatalf("expected restore ok")
|
|
}
|
|
if restorePos != 10 {
|
|
t.Fatalf("expected restorePos 10, got %d", restorePos)
|
|
}
|
|
rest, ok := cache.pendingRestore[1]
|
|
if !ok {
|
|
t.Fatalf("expected pending restore entry")
|
|
}
|
|
if rest.pos != 9 {
|
|
t.Fatalf("expected pending restore pos 9, got %d", rest.pos)
|
|
}
|
|
}
|
|
|
|
func TestSlotCheckpointStorePruneAfter(t *testing.T) {
|
|
store := newSlotCheckpointStore(3)
|
|
store.record(10)
|
|
store.record(20)
|
|
store.record(30)
|
|
|
|
store.pruneAfter(20)
|
|
|
|
if store.lastPos != 20 {
|
|
t.Fatalf("expected lastPos 20, got %d", store.lastPos)
|
|
}
|
|
|
|
_, pos, ok := store.bestIndex(25)
|
|
if !ok || pos != 20 {
|
|
t.Fatalf("expected best pos 20 after prune, got pos=%d ok=%v", pos, ok)
|
|
}
|
|
|
|
_, pos, ok = store.bestIndex(35)
|
|
if !ok || pos != 20 {
|
|
t.Fatalf("expected pruned best pos 20 for targetPos=35, got pos=%d ok=%v", pos, ok)
|
|
}
|
|
}
|
|
|
|
func TestCacheRestoreRejectsIncompleteCheckpoint(t *testing.T) {
|
|
cache := newTestCache()
|
|
cache.checkpointCount = 3
|
|
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
|
cache.pendingRestore = make(map[int]checkpointRestore)
|
|
|
|
cache.slotForSeq[1] = 0
|
|
cache.refCount = []int{1}
|
|
cache.freeSlots = nil
|
|
|
|
// Simulate layer 0 requires both conv and recurrent checkpoints.
|
|
cache.convStates[0] = nil
|
|
cache.recurrentStates[0] = nil
|
|
|
|
store := cache.checkpointStore(0)
|
|
idx := store.record(9)
|
|
entry := &store.entries[idx]
|
|
entry.conv = map[int]ml.Tensor{0: nil}
|
|
// entry.recurrent intentionally missing
|
|
|
|
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
|
|
|
err := cache.Remove(1, 10, math.MaxInt32)
|
|
if !errors.Is(err, ErrNotSupported) {
|
|
t.Fatalf("expected ErrNotSupported for incomplete checkpoint, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestCacheRestoreAcceptsCompleteCheckpoint(t *testing.T) {
|
|
cache := newTestCache()
|
|
cache.checkpointCount = 3
|
|
cache.checkpoints = make(map[int]*slotCheckpointStore)
|
|
cache.pendingRestore = make(map[int]checkpointRestore)
|
|
|
|
cache.slotForSeq[1] = 0
|
|
cache.refCount = []int{1}
|
|
cache.freeSlots = nil
|
|
|
|
store := cache.checkpointStore(0)
|
|
idx := store.record(9)
|
|
|
|
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
|
|
|
|
restore := cache.pendingRestore[1]
|
|
if !cache.restoreComplete(restore) {
|
|
t.Fatalf("expected restoreComplete to return true for complete checkpoint")
|
|
}
|
|
}
|
|
|
|
func TestCacheRecurrentStateShapeValidation(t *testing.T) {
|
|
cache := newTestCache()
|
|
_, err := cache.RecurrentState(nil, 0, 3)
|
|
if !errors.Is(err, ErrInvalidRecurrentShape) {
|
|
t.Fatalf("expected ErrInvalidRecurrentShape, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestSlotCheckpointStoreShiftRange(t *testing.T) {
|
|
store := newSlotCheckpointStore(5)
|
|
store.record(1)
|
|
store.record(4)
|
|
store.record(7)
|
|
store.record(10)
|
|
|
|
store.shiftRange(2, 6)
|
|
|
|
var positions []int32
|
|
for i := range store.entries {
|
|
if store.entries[i].pos >= 0 {
|
|
positions = append(positions, store.entries[i].pos)
|
|
}
|
|
}
|
|
slices.Sort(positions)
|
|
|
|
want := []int32{1, 3, 6}
|
|
if !slices.Equal(positions, want) {
|
|
t.Fatalf("unexpected shifted positions: got=%v want=%v", positions, want)
|
|
}
|
|
if store.lastPos != 6 {
|
|
t.Fatalf("expected lastPos 6, got %d", store.lastPos)
|
|
}
|
|
}
|
|
|
|
func TestCacheRemoveMiddleShiftsCheckpoints(t *testing.T) {
|
|
cache := newTestCache()
|
|
cache.slotForSeq[1] = 0
|
|
cache.refCount = []int{1}
|
|
cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: 0, pos: 1}
|
|
|
|
store := cache.checkpointStore(0)
|
|
store.record(1)
|
|
store.record(4)
|
|
store.record(7)
|
|
store.record(10)
|
|
|
|
if err := cache.Remove(1, 2, 6); err != nil {
|
|
t.Fatalf("expected middle remove to succeed, got %v", err)
|
|
}
|
|
|
|
if _, ok := cache.pendingRestore[1]; ok {
|
|
t.Fatalf("expected pending restore to be cleared after middle remove")
|
|
}
|
|
|
|
var positions []int32
|
|
for i := range store.entries {
|
|
if store.entries[i].pos >= 0 {
|
|
positions = append(positions, store.entries[i].pos)
|
|
}
|
|
}
|
|
slices.Sort(positions)
|
|
|
|
want := []int32{1, 3, 6}
|
|
if !slices.Equal(positions, want) {
|
|
t.Fatalf("unexpected checkpoint positions after remove: got=%v want=%v", positions, want)
|
|
}
|
|
}
|
|
|
|
func TestSlotCheckpointStoreRingBufferWrapAround(t *testing.T) {
|
|
store := newSlotCheckpointStore(3)
|
|
|
|
store.record(10)
|
|
store.record(20)
|
|
store.record(30)
|
|
|
|
store.entries[0].conv = make(map[int]ml.Tensor)
|
|
store.entries[0].conv[0] = nil
|
|
store.entries[0].recurrent = make(map[int]ml.Tensor)
|
|
store.entries[0].recurrent[0] = nil
|
|
|
|
store.record(40)
|
|
|
|
if store.entries[0].conv == nil {
|
|
t.Fatalf("expected conv map to be preserved on reuse")
|
|
}
|
|
if store.entries[0].recurrent == nil {
|
|
t.Fatalf("expected recurrent map to be preserved on reuse")
|
|
}
|
|
if store.entries[0].pos != 40 {
|
|
t.Fatalf("expected entry 0 pos to be 40, got %d", store.entries[0].pos)
|
|
}
|
|
}
|
|
|
|
func TestSlotCheckpointStoreFullCapacity(t *testing.T) {
|
|
store := newSlotCheckpointStore(2)
|
|
|
|
idx1 := store.record(10)
|
|
idx2 := store.record(20)
|
|
|
|
if idx1 != 0 || idx2 != 1 {
|
|
t.Fatalf("expected indices 0, 1, got %d, %d", idx1, idx2)
|
|
}
|
|
if store.size != 2 {
|
|
t.Fatalf("expected size 2, got %d", store.size)
|
|
}
|
|
|
|
_, pos1, ok1 := store.bestIndex(15)
|
|
_, pos2, ok2 := store.bestIndex(25)
|
|
|
|
if !ok1 || pos1 != 10 {
|
|
t.Fatalf("expected best pos 10 for target 15, got pos=%d ok=%v", pos1, ok1)
|
|
}
|
|
if !ok2 || pos2 != 20 {
|
|
t.Fatalf("expected best pos 20 for target 25, got pos=%d ok=%v", pos2, ok2)
|
|
}
|
|
}
|
|
|
|
func TestSlotCheckpointStoreEmptyBuffer(t *testing.T) {
|
|
store := newSlotCheckpointStore(0)
|
|
|
|
idx := store.record(10)
|
|
if idx != -1 {
|
|
t.Fatalf("expected record to return -1 for empty buffer, got %d", idx)
|
|
}
|
|
|
|
_, _, ok := store.bestIndex(15)
|
|
if ok {
|
|
t.Fatalf("expected no checkpoint for empty buffer")
|
|
}
|
|
}
|
|
|
|
func TestSlotCheckpointStorePruneAfterAll(t *testing.T) {
|
|
store := newSlotCheckpointStore(3)
|
|
store.record(10)
|
|
store.record(20)
|
|
store.record(30)
|
|
|
|
store.pruneAfter(5)
|
|
|
|
if store.size != 0 {
|
|
t.Fatalf("expected size 0 after pruning all, got %d", store.size)
|
|
}
|
|
if store.lastPos != -1 {
|
|
t.Fatalf("expected lastPos -1 after pruning all, got %d", store.lastPos)
|
|
}
|
|
|
|
_, _, ok := store.bestIndex(100)
|
|
if ok {
|
|
t.Fatalf("expected no checkpoint after pruning all")
|
|
}
|
|
}
|