From 32cd74f6234e4bebe0c361cb40deb35bf8f3f319 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Mon, 17 Feb 2025 19:09:34 +0800 Subject: [PATCH 01/10] test new api --- src/xccl/ProcessGroupXCCL.cpp | 38 +++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 2c4d6b9b0..252c9eaa1 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -890,14 +890,22 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); auto ccl_stream = ccl::create_stream(stream.queue()); - ccl::allreduce( + onecclAllReduce( input.data_ptr(), - output.data_ptr(), + coutput.data_ptr(), (size_t)input.numel(), xcclDataType, xcclReduceOp, comm, - ccl::create_stream(stream.queue())); + stream.queue()); + // ccl::allreduce( + // input.data_ptr(), + // output.data_ptr(), + // (size_t)input.numel(), + // xcclDataType, + // xcclReduceOp, + // comm, + // ccl::create_stream(stream.queue())); // Use SUM emu AVG due to oneCCL not support AVG // oneCCL is expected to support avg in basekit 2025.2 release. if (opts.reduceOp == ReduceOp::AVG) { @@ -944,14 +952,22 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( at::xpu::XPUStream& stream) { auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::allreduce( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - xcclReduceOp, - comm, - ccl::create_stream(stream.queue())); + onecclAllReduce( + input.data_ptr(), + coutput.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + comm, + stream.queue()); + // ccl::allreduce( + // input.data_ptr(), + // output.data_ptr(), + // (size_t)input.numel(), + // xcclDataType, + // xcclReduceOp, + // comm, + // ccl::create_stream(stream.queue())); // Use SUM emu AVG due to oneCCL not support AVG // oneCCL is expected to support avg in basekit 2025.2 release. if (opts.reduceOp == ReduceOp::AVG) { From 65e750d6bc4766086e5cb18890974c70b4e7a6ea Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Tue, 18 Feb 2025 00:46:30 +0800 Subject: [PATCH 02/10] update --- src/xccl/ProcessGroupXCCL.cpp | 722 ++++++++++++++++------------------ src/xccl/ProcessGroupXCCL.hpp | 3 +- 2 files changed, 347 insertions(+), 378 deletions(-) diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 252c9eaa1..81a4d142b 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -7,28 +7,29 @@ namespace c10d { namespace { -const std::map xcclOps = { - {ReduceOp::MIN, ccl::reduction::min}, - {ReduceOp::MAX, ccl::reduction::max}, - {ReduceOp::SUM, ccl::reduction::sum}, - {ReduceOp::PRODUCT, ccl::reduction::prod}, +const std::map xcclOps = { + {ReduceOp::MIN, onecclRedOp_t::ONECCL_MIN}, + {ReduceOp::MAX, onecclRedOp_t::ONECCL_MAX}, + {ReduceOp::SUM, onecclRedOp_t::ONECCL_SUM}, + {ReduceOp::PRODUCT, onecclRedOp_t::ONECCL_PROD}, + {ReduceOp::AVG, onecclRedOp_t::ONECCL_AVG}, }; -const std::map xcclDatatypes = { - {at::kByte, ccl::datatype::uint8}, - {at::kChar, ccl::datatype::int8}, - {at::kInt, ccl::datatype::int32}, - {at::kLong, ccl::datatype::int64}, - {at::kHalf, ccl::datatype::float16}, - {at::kFloat, ccl::datatype::float32}, - {at::kDouble, ccl::datatype::float64}, - {at::kBFloat16, ccl::datatype::bfloat16}, - {at::kBool, ccl::datatype::uint8}, +const std::map xcclDatatypes = { + {at::kByte, onecclDataType_t::ONECCL_UINT8}, + {at::kChar, onecclDataType_t::ONECCL_INT8}, + {at::kInt, onecclDataType_t::ONECCL_INT32}, + {at::kLong, onecclDataType_t::ONECCL_INT64}, + {at::kHalf, onecclDataType_t::ONECCL_FLOAT16}, + {at::kFloat, onecclDataType_t::ONECCL_FLOAT32}, + {at::kDouble, onecclDataType_t::ONECCL_FLOAT64}, + {at::kBFloat16, onecclDataType_t::ONECCL_BFLOAT16}, + {at::kBool, onecclDataType_t::ONECCL_UINT8}, // use for non-reducetion op like allgather - {at::kFloat8_e5m2, ccl::datatype::uint8}, - {at::kFloat8_e4m3fn, ccl::datatype::uint8}, - {at::kFloat8_e4m3fnuz, ccl::datatype::uint8}, - {at::kFloat8_e5m2fnuz, ccl::datatype::uint8}, + {at::kFloat8_e5m2, onecclDataType_t::ONECCL_UINT8}, + {at::kFloat8_e4m3fn, onecclDataType_t::ONECCL_UINT8}, + {at::kFloat8_e4m3fnuz, onecclDataType_t::ONECCL_UINT8}, + {at::kFloat8_e5m2fnuz, onecclDataType_t::ONECCL_UINT8}, }; bool computeLengthsAndCheckAndGetFlat( @@ -125,7 +126,7 @@ int64_t checkTensorOnSameDevice(const std::vector& tensors) { return total_numel; } -ccl::datatype getXcclDataType( +onecclDataType_t getXcclDataType( at::ScalarType type, bool is_reduction_op = false) { if (is_reduction_op) @@ -141,16 +142,11 @@ ccl::datatype getXcclDataType( return it->second; } -ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { +onecclRedOp_t getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { try { if (input.scalar_type() == at::kBool && reduceOp == ReduceOp::SUM) { // Map sum to max for bool tensors to avoid overflow issues with sum. - return ccl::reduction::max; - } - // Use SUM emu AVG due to oneCCL not support AVG. - // oneCCL is expected to support avg in basekit 2025.2 release. - if (reduceOp == ReduceOp::AVG) { - return ccl::reduction::sum; + return onecclRedOp_t::max; } return xcclOps.at(reduceOp); } catch (const std::out_of_range&) { @@ -274,6 +270,60 @@ c10::intrusive_ptr ProcessGroupXCCL::initWork( return r; } +void ProcessGroupNCCL::broadcastUniqueXCCLID( + onecclUniqueId* xcclID, + bool isSingleP2POp, + const std::string& p2pKey, + int p2pRank) { + + std::string storeKey; + if (!isSingleP2POp) { + storeKey = std::to_string(xcclCommCounter_++); + } else { + storeKey = p2pKey; + } + if (rank_ == 0 || (isSingleP2POp && p2pRank == 0)) { + auto vec = std::vector( + reinterpret_cast(xcclID), + reinterpret_cast(xcclID) + XCCL_UNIQUE_ID_BYTES); + store_->set(storeKey, vec); + } else { + try { + auto vec = store_->get(storeKey); + TORCH_CHECK_WITH( + DistBackendError, + vec.size() == XCCL_UNIQUE_ID_BYTES, + "Invalid size for xcclUniqueId"); + std::memcpy(xcclID, vec.data(), vec.size()); + } catch (const std::exception& e) { + std::string exceptionMsg = c10::str( + "[", + rank_, + "] is setting up XCCL communicator and " + "retrieving xcclUniqueId from [0] via c10d key-value store by key '", + storeKey, + "', but store->get('", + storeKey, + "') got error: "); + C10_THROW_ERROR( + DistBackendError, + exceptionMsg + e.what() + + ". This may indicate a possible application crash on rank 0 or a network set up issue."); + } catch (...) { + C10_THROW_ERROR( + DistBackendError, + c10::str( + "Unknown exception while [", + rank_, + "] is setting up XCCL communicator and " + "retrieving xcclUniqueId from [0] via c10d key-value store by key '", + storeKey, + "'", + ". This may indicate a possible application crash on rank 0 or a network set up issue.")); + } + } +} + std::shared_ptr ProcessGroupXCCL::getXCCLComm( const std::string& deviceKey, at::Device& device, @@ -297,6 +347,7 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( } std::shared_ptr XCCLComm; + xcclUniqueId xcclID; bool batchP2P = xcclActiveGroupCounter_ > 0; bool singleP2POp = isP2POp(opType, batchP2P); @@ -320,13 +371,15 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( impl.getStreamFromGlobalPool(device, /*isHighPriority=*/false); sycl::queue& q = c10::xpu::XPUStream(stream).queue(); - auto ctx = ccl::create_context(q.get_context()); - ccl::vector_class> devs_rank; - devs_rank.emplace_back(rank, ccl::create_device(q.get_device())); - auto xccl_kvs = get_kvs(rank_, *store_, singleP2POp, deviceKey, p2pRank); - auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, xccl_kvs); - XCCLComm = std::make_shared(std::move(comms[0])); + if (rank_ == 0 || (singleP2POp && p2pRank == 0)) { + onecclGetUniqueId(&uid); + } + broadcastUniqueXCCLID(&ncclID, singleP2POp, deviceKey, p2pRank); + + xcclComm_t comm; + onecclCommInitRank(&comm, numRanks, uid, rank); + XCCLComm = std::make_shared(std::move(comm)); RECORD_PARAM_COMMS( 0, // seq @@ -351,12 +404,12 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( } void ProcessGroupXCCL::groupStart() { - ccl::group_start(); + onecclGroupStart(); ++xcclActiveGroupCounter_; } void ProcessGroupXCCL::groupEnd() { - ccl::group_end(); + onecclGroupEnd; --xcclActiveGroupCounter_; } @@ -587,13 +640,13 @@ c10::intrusive_ptr ProcessGroupXCCL::send( at::xpu::XPUStream& stream, int dst) { auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::send( + onecclSend( input.data_ptr(), (size_t)input.numel(), xcclDataType, dst, comm, - ccl::create_stream(stream.queue())); + stream.queue()); return; }, dstRank, @@ -635,13 +688,13 @@ c10::intrusive_ptr ProcessGroupXCCL::recv( at::xpu::XPUStream& stream, int src) { auto xcclDataType = getXcclDataType(output.scalar_type()); - ccl::recv( + onecclRecv( output.data_ptr(), (size_t)output.numel(), xcclDataType, src, comm, - ccl::create_stream(stream.queue())); + stream.queue()); return; }, srcRank, @@ -733,13 +786,13 @@ c10::intrusive_ptr ProcessGroupXCCL::gather( for (const auto r : c10::irange(size_)) { if (r != root) { // do receive - ccl::recv( + onecclRecv( outputs[r].data_ptr(), (size_t)inputTensor.numel(), xcclDataType, r, comm, - ccl::create_stream(stream.queue())); + stream.queue()); } else { // on its own rank, simply copy from the input outputs[r].copy_(inputTensor); @@ -747,13 +800,13 @@ c10::intrusive_ptr ProcessGroupXCCL::gather( } } else { // do send - ccl::send( + onecclSend( inputTensor.data_ptr(), (size_t)inputTensor.numel(), xcclDataType, root, comm, - ccl::create_stream(stream.queue())); + stream.queue()); } return; } @@ -846,13 +899,13 @@ c10::intrusive_ptr ProcessGroupXCCL::scatter( // do send size_t send_count = inputs[r].numel(); auto send_type = getXcclDataType(inputs[r].scalar_type()); - ccl::send( + onecclSend( inputs[r].data_ptr(), send_count, send_type, r, comm, - ccl::create_stream(stream.queue())); + (stream.queue()); } else { // on its own rank, simply copy from the input outputTensor.copy_(inputs[r]); @@ -862,13 +915,13 @@ c10::intrusive_ptr ProcessGroupXCCL::scatter( // do receive size_t recv_count = outputTensor.numel(); auto recv_type = getXcclDataType(outputTensor.scalar_type()); - ccl::recv( + onecclRecv( outputTensor.data_ptr(), recv_count, recv_type, root, comm, - ccl::create_stream(stream.queue())); + stream.queue()); } return; @@ -889,29 +942,14 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( at::xpu::XPUStream& stream) { auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - auto ccl_stream = ccl::create_stream(stream.queue()); onecclAllReduce( input.data_ptr(), - coutput.data_ptr(), + output.data_ptr(), (size_t)input.numel(), xcclDataType, xcclReduceOp, comm, stream.queue()); - // ccl::allreduce( - // input.data_ptr(), - // output.data_ptr(), - // (size_t)input.numel(), - // xcclDataType, - // xcclReduceOp, - // comm, - // ccl::create_stream(stream.queue())); - // Use SUM emu AVG due to oneCCL not support AVG - // oneCCL is expected to support avg in basekit 2025.2 release. - if (opts.reduceOp == ReduceOp::AVG) { - auto divisor = getSize(); - output.div_(divisor); - } return; }, OpType::ALLREDUCE, @@ -953,30 +991,13 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); onecclAllReduce( - input.data_ptr(), - coutput.data_ptr(), - (size_t)input.numel(), - xcclDataType, - xcclReduceOp, - comm, - stream.queue()); - // ccl::allreduce( - // input.data_ptr(), - // output.data_ptr(), - // (size_t)input.numel(), - // xcclDataType, - // xcclReduceOp, - // comm, - // ccl::create_stream(stream.queue())); - // Use SUM emu AVG due to oneCCL not support AVG - // oneCCL is expected to support avg in basekit 2025.2 release. - if (opts.reduceOp == ReduceOp::AVG) { - auto divisor = getSize(); - c10::StreamGuard guard(stream); - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - output.div_(divisor); - } + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + comm, + stream.queue()); return; }, OpType::ALLREDUCE, @@ -1015,23 +1036,14 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( at::xpu::XPUStream& stream) { auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::allreduce( + onecclAllReduce( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), xcclDataType, xcclReduceOp, comm, - ccl::create_stream(stream.queue())); - // Use SUM emu AVG due to oneCCL not support AVG - // oneCCL is expected to support avg in basekit 2025.2 release. - if (opts.reduceOp == ReduceOp::AVG) { - auto divisor = getSize(); - c10::StreamGuard guard(stream); - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - output.div_(divisor); - } + stream.queue()); return; }, OpType::COALESCED, @@ -1073,13 +1085,13 @@ c10::intrusive_ptr ProcessGroupXCCL::broadcast( xcclComm_t& comm, at::xpu::XPUStream& stream) { auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::broadcast( + onecclBroadcast( input.data_ptr(), (size_t)input.numel(), xcclDataType, root, comm, - ccl::create_stream(stream.queue())); + stream.queue()); return; }, OpType::BROADCAST, @@ -1104,14 +1116,14 @@ c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( xcclComm_t& comm, at::xpu::XPUStream& stream) { auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::broadcast( + onecclBroadcast( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), xcclDataType, root, comm, - ccl::create_stream(stream.queue())); + stream.queue()); return; }, OpType::BROADCAST, @@ -1152,7 +1164,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce( const int root = opts.rootRank + opts.rootTensor; const auto xcclDataType = getXcclDataType(input.scalar_type(), true); const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::reduce( + onecclReduce( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), @@ -1160,15 +1172,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce( xcclReduceOp, root, comm, - ccl::create_stream(stream.queue())); - // WA due to oneCCL not support AVG - if (opts.reduceOp == ReduceOp::AVG && getRank() == root) { - auto divisor = getSize(); - c10::StreamGuard guard(stream); - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - output.div_(divisor); - } + stream.queue()); return; }, OpType::REDUCE, @@ -1193,24 +1197,15 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( const int root = opts.rootRank + opts.rootTensor; const auto xcclDataType = getXcclDataType(input.scalar_type(), true); const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::reduce( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - xcclReduceOp, - root, - comm, - ccl::create_stream(stream.queue())); - // Use SUM emu AVG due to oneCCL not support AVG - // oneCCL is expected to support avg in basekit 2025.2 release. - if (opts.reduceOp == ReduceOp::AVG && getRank() == root) { - auto divisor = getSize(); - c10::StreamGuard guard(stream); - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - output.div_(divisor); - } + onecclReduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + root, + comm, + stream.queue()); return; }, OpType::REDUCE, @@ -1261,13 +1256,13 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather( c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::allgather( + onecclAllGather( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), xcclDataType, comm, - ccl::create_stream(stream.queue())); + stream.queue()); return; }, [](at::xpu::XPUStream&, @@ -1342,13 +1337,13 @@ c10::intrusive_ptr ProcessGroupXCCL::_allgather_base( c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::allgather( + onecclAllGather( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), xcclDataType, comm, - ccl::create_stream(stream.queue())); + stream.queue()); return; }, OpType::_ALLGATHER_BASE, @@ -1367,13 +1362,13 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather_into_tensor_coalesced( xcclComm_t& comm, at::xpu::XPUStream& stream) { auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::allgather( + onecclAllGather( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), xcclDataType, comm, - ccl::create_stream(stream.queue())); + stream.queue()); return; }, OpType::COALESCED, @@ -1423,23 +1418,14 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::reduce_scatter( + onecclReduceScatter( input.data_ptr(), output.data_ptr(), (size_t)output.numel(), xcclDataType, xcclReduceOp, comm, - ccl::create_stream(stream.queue())); - // Use SUM emu AVG due to oneCCL not support AVG - // oneCCL is expected to support avg in basekit 2025.2 release. - if (opts.reduceOp == ReduceOp::AVG) { - auto divisor = getSize(); - c10::StreamGuard guard(stream); - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - output.div_(divisor); - } + stream.queue()); return; }, [&](at::xpu::XPUStream& Stream, @@ -1515,23 +1501,14 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::reduce_scatter( + onecclReduceScatter( input.data_ptr(), output.data_ptr(), (size_t)output.numel(), xcclDataType, xcclReduceOp, comm, - ccl::create_stream(stream.queue())); - // Use SUM emu AVG due to oneCCL not support AVG - // oneCCL is expected to support avg in basekit 2025.2 release. - if (opts.reduceOp == ReduceOp::AVG) { - auto divisor = getSize(); - c10::StreamGuard guard(stream); - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - output.div_(divisor); - } + stream.queue()); return; }, OpType::_REDUCE_SCATTER_BASE, @@ -1553,23 +1530,14 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::reduce_scatter( + onecclReduceScatter( input.data_ptr(), output.data_ptr(), (size_t)output.numel(), xcclDataType, xcclReduceOp, comm, - ccl::create_stream(stream.queue())); - // Use SUM emu AVG due to oneCCL not support AVG - // oneCCL is expected to support avg in basekit 2025.2 release. - if (opts.reduceOp == ReduceOp::AVG) { - auto divisor = getSize(); - c10::StreamGuard guard(stream); - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - output.div_(divisor); - } + stream.queue()); return; }, OpType::COALESCED, @@ -1623,214 +1591,214 @@ c10::intrusive_ptr ProcessGroupXCCL::barrier(const BarrierOptions& opts) { return work; } -c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - std::vector& outputSplitSizes, - std::vector& inputSplitSizes, - const AllToAllOptions& /* unused */) { - checkSingleTensor(outputTensor, true); - checkSingleTensor(inputTensor, true); - if (outputSplitSizes.size() == 0 && inputSplitSizes.size() == 0) { - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + - 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensor, // inputTensor - outputTensor, // outputTensor - rank_, // rank - "all_to_all", // collective name - inputTensor.numel(), // inNelems - outputTensor.numel(), // outNelems - inputTensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - TORCH_CHECK( - outputTensor.numel() == inputTensor.numel() && - outputTensor.scalar_type() == inputTensor.scalar_type(), - "xpu_alltoall_base: tensors are not equal in size or data type"); - TORCH_CHECK( - outputTensor.size(0) % size_ == 0, - "xpu_alltoall_base: tensor's dim 0 does not divide equally across group size"); - return collective( - inputTensor, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(output.scalar_type()); - ccl::alltoall( - input.data_ptr(), - output.data_ptr(), - (size_t)output.numel() / comm.size(), - xcclDataType, - comm, - ccl::create_stream(stream.queue())); - return; - }, - OpType::ALLTOALL_BASE, - "xccl:all_to_all"); - } else { - c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); - c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + - 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensor, // inputTensor - outputTensor, // outputTensor - rank_, // rank - "all_to_allv", // collective name - inputTensor.numel(), // inNelems - outputTensor.numel(), // outNelems - inputTensor.scalar_type(), // dType - inputSplitSizes, // inSplitSizes - outputSplitSizes, // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - return collective( - inputTensor, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - std::vector sendCounts(size_); - std::vector recvCounts(size_); - bool inputSplitsEqual = inputSplitSizes.size() == 0; - bool outputSplitsEqual = outputSplitSizes.size() == 0; - - size_t inLen = input.numel(); - size_t outLen = output.numel(); - if (inLen) - inLen /= (inputSplitsEqual ? size_ : input.size(0)); - if (outLen) - outLen /= (outputSplitsEqual ? size_ : output.size(0)); - - for (int i = 0; i < size_; i++) { - sendCounts[i] = - (inputSplitsEqual ? inLen : inputSplitSizes[i] * inLen); - recvCounts[i] = - (outputSplitsEqual ? outLen : outputSplitSizes[i] * outLen); - } - auto xcclDataType = getXcclDataType(output.scalar_type()); - ccl::alltoallv( - input.data_ptr(), - sendCounts, - output.data_ptr(), - recvCounts, - xcclDataType, - comm, - ccl::create_stream(stream.queue())); - return; - }, - OpType::ALLTOALL_BASE, - "xccl:all_to_all"); - } -} - -c10::intrusive_ptr ProcessGroupXCCL::alltoall( - std::vector& outputTensors, - std::vector& inputTensors, - const AllToAllOptions& /* unused */) { - auto device = outputTensors[0].device(); - int64_t total_numel = 0; - for (const auto r : c10::irange(outputTensors.size())) { - checkSingleTensor(outputTensors[r], true); - checkSingleTensor(inputTensors[r], true); - TORCH_CHECK( - device == outputTensors[r].device() && - device == inputTensors[r].device(), - "Tensors must be on the same device") - total_numel += inputTensors[r].numel(); - } - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensors, // inputTensors - outputTensors, // outputTensors - rank_, // rank - "all_to_all", // collective name - total_numel, // inNelems - total_numel, // outNelems - inputTensors.front().scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - return collective( - inputTensors, - outputTensors, - [&](at::Tensor& /* unused */, - at::Tensor& /* unused */, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - c10::OptionalStreamGuard stream_guard(stream.unwrap()); - at::Tensor flatInput; - at::Tensor flatOutput; - - std::vector sendCounts(size_); - std::vector recvCounts(size_); - - int64_t flatSendCount; - int64_t flatRecvCount; - - bool isInputFlat = computeLengthsAndCheckAndGetFlat( - inputTensors, sendCounts, flatInput, flatSendCount); - bool isOutputFlat = computeLengthsAndCheckAndGetFlat( - outputTensors, recvCounts, flatOutput, flatRecvCount); - if (!isInputFlat) { - auto flatInputSplits = flatInput.split_with_sizes( - c10::IntArrayRef((int64_t*)sendCounts.data(), sendCounts.size()), - 0); - - for (int i = 0; i < size_; i++) { - flatInputSplits[i].copy_(inputTensors[i].view({-1})); - } - } - - auto xcclDataType = getXcclDataType(flatOutput.scalar_type()); - ccl::event ret_evt; - ret_evt = ccl::alltoallv( - flatInput.data_ptr(), - sendCounts, - flatOutput.data_ptr(), - recvCounts, - xcclDataType, - comm, - ccl::create_stream(stream.queue())); - - if (!isOutputFlat) { - ret_evt.wait(); - auto flatOutputSplits = flatOutput.split_with_sizes( - c10::IntArrayRef((int64_t*)recvCounts.data(), recvCounts.size()), - 0); - - for (int i = 0; i < size_; i++) { - outputTensors[i].view({-1}).copy_(flatOutputSplits[i]); - } - } - stream.synchronize(); - return; - }, - OpType::ALLTOALL, - "xccl:all_to_all"); -} +// c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( +// at::Tensor& outputTensor, +// at::Tensor& inputTensor, +// std::vector& outputSplitSizes, +// std::vector& inputSplitSizes, +// const AllToAllOptions& /* unused */) { +// checkSingleTensor(outputTensor, true); +// checkSingleTensor(inputTensor, true); +// if (outputSplitSizes.size() == 0 && inputSplitSizes.size() == 0) { +// RECORD_PARAM_COMMS_DATA( +// static_cast( +// this->getSequenceNumberForGroup() + +// 1), // seq + 1 to match collective +// std::make_tuple(pg_uid_, pg_desc_), // PG name tuple +// inputTensor, // inputTensor +// outputTensor, // outputTensor +// rank_, // rank +// "all_to_all", // collective name +// inputTensor.numel(), // inNelems +// outputTensor.numel(), // outNelems +// inputTensor.scalar_type(), // dType +// std::vector(), // inSplitSizes +// std::vector(), // outSplitSizes +// -1, // globalRankStart +// -1, // globalRankStride +// this->getSize()); // worldSize +// TORCH_CHECK( +// outputTensor.numel() == inputTensor.numel() && +// outputTensor.scalar_type() == inputTensor.scalar_type(), +// "xpu_alltoall_base: tensors are not equal in size or data type"); +// TORCH_CHECK( +// outputTensor.size(0) % size_ == 0, +// "xpu_alltoall_base: tensor's dim 0 does not divide equally across group size"); +// return collective( +// inputTensor, +// outputTensor, +// [&](at::Tensor& input, +// at::Tensor& output, +// xcclComm_t& comm, +// at::xpu::XPUStream& stream) { +// c10::xpu::XPUCachingAllocator::recordStream( +// output.storage().data_ptr(), stream); +// auto xcclDataType = getXcclDataType(output.scalar_type()); +// ccl::alltoall( +// input.data_ptr(), +// output.data_ptr(), +// (size_t)output.numel() / comm.size(), +// xcclDataType, +// comm, +// ccl::create_stream(stream.queue())); +// return; +// }, +// OpType::ALLTOALL_BASE, +// "xccl:all_to_all"); +// } else { +// c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); +// c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); + +// RECORD_PARAM_COMMS_DATA( +// static_cast( +// this->getSequenceNumberForGroup() + +// 1), // seq + 1 to match collective +// std::make_tuple(pg_uid_, pg_desc_), // PG name tuple +// inputTensor, // inputTensor +// outputTensor, // outputTensor +// rank_, // rank +// "all_to_allv", // collective name +// inputTensor.numel(), // inNelems +// outputTensor.numel(), // outNelems +// inputTensor.scalar_type(), // dType +// inputSplitSizes, // inSplitSizes +// outputSplitSizes, // outSplitSizes +// -1, // globalRankStart +// -1, // globalRankStride +// this->getSize()); // worldSize + +// return collective( +// inputTensor, +// outputTensor, +// [&](at::Tensor& input, +// at::Tensor& output, +// xcclComm_t& comm, +// at::xpu::XPUStream& stream) { +// std::vector sendCounts(size_); +// std::vector recvCounts(size_); +// bool inputSplitsEqual = inputSplitSizes.size() == 0; +// bool outputSplitsEqual = outputSplitSizes.size() == 0; + +// size_t inLen = input.numel(); +// size_t outLen = output.numel(); +// if (inLen) +// inLen /= (inputSplitsEqual ? size_ : input.size(0)); +// if (outLen) +// outLen /= (outputSplitsEqual ? size_ : output.size(0)); + +// for (int i = 0; i < size_; i++) { +// sendCounts[i] = +// (inputSplitsEqual ? inLen : inputSplitSizes[i] * inLen); +// recvCounts[i] = +// (outputSplitsEqual ? outLen : outputSplitSizes[i] * outLen); +// } +// auto xcclDataType = getXcclDataType(output.scalar_type()); +// ccl::alltoallv( +// input.data_ptr(), +// sendCounts, +// output.data_ptr(), +// recvCounts, +// xcclDataType, +// comm, +// ccl::create_stream(stream.queue())); +// return; +// }, +// OpType::ALLTOALL_BASE, +// "xccl:all_to_all"); +// } +// } + +// c10::intrusive_ptr ProcessGroupXCCL::alltoall( +// std::vector& outputTensors, +// std::vector& inputTensors, +// const AllToAllOptions& /* unused */) { +// auto device = outputTensors[0].device(); +// int64_t total_numel = 0; +// for (const auto r : c10::irange(outputTensors.size())) { +// checkSingleTensor(outputTensors[r], true); +// checkSingleTensor(inputTensors[r], true); +// TORCH_CHECK( +// device == outputTensors[r].device() && +// device == inputTensors[r].device(), +// "Tensors must be on the same device") +// total_numel += inputTensors[r].numel(); +// } + +// RECORD_PARAM_COMMS_DATA( +// static_cast( +// this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective +// std::make_tuple(pg_uid_, pg_desc_), // PG name tuple +// inputTensors, // inputTensors +// outputTensors, // outputTensors +// rank_, // rank +// "all_to_all", // collective name +// total_numel, // inNelems +// total_numel, // outNelems +// inputTensors.front().scalar_type(), // dType +// std::vector(), // inSplitSizes +// std::vector(), // outSplitSizes +// -1, // globalRankStart +// -1, // globalRankStride +// this->getSize()); // worldSize + +// return collective( +// inputTensors, +// outputTensors, +// [&](at::Tensor& /* unused */, +// at::Tensor& /* unused */, +// xcclComm_t& comm, +// at::xpu::XPUStream& stream) { +// c10::OptionalStreamGuard stream_guard(stream.unwrap()); +// at::Tensor flatInput; +// at::Tensor flatOutput; + +// std::vector sendCounts(size_); +// std::vector recvCounts(size_); + +// int64_t flatSendCount; +// int64_t flatRecvCount; + +// bool isInputFlat = computeLengthsAndCheckAndGetFlat( +// inputTensors, sendCounts, flatInput, flatSendCount); +// bool isOutputFlat = computeLengthsAndCheckAndGetFlat( +// outputTensors, recvCounts, flatOutput, flatRecvCount); +// if (!isInputFlat) { +// auto flatInputSplits = flatInput.split_with_sizes( +// c10::IntArrayRef((int64_t*)sendCounts.data(), sendCounts.size()), +// 0); + +// for (int i = 0; i < size_; i++) { +// flatInputSplits[i].copy_(inputTensors[i].view({-1})); +// } +// } + +// auto xcclDataType = getXcclDataType(flatOutput.scalar_type()); +// ccl::event ret_evt; +// ret_evt = ccl::alltoallv( +// flatInput.data_ptr(), +// sendCounts, +// flatOutput.data_ptr(), +// recvCounts, +// xcclDataType, +// comm, +// ccl::create_stream(stream.queue())); + +// if (!isOutputFlat) { +// ret_evt.wait(); +// auto flatOutputSplits = flatOutput.split_with_sizes( +// c10::IntArrayRef((int64_t*)recvCounts.data(), recvCounts.size()), +// 0); + +// for (int i = 0; i < size_; i++) { +// outputTensors[i].view({-1}).copy_(flatOutputSplits[i]); +// } +// } +// stream.synchronize(); +// return; +// }, +// OpType::ALLTOALL, +// "xccl:all_to_all"); +// } } // namespace c10d diff --git a/src/xccl/ProcessGroupXCCL.hpp b/src/xccl/ProcessGroupXCCL.hpp index dbb4e936a..29b2d7284 100644 --- a/src/xccl/ProcessGroupXCCL.hpp +++ b/src/xccl/ProcessGroupXCCL.hpp @@ -21,11 +21,12 @@ #include namespace c10d { +#define XCCL_UNIQUE_ID_BYTES 128 static std::vector TORCH_XCCL_BLOCKING_WAIT = { "TORCH_XCCL_BLOCKING_WAIT", "XCCL_BLOCKING_WAIT"}; -using xcclComm_t = ccl::communicator; +using xcclComm_t = onecclComm_t; constexpr const char* XCCL_BACKEND_NAME = "xccl"; class TORCH_API ProcessGroupXCCL : public Backend { From 070d42c843840f93d37d4fdf9597760d48255b0e Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Fri, 21 Feb 2025 18:10:13 +0800 Subject: [PATCH 03/10] update --- src/xccl/ProcessGroupXCCL.cpp | 38 +++++++++--------- src/xccl/ProcessGroupXCCL.hpp | 74 ++++++++++++++++++----------------- 2 files changed, 57 insertions(+), 55 deletions(-) diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 81a4d142b..4f3b30520 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -146,7 +146,7 @@ onecclRedOp_t getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { try { if (input.scalar_type() == at::kBool && reduceOp == ReduceOp::SUM) { // Map sum to max for bool tensors to avoid overflow issues with sum. - return onecclRedOp_t::max; + return onecclRedOp_t::ONECCL_MAX; } return xcclOps.at(reduceOp); } catch (const std::out_of_range&) { @@ -271,11 +271,10 @@ c10::intrusive_ptr ProcessGroupXCCL::initWork( } void ProcessGroupNCCL::broadcastUniqueXCCLID( - onecclUniqueId* xcclID, + onecclUniqueId* xcclID, bool isSingleP2POp, const std::string& p2pKey, int p2pRank) { - std::string storeKey; if (!isSingleP2POp) { storeKey = std::to_string(xcclCommCounter_++); @@ -371,7 +370,6 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( impl.getStreamFromGlobalPool(device, /*isHighPriority=*/false); sycl::queue& q = c10::xpu::XPUStream(stream).queue(); - if (rank_ == 0 || (singleP2POp && p2pRank == 0)) { onecclGetUniqueId(&uid); } @@ -409,7 +407,7 @@ void ProcessGroupXCCL::groupStart() { } void ProcessGroupXCCL::groupEnd() { - onecclGroupEnd; + onecclGroupEnd(); --xcclActiveGroupCounter_; } @@ -1198,14 +1196,14 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( const auto xcclDataType = getXcclDataType(input.scalar_type(), true); const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); onecclReduce( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - xcclReduceOp, - root, - comm, - stream.queue()); + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + root, + comm, + stream.queue()); return; }, OpType::REDUCE, @@ -1623,7 +1621,8 @@ c10::intrusive_ptr ProcessGroupXCCL::barrier(const BarrierOptions& opts) { // "xpu_alltoall_base: tensors are not equal in size or data type"); // TORCH_CHECK( // outputTensor.size(0) % size_ == 0, -// "xpu_alltoall_base: tensor's dim 0 does not divide equally across group size"); +// "xpu_alltoall_base: tensor's dim 0 does not divide equally across +// group size"); // return collective( // inputTensor, // outputTensor, @@ -1726,7 +1725,8 @@ c10::intrusive_ptr ProcessGroupXCCL::barrier(const BarrierOptions& opts) { // RECORD_PARAM_COMMS_DATA( // static_cast( -// this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective +// this->getSequenceNumberForGroup() + 1), // seq + 1 to match +// collective // std::make_tuple(pg_uid_, pg_desc_), // PG name tuple // inputTensors, // inputTensors // outputTensors, // outputTensors @@ -1764,8 +1764,8 @@ c10::intrusive_ptr ProcessGroupXCCL::barrier(const BarrierOptions& opts) { // outputTensors, recvCounts, flatOutput, flatRecvCount); // if (!isInputFlat) { // auto flatInputSplits = flatInput.split_with_sizes( -// c10::IntArrayRef((int64_t*)sendCounts.data(), sendCounts.size()), -// 0); +// c10::IntArrayRef((int64_t*)sendCounts.data(), +// sendCounts.size()), 0); // for (int i = 0; i < size_; i++) { // flatInputSplits[i].copy_(inputTensors[i].view({-1})); @@ -1786,8 +1786,8 @@ c10::intrusive_ptr ProcessGroupXCCL::barrier(const BarrierOptions& opts) { // if (!isOutputFlat) { // ret_evt.wait(); // auto flatOutputSplits = flatOutput.split_with_sizes( -// c10::IntArrayRef((int64_t*)recvCounts.data(), recvCounts.size()), -// 0); +// c10::IntArrayRef((int64_t*)recvCounts.data(), +// recvCounts.size()), 0); // for (int i = 0; i < size_; i++) { // outputTensors[i].view({-1}).copy_(flatOutputSplits[i]); diff --git a/src/xccl/ProcessGroupXCCL.hpp b/src/xccl/ProcessGroupXCCL.hpp index 29b2d7284..541d9921a 100644 --- a/src/xccl/ProcessGroupXCCL.hpp +++ b/src/xccl/ProcessGroupXCCL.hpp @@ -6,7 +6,7 @@ #define CCL_ENABLE_ZE #define CCL_ENABLE_SYCL -#include +#include #include #include #include @@ -201,11 +201,11 @@ class TORCH_API ProcessGroupXCCL : public Backend { // `xcclActiveGroupCounter_` is introduced to track group calls made // in the frontend. In this scenario, the `groupStart` wrap API is // used. - ccl::group_start(); + onecclGroupStart(); }, [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - ccl::group_end(); + onecclGroupEnd(); }, opType, profilingTitle); @@ -342,39 +342,41 @@ class TORCH_API ProcessGroupXCCL : public Backend { private: std::mutex kvs_mutex; - ccl::shared_ptr_class get_kvs( - int rank, - c10d::Store& store, - bool singleP2POp = false, - const std::string& p2pKey = "", - int p2pRank = 0) { - std::lock_guard lock(kvs_mutex); - ccl::shared_ptr_class kvs; - std::string storeKey; - if (!singleP2POp) { - storeKey = std::to_string(xcclCommCounter_++); - } else { - storeKey = p2pKey; - } - // Rank 0 broadcast the bootstrap network information to other ranks - if (rank == 0 || (singleP2POp && p2pRank == 0)) { - kvs = ccl::create_main_kvs(); - ccl::kvs::address_type main_addr = kvs->get_address(); - auto ccl_kvs_addr = - std::vector(main_addr.begin(), main_addr.end()); - store.set(storeKey, ccl_kvs_addr); - } else { - auto ccl_kvs_addr = store.get(storeKey); - if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) { - throw std::runtime_error("Unexpected ccl kvs addr from the store\n"); - } - ccl::kvs::address_type main_addr; - std::copy_n( - ccl_kvs_addr.begin(), ccl::kvs::address_max_size, main_addr.begin()); - kvs = ccl::create_kvs(main_addr); - } - return kvs; - } + // ccl::shared_ptr_class get_kvs( + // int rank, + // c10d::Store& store, + // bool singleP2POp = false, + // const std::string& p2pKey = "", + // int p2pRank = 0) { + // std::lock_guard lock(kvs_mutex); + // ccl::shared_ptr_class kvs; + // std::string storeKey; + // if (!singleP2POp) { + // storeKey = std::to_string(xcclCommCounter_++); + // } else { + // storeKey = p2pKey; + // } + // // Rank 0 broadcast the bootstrap network information to other ranks + // if (rank == 0 || (singleP2POp && p2pRank == 0)) { + // kvs = ccl::create_main_kvs(); + // ccl::kvs::address_type main_addr = kvs->get_address(); + // auto ccl_kvs_addr = + // std::vector(main_addr.begin(), main_addr.end()); + // store.set(storeKey, ccl_kvs_addr); + // } else { + // auto ccl_kvs_addr = store.get(storeKey); + // if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) { + // throw std::runtime_error("Unexpected ccl kvs addr from the + // store\n"); + // } + // ccl::kvs::address_type main_addr; + // std::copy_n( + // ccl_kvs_addr.begin(), ccl::kvs::address_max_size, + // main_addr.begin()); + // kvs = ccl::create_kvs(main_addr); + // } + // return kvs; + // } }; } // namespace c10d From 842547a68c3ed32189577c3c30d5d591816a50e1 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Mon, 24 Feb 2025 17:50:49 +0800 Subject: [PATCH 04/10] update --- src/xccl/ProcessGroupXCCL.cpp | 51 ++++++++++++++++++----------------- src/xccl/ProcessGroupXCCL.hpp | 30 ++++++++++++--------- 2 files changed, 44 insertions(+), 37 deletions(-) diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 4f3b30520..8dc444e4b 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -270,7 +270,7 @@ c10::intrusive_ptr ProcessGroupXCCL::initWork( return r; } -void ProcessGroupNCCL::broadcastUniqueXCCLID( +void ProcessGroupXCCL::broadcastUniqueXCCLID( onecclUniqueId* xcclID, bool isSingleP2POp, const std::string& p2pKey, @@ -346,7 +346,7 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( } std::shared_ptr XCCLComm; - xcclUniqueId xcclID; + onecclUniqueId xcclID; bool batchP2P = xcclActiveGroupCounter_ > 0; bool singleP2POp = isP2POp(opType, batchP2P); @@ -371,12 +371,12 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( sycl::queue& q = c10::xpu::XPUStream(stream).queue(); if (rank_ == 0 || (singleP2POp && p2pRank == 0)) { - onecclGetUniqueId(&uid); + onecclGetUniqueId(&xcclID); } - broadcastUniqueXCCLID(&ncclID, singleP2POp, deviceKey, p2pRank); + broadcastUniqueXCCLID(&xcclID, singleP2POp, deviceKey, p2pRank); - xcclComm_t comm; - onecclCommInitRank(&comm, numRanks, uid, rank); + xcclComm_t comm = nullptr; + onecclCommInitRank(&comm, numRanks, xcclID, rank); XCCLComm = std::make_shared(std::move(comm)); RECORD_PARAM_COMMS( @@ -644,7 +644,7 @@ c10::intrusive_ptr ProcessGroupXCCL::send( xcclDataType, dst, comm, - stream.queue()); + &(stream.queue())); return; }, dstRank, @@ -692,7 +692,7 @@ c10::intrusive_ptr ProcessGroupXCCL::recv( xcclDataType, src, comm, - stream.queue()); + &(stream.queue())); return; }, srcRank, @@ -790,7 +790,7 @@ c10::intrusive_ptr ProcessGroupXCCL::gather( xcclDataType, r, comm, - stream.queue()); + &(stream.queue())); } else { // on its own rank, simply copy from the input outputs[r].copy_(inputTensor); @@ -804,7 +804,7 @@ c10::intrusive_ptr ProcessGroupXCCL::gather( xcclDataType, root, comm, - stream.queue()); + &(stream.queue())); } return; } @@ -903,7 +903,7 @@ c10::intrusive_ptr ProcessGroupXCCL::scatter( send_type, r, comm, - (stream.queue()); + &(stream.queue())); } else { // on its own rank, simply copy from the input outputTensor.copy_(inputs[r]); @@ -919,7 +919,7 @@ c10::intrusive_ptr ProcessGroupXCCL::scatter( recv_type, root, comm, - stream.queue()); + &(stream.queue())); } return; @@ -947,7 +947,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( xcclDataType, xcclReduceOp, comm, - stream.queue()); + &(stream.queue())); return; }, OpType::ALLREDUCE, @@ -995,7 +995,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( xcclDataType, xcclReduceOp, comm, - stream.queue()); + &(stream.queue())); return; }, OpType::ALLREDUCE, @@ -1041,7 +1041,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( xcclDataType, xcclReduceOp, comm, - stream.queue()); + &(stream.queue())); return; }, OpType::COALESCED, @@ -1085,11 +1085,12 @@ c10::intrusive_ptr ProcessGroupXCCL::broadcast( auto xcclDataType = getXcclDataType(input.scalar_type()); onecclBroadcast( input.data_ptr(), + output.data_ptr(), // ? (size_t)input.numel(), xcclDataType, root, comm, - stream.queue()); + &(stream.queue())); return; }, OpType::BROADCAST, @@ -1121,7 +1122,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( xcclDataType, root, comm, - stream.queue()); + &(stream.queue())); return; }, OpType::BROADCAST, @@ -1170,7 +1171,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce( xcclReduceOp, root, comm, - stream.queue()); + &(stream.queue())); return; }, OpType::REDUCE, @@ -1203,7 +1204,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( xcclReduceOp, root, comm, - stream.queue()); + &(stream.queue())); return; }, OpType::REDUCE, @@ -1260,7 +1261,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather( (size_t)input.numel(), xcclDataType, comm, - stream.queue()); + &(stream.queue())); return; }, [](at::xpu::XPUStream&, @@ -1341,7 +1342,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_allgather_base( (size_t)input.numel(), xcclDataType, comm, - stream.queue()); + &(stream.queue())); return; }, OpType::_ALLGATHER_BASE, @@ -1366,7 +1367,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather_into_tensor_coalesced( (size_t)input.numel(), xcclDataType, comm, - stream.queue()); + &(stream.queue())); return; }, OpType::COALESCED, @@ -1423,7 +1424,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( xcclDataType, xcclReduceOp, comm, - stream.queue()); + &(stream.queue())); return; }, [&](at::xpu::XPUStream& Stream, @@ -1506,7 +1507,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( xcclDataType, xcclReduceOp, comm, - stream.queue()); + &(stream.queue())); return; }, OpType::_REDUCE_SCATTER_BASE, @@ -1535,7 +1536,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( xcclDataType, xcclReduceOp, comm, - stream.queue()); + &(stream.queue())); return; }, OpType::COALESCED, diff --git a/src/xccl/ProcessGroupXCCL.hpp b/src/xccl/ProcessGroupXCCL.hpp index 541d9921a..ed259b859 100644 --- a/src/xccl/ProcessGroupXCCL.hpp +++ b/src/xccl/ProcessGroupXCCL.hpp @@ -116,6 +116,12 @@ class TORCH_API ProcessGroupXCCL : public Backend { const std::vector& inputs = {}, const std::vector& outputs = {}); + void broadcastUniqueXCCLID( + onecclUniqueId* xcclID, + bool isSingleP2POp, + const std::string& p2pKey, + int p2pRank); + template c10::intrusive_ptr collective( at::Tensor& input, @@ -283,17 +289,17 @@ class TORCH_API ProcessGroupXCCL : public Backend { c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; - c10::intrusive_ptr alltoall_base( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - std::vector& outputSplitSizes, - std::vector& inputSplitSizes, - const AllToAllOptions& opts = AllToAllOptions()) override; + // c10::intrusive_ptr alltoall_base( + // at::Tensor& outputTensor, + // at::Tensor& inputTensor, + // std::vector& outputSplitSizes, + // std::vector& inputSplitSizes, + // const AllToAllOptions& opts = AllToAllOptions()) override; - c10::intrusive_ptr alltoall( - std::vector& outputTensors, - std::vector& inputTensors, - const AllToAllOptions& opts = AllToAllOptions()) override; + // c10::intrusive_ptr alltoall( + // std::vector& outputTensors, + // std::vector& inputTensors, + // const AllToAllOptions& opts = AllToAllOptions()) override; c10::intrusive_ptr send( std::vector& tensors, @@ -339,8 +345,8 @@ class TORCH_API ProcessGroupXCCL : public Backend { uint64_t seqCollective_{0}; uint64_t seqP2P_{0}; - private: - std::mutex kvs_mutex; + // private: + // std::mutex kvs_mutex; // ccl::shared_ptr_class get_kvs( // int rank, From 0bc5b51bc63fafd35830d4a2e3018c7c258ff757 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Mon, 24 Feb 2025 21:49:22 +0800 Subject: [PATCH 05/10] All new api changed --- src/xccl/ProcessGroupXCCL.cpp | 470 ++++++++++++++++------------------ src/xccl/ProcessGroupXCCL.hpp | 60 +---- 2 files changed, 230 insertions(+), 300 deletions(-) diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 8dc444e4b..ce0326b53 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -32,44 +32,6 @@ const std::map xcclDatatypes = { {at::kFloat8_e5m2fnuz, onecclDataType_t::ONECCL_UINT8}, }; -bool computeLengthsAndCheckAndGetFlat( - const std::vector& tensors, - std::vector& lengths, - at::Tensor& flatTensor, - int64_t& flatLength) { - int64_t groupSize = tensors.size(); - auto firstTensor = tensors[0]; - int64_t totalSize = 0; - bool isFlat = true; - - auto storage = firstTensor.storage(); - int64_t firstStorageOffset = firstTensor.storage_offset(); - - for (int i = 0; i < groupSize; i++) { - auto& curTensor = tensors[i]; - int64_t length = curTensor.numel(); - lengths[i] = length; - totalSize += length; - - if (isFlat && - (!storage.is_alias_of(curTensor.storage()) || - curTensor.storage_offset() != - firstStorageOffset + totalSize - length)) { - isFlat = false; - } - } - - flatLength = totalSize; - - if (isFlat) { - flatTensor = firstTensor; - } else { - flatTensor = at::empty({totalSize}, firstTensor.options()); - } - - return isFlat; -} - bool checkSameSize(const std::vector& input_tensors) { for (const auto& input_tensor : input_tensors) { if (!input_tensors[0].is_same_size(input_tensor)) { @@ -284,14 +246,14 @@ void ProcessGroupXCCL::broadcastUniqueXCCLID( if (rank_ == 0 || (isSingleP2POp && p2pRank == 0)) { auto vec = std::vector( reinterpret_cast(xcclID), - reinterpret_cast(xcclID) + XCCL_UNIQUE_ID_BYTES); + reinterpret_cast(xcclID) + ONECCL_UNIQUE_ID_BYTES); store_->set(storeKey, vec); } else { try { auto vec = store_->get(storeKey); TORCH_CHECK_WITH( DistBackendError, - vec.size() == XCCL_UNIQUE_ID_BYTES, + vec.size() == ONECCL_UNIQUE_ID_BYTES, "Invalid size for xcclUniqueId"); std::memcpy(xcclID, vec.data(), vec.size()); } catch (const std::exception& e) { @@ -1590,216 +1552,224 @@ c10::intrusive_ptr ProcessGroupXCCL::barrier(const BarrierOptions& opts) { return work; } -// c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( -// at::Tensor& outputTensor, -// at::Tensor& inputTensor, -// std::vector& outputSplitSizes, -// std::vector& inputSplitSizes, -// const AllToAllOptions& /* unused */) { -// checkSingleTensor(outputTensor, true); -// checkSingleTensor(inputTensor, true); -// if (outputSplitSizes.size() == 0 && inputSplitSizes.size() == 0) { -// RECORD_PARAM_COMMS_DATA( -// static_cast( -// this->getSequenceNumberForGroup() + -// 1), // seq + 1 to match collective -// std::make_tuple(pg_uid_, pg_desc_), // PG name tuple -// inputTensor, // inputTensor -// outputTensor, // outputTensor -// rank_, // rank -// "all_to_all", // collective name -// inputTensor.numel(), // inNelems -// outputTensor.numel(), // outNelems -// inputTensor.scalar_type(), // dType -// std::vector(), // inSplitSizes -// std::vector(), // outSplitSizes -// -1, // globalRankStart -// -1, // globalRankStride -// this->getSize()); // worldSize -// TORCH_CHECK( -// outputTensor.numel() == inputTensor.numel() && -// outputTensor.scalar_type() == inputTensor.scalar_type(), -// "xpu_alltoall_base: tensors are not equal in size or data type"); -// TORCH_CHECK( -// outputTensor.size(0) % size_ == 0, -// "xpu_alltoall_base: tensor's dim 0 does not divide equally across -// group size"); -// return collective( -// inputTensor, -// outputTensor, -// [&](at::Tensor& input, -// at::Tensor& output, -// xcclComm_t& comm, -// at::xpu::XPUStream& stream) { -// c10::xpu::XPUCachingAllocator::recordStream( -// output.storage().data_ptr(), stream); -// auto xcclDataType = getXcclDataType(output.scalar_type()); -// ccl::alltoall( -// input.data_ptr(), -// output.data_ptr(), -// (size_t)output.numel() / comm.size(), -// xcclDataType, -// comm, -// ccl::create_stream(stream.queue())); -// return; -// }, -// OpType::ALLTOALL_BASE, -// "xccl:all_to_all"); -// } else { -// c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); -// c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); - -// RECORD_PARAM_COMMS_DATA( -// static_cast( -// this->getSequenceNumberForGroup() + -// 1), // seq + 1 to match collective -// std::make_tuple(pg_uid_, pg_desc_), // PG name tuple -// inputTensor, // inputTensor -// outputTensor, // outputTensor -// rank_, // rank -// "all_to_allv", // collective name -// inputTensor.numel(), // inNelems -// outputTensor.numel(), // outNelems -// inputTensor.scalar_type(), // dType -// inputSplitSizes, // inSplitSizes -// outputSplitSizes, // outSplitSizes -// -1, // globalRankStart -// -1, // globalRankStride -// this->getSize()); // worldSize - -// return collective( -// inputTensor, -// outputTensor, -// [&](at::Tensor& input, -// at::Tensor& output, -// xcclComm_t& comm, -// at::xpu::XPUStream& stream) { -// std::vector sendCounts(size_); -// std::vector recvCounts(size_); -// bool inputSplitsEqual = inputSplitSizes.size() == 0; -// bool outputSplitsEqual = outputSplitSizes.size() == 0; - -// size_t inLen = input.numel(); -// size_t outLen = output.numel(); -// if (inLen) -// inLen /= (inputSplitsEqual ? size_ : input.size(0)); -// if (outLen) -// outLen /= (outputSplitsEqual ? size_ : output.size(0)); - -// for (int i = 0; i < size_; i++) { -// sendCounts[i] = -// (inputSplitsEqual ? inLen : inputSplitSizes[i] * inLen); -// recvCounts[i] = -// (outputSplitsEqual ? outLen : outputSplitSizes[i] * outLen); -// } -// auto xcclDataType = getXcclDataType(output.scalar_type()); -// ccl::alltoallv( -// input.data_ptr(), -// sendCounts, -// output.data_ptr(), -// recvCounts, -// xcclDataType, -// comm, -// ccl::create_stream(stream.queue())); -// return; -// }, -// OpType::ALLTOALL_BASE, -// "xccl:all_to_all"); -// } -// } - -// c10::intrusive_ptr ProcessGroupXCCL::alltoall( -// std::vector& outputTensors, -// std::vector& inputTensors, -// const AllToAllOptions& /* unused */) { -// auto device = outputTensors[0].device(); -// int64_t total_numel = 0; -// for (const auto r : c10::irange(outputTensors.size())) { -// checkSingleTensor(outputTensors[r], true); -// checkSingleTensor(inputTensors[r], true); -// TORCH_CHECK( -// device == outputTensors[r].device() && -// device == inputTensors[r].device(), -// "Tensors must be on the same device") -// total_numel += inputTensors[r].numel(); -// } - -// RECORD_PARAM_COMMS_DATA( -// static_cast( -// this->getSequenceNumberForGroup() + 1), // seq + 1 to match -// collective -// std::make_tuple(pg_uid_, pg_desc_), // PG name tuple -// inputTensors, // inputTensors -// outputTensors, // outputTensors -// rank_, // rank -// "all_to_all", // collective name -// total_numel, // inNelems -// total_numel, // outNelems -// inputTensors.front().scalar_type(), // dType -// std::vector(), // inSplitSizes -// std::vector(), // outSplitSizes -// -1, // globalRankStart -// -1, // globalRankStride -// this->getSize()); // worldSize - -// return collective( -// inputTensors, -// outputTensors, -// [&](at::Tensor& /* unused */, -// at::Tensor& /* unused */, -// xcclComm_t& comm, -// at::xpu::XPUStream& stream) { -// c10::OptionalStreamGuard stream_guard(stream.unwrap()); -// at::Tensor flatInput; -// at::Tensor flatOutput; - -// std::vector sendCounts(size_); -// std::vector recvCounts(size_); - -// int64_t flatSendCount; -// int64_t flatRecvCount; - -// bool isInputFlat = computeLengthsAndCheckAndGetFlat( -// inputTensors, sendCounts, flatInput, flatSendCount); -// bool isOutputFlat = computeLengthsAndCheckAndGetFlat( -// outputTensors, recvCounts, flatOutput, flatRecvCount); -// if (!isInputFlat) { -// auto flatInputSplits = flatInput.split_with_sizes( -// c10::IntArrayRef((int64_t*)sendCounts.data(), -// sendCounts.size()), 0); - -// for (int i = 0; i < size_; i++) { -// flatInputSplits[i].copy_(inputTensors[i].view({-1})); -// } -// } - -// auto xcclDataType = getXcclDataType(flatOutput.scalar_type()); -// ccl::event ret_evt; -// ret_evt = ccl::alltoallv( -// flatInput.data_ptr(), -// sendCounts, -// flatOutput.data_ptr(), -// recvCounts, -// xcclDataType, -// comm, -// ccl::create_stream(stream.queue())); - -// if (!isOutputFlat) { -// ret_evt.wait(); -// auto flatOutputSplits = flatOutput.split_with_sizes( -// c10::IntArrayRef((int64_t*)recvCounts.data(), -// recvCounts.size()), 0); - -// for (int i = 0; i < size_; i++) { -// outputTensors[i].view({-1}).copy_(flatOutputSplits[i]); -// } -// } -// stream.synchronize(); -// return; -// }, -// OpType::ALLTOALL, -// "xccl:all_to_all"); -// } +c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + std::vector& outputSplitSizes, + std::vector& inputSplitSizes, + const AllToAllOptions& /* unused */) { + checkSingleTensor(outputTensor, true); + checkSingleTensor(inputTensor, true); + if (outputSplitSizes.size() == 0 && inputSplitSizes.size() == 0) { + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensor, // inputTensor + outputTensor, // outputTensor + rank_, // rank + "all_to_all", // collective name + inputTensor.numel(), // inNelems + outputTensor.numel(), // outNelems + inputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + TORCH_CHECK( + outputTensor.numel() == inputTensor.numel() && + outputTensor.scalar_type() == inputTensor.scalar_type(), + "xpu_alltoall_base: tensors are not equal in size or data type"); + TORCH_CHECK( + outputTensor.size(0) % size_ == 0, + "xpu_alltoall_base: tensor's dim 0 does not divide equally across group size"); + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(output.scalar_type()); + size_t count = input.numel() / size_; + size_t rankdiff = input.nbytes() / size_; + + onecclGroupStart(); + for (const auto r : c10::irange(rank_)) { + if (count != 0) { + onecclSend( + ((char*)input.data_ptr()) + r * rankdiff, + count, + xcclDataType, + r, + comm, + &(stream.queue())); + onecclRecv( + ((char*)output.data_ptr()) + r * rankdiff, + count, + xcclDataType, + r, + comm, + &(stream.queue())); + } + } + onecclGroupEnd(); + + return; + }, + OpType::ALLTOALL_BASE, + "xccl:all_to_all"); + } else { + c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); + c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensor, // inputTensor + outputTensor, // outputTensor + rank_, // rank + "all_to_allv", // collective name + inputTensor.numel(), // inNelems + outputTensor.numel(), // outNelems + inputTensor.scalar_type(), // dType + inputSplitSizes, // inSplitSizes + outputSplitSizes, // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + std::vector send_lengths(size_); + std::vector recv_lengths(size_); + std::vector send_offsets(size_); + std::vector recv_offsets(size_); + c10d::computeLengthsAndOffsets( + inputSplitSizes, input, &send_lengths, &send_offsets); + c10d::computeLengthsAndOffsets( + outputSplitSizes, output, &recv_lengths, &recv_offsets); + + size_t size = input.element_size(); + auto xcclDataType = getXcclDataType(input.scalar_type()); + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + + auto send_offsets_data = send_offsets.data(); + auto recv_offsets_data = recv_offsets.data(); + + onecclGroupStart(); + for (const auto r : c10::irange(size_)) { + if (send_lengths[r] != 0) { + onecclSend( + ((char*)input.data_ptr()) + send_offsets_data[r] * size, + send_lengths[r], + xcclDataType, + r, + comm, + &(stream.queue())); + } + if (recv_lengths[r] != 0) { + onecclRecv( + ((char*)output.data_ptr()) + recv_offsets_data[r] * size, + recv_lengths[r], + xcclDataType, + r, + comm, + &(stream.queue())); + } + } + onecclGroupEnd(); + + return; + }, + OpType::ALLTOALL_BASE, + "xccl:all_to_all"); + } +} + +c10::intrusive_ptr ProcessGroupXCCL::alltoall( + std::vector& outputTensors, + std::vector& inputTensors, + const AllToAllOptions& /* unused */) { + auto device = outputTensors[0].device(); + int64_t total_numel = 0; + for (const auto r : c10::irange(outputTensors.size())) { + checkSingleTensor(outputTensors[r], true); + checkSingleTensor(inputTensors[r], true); + TORCH_CHECK( + device == outputTensors[r].device() && + device == inputTensors[r].device(), + "Tensors must be on the same device") + total_numel += inputTensors[r].numel(); + } + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + rank_, // rank + "all_to_all", // collective name + total_numel, // inNelems + total_numel, // outNelems + inputTensors.front().scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + return collective( + inputTensors, + outputTensors, + [&](at::Tensor& /* unused */, + at::Tensor& /* unused */, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + onecclGroupStart(); + for (const int r : + c10::irange(static_cast(outputTensors.size()))) { + at::Tensor& input = inputTensors[r]; + at::Tensor& output = outputTensors[r]; + if (input.numel() != 0) { + onecclSend( + input.data_ptr(), + input.numel(), + getXcclDataType(input.scalar_type()), + r, + comm, + &(stream.queue())); + } + if (output.numel() != 0) { + onecclRecv( + output.data_ptr(), + output.numel(), + getXcclDataType(output.scalar_type()), + r, + comm, + &(stream.queue())); + } + } + onecclGroupEnd(); + + return; + }, + OpType::ALLTOALL, + "xccl:all_to_all"); +} } // namespace c10d diff --git a/src/xccl/ProcessGroupXCCL.hpp b/src/xccl/ProcessGroupXCCL.hpp index ed259b859..05306e9ee 100644 --- a/src/xccl/ProcessGroupXCCL.hpp +++ b/src/xccl/ProcessGroupXCCL.hpp @@ -21,7 +21,6 @@ #include namespace c10d { -#define XCCL_UNIQUE_ID_BYTES 128 static std::vector TORCH_XCCL_BLOCKING_WAIT = { "TORCH_XCCL_BLOCKING_WAIT", "XCCL_BLOCKING_WAIT"}; @@ -289,17 +288,17 @@ class TORCH_API ProcessGroupXCCL : public Backend { c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; - // c10::intrusive_ptr alltoall_base( - // at::Tensor& outputTensor, - // at::Tensor& inputTensor, - // std::vector& outputSplitSizes, - // std::vector& inputSplitSizes, - // const AllToAllOptions& opts = AllToAllOptions()) override; + c10::intrusive_ptr alltoall_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + std::vector& outputSplitSizes, + std::vector& inputSplitSizes, + const AllToAllOptions& opts = AllToAllOptions()) override; - // c10::intrusive_ptr alltoall( - // std::vector& outputTensors, - // std::vector& inputTensors, - // const AllToAllOptions& opts = AllToAllOptions()) override; + c10::intrusive_ptr alltoall( + std::vector& outputTensors, + std::vector& inputTensors, + const AllToAllOptions& opts = AllToAllOptions()) override; c10::intrusive_ptr send( std::vector& tensors, @@ -344,45 +343,6 @@ class TORCH_API ProcessGroupXCCL : public Backend { static thread_local uint64_t xcclActiveGroupCounter_; uint64_t seqCollective_{0}; uint64_t seqP2P_{0}; - - // private: - // std::mutex kvs_mutex; - - // ccl::shared_ptr_class get_kvs( - // int rank, - // c10d::Store& store, - // bool singleP2POp = false, - // const std::string& p2pKey = "", - // int p2pRank = 0) { - // std::lock_guard lock(kvs_mutex); - // ccl::shared_ptr_class kvs; - // std::string storeKey; - // if (!singleP2POp) { - // storeKey = std::to_string(xcclCommCounter_++); - // } else { - // storeKey = p2pKey; - // } - // // Rank 0 broadcast the bootstrap network information to other ranks - // if (rank == 0 || (singleP2POp && p2pRank == 0)) { - // kvs = ccl::create_main_kvs(); - // ccl::kvs::address_type main_addr = kvs->get_address(); - // auto ccl_kvs_addr = - // std::vector(main_addr.begin(), main_addr.end()); - // store.set(storeKey, ccl_kvs_addr); - // } else { - // auto ccl_kvs_addr = store.get(storeKey); - // if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) { - // throw std::runtime_error("Unexpected ccl kvs addr from the - // store\n"); - // } - // ccl::kvs::address_type main_addr; - // std::copy_n( - // ccl_kvs_addr.begin(), ccl::kvs::address_max_size, - // main_addr.begin()); - // kvs = ccl::create_kvs(main_addr); - // } - // return kvs; - // } }; } // namespace c10d From c2220fc212b9f6a00c42b88089da12c509532324 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Wed, 26 Feb 2025 17:06:50 +0800 Subject: [PATCH 06/10] add CCL_PROCESS_LAUNCHER check and set CCL_LOCAL_RANK\CCL_LOCAL_SIZE --- src/xccl/ProcessGroupXCCL.cpp | 24 +++++++++++++++++++++--- src/xccl/ProcessGroupXCCL.hpp | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index ce0326b53..d38a9e03c 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -205,6 +205,17 @@ ProcessGroupXCCL::ProcessGroupXCCL( : Backend(rank, size), store_(store), xcclCommCounter_(0) { blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false); init(); + if (!with_mpirun()) { + int local_rank = getXCCLEnvVar("LOCAL_RANK"); + int local_world_size = getXCCLEnvVar("LOCAL_WORLD_SIZE"); + if (local_rank == -1 || local_world_size == -1) { + local_rank = rank; + local_world_size = size; + } + setXCCLEnvVar("CCL_PROCESS_LAUNCHER", "none"); + setXCCLEnvVar("CCL_LOCAL_RANK", local_rank); + setXCCLEnvVar("CCL_LOCAL_SIZE", local_world_size); + } } ProcessGroupXCCL::~ProcessGroupXCCL() = default; @@ -330,7 +341,6 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( c10::impl::VirtualGuardImpl impl(device.type()); c10::Stream stream = impl.getStreamFromGlobalPool(device, /*isHighPriority=*/false); - sycl::queue& q = c10::xpu::XPUStream(stream).queue(); if (rank_ == 0 || (singleP2POp && p2pRank == 0)) { onecclGetUniqueId(&xcclID); @@ -338,8 +348,16 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( broadcastUniqueXCCLID(&xcclID, singleP2POp, deviceKey, p2pRank); xcclComm_t comm = nullptr; - onecclCommInitRank(&comm, numRanks, xcclID, rank); - XCCLComm = std::make_shared(std::move(comm)); + onecclResult_t result = ONECCL_SUCCESS; + result = onecclSetDevice(rank); + if (result != ONECCL_SUCCESS) { + std::cerr << "Failed to set device.\n"; + } + result = onecclCommInitRank(&comm, numRanks, xcclID, rank); + if (result != ONECCL_SUCCESS) { + std::cerr << "Failed to initialize communicator.\n"; + } + XCCLComm = std::make_shared(comm); RECORD_PARAM_COMMS( 0, // seq diff --git a/src/xccl/ProcessGroupXCCL.hpp b/src/xccl/ProcessGroupXCCL.hpp index 05306e9ee..5b6831415 100644 --- a/src/xccl/ProcessGroupXCCL.hpp +++ b/src/xccl/ProcessGroupXCCL.hpp @@ -21,6 +21,40 @@ #include namespace c10d { +namespace { +int getXCCLEnvVar(std::string envVarName) { + char* stringValue = std::getenv(envVarName.c_str()); + if (stringValue != nullptr) { + try { + int val = std::stoi(stringValue); + return val; + } catch (std::exception& e) { + TORCH_CHECK( + false, + "Invalid value for environment variable: " + std::string(envVarName)); + } + } else { + return -1; + } +} + +template +void setXCCLEnvVar(const std::string& envVarName, T val) { + if constexpr (std::is_same_v) { + setenv(envVarName.c_str(), std::to_string(val).c_str(), 1); + } else if constexpr (std::is_same_v) { + setenv(envVarName.c_str(), val.c_str(), 1); + } +} + +bool with_mpirun() { + return (getenv("MPI_LOCALRANKID") || getenv("MPI_LOCALNRANKS") || + getenv("PMI_RANK") || getenv("PMI_SIZE") || getenv("PMIX_RANK")) + ? true + : false; +} +} // namespace + static std::vector TORCH_XCCL_BLOCKING_WAIT = { "TORCH_XCCL_BLOCKING_WAIT", "XCCL_BLOCKING_WAIT"}; From 11500636e8412168d071207b1279a33e3eb00dc6 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Tue, 25 Mar 2025 19:10:54 +0800 Subject: [PATCH 07/10] update --- src/xccl/ProcessGroupXCCL.cpp | 2157 ++++++++++++++++----------------- src/xccl/ProcessGroupXCCL.hpp | 8 +- 2 files changed, 1081 insertions(+), 1084 deletions(-) diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 3948b18b7..2b26e8c69 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -363,7 +363,7 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( for (const auto i : c10::irange(xcclActiveGroupCounter_)) { (void)i; - ccl::group_end(); + onecclGroupEnd(); } int numRanks, rank; @@ -415,22 +415,21 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( for (const auto i : c10::irange(xcclActiveGroupCounter_)) { (void)i; - ccl::group_start(); + onecclGroupStart(); } - // The oneCCL group API requires retaining the SYCL queue (xcclstream) object + // The oneCCL group API requires retaining the SYCL queue (SyclQueue) object // within the lifecycle of the communicator. If the XPU stream is created // within the collective operation, it would be destroyed earlier than the // communicator after the operation ends. Therefore, the XPU stream is stored // in a map alongside the communicator. Similarly, oneCCLv2 also requires // retaining the SYCL queue pointer for collective operations, so this change // will be necessary in oneCCLv2 as well. - ccl::stream xccl_stream = ccl::create_stream(q); std::lock_guard lock(mutex_); + sycl::queue& q = c10::xpu::XPUStream(stream).queue(); devXCCLCommMap_.emplace(deviceKey, XCCLComm); xcclStreamsMap_.emplace( - deviceKey, - std::make_pair(at::xpu::XPUStream(stream), std::move(xccl_stream))); + deviceKey, std::make_pair(at::xpu::XPUStream(stream), q)); xcclEventsMap_.emplace(deviceKey, at::xpu::XPUEvent()); return XCCLComm; @@ -673,7 +672,7 @@ c10::intrusive_ptr ProcessGroupXCCL::send( [&](at::Tensor& input, xcclComm_t& comm, at::xpu::XPUStream& stream, - ccl::stream& xcclStream, + sycl::queue& SyclQueue, int dst) { auto xcclDataType = getXcclDataType(input.scalar_type()); onecclSend( @@ -682,7 +681,7 @@ c10::intrusive_ptr ProcessGroupXCCL::send( xcclDataType, dst, comm, - xcclStream); + &SyclQueue); return; }, dstRank, @@ -722,7 +721,7 @@ c10::intrusive_ptr ProcessGroupXCCL::recv( [&](at::Tensor& output, xcclComm_t& comm, at::xpu::XPUStream& stream, - ccl::stream& xcclStream, + sycl::queue& SyclQueue, int src) { auto xcclDataType = getXcclDataType(output.scalar_type()); onecclRecv( @@ -731,7 +730,7 @@ c10::intrusive_ptr ProcessGroupXCCL::recv( xcclDataType, src, comm, - xcclStream); + &SyclQueue); return; }, srcRank, @@ -810,27 +809,32 @@ c10::intrusive_ptr ProcessGroupXCCL::gather( at::Tensor& /* unused */, xcclComm_t& comm, at::xpu::XPUStream& stream, - ccl::stream& xcclStream) { - const auto root = opts.rootRank; - if (getRank() == root) { - for (auto output : outputs) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } - } - { - auto xcclDataType = getXcclDataType(inputTensor.scalar_type()); - if (rank_ == root) { - for (const auto r : c10::irange(size_)) { - if (r != root) { - // do receive - onecclRecv( - outputs[r].data_ptr(), - (size_t)inputTensor.numel(), - xcclDataType, - r, - comm, - xcclStream); + sycl::queue& SyclQueue) { + const auto root = opts.rootRank; + if (getRank() == root) { + for (auto output : outputs) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + } + { + auto xcclDataType = getXcclDataType(inputTensor.scalar_type()); + if (rank_ == root) { + for (const auto r : c10::irange(size_)) { + if (r != root) { + // do receive + onecclRecv( + outputs[r].data_ptr(), + (size_t)inputTensor.numel(), + xcclDataType, + r, + comm, + &SyclQueue); + } else { + // on its own rank, simply copy from the input + outputs[r].copy_(inputTensor); + } + } } else { // do send onecclSend( @@ -839,81 +843,80 @@ c10::intrusive_ptr ProcessGroupXCCL::gather( xcclDataType, root, comm, - &(stream.queue())); + &SyclQueue); } + return; } - return; - } - }, - OpType::GATHER); + }, + OpType::GATHER, + "xccl:gather"); } c10::intrusive_ptr ProcessGroupXCCL::scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts) { - static auto invalidArgument = [](const std::string& msg) { - C10_THROW_ERROR(ValueError, "ProcessGroupXCCL::scatter: " + msg); - }; - - assertRootRank(invalidArgument, opts.rootRank, size_); - - TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - auto outputTensor = outputTensors.back(); - - std::vector inputs; - - if (getRank() == opts.rootRank) { - if (inputTensors.size() != 1) { - std::stringstream ss; - ss << "requires a single-element input list containing a list with " - << getSize() << " tensors."; - invalidArgument(ss.str()); - } else if (inputTensors[0].size() != static_cast(getSize())) { - std::stringstream ss; - ss << "Incorrect input list size " << inputTensors[0].size() - << ". Input list size should be " << getSize() - << ", same as size of the process group."; - invalidArgument(ss.str()); - } + static auto invalidArgument = [](const std::string& msg) { + C10_THROW_ERROR(ValueError, "ProcessGroupXCCL::scatter: " + msg); + }; - const auto& options = outputTensor.options(); - const auto& sizes = outputTensor.sizes(); - assertTypeAndSizesMatch(invalidArgument, inputTensors[0], options, sizes); - inputs = inputTensors[0]; - } else { - // if not in the root rank, initialize inputTensors as empty place holder - // with an empty list - if (inputTensors.size() != 0) { - invalidArgument("requires empty input on non-root"); - } - inputs = {}; - // append a empty tensor to the list, we don't use it but the - // `collective` template function requires it to invoke its function - inputs.emplace_back(); + assertRootRank(invalidArgument, opts.rootRank, size_); + + TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto outputTensor = outputTensors.back(); + + std::vector inputs; + + if (getRank() == opts.rootRank) { + if (inputTensors.size() != 1) { + std::stringstream ss; + ss << "requires a single-element input list containing a list with " + << getSize() << " tensors."; + invalidArgument(ss.str()); + } else if (inputTensors[0].size() != static_cast(getSize())) { + std::stringstream ss; + ss << "Incorrect input list size " << inputTensors[0].size() + << ". Input list size should be " << getSize() + << ", same as size of the process group."; + invalidArgument(ss.str()); } - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + - 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensors, // inputTensors - outputTensors, // outputTensors - opts.rootRank, // root rank - "scatter", // collective name - outputTensor.numel() * this->getSize(), // inNelems - outputTensor.numel(), // outNelems - outputTensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSize - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize + const auto& options = outputTensor.options(); + const auto& sizes = outputTensor.sizes(); + assertTypeAndSizesMatch(invalidArgument, inputTensors[0], options, sizes); + inputs = inputTensors[0]; + } else { + // if not in the root rank, initialize inputTensors as empty place + // holder with an empty list + if (inputTensors.size() != 0) { + invalidArgument("requires empty input on non-root"); + } + inputs = {}; + // append a empty tensor to the list, we don't use it but the + // `collective` template function requires it to invoke its function + inputs.emplace_back(); + } + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + opts.rootRank, // root rank + "scatter", // collective name + outputTensor.numel() * this->getSize(), // inNelems + outputTensor.numel(), // outNelems + outputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSize + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize - const auto root = opts.rootRank; + const auto root = opts.rootRank; - auto outputs = std::vector{outputTensor}; + auto outputs = std::vector{outputTensor}; return collective( outputs, inputs, // just to fit the collective interface @@ -921,1015 +924,1009 @@ c10::intrusive_ptr ProcessGroupXCCL::scatter( at::Tensor& /* unused */, xcclComm_t& comm, at::xpu::XPUStream& stream, - ccl::stream& xcclStream) { - if (getRank() == root) { - for (auto input : inputs) { - c10::xpu::XPUCachingAllocator::recordStream( - input.storage().data_ptr(), stream); + sycl::queue& SyclQueue) { + if (getRank() == root) { + for (auto input : inputs) { + c10::xpu::XPUCachingAllocator::recordStream( + input.storage().data_ptr(), stream); + } } - } - { - if (rank_ == root) { - for (const auto r : c10::irange(size_)) { - if (r != root) { - // do send - size_t send_count = inputs[r].numel(); - auto send_type = getXcclDataType(inputs[r].scalar_type()); - onecclSend( - inputs[r].data_ptr(), - send_count, - send_type, - r, - comm, - - xcclStream); - } else { - // on its own rank, simply copy from the input - outputTensor.copy_(inputs[r]); - // do receive - auto recv_type = getXcclDataType(outputTensor.scalar_type()); - onecclRecv( - outputTensor.data_ptr(), - recv_count, - recv_type, - root, - comm, - xcclStream); + { + if (rank_ == root) { + for (const auto r : c10::irange(size_)) { + if (r != root) { + // do send + size_t send_count = inputs[r].numel(); + auto send_type = getXcclDataType(inputs[r].scalar_type()); + onecclSend( + inputs[r].data_ptr(), + send_count, + send_type, + r, + comm, + &SyclQueue); + } else { + // on its own rank, simply copy from the input + outputTensor.copy_(inputs[r]); + } } - return; + } else { + // do receive + size_t recv_count = outputTensor.numel(); + auto recv_type = getXcclDataType(outputTensor.scalar_type()); + onecclRecv( + outputTensor.data_ptr(), + recv_count, + recv_type, + root, + comm, + &SyclQueue); } - }, - OpType::SCATTER); - } - - c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( - at::Tensor & tensor, - const char* profilingTitle, - const AllreduceOptions& opts) { - return collective( - tensor, - tensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream, - ccl::stream& xcclStream) { - auto xcclDataType = getXcclDataType(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - onecclAllReduce( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - xcclReduceOp, - comm, - xcclStream); -#if !defined(XCCL_HAS_AVG) - if (opts.reduceOp == ReduceOp::AVG) { - auto divisor = getSize(); - c10::StreamGuard guard(stream); - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - output.div_(divisor); - } -#endif - return; - }, - OpType::ALLREDUCE, - profilingTitle); - } - c10::intrusive_ptr ProcessGroupXCCL::allreduce( - std::vector & tensors, const AllreduceOptions& opts) { - TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - auto tensor = tensors.back(); - if (tensor.is_complex()) { - TORCH_CHECK( - complexViewAsRealAllowed(opts.reduceOp), - "all_reduce does not support", - opts.reduceOp, - "on complex tensors"); - tensor = at::view_as_real(tensor); + return; } - checkSingleTensor(tensor); - - // @lint-ignore CLANGTIDY - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + - 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - tensors, // inputTensors - tensors, // outputTensors - rank_, // rank - "allreduce", // collective name - tensor.numel(), // inNelems - tensor.numel(), // outNelems - tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - size_); // worldSize - - return allreduce_impl(tensor, "xccl:all_reduce", opts); - } + }, + OpType::SCATTER, + "xccl:scatter"); +} - c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( - std::vector & tensors, - const AllreduceCoalescedOptions& opts) { - auto total_numel = checkTensorOnSameDevice(tensors); - - // @lint-ignore CLANGTIDY - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + - 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - tensors, // inputTensors - tensors, // outputTensors - rank_, // rank - "allreduce_coalesced", // collective name - total_numel, // inNelems - total_numel, // outNelems - tensors[0].scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - return collectiveCoalesced( - tensors, - tensors, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream, - ccl::stream& xcclStream) { - auto xcclDataType = getXcclDataType(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - onecclAllReduce( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - xcclReduceOp, - comm, - xcclStream); +c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( + at::Tensor& tensor, + const char* profilingTitle, + const AllreduceOptions& opts) { + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream, + sycl::queue& SyclQueue) { + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + onecclAllReduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + comm, + &SyclQueue); #if !defined(XCCL_HAS_AVG) - if (opts.reduceOp == ReduceOp::AVG) { - auto divisor = getSize(); - c10::StreamGuard guard(stream); - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - output.div_(divisor); - } -#endif - return; - }, - OpType::COALESCED, - "xccl:allreduce_coalesced"); - } - - c10::intrusive_ptr ProcessGroupXCCL::broadcast( - std::vector & tensors, const BroadcastOptions& opts) { - TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - auto tensor = tensors.back(); - if (tensor.is_complex()) { - tensor = at::view_as_real(tensor); - } - checkSingleTensor(tensor); - - // @lint-ignore CLANGTIDY - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + - 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - tensors, // inputTensors - tensors, // outputTensors - opts.rootRank, // root rank - "broadcast", // collective name - tensor.numel(), // inNelems - tensor.numel(), // outNelems - tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - const auto root = opts.rootRank + opts.rootTensor; - - return collective( - tensor, - tensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream, - ccl::stream& xcclStream) { - auto xcclDataType = getXcclDataType(input.scalar_type()); - onecclBroadcast( - input.data_ptr(), - output.data_ptr(), // ? - (size_t)input.numel(), - xcclDataType, - root, - comm, - xcclStream); - return; - }, - OpType::BROADCAST, - "nccl:broadcast"); - } - - c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( - at::Tensor & outputTensor, - at::Tensor & inputTensor, - const BroadcastOptions& opts) { - if (outputTensor.numel() != inputTensor.numel()) { - C10_THROW_ERROR( - ValueError, - "Tensor input and output of _broadcast_oop must have the same number of elements "); - } - const auto root = opts.rootRank + opts.rootTensor; - return collective( - inputTensor, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream, - ccl::stream& xcclStream) { - auto xcclDataType = getXcclDataType(input.scalar_type()); - onecclBroadcast( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - root, - comm, - xcclStream); - return; - }, - OpType::BROADCAST, - "xccl:_broadcast_oop"); - } - - c10::intrusive_ptr ProcessGroupXCCL::reduce( - std::vector & tensors, const ReduceOptions& opts) { - TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - auto tensor = tensors.back(); - if (tensor.is_complex()) { - TORCH_CHECK( - complexViewAsRealAllowed(opts.reduceOp), - "reduce does not support", - opts.reduceOp, - "on complex tensors"); - tensor = at::view_as_real(tensor); + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + c10::StreamGuard guard(stream); + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + output.div_(divisor); } - checkSingleTensor(tensor); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + - 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - tensors, // inputTensors - tensors, // outputTensors - opts.rootRank, // root rank - "reduce", // collective name - tensor.numel(), // inNelems - tensor.numel(), // outNelems - tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - return collective( - tensor, - tensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream, - ccl::stream& xcclStream) { - const int root = opts.rootRank + opts.rootTensor; - const auto xcclDataType = - getXcclDataType(input.scalar_type(), true); - const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - onecclReduce( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - xcclReduceOp, - root, - comm, - xcclStream); -#if !defined(XCCL_HAS_AVG) - if (opts.reduceOp == ReduceOp::AVG && getRank() == root) { - auto divisor = getSize(); - c10::StreamGuard guard(stream); - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - output.div_(divisor); - } -#endif - return; - }, - OpType::REDUCE, - "xccl:reduce"); - } - - c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( - at::Tensor & outputTensor, - at::Tensor & inputTensor, - const ReduceOptions& opts) { - TORCH_CHECK_WITH( - ValueError, - outputTensor.numel() == inputTensor.numel(), - "Tensor input and output of _reduce_oop must have the same number of elements"); - return collective( - inputTensor, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream, - ccl::stream& xcclStream) { - const int root = opts.rootRank + opts.rootTensor; - const auto xcclDataType = - getXcclDataType(input.scalar_type(), true); - const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - onecclReduce( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - xcclReduceOp, - root, - comm, - xcclStream); -#if !defined(XCCL_HAS_AVG) - if (opts.reduceOp == ReduceOp::AVG && getRank() == root) { - auto divisor = getSize(); - c10::StreamGuard guard(stream); - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - output.div_(divisor); - } #endif - return; - }, - OpType::REDUCE, - "xccl:_reduce_oop"); - } + return; + }, + OpType::ALLREDUCE, + profilingTitle); +} - c10::intrusive_ptr ProcessGroupXCCL::allgather( - std::vector> & outputTensors, - std::vector & inputTensors, - const AllgatherOptions& opts) { - TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - // @lint-ignore CLANGTIDY - auto inputTensor = inputTensors.back(); - checkSingleTensor(inputTensor); - // @lint-ignore CLANGTIDY - std::vector& outputTensors_ = outputTensors.back(); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + - 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensors, // inputTensors - outputTensors, // outputTensors - rank_, // rank - "all_gather", // collective name - inputTensor.numel(), // inNelems - inputTensor.numel() * // outNelems - this->getSize(), - inputTensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSize - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - bool same_size = checkSameSize(outputTensors_); - if (same_size) { - // Flatten a vector of tensors into a single, stacked tensor. - at::Tensor outputFlattened = newLikeFlat(outputTensors_); - - return collective( - inputTensor, - outputFlattened, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream, - ccl::stream& xcclStream) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type()); - onecclAllGather( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - comm, - xcclStream); - return; - }, - [](at::xpu::XPUStream&, - c10::intrusive_ptr& work) {}, - [&](at::xpu::XPUStream& Stream, - c10::intrusive_ptr& work) { - // Copy the flattened output tensors to the outputs. - c10::StreamGuard guard(Stream); - for (const auto j : c10::irange(outputTensors_.size())) { - c10::xpu::XPUCachingAllocator::recordStream( - outputTensors_[j].storage().data_ptr(), Stream); - outputTensors_[j].copy_(outputFlattened[j], true); - } - }, - OpType::ALLGATHER, - "xccl:all_gather"); - } else { - const auto num_reduces = outputTensors_.size(); - startCoalescing(); - for (const int i : c10::irange(num_reduces)) { - auto& output = outputTensors_[i]; - auto& input = (i == rank_) ? inputTensor : output; - auto broadcastOpts = BroadcastOptions{ - static_cast(i), static_cast(0), opts.timeout}; - _broadcast_oop(output, input, broadcastOpts); - } - auto work = endCoalescing(OpType::ALLGATHER); - return work; - } - } +c10::intrusive_ptr ProcessGroupXCCL::allreduce( + std::vector& tensors, + const AllreduceOptions& opts) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + if (tensor.is_complex()) { + TORCH_CHECK( + complexViewAsRealAllowed(opts.reduceOp), + "all_reduce does not support", + opts.reduceOp, + "on complex tensors"); + tensor = at::view_as_real(tensor); + } + checkSingleTensor(tensor); - c10::intrusive_ptr ProcessGroupXCCL::_allgather_base( - at::Tensor & output_tensor, - at::Tensor & input_tensor, - const AllgatherOptions& opts) { - checkSingleTensor(input_tensor); - checkSingleTensor(output_tensor); - - TORCH_CHECK_WITH( - TypeError, - input_tensor.dtype() == output_tensor.dtype(), - "output tensor must have the same type as input tensor"); - TORCH_CHECK_WITH( - ValueError, - input_tensor.numel() * size_ == output_tensor.numel(), - "output tensor size must be equal to world_size times input tensor size"); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + - 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - input_tensor, // inputTensors - output_tensor, // outputTensors - rank_, // rank - "_allgather_base", // collective name - input_tensor.numel(), // inNelems - output_tensor.numel(), // outNelems - output_tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSize - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - return collective( - input_tensor, - output_tensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream, - ccl::stream& xcclStream) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type()); - onecclAllGather( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - comm, - xcclStream); - return; - }, - OpType::_ALLGATHER_BASE, - "xccl:_all_gather_base"); - } + // @lint-ignore CLANGTIDY + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + rank_, // rank + "allreduce", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + size_); // worldSize - c10::intrusive_ptr - ProcessGroupXCCL::allgather_into_tensor_coalesced( - std::vector & outputs, - std::vector & inputs, - const AllgatherOptions& opts) { - return collectiveCoalesced( - inputs, - outputs, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream, - ccl::stream& xcclStream) { - auto xcclDataType = getXcclDataType(input.scalar_type()); - onecclAllGather( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - comm, - xcclStream); - return; - }, - OpType::COALESCED, - "xccl:all_gather_into_tensor_coalesced"); - } + return allreduce_impl(tensor, "xccl:all_reduce", opts); +} - c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( - std::vector & outputTensors, - std::vector> & inputTensors, - const ReduceScatterOptions& opts) { - TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - // @lint-ignore CLANGTIDY - auto outputTensor = outputTensors.back(); - checkSingleTensor(outputTensor); - // @lint-ignore CLANGTIDY - auto inputTensors_ = inputTensors.back(); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + - 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensors, // inputTensors - outputTensors, // outputTensors - rank_, // rank - "reduce_scatter", // collective name - outputTensor.numel() * this->getSize(), // inNelems - outputTensor.numel(), // outNelems - outputTensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - bool same_size = checkSameSize(inputTensors_); - if (same_size) { - // Flatten a vector of tensors into a single, stacked tensor. - at::Tensor inputFlattened = newLikeFlat(inputTensors_); - return collective( - inputFlattened, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream, - ccl::stream& xcclStream) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - onecclReduceScatter( - input.data_ptr(), - output.data_ptr(), - (size_t)output.numel(), - xcclDataType, - xcclReduceOp, - comm, - xcclStream); -#if !defined(XCCL_HAS_AVG) - if (opts.reduceOp == ReduceOp::AVG) { - auto divisor = getSize(); - c10::StreamGuard guard(stream); - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - output.div_(divisor); - } -#endif - return; - }, - [&](at::xpu::XPUStream& Stream, - c10::intrusive_ptr& work) { - // Copy the input tensors to the flattened inputs. - c10::StreamGuard guard(Stream); - for (const auto j : c10::irange(inputTensors_.size())) { - c10::xpu::XPUCachingAllocator::recordStream( - inputTensors_[j].storage().data_ptr(), Stream); - inputFlattened[j].copy_(inputTensors_[j], true); - } - }, - [&](at::xpu::XPUStream&, - c10::intrusive_ptr&) {}, - OpType::REDUCE_SCATTER, - "xccl:reduce_scatter"); - } else { - const auto num_reduces = inputTensors_.size(); - startCoalescing(); - for (const int i : c10::irange(num_reduces)) { - auto& input = inputTensors_[i]; - auto& output = (i == rank_) ? outputTensor : input; - auto reduceOpts = ReduceOptions{ - opts.reduceOp, - static_cast(i), - static_cast(0), - opts.timeout}; - _reduce_oop(output, input, reduceOpts); - } - auto work = endCoalescing(OpType::REDUCE_SCATTER); - return work; - } - } +c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( + std::vector& tensors, + const AllreduceCoalescedOptions& opts) { + auto total_numel = checkTensorOnSameDevice(tensors); - c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( - at::Tensor & outputTensor, - at::Tensor & inputTensor, - const ReduceScatterOptions& opts) { - TORCH_CHECK_WITH( - TypeError, - inputTensor.dtype() == outputTensor.dtype(), - "input tensor must be the same type as the output tensor."); - TORCH_CHECK_WITH( - ValueError, - inputTensor.numel() == outputTensor.numel() * size_, - "input tensor must be the same size as output size times world size"); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + - 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensor, // inputTensor - outputTensor, // outputTensor - rank_, // rank - "_reduce_scatter_base", // collective name - inputTensor.numel(), // inNelems - outputTensor.numel(), // outNelems - outputTensor.scalar_type(), // dtype - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - return collective( - inputTensor, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream, - ccl::stream& xcclStream) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - onecclReduceScatter( - input.data_ptr(), - output.data_ptr(), - (size_t)output.numel(), - xcclDataType, - xcclReduceOp, - comm, - xcclStream); -#if !defined(XCCL_HAS_AVG) - if (opts.reduceOp == ReduceOp::AVG) { - auto divisor = getSize(); - c10::StreamGuard guard(stream); - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - output.div_(divisor); - } -#endif - return; - }, - OpType::_REDUCE_SCATTER_BASE, - "xccl:_reduce_scatter_base"); - } + // @lint-ignore CLANGTIDY + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + rank_, // rank + "allreduce_coalesced", // collective name + total_numel, // inNelems + total_numel, // outNelems + tensors[0].scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize - c10::intrusive_ptr - ProcessGroupXCCL::reduce_scatter_tensor_coalesced( - std::vector & outputs, - std::vector & inputs, - const ReduceScatterOptions& opts) { - return collectiveCoalesced( - inputs, - outputs, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream, - ccl::stream& xcclStream) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - onecclReduceScatter( - input.data_ptr(), - output.data_ptr(), - (size_t)output.numel(), - xcclDataType, - xcclReduceOp, - comm, - xcclStream); + return collectiveCoalesced( + tensors, + tensors, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream, + sycl::queue& SyclQueue) { + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + onecclAllReduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + comm, + &SyclQueue); #if !defined(XCCL_HAS_AVG) - if (opts.reduceOp == ReduceOp::AVG) { - auto divisor = getSize(); - c10::StreamGuard guard(stream); - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - output.div_(divisor); - } -#endif - return; - }, - OpType::COALESCED, - "xccl:reduce_scatter_tensor_coalesced"); - } - - c10::intrusive_ptr ProcessGroupXCCL::barrier( - const BarrierOptions& opts) { - RECORD_PARAM_COMMS( - static_cast( - this->getSequenceNumberForGroup() + - 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - rank_, // rank - "barrier", // collective name - 0, // inNelems - 0, // outNelems - at::kByte, // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - // Device to use for barrier - int barDevIdx = -1; - - // See nccl barrier comments - if (!opts.device_ids.empty()) { - barDevIdx = opts.device_ids[0]; - } else if (getBoundDeviceId()) { - barDevIdx = (*getBoundDeviceId()).index(); - } else if (!usedDeviceIdxs_.empty()) { - barDevIdx = *usedDeviceIdxs_.begin(); - } else { - barDevIdx = static_cast( - rank_ % at::detail::getXPUHooks().getNumGPUs()); + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + c10::StreamGuard guard(stream); + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + output.div_(divisor); } +#endif + return; + }, + OpType::COALESCED, + "xccl:allreduce_coalesced"); +} - TORCH_CHECK_WITH( - ValueError, - barDevIdx >= 0, - "Failed to infer a GPU device id to perform barrier. "); - auto barDevice = at::Device(at::DeviceType::XPU, barDevIdx); - - at::Tensor barrierTensor = at::zeros( - {1}, at::TensorOptions().device(barDevice).dtype(at::kFloat)); - - auto work = allreduce_impl(barrierTensor, "xccl:all_reduce_barrier"); - - auto xcclWork = dynamic_cast(work.get()); - TORCH_CHECK(xcclWork); - xcclWork->barrierTensor_ = std::move(barrierTensor); - return work; - } +c10::intrusive_ptr ProcessGroupXCCL::broadcast( + std::vector& tensors, + const BroadcastOptions& opts) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + if (tensor.is_complex()) { + tensor = at::view_as_real(tensor); + } + checkSingleTensor(tensor); - c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( - at::Tensor & outputTensor, - at::Tensor & inputTensor, - std::vector & outputSplitSizes, - std::vector & inputSplitSizes, - const AllToAllOptions& /* unused */) { - checkSingleTensor(outputTensor, true); - checkSingleTensor(inputTensor, true); - if (outputSplitSizes.size() == 0 && inputSplitSizes.size() == 0) { - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + - 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensor, // inputTensor - outputTensor, // outputTensor - rank_, // rank - "all_to_all", // collective name - inputTensor.numel(), // inNelems - outputTensor.numel(), // outNelems - inputTensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - TORCH_CHECK( - outputTensor.numel() == inputTensor.numel() && - outputTensor.scalar_type() == inputTensor.scalar_type(), - "xpu_alltoall_base: tensors are not equal in size or data type"); - TORCH_CHECK( - outputTensor.size(0) % size_ == 0, - "xpu_alltoall_base: tensor's dim 0 does not divide equally across group size"); - return collective( - inputTensor, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream, - ccl::stream& xcclStream) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(output.scalar_type()); - size_t count = input.numel() / size_; - size_t rankdiff = input.nbytes() / size_; - - onecclGroupStart(); - for (const auto r : c10::irange(rank_)) { - if (count != 0) { - onecclSend( - ((char*)input.data_ptr()) + r * rankdiff, - count, - xcclDataType, - r, - comm, - &(stream.queue())); - onecclRecv( - ((char*)output.data_ptr()) + r * rankdiff, - count, - xcclDataType, - r, - comm, - &(stream.queue())); - } - } - onecclGroupEnd(); - return; - }, - OpType::ALLTOALL_BASE, - "xccl:all_to_all"); - } else { - c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); - c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + - 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensor, // inputTensor - outputTensor, // outputTensor - rank_, // rank - "all_to_allv", // collective name - inputTensor.numel(), // inNelems - outputTensor.numel(), // outNelems - inputTensor.scalar_type(), // dType - inputSplitSizes, // inSplitSizes - outputSplitSizes, // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - return collective( - inputTensor, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - std::vector send_lengths(size_); - std::vector recv_lengths(size_); - std::vector send_offsets(size_); - std::vector recv_offsets(size_); - c10d::computeLengthsAndOffsets( - inputSplitSizes, input, &send_lengths, &send_offsets); - c10d::computeLengthsAndOffsets( - outputSplitSizes, output, &recv_lengths, &recv_offsets); - - size_t size = input.element_size(); - auto xcclDataType = getXcclDataType(input.scalar_type()); - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - - auto send_offsets_data = send_offsets.data(); - auto recv_offsets_data = recv_offsets.data(); - - onecclGroupStart(); - for (const auto r : c10::irange(size_)) { - if (send_lengths[r] != 0) { - onecclSend( - ((char*)input.data_ptr()) + send_offsets_data[r] * size, - send_lengths[r], - xcclDataType, - r, - comm, - &(stream.queue())); - } - if (recv_lengths[r] != 0) { - onecclRecv( - ((char*)output.data_ptr()) + - recv_offsets_data[r] * size, - recv_lengths[r], - xcclDataType, - r, - comm, - &(stream.queue())); - } - } - onecclGroupEnd(); - - return; - }, - OpType::ALLTOALL_BASE, - "xccl:all_to_all"); - } - } + // @lint-ignore CLANGTIDY + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + opts.rootRank, // root rank + "broadcast", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize - c10::intrusive_ptr ProcessGroupXCCL::alltoall( - std::vector & outputTensors, - std::vector & inputTensors, - const AllToAllOptions& /* unused */) { - auto device = outputTensors[0].device(); - int64_t total_numel = 0; - for (const auto r : c10::irange(outputTensors.size())) { - checkSingleTensor(outputTensors[r], true); - checkSingleTensor(inputTensors[r], true); - TORCH_CHECK( - device == outputTensors[r].device() && - device == inputTensors[r].device(), - "Tensors must be on the same device") - total_numel += inputTensors[r].numel(); - } + const auto root = opts.rootRank + opts.rootTensor; - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + - 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensors, // inputTensors - outputTensors, // outputTensors - rank_, // rank - "all_to_all", // collective name - total_numel, // inNelems - total_numel, // outNelems - inputTensors.front().scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - return collective( - inputTensors, - outputTensors, - [&](at::Tensor& /* unused */, - at::Tensor& /* unused */, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - onecclGroupStart(); - for (const int r : - c10::irange(static_cast(outputTensors.size()))) { - at::Tensor& input = inputTensors[r]; - at::Tensor& output = outputTensors[r]; - if (input.numel() != 0) { - onecclSend( - input.data_ptr(), - input.numel(), - getXcclDataType(input.scalar_type()), - r, - comm, - &(stream.queue())); - } - if (output.numel() != 0) { - onecclRecv( - output.data_ptr(), - output.numel(), - getXcclDataType(output.scalar_type()), - r, - comm, - &(stream.queue())); - } - } - onecclGroupEnd(); + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream, + sycl::queue& SyclQueue) { + auto xcclDataType = getXcclDataType(input.scalar_type()); + onecclBroadcast( + input.data_ptr(), + output.data_ptr(), // ? + (size_t)input.numel(), + xcclDataType, + root, + comm, + &SyclQueue); + return; + }, + OpType::BROADCAST, + "xccl:broadcast"); +} - return; - }, - OpType::ALLTOALL, - "xccl:all_to_all"); - } +c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const BroadcastOptions& opts) { + if (outputTensor.numel() != inputTensor.numel()) { + C10_THROW_ERROR( + ValueError, + "Tensor input and output of _broadcast_oop must have the same number of elements "); + } + const auto root = opts.rootRank + opts.rootTensor; + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream, + sycl::queue& SyclQueue) { + auto xcclDataType = getXcclDataType(input.scalar_type()); + onecclBroadcast( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + root, + comm, + &SyclQueue); + return; + }, + OpType::BROADCAST, + "xccl:_broadcast_oop"); +} + +c10::intrusive_ptr ProcessGroupXCCL::reduce( + std::vector& tensors, + const ReduceOptions& opts) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + if (tensor.is_complex()) { + TORCH_CHECK( + complexViewAsRealAllowed(opts.reduceOp), + "reduce does not support", + opts.reduceOp, + "on complex tensors"); + tensor = at::view_as_real(tensor); + } + checkSingleTensor(tensor); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + opts.rootRank, // root rank + "reduce", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream, + sycl::queue& SyclQueue) { + const int root = opts.rootRank + opts.rootTensor; + const auto xcclDataType = getXcclDataType(input.scalar_type(), true); + const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + onecclReduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + root, + comm, + &SyclQueue); +#if !defined(XCCL_HAS_AVG) + if (opts.reduceOp == ReduceOp::AVG && getRank() == root) { + auto divisor = getSize(); + c10::StreamGuard guard(stream); + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + output.div_(divisor); + } +#endif + return; + }, + OpType::REDUCE, + "xccl:reduce"); +} + +c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceOptions& opts) { + TORCH_CHECK_WITH( + ValueError, + outputTensor.numel() == inputTensor.numel(), + "Tensor input and output of _reduce_oop must have the same number of elements"); + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream, + sycl::queue& SyclQueue) { + const int root = opts.rootRank + opts.rootTensor; + const auto xcclDataType = getXcclDataType(input.scalar_type(), true); + const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + onecclReduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + root, + comm, + &SyclQueue); +#if !defined(XCCL_HAS_AVG) + if (opts.reduceOp == ReduceOp::AVG && getRank() == root) { + auto divisor = getSize(); + c10::StreamGuard guard(stream); + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + output.div_(divisor); + } +#endif + return; + }, + OpType::REDUCE, + "xccl:_reduce_oop"); +} + +c10::intrusive_ptr ProcessGroupXCCL::allgather( + std::vector>& outputTensors, + std::vector& inputTensors, + const AllgatherOptions& opts) { + TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto inputTensor = inputTensors.back(); + checkSingleTensor(inputTensor); + // @lint-ignore CLANGTIDY + std::vector& outputTensors_ = outputTensors.back(); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + rank_, // rank + "all_gather", // collective name + inputTensor.numel(), // inNelems + inputTensor.numel() * // outNelems + this->getSize(), + inputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSize + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + bool same_size = checkSameSize(outputTensors_); + if (same_size) { + // Flatten a vector of tensors into a single, stacked tensor. + at::Tensor outputFlattened = newLikeFlat(outputTensors_); + + return collective( + inputTensor, + outputFlattened, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream, + sycl::queue& SyclQueue) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type()); + onecclAllGather( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + comm, + &SyclQueue); + return; + }, + [](at::xpu::XPUStream&, + c10::intrusive_ptr& work) {}, + [&](at::xpu::XPUStream& Stream, + c10::intrusive_ptr& work) { + // Copy the flattened output tensors to the outputs. + c10::StreamGuard guard(Stream); + for (const auto j : c10::irange(outputTensors_.size())) { + c10::xpu::XPUCachingAllocator::recordStream( + outputTensors_[j].storage().data_ptr(), Stream); + outputTensors_[j].copy_(outputFlattened[j], true); + } + }, + OpType::ALLGATHER, + "xccl:all_gather"); + } else { + const auto num_reduces = outputTensors_.size(); + startCoalescing(); + for (const int i : c10::irange(num_reduces)) { + auto& output = outputTensors_[i]; + auto& input = (i == rank_) ? inputTensor : output; + auto broadcastOpts = BroadcastOptions{ + static_cast(i), static_cast(0), opts.timeout}; + _broadcast_oop(output, input, broadcastOpts); + } + auto work = endCoalescing(OpType::ALLGATHER); + return work; + } +} + +c10::intrusive_ptr ProcessGroupXCCL::_allgather_base( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const AllgatherOptions& opts) { + checkSingleTensor(input_tensor); + checkSingleTensor(output_tensor); + + TORCH_CHECK_WITH( + TypeError, + input_tensor.dtype() == output_tensor.dtype(), + "output tensor must have the same type as input tensor"); + TORCH_CHECK_WITH( + ValueError, + input_tensor.numel() * size_ == output_tensor.numel(), + "output tensor size must be equal to world_size times input tensor size"); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + input_tensor, // inputTensors + output_tensor, // outputTensors + rank_, // rank + "_allgather_base", // collective name + input_tensor.numel(), // inNelems + output_tensor.numel(), // outNelems + output_tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSize + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + return collective( + input_tensor, + output_tensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream, + sycl::queue& SyclQueue) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type()); + onecclAllGather( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + comm, + &SyclQueue); + return; + }, + OpType::_ALLGATHER_BASE, + "xccl:_all_gather_base"); +} + +c10::intrusive_ptr ProcessGroupXCCL::allgather_into_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const AllgatherOptions& opts) { + return collectiveCoalesced( + inputs, + outputs, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream, + sycl::queue& SyclQueue) { + auto xcclDataType = getXcclDataType(input.scalar_type()); + onecclAllGather( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + comm, + &SyclQueue); + return; + }, + OpType::COALESCED, + "xccl:all_gather_into_tensor_coalesced"); +} + +c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ReduceScatterOptions& opts) { + TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto outputTensor = outputTensors.back(); + checkSingleTensor(outputTensor); + // @lint-ignore CLANGTIDY + auto inputTensors_ = inputTensors.back(); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + rank_, // rank + "reduce_scatter", // collective name + outputTensor.numel() * this->getSize(), // inNelems + outputTensor.numel(), // outNelems + outputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + bool same_size = checkSameSize(inputTensors_); + if (same_size) { + // Flatten a vector of tensors into a single, stacked tensor. + at::Tensor inputFlattened = newLikeFlat(inputTensors_); + return collective( + inputFlattened, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream, + sycl::queue& SyclQueue) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + onecclReduceScatter( + input.data_ptr(), + output.data_ptr(), + (size_t)output.numel(), + xcclDataType, + xcclReduceOp, + comm, + &SyclQueue); +#if !defined(XCCL_HAS_AVG) + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + c10::StreamGuard guard(stream); + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + output.div_(divisor); + } +#endif + return; + }, + [&](at::xpu::XPUStream& Stream, + c10::intrusive_ptr& work) { + // Copy the input tensors to the flattened inputs. + c10::StreamGuard guard(Stream); + for (const auto j : c10::irange(inputTensors_.size())) { + c10::xpu::XPUCachingAllocator::recordStream( + inputTensors_[j].storage().data_ptr(), Stream); + inputFlattened[j].copy_(inputTensors_[j], true); + } + }, + [&](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + OpType::REDUCE_SCATTER, + "xccl:reduce_scatter"); + } else { + const auto num_reduces = inputTensors_.size(); + startCoalescing(); + for (const int i : c10::irange(num_reduces)) { + auto& input = inputTensors_[i]; + auto& output = (i == rank_) ? outputTensor : input; + auto reduceOpts = ReduceOptions{ + opts.reduceOp, + static_cast(i), + static_cast(0), + opts.timeout}; + _reduce_oop(output, input, reduceOpts); + } + auto work = endCoalescing(OpType::REDUCE_SCATTER); + return work; + } +} + +c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts) { + TORCH_CHECK_WITH( + TypeError, + inputTensor.dtype() == outputTensor.dtype(), + "input tensor must be the same type as the output tensor."); + TORCH_CHECK_WITH( + ValueError, + inputTensor.numel() == outputTensor.numel() * size_, + "input tensor must be the same size as output size times world size"); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensor, // inputTensor + outputTensor, // outputTensor + rank_, // rank + "_reduce_scatter_base", // collective name + inputTensor.numel(), // inNelems + outputTensor.numel(), // outNelems + outputTensor.scalar_type(), // dtype + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream, + sycl::queue& SyclQueue) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + onecclReduceScatter( + input.data_ptr(), + output.data_ptr(), + (size_t)output.numel(), + xcclDataType, + xcclReduceOp, + comm, + &SyclQueue); +#if !defined(XCCL_HAS_AVG) + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + c10::StreamGuard guard(stream); + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + output.div_(divisor); + } +#endif + return; + }, + OpType::_REDUCE_SCATTER_BASE, + "xccl:_reduce_scatter_base"); +} + +c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const ReduceScatterOptions& opts) { + return collectiveCoalesced( + inputs, + outputs, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream, + sycl::queue& SyclQueue) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + onecclReduceScatter( + input.data_ptr(), + output.data_ptr(), + (size_t)output.numel(), + xcclDataType, + xcclReduceOp, + comm, + &SyclQueue); +#if !defined(XCCL_HAS_AVG) + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + c10::StreamGuard guard(stream); + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + output.div_(divisor); + } +#endif + return; + }, + OpType::COALESCED, + "xccl:reduce_scatter_tensor_coalesced"); +} + +c10::intrusive_ptr ProcessGroupXCCL::barrier(const BarrierOptions& opts) { + RECORD_PARAM_COMMS( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + rank_, // rank + "barrier", // collective name + 0, // inNelems + 0, // outNelems + at::kByte, // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + // Device to use for barrier + int barDevIdx = -1; + + // See nccl barrier comments + if (!opts.device_ids.empty()) { + barDevIdx = opts.device_ids[0]; + } else if (getBoundDeviceId()) { + barDevIdx = (*getBoundDeviceId()).index(); + } else if (!usedDeviceIdxs_.empty()) { + barDevIdx = *usedDeviceIdxs_.begin(); + } else { + barDevIdx = + static_cast(rank_ % at::detail::getXPUHooks().getNumGPUs()); + } + + TORCH_CHECK_WITH( + ValueError, + barDevIdx >= 0, + "Failed to infer a GPU device id to perform barrier. "); + auto barDevice = at::Device(at::DeviceType::XPU, barDevIdx); + + at::Tensor barrierTensor = + at::zeros({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat)); + + auto work = allreduce_impl(barrierTensor, "xccl:all_reduce_barrier"); + + auto xcclWork = dynamic_cast(work.get()); + TORCH_CHECK(xcclWork); + xcclWork->barrierTensor_ = std::move(barrierTensor); + return work; +} + +c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + std::vector& outputSplitSizes, + std::vector& inputSplitSizes, + const AllToAllOptions& /* unused */) { + checkSingleTensor(outputTensor, true); + checkSingleTensor(inputTensor, true); + if (outputSplitSizes.size() == 0 && inputSplitSizes.size() == 0) { + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensor, // inputTensor + outputTensor, // outputTensor + rank_, // rank + "all_to_all", // collective name + inputTensor.numel(), // inNelems + outputTensor.numel(), // outNelems + inputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + TORCH_CHECK( + outputTensor.numel() == inputTensor.numel() && + outputTensor.scalar_type() == inputTensor.scalar_type(), + "xpu_alltoall_base: tensors are not equal in size or data type"); + TORCH_CHECK( + outputTensor.size(0) % size_ == 0, + "xpu_alltoall_base: tensor's dim 0 does not divide equally across group size"); + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream, + sycl::queue& SyclQueue) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(output.scalar_type()); + size_t count = input.numel() / size_; + size_t rankdiff = input.nbytes() / size_; + + onecclGroupStart(); + for (const auto r : c10::irange(rank_)) { + if (count != 0) { + onecclSend( + ((char*)input.data_ptr()) + r * rankdiff, + count, + xcclDataType, + r, + comm, + &SyclQueue); + onecclRecv( + ((char*)output.data_ptr()) + r * rankdiff, + count, + xcclDataType, + r, + comm, + &SyclQueue); + } + } + onecclGroupEnd(); + return; + }, + OpType::ALLTOALL_BASE, + "xccl:all_to_all"); + } else { + c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); + c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensor, // inputTensor + outputTensor, // outputTensor + rank_, // rank + "all_to_allv", // collective name + inputTensor.numel(), // inNelems + outputTensor.numel(), // outNelems + inputTensor.scalar_type(), // dType + inputSplitSizes, // inSplitSizes + outputSplitSizes, // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream, + sycl::queue& SyclQueue) { + std::vector send_lengths(size_); + std::vector recv_lengths(size_); + std::vector send_offsets(size_); + std::vector recv_offsets(size_); + c10d::computeLengthsAndOffsets( + inputSplitSizes, input, &send_lengths, &send_offsets); + c10d::computeLengthsAndOffsets( + outputSplitSizes, output, &recv_lengths, &recv_offsets); + + size_t size = input.element_size(); + auto xcclDataType = getXcclDataType(input.scalar_type()); + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + + auto send_offsets_data = send_offsets.data(); + auto recv_offsets_data = recv_offsets.data(); + + onecclGroupStart(); + for (const auto r : c10::irange(size_)) { + if (send_lengths[r] != 0) { + onecclSend( + ((char*)input.data_ptr()) + send_offsets_data[r] * size, + send_lengths[r], + xcclDataType, + r, + comm, + &SyclQueue); + } + if (recv_lengths[r] != 0) { + onecclRecv( + ((char*)output.data_ptr()) + recv_offsets_data[r] * size, + recv_lengths[r], + xcclDataType, + r, + comm, + &SyclQueue); + } + } + onecclGroupEnd(); + + return; + }, + OpType::ALLTOALL_BASE, + "xccl:all_to_all"); + } +} + +c10::intrusive_ptr ProcessGroupXCCL::alltoall( + std::vector& outputTensors, + std::vector& inputTensors, + const AllToAllOptions& /* unused */) { + auto device = outputTensors[0].device(); + int64_t total_numel = 0; + for (const auto r : c10::irange(outputTensors.size())) { + checkSingleTensor(outputTensors[r], true); + checkSingleTensor(inputTensors[r], true); + TORCH_CHECK( + device == outputTensors[r].device() && + device == inputTensors[r].device(), + "Tensors must be on the same device") + total_numel += inputTensors[r].numel(); + } + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + rank_, // rank + "all_to_all", // collective name + total_numel, // inNelems + total_numel, // outNelems + inputTensors.front().scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + return collective( + inputTensors, + outputTensors, + [&](at::Tensor& /* unused */, + at::Tensor& /* unused */, + xcclComm_t& comm, + at::xpu::XPUStream& stream, + sycl::queue& SyclQueue) { + onecclGroupStart(); + for (const int r : + c10::irange(static_cast(outputTensors.size()))) { + at::Tensor& input = inputTensors[r]; + at::Tensor& output = outputTensors[r]; + if (input.numel() != 0) { + onecclSend( + input.data_ptr(), + input.numel(), + getXcclDataType(input.scalar_type()), + r, + comm, + &SyclQueue); + } + if (output.numel() != 0) { + onecclRecv( + output.data_ptr(), + output.numel(), + getXcclDataType(output.scalar_type()), + r, + comm, + &SyclQueue); + } + } + onecclGroupEnd(); + + return; + }, + OpType::ALLTOALL, + "xccl:all_to_all"); +} } // namespace c10d diff --git a/src/xccl/ProcessGroupXCCL.hpp b/src/xccl/ProcessGroupXCCL.hpp index a69d3f21c..3a7c5ea9a 100644 --- a/src/xccl/ProcessGroupXCCL.hpp +++ b/src/xccl/ProcessGroupXCCL.hpp @@ -259,9 +259,9 @@ class TORCH_API ProcessGroupXCCL : public Backend { const char* profilingTitle = nullptr); c10::intrusive_ptr allreduce_impl( - at::Tensor& tensor, - const char* profilingTitle = "xccl:all_reduce", - const AllreduceOptions& opts = AllreduceOptions()); + at::Tensor& tensor, + const char* profilingTitle = "xccl:all_reduce", + const AllreduceOptions& opts = AllreduceOptions()); c10::intrusive_ptr allreduce( std::vector& tensors, @@ -364,7 +364,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { uint64_t getSequenceNumberForGroup() override; protected: - std::unordered_map> + std::unordered_map> xcclStreamsMap_; std::unordered_map xcclEventsMap_; std::unordered_map> devXCCLCommMap_; From 82e4f96364f4d798e4fc8e7f658616c1f55fda1f Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Mon, 31 Mar 2025 17:31:48 +0800 Subject: [PATCH 08/10] update emu --- src/xccl/ProcessGroupXCCL.cpp | 46 +++++++++++++++++------------------ 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 2b26e8c69..d54aab5c3 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -15,31 +15,31 @@ namespace { #endif // oneCCL version >= 2021.15 const std::map xcclOps = { - {ReduceOp::MIN, onecclRedOp_t::ONECCL_MIN}, - {ReduceOp::MAX, onecclRedOp_t::ONECCL_MAX}, - {ReduceOp::SUM, onecclRedOp_t::ONECCL_SUM}, - {ReduceOp::PRODUCT, onecclRedOp_t::ONECCL_PROD}, + {ReduceOp::MIN, onecclRedOp_t::onecclMin}, + {ReduceOp::MAX, onecclRedOp_t::onecclMax}, + {ReduceOp::SUM, onecclRedOp_t::onecclSum}, + {ReduceOp::PRODUCT, onecclRedOp_t::onecclProd}, #ifdef XCCL_HAS_AVG - {ReduceOp::AVG, onecclRedOp_t::ONECCL_AVG}, + {ReduceOp::AVG, onecclRedOp_t::onecclAvg}, #endif // XCCL_HAS_AVG }; const std::map xcclDatatypes = { - {at::kByte, onecclDataType_t::ONECCL_UINT8}, - {at::kChar, onecclDataType_t::ONECCL_INT8}, - {at::kInt, onecclDataType_t::ONECCL_INT32}, - {at::kLong, onecclDataType_t::ONECCL_INT64}, - {at::kHalf, onecclDataType_t::ONECCL_FLOAT16}, - {at::kFloat, onecclDataType_t::ONECCL_FLOAT32}, - {at::kDouble, onecclDataType_t::ONECCL_FLOAT64}, - {at::kBFloat16, onecclDataType_t::ONECCL_BFLOAT16}, - {at::kBool, onecclDataType_t::ONECCL_UINT8}, + {at::kByte, onecclDataType_t::onecclUint8}, + {at::kChar, onecclDataType_t::onecclChar}, + {at::kInt, onecclDataType_t::onecclInt32}, + {at::kLong, onecclDataType_t::onecclInt64}, + {at::kHalf, onecclDataType_t::onecclFloat16}, + {at::kFloat, onecclDataType_t::onecclFloat32}, + {at::kDouble, onecclDataType_t::onecclFloat64}, + {at::kBFloat16, onecclDataType_t::onecclBfloat16}, + {at::kBool, onecclDataType_t::onecclUint8}, // use for non-reducetion op like allgather - {at::kFloat8_e5m2, onecclDataType_t::ONECCL_UINT8}, - {at::kFloat8_e4m3fn, onecclDataType_t::ONECCL_UINT8}, - {at::kFloat8_e4m3fnuz, onecclDataType_t::ONECCL_UINT8}, - {at::kFloat8_e5m2fnuz, onecclDataType_t::ONECCL_UINT8}, + {at::kFloat8_e5m2, onecclDataType_t::onecclUint8}, + {at::kFloat8_e4m3fn, onecclDataType_t::onecclUint8}, + {at::kFloat8_e4m3fnuz, onecclDataType_t::onecclUint8}, + {at::kFloat8_e5m2fnuz, onecclDataType_t::onecclUint8}, }; bool checkSameSize(const std::vector& input_tensors) { @@ -117,7 +117,7 @@ onecclRedOp_t getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { if (input.scalar_type() == at::kBool) { if (reduceOp == ReduceOp::SUM) { // Map sum to max for bool tensors to avoid overflow issues with sum. - return onecclRedOp_t::ONECCL_MAX; + return onecclRedOp_t::onecclMax; } #ifdef XCCL_HAS_AVG if (reduceOp == ReduceOp::AVG) { @@ -128,7 +128,7 @@ onecclRedOp_t getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { } #if !defined(XCCL_HAS_AVG) if (reduceOp == ReduceOp::AVG) { - return onecclRedOp_t::ONECCL_SUM; + return onecclRedOp_t::onecclSum; } #endif return xcclOps.at(reduceOp); @@ -388,13 +388,13 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( broadcastUniqueXCCLID(&xcclID, singleP2POp, deviceKey, p2pRank); xcclComm_t comm = nullptr; - onecclResult_t result = ONECCL_SUCCESS; + onecclResult_t result = onecclSuccess; result = onecclSetDevice(rank); - if (result != ONECCL_SUCCESS) { + if (result != onecclSuccess) { std::cerr << "Failed to set device.\n"; } result = onecclCommInitRank(&comm, numRanks, xcclID, rank); - if (result != ONECCL_SUCCESS) { + if (result != onecclSuccess) { std::cerr << "Failed to initialize communicator.\n"; } XCCLComm = std::make_shared(comm); From a276961909f75aef1555500bc1037129c483fd4b Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Mon, 14 Apr 2025 17:57:19 +0800 Subject: [PATCH 09/10] update version --- src/xccl/ProcessGroupXCCL.hpp | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/xccl/ProcessGroupXCCL.hpp b/src/xccl/ProcessGroupXCCL.hpp index 144b73f09..3a0cb90c4 100644 --- a/src/xccl/ProcessGroupXCCL.hpp +++ b/src/xccl/ProcessGroupXCCL.hpp @@ -362,10 +362,22 @@ class TORCH_API ProcessGroupXCCL : public Backend { namespace { inline std::string getXcclVersion() { - auto xccl_version = ccl::get_library_version(); - std::string versionString = std::to_string(xccl_version.major) + "." + - std::to_string(xccl_version.minor) + "." + - std::to_string(xccl_version.update); + static std::string versionString = []() { + int version = 0; + std::string versionString; + onecclGetVersion(&version); + + const int majorBase = 10000; + const int minorBase = 100; + auto xcclMajor = version / majorBase; + auto xcclMinor = (version % majorBase) / minorBase; + auto xcclPatch = version % (xcclMajor * majorBase + xcclMinor * minorBase); + versionString = std::to_string(xcclMajor) + "." + + std::to_string(xcclMinor) + "." + std::to_string(xcclPatch); + + return versionString; + }(); + return versionString; } From f4fb1f051fb0efcfc42ae1eb9c7c26dbb1a63517 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Tue, 15 Apr 2025 00:28:36 +0800 Subject: [PATCH 10/10] refine all2all --- src/xccl/ProcessGroupXCCL.cpp | 141 ++++++++++++---------------------- 1 file changed, 51 insertions(+), 90 deletions(-) diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 94919e230..8170665d7 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -1791,44 +1791,6 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( TORCH_CHECK( outputTensor.size(0) % size_ == 0, "xpu_alltoall_base: tensor's dim 0 does not divide equally across group size"); - return collective( - inputTensor, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream, - sycl::queue& SyclQueue) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(output.scalar_type()); - size_t count = input.numel() / size_; - size_t rankdiff = input.nbytes() / size_; - - onecclGroupStart(); - for (const auto r : c10::irange(rank_)) { - if (count != 0) { - onecclSend( - ((char*)input.data_ptr()) + r * rankdiff, - count, - xcclDataType, - r, - comm, - &SyclQueue); - onecclRecv( - ((char*)output.data_ptr()) + r * rankdiff, - count, - xcclDataType, - r, - comm, - &SyclQueue); - } - } - onecclGroupEnd(); - return; - }, - OpType::ALLTOALL_BASE, - "xccl:all_to_all"); } else { c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); @@ -1850,60 +1812,59 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( -1, // globalRankStart -1, // globalRankStride this->getSize()); // worldSize + } + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream, + sycl::queue& SyclQueue) { + std::vector send_lengths(size_); + std::vector recv_lengths(size_); + std::vector send_offsets(size_); + std::vector recv_offsets(size_); + c10d::computeLengthsAndOffsets( + inputSplitSizes, input, &send_lengths, &send_offsets); + c10d::computeLengthsAndOffsets( + outputSplitSizes, output, &recv_lengths, &recv_offsets); + + size_t size = input.element_size(); + auto xcclDataType = getXcclDataType(input.scalar_type()); + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); - return collective( - inputTensor, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream, - sycl::queue& SyclQueue) { - std::vector send_lengths(size_); - std::vector recv_lengths(size_); - std::vector send_offsets(size_); - std::vector recv_offsets(size_); - c10d::computeLengthsAndOffsets( - inputSplitSizes, input, &send_lengths, &send_offsets); - c10d::computeLengthsAndOffsets( - outputSplitSizes, output, &recv_lengths, &recv_offsets); - - size_t size = input.element_size(); - auto xcclDataType = getXcclDataType(input.scalar_type()); - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); + auto send_offsets_data = send_offsets.data(); + auto recv_offsets_data = recv_offsets.data(); - auto send_offsets_data = send_offsets.data(); - auto recv_offsets_data = recv_offsets.data(); - - onecclGroupStart(); - for (const auto r : c10::irange(size_)) { - if (send_lengths[r] != 0) { - onecclSend( - ((char*)input.data_ptr()) + send_offsets_data[r] * size, - send_lengths[r], - xcclDataType, - r, - comm, - &SyclQueue); - } - if (recv_lengths[r] != 0) { - onecclRecv( - ((char*)output.data_ptr()) + recv_offsets_data[r] * size, - recv_lengths[r], - xcclDataType, - r, - comm, - &SyclQueue); - } + onecclGroupStart(); + for (const auto r : c10::irange(size_)) { + if (send_lengths[r] != 0) { + onecclSend( + ((char*)input.data_ptr()) + send_offsets_data[r] * size, + send_lengths[r], + xcclDataType, + r, + comm, + &SyclQueue); } - onecclGroupEnd(); + if (recv_lengths[r] != 0) { + onecclRecv( + ((char*)output.data_ptr()) + recv_offsets_data[r] * size, + recv_lengths[r], + xcclDataType, + r, + comm, + &SyclQueue); + } + } + onecclGroupEnd(); - return; - }, - OpType::ALLTOALL_BASE, - "xccl:all_to_all"); - } + return; + }, + OpType::ALLTOALL_BASE, + "xccl:all_to_all"); } c10::intrusive_ptr ProcessGroupXCCL::alltoall( @@ -1940,8 +1901,8 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall( this->getSize()); // worldSize return collective( - inputTensors, - outputTensors, + inputTensors.front(), + outputTensors.front(), [&](at::Tensor& /* unused */, at::Tensor& /* unused */, xcclComm_t& comm,