@@ -2186,7 +2186,15 @@ static void add_lru_matched_graph_node_properties(
2186
2186
std::copy_n (node->nb , GGML_MAX_DIMS, prop.nb );
2187
2187
2188
2188
for (int src = 0 ; src < GGML_MAX_SRC; ++src) {
2189
- prop.src_address [src] = node->src [src] ? node->src [src]->data : nullptr ;
2189
+ if (node->src [src]) {
2190
+ prop.src_address [src] = node->src [src]->data ;
2191
+ std::copy_n (node->src [src]->ne , GGML_MAX_DIMS, prop.src_ne [src]);
2192
+ std::copy_n (node->src [src]->nb , GGML_MAX_DIMS, prop.src_nb [src]);
2193
+ } else {
2194
+ prop.src_address [src] = nullptr ;
2195
+ std::fill_n (prop.src_ne [src], GGML_MAX_DIMS, 0 );
2196
+ std::fill_n (prop.src_nb [src], GGML_MAX_DIMS, 0 );
2197
+ }
2190
2198
}
2191
2199
2192
2200
memcpy (prop.op_params , node->op_params , GGML_MAX_OP_PARAMS);
@@ -2206,14 +2214,18 @@ static void add_lru_matched_graph_node_properties(
2206
2214
* @param graph_node_properties The stored properties of a CANN graph node.
2207
2215
* @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
2208
2216
*/
2209
- static bool ggml_graph_node_has_matching_properties (ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
2217
+ static bool ggml_graph_node_has_matching_properties (
2218
+ ggml_tensor * node,
2219
+ ggml_graph_node_properties * graph_node_properties) {
2210
2220
if (node->data != graph_node_properties->node_address &&
2211
- node->op != GGML_OP_VIEW) {
2221
+ node->op != GGML_OP_VIEW) {
2212
2222
return false ;
2213
2223
}
2224
+
2214
2225
if (node->op != graph_node_properties->node_op ) {
2215
2226
return false ;
2216
2227
}
2228
+
2217
2229
for (int i = 0 ; i < GGML_MAX_DIMS; i++) {
2218
2230
if (node->ne [i] != graph_node_properties->ne [i]) {
2219
2231
return false ;
@@ -2222,14 +2234,29 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
2222
2234
return false ;
2223
2235
}
2224
2236
}
2237
+
2225
2238
for (int i = 0 ; i < GGML_MAX_SRC; i++) {
2226
- if (node->src [i] &&
2227
- node->src [i]->data != graph_node_properties->src_address [i] &&
2228
- node->op != GGML_OP_VIEW
2229
- ) {
2230
- return false ;
2239
+ if (node->src [i]) {
2240
+ if (node->src [i]->data != graph_node_properties->src_address [i] &&
2241
+ node->op != GGML_OP_VIEW) {
2242
+ return false ;
2243
+ }
2244
+
2245
+ for (int d = 0 ; d < GGML_MAX_DIMS; d++) {
2246
+ if (node->src [i]->ne [d] != graph_node_properties->src_ne [i][d]) {
2247
+ return false ;
2248
+ }
2249
+ if (node->src [i]->nb [d] != graph_node_properties->src_nb [i][d]) {
2250
+ return false ;
2251
+ }
2252
+ }
2253
+ } else {
2254
+ if (graph_node_properties->src_address [i] != nullptr ) {
2255
+ return false ;
2256
+ }
2231
2257
}
2232
2258
}
2259
+
2233
2260
if (node->op == GGML_OP_SCALE &&
2234
2261
memcmp (graph_node_properties->op_params , node->op_params , GGML_MAX_OP_PARAMS) != 0 ) {
2235
2262
return false ;
0 commit comments