[GH-ISSUE #2547] Dynamically determine context window at runtime #1491

Open
opened 2026-04-12 11:24:00 -05:00 by GiteaMirror · 4 comments
Owner

Originally created by @jmorganca on GitHub (Feb 16, 2024).
Original GitHub issue: https://github.com/ollama/ollama/issues/2547

Originally created by @jmorganca on GitHub (Feb 16, 2024). Original GitHub issue: https://github.com/ollama/ollama/issues/2547
GiteaMirror added the feature request label 2026-04-12 11:24:00 -05:00
Author
Owner

@MarkWard0110 commented on GitHub (Feb 9, 2026):

@jmorganca, I have been prototyping a dynamic context (dynamic KV) in Ollama. https://github.com/ollama/ollama/compare/main...MarkWard0110:fork.ollama:dynamic-kv-cache

The prototype is based on Ollama's new engine. I have not looked into the llama.cpp engine.

The dynamic KV works like this.
A runner starts with a minimum context size. As it generates tokens, it will dynamically increase the KV cache size up to the model's maximum context size. The increment is in blocks. If an allocation runs out of memory, Ollama will pause the request and attempt to create a new runner with a slightly larger context size so the runner can offload the model to additional GPU or System RAM. It will resume generating tokens when the runner is online.

Todo:
I have not determined how the various num_ctx configurations should affect the dynamic KV. That is the next thing I will look into.

<!-- gh-comment-id:3872181827 --> @MarkWard0110 commented on GitHub (Feb 9, 2026): @jmorganca, I have been prototyping a dynamic context (dynamic KV) in Ollama. https://github.com/ollama/ollama/compare/main...MarkWard0110:fork.ollama:dynamic-kv-cache The prototype is based on Ollama's new engine. I have not looked into the llama.cpp engine. The dynamic KV works like this. A runner starts with a minimum context size. As it generates tokens, it will dynamically increase the KV cache size up to the model's maximum context size. The increment is in blocks. If an allocation runs out of memory, Ollama will pause the request and attempt to create a new runner with a slightly larger context size so the runner can offload the model to additional GPU or System RAM. It will resume generating tokens when the runner is online. Todo: I have not determined how the various `num_ctx` configurations should affect the dynamic KV. That is the next thing I will look into.
Author
Owner

@MarkWard0110 commented on GitHub (Feb 9, 2026):

https://github.com/ollama/ollama/issues/1005

<!-- gh-comment-id:3872186530 --> @MarkWard0110 commented on GitHub (Feb 9, 2026): https://github.com/ollama/ollama/issues/1005
Author
Owner

@MarkWard0110 commented on GitHub (Feb 9, 2026):

Dynamic allocation lets the model load with all layers on GPU using a small initial KV allocation, then grow the cache incrementally as the actual workload demands. If growth eventually exhausts VRAM, the system gracefully handles the OOM and attempts a best-effort resume with an adjusted layout.

Configuration

Three new environment variables control the feature (all opt-in):

Variable Default Description
OLLAMA_KV_CACHE_DYNAMIC false Enable dynamic KV cache allocation
OLLAMA_KV_CACHE_INIT 2048 Initial context length (tokens per sequence) to allocate
OLLAMA_KV_CACHE_GROW 1024 Growth block size (tokens per sequence) when more cache is needed

Dynamic allocation is only supported for the unbounded causal cache. Sliding window and chunked attention models fall back to eager (full) allocation automatically.

Changes

Core: Dynamic KV Cache (kvcache/)

  • causal.goInit() allocates OLLAMA_KV_CACHE_INIT × maxSequences cells (rounded/padded) instead of the full maxCells when dynamic mode is enabled. StartForward() has a grow loop: when findLocs() returns ErrKvCacheFull and there is room to grow, it calls grow() before retrying.
  • grow() — Atomically copies existing KV tensors into a larger allocation. All layer copies are batched into a single Compute() call to avoid tripping ggml-cuda's consecutive-update heuristic (which would permanently disable CUDA graph replay). Allocation failures (ml.ErrNoMem) are caught by a deferred panic recovery and wrapped in ErrKvCacheGrow.
  • errors.go (new)ErrKvCacheGrow structured error type with FromCells, ToCells, MaxCells fields and Unwrap() support for transparent error chaining.
  • cache.goStatsProvider interface and Stats struct for reporting AllocatedCells, InitialCells, MaxCells, MaxSequences.
  • wrapper.goWrapperCache implements StatsProvider by aggregating stats from inner caches.

Runner: Error Handling (runner/ollamarunner/)

  • runner.gofailBatch() replaces the previous panic(err) in the run() loop. OOM errors are routed to sequences as user-friendly messages and DoneReasonError, propagated through CompletionResponse.Error to the HTTP API.
  • computeBatch() — Deferred panic recovery catches ml.ErrNoMem panics in the async compute goroutine and routes them through failBatch() instead of crashing the process.
  • /stats endpoint (new) — Reports live context usage (ContextMax, ContextUsed, ContextActive, ContextAllocated, ContextInitial, Slots, SlotsInUse) using atomic reads for lock-free access from hot paths.
  • cache.goInputCacheSlot gains atomic mirrors (inUseAtomic, inputsLenAtomic) for safe concurrent reads from the /stats handler without holding the main mutex.

Server: Best-Effort Resume (server/routes.go)

  • Both GenerateHandler and ChatHandler wrap completion in a retry loop (max 1 retry).
  • On a retryable error (OOM, runner crash), the handler:
    1. Computes a KV grow target via kvGrowTargetNumCtx() — a prompt-length-aware heuristic: min(max(toCells×2, promptCharLen/3), maxCells) / slots
    2. Releases the current runner
    3. Re-acquires a runner with OLLAMA_KV_CACHE_INIT set to the target, SchedSpread=true, and ForceEvictIdle=true
    4. Replays the original prompt plus any already-emitted content
  • isRetryableCompletionError() conservatively identifies OOM/crash errors while never retrying user cancellations.

Scheduler (server/sched.go)

  • ForceEvictIdle option eagerly evicts idle runners before the fit-check to maximize available VRAM for OOM retries.
  • needsReload() properly normalizes SchedSpread, RunnerEnv, and ForceEvictIdle so a resume-loaded runner satisfies subsequent baseline requests without oscillating reloads.

Observability: ollama ps --verbose

  • New --verbose / -v flag on ollama ps adds columns: ACTIVE, VRAM, RAM.
  • CONTEXT column shows allocated/max when dynamic allocation is active (e.g. 4096/131072).
  • VRAM is live-scaled using the ratio ContextAllocated / ContextInitial applied to the load-time KV cache VRAM, reflecting actual usage rather than the load-time snapshot.
  • SizeVRAM tracks the scaled value to prevent phantom CPU offload percentages in the PROCESSOR column.
  • GPU total memory is cached on the runnerRef so display doesn't flicker when the runner is busy.

API Types (api/types.go)

  • ProcessModelResponse extended with: ContextUsed, ContextActive, ContextAllocated (optional ints), VRAMUsed, VRAMFree, VRAMTotal, RAMUsed (int64s).
  • Runner extended with: SchedSpread, RunnerEnv, ForceEvictIdle (JSON-hidden internal fields).
  • CompletionResponse extended with: Error field (string, omitempty).
  • DoneReasonError added to DoneReason enum.

LLM Server (llm/server.go)

  • LlamaServer interface gains VRAMCacheSize(), CPUCacheSize(), InputWeightsSize() for decomposing memory into static (weights + graph) and dynamic (KV cache) portions.
  • assignLayers() accepts spread parameter for per-request GPU spreading.

Runner Stats (ml/runner_stats.go) (new)

  • RunnerStats struct and GetRunnerStatsFromRunner() HTTP client for querying the runner's /stats endpoint with timeout handling.

Tests

Test Package Covers
TestDynamicGrowth kvcache End-to-end grow from init=8 to 40+ cells, data integrity, causal mask correctness, both PermutedV variants
TestDynamicGrowthOOMIsWrapped kvcache Synthetic OOM backend verifies ErrKvCacheGrow wraps ml.ErrNoMem
TestFailBatchPrefersKvGrowCellsMessage runner/ollamarunner failBatch() produces human-readable cell counts message
TestGenerateBestEffortResumeBumpsNumCtxNoDup server Full HTTP integration: generate → fake OOM → resume with bumped num_ctx, no duplicate content
TestGenerateBestEffortResumeWithThinkingEnabled server Resume works with thinking-tag parsing active
TestGenerateBestEffortResumeKeepsMaxNumCtxWhenTargetLower server Resume doesn't reduce num_ctx below existing configuration
TestChatBestEffortResumeBumpsNumCtxNoDupWithThinking server Chat endpoint resume with thinking tags

Known Limitations

  1. No shrink/compaction — Once the KV cache grows, it does not shrink back when context usage decreases. VRAM stays claimed until the model is unloaded.
  2. Sliding window / chunked attention excluded — Dynamic allocation silently falls back to full allocation for these architectures. This is intentional (their tighter invariants make dynamic resizing unsafe) but could benefit from a user-facing log message.
  3. Single retry — The best-effort resume attempts exactly one retry. Progressive backoff (smaller context, more CPU offload) is not implemented.
  4. C-level SIGSEGV on extreme OOM — If ggml_backend_sched_graph_compute_async fails its internal buffer allocation, it can dereference a null pointer at the C level, which Go's recover() cannot catch. The computeBatch panic recovery handles Go-level ErrNoMem panics from Reserve() but cannot intercept C signal crashes. A proper fix would require a null-safety check upstream in ggml.
  5. Error propagation uses string matchingisRetryableCompletionError() and kvGrowTargetNumCtx() parse error message substrings. Structured error propagation through the runner→server boundary would be more robust.

How to Test

# Enable the new engine since it is based on this
export OLLAMA_NEW_ENGINE=1

# Enable dynamic KV cache
export OLLAMA_KV_CACHE_DYNAMIC=1
export OLLAMA_KV_CACHE_INIT=2048
export OLLAMA_KV_CACHE_GROW=1024

# Start the server
./ollama serve

# In another terminal, watch live stats
watch ./ollama ps --verbose

# Run a model — observe CONTEXT growing from 2048/131072
./ollama run llama3.3:70b-instruct-q4_K_M --verbose

# Send a large prompt to trigger growth
# The CONTEXT, VRAM, and RAM columns update in real-time

Files Changed

File Type Description
kvcache/causal.go Modified Dynamic init, grow loop, grow(), Stats()
kvcache/cache.go Modified StatsProvider interface, Stats struct
kvcache/errors.go New ErrKvCacheGrow error type
kvcache/wrapper.go Modified Stats() aggregation
kvcache/causal_test.go Modified TestDynamicGrowth, TestDynamicGrowthOOMIsWrapped
runner/ollamarunner/runner.go Modified failBatch(), computeBatch recovery, /stats, error plumbing
runner/ollamarunner/runner_failbatch_test.go New TestFailBatchPrefersKvGrowCellsMessage
runner/ollamarunner/cache.go Modified Atomic mirrors for stats
server/routes.go Modified Resume loop (generate+chat), kvGrowTargetNumCtx(), VRAM scaling in PsHandler
server/routes_generate_test.go New Resume integration tests
server/sched.go Modified ForceEvictIdle, needsReload() fixes
llm/server.go Modified VRAMCacheSize(), CPUCacheSize(), InputWeightsSize(), DoneReasonError
ml/runner_stats.go New RunnerStats, GetRunnerStatsFromRunner()
api/types.go Modified New fields on ProcessModelResponse, Runner, CompletionResponse
envconfig/config.go Modified KvCacheDynamic, KvCacheInit, KvCacheGrow
cmd/cmd.go Modified ollama ps --verbose
<!-- gh-comment-id:3872923015 --> @MarkWard0110 commented on GitHub (Feb 9, 2026): Dynamic allocation lets the model load with all layers on GPU using a small initial KV allocation, then grow the cache incrementally as the actual workload demands. If growth eventually exhausts VRAM, the system gracefully handles the OOM and attempts a best-effort resume with an adjusted layout. ## Configuration Three new environment variables control the feature (all opt-in): | Variable | Default | Description | |----------|---------|-------------| | `OLLAMA_KV_CACHE_DYNAMIC` | `false` | Enable dynamic KV cache allocation | | `OLLAMA_KV_CACHE_INIT` | `2048` | Initial context length (tokens per sequence) to allocate | | `OLLAMA_KV_CACHE_GROW` | `1024` | Growth block size (tokens per sequence) when more cache is needed | Dynamic allocation is only supported for the unbounded causal cache. Sliding window and chunked attention models fall back to eager (full) allocation automatically. ## Changes ### Core: Dynamic KV Cache (`kvcache/`) - **`causal.go`** — `Init()` allocates `OLLAMA_KV_CACHE_INIT × maxSequences` cells (rounded/padded) instead of the full `maxCells` when dynamic mode is enabled. `StartForward()` has a grow loop: when `findLocs()` returns `ErrKvCacheFull` and there is room to grow, it calls `grow()` before retrying. - **`grow()`** — Atomically copies existing KV tensors into a larger allocation. All layer copies are batched into a single `Compute()` call to avoid tripping ggml-cuda's consecutive-update heuristic (which would permanently disable CUDA graph replay). Allocation failures (`ml.ErrNoMem`) are caught by a deferred panic recovery and wrapped in `ErrKvCacheGrow`. - **`errors.go`** *(new)* — `ErrKvCacheGrow` structured error type with `FromCells`, `ToCells`, `MaxCells` fields and `Unwrap()` support for transparent error chaining. - **`cache.go`** — `StatsProvider` interface and `Stats` struct for reporting `AllocatedCells`, `InitialCells`, `MaxCells`, `MaxSequences`. - **`wrapper.go`** — `WrapperCache` implements `StatsProvider` by aggregating stats from inner caches. ### Runner: Error Handling (`runner/ollamarunner/`) - **`runner.go`** — `failBatch()` replaces the previous `panic(err)` in the `run()` loop. OOM errors are routed to sequences as user-friendly messages and `DoneReasonError`, propagated through `CompletionResponse.Error` to the HTTP API. - **`computeBatch()`** — Deferred panic recovery catches `ml.ErrNoMem` panics in the async compute goroutine and routes them through `failBatch()` instead of crashing the process. - **`/stats` endpoint** *(new)* — Reports live context usage (`ContextMax`, `ContextUsed`, `ContextActive`, `ContextAllocated`, `ContextInitial`, `Slots`, `SlotsInUse`) using atomic reads for lock-free access from hot paths. - **`cache.go`** — `InputCacheSlot` gains atomic mirrors (`inUseAtomic`, `inputsLenAtomic`) for safe concurrent reads from the `/stats` handler without holding the main mutex. ### Server: Best-Effort Resume (`server/routes.go`) - Both `GenerateHandler` and `ChatHandler` wrap completion in a retry loop (max 1 retry). - On a retryable error (OOM, runner crash), the handler: 1. Computes a KV grow target via `kvGrowTargetNumCtx()` — a prompt-length-aware heuristic: `min(max(toCells×2, promptCharLen/3), maxCells) / slots` 2. Releases the current runner 3. Re-acquires a runner with `OLLAMA_KV_CACHE_INIT` set to the target, `SchedSpread=true`, and `ForceEvictIdle=true` 4. Replays the original prompt plus any already-emitted content - `isRetryableCompletionError()` conservatively identifies OOM/crash errors while never retrying user cancellations. ### Scheduler (`server/sched.go`) - `ForceEvictIdle` option eagerly evicts idle runners before the fit-check to maximize available VRAM for OOM retries. - `needsReload()` properly normalizes `SchedSpread`, `RunnerEnv`, and `ForceEvictIdle` so a resume-loaded runner satisfies subsequent baseline requests without oscillating reloads. ### Observability: `ollama ps --verbose` - New `--verbose` / `-v` flag on `ollama ps` adds columns: **ACTIVE**, **VRAM**, **RAM**. - **CONTEXT** column shows `allocated/max` when dynamic allocation is active (e.g. `4096/131072`). - VRAM is live-scaled using the ratio `ContextAllocated / ContextInitial` applied to the load-time KV cache VRAM, reflecting actual usage rather than the load-time snapshot. - `SizeVRAM` tracks the scaled value to prevent phantom CPU offload percentages in the PROCESSOR column. - GPU total memory is cached on the `runnerRef` so display doesn't flicker when the runner is busy. ### API Types (`api/types.go`) - `ProcessModelResponse` extended with: `ContextUsed`, `ContextActive`, `ContextAllocated` (optional ints), `VRAMUsed`, `VRAMFree`, `VRAMTotal`, `RAMUsed` (int64s). - `Runner` extended with: `SchedSpread`, `RunnerEnv`, `ForceEvictIdle` (JSON-hidden internal fields). - `CompletionResponse` extended with: `Error` field (string, omitempty). - `DoneReasonError` added to `DoneReason` enum. ### LLM Server (`llm/server.go`) - `LlamaServer` interface gains `VRAMCacheSize()`, `CPUCacheSize()`, `InputWeightsSize()` for decomposing memory into static (weights + graph) and dynamic (KV cache) portions. - `assignLayers()` accepts `spread` parameter for per-request GPU spreading. ### Runner Stats (`ml/runner_stats.go`) *(new)* - `RunnerStats` struct and `GetRunnerStatsFromRunner()` HTTP client for querying the runner's `/stats` endpoint with timeout handling. ## Tests | Test | Package | Covers | |------|---------|--------| | `TestDynamicGrowth` | `kvcache` | End-to-end grow from init=8 to 40+ cells, data integrity, causal mask correctness, both PermutedV variants | | `TestDynamicGrowthOOMIsWrapped` | `kvcache` | Synthetic OOM backend verifies `ErrKvCacheGrow` wraps `ml.ErrNoMem` | | `TestFailBatchPrefersKvGrowCellsMessage` | `runner/ollamarunner` | `failBatch()` produces human-readable cell counts message | | `TestGenerateBestEffortResumeBumpsNumCtxNoDup` | `server` | Full HTTP integration: generate → fake OOM → resume with bumped num_ctx, no duplicate content | | `TestGenerateBestEffortResumeWithThinkingEnabled` | `server` | Resume works with thinking-tag parsing active | | `TestGenerateBestEffortResumeKeepsMaxNumCtxWhenTargetLower` | `server` | Resume doesn't reduce num_ctx below existing configuration | | `TestChatBestEffortResumeBumpsNumCtxNoDupWithThinking` | `server` | Chat endpoint resume with thinking tags | ## Known Limitations 1. **No shrink/compaction** — Once the KV cache grows, it does not shrink back when context usage decreases. VRAM stays claimed until the model is unloaded. 2. **Sliding window / chunked attention excluded** — Dynamic allocation silently falls back to full allocation for these architectures. This is intentional (their tighter invariants make dynamic resizing unsafe) but could benefit from a user-facing log message. 3. **Single retry** — The best-effort resume attempts exactly one retry. Progressive backoff (smaller context, more CPU offload) is not implemented. 4. **C-level SIGSEGV on extreme OOM** — If `ggml_backend_sched_graph_compute_async` fails its internal buffer allocation, it can dereference a null pointer at the C level, which Go's `recover()` cannot catch. The `computeBatch` panic recovery handles Go-level `ErrNoMem` panics from `Reserve()` but cannot intercept C signal crashes. A proper fix would require a null-safety check upstream in ggml. 5. **Error propagation uses string matching** — `isRetryableCompletionError()` and `kvGrowTargetNumCtx()` parse error message substrings. Structured error propagation through the runner→server boundary would be more robust. ## How to Test ```bash # Enable the new engine since it is based on this export OLLAMA_NEW_ENGINE=1 # Enable dynamic KV cache export OLLAMA_KV_CACHE_DYNAMIC=1 export OLLAMA_KV_CACHE_INIT=2048 export OLLAMA_KV_CACHE_GROW=1024 # Start the server ./ollama serve # In another terminal, watch live stats watch ./ollama ps --verbose # Run a model — observe CONTEXT growing from 2048/131072 ./ollama run llama3.3:70b-instruct-q4_K_M --verbose # Send a large prompt to trigger growth # The CONTEXT, VRAM, and RAM columns update in real-time ``` ## Files Changed | File | Type | Description | |------|------|-------------| | `kvcache/causal.go` | Modified | Dynamic init, grow loop, `grow()`, `Stats()` | | `kvcache/cache.go` | Modified | `StatsProvider` interface, `Stats` struct | | `kvcache/errors.go` | New | `ErrKvCacheGrow` error type | | `kvcache/wrapper.go` | Modified | `Stats()` aggregation | | `kvcache/causal_test.go` | Modified | `TestDynamicGrowth`, `TestDynamicGrowthOOMIsWrapped` | | `runner/ollamarunner/runner.go` | Modified | `failBatch()`, `computeBatch` recovery, `/stats`, error plumbing | | `runner/ollamarunner/runner_failbatch_test.go` | New | `TestFailBatchPrefersKvGrowCellsMessage` | | `runner/ollamarunner/cache.go` | Modified | Atomic mirrors for stats | | `server/routes.go` | Modified | Resume loop (generate+chat), `kvGrowTargetNumCtx()`, VRAM scaling in `PsHandler` | | `server/routes_generate_test.go` | New | Resume integration tests | | `server/sched.go` | Modified | `ForceEvictIdle`, `needsReload()` fixes | | `llm/server.go` | Modified | `VRAMCacheSize()`, `CPUCacheSize()`, `InputWeightsSize()`, `DoneReasonError` | | `ml/runner_stats.go` | New | `RunnerStats`, `GetRunnerStatsFromRunner()` | | `api/types.go` | Modified | New fields on `ProcessModelResponse`, `Runner`, `CompletionResponse` | | `envconfig/config.go` | Modified | `KvCacheDynamic`, `KvCacheInit`, `KvCacheGrow` | | `cmd/cmd.go` | Modified | `ollama ps --verbose` |
Author
Owner

@MarkWard0110 commented on GitHub (Feb 9, 2026):

I have not tested this with Vulkan, Apple, or AMD

<!-- gh-comment-id:3873402220 --> @MarkWard0110 commented on GitHub (Feb 9, 2026): I have not tested this with Vulkan, Apple, or AMD
Sign in to join this conversation.
1 Participants
Notifications
Due Date
No due date set.
Dependencies

No dependencies set.

Reference: github-starred/ollama#1491