Revert "model: add MLA absorption for glm4moelite (#13810)" (#13869)

This reverts commit 1044b0419a.
This commit is contained in:
Jeffrey Morgan
2026-01-23 17:14:15 -08:00
committed by GitHub
parent 66831dcf70
commit 2eda97f1c3
16 changed files with 23 additions and 522 deletions

View File

@@ -1,248 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: nobody <>
Date: Fri, 23 Jan 2026 12:42:53 -0800
Subject: [PATCH] ggml: enable MLA flash attention for GLM-4.7-flash
Add support for gqa_ratio 4 in MLA flash attention kernels. GLM-4.7-flash
uses head size 576 with gqa_ratio 4, which was previously only supported
for gqa_ratio 16 (DeepSeek).
Metal changes:
- Enable head size 576 for flash attention
- Increase simdgroups to 8 for large heads (>=512)
- Add case 8 kernel dispatch for 8 simdgroups
CUDA changes:
- Add gqa_ratio 4 support for head 576/512
- Add tile configs for (576, 512, 4) and (576, 512, 8)
- Add MMA config cases for ncols 4
- Add template instances for ncols2=4
---
ggml/src/ggml-cuda/fattn-mma-f16.cuh | 15 ++++++++++++---
ggml/src/ggml-cuda/fattn-tile.cuh | 16 ++++++++++++++++
ggml/src/ggml-cuda/fattn.cu | 12 ++++++++----
.../fattn-mma-f16-instance-ncols1_16-ncols2_4.cu | 1 +
.../fattn-mma-f16-instance-ncols1_2-ncols2_4.cu | 1 +
.../fattn-mma-f16-instance-ncols1_4-ncols2_4.cu | 1 +
.../fattn-mma-f16-instance-ncols1_8-ncols2_4.cu | 1 +
ggml/src/ggml-metal/ggml-metal-device.m | 8 ++------
ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +-
ggml/src/ggml-metal/ggml-metal.metal | 1 +
10 files changed, 44 insertions(+), 14 deletions(-)
diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
index 7bd1044c1..a627302f9 100644
--- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh
+++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
@@ -66,7 +66,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
@@ -80,7 +81,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 96, 64, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
@@ -89,7 +91,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
}
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 64, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);
@@ -1585,3 +1588,9 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
+
+// GLM 4.7 Flash uses gqa_ratio 4:
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh
index 7c4d6fe67..682fb366e 100644
--- a/ggml/src/ggml-cuda/fattn-tile.cuh
+++ b/ggml/src/ggml-cuda/fattn-tile.cuh
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
return 0;
@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
return 0;
@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 4, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 4, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
@@ -1187,6 +1195,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
return;
}
+ if (use_gqa_opt && gqa_ratio % 8 == 0) {
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
+ return;
+ }
+ if (use_gqa_opt && gqa_ratio % 4 == 0) {
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
+ return;
+ }
}
if constexpr (DV <= 256) {
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
index 015540666..1693479cb 100644
--- a/ggml/src/ggml-cuda/fattn.cu
+++ b/ggml/src/ggml-cuda/fattn.cu
@@ -111,7 +111,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
break;
case 576: {
- // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
+ // For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
GGML_ASSERT(V->ne[0] == 512);
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
@@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
- GGML_ASSERT(gqa_ratio % 16 == 0);
- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+ GGML_ASSERT(gqa_ratio % 4 == 0);
+ if (gqa_ratio % 16 == 0) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+ }
} break;
default:
GGML_ABORT("fatal error");
@@ -251,7 +255,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
if (V->ne[0] != 512) {
return BEST_FATTN_KERNEL_NONE;
}
- if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
+ if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
return BEST_FATTN_KERNEL_NONE;
}
break;
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
index 2074e954a..517993cb0 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
index 24c64cf00..97b19c67a 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
index 1ada657f1..989626dfa 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
index 86d4ffae2..173de7aac 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m
index f24270bb1..7b5ee968c 100644
--- a/ggml/src/ggml-metal/ggml-metal-device.m
+++ b/ggml/src/ggml-metal/ggml-metal-device.m
@@ -1071,12 +1071,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
op->src[0]->ne[0] != 112 &&
op->src[0]->ne[0] != 128 &&
op->src[0]->ne[0] != 192 &&
- op->src[0]->ne[0] != 256) {
- return false;
- }
- if (op->src[0]->ne[0] == 576) {
- // DeepSeek sizes
- // TODO: disabled for now, until optmized
+ op->src[0]->ne[0] != 256 &&
+ op->src[0]->ne[0] != 576) {
return false;
}
if (op->src[1]->type != op->src[2]->type) {
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
index e99c1763f..80864f303 100644
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
@@ -2456,7 +2456,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
// simdgroups per threadgroup (a.k.a. warps)
//nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
- int32_t nsg = 4;
+ int32_t nsg = ne00 >= 512 ? 8 : 4;
const size_t smem = FATTN_SMEM(nsg);
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index c98d269d1..d33c16079 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -6166,6 +6166,7 @@ kernel void kernel_flash_attn_ext(
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
+ case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
}
#undef FWD_TMPL
#undef FWD_ARGS