[PR #15505] kvcache: add TurboQuant compressed KV cache (tq2/tq3/tq2k/tq3k) #41052

Open
opened 2026-04-23 01:47:45 -05:00 by GiteaMirror · 0 comments
Owner

📋 Pull Request Information

Original PR: https://github.com/ollama/ollama/pull/15505
Author: @mverrilli
Created: 4/11/2026
Status: 🔄 Open

Base: mainHead: turboquant


📝 Commits (9)

  • 8325ac3 turboquant: add block encoder and Lloyd-Max codebook primitives
  • cf0e2e5 ml/backend/ggml: add CUDA kernels and GGML ops for TurboQuant KV compression
  • b454db9 ml: add Go bindings for the TurboQuant compressed KV manager
  • a877fd2 ml/backend/ggml: gate TurboQuant CUDA kernels on wave32 AMD GPUs
  • f563d8e kvcache: add TurboQuant compressed KV cache with tq2/tq3/tq2k/tq3k presets
  • 76e5fc2 ml/backend/ggml: port the TurboQuant CUDA kernels to Metal
  • 0c1f7f1 ml/backend/ggml: add D=256 TurboQuant flash-attention kernels on Metal
  • fda849a kvcache: route TurboQuant prefill through DequantKV on Metal
  • c56d85e ml/backend/ggml: optimize the Metal TurboQuant dequant kernel

📊 Changes

66 files changed (+17879 additions, -73 deletions)

View changed files

📝 CMakeLists.txt (+1 -0)
📝 docs/faq.mdx (+10 -0)
📝 fs/ggml/ggml.go (+27 -1)
📝 kvcache/cache.go (+9 -0)
📝 kvcache/causal.go (+136 -53)
📝 kvcache/recurrent.go (+14 -1)
kvcache/turboquant.go (+619 -0)
kvcache/turboquant_test.go (+144 -0)
llama/patches/0037-turboquant-add-CUDA-kernels-and-GGML-ops-for-TurboQu.patch (+2473 -0)
llama/patches/0038-turboquant-fix-HIP-compile-drop-cuda_fp16.h-pass-__s.patch (+98 -0)
llama/patches/0039-ml-backend-ggml-port-the-TurboQuant-CUDA-kernels-to-.patch (+1943 -0)
llama/patches/0040-ml-backend-ggml-add-D-256-TurboQuant-flash-attention.patch (+667 -0)
llama/patches/0041-ml-backend-ggml-optimize-the-Metal-TurboQuant-dequan.patch (+135 -0)
📝 llm/server.go (+11 -7)
📝 ml/backend.go (+74 -1)
📝 ml/backend/ggml/ggml.go (+383 -1)
📝 ml/backend/ggml/ggml/include/ggml.h (+155 -0)
📝 ml/backend/ggml/ggml/src/ggml-backend.cpp (+6 -1)
📝 ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c (+23 -0)
📝 ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp (+9 -0)

...and 46 more files

📄 Description

Summary

Adds a GPU-resident compressed KV cache based on the TurboQuant paper
(arXiv 2504.19874), implementing
Algorithm 1 (Householder QR rotation + Lloyd-Max scalar quantization)
as four new OLLAMA_KV_CACHE_TYPE options.

Name K bits V bits Approx VRAM vs f16 Notes
tq3k 3 f16 ~60% quality-preserving K-only, no FA required
tq3 3 3 ~20% balanced tier
tq2k 2 f16 ~57% smallest K footprint with f16 V, no FA required
tq2 2 2 ~14% most aggressive

The compressed K (and optionally V) tensors live in GPU memory with
their own Lloyd-Max codebook, rotation matrix, and per-cell RMS scales.
Encode, dequant, and an optional inline-decode fused flash attention
path are implemented as new ops (GGML_OP_TQ_ENCODE, GGML_OP_TQ_DEQUANT,
GGML_OP_TQ_FLASH_ATTN_EXT, plus variants for V and combined K+V).
The CPU backend rejects the new ops so the scheduler never routes them
off-GPU.

Platform support

NVIDIA (CUDA). Pascal (cc 6.0+) and newer. Validated on Pascal (P40)
and Blackwell (RTX 5060).

AMD (ROCm). RDNA2+ (gfx1030+), validated on RX 7600 (RDNA3, gfx1102).
Wave64 AMD (Vega / GCN / CDNA) is explicitly blocked — the TQ
__shfl_sync(width=32) codebook lookup produces garbage on 64-wide
wavefronts. ROCm HIP compile fixes (drop cuda_fp16.h, explicit
__shfl_sync width) are included in llama/patches/0038. RDNA1
(gfx1010–gfx1012) shares ccMajor=16 with RDNA2 and is admitted by
the gate but untested.

Apple Silicon (Metal). All Apple GPUs that ship with the unified
Metal shader-compiler backend — SIMD groups are natively 32-wide, so
CUDA's __shfl_sync(width=32) maps 1-to-1 to simd_shuffle /
simd_shuffle_xor. Fused inline-decode flash attention is instantiated
at both headDim=128 (llama / qwen families) and headDim=256
(gemma3). Validated on M-series silicon with llama3.2:3b, gemma3:1b
and gemma3:4b.

Usage

Set the cache type via environment variable before starting the Ollama
server:

# K-only, no FA required
OLLAMA_KV_CACHE_TYPE=tq3k ollama serve

# K+V, requires flash attention
OLLAMA_FLASH_ATTENTION=1 OLLAMA_KV_CACHE_TYPE=tq3 ollama serve

K+V presets (tq3, tq2) require flash attention. K-only presets
(tq3k, tq2k) work with or without it. The setting is global and
applies only to models running through Ollama's native engine.

Tested models

Model Backend coverage Notes
llama3.2:3b CUDA, ROCm, Metal primary benchmark model
gemma3:1b / gemma3:4b CUDA, Metal headDim=256; sliding-window layers stay at f16, global layers use TQ (WrapperCache)
qwen3-coder:30b CUDA MoE; drove the GGML_SCHED_MAX_SPLIT_INPUTS change below
qwen2.5:7b CUDA see Known limitations — Qwen 2 is a known weak spot

Benchmarks

Non-TQ users see no behavioural or performance changes — all TQ code is
gated on OLLAMA_KV_CACHE_TYPE=tq*.

NVIDIA Pascal (Tesla P40) — llama3.2:3b @ ctx=2048

Mode Prefill tok/s Decode tok/s PPL KV MB
f16 1790 74.6 1.0157 224
tq3k 1415 (-21%) 65.1 (-13%) 1.0094 (-0.62%) 135 (-40%)
tq3 1132 (-37%) 57.9 (-22%) 1.0130 (-0.27%) 46 (-79%)
tq2k 1417 (-21%) 65.6 (-12%) 1.0189 (+0.31%) 128 (-43%)
tq2 1140 (-36%) 58.5 (-22%) 1.0355 (+1.95%) 32 (-86%)

On Ampere and newer — tensor cores for the rotation mul_mat, higher
HBM bandwidth, lower kernel launch latency — the decode regression vs
f16 is expected to shrink, roughly halving the measured Pascal delta.

NVIDIA Pascal (Tesla P40) — qwen3-coder:30b @ ctx=32768

Mode Decode tok/s PPL KV MB
f16 40.5 1.0000 3072
tq3k 25.3 (-37%) 1.0010 (within noise) 2156 (-30%)
tq3 17.6 (-57%) 1.0015 (within noise) 728 (-76%)

AMD RX 7600 (RDNA3, gfx1102) — llama3.2:3b @ ctx=4096

Mode Decode tok/s vs f16 PPL vs f16
f16 87.6 baseline
tq3k 82.5 -5.8% +9.4%
tq3 79.6 -9.1% +11.4%
tq2k 84.1 -3.9%
tq2 80.3 -8.3%

Apple Silicon (Metal) — llama3.2:3b @ ctx=32768

Mode Prefill tok/s Decode tok/s
tq3 ~338 ~46

Prefill is routed through DequantKV + stock flash attention; decode
stays on the fused inline-decode kernel. The Metal dequant kernel is
vectorised (4 elements per iteration, one half4 store per inner step)
with per-SIMDgroup dispatch.

Known limitations

  1. Qwen 2 family (Qwen 2.5, Qwen 2-VL, etc.) — learned K bias produces
    a bias-dominated asymmetric K distribution that none of the symmetric
    quantizers here handle well. PPL degrades sharply. Documented in
    docs/faq.mdx. A follow-up adding per-vector asymmetric quantization
    (following the NVIDIA TensorRT-LLM hint in
    #4218) is
    planned and will plug into the same encode/dequant infrastructure.

  2. Outlier split (paper §4.3) and Algorithm 2 QJL residual are
    implemented but disabled in the shipped presets (OutlierCount=0):
    on the validated model set they regress decode throughput and PPL
    without improving heavy-tailed quality. Infrastructure remains for
    future dynamic dispatch.

  3. Fused inline-decode flash attention is restricted to
    k_bits ∈ {2, 3}. Head-dim coverage is {128, 256} on Metal and
    {128} on CUDA/ROCm. Configurations outside those fall through to
    the separate DequantKV + stock FA path, which is the faster path
    for prefill in any case.

  4. RDNA1 (gfx1010–gfx1012) — wave32, structurally compatible, but
    ROCm support for these targets is incomplete in recent SDK releases
    and is untested here. RDNA2+ (gfx1030+) is the validated minimum.

Notable non-TQ change

ml/backend/ggml/ggml/src/ggml-backend.cpp bumps
GGML_SCHED_MAX_SPLIT_INPUTS from 30 to 128. This is the only
global change in the PR that is not gated on a TQ cache type.

  • Large MoE models like qwen3-coder:30b (48 layers with expert routing)
    combined with TQ K+V encode push split input counts above the old
    30-input ceiling, causing GGML_ASSERT at graph build time.
  • Upstream GGML has a FIXME in the same area noting that the check
    only fires when the split is exactly full, so multi-input ops can
    already overshoot it — the old limit was artificially conservative.
  • Cost is CPU heap only; the OS lazy-faults pages on write, so RSS
    impact is bounded by what the scheduler actually populates. Typical
    non-MoE graphs never fill more than ~30 inputs per split and never
    touch the extra range.
  • Behavioural impact on non-TQ users: none. The change only unblocks
    previously-failing configurations.

Test plan

  • go test ./turboquant/... ./kvcache/... ./ml/backend/ggml/... ./runner/ollamarunner/...
  • Smoke generation on llama3.2:3b with f16, tq3k, tq3, tq2k, tq2
  • Smoke generation on gemma3:1b / gemma3:4b with tq3k (WrapperCache path; only global sub-cache wrapped)
  • Benchmark matrix across llama3.2:3b / gemma3 / qwen2.5:7b / qwen3-coder:30b at ctx 2048 / 8192 / 32768
  • ROCm Docker build (docker build --target rocm-7) clean
  • ROCm runtime validated on AMD RX 7600 (RDNA3, gfx1102)
  • Metal runtime validated on Apple Silicon (llama3.2:3b and gemma3)
  • Non-TQ cache paths (f16, q8_0, q4_0) unchanged — verified SkipK/SkipV gates and PresetFromDType returning false for non-tq types

🔄 This issue represents a GitHub Pull Request. It cannot be merged through Gitea due to API limitations.

## 📋 Pull Request Information **Original PR:** https://github.com/ollama/ollama/pull/15505 **Author:** [@mverrilli](https://github.com/mverrilli) **Created:** 4/11/2026 **Status:** 🔄 Open **Base:** `main` ← **Head:** `turboquant` --- ### 📝 Commits (9) - [`8325ac3`](https://github.com/ollama/ollama/commit/8325ac321abec1be76798946d0811b6409c8e47d) turboquant: add block encoder and Lloyd-Max codebook primitives - [`cf0e2e5`](https://github.com/ollama/ollama/commit/cf0e2e5baea6c5b2d6f837fcb20055f8d5810c3c) ml/backend/ggml: add CUDA kernels and GGML ops for TurboQuant KV compression - [`b454db9`](https://github.com/ollama/ollama/commit/b454db971bbbcfc705fdb2845bffd028b6c88bbc) ml: add Go bindings for the TurboQuant compressed KV manager - [`a877fd2`](https://github.com/ollama/ollama/commit/a877fd2f7828a191e8c12ea30bac6425447e4bd7) ml/backend/ggml: gate TurboQuant CUDA kernels on wave32 AMD GPUs - [`f563d8e`](https://github.com/ollama/ollama/commit/f563d8ebbeeefeb2b2b39eeede871c2379284e4b) kvcache: add TurboQuant compressed KV cache with tq2/tq3/tq2k/tq3k presets - [`76e5fc2`](https://github.com/ollama/ollama/commit/76e5fc2b754766bb64408c020939393a6cc09d90) ml/backend/ggml: port the TurboQuant CUDA kernels to Metal - [`0c1f7f1`](https://github.com/ollama/ollama/commit/0c1f7f108de1f01b56e50b384770d5ecab6dface) ml/backend/ggml: add D=256 TurboQuant flash-attention kernels on Metal - [`fda849a`](https://github.com/ollama/ollama/commit/fda849a7749426e7a8717fba97ab23fef6491b45) kvcache: route TurboQuant prefill through DequantKV on Metal - [`c56d85e`](https://github.com/ollama/ollama/commit/c56d85ee5e34807107a715c2937941f9220d452a) ml/backend/ggml: optimize the Metal TurboQuant dequant kernel ### 📊 Changes **66 files changed** (+17879 additions, -73 deletions) <details> <summary>View changed files</summary> 📝 `CMakeLists.txt` (+1 -0) 📝 `docs/faq.mdx` (+10 -0) 📝 `fs/ggml/ggml.go` (+27 -1) 📝 `kvcache/cache.go` (+9 -0) 📝 `kvcache/causal.go` (+136 -53) 📝 `kvcache/recurrent.go` (+14 -1) ➕ `kvcache/turboquant.go` (+619 -0) ➕ `kvcache/turboquant_test.go` (+144 -0) ➕ `llama/patches/0037-turboquant-add-CUDA-kernels-and-GGML-ops-for-TurboQu.patch` (+2473 -0) ➕ `llama/patches/0038-turboquant-fix-HIP-compile-drop-cuda_fp16.h-pass-__s.patch` (+98 -0) ➕ `llama/patches/0039-ml-backend-ggml-port-the-TurboQuant-CUDA-kernels-to-.patch` (+1943 -0) ➕ `llama/patches/0040-ml-backend-ggml-add-D-256-TurboQuant-flash-attention.patch` (+667 -0) ➕ `llama/patches/0041-ml-backend-ggml-optimize-the-Metal-TurboQuant-dequan.patch` (+135 -0) 📝 `llm/server.go` (+11 -7) 📝 `ml/backend.go` (+74 -1) 📝 `ml/backend/ggml/ggml.go` (+383 -1) 📝 `ml/backend/ggml/ggml/include/ggml.h` (+155 -0) 📝 `ml/backend/ggml/ggml/src/ggml-backend.cpp` (+6 -1) 📝 `ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c` (+23 -0) 📝 `ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp` (+9 -0) _...and 46 more files_ </details> ### 📄 Description ## Summary Adds a GPU-resident compressed KV cache based on the TurboQuant paper ([arXiv 2504.19874](https://arxiv.org/abs/2504.19874)), implementing Algorithm 1 (Householder QR rotation + Lloyd-Max scalar quantization) as four new `OLLAMA_KV_CACHE_TYPE` options. | Name | K bits | V bits | Approx VRAM vs f16 | Notes | |-------|--------|--------|---------------------|-------| | `tq3k` | 3 | f16 | ~60% | quality-preserving K-only, no FA required | | `tq3` | 3 | 3 | ~20% | balanced tier | | `tq2k` | 2 | f16 | ~57% | smallest K footprint with f16 V, no FA required | | `tq2` | 2 | 2 | ~14% | most aggressive | The compressed K (and optionally V) tensors live in GPU memory with their own Lloyd-Max codebook, rotation matrix, and per-cell RMS scales. Encode, dequant, and an optional inline-decode fused flash attention path are implemented as new ops (`GGML_OP_TQ_ENCODE`, `GGML_OP_TQ_DEQUANT`, `GGML_OP_TQ_FLASH_ATTN_EXT`, plus variants for V and combined K+V). The CPU backend rejects the new ops so the scheduler never routes them off-GPU. ## Platform support **NVIDIA (CUDA).** Pascal (cc 6.0+) and newer. Validated on Pascal (P40) and Blackwell (RTX 5060). **AMD (ROCm).** RDNA2+ (gfx1030+), validated on RX 7600 (RDNA3, gfx1102). Wave64 AMD (Vega / GCN / CDNA) is explicitly blocked — the TQ `__shfl_sync(width=32)` codebook lookup produces garbage on 64-wide wavefronts. ROCm HIP compile fixes (drop `cuda_fp16.h`, explicit `__shfl_sync` width) are included in `llama/patches/0038`. RDNA1 (gfx1010–gfx1012) shares `ccMajor=16` with RDNA2 and is admitted by the gate but untested. **Apple Silicon (Metal).** All Apple GPUs that ship with the unified Metal shader-compiler backend — SIMD groups are natively 32-wide, so CUDA's `__shfl_sync(width=32)` maps 1-to-1 to `simd_shuffle` / `simd_shuffle_xor`. Fused inline-decode flash attention is instantiated at both `headDim=128` (llama / qwen families) and `headDim=256` (gemma3). Validated on M-series silicon with llama3.2:3b, gemma3:1b and gemma3:4b. ## Usage Set the cache type via environment variable before starting the Ollama server: ```sh # K-only, no FA required OLLAMA_KV_CACHE_TYPE=tq3k ollama serve # K+V, requires flash attention OLLAMA_FLASH_ATTENTION=1 OLLAMA_KV_CACHE_TYPE=tq3 ollama serve ``` K+V presets (`tq3`, `tq2`) require flash attention. K-only presets (`tq3k`, `tq2k`) work with or without it. The setting is global and applies only to models running through Ollama's native engine. ## Tested models | Model | Backend coverage | Notes | |-------|------------------|-------| | llama3.2:3b | CUDA, ROCm, Metal | primary benchmark model | | gemma3:1b / gemma3:4b | CUDA, Metal | `headDim=256`; sliding-window layers stay at `f16`, global layers use TQ (WrapperCache) | | qwen3-coder:30b | CUDA | MoE; drove the `GGML_SCHED_MAX_SPLIT_INPUTS` change below | | qwen2.5:7b | CUDA | see **Known limitations** — Qwen 2 is a known weak spot | ## Benchmarks Non-TQ users see no behavioural or performance changes — all TQ code is gated on `OLLAMA_KV_CACHE_TYPE=tq*`. ### NVIDIA Pascal (Tesla P40) — llama3.2:3b @ ctx=2048 | Mode | Prefill tok/s | Decode tok/s | PPL | KV MB | |------|---------------|---------------|-----|-------| | f16 | 1790 | 74.6 | 1.0157 | 224 | | tq3k | 1415 (-21%) | 65.1 (-13%) | 1.0094 (**-0.62%**) | 135 (-40%) | | tq3 | 1132 (-37%) | 57.9 (-22%) | 1.0130 (-0.27%) | 46 (-79%) | | tq2k | 1417 (-21%) | 65.6 (-12%) | 1.0189 (+0.31%) | 128 (-43%) | | tq2 | 1140 (-36%) | 58.5 (-22%) | 1.0355 (+1.95%) | 32 (-86%) | On Ampere and newer — tensor cores for the rotation mul_mat, higher HBM bandwidth, lower kernel launch latency — the decode regression vs `f16` is expected to shrink, roughly halving the measured Pascal delta. ### NVIDIA Pascal (Tesla P40) — qwen3-coder:30b @ ctx=32768 | Mode | Decode tok/s | PPL | KV MB | |------|--------------|-----|-------| | f16 | 40.5 | 1.0000 | 3072 | | tq3k | 25.3 (-37%) | **1.0010** (within noise) | 2156 (-30%) | | tq3 | 17.6 (-57%) | **1.0015** (within noise) | 728 (-76%) | ### AMD RX 7600 (RDNA3, gfx1102) — llama3.2:3b @ ctx=4096 | Mode | Decode tok/s | vs f16 | PPL vs f16 | |------|--------------|--------|------------| | f16 | 87.6 | — | baseline | | tq3k | 82.5 | -5.8% | +9.4% | | tq3 | 79.6 | -9.1% | +11.4% | | tq2k | 84.1 | -3.9% | — | | tq2 | 80.3 | -8.3% | — | ### Apple Silicon (Metal) — llama3.2:3b @ ctx=32768 | Mode | Prefill tok/s | Decode tok/s | |------|---------------|---------------| | tq3 | ~338 | ~46 | Prefill is routed through `DequantKV` + stock flash attention; decode stays on the fused inline-decode kernel. The Metal dequant kernel is vectorised (4 elements per iteration, one `half4` store per inner step) with per-SIMDgroup dispatch. ## Known limitations 1. **Qwen 2 family (Qwen 2.5, Qwen 2-VL, etc.)** — learned K bias produces a bias-dominated asymmetric K distribution that none of the symmetric quantizers here handle well. PPL degrades sharply. Documented in `docs/faq.mdx`. A follow-up adding per-vector asymmetric quantization (following the NVIDIA TensorRT-LLM hint in [#4218](https://github.com/NVIDIA/TensorRT-LLM/issues/4218)) is planned and will plug into the same encode/dequant infrastructure. 2. **Outlier split (paper §4.3) and Algorithm 2 QJL residual** are implemented but disabled in the shipped presets (`OutlierCount=0`): on the validated model set they regress decode throughput and PPL without improving heavy-tailed quality. Infrastructure remains for future dynamic dispatch. 3. **Fused inline-decode flash attention** is restricted to `k_bits ∈ {2, 3}`. Head-dim coverage is `{128, 256}` on Metal and `{128}` on CUDA/ROCm. Configurations outside those fall through to the separate `DequantKV` + stock FA path, which is the faster path for prefill in any case. 4. **RDNA1 (gfx1010–gfx1012)** — wave32, structurally compatible, but ROCm support for these targets is incomplete in recent SDK releases and is untested here. RDNA2+ (gfx1030+) is the validated minimum. ## Notable non-TQ change `ml/backend/ggml/ggml/src/ggml-backend.cpp` bumps `GGML_SCHED_MAX_SPLIT_INPUTS` from **30 to 128**. This is the only global change in the PR that is not gated on a TQ cache type. - Large MoE models like qwen3-coder:30b (48 layers with expert routing) combined with TQ K+V encode push split input counts above the old 30-input ceiling, causing `GGML_ASSERT` at graph build time. - Upstream GGML has a FIXME in the same area noting that the check only fires when the split is exactly full, so multi-input ops can already overshoot it — the old limit was artificially conservative. - Cost is CPU heap only; the OS lazy-faults pages on write, so RSS impact is bounded by what the scheduler actually populates. Typical non-MoE graphs never fill more than ~30 inputs per split and never touch the extra range. - Behavioural impact on non-TQ users: none. The change only unblocks previously-failing configurations. ## Test plan - [x] `go test ./turboquant/... ./kvcache/... ./ml/backend/ggml/... ./runner/ollamarunner/...` - [x] Smoke generation on llama3.2:3b with `f16`, `tq3k`, `tq3`, `tq2k`, `tq2` - [x] Smoke generation on gemma3:1b / gemma3:4b with `tq3k` (WrapperCache path; only global sub-cache wrapped) - [x] Benchmark matrix across llama3.2:3b / gemma3 / qwen2.5:7b / qwen3-coder:30b at ctx 2048 / 8192 / 32768 - [x] ROCm Docker build (`docker build --target rocm-7`) clean - [x] ROCm runtime validated on AMD RX 7600 (RDNA3, gfx1102) - [x] Metal runtime validated on Apple Silicon (llama3.2:3b and gemma3) - [x] Non-TQ cache paths (`f16`, `q8_0`, `q4_0`) unchanged — verified `SkipK`/`SkipV` gates and `PresetFromDType` returning false for non-tq types --- <sub>🔄 This issue represents a GitHub Pull Request. It cannot be merged through Gitea due to API limitations.</sub>
GiteaMirror added the pull-request label 2026-04-23 01:47:45 -05:00
Sign in to join this conversation.
1 Participants
Notifications
Due Date
No due date set.
Dependencies

No dependencies set.

Reference: github-starred/ollama#41052