Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/all_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ testResult_t AllReduceRunTest(struct threadArgs* args, int root, ncclDataType_t
ncclRedOp_t *run_ops;
const char **run_typenames, **run_opnames;
int type_count, op_count;
if((type == ncclFp8E4M3 || type == ncclFp8E5M2) && op == ncclProd)
if((type == ncclFloat8e4m3 || type == ncclFloat8e5m2) && op == ncclProd)
return testSuccess;

if ((int)type != -1) {
Expand All @@ -90,7 +90,7 @@ testResult_t AllReduceRunTest(struct threadArgs* args, int root, ncclDataType_t

for (int i=0; i<type_count; i++) {
for (int j=0; j<op_count; j++) {
if((i == ncclFp8E4M3 || i == ncclFp8E5M2) && j == ncclProd)
if((i == ncclFloat8e4m3 || i == ncclFloat8e5m2) && j == ncclProd)
continue;
TESTCHECK(TimeTest(args, run_types[i], run_typenames[i], run_ops[j], run_opnames[j], -1));
}
Expand Down
20 changes: 17 additions & 3 deletions src/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ size_t cache_bytes = 192 * 1024 * 1024; // Use 192MB
, ncclBfloat16
#endif
#if RCCL_FLOAT8 == 1
, ncclFp8E4M3, ncclFp8E5M2
, ncclFloat8e4m3, ncclFloat8e5m2
#endif
};
const char *test_typenames[ncclNumTypes] = {
Expand Down Expand Up @@ -115,6 +115,13 @@ static int enable_rotating_tensor = 0;
static int local_register = 0;
#endif

// RCCL_FLOAT8 support
bool rccl_float8_useFnuz = false;
bool IsArchMatch(char const* arch, char const* target) {
// helper function to reduce clutter in code elsewhere. Returns true on match.
return (strncmp(arch, target, strlen(target)) == 0);
}

Reporter::Reporter(std::string fileName, std::string outputFormat) : _outputFormat(outputFormat) {
if (!fileName.empty()) {
if (isMainThread()) {
Expand Down Expand Up @@ -557,8 +564,8 @@ testResult_t startColl(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
case ncclBfloat16: bf16 = ncclVerifiablePremulScalar<hip_bfloat16>(rank); break;
#endif
#if defined(RCCL_FLOAT8)
case ncclFp8E4M3: fp8_e4m3 = ncclVerifiablePremulScalar<rccl_float8>(rank); break;
case ncclFp8E5M2: fp8_e5m2 = ncclVerifiablePremulScalar<rccl_bfloat8>(rank); break;
case ncclFloat8e4m3: fp8_e4m3 = ncclVerifiablePremulScalar<rccl_float8>(rank); break;
case ncclFloat8e5m2: fp8_e5m2 = ncclVerifiablePremulScalar<rccl_bfloat8>(rank); break;
Comment on lines +567 to +568
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing if/else. Likely will need to add fp8_e4m3_fnuz and fp8_e5m2_fnuz fields to the union.

#endif
case ncclNumTypes: break;
}
Expand Down Expand Up @@ -1290,6 +1297,13 @@ testResult_t run() {
char hostname[1024];
getHostName(hostname, 1024);

hipDeviceProp_t devProp;
CUDACHECK(hipGetDeviceProperties(&devProp, 0));
if (IsArchMatch(devProp.gcnArchName, "gfx942")) {
PRINT("On gfx942 architecture, using FNUZ FP8 types");
rccl_float8_useFnuz = true;
}

#ifdef MPI_SUPPORT
MPI_Comm_size(MPI_COMM_WORLD, &totalProcs);
MPI_Comm_rank(MPI_COMM_WORLD, &proc);
Expand Down
4 changes: 2 additions & 2 deletions src/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,8 @@ static size_t wordSize(ncclDataType_t type) {
//case ncclInt8:
case ncclUint8:
#if NCCL_MAJOR >= 2 && RCCL_FLOAT8 == 1
case ncclFp8E4M3:
case ncclFp8E5M2:
case ncclFloat8e4m3:
case ncclFloat8e5m2:
#endif
#endif
return 1;
Expand Down
Loading