Files
ollama/llama/patches/0034-ggml-metal-guard-mul_mat_id-map0-and-add-ne20-22-spe.patch
2026-02-22 15:09:14 -08:00

38 lines
2.3 KiB
Diff

From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: jmorganca <jmorganca@gmail.com>
Date: Sun, 22 Feb 2026 14:12:30 -0800
Subject: [PATCH] ggml-metal: guard mul_mat_id map0 and add ne20=22
specialization
---
ggml/src/ggml-metal/ggml-metal-ops.cpp | 3 ++-
ggml/src/ggml-metal/ggml-metal.metal | 1 +
2 files changed, 3 insertions(+), 1 deletion(-)
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
index 4ac135603..ac5ad53db 100644
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
@@ -1961,7 +1961,8 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
// ne21 = n_rows (batch size)
const int ne21_mm_id_min = 32;
- if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
+ if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min) &&
+ (ne20 == 1 || ne20 == 2 || ne20 == 4 || ne20 == 6 || ne20 == 8 || ne20 == 10 || ne20 == 16 || ne20 == 22)) {
// some Metal matrix data types require aligned pointers
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
//switch (op->src[0]->type) {
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index c37447a10..4f338aa13 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -9427,6 +9427,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
+template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm_id(