From 5f5ff606b505fe195d7c8476d4c5a27560d58af8 Mon Sep 17 00:00:00 2001 From: hanchao Date: Thu, 28 Nov 2024 05:52:06 +0000 Subject: [PATCH] register xccl --- caffe2/CMakeLists.txt | 4 ++++ torch/csrc/distributed/c10d/init.cpp | 21 +++++++++++++++++++++ torch/xpu/__init__.py | 27 +++++++++++++++++++++++++++ 3 files changed, 52 insertions(+) diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index dbd765ab44b13e..1f8acea1b6b3a1 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1104,6 +1104,10 @@ if(USE_XPU) message(WARNING "Failed to include ATen XPU implementation target") else() target_link_libraries(torch_xpu PRIVATE torch_xpu_ops) + if(USE_C10D_XCCL) + message(WARNING "USE_C10D_XCCL") + target_compile_definitions(torch_xpu PUBLIC USE_C10D_XCCL) + endif() if(MSVC) # Windows target_link_options(torch_xpu PRIVATE diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 01fc8cb45a3336..18cf2764312212 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -30,6 +30,10 @@ #include #endif +#ifdef USE_C10D_XCCL +#include +#endif + #ifdef USE_C10D_MPI #include #endif @@ -2946,6 +2950,23 @@ Example:: py::call_guard()); #endif +#ifdef USE_C10D_XCCL + auto processGroupXCCL = + intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupXCCL>( + module, "ProcessGroupXCCL", backend) + .def( + py::init([](const c10::intrusive_ptr<::c10d::Store>& store, + int rank, + int size) { + return c10::make_intrusive<::c10d::ProcessGroupXCCL>( + store, rank, size); + }), + py::arg("store"), + py::arg("rank"), + py::arg("size"), + py::call_guard()); +#endif + py::enum_<::c10d::OpType>(module, "OpType") .value("BROADCAST", ::c10d::OpType::BROADCAST) .value("ALLREDUCE", ::c10d::OpType::ALLREDUCE) diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index 380c30bcc29794..a75d6781b54f28 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -459,6 +459,33 @@ def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int: default_generator = _get_generator(final_device) return default_generator.get_offset() +def _is_xccl_available(): + try: + from torch._C._distributed_c10d import ProcessGroupXCCL + return True + except ImportError: + return False + +_XCCL_AVAILABLE = _is_xccl_available() + +def _create_process_group_xccl(backend_opts, pg_opts): + if _XCCL_AVAILABLE: + from torch._C._distributed_c10d import ProcessGroupXCCL + return ProcessGroupXCCL(backend_opts.store, backend_opts.group_rank, backend_opts.group_size) + else: + return None + +def _init_xccl(): + if _XCCL_AVAILABLE: + torch.distributed.Backend.register_backend( + "xccl", + _create_process_group_xccl, + devices=["xpu"], + extended_api=True + ) + +_init_xccl() + # import here to avoid circular import from .memory import (