Skip to content

Commit 7fa0fef

Browse files
committed
update emu
1 parent 08819c7 commit 7fa0fef

File tree

1 file changed

+20
-20
lines changed

1 file changed

+20
-20
lines changed

src/xccl/ProcessGroupXCCL.cpp

+20-20
Original file line numberDiff line numberDiff line change
@@ -15,31 +15,31 @@ namespace {
1515
#endif // oneCCL version >= 2021.15
1616

1717
const std::map<c10d::ReduceOp, onecclRedOp_t> xcclOps = {
18-
{ReduceOp::MIN, onecclRedOp_t::ONECCL_MIN},
19-
{ReduceOp::MAX, onecclRedOp_t::ONECCL_MAX},
20-
{ReduceOp::SUM, onecclRedOp_t::ONECCL_SUM},
21-
{ReduceOp::PRODUCT, onecclRedOp_t::ONECCL_PROD},
18+
{ReduceOp::MIN, onecclRedOp_t::onecclMin},
19+
{ReduceOp::MAX, onecclRedOp_t::onecclMax},
20+
{ReduceOp::SUM, onecclRedOp_t::onecclSum},
21+
{ReduceOp::PRODUCT, onecclRedOp_t::onecclProd},
2222
#ifdef XCCL_HAS_AVG
23-
{ReduceOp::AVG, onecclRedOp_t::ONECCL_AVG},
23+
{ReduceOp::AVG, onecclRedOp_t::onecclAvg},
2424
#endif // XCCL_HAS_AVG
2525

2626
};
2727

2828
const std::map<at::ScalarType, onecclDataType_t> xcclDatatypes = {
29-
{at::kByte, onecclDataType_t::ONECCL_UINT8},
30-
{at::kChar, onecclDataType_t::ONECCL_INT8},
31-
{at::kInt, onecclDataType_t::ONECCL_INT32},
32-
{at::kLong, onecclDataType_t::ONECCL_INT64},
33-
{at::kHalf, onecclDataType_t::ONECCL_FLOAT16},
34-
{at::kFloat, onecclDataType_t::ONECCL_FLOAT32},
35-
{at::kDouble, onecclDataType_t::ONECCL_FLOAT64},
36-
{at::kBFloat16, onecclDataType_t::ONECCL_BFLOAT16},
37-
{at::kBool, onecclDataType_t::ONECCL_UINT8},
29+
{at::kByte, onecclDataType_t::onecclUint8},
30+
{at::kChar, onecclDataType_t::onecclChar},
31+
{at::kInt, onecclDataType_t::onecclInt32},
32+
{at::kLong, onecclDataType_t::onecclInt64},
33+
{at::kHalf, onecclDataType_t::onecclFloat16},
34+
{at::kFloat, onecclDataType_t::onecclFloat32},
35+
{at::kDouble, onecclDataType_t::onecclFloat64},
36+
{at::kBFloat16, onecclDataType_t::onecclBfloat16},
37+
{at::kBool, onecclDataType_t::onecclUint8},
3838
// use for non-reducetion op like allgather
39-
{at::kFloat8_e5m2, onecclDataType_t::ONECCL_UINT8},
40-
{at::kFloat8_e4m3fn, onecclDataType_t::ONECCL_UINT8},
41-
{at::kFloat8_e4m3fnuz, onecclDataType_t::ONECCL_UINT8},
42-
{at::kFloat8_e5m2fnuz, onecclDataType_t::ONECCL_UINT8},
39+
{at::kFloat8_e5m2, onecclDataType_t::onecclUint8},
40+
{at::kFloat8_e4m3fn, onecclDataType_t::onecclUint8},
41+
{at::kFloat8_e4m3fnuz, onecclDataType_t::onecclUint8},
42+
{at::kFloat8_e5m2fnuz, onecclDataType_t::onecclUint8},
4343
};
4444

4545
bool checkSameSize(const std::vector<at::Tensor>& input_tensors) {
@@ -117,7 +117,7 @@ onecclRedOp_t getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) {
117117
if (input.scalar_type() == at::kBool) {
118118
if (reduceOp == ReduceOp::SUM) {
119119
// Map sum to max for bool tensors to avoid overflow issues with sum.
120-
return onecclRedOp_t::ONECCL_MAX;
120+
return onecclRedOp_t::onecclMax;
121121
}
122122
#ifdef XCCL_HAS_AVG
123123
if (reduceOp == ReduceOp::AVG) {
@@ -128,7 +128,7 @@ onecclRedOp_t getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) {
128128
}
129129
#if !defined(XCCL_HAS_AVG)
130130
if (reduceOp == ReduceOp::AVG) {
131-
return onecclRedOp_t::ONECCL_SUM;
131+
return onecclRedOp_t::onecclSum;
132132
}
133133
#endif
134134
return xcclOps.at(reduceOp);

0 commit comments

Comments
 (0)