mirror of
https://github.com/ollama/ollama.git
synced 2026-03-09 07:16:38 -05:00
38 lines
2.3 KiB
Diff
38 lines
2.3 KiB
Diff
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
|
From: jmorganca <jmorganca@gmail.com>
|
|
Date: Sun, 22 Feb 2026 14:12:30 -0800
|
|
Subject: [PATCH] ggml-metal: guard mul_mat_id map0 and add ne20=22
|
|
specialization
|
|
|
|
---
|
|
ggml/src/ggml-metal/ggml-metal-ops.cpp | 3 ++-
|
|
ggml/src/ggml-metal/ggml-metal.metal | 1 +
|
|
2 files changed, 3 insertions(+), 1 deletion(-)
|
|
|
|
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
|
index 4ac135603..ac5ad53db 100644
|
|
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
|
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
|
@@ -1961,7 +1961,8 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
|
// ne21 = n_rows (batch size)
|
|
const int ne21_mm_id_min = 32;
|
|
|
|
- if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
|
|
+ if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min) &&
|
|
+ (ne20 == 1 || ne20 == 2 || ne20 == 4 || ne20 == 6 || ne20 == 8 || ne20 == 10 || ne20 == 16 || ne20 == 22)) {
|
|
// some Metal matrix data types require aligned pointers
|
|
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
|
//switch (op->src[0]->type) {
|
|
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
|
index c37447a10..4f338aa13 100644
|
|
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
|
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
|
@@ -9427,6 +9427,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
|
|
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
|
|
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
|
|
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
|
|
+template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
|
|
|
|
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
|
kernel void kernel_mul_mm_id(
|