Skip to content

Commit bd9eee7

Browse files
committed
CANN: improve ACL graph matching
Record `ne` and `nb` information for src tensors and include them in the graph matching check. This enhances the robustness of ACL graph matching by preventing incorrect matches when src tensors share the same data address but differ in shape or stride.
1 parent 1d0125b commit bd9eee7

File tree

2 files changed

+43
-9
lines changed

2 files changed

+43
-9
lines changed

ggml/src/ggml-cann/common.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,11 +341,18 @@ class cann_task_queue {
341341

342342
#ifdef USE_ACL_GRAPH
343343
struct ggml_graph_node_properties {
344+
// dst tensor
344345
void * node_address;
345-
ggml_op node_op;
346346
int64_t ne[GGML_MAX_DIMS];
347347
size_t nb[GGML_MAX_DIMS];
348+
349+
// src tensor
348350
void * src_address[GGML_MAX_SRC];
351+
int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS];
352+
size_t src_nb[GGML_MAX_SRC][GGML_MAX_DIMS];
353+
354+
// op
355+
ggml_op node_op;
349356
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
350357
};
351358

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2186,7 +2186,15 @@ static void add_lru_matched_graph_node_properties(
21862186
std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
21872187

21882188
for (int src = 0; src < GGML_MAX_SRC; ++src) {
2189-
prop.src_address[src] = node->src[src] ? node->src[src]->data : nullptr;
2189+
if (node->src[src]) {
2190+
prop.src_address[src] = node->src[src]->data;
2191+
std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
2192+
std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
2193+
} else {
2194+
prop.src_address[src] = nullptr;
2195+
std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
2196+
std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
2197+
}
21902198
}
21912199

21922200
memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
@@ -2206,14 +2214,18 @@ static void add_lru_matched_graph_node_properties(
22062214
* @param graph_node_properties The stored properties of a CANN graph node.
22072215
* @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
22082216
*/
2209-
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
2217+
static bool ggml_graph_node_has_matching_properties(
2218+
ggml_tensor * node,
2219+
ggml_graph_node_properties * graph_node_properties) {
22102220
if (node->data != graph_node_properties->node_address &&
2211-
node->op != GGML_OP_VIEW) {
2221+
node->op != GGML_OP_VIEW) {
22122222
return false;
22132223
}
2224+
22142225
if (node->op != graph_node_properties->node_op) {
22152226
return false;
22162227
}
2228+
22172229
for (int i = 0; i < GGML_MAX_DIMS; i++) {
22182230
if (node->ne[i] != graph_node_properties->ne[i]) {
22192231
return false;
@@ -2222,14 +2234,29 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
22222234
return false;
22232235
}
22242236
}
2237+
22252238
for (int i = 0; i < GGML_MAX_SRC; i++) {
2226-
if (node->src[i] &&
2227-
node->src[i]->data != graph_node_properties->src_address[i] &&
2228-
node->op != GGML_OP_VIEW
2229-
) {
2230-
return false;
2239+
if (node->src[i]) {
2240+
if (node->src[i]->data != graph_node_properties->src_address[i] &&
2241+
node->op != GGML_OP_VIEW) {
2242+
return false;
2243+
}
2244+
2245+
for (int d = 0; d < GGML_MAX_DIMS; d++) {
2246+
if (node->src[i]->ne[d] != graph_node_properties->src_ne[i][d]) {
2247+
return false;
2248+
}
2249+
if (node->src[i]->nb[d] != graph_node_properties->src_nb[i][d]) {
2250+
return false;
2251+
}
2252+
}
2253+
} else {
2254+
if (graph_node_properties->src_address[i] != nullptr) {
2255+
return false;
2256+
}
22312257
}
22322258
}
2259+
22332260
if (node->op == GGML_OP_SCALE &&
22342261
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
22352262
return false;

0 commit comments

Comments
 (0)