Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/all_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ testResult_t AllReduceRunTest(struct threadArgs* args, int root, ncclDataType_t
ncclRedOp_t *run_ops;
const char **run_typenames, **run_opnames;
int type_count, op_count;
if((type == ncclFp8E4M3 || type == ncclFp8E5M2) && op == ncclProd)
if((type == ncclFloat8e4m3 || type == ncclFloat8e5m2) && op == ncclProd)
return testSuccess;

if ((int)type != -1) {
Expand All @@ -90,7 +90,7 @@ testResult_t AllReduceRunTest(struct threadArgs* args, int root, ncclDataType_t

for (int i=0; i<type_count; i++) {
for (int j=0; j<op_count; j++) {
if((i == ncclFp8E4M3 || i == ncclFp8E5M2) && j == ncclProd)
if((i == ncclFloat8e4m3 || i == ncclFloat8e5m2) && j == ncclProd)
continue;
TESTCHECK(TimeTest(args, run_types[i], run_typenames[i], run_ops[j], run_opnames[j], -1));
}
Expand Down
20 changes: 17 additions & 3 deletions src/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ size_t cache_bytes = 192 * 1024 * 1024; // Use 192MB
, ncclBfloat16
#endif
#if RCCL_FLOAT8 == 1
, ncclFp8E4M3, ncclFp8E5M2
, ncclFloat8e4m3, ncclFloat8e5m2
#endif
};
const char *test_typenames[ncclNumTypes] = {
Expand Down Expand Up @@ -115,6 +115,13 @@ static int enable_rotating_tensor = 0;
static int local_register = 0;
#endif

// RCCL_FLOAT8 support
bool rccl_float8_useFnuz = false;
bool IsArchMatch(char const* arch, char const* target) {
// helper function to reduce clutter in code elsewhere. Returns true on match.
return (strncmp(arch, target, strlen(target)) == 0);
}

Reporter::Reporter(std::string fileName, std::string outputFormat) : _outputFormat(outputFormat) {
if (!fileName.empty()) {
if (isMainThread()) {
Expand Down Expand Up @@ -557,8 +564,8 @@ testResult_t startColl(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
case ncclBfloat16: bf16 = ncclVerifiablePremulScalar<hip_bfloat16>(rank); break;
#endif
#if defined(RCCL_FLOAT8)
case ncclFp8E4M3: fp8_e4m3 = ncclVerifiablePremulScalar<rccl_float8>(rank); break;
case ncclFp8E5M2: fp8_e5m2 = ncclVerifiablePremulScalar<rccl_bfloat8>(rank); break;
case ncclFloat8e4m3: fp8_e4m3 = ncclVerifiablePremulScalar<rccl_float8>(rank); break;
case ncclFloat8e5m2: fp8_e5m2 = ncclVerifiablePremulScalar<rccl_bfloat8>(rank); break;
Comment on lines +567 to +568
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing if/else. Likely will need to add fp8_e4m3_fnuz and fp8_e5m2_fnuz fields to the union.

#endif
case ncclNumTypes: break;
}
Expand Down Expand Up @@ -1290,6 +1297,13 @@ testResult_t run() {
char hostname[1024];
getHostName(hostname, 1024);

hipDeviceProp_t devProp;
CUDACHECK(hipGetDeviceProperties(&devProp, 0));
if (IsArchMatch(devProp.gcnArchName, "gfx942")) {
PRINT("On gfx942 architecture, using FNUZ FP8 types");
rccl_float8_useFnuz = true;
}

#ifdef MPI_SUPPORT
MPI_Comm_size(MPI_COMM_WORLD, &totalProcs);
MPI_Comm_rank(MPI_COMM_WORLD, &proc);
Expand Down
4 changes: 2 additions & 2 deletions src/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,8 @@ static size_t wordSize(ncclDataType_t type) {
//case ncclInt8:
case ncclUint8:
#if NCCL_MAJOR >= 2 && RCCL_FLOAT8 == 1
case ncclFp8E4M3:
case ncclFp8E5M2:
case ncclFloat8e4m3:
case ncclFloat8e5m2:
#endif
#endif
return 1;
Expand Down
Loading