From aafb65a95708f79288c676bb36e91d78657939eb Mon Sep 17 00:00:00 2001 From: dankguy17 Date: Sat, 28 Mar 2026 21:59:31 -0700 Subject: [PATCH] kvcache: wire turboquant cache into new engine path --- kvcache/turboquant.go | 29 ++++++++- kvcache/turboquant_test.go | 118 +++++++++++++++++++++++++++++++++++ model/model.go | 5 ++ model/model_test.go | 33 ++++++++++ runner/ollamarunner/cache.go | 16 +++++ 5 files changed, 199 insertions(+), 2 deletions(-) create mode 100644 kvcache/turboquant_test.go diff --git a/kvcache/turboquant.go b/kvcache/turboquant.go index 88a35ac3d..526b383b8 100644 --- a/kvcache/turboquant.go +++ b/kvcache/turboquant.go @@ -60,12 +60,15 @@ type TurboQuantWrapper struct { inner Cache bitWidth int backend ml.Backend + layer int mu sync.Mutex rotations map[int]*rotationPair rotCtxs map[int]ml.Context } +var _ CheckpointCache = (*TurboQuantWrapper)(nil) + type rotationPair struct { // forward stores Pi^T in GGML format. // GGML Mulmat(A, B) computes A^T @ B. @@ -108,7 +111,10 @@ func (w *TurboQuantWrapper) Close() { w.inner.Close() } -func (w *TurboQuantWrapper) SetLayer(layer int) { w.inner.SetLayer(layer) } +func (w *TurboQuantWrapper) SetLayer(layer int) { + w.layer = layer + w.inner.SetLayer(layer) +} func (w *TurboQuantWrapper) SetConfig(config ml.CacheConfig) { w.inner.SetConfig(config) } func (w *TurboQuantWrapper) CopyPrefix(src, dst int, l int32) { w.inner.CopyPrefix(src, dst, l) } func (w *TurboQuantWrapper) CanResume(seq int, pos int32) bool { return w.inner.CanResume(seq, pos) } @@ -121,6 +127,20 @@ func (w *TurboQuantWrapper) Remove(seq int, beginIndex, endIndex int32) error { return w.inner.Remove(seq, beginIndex, endIndex) } +func (w *TurboQuantWrapper) PrepareRestore(seq int, targetPos int32) (int32, bool) { + if cc, ok := w.inner.(CheckpointCache); ok { + return cc.PrepareRestore(seq, targetPos) + } + + // Preserve non-checkpoint cache behavior used by ollamarunner: + // keep targetPos when the cache can resume, otherwise signal reprocess. + if w.inner.CanResume(seq, targetPos) { + return targetPos, true + } + + return 0, false +} + func (w *TurboQuantWrapper) Put(ctx ml.Context, key, value ml.Tensor) { headDim := key.Dim(0) rot := w.getOrCreateRotation(headDim) @@ -139,6 +159,11 @@ func (w *TurboQuantWrapper) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor rot := w.getOrCreateRotation(headDim) if rot != nil { + // Metal does not provide all f32 x q4_0 matmul variants needed for + // rotation. Cast cache output to f32 before applying inverse rotation. + if key.DType() != ml.DTypeF32 { + key = key.Cast(ctx, ml.DTypeF32) + } key = rot.inverse.Mulmat(ctx, key) } @@ -165,7 +190,7 @@ func (w *TurboQuantWrapper) getOrCreateRotation(headDim int) *rotationPair { piData := turboquant.GenerateRotation(headDim, seed) piTData := turboquant.GenerateRotationTranspose(headDim, seed) - rotCtx := w.backend.NewContextSize(2) + rotCtx := w.backend.NewContextSize(2).Layer(w.layer) piTensor := rotCtx.FromFloats(piData, headDim, headDim) piTTensor := rotCtx.FromFloats(piTData, headDim, headDim) diff --git a/kvcache/turboquant_test.go b/kvcache/turboquant_test.go new file mode 100644 index 000000000..1f0e619c5 --- /dev/null +++ b/kvcache/turboquant_test.go @@ -0,0 +1,118 @@ +package kvcache + +import ( + "testing" + + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" +) + +type turboQuantTestBaseCache struct { + canResume bool + canResumeCalls int + canResumeLastSeq int + canResumeLastPos int32 +} + +func (c *turboQuantTestBaseCache) SetLayer(layer int) {} + +func (c *turboQuantTestBaseCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { + return nil, nil, nil +} + +func (c *turboQuantTestBaseCache) Put(ctx ml.Context, key, value ml.Tensor) {} + +func (c *turboQuantTestBaseCache) SetConfig(config ml.CacheConfig) {} + +func (c *turboQuantTestBaseCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { +} + +func (c *turboQuantTestBaseCache) Close() {} + +func (c *turboQuantTestBaseCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { + return nil +} + +func (c *turboQuantTestBaseCache) CopyPrefix(srcSeq, dstSeq int, len int32) {} + +func (c *turboQuantTestBaseCache) CanResume(seq int, pos int32) bool { + c.canResumeCalls++ + c.canResumeLastSeq = seq + c.canResumeLastPos = pos + return c.canResume +} + +func (c *turboQuantTestBaseCache) Remove(seq int, beginIndex, endIndex int32) error { + return nil +} + +type turboQuantTestCheckpointCache struct { + turboQuantTestBaseCache + restorePos int32 + restoreOK bool + prepareCalls int + prepareLastSeq int + prepareLastPos int32 +} + +func (c *turboQuantTestCheckpointCache) PrepareRestore(seq int, targetPos int32) (int32, bool) { + c.prepareCalls++ + c.prepareLastSeq = seq + c.prepareLastPos = targetPos + return c.restorePos, c.restoreOK +} + +func TestTurboQuantPrepareRestorePassthrough(t *testing.T) { + inner := &turboQuantTestCheckpointCache{ + restorePos: 7, + restoreOK: true, + } + + w := NewTurboQuantWrapper(inner, ml.DTypeTQ4) + + gotPos, gotOK := w.PrepareRestore(2, 11) + if !gotOK || gotPos != 7 { + t.Fatalf("PrepareRestore() = (%d, %v), want (7, true)", gotPos, gotOK) + } + + if inner.prepareCalls != 1 { + t.Fatalf("inner PrepareRestore calls = %d, want 1", inner.prepareCalls) + } + + if inner.prepareLastSeq != 2 || inner.prepareLastPos != 11 { + t.Fatalf("inner PrepareRestore args = (%d, %d), want (2, 11)", inner.prepareLastSeq, inner.prepareLastPos) + } +} + +func TestTurboQuantPrepareRestoreFallsBackToCanResume(t *testing.T) { + inner := &turboQuantTestBaseCache{canResume: true} + w := NewTurboQuantWrapper(inner, ml.DTypeTQ3) + + gotPos, gotOK := w.PrepareRestore(3, 9) + if !gotOK || gotPos != 9 { + t.Fatalf("PrepareRestore() = (%d, %v), want (9, true)", gotPos, gotOK) + } + + if inner.canResumeCalls != 1 { + t.Fatalf("inner CanResume calls = %d, want 1", inner.canResumeCalls) + } + + if inner.canResumeLastSeq != 3 || inner.canResumeLastPos != 9 { + t.Fatalf("inner CanResume args = (%d, %d), want (3, 9)", inner.canResumeLastSeq, inner.canResumeLastPos) + } +} + +func TestTurboQuantPrepareRestoreFallbackFailure(t *testing.T) { + inner := &turboQuantTestBaseCache{canResume: false} + w := NewTurboQuantWrapper(inner, ml.DTypeTQ4) + + gotPos, gotOK := w.PrepareRestore(4, 13) + if gotOK || gotPos != 0 { + t.Fatalf("PrepareRestore() = (%d, %v), want (0, false)", gotPos, gotOK) + } + + if inner.canResumeCalls != 1 { + t.Fatalf("inner CanResume calls = %d, want 1", inner.canResumeCalls) + } +} + diff --git a/model/model.go b/model/model.go index 42fe7f25c..95fda3206 100644 --- a/model/model.go +++ b/model/model.go @@ -98,6 +98,11 @@ func (m *Base) Config() config { return m.config } +// SetCache replaces the model cache implementation at runtime. +func (m *Base) SetCache(cache kvcache.Cache) { + m.config.Cache = cache +} + var models = make(map[string]func(fs.Config) (Model, error)) // Register registers a model constructor for the given architecture diff --git a/model/model_test.go b/model/model_test.go index 03b9460d0..419e6e6e3 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -12,6 +12,7 @@ import ( "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/backend/ggml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/model/input" ) func TestParseTags(t *testing.T) { @@ -56,6 +57,21 @@ type fakeTensor struct { Name string } +type fakeCache struct{} + +func (f *fakeCache) Init(ml.Backend, ml.DType, int, int, int) {} +func (f *fakeCache) SetConfig(ml.CacheConfig) {} +func (f *fakeCache) SetLayer(int) {} +func (f *fakeCache) Get(ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { + return nil, nil, nil +} +func (f *fakeCache) Put(ml.Context, ml.Tensor, ml.Tensor) {} +func (f *fakeCache) CopyPrefix(int, int, int32) {} +func (f *fakeCache) Remove(int, int32, int32) error { return nil } +func (f *fakeCache) CanResume(int, int32) bool { return true } +func (f *fakeCache) StartForward(ml.Context, input.Batch, bool) error { return nil } +func (f *fakeCache) Close() {} + // Stub methods to satisfy ml.Tensor interface func (f *fakeTensor) Exp(ctx ml.Context) ml.Tensor { return f } func (f *fakeTensor) Neg(ctx ml.Context) ml.Tensor { return f } @@ -269,3 +285,20 @@ func TestModelForArch(t *testing.T) { }) } } + +func TestBaseSetCache(t *testing.T) { + initial := &fakeCache{} + replacement := &fakeCache{} + + base := &Base{ + config: config{ + Cache: initial, + }, + } + + base.SetCache(replacement) + + if base.Config().Cache != replacement { + t.Fatal("expected Config().Cache to be replaced") + } +} diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index 94ad92b5a..5f309348a 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -31,6 +31,10 @@ type InputCache struct { cache kvcache.Cache } +type cacheSetter interface { + SetCache(kvcache.Cache) +} + func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) { numCtx := kvSize / int32(numSlots) @@ -47,9 +51,21 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots cache := model.Config().Cache if cache != nil { dtype := kvCacheTypeFromStr(kvCacheType) + wrapped := false if dtype == ml.DTypeTQ3 || dtype == ml.DTypeTQ4 { cache = kvcache.NewTurboQuantWrapper(cache, dtype) + wrapped = true + } + + // Model.Forward reads from the model's cache field, so replace it when we + // wrap the cache for TurboQuant. + if wrapped { + setter, ok := model.(cacheSetter) + if !ok { + return nil, errors.New("model does not support cache replacement for turboquant") + } + setter.SetCache(cache) } cache.Init(model.Backend(), dtype, numSlots, int(numCtx), batchSize)