kvcache: wire turboquant cache into new engine path

This commit is contained in:
dankguy17
2026-03-28 21:59:31 -07:00
parent 8195786db6
commit aafb65a957
5 changed files with 199 additions and 2 deletions

View File

@@ -60,12 +60,15 @@ type TurboQuantWrapper struct {
inner Cache
bitWidth int
backend ml.Backend
layer int
mu sync.Mutex
rotations map[int]*rotationPair
rotCtxs map[int]ml.Context
}
var _ CheckpointCache = (*TurboQuantWrapper)(nil)
type rotationPair struct {
// forward stores Pi^T in GGML format.
// GGML Mulmat(A, B) computes A^T @ B.
@@ -108,7 +111,10 @@ func (w *TurboQuantWrapper) Close() {
w.inner.Close()
}
func (w *TurboQuantWrapper) SetLayer(layer int) { w.inner.SetLayer(layer) }
func (w *TurboQuantWrapper) SetLayer(layer int) {
w.layer = layer
w.inner.SetLayer(layer)
}
func (w *TurboQuantWrapper) SetConfig(config ml.CacheConfig) { w.inner.SetConfig(config) }
func (w *TurboQuantWrapper) CopyPrefix(src, dst int, l int32) { w.inner.CopyPrefix(src, dst, l) }
func (w *TurboQuantWrapper) CanResume(seq int, pos int32) bool { return w.inner.CanResume(seq, pos) }
@@ -121,6 +127,20 @@ func (w *TurboQuantWrapper) Remove(seq int, beginIndex, endIndex int32) error {
return w.inner.Remove(seq, beginIndex, endIndex)
}
func (w *TurboQuantWrapper) PrepareRestore(seq int, targetPos int32) (int32, bool) {
if cc, ok := w.inner.(CheckpointCache); ok {
return cc.PrepareRestore(seq, targetPos)
}
// Preserve non-checkpoint cache behavior used by ollamarunner:
// keep targetPos when the cache can resume, otherwise signal reprocess.
if w.inner.CanResume(seq, targetPos) {
return targetPos, true
}
return 0, false
}
func (w *TurboQuantWrapper) Put(ctx ml.Context, key, value ml.Tensor) {
headDim := key.Dim(0)
rot := w.getOrCreateRotation(headDim)
@@ -139,6 +159,11 @@ func (w *TurboQuantWrapper) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor
rot := w.getOrCreateRotation(headDim)
if rot != nil {
// Metal does not provide all f32 x q4_0 matmul variants needed for
// rotation. Cast cache output to f32 before applying inverse rotation.
if key.DType() != ml.DTypeF32 {
key = key.Cast(ctx, ml.DTypeF32)
}
key = rot.inverse.Mulmat(ctx, key)
}
@@ -165,7 +190,7 @@ func (w *TurboQuantWrapper) getOrCreateRotation(headDim int) *rotationPair {
piData := turboquant.GenerateRotation(headDim, seed)
piTData := turboquant.GenerateRotationTranspose(headDim, seed)
rotCtx := w.backend.NewContextSize(2)
rotCtx := w.backend.NewContextSize(2).Layer(w.layer)
piTensor := rotCtx.FromFloats(piData, headDim, headDim)
piTTensor := rotCtx.FromFloats(piTData, headDim, headDim)

118
kvcache/turboquant_test.go Normal file
View File

@@ -0,0 +1,118 @@
package kvcache
import (
"testing"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
type turboQuantTestBaseCache struct {
canResume bool
canResumeCalls int
canResumeLastSeq int
canResumeLastPos int32
}
func (c *turboQuantTestBaseCache) SetLayer(layer int) {}
func (c *turboQuantTestBaseCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
return nil, nil, nil
}
func (c *turboQuantTestBaseCache) Put(ctx ml.Context, key, value ml.Tensor) {}
func (c *turboQuantTestBaseCache) SetConfig(config ml.CacheConfig) {}
func (c *turboQuantTestBaseCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
}
func (c *turboQuantTestBaseCache) Close() {}
func (c *turboQuantTestBaseCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
return nil
}
func (c *turboQuantTestBaseCache) CopyPrefix(srcSeq, dstSeq int, len int32) {}
func (c *turboQuantTestBaseCache) CanResume(seq int, pos int32) bool {
c.canResumeCalls++
c.canResumeLastSeq = seq
c.canResumeLastPos = pos
return c.canResume
}
func (c *turboQuantTestBaseCache) Remove(seq int, beginIndex, endIndex int32) error {
return nil
}
type turboQuantTestCheckpointCache struct {
turboQuantTestBaseCache
restorePos int32
restoreOK bool
prepareCalls int
prepareLastSeq int
prepareLastPos int32
}
func (c *turboQuantTestCheckpointCache) PrepareRestore(seq int, targetPos int32) (int32, bool) {
c.prepareCalls++
c.prepareLastSeq = seq
c.prepareLastPos = targetPos
return c.restorePos, c.restoreOK
}
func TestTurboQuantPrepareRestorePassthrough(t *testing.T) {
inner := &turboQuantTestCheckpointCache{
restorePos: 7,
restoreOK: true,
}
w := NewTurboQuantWrapper(inner, ml.DTypeTQ4)
gotPos, gotOK := w.PrepareRestore(2, 11)
if !gotOK || gotPos != 7 {
t.Fatalf("PrepareRestore() = (%d, %v), want (7, true)", gotPos, gotOK)
}
if inner.prepareCalls != 1 {
t.Fatalf("inner PrepareRestore calls = %d, want 1", inner.prepareCalls)
}
if inner.prepareLastSeq != 2 || inner.prepareLastPos != 11 {
t.Fatalf("inner PrepareRestore args = (%d, %d), want (2, 11)", inner.prepareLastSeq, inner.prepareLastPos)
}
}
func TestTurboQuantPrepareRestoreFallsBackToCanResume(t *testing.T) {
inner := &turboQuantTestBaseCache{canResume: true}
w := NewTurboQuantWrapper(inner, ml.DTypeTQ3)
gotPos, gotOK := w.PrepareRestore(3, 9)
if !gotOK || gotPos != 9 {
t.Fatalf("PrepareRestore() = (%d, %v), want (9, true)", gotPos, gotOK)
}
if inner.canResumeCalls != 1 {
t.Fatalf("inner CanResume calls = %d, want 1", inner.canResumeCalls)
}
if inner.canResumeLastSeq != 3 || inner.canResumeLastPos != 9 {
t.Fatalf("inner CanResume args = (%d, %d), want (3, 9)", inner.canResumeLastSeq, inner.canResumeLastPos)
}
}
func TestTurboQuantPrepareRestoreFallbackFailure(t *testing.T) {
inner := &turboQuantTestBaseCache{canResume: false}
w := NewTurboQuantWrapper(inner, ml.DTypeTQ4)
gotPos, gotOK := w.PrepareRestore(4, 13)
if gotOK || gotPos != 0 {
t.Fatalf("PrepareRestore() = (%d, %v), want (0, false)", gotPos, gotOK)
}
if inner.canResumeCalls != 1 {
t.Fatalf("inner CanResume calls = %d, want 1", inner.canResumeCalls)
}
}

View File

@@ -98,6 +98,11 @@ func (m *Base) Config() config {
return m.config
}
// SetCache replaces the model cache implementation at runtime.
func (m *Base) SetCache(cache kvcache.Cache) {
m.config.Cache = cache
}
var models = make(map[string]func(fs.Config) (Model, error))
// Register registers a model constructor for the given architecture

View File

@@ -12,6 +12,7 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model/input"
)
func TestParseTags(t *testing.T) {
@@ -56,6 +57,21 @@ type fakeTensor struct {
Name string
}
type fakeCache struct{}
func (f *fakeCache) Init(ml.Backend, ml.DType, int, int, int) {}
func (f *fakeCache) SetConfig(ml.CacheConfig) {}
func (f *fakeCache) SetLayer(int) {}
func (f *fakeCache) Get(ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
return nil, nil, nil
}
func (f *fakeCache) Put(ml.Context, ml.Tensor, ml.Tensor) {}
func (f *fakeCache) CopyPrefix(int, int, int32) {}
func (f *fakeCache) Remove(int, int32, int32) error { return nil }
func (f *fakeCache) CanResume(int, int32) bool { return true }
func (f *fakeCache) StartForward(ml.Context, input.Batch, bool) error { return nil }
func (f *fakeCache) Close() {}
// Stub methods to satisfy ml.Tensor interface
func (f *fakeTensor) Exp(ctx ml.Context) ml.Tensor { return f }
func (f *fakeTensor) Neg(ctx ml.Context) ml.Tensor { return f }
@@ -269,3 +285,20 @@ func TestModelForArch(t *testing.T) {
})
}
}
func TestBaseSetCache(t *testing.T) {
initial := &fakeCache{}
replacement := &fakeCache{}
base := &Base{
config: config{
Cache: initial,
},
}
base.SetCache(replacement)
if base.Config().Cache != replacement {
t.Fatal("expected Config().Cache to be replaced")
}
}

View File

@@ -31,6 +31,10 @@ type InputCache struct {
cache kvcache.Cache
}
type cacheSetter interface {
SetCache(kvcache.Cache)
}
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) {
numCtx := kvSize / int32(numSlots)
@@ -47,9 +51,21 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
cache := model.Config().Cache
if cache != nil {
dtype := kvCacheTypeFromStr(kvCacheType)
wrapped := false
if dtype == ml.DTypeTQ3 || dtype == ml.DTypeTQ4 {
cache = kvcache.NewTurboQuantWrapper(cache, dtype)
wrapped = true
}
// Model.Forward reads from the model's cache field, so replace it when we
// wrap the cache for TurboQuant.
if wrapped {
setter, ok := model.(cacheSetter)
if !ok {
return nil, errors.New("model does not support cache replacement for turboquant")
}
setter.SetCache(cache)
}
cache.Init(model.Backend(), dtype, numSlots, int(numCtx), batchSize)