[GH-ISSUE #15939] ggml_cuda_cpy: unsupported type combination (q4_K to q4_K) on Blackwell (compute 12.0) — RTX 5070 Ti Laptop GPU #72209

Open
opened 2026-05-05 03:38:06 -05:00 by GiteaMirror · 1 comment
Owner

Originally created by @nickherahomes on GitHub (May 3, 2026).
Original GitHub issue: https://github.com/ollama/ollama/issues/15939

Summary

Loading any Q4_K_M GGUF model crashes the Ollama runner on RTX 50-series (Blackwell, compute capability 12.0). The CUDA backend hits an unsupported type combination in ggml_cuda_cpy and the runner exits with a 500 before producing any tokens. Forcing the Vulkan backend (OLLAMA_LLM_LIBRARY=vulkan) is a working workaround but introduces its own model-swap stability issue (separate from this report).

Environment

  • OS: Windows 11 Home 26200
  • Ollama: 0.22.1 (latest available via winget at time of report)
  • GPU: NVIDIA GeForce RTX 5070 Ti Laptop GPU (Blackwell, SM 12.0)
  • VRAM: 12 GB
  • NVIDIA Driver: 577.13
  • CUDA Version (driver-reported): 12.9
  • Models tested: 10 separate fulcrum-*-q4km.gguf files (Gemma 3 LoRA fine-tunes quantized to Q4_K_M; ~7.25 GB each)

Reproduce

  1. Install Ollama 0.22.1 on a Blackwell GPU box.
  2. Register any Q4_K_M GGUF model:
    ollama create fulcrum-test -f Modelfile
    
    where Modelfile is:
    FROM /path/to/fulcrum-command-q4km.gguf
    PARAMETER temperature 0.7
    PARAMETER num_ctx 4096
    
  3. Try to run it:
    ollama run fulcrum-test "hello"
    

Expected

Model loads, inference proceeds.

Actual

Returns: Error: 500 Internal Server Error: model failed to load, this may be due to resource limitations or an internal error, check ollama server logs for details

Server log (relevant tail)

time=... level=INFO source=runner.go:1290 msg=load request="{Operation:fit LoraPath:[] Parallel:1 BatchSize:512 FlashAttention:Enabled KvSize:4096 KvCacheType: NumThreads:8 GPULayers:49[ID:GPU-... Layers:49(0..48)] MultiUserCache:false ProjectorPath: MainGPU:0 UseMmap:false}"
time=... level=INFO source=ggml.go:136 msg="" architecture=gemma3 file_type=Q4_K_M name="" description="" num_tensors=1065 num_key_values=38
load_backend: loaded CPU backend from ...\ggml-cpu-alderlake.dll
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 5070 Ti Laptop GPU, compute capability 12.0, VMM: yes
load_backend: loaded CUDA backend from ...\cuda_v12\ggml-cuda.dll
time=... level=INFO source=ggml.go:104 msg=system ... CUDA.0.ARCHS=500,520,600,610,700,750,800,860,890,900,1200 CUDA.0.USE_GRAPHS=1 CUDA.0.PEER_MAX_BATCH_SIZE=128 compiler=cgo(clang)
C:\a\ollama\ollama\ml\backend\ggml\ggml\src\ggml-cuda\cpy.cu:574: ggml_cuda_cpy: unsupported type combination (q4_K to q4_K)

time=... level=ERROR source=server.go:1219 msg="do load request" error="Post \"http://127.0.0.1:.../load\": read tcp ...->...: wsarecv: An existing connection was forcibly closed by the remote host."

The CUDA.0.ARCHS line shows SM 1200 is in the compiled-arch list — the runner DOES try the CUDA path on Blackwell rather than skipping it. The crash is specifically in cpy.cu:574 for the q4_K to q4_K type combination.

Other environment knobs tried (no effect on the CUDA path)

  • OLLAMA_FLASH_ATTENTION=0 — still crashes the same way
  • OLLAMA_NEW_ENGINE=1 — still crashes
  • OLLAMA_KV_CACHE_TYPE=q8_0 — still crashes
  • CUDA_VISIBLE_DEVICES="" — Ollama's scheduler still discovers and selects CUDA0 anyway; the CUDA path runs.

Workaround that works

Force Vulkan-only:

  • OLLAMA_LLM_LIBRARY=vulkan
  • OLLAMA_VULKAN=1

The Vulkan backend on the same GPU loads and serves the same Q4_K_M models without the kernel crash. Inference quality is good. Throughput is acceptable for our use case.

(Vulkan has its own separate model-swap stability issue under heavy load — happy to file that as a separate report if helpful.)

Why this matters

Blackwell consumer cards are now the default high-VRAM laptop GPU. Anyone who sets up Ollama on an RTX 50-series machine and tries any Q4_K_M model gets a hard 500 with no obvious next step — the error message ("resource limitations or an internal error") doesn't mention the kernel mismatch, so the workaround is hard to find without reading the server log directly.

A graceful fallback (CUDA → Vulkan when the runtime detects an unsupported kernel on the active arch) would resolve this without requiring users to know about the env-var workaround.

Thanks for everything you build — happy to provide additional repros/logs if useful.

Originally created by @nickherahomes on GitHub (May 3, 2026). Original GitHub issue: https://github.com/ollama/ollama/issues/15939 ## Summary Loading any Q4_K_M GGUF model crashes the Ollama runner on RTX 50-series (Blackwell, compute capability 12.0). The CUDA backend hits an unsupported type combination in `ggml_cuda_cpy` and the runner exits with a 500 before producing any tokens. Forcing the Vulkan backend (`OLLAMA_LLM_LIBRARY=vulkan`) is a working workaround but introduces its own model-swap stability issue (separate from this report). ## Environment - OS: Windows 11 Home 26200 - Ollama: 0.22.1 (latest available via winget at time of report) - GPU: NVIDIA GeForce RTX 5070 Ti Laptop GPU (Blackwell, SM 12.0) - VRAM: 12 GB - NVIDIA Driver: 577.13 - CUDA Version (driver-reported): 12.9 - Models tested: 10 separate `fulcrum-*-q4km.gguf` files (Gemma 3 LoRA fine-tunes quantized to Q4_K_M; ~7.25 GB each) ## Reproduce 1. Install Ollama 0.22.1 on a Blackwell GPU box. 2. Register any Q4_K_M GGUF model: ``` ollama create fulcrum-test -f Modelfile ``` where `Modelfile` is: ``` FROM /path/to/fulcrum-command-q4km.gguf PARAMETER temperature 0.7 PARAMETER num_ctx 4096 ``` 3. Try to run it: ``` ollama run fulcrum-test "hello" ``` ## Expected Model loads, inference proceeds. ## Actual Returns: `Error: 500 Internal Server Error: model failed to load, this may be due to resource limitations or an internal error, check ollama server logs for details` ## Server log (relevant tail) ``` time=... level=INFO source=runner.go:1290 msg=load request="{Operation:fit LoraPath:[] Parallel:1 BatchSize:512 FlashAttention:Enabled KvSize:4096 KvCacheType: NumThreads:8 GPULayers:49[ID:GPU-... Layers:49(0..48)] MultiUserCache:false ProjectorPath: MainGPU:0 UseMmap:false}" time=... level=INFO source=ggml.go:136 msg="" architecture=gemma3 file_type=Q4_K_M name="" description="" num_tensors=1065 num_key_values=38 load_backend: loaded CPU backend from ...\ggml-cpu-alderlake.dll ggml_cuda_init: found 1 CUDA devices: Device 0: NVIDIA GeForce RTX 5070 Ti Laptop GPU, compute capability 12.0, VMM: yes load_backend: loaded CUDA backend from ...\cuda_v12\ggml-cuda.dll time=... level=INFO source=ggml.go:104 msg=system ... CUDA.0.ARCHS=500,520,600,610,700,750,800,860,890,900,1200 CUDA.0.USE_GRAPHS=1 CUDA.0.PEER_MAX_BATCH_SIZE=128 compiler=cgo(clang) C:\a\ollama\ollama\ml\backend\ggml\ggml\src\ggml-cuda\cpy.cu:574: ggml_cuda_cpy: unsupported type combination (q4_K to q4_K) time=... level=ERROR source=server.go:1219 msg="do load request" error="Post \"http://127.0.0.1:.../load\": read tcp ...->...: wsarecv: An existing connection was forcibly closed by the remote host." ``` The `CUDA.0.ARCHS` line shows SM 1200 is in the compiled-arch list — the runner DOES try the CUDA path on Blackwell rather than skipping it. The crash is specifically in `cpy.cu:574` for the `q4_K to q4_K` type combination. ## Other environment knobs tried (no effect on the CUDA path) - `OLLAMA_FLASH_ATTENTION=0` — still crashes the same way - `OLLAMA_NEW_ENGINE=1` — still crashes - `OLLAMA_KV_CACHE_TYPE=q8_0` — still crashes - `CUDA_VISIBLE_DEVICES=""` — Ollama's scheduler still discovers and selects CUDA0 anyway; the CUDA path runs. ## Workaround that works Force Vulkan-only: - `OLLAMA_LLM_LIBRARY=vulkan` - `OLLAMA_VULKAN=1` The Vulkan backend on the same GPU loads and serves the same Q4_K_M models without the kernel crash. Inference quality is good. Throughput is acceptable for our use case. (Vulkan has its own separate model-swap stability issue under heavy load — happy to file that as a separate report if helpful.) ## Why this matters Blackwell consumer cards are now the default high-VRAM laptop GPU. Anyone who sets up Ollama on an RTX 50-series machine and tries any Q4_K_M model gets a hard 500 with no obvious next step — the error message ("resource limitations or an internal error") doesn't mention the kernel mismatch, so the workaround is hard to find without reading the server log directly. A graceful fallback (CUDA → Vulkan when the runtime detects an unsupported kernel on the active arch) would resolve this without requiring users to know about the env-var workaround. Thanks for everything you build — happy to provide additional repros/logs if useful.
Author
Owner

@nickherahomes commented on GitHub (May 3, 2026):

Update — driver upgrade does not help.

Updated NVIDIA driver from 577.13 (CUDA runtime 12.9) → 596.36 (CUDA runtime 13.2, GeForce Game Ready Driver, released 2026-04-28). Ollama 0.22.1 unchanged. Same exact crash:

load_backend: loaded CUDA backend from ...\cuda_v13\ggml-cuda.dll
C:\a\ollama\ollama\ml\backend\ggml\ggml\src\ggml-cuda\cpy.cu:574: ggml_cuda_cpy: unsupported type combination (q4_K to q4_K)

So the kernel mismatch is in ggml's CUDA backend itself, not anything the driver fixes — confirms the right place for a fix is upstream code, not user-side driver hygiene. (cuda_v13 lib is now selected per the new runtime, but it has the same kernel gap.)

Vulkan workaround still works on the new driver — same model, same Q4_K_M GGUFs, no kernel error, normal inference. So users who arrive here can use that as a reliable interim path.

<!-- gh-comment-id:4365333968 --> @nickherahomes commented on GitHub (May 3, 2026): **Update — driver upgrade does not help.** Updated NVIDIA driver from 577.13 (CUDA runtime 12.9) → 596.36 (CUDA runtime 13.2, GeForce Game Ready Driver, released 2026-04-28). Ollama 0.22.1 unchanged. Same exact crash: ``` load_backend: loaded CUDA backend from ...\cuda_v13\ggml-cuda.dll C:\a\ollama\ollama\ml\backend\ggml\ggml\src\ggml-cuda\cpy.cu:574: ggml_cuda_cpy: unsupported type combination (q4_K to q4_K) ``` So the kernel mismatch is in ggml's CUDA backend itself, not anything the driver fixes — confirms the right place for a fix is upstream code, not user-side driver hygiene. (`cuda_v13` lib is now selected per the new runtime, but it has the same kernel gap.) Vulkan workaround still works on the new driver — same model, same Q4_K_M GGUFs, no kernel error, normal inference. So users who arrive here can use that as a reliable interim path.
Sign in to join this conversation.
1 Participants
Notifications
Due Date
No due date set.
Dependencies

No dependencies set.

Reference: github-starred/ollama#72209