diff --git a/megatron/rl/rl_utils.py b/megatron/rl/rl_utils.py index b6f784530a4..07e3000c835 100644 --- a/megatron/rl/rl_utils.py +++ b/megatron/rl/rl_utils.py @@ -936,7 +936,7 @@ def prepare_data_for_update( # Wrap forward_backward_func for Full iteration CUDA graph forward_backward_func = get_forward_backward_func() - if args.enable_cuda_graph and args.cuda_graph_scope == "full_iteration": + if args.cuda_graph_impl == "local" and CudaGraphScope.full_iteration in args.cuda_graph_scope: forward_backward_func = FullCudaGraphWrapper( forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps )