[PR #15332] ggml: add CUDA flash attention support for head dimension 512 for Gemma4 #25657

Open
opened 2026-04-19 18:20:11 -05:00 by GiteaMirror · 0 comments
Owner

📋 Pull Request Information

Original PR: https://github.com/ollama/ollama/pull/15332
Author: @mazphilip
Created: 4/5/2026
Status: 🔄 Open

Base: mainHead: ggml/fattn-head-dim-512


📝 Commits (1)

  • 69cedff ggml: add CUDA flash attention support for head dimension 512 for Gemma4 support

📊 Changes

14 files changed (+80 additions, -10 deletions)

View changed files

📝 fs/ggml/ggml.go (+1 -0)
📝 ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh (+23 -1)
📝 ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cu (+4 -0)
📝 ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh (+29 -8)
📝 ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu (+10 -1)
📝 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu (+1 -0)
📝 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu (+1 -0)
📝 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu (+1 -0)
📝 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu (+1 -0)
📝 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu (+1 -0)
📝 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu (+1 -0)
📝 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu (+1 -0)
📝 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu (+1 -0)
ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu (+5 -0)

📄 Description

Summary

Backport of ggml-org/llama.cpp#20998 into ollama's ggml backend.
I am not sure if there is a formal way how this is done for ollama. The llama.cpp release that contains this fix is: https://github.com/ggml-org/llama.cpp/releases/tag/b8609

Why this is needed:

Gemma4's global attention layers use head_dim=512, which has no CUDA flash attention kernel in the current llama.cpp snapshot. When FA is enabled, these ops silently fall back to CPU, during inference.

  • ollama run with short prompts did not noticeably trigger the fallback, but ollama launch claude (and VS Code Copilot) did. Maybe due to large system prompts with tool definitions.

Changes:

Follows ggml-org/llama.cpp#20998

  • Add case 512 to MMA and tile kernel dispatch
  • Add kernel configs for Ampere, Turing, Volta, and RDNA architectures
  • Add template instances for D=512
  • Exclude D=512 from WMMA path and vector kernel (no D=512 vec templates)
  • Add gemma4 to flash attention default whitelist
    • this has been added and revoked in #15311 - unclear why revoked and locally this works so I suggest to re-add

Fixes #15237, #15350

Test plan

  • Verified on RTX 5090 + RTX 3090 Ti with gemma4:31b Q4_K_M (FA on, 128K context, 100% GPU)
    • verified that no CPU spike during ollama launch claude/vscode with long system prompts
    • verified no regression on other tool-enabled models: nemotron-cascade-2, qwen3.5:35b-a3b, gpt-oss:20b
  • go test ./fs/ggml/ ./ml/backend/ggml/ passes

Evaluation steps used:

# Dont have vulkan locally, used PATH to CUDA 13.0 nvcc compiler:
cmake -B build -DCMAKE_DISABLE_FIND_PACKAGE_Vulkan=TRUE
cmake --build build -j$(nproc) 
go build -o ./ollama .

# Deploy
sudo systemctl stop ollama
sudo cp ./ollama /usr/local/bin/ollama
sudo cp ./build/lib/ollama/libggml-cuda.so /usr/local/lib/ollama/cuda_v13/libggml-cuda.so
sudo systemctl daemon-reload
sudo systemctl start ollama

# Enable FA (not needed with whitelist)
# In /etc/systemd/system/ollama.service.d/override.conf:
#   Environment="OLLAMA_FLASH_ATTENTION=1"
# Then: sudo systemctl daemon-reload && sudo systemctl restart ollama

# Test
ollama launch claude
# select model
# "hi"

Checks:

  1. ollama ps #if running
  2. nvidia-smi # careful: Memory will be filled but util ramps up then falls to basically 0% after the prompt is triggered
  3. perf top for cpu util - if FA doesnt work, you should see things like following (I use a Q8 KV cache but it will max out CPU regardless):
    48.23%  ollama          libggml-base.so.0.0.0   [.] dequantize_row_q8_0
            |--11.46%--ggml_compute_forward_flash_attn_ext
    22.67%  ollama          libggml-cpu-haswell.so  [.] ggml_vec_dot_q8_0_q8_0
            |--5.89%--ggml_compute_forward_flash_attn_ext
    17.05%  ollama          libggml-cpu-haswell.so  [.] ggml_compute_forward_flash_attn_ext
            |--2.66%--ggml_compute_forward_flash_attn_ext
    
    1. Note that Gemma4 has some vision modules on the CPU - these would still be there and not a sign of FA not working

AI disclaimer: AI was used in the triaging and resolution of the issue.


🔄 This issue represents a GitHub Pull Request. It cannot be merged through Gitea due to API limitations.

## 📋 Pull Request Information **Original PR:** https://github.com/ollama/ollama/pull/15332 **Author:** [@mazphilip](https://github.com/mazphilip) **Created:** 4/5/2026 **Status:** 🔄 Open **Base:** `main` ← **Head:** `ggml/fattn-head-dim-512` --- ### 📝 Commits (1) - [`69cedff`](https://github.com/ollama/ollama/commit/69cedff04a6b30f86ce3a880e9e427538889bede) ggml: add CUDA flash attention support for head dimension 512 for Gemma4 support ### 📊 Changes **14 files changed** (+80 additions, -10 deletions) <details> <summary>View changed files</summary> 📝 `fs/ggml/ggml.go` (+1 -0) 📝 `ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh` (+23 -1) 📝 `ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cu` (+4 -0) 📝 `ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh` (+29 -8) 📝 `ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu` (+10 -1) 📝 `ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu` (+1 -0) 📝 `ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu` (+1 -0) 📝 `ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu` (+1 -0) 📝 `ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu` (+1 -0) 📝 `ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu` (+1 -0) 📝 `ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu` (+1 -0) 📝 `ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu` (+1 -0) 📝 `ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu` (+1 -0) ➕ `ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu` (+5 -0) </details> ### 📄 Description ## Summary Backport of ggml-org/llama.cpp#20998 into ollama's ggml backend. I am not sure if there is a formal way how this is done for ollama. The llama.cpp release that contains this fix is: https://github.com/ggml-org/llama.cpp/releases/tag/b8609 ## Why this is needed: Gemma4's global attention layers use head_dim=512, which has no CUDA flash attention kernel in the current llama.cpp snapshot. When FA is enabled, these ops silently fall back to CPU, during inference. - `ollama run` with short prompts did not noticeably trigger the fallback, but ollama launch claude (and VS Code Copilot) did. Maybe due to large system prompts with tool definitions. ## Changes: Follows ggml-org/llama.cpp#20998 - Add `case 512` to MMA and tile kernel dispatch - Add kernel configs for Ampere, Turing, Volta, and RDNA architectures - Add template instances for D=512 - Exclude D=512 from WMMA path and vector kernel (no D=512 vec templates) - **Add gemma4 to flash attention default whitelist** - this has been added and revoked in #15311 - unclear why revoked and locally this works so I suggest to re-add ## Related issues: Fixes #15237, #15350 ## Test plan - [x] Verified on RTX 5090 + RTX 3090 Ti with `gemma4:31b Q4_K_M` (FA on, 128K context, 100% GPU) - [x] verified that no CPU spike during `ollama launch claude/vscode` with long system prompts - [x] verified no regression on other tool-enabled models: nemotron-cascade-2, qwen3.5:35b-a3b, gpt-oss:20b - [x] `go test ./fs/ggml/ ./ml/backend/ggml/` passes ### Evaluation steps used: ``` # Dont have vulkan locally, used PATH to CUDA 13.0 nvcc compiler: cmake -B build -DCMAKE_DISABLE_FIND_PACKAGE_Vulkan=TRUE cmake --build build -j$(nproc) go build -o ./ollama . # Deploy sudo systemctl stop ollama sudo cp ./ollama /usr/local/bin/ollama sudo cp ./build/lib/ollama/libggml-cuda.so /usr/local/lib/ollama/cuda_v13/libggml-cuda.so sudo systemctl daemon-reload sudo systemctl start ollama # Enable FA (not needed with whitelist) # In /etc/systemd/system/ollama.service.d/override.conf: # Environment="OLLAMA_FLASH_ATTENTION=1" # Then: sudo systemctl daemon-reload && sudo systemctl restart ollama # Test ollama launch claude # select model # "hi" ``` Checks: 1. `ollama ps` #if running 2. nvidia-smi # careful: Memory will be filled but util ramps up then falls to basically 0% after the prompt is triggered 2. perf top for cpu util - if FA doesnt work, you should see things like following (I use a Q8 KV cache but it will max out CPU regardless): ``` 48.23% ollama libggml-base.so.0.0.0 [.] dequantize_row_q8_0 |--11.46%--ggml_compute_forward_flash_attn_ext 22.67% ollama libggml-cpu-haswell.so [.] ggml_vec_dot_q8_0_q8_0 |--5.89%--ggml_compute_forward_flash_attn_ext 17.05% ollama libggml-cpu-haswell.so [.] ggml_compute_forward_flash_attn_ext |--2.66%--ggml_compute_forward_flash_attn_ext ``` 1. Note that Gemma4 has some vision modules on the CPU - these would still be there and not a sign of FA not working **AI disclaimer: AI was used in the triaging and resolution of the issue.** --- <sub>🔄 This issue represents a GitHub Pull Request. It cannot be merged through Gitea due to API limitations.</sub>
GiteaMirror added the pull-request label 2026-04-19 18:20:11 -05:00
Sign in to join this conversation.
1 Participants
Notifications
Due Date
No due date set.
Dependencies

No dependencies set.

Reference: github-starred/ollama#25657