diff --git a/src/all_reduce.cu b/src/all_reduce.cu index 5302f86..8873d6d 100644 --- a/src/all_reduce.cu +++ b/src/all_reduce.cu @@ -65,7 +65,7 @@ testResult_t AllReduceRunTest(struct threadArgs* args, int root, ncclDataType_t ncclRedOp_t *run_ops; const char **run_typenames, **run_opnames; int type_count, op_count; - if((type == ncclFp8E4M3 || type == ncclFp8E5M2) && op == ncclProd) + if((type == ncclFloat8e4m3 || type == ncclFloat8e5m2) && op == ncclProd) return testSuccess; if ((int)type != -1) { @@ -90,7 +90,7 @@ testResult_t AllReduceRunTest(struct threadArgs* args, int root, ncclDataType_t for (int i=0; i(rank); break; #endif #if defined(RCCL_FLOAT8) - case ncclFp8E4M3: fp8_e4m3 = ncclVerifiablePremulScalar(rank); break; - case ncclFp8E5M2: fp8_e5m2 = ncclVerifiablePremulScalar(rank); break; + case ncclFloat8e4m3: fp8_e4m3 = ncclVerifiablePremulScalar(rank); break; + case ncclFloat8e5m2: fp8_e5m2 = ncclVerifiablePremulScalar(rank); break; #endif case ncclNumTypes: break; } diff --git a/src/common.h b/src/common.h index 2f2082c..ddbb4ea 100644 --- a/src/common.h +++ b/src/common.h @@ -250,8 +250,8 @@ static size_t wordSize(ncclDataType_t type) { //case ncclInt8: case ncclUint8: #if NCCL_MAJOR >= 2 && RCCL_FLOAT8 == 1 - case ncclFp8E4M3: - case ncclFp8E5M2: + case ncclFloat8e4m3: + case ncclFloat8e5m2: #endif #endif return 1; diff --git a/verifiable/verifiable.cu b/verifiable/verifiable.cu index 32c13b0..c72618e 100644 --- a/verifiable/verifiable.cu +++ b/verifiable/verifiable.cu @@ -890,8 +890,8 @@ void prepareInput1( case ncclBfloat16: CASE_TY(hip_bfloat16) #endif #if HAVE_ncclfp8 - case ncclFp8E4M3: CASE_TY(rccl_float8) - case ncclFp8E5M2: CASE_TY(rccl_bfloat8) + case ncclFloat8e4m3: CASE_TY(rccl_float8) + case ncclFloat8e5m2: CASE_TY(rccl_bfloat8) #endif case ncclFloat32: CASE_TY(float) case ncclFloat64: CASE_TY(double) @@ -970,8 +970,8 @@ void prepareExpected1( case ncclBfloat16: CASE_TY(hip_bfloat16) #endif #if HAVE_ncclfp8 - case ncclFp8E4M3: CASE_TY(rccl_float8) - case ncclFp8E5M2: CASE_TY(rccl_bfloat8) + case ncclFloat8e4m3: CASE_TY(rccl_float8) + case ncclFloat8e5m2: CASE_TY(rccl_bfloat8) #endif case ncclFloat32: CASE_TY(float) case ncclFloat64: CASE_TY(double) @@ -1044,8 +1044,8 @@ __host__ __device__ unsigned calcSumFloatTolerance(int rank_n, int elt_ty) { break; #endif #if HAVE_ncclfp8 - case ncclFp8E4M3: - case ncclFp8E5M2: + case ncclFloat8e4m3: + case ncclFloat8e5m2: power = .91f; coef = .66f; break; @@ -1175,8 +1175,8 @@ void ncclVerifiableVerify( floating |= elt_ty == ncclBfloat16; #endif #if HAVE_ncclfp8 - floating |= elt_ty == ncclFp8E4M3; - floating |= elt_ty == ncclFp8E5M2; + floating |= elt_ty == ncclFloat8e4m3; + floating |= elt_ty == ncclFloat8e5m2; #endif unsigned tolerance = 0; @@ -1207,8 +1207,8 @@ void ncclVerifiableVerify( case ncclBfloat16: CASE_TY(hip_bfloat16, uint16_t) #endif #if HAVE_ncclfp8 - case ncclFp8E4M3: CASE_TY(rccl_float8, uint8_t) - case ncclFp8E5M2: CASE_TY(rccl_bfloat8, uint8_t) + case ncclFloat8e4m3: CASE_TY(rccl_float8, uint8_t) + case ncclFloat8e5m2: CASE_TY(rccl_bfloat8, uint8_t) #endif case ncclFloat32: CASE_TY(float, uint32_t) case ncclFloat64: CASE_TY(double, uint64_t) @@ -1278,8 +1278,8 @@ __global__ void sweep() { sweep1(ncclBfloat16, "bfloat16"); #endif #if HAVE_ncclfp8 - sweep1(ncclFp8E4M3, "fp8_e4m3"); - sweep1(ncclFp8E5M2, "fp8_e5m2"); + sweep1(ncclFloat8e4m3, "fp8_e4m3"); + sweep1(ncclFloat8e5m2, "fp8_e5m2"); #endif sweep1(ncclFloat32, "float"); sweep1(ncclFloat64, "double");