diff --git a/verifiable/verifiable.cu b/verifiable/verifiable.cu index e875c32..37dba14 100644 --- a/verifiable/verifiable.cu +++ b/verifiable/verifiable.cu @@ -392,7 +392,7 @@ struct FloatLayout { }; #endif #if RCCL_FLOAT8 == 1 -#if __HIP_DEVICE_COMPILE__ || HIP_VERSION < 60300000 +#if __HIP_DEVICE_COMPILE__ || (HIP_VERSION >= 60200000 && HIP_VERSION < 60300000) template<> struct FloatLayout { static constexpr bool is_floating_point = true; @@ -993,7 +993,7 @@ cudaError_t prepareInput1( #if HAVE_ncclBfloat16 case ncclBfloat16: fn = (void const*)&prepareInput2; break; #endif - #if HAVE_ncclfp8_DEVICE || HIP_VERSION < 60300000 + #if HAVE_ncclfp8_DEVICE || (HIP_VERSION >= 60200000 && HIP_VERSION < 60300000) case ncclFloat8e4m3: fn = (void const*)&prepareInput2; break; case ncclFloat8e5m2: fn = (void const*)&prepareInput2; break; #elif HAVE_ncclfp8_HOST @@ -1083,7 +1083,7 @@ cudaError_t prepareExpected1( #if HAVE_ncclBfloat16 case ncclBfloat16: fn = (void const*)&prepareExpected2; break; #endif - #if HAVE_ncclfp8_DEVICE || HIP_VERSION < 60300000 //for backward compatibility + #if HAVE_ncclfp8_DEVICE || (HIP_VERSION >= 60200000 && HIP_VERSION < 60300000) case ncclFloat8e4m3: fn = (void const*)&prepareExpected2; break; case ncclFloat8e5m2: fn = (void const*)&prepareExpected2; break; #elif HAVE_ncclfp8_HOST @@ -1321,7 +1321,7 @@ hipError_t ncclVerifiableVerify( #if HAVE_ncclBfloat16 case ncclBfloat16: CASE_TY(hip_bfloat16, uint16_t) #endif - #if HAVE_ncclfp8_DEVICE || HIP_VERSION < 60300000 + #if HAVE_ncclfp8_DEVICE || (HIP_VERSION >= 60200000 && HIP_VERSION < 60300000) case ncclFloat8e4m3: CASE_TY(rccl_float8, uint8_t) case ncclFloat8e5m2: CASE_TY(rccl_bfloat8, uint8_t) #elif HAVE_ncclfp8_HOST