@@ -15,31 +15,31 @@ namespace {
15
15
#endif // oneCCL version >= 2021.15
16
16
17
17
const std::map<c10d::ReduceOp, onecclRedOp_t> xcclOps = {
18
- {ReduceOp::MIN, onecclRedOp_t::ONECCL_MIN },
19
- {ReduceOp::MAX, onecclRedOp_t::ONECCL_MAX },
20
- {ReduceOp::SUM, onecclRedOp_t::ONECCL_SUM },
21
- {ReduceOp::PRODUCT, onecclRedOp_t::ONECCL_PROD },
18
+ {ReduceOp::MIN, onecclRedOp_t::onecclMin },
19
+ {ReduceOp::MAX, onecclRedOp_t::onecclMax },
20
+ {ReduceOp::SUM, onecclRedOp_t::onecclSum },
21
+ {ReduceOp::PRODUCT, onecclRedOp_t::onecclProd },
22
22
#ifdef XCCL_HAS_AVG
23
- {ReduceOp::AVG, onecclRedOp_t::ONECCL_AVG },
23
+ {ReduceOp::AVG, onecclRedOp_t::onecclAvg },
24
24
#endif // XCCL_HAS_AVG
25
25
26
26
};
27
27
28
28
const std::map<at::ScalarType, onecclDataType_t> xcclDatatypes = {
29
- {at::kByte , onecclDataType_t::ONECCL_UINT8 },
30
- {at::kChar , onecclDataType_t::ONECCL_INT8 },
31
- {at::kInt , onecclDataType_t::ONECCL_INT32 },
32
- {at::kLong , onecclDataType_t::ONECCL_INT64 },
33
- {at::kHalf , onecclDataType_t::ONECCL_FLOAT16 },
34
- {at::kFloat , onecclDataType_t::ONECCL_FLOAT32 },
35
- {at::kDouble , onecclDataType_t::ONECCL_FLOAT64 },
36
- {at::kBFloat16 , onecclDataType_t::ONECCL_BFLOAT16 },
37
- {at::kBool , onecclDataType_t::ONECCL_UINT8 },
29
+ {at::kByte , onecclDataType_t::onecclUint8 },
30
+ {at::kChar , onecclDataType_t::onecclChar },
31
+ {at::kInt , onecclDataType_t::onecclInt32 },
32
+ {at::kLong , onecclDataType_t::onecclInt64 },
33
+ {at::kHalf , onecclDataType_t::onecclFloat16 },
34
+ {at::kFloat , onecclDataType_t::onecclFloat32 },
35
+ {at::kDouble , onecclDataType_t::onecclFloat64 },
36
+ {at::kBFloat16 , onecclDataType_t::onecclBfloat16 },
37
+ {at::kBool , onecclDataType_t::onecclUint8 },
38
38
// use for non-reducetion op like allgather
39
- {at::kFloat8_e5m2 , onecclDataType_t::ONECCL_UINT8 },
40
- {at::kFloat8_e4m3fn , onecclDataType_t::ONECCL_UINT8 },
41
- {at::kFloat8_e4m3fnuz , onecclDataType_t::ONECCL_UINT8 },
42
- {at::kFloat8_e5m2fnuz , onecclDataType_t::ONECCL_UINT8 },
39
+ {at::kFloat8_e5m2 , onecclDataType_t::onecclUint8 },
40
+ {at::kFloat8_e4m3fn , onecclDataType_t::onecclUint8 },
41
+ {at::kFloat8_e4m3fnuz , onecclDataType_t::onecclUint8 },
42
+ {at::kFloat8_e5m2fnuz , onecclDataType_t::onecclUint8 },
43
43
};
44
44
45
45
bool checkSameSize (const std::vector<at::Tensor>& input_tensors) {
@@ -117,7 +117,7 @@ onecclRedOp_t getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) {
117
117
if (input.scalar_type () == at::kBool ) {
118
118
if (reduceOp == ReduceOp::SUM) {
119
119
// Map sum to max for bool tensors to avoid overflow issues with sum.
120
- return onecclRedOp_t::ONECCL_MAX ;
120
+ return onecclRedOp_t::onecclMax ;
121
121
}
122
122
#ifdef XCCL_HAS_AVG
123
123
if (reduceOp == ReduceOp::AVG) {
@@ -128,7 +128,7 @@ onecclRedOp_t getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) {
128
128
}
129
129
#if !defined(XCCL_HAS_AVG)
130
130
if (reduceOp == ReduceOp::AVG) {
131
- return onecclRedOp_t::ONECCL_SUM ;
131
+ return onecclRedOp_t::onecclSum ;
132
132
}
133
133
#endif
134
134
return xcclOps.at (reduceOp);
0 commit comments