Skip to content

Commit 207c22e

Browse files
authored
ggml: Re-enable CUDA graphs in presence of CONT and DUP nodes (#12970)
1 parent 7a395f6 commit 207c22e

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

ggml/src/ggml-cuda/cpy.cu

+5-4
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ static void ggml_cpy_f16_f16_cuda(
551551
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
552552
}
553553

554-
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
554+
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
555555
const int64_t ne = ggml_nelements(src0);
556556
GGML_ASSERT(ne == ggml_nelements(src1));
557557

@@ -588,7 +588,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
588588
char ** dest_ptrs_d = nullptr;
589589
int graph_cpynode_index = -1;
590590
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
591-
if(ctx.cuda_graph->use_cpy_indirection) {
591+
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
592592
dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
593593
graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
594594
}
@@ -636,7 +636,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
636636
ggml_type_name(src0->type), ggml_type_name(src1->type));
637637
}
638638
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
639-
if(ctx.cuda_graph->use_cpy_indirection) {
639+
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
640640
ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
641641
}
642642
#endif
@@ -645,7 +645,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
645645

646646
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
647647
const ggml_tensor * src0 = dst->src[0];
648-
ggml_cuda_cpy(ctx, src0, dst);
648+
bool disable_indirection = true;
649+
ggml_cuda_cpy(ctx, src0, dst, disable_indirection);
649650
}
650651

651652
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {

ggml/src/ggml-cuda/cpy.cuh

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#define CUDA_CPY_BLOCK_SIZE 64
44

5-
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
5+
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection = false);
66

77
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
88

ggml/src/ggml-cuda/ggml-cuda.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -2489,7 +2489,7 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
24892489
#endif
24902490
}
24912491

2492-
if (node->op == GGML_OP_MUL_MAT_ID || node->op == GGML_OP_CONT || node->op == GGML_OP_DUP) {
2492+
if (node->op == GGML_OP_MUL_MAT_ID) {
24932493
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
24942494
#ifndef NDEBUG
24952495
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);

0 commit comments

Comments
 (0)