mirror of
https://github.com/ollama/ollama.git
synced 2026-03-09 07:16:38 -05:00
This reverts commit 1044b0419a.
This commit is contained in:
@@ -1,248 +0,0 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: nobody <>
|
||||
Date: Fri, 23 Jan 2026 12:42:53 -0800
|
||||
Subject: [PATCH] ggml: enable MLA flash attention for GLM-4.7-flash
|
||||
|
||||
Add support for gqa_ratio 4 in MLA flash attention kernels. GLM-4.7-flash
|
||||
uses head size 576 with gqa_ratio 4, which was previously only supported
|
||||
for gqa_ratio 16 (DeepSeek).
|
||||
|
||||
Metal changes:
|
||||
- Enable head size 576 for flash attention
|
||||
- Increase simdgroups to 8 for large heads (>=512)
|
||||
- Add case 8 kernel dispatch for 8 simdgroups
|
||||
|
||||
CUDA changes:
|
||||
- Add gqa_ratio 4 support for head 576/512
|
||||
- Add tile configs for (576, 512, 4) and (576, 512, 8)
|
||||
- Add MMA config cases for ncols 4
|
||||
- Add template instances for ncols2=4
|
||||
---
|
||||
ggml/src/ggml-cuda/fattn-mma-f16.cuh | 15 ++++++++++++---
|
||||
ggml/src/ggml-cuda/fattn-tile.cuh | 16 ++++++++++++++++
|
||||
ggml/src/ggml-cuda/fattn.cu | 12 ++++++++----
|
||||
.../fattn-mma-f16-instance-ncols1_16-ncols2_4.cu | 1 +
|
||||
.../fattn-mma-f16-instance-ncols1_2-ncols2_4.cu | 1 +
|
||||
.../fattn-mma-f16-instance-ncols1_4-ncols2_4.cu | 1 +
|
||||
.../fattn-mma-f16-instance-ncols1_8-ncols2_4.cu | 1 +
|
||||
ggml/src/ggml-metal/ggml-metal-device.m | 8 ++------
|
||||
ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +-
|
||||
ggml/src/ggml-metal/ggml-metal.metal | 1 +
|
||||
10 files changed, 44 insertions(+), 14 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
|
||||
index 7bd1044c1..a627302f9 100644
|
||||
--- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh
|
||||
+++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
|
||||
@@ -66,7 +66,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
|
||||
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
||||
@@ -80,7 +81,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
|
||||
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
||||
@@ -89,7 +91,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
|
||||
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);
|
||||
@@ -1585,3 +1588,9 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
||||
+
|
||||
+// GLM 4.7 Flash uses gqa_ratio 4:
|
||||
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
|
||||
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||
diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh
|
||||
index 7c4d6fe67..682fb366e 100644
|
||||
--- a/ggml/src/ggml-cuda/fattn-tile.cuh
|
||||
+++ b/ggml/src/ggml-cuda/fattn-tile.cuh
|
||||
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
|
||||
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 64, 64)
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
|
||||
return 0;
|
||||
@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
|
||||
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 32, 64)
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
|
||||
|
||||
return 0;
|
||||
@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
|
||||
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 64, 64)
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
|
||||
|
||||
@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
|
||||
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 4, 64, 64)
|
||||
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 4, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
|
||||
|
||||
@@ -1187,6 +1195,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
+ if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
|
||||
+ return;
|
||||
+ }
|
||||
+ if (use_gqa_opt && gqa_ratio % 4 == 0) {
|
||||
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
|
||||
+ return;
|
||||
+ }
|
||||
}
|
||||
|
||||
if constexpr (DV <= 256) {
|
||||
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
|
||||
index 015540666..1693479cb 100644
|
||||
--- a/ggml/src/ggml-cuda/fattn.cu
|
||||
+++ b/ggml/src/ggml-cuda/fattn.cu
|
||||
@@ -111,7 +111,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
|
||||
break;
|
||||
case 576: {
|
||||
- // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
|
||||
+ // For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
@@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
- GGML_ASSERT(gqa_ratio % 16 == 0);
|
||||
- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
+ GGML_ASSERT(gqa_ratio % 4 == 0);
|
||||
+ if (gqa_ratio % 16 == 0) {
|
||||
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
+ } else {
|
||||
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
|
||||
+ }
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -251,7 +255,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
if (V->ne[0] != 512) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
- if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
|
||||
+ if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
|
||||
index 2074e954a..517993cb0 100644
|
||||
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
|
||||
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
|
||||
+DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
|
||||
index 24c64cf00..97b19c67a 100644
|
||||
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
|
||||
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
|
||||
+DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
|
||||
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
|
||||
index 1ada657f1..989626dfa 100644
|
||||
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
|
||||
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
|
||||
+DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
|
||||
index 86d4ffae2..173de7aac 100644
|
||||
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
|
||||
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
|
||||
+DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
index f24270bb1..7b5ee968c 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-device.m
|
||||
@@ -1071,12 +1071,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
op->src[0]->ne[0] != 112 &&
|
||||
op->src[0]->ne[0] != 128 &&
|
||||
op->src[0]->ne[0] != 192 &&
|
||||
- op->src[0]->ne[0] != 256) {
|
||||
- return false;
|
||||
- }
|
||||
- if (op->src[0]->ne[0] == 576) {
|
||||
- // DeepSeek sizes
|
||||
- // TODO: disabled for now, until optmized
|
||||
+ op->src[0]->ne[0] != 256 &&
|
||||
+ op->src[0]->ne[0] != 576) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->type != op->src[2]->type) {
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
index e99c1763f..80864f303 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
@@ -2456,7 +2456,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
//nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
||||
- int32_t nsg = 4;
|
||||
+ int32_t nsg = ne00 >= 512 ? 8 : 4;
|
||||
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index c98d269d1..d33c16079 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -6166,6 +6166,7 @@ kernel void kernel_flash_attn_ext(
|
||||
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
||||
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
|
||||
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
|
||||
+ case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
|
||||
}
|
||||
#undef FWD_TMPL
|
||||
#undef FWD_ARGS
|
||||
Reference in New Issue
Block a user