mirror of
https://github.com/ollama/ollama.git
synced 2026-05-07 00:22:43 -05:00
kvcache: wire turboquant cache into new engine path
This commit is contained in:
@@ -60,12 +60,15 @@ type TurboQuantWrapper struct {
|
||||
inner Cache
|
||||
bitWidth int
|
||||
backend ml.Backend
|
||||
layer int
|
||||
|
||||
mu sync.Mutex
|
||||
rotations map[int]*rotationPair
|
||||
rotCtxs map[int]ml.Context
|
||||
}
|
||||
|
||||
var _ CheckpointCache = (*TurboQuantWrapper)(nil)
|
||||
|
||||
type rotationPair struct {
|
||||
// forward stores Pi^T in GGML format.
|
||||
// GGML Mulmat(A, B) computes A^T @ B.
|
||||
@@ -108,7 +111,10 @@ func (w *TurboQuantWrapper) Close() {
|
||||
w.inner.Close()
|
||||
}
|
||||
|
||||
func (w *TurboQuantWrapper) SetLayer(layer int) { w.inner.SetLayer(layer) }
|
||||
func (w *TurboQuantWrapper) SetLayer(layer int) {
|
||||
w.layer = layer
|
||||
w.inner.SetLayer(layer)
|
||||
}
|
||||
func (w *TurboQuantWrapper) SetConfig(config ml.CacheConfig) { w.inner.SetConfig(config) }
|
||||
func (w *TurboQuantWrapper) CopyPrefix(src, dst int, l int32) { w.inner.CopyPrefix(src, dst, l) }
|
||||
func (w *TurboQuantWrapper) CanResume(seq int, pos int32) bool { return w.inner.CanResume(seq, pos) }
|
||||
@@ -121,6 +127,20 @@ func (w *TurboQuantWrapper) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
return w.inner.Remove(seq, beginIndex, endIndex)
|
||||
}
|
||||
|
||||
func (w *TurboQuantWrapper) PrepareRestore(seq int, targetPos int32) (int32, bool) {
|
||||
if cc, ok := w.inner.(CheckpointCache); ok {
|
||||
return cc.PrepareRestore(seq, targetPos)
|
||||
}
|
||||
|
||||
// Preserve non-checkpoint cache behavior used by ollamarunner:
|
||||
// keep targetPos when the cache can resume, otherwise signal reprocess.
|
||||
if w.inner.CanResume(seq, targetPos) {
|
||||
return targetPos, true
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (w *TurboQuantWrapper) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
headDim := key.Dim(0)
|
||||
rot := w.getOrCreateRotation(headDim)
|
||||
@@ -139,6 +159,11 @@ func (w *TurboQuantWrapper) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor
|
||||
rot := w.getOrCreateRotation(headDim)
|
||||
|
||||
if rot != nil {
|
||||
// Metal does not provide all f32 x q4_0 matmul variants needed for
|
||||
// rotation. Cast cache output to f32 before applying inverse rotation.
|
||||
if key.DType() != ml.DTypeF32 {
|
||||
key = key.Cast(ctx, ml.DTypeF32)
|
||||
}
|
||||
key = rot.inverse.Mulmat(ctx, key)
|
||||
}
|
||||
|
||||
@@ -165,7 +190,7 @@ func (w *TurboQuantWrapper) getOrCreateRotation(headDim int) *rotationPair {
|
||||
piData := turboquant.GenerateRotation(headDim, seed)
|
||||
piTData := turboquant.GenerateRotationTranspose(headDim, seed)
|
||||
|
||||
rotCtx := w.backend.NewContextSize(2)
|
||||
rotCtx := w.backend.NewContextSize(2).Layer(w.layer)
|
||||
piTensor := rotCtx.FromFloats(piData, headDim, headDim)
|
||||
piTTensor := rotCtx.FromFloats(piTData, headDim, headDim)
|
||||
|
||||
|
||||
118
kvcache/turboquant_test.go
Normal file
118
kvcache/turboquant_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type turboQuantTestBaseCache struct {
|
||||
canResume bool
|
||||
canResumeCalls int
|
||||
canResumeLastSeq int
|
||||
canResumeLastPos int32
|
||||
}
|
||||
|
||||
func (c *turboQuantTestBaseCache) SetLayer(layer int) {}
|
||||
|
||||
func (c *turboQuantTestBaseCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func (c *turboQuantTestBaseCache) Put(ctx ml.Context, key, value ml.Tensor) {}
|
||||
|
||||
func (c *turboQuantTestBaseCache) SetConfig(config ml.CacheConfig) {}
|
||||
|
||||
func (c *turboQuantTestBaseCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
}
|
||||
|
||||
func (c *turboQuantTestBaseCache) Close() {}
|
||||
|
||||
func (c *turboQuantTestBaseCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *turboQuantTestBaseCache) CopyPrefix(srcSeq, dstSeq int, len int32) {}
|
||||
|
||||
func (c *turboQuantTestBaseCache) CanResume(seq int, pos int32) bool {
|
||||
c.canResumeCalls++
|
||||
c.canResumeLastSeq = seq
|
||||
c.canResumeLastPos = pos
|
||||
return c.canResume
|
||||
}
|
||||
|
||||
func (c *turboQuantTestBaseCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type turboQuantTestCheckpointCache struct {
|
||||
turboQuantTestBaseCache
|
||||
restorePos int32
|
||||
restoreOK bool
|
||||
prepareCalls int
|
||||
prepareLastSeq int
|
||||
prepareLastPos int32
|
||||
}
|
||||
|
||||
func (c *turboQuantTestCheckpointCache) PrepareRestore(seq int, targetPos int32) (int32, bool) {
|
||||
c.prepareCalls++
|
||||
c.prepareLastSeq = seq
|
||||
c.prepareLastPos = targetPos
|
||||
return c.restorePos, c.restoreOK
|
||||
}
|
||||
|
||||
func TestTurboQuantPrepareRestorePassthrough(t *testing.T) {
|
||||
inner := &turboQuantTestCheckpointCache{
|
||||
restorePos: 7,
|
||||
restoreOK: true,
|
||||
}
|
||||
|
||||
w := NewTurboQuantWrapper(inner, ml.DTypeTQ4)
|
||||
|
||||
gotPos, gotOK := w.PrepareRestore(2, 11)
|
||||
if !gotOK || gotPos != 7 {
|
||||
t.Fatalf("PrepareRestore() = (%d, %v), want (7, true)", gotPos, gotOK)
|
||||
}
|
||||
|
||||
if inner.prepareCalls != 1 {
|
||||
t.Fatalf("inner PrepareRestore calls = %d, want 1", inner.prepareCalls)
|
||||
}
|
||||
|
||||
if inner.prepareLastSeq != 2 || inner.prepareLastPos != 11 {
|
||||
t.Fatalf("inner PrepareRestore args = (%d, %d), want (2, 11)", inner.prepareLastSeq, inner.prepareLastPos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTurboQuantPrepareRestoreFallsBackToCanResume(t *testing.T) {
|
||||
inner := &turboQuantTestBaseCache{canResume: true}
|
||||
w := NewTurboQuantWrapper(inner, ml.DTypeTQ3)
|
||||
|
||||
gotPos, gotOK := w.PrepareRestore(3, 9)
|
||||
if !gotOK || gotPos != 9 {
|
||||
t.Fatalf("PrepareRestore() = (%d, %v), want (9, true)", gotPos, gotOK)
|
||||
}
|
||||
|
||||
if inner.canResumeCalls != 1 {
|
||||
t.Fatalf("inner CanResume calls = %d, want 1", inner.canResumeCalls)
|
||||
}
|
||||
|
||||
if inner.canResumeLastSeq != 3 || inner.canResumeLastPos != 9 {
|
||||
t.Fatalf("inner CanResume args = (%d, %d), want (3, 9)", inner.canResumeLastSeq, inner.canResumeLastPos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTurboQuantPrepareRestoreFallbackFailure(t *testing.T) {
|
||||
inner := &turboQuantTestBaseCache{canResume: false}
|
||||
w := NewTurboQuantWrapper(inner, ml.DTypeTQ4)
|
||||
|
||||
gotPos, gotOK := w.PrepareRestore(4, 13)
|
||||
if gotOK || gotPos != 0 {
|
||||
t.Fatalf("PrepareRestore() = (%d, %v), want (0, false)", gotPos, gotOK)
|
||||
}
|
||||
|
||||
if inner.canResumeCalls != 1 {
|
||||
t.Fatalf("inner CanResume calls = %d, want 1", inner.canResumeCalls)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,6 +98,11 @@ func (m *Base) Config() config {
|
||||
return m.config
|
||||
}
|
||||
|
||||
// SetCache replaces the model cache implementation at runtime.
|
||||
func (m *Base) SetCache(cache kvcache.Cache) {
|
||||
m.config.Cache = cache
|
||||
}
|
||||
|
||||
var models = make(map[string]func(fs.Config) (Model, error))
|
||||
|
||||
// Register registers a model constructor for the given architecture
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/backend/ggml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
func TestParseTags(t *testing.T) {
|
||||
@@ -56,6 +57,21 @@ type fakeTensor struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
type fakeCache struct{}
|
||||
|
||||
func (f *fakeCache) Init(ml.Backend, ml.DType, int, int, int) {}
|
||||
func (f *fakeCache) SetConfig(ml.CacheConfig) {}
|
||||
func (f *fakeCache) SetLayer(int) {}
|
||||
func (f *fakeCache) Get(ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (f *fakeCache) Put(ml.Context, ml.Tensor, ml.Tensor) {}
|
||||
func (f *fakeCache) CopyPrefix(int, int, int32) {}
|
||||
func (f *fakeCache) Remove(int, int32, int32) error { return nil }
|
||||
func (f *fakeCache) CanResume(int, int32) bool { return true }
|
||||
func (f *fakeCache) StartForward(ml.Context, input.Batch, bool) error { return nil }
|
||||
func (f *fakeCache) Close() {}
|
||||
|
||||
// Stub methods to satisfy ml.Tensor interface
|
||||
func (f *fakeTensor) Exp(ctx ml.Context) ml.Tensor { return f }
|
||||
func (f *fakeTensor) Neg(ctx ml.Context) ml.Tensor { return f }
|
||||
@@ -269,3 +285,20 @@ func TestModelForArch(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseSetCache(t *testing.T) {
|
||||
initial := &fakeCache{}
|
||||
replacement := &fakeCache{}
|
||||
|
||||
base := &Base{
|
||||
config: config{
|
||||
Cache: initial,
|
||||
},
|
||||
}
|
||||
|
||||
base.SetCache(replacement)
|
||||
|
||||
if base.Config().Cache != replacement {
|
||||
t.Fatal("expected Config().Cache to be replaced")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,6 +31,10 @@ type InputCache struct {
|
||||
cache kvcache.Cache
|
||||
}
|
||||
|
||||
type cacheSetter interface {
|
||||
SetCache(kvcache.Cache)
|
||||
}
|
||||
|
||||
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) {
|
||||
numCtx := kvSize / int32(numSlots)
|
||||
|
||||
@@ -47,9 +51,21 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
|
||||
cache := model.Config().Cache
|
||||
if cache != nil {
|
||||
dtype := kvCacheTypeFromStr(kvCacheType)
|
||||
wrapped := false
|
||||
|
||||
if dtype == ml.DTypeTQ3 || dtype == ml.DTypeTQ4 {
|
||||
cache = kvcache.NewTurboQuantWrapper(cache, dtype)
|
||||
wrapped = true
|
||||
}
|
||||
|
||||
// Model.Forward reads from the model's cache field, so replace it when we
|
||||
// wrap the cache for TurboQuant.
|
||||
if wrapped {
|
||||
setter, ok := model.(cacheSetter)
|
||||
if !ok {
|
||||
return nil, errors.New("model does not support cache replacement for turboquant")
|
||||
}
|
||||
setter.SetCache(cache)
|
||||
}
|
||||
|
||||
cache.Init(model.Backend(), dtype, numSlots, int(numCtx), batchSize)
|
||||
|
||||
Reference in New Issue
Block a user