vulkan: temporary cary of vulkan fixes (#12971)

This should be reverted once we update ggml past b6897
This commit is contained in:
Daniel Hiltgen
2025-11-12 08:31:40 -08:00
committed by GitHub
parent cb1cb06478
commit 3a9e8e9fd4
32 changed files with 5838 additions and 667 deletions

View File

@@ -0,0 +1,32 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jeff Bolz <jbolz@nvidia.com>
Date: Wed, 29 Oct 2025 03:53:04 -0500
Subject: [PATCH] vulkan: Call ggml_vk_buffer_write_2d from ggml_vk_buffer_copy
(#16793)
This lets the copy to the destination device use the host-visible
vidmem optimization.
---
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 221e29509..18b7cbccf 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -5654,14 +5654,11 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")");
// Copy device to device
ggml_vk_ensure_sync_staging_buffer(src->device, size);
- ggml_vk_ensure_sync_staging_buffer(dst->device, size);
// Copy to src staging buffer
ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size);
- // memcpy to dst staging buffer
- memcpy(dst->device->sync_staging->ptr, src->device->sync_staging->ptr, size);
// Copy to dst buffer
- ggml_vk_buffer_copy(dst, dst_offset, dst->device->sync_staging, 0, size);
+ ggml_vk_buffer_write_2d(dst, dst_offset, src->device->sync_staging->ptr, 0, size, 1);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,657 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jeff Bolz <jbolz@nvidia.com>
Date: Wed, 29 Oct 2025 08:44:29 -0500
Subject: [PATCH] vulkan: Update topk_moe fusion to handle gpt's late softmax
(#16656)
* vulkan: Update topk_moe fusion to handle gpt's late softmax
Based on #16649.
* Add ggml_check_edges
* Add sync logging to show fusion effects
* handle clamp added in #16655
* Update ggml/src/ggml-impl.h
Co-authored-by: Diego Devesa <slarengh@gmail.com>
---
ggml/src/ggml-impl.h | 16 +
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 304 +++++++++++-------
.../ggml-vulkan/vulkan-shaders/topk_moe.comp | 90 ++++--
3 files changed, 272 insertions(+), 138 deletions(-)
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index 639d551a2..e5c446d1d 100644
--- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h
@@ -693,6 +693,7 @@ GGML_API void ggml_dxgi_pdh_release();
#endif
#ifdef __cplusplus
+#include <array>
#include <initializer_list>
#include <vector>
@@ -708,6 +709,21 @@ inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
}
+// Return true if the edges in the graph match expectations.
+inline bool ggml_check_edges(const struct ggml_cgraph * cgraph,
+ int start_idx,
+ std::initializer_list<std::array<int, 3>> edges) {
+ for (const auto & edge : edges) {
+ int dst_node = edge[0];
+ int src_idx = edge[1];
+ int src_node = edge[2];
+ if (cgraph->nodes[start_idx + dst_node]->src[src_idx] != cgraph->nodes[start_idx + src_node]) {
+ return false;
+ }
+ }
+ return true;
+}
+
// expose GGUF internals for test code
GGML_API size_t gguf_type_size(enum gguf_type type);
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 53b57c179..b2855b078 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -387,12 +387,76 @@ static constexpr uint32_t num_argsort_pipelines = 11;
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
static constexpr uint32_t num_topk_moe_pipelines = 10;
-static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
- GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
- GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
-static constexpr std::array topk_moe { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
- GGML_OP_VIEW, GGML_OP_GET_ROWS };
+static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
+ GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
+ GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
+ GGML_OP_RESHAPE };
+static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
+ GGML_OP_VIEW, GGML_OP_GET_ROWS };
+static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW,
+ GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
+ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
+
+//node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ]
+//node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
+//node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
+//node #981 ( VIEW): ffn_moe_topk-15 ( 0K) [Vulka ] use=4: ffn_moe_argsort-15 ( 0K) [Vulka ]
+//node #982 ( GET_ROWS): ffn_moe_weights-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 (re ( 0K) [Vulka ] ffn_moe_topk-15 ( 0K) [Vulka ]
+//node #983 ( RESHAPE): ffn_moe_weights-15 ( ( 0K) [Vulka ] use=2: ffn_moe_weights-15 ( 0K) [Vulka ]
+//node #984 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ]
+//node #985 ( CLAMP): ffn_moe_weights_sum_ ( 0K) [Vulka ] use=1: ffn_moe_weights_sum- ( 0K) [Vulka ]
+//node #986 ( DIV): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ] ffn_moe_weights_sum_ ( 0K) [Vulka ]
+//node #987 ( RESHAPE): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights_norm ( 0K) [Vulka ]
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_norm_edges {
+ { 1, 0, 0 }, // reshape->src[0] == softmax
+ { 2, 0, 0 }, // argsort->src[0] == softmax
+ { 3, 0, 2 }, // view->src[0] == argsort
+ { 4, 0, 1 }, // get_rows->src[0] == reshape
+ { 4, 1, 3 }, // get_rows->src[1] == view
+ { 5, 0, 4 }, // reshape->src[0] == get_rows
+ { 6, 0, 5 }, // sum_rows->src[0] == reshape
+ { 7, 0, 6 }, // clamp->src[0] == sum_rows
+ { 8, 0, 5 }, // div->src[0] == reshape
+ { 8, 1, 7 }, // div->src[1] == clamp
+ { 9, 0, 8 }, // reshape->src[0] == div
+};
+
+// same as early_softmax_norm but ending after the get_rows
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
+ { 1, 0, 0 }, // reshape->src[0] == softmax
+ { 2, 0, 0 }, // argsort->src[0] == softmax
+ { 3, 0, 2 }, // view->src[0] == argsort
+ { 4, 0, 1 }, // get_rows->src[0] == reshape
+ { 4, 1, 3 }, // get_rows->src[1] == view
+};
+//node #652 ( ARGSORT): ffn_moe_argsort-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 ( 0K) [Vulka ]
+//node #653 ( VIEW): ffn_moe_topk-11 ( 0K) [Vulka ] use=7: ffn_moe_argsort-11 ( 0K) [Vulka ]
+//node #654 ( GET_ROWS): ffn_moe_weights-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 (re ( 0K) [Vulka ] ffn_moe_topk-11 ( 0K) [Vulka ]
+//node #655 ( RESHAPE): ffn_moe_weights-11 ( ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( 0K) [Vulka ]
+//node #656 ( SOFT_MAX): node_656 ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( ( 0K) [Vulka ]
+//node #657 ( RESHAPE): ffn_moe_weights_soft ( 0K) [Vulka ] use=1: node_656 ( 0K) [Vulka ]
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_late_softmax_edges {
+ { 1, 0, 0 }, // view->src[0] == argsort
+ { 2, 1, 1 }, // get_rows->src[1] == view
+ { 3, 0, 2 }, // reshape->src[0] == get_rows
+ { 4, 0, 3 }, // soft_max->src[0] == reshape
+ { 5, 0, 4 }, // reshape->src[0] == soft_max
+};
+
+enum topk_moe_mode {
+ TOPK_MOE_EARLY_SOFTMAX,
+ TOPK_MOE_EARLY_SOFTMAX_NORM,
+ TOPK_MOE_LATE_SOFTMAX,
+ TOPK_MOE_COUNT,
+};
+
+static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
+ topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM :
+ num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX :
+ TOPK_MOE_LATE_SOFTMAX;
+ return mode;
+}
struct vk_device_struct {
std::recursive_mutex mutex;
@@ -607,8 +671,7 @@ struct vk_device_struct {
vk_pipeline pipeline_flash_attn_split_k_reduce;
- // [2] is {!norm, norm}
- vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
+ vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT];
std::vector<vk_pipeline_ref> all_pipelines;
@@ -956,6 +1019,8 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
struct vk_op_topk_moe_push_constants {
uint32_t n_rows;
uint32_t n_expert_used;
+ float clamp_min;
+ float clamp_max;
};
struct vk_op_add_id_push_constants {
@@ -3806,8 +3871,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0}, 1, true, true);
- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1}, 1, true, true);
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true);
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true);
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true);
}
for (auto &c : compiles) {
@@ -8085,8 +8151,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
if (ctx->num_additional_fused_ops) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
GGML_ASSERT(idx < num_topk_moe_pipelines);
- bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
- return ctx->device->pipeline_topk_moe[idx][with_norm];
+ topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
+ return ctx->device->pipeline_topk_moe[idx][mode];
}
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
@@ -8141,6 +8207,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr;
}
case GGML_OP_ARGSORT:
+ if (ctx->num_additional_fused_ops) {
+ uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
+ GGML_ASSERT(idx < num_topk_moe_pipelines);
+ topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
+ return ctx->device->pipeline_topk_moe[idx][mode];
+ }
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
return ctx->device->pipeline_argsort_f32[idx];
@@ -9676,10 +9749,12 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
- bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
+ topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
- ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
- ggml_tensor * ids = cgraph->nodes[node_idx + 3];
+ ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] :
+ (mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] :
+ cgraph->nodes[node_idx + 5];
+ ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3];
GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32);
@@ -9738,9 +9813,14 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
GGML_ASSERT(d_ids != nullptr);
}
- vk_op_topk_moe_push_constants pc;
+ vk_op_topk_moe_push_constants pc {};
pc.n_rows = n_rows;
pc.n_expert_used = n_expert_used;
+ if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
+ ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
+ pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
+ pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
+ }
GGML_ASSERT(n_expert_used <= n_experts);
@@ -11335,7 +11415,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
}
}
}
+
+#define ENABLE_SYNC_LOGGING 0
+
if (need_sync) {
+#if ENABLE_SYNC_LOGGING
+ std::cerr << "sync" << std::endl;
+#endif
ctx->unsynced_nodes_written.clear();
ctx->unsynced_nodes_read.clear();
ggml_vk_sync_buffers(ctx, compute_ctx);
@@ -11353,6 +11439,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
}
}
}
+#if ENABLE_SYNC_LOGGING
+ if (!dryrun) {
+ for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
+ auto *n = cgraph->nodes[node_idx + i];
+ std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " << n->name;
+ if (n->op == GGML_OP_GLU) {
+ std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " ";
+ }
+ std::cerr << std::endl;
+ }
+ }
+#endif
switch (node->op) {
case GGML_OP_REPEAT:
@@ -11531,7 +11629,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
case GGML_OP_ARGSORT:
- ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
+ if (ctx->num_additional_fused_ops) {
+ ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun);
+ } else {
+ ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
+ }
break;
case GGML_OP_SUM:
@@ -12329,30 +12431,27 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
}
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
- int node_idx, bool with_norm) {
+ int node_idx, topk_moe_mode mode) {
- if (with_norm) {
- if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) {
- return false;
- }
- for (size_t i = 0; i < topk_moe_norm.size(); ++i) {
- if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) {
- return false;
- }
- }
- } else {
- if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) {
- return false;
- }
- for (size_t i = 0; i < topk_moe.size(); ++i) {
- if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) {
- return false;
- }
- }
- }
+ const ggml_tensor * softmax;
+ const ggml_tensor * weights;
- const ggml_tensor * softmax = cgraph->nodes[node_idx + 0];
- const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
+ switch (mode) {
+ case TOPK_MOE_EARLY_SOFTMAX_NORM:
+ softmax = cgraph->nodes[node_idx + 0];
+ weights = cgraph->nodes[node_idx + 9];
+ break;
+ case TOPK_MOE_EARLY_SOFTMAX:
+ softmax = cgraph->nodes[node_idx + 0];
+ weights = cgraph->nodes[node_idx + 4];
+ break;
+ case TOPK_MOE_LATE_SOFTMAX:
+ softmax = cgraph->nodes[node_idx + 4];
+ weights = cgraph->nodes[node_idx + 5];
+ break;
+ default:
+ return false;
+ }
const float * op_params = (const float *)softmax->op_params;
@@ -12378,60 +12477,6 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
return false;
}
- // Check that the nodes don't have any unexpected uses
- const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1];
- const ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
- const ggml_tensor * view = cgraph->nodes[node_idx + 3];
- const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
- const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr;
- const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr;
- const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr;
- const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr;
-
- // softmax is used by reshape and argsort
- if (ggml_node_get_use_count(cgraph, node_idx) != 2 ||
- reshape1->src[0] != softmax ||
- argsort->src[0] != softmax) {
- return false;
- }
- // reshape is used by get_rows
- if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 ||
- get_rows->src[0] != reshape1) {
- return false;
- }
- // argsort is used by view
- if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
- view->src[0] != argsort) {
- return false;
- }
- // view is written (via argsort), we can skip checking it
-
- if (with_norm) {
- // get_rows is used by reshape
- if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
- reshape5->src[0] != get_rows) {
- return false;
- }
-
- // reshape is used by sum_rows and div
- if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 ||
- sum_rows->src[0] != reshape5 ||
- div->src[0] != reshape5) {
- return false;
- }
-
- // sum_rows is used by div
- if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 ||
- div->src[1] != sum_rows) {
- return false;
- }
-
- // div/reshape are written
- if (reshape8->src[0] != div) {
- return false;
- }
- }
-
if (!ctx->device->subgroup_arithmetic ||
!ctx->device->subgroup_shuffle ||
!ctx->device->subgroup_require_full_support ||
@@ -12517,10 +12562,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
- ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
- ctx->num_additional_fused_ops = topk_moe.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
+ ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
+ ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
+ ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
}
}
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
@@ -12618,10 +12671,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
- ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
- ctx->num_additional_fused_ops = topk_moe.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
+ ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
+ ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
+ ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
}
}
@@ -12754,25 +12815,44 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
while (first_unused < graph->n_nodes) {
std::vector<int> current_set;
- // Avoid reordering topk_moe_norm
- if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) {
- bool is_topk_moe_norm = true;
- for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
- if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) {
- is_topk_moe_norm = false;
+ // Check for fusion patterns and avoid reordering them
+ auto const &match_pattern = [&](const std::initializer_list<ggml_op> &pattern, int start) -> bool {
+ if (start + (int)pattern.size() <= graph->n_nodes) {
+ bool is_pattern = true;
+ for (size_t j = 0; j < pattern.size(); ++j) {
+ if (graph->nodes[start + j]->op != pattern.begin()[j] || used[start + j]) {
+ is_pattern = false;
+ }
}
+ return is_pattern;
}
- if (is_topk_moe_norm) {
- for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
+ return false;
+ };
+
+ auto const &keep_pattern = [&](const std::initializer_list<ggml_op> &pattern) -> bool {
+ if (match_pattern(pattern, first_unused)) {
+ for (size_t j = 0; j < pattern.size(); ++j) {
new_order.push_back(graph->nodes[first_unused + j]);
used[first_unused + j] = true;
}
while (first_unused < graph->n_nodes && used[first_unused]) {
first_unused++;
}
- continue;
+ return true;
}
+ return false;
+ };
+
+ if (keep_pattern(topk_moe_early_softmax_norm)) {
+ continue;
+ }
+ if (keep_pattern(topk_moe_early_softmax)) {
+ continue;
}
+ if (keep_pattern(topk_moe_late_softmax)) {
+ continue;
+ }
+
// First, grab the next unused node.
current_set.push_back(first_unused);
@@ -12790,6 +12870,12 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
if (is_empty(graph->nodes[j])) {
continue;
}
+ // Don't pull forward nodes from fusion patterns
+ if (match_pattern(topk_moe_early_softmax_norm, j) ||
+ match_pattern(topk_moe_early_softmax, j) ||
+ match_pattern(topk_moe_late_softmax, j)) {
+ continue;
+ }
bool ok = true;
for (int c = first_unused; c < j; ++c) {
if (!used[c] &&
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
index 9e56d5f8a..bc1c278bf 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
@@ -11,6 +11,8 @@ layout (push_constant) uniform parameter
{
uint n_rows;
uint n_expert_used;
+ float clamp_min;
+ float clamp_max;
};
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
@@ -18,6 +20,7 @@ layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
layout(constant_id = 0) const uint WARP_SIZE = 32;
layout(constant_id = 1) const uint n_experts = 512;
layout(constant_id = 2) const bool with_norm = true;
+layout(constant_id = 3) const bool late_softmax = false;
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
@@ -25,53 +28,72 @@ layout (binding = 0, std430) readonly buffer Logits {float logits[];};
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
-void main() {
- const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
- if (row >= n_rows) {
- return;
- }
+const float INFINITY = 1.0 / 0.0;
- const uint logits_offset = n_experts * row;
- const uint weights_offset = n_expert_used * row;
- const uint ids_offset = n_experts * row;
-
- float logits_r[experts_per_thread];
-
- const float INFINITY = 1.0 / 0.0;
+// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
+void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) {
+ float max_val = -INFINITY;
[[unroll]]
- for (uint i = 0; i < n_experts; i += WARP_SIZE) {
- const uint expert = i + gl_LocalInvocationID.x;
- logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY;
+ for (int i = 0; i < experts_per_thread; i++) {
+ const uint idx = lane + i * WARP_SIZE;
+ const bool is_active = !use_limit || (idx < limit);
+ if (is_active) {
+ max_val = max(max_val, vals[i]);
+ }
}
- float max_val = logits_r[0];
+ max_val = subgroupMax(max_val);
+
+ float sum = 0.f;
[[unroll]]
- for (int i = 1; i < experts_per_thread; i++) {
- const float val = logits_r[i];
- max_val = max(val, max_val);
+ for (int i = 0; i < experts_per_thread; i++) {
+ const uint idx = lane + i * WARP_SIZE;
+ const bool is_active = !use_limit || (idx < limit);
+ if (is_active) {
+ const float val = exp(vals[i] - max_val);
+ vals[i] = val;
+ sum += val;
+ } else {
+ vals[i] = 0.f;
+ }
}
- max_val = subgroupMax(max_val);
+ sum = subgroupAdd(sum);
- float wt[experts_per_thread];
- float tmp = 0.f;
+ const float inv_sum = 1.0f / sum;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
- const float val = logits_r[i];
- wt[i] = exp(val - max_val);
- tmp += wt[i];
+ const uint idx = lane + i * WARP_SIZE;
+ const bool is_active = !use_limit || (idx < limit);
+ if (is_active) {
+ vals[i] *= inv_sum;
+ }
}
+}
- tmp = subgroupAdd(tmp);
+void main() {
+ const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
+ if (row >= n_rows) {
+ return;
+ }
- const float inv_sum = 1.0f / tmp;
+ const uint logits_offset = n_experts * row;
+ const uint weights_offset = n_expert_used * row;
+ const uint ids_offset = n_experts * row;
+
+ float wt[experts_per_thread];
[[unroll]]
- for (int i = 0; i < experts_per_thread; i++) {
- wt[i] = wt[i] * inv_sum;
+ for (uint i = 0; i < n_experts; i += WARP_SIZE) {
+ const uint expert = i + gl_LocalInvocationID.x;
+ wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
+ }
+
+ if (!late_softmax) {
+ softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false);
}
// at this point, each thread holds a portion of softmax,
@@ -82,6 +104,11 @@ void main() {
float output_weights[experts_per_thread];
+ [[unroll]]
+ for (int i = 0; i < experts_per_thread; i++) {
+ output_weights[i] = 0.f;
+ }
+
for (int k = 0; k < n_expert_used; k++) {
float max_val = wt[0];
uint max_expert = gl_LocalInvocationID.x;
@@ -121,6 +148,7 @@ void main() {
if (with_norm) {
wt_sum = subgroupAdd(wt_sum);
+ wt_sum = clamp(wt_sum, clamp_min, clamp_max);
const float inv_sum = 1.0f / wt_sum;
[[unroll]]
@@ -129,6 +157,10 @@ void main() {
}
}
+ if (late_softmax) {
+ softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true);
+ }
+
[[unroll]]
for (uint i = 0; i < experts_per_thread; ++i) {
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,85 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jeff Bolz <jbolz@nvidia.com>
Date: Thu, 30 Oct 2025 01:27:41 -0500
Subject: [PATCH] vulkan: Handle argsort with a large number of rows (#16851)
---
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 ++++
ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp | 16 ++++++++++++----
2 files changed, 16 insertions(+), 4 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index aaf4334b5..3604ceb04 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -1084,6 +1084,7 @@ struct vk_op_soft_max_push_constants {
struct vk_op_argsort_push_constants {
uint32_t ncols;
+ uint32_t nrows;
int32_t order;
};
@@ -8710,6 +8711,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
break;
case GGML_OP_ARGSORT:
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
+ elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
break;
case GGML_OP_IM2COL:
{
@@ -9952,9 +9954,11 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
int32_t * op_params = (int32_t *)dst->op_params;
uint32_t ncols = src0->ne[0];
+ uint32_t nrows = ggml_nrows(src0);
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
ncols,
+ nrows,
op_params[0],
}, dryrun);
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
index c81b84452..c4e68bc02 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
@@ -14,6 +14,7 @@ layout (binding = 1) buffer D {int data_d[];};
layout (push_constant) uniform parameter {
uint ncols;
+ uint nrows;
uint order;
} p;
@@ -26,10 +27,9 @@ void swap(uint idx0, uint idx1) {
dst_row[idx1] = tmp;
}
-void argsort(bool needs_bounds_check) {
+void argsort(bool needs_bounds_check, const uint row) {
// bitonic sort
const int col = int(gl_LocalInvocationID.x);
- const uint row = gl_WorkGroupID.y;
const uint row_offset = row * p.ncols;
@@ -72,8 +72,16 @@ void argsort(bool needs_bounds_check) {
void main() {
if (p.ncols == BLOCK_SIZE) {
- argsort(false);
+ uint row = gl_WorkGroupID.y;
+ while (row < p.nrows) {
+ argsort(false, row);
+ row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
+ }
} else {
- argsort(true);
+ uint row = gl_WorkGroupID.y;
+ while (row < p.nrows) {
+ argsort(true, row);
+ row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
+ }
}
}

View File

@@ -0,0 +1,77 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Ruben Ortlam <picard12@live.de>
Date: Fri, 31 Oct 2025 08:14:49 +0100
Subject: [PATCH] vulkan: fix shmem overrun in mmq id shader (#16873)
* vulkan: fix shmem overrun in mmq id shader
* metal : fix mul_mm_id
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
---
ggml/src/ggml-metal/ggml-metal-device.cpp | 2 +-
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp | 4 ++++
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl | 2 +-
tests/test-backend-ops.cpp | 3 +++
4 files changed, 9 insertions(+), 2 deletions(-)
diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp
index 758116342..c78082ac3 100644
--- a/ggml/src/ggml-metal/ggml-metal-device.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-device.cpp
@@ -677,7 +677,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_
char name[256];
snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
- snprintf(name, 256, "%s", base);
+ snprintf(name, 256, "%s_ne02=%d", base, ne02);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
index 8b238ac4b..d955b4fc7 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
@@ -82,9 +82,13 @@ layout (constant_id = 10) const uint WARP = 32;
#include "mul_mmq_shmem_types.glsl"
+#ifdef MUL_MAT_ID
+#define BK_STEP 1
+#else
#ifndef BK_STEP
#define BK_STEP 4
#endif
+#endif
// Shared memory cache
shared block_a_cache buf_a[BM * BK_STEP];
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
index 72fec4404..1c0f5306f 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
@@ -27,7 +27,7 @@ struct block_a_cache {
#elif defined(DATA_A_Q8_0)
#define QUANT_R_MMQ 1
// AMD likes 4, Intel likes 1 and Nvidia likes 2
-#define BK_STEP 1
+// #define BK_STEP 1
struct block_a_cache {
int32_t qs[32/4];
FLOAT_TYPE dm;
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 657b6cc2f..1f8dda383 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -6722,6 +6722,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3));
+ // gpt-oss issue with Vulkan mmq_id
+ test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
+
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
for (int n_mats : {4, 8}) {

View File

@@ -0,0 +1,80 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Masato Nakasaka <masato.nakasaka@intel.com>
Date: Fri, 31 Oct 2025 16:18:59 +0900
Subject: [PATCH] vulkan: Fix crash when FP16 mul_mat accumulation is not
supported (#16796)
* Experimenting crash fix
* added assert for aborting and fixed comment
* changed to check if a pipeline is empty or not
* Moved function in class definition
* replaced with is_empty
* Modified is_empty to check only unaligned pipelines
---
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 20 +++++++++++++-------
1 file changed, 13 insertions(+), 7 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 3604ceb04..80185d9f0 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -146,8 +146,13 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
struct vk_matmul_pipeline_struct {
vk_pipeline l, m, s;
vk_pipeline a_l, a_m, a_s;
+ // Returns true when all unaligned pipelines are null.
+ // We only check for unaligned variants since one of the unaligned pipelines must exist
+ // while aligned pipelines are optional
+ bool is_empty() const {
+ return l == nullptr && m == nullptr && s == nullptr;
+ }
};
-
typedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline;
struct vk_matmul_pipeline2 {
@@ -5080,7 +5085,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
if (src1_type == GGML_TYPE_Q8_1) {
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
- if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
+ if (pipelines->is_empty()) {
return nullptr;
}
@@ -5229,7 +5234,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
if (src1_type == GGML_TYPE_Q8_1) {
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc;
- if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
+ if (pipelines->is_empty()) {
return nullptr;
}
@@ -5264,16 +5269,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
return nullptr;
}
+ vk_matmul_pipeline2& mmp = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type];
// XXX TODO 'prec' is not actually allowed in mul_mat_id.
bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/;
- bool support_fp16acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc != nullptr;
- bool support_fp32acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc != nullptr;
+ bool support_fp16acc = !mmp.f16acc->is_empty();
+ bool support_fp32acc = !mmp.f32acc->is_empty();
if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) {
- return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc;
+ return mmp.f16acc;
} else {
GGML_ASSERT(support_fp32acc);
- return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc;
+ return mmp.f32acc;
}
}