diff --git a/src/all_reduce.cu b/src/all_reduce.cu index 5302f86..6f2fa35 100644 --- a/src/all_reduce.cu +++ b/src/all_reduce.cu @@ -65,7 +65,7 @@ testResult_t AllReduceRunTest(struct threadArgs* args, int root, ncclDataType_t ncclRedOp_t *run_ops; const char **run_typenames, **run_opnames; int type_count, op_count; - if((type == ncclFp8E4M3 || type == ncclFp8E5M2) && op == ncclProd) + if((type == ncclFloat8e4m3 || type == ncclFloat8e5m2) && op == ncclProd) return testSuccess; if ((int)type != -1) { @@ -90,7 +90,7 @@ testResult_t AllReduceRunTest(struct threadArgs* args, int root, ncclDataType_t for (int i=0; i(rank); break; #endif #if defined(RCCL_FLOAT8) - case ncclFp8E4M3: fp8_e4m3 = ncclVerifiablePremulScalar(rank); break; - case ncclFp8E5M2: fp8_e5m2 = ncclVerifiablePremulScalar(rank); break; + case ncclFloat8e4m3: fp8_e4m3 = ncclVerifiablePremulScalar(rank); break; + case ncclFloat8e5m2: fp8_e5m2 = ncclVerifiablePremulScalar(rank); break; #endif case ncclNumTypes: break; } @@ -1290,6 +1297,13 @@ testResult_t run() { char hostname[1024]; getHostName(hostname, 1024); + hipDeviceProp_t devProp; + CUDACHECK(hipGetDeviceProperties(&devProp, 0)); + if (IsArchMatch(devProp.gcnArchName, "gfx942")) { + PRINT("On gfx942 architecture, using FNUZ FP8 types"); + rccl_float8_useFnuz = true; + } + #ifdef MPI_SUPPORT MPI_Comm_size(MPI_COMM_WORLD, &totalProcs); MPI_Comm_rank(MPI_COMM_WORLD, &proc); diff --git a/src/common.h b/src/common.h index 2f2082c..ddbb4ea 100644 --- a/src/common.h +++ b/src/common.h @@ -250,8 +250,8 @@ static size_t wordSize(ncclDataType_t type) { //case ncclInt8: case ncclUint8: #if NCCL_MAJOR >= 2 && RCCL_FLOAT8 == 1 - case ncclFp8E4M3: - case ncclFp8E5M2: + case ncclFloat8e4m3: + case ncclFloat8e5m2: #endif #endif return 1; diff --git a/src/rccl_float8.h b/src/rccl_float8.h index 01cab41..88af37c 100644 --- a/src/rccl_float8.h +++ b/src/rccl_float8.h @@ -39,6 +39,12 @@ typedef struct } rccl_bfloat8; #else // __cplusplus < 201103L || (!defined(__HCC__) && !defined(__HIPCC__)) +#if HIP_VERSION >= 60200000 +#include +#else +#define HIP_FP8_TYPE_OCP 0 +#define HIP_FP8_TYPE_FNUZ 0 +#endif #define HIP_HOST_DEVICE __host__ __device__ #define HIP_HOST __host__ @@ -332,7 +338,8 @@ namespace rocblas_hip_f8_impl static __device__ bool rocblas_hip_f8_bias_mode_bit_device = true; static bool rocblas_hip_f8_bias_mode_bit_host = true; -struct rccl_float8 +template +struct rccl_float8_bc { uint8_t data; enum class rocblas_hip_f8_rounding_mode @@ -342,9 +349,9 @@ struct rccl_float8 }; // default constructor - HIP_HOST_DEVICE rccl_float8() = default; + HIP_HOST_DEVICE rccl_float8_bc() = default; -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) // device specific optimized F8 down-conversion code template @@ -381,13 +388,13 @@ struct rccl_float8 return i8data; } -#endif // __gfx940__ +#endif // __gfx942__ // constructor from float -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) // NOTE: ON-DEVICE... always optimal bias - explicit HIP_DEVICE rccl_float8(float v, + explicit HIP_DEVICE rccl_float8_bc(float v, rocblas_hip_f8_rounding_mode rm = rocblas_hip_f8_rounding_mode::standard, uint32_t rng = 0) @@ -402,10 +409,10 @@ struct rccl_float8 // Host only implementation using s/w simulation explicit HIP_HOST #else - // both Host and DEVICE for non-gfx940 using s/w simulation + // both Host and DEVICE for non-gfx942 using s/w simulation explicit HIP_HOST_DEVICE #endif - rccl_float8(float v, + rccl_float8_bc(float v, rocblas_hip_f8_rounding_mode rm = rocblas_hip_f8_rounding_mode::standard, uint32_t rng = 0) { @@ -421,32 +428,32 @@ struct rccl_float8 } // Constructor from half - explicit HIP_HOST_DEVICE rccl_float8(_Float16 v, + explicit HIP_HOST_DEVICE rccl_float8_bc(_Float16 v, rocblas_hip_f8_rounding_mode rm = rocblas_hip_f8_rounding_mode::standard, uint32_t rng = 0) - : rccl_float8((float)v, rm, rng) + : rccl_float8_bc((float)v, rm, rng) { } // constructor from int - explicit HIP_HOST_DEVICE rccl_float8(int v, + explicit HIP_HOST_DEVICE rccl_float8_bc(int v, rocblas_hip_f8_rounding_mode rm = rocblas_hip_f8_rounding_mode::standard, uint32_t rng = 0) - : rccl_float8((float)v, rm, rng) + : rccl_float8_bc((float)v, rm, rng) { } // constructor from double - explicit HIP_HOST_DEVICE rccl_float8(double v, + explicit HIP_HOST_DEVICE rccl_float8_bc(double v, rocblas_hip_f8_rounding_mode rm = rocblas_hip_f8_rounding_mode::standard, uint32_t rng = 0) - : rccl_float8((float)v, rm, rng) + : rccl_float8_bc((float)v, rm, rng) { } // convert to float -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) // upcast using device specific intrinsic explicit inline HIP_DEVICE operator float() const { @@ -460,7 +467,7 @@ struct rccl_float8 } explicit inline HIP_HOST operator float() const -#else // non gfx940 +#else // non gfx942 explicit inline HIP_HOST_DEVICE operator float() const #endif { @@ -492,14 +499,15 @@ struct rccl_float8 } // assignment overloading only from the same F8 types - inline __host__ __device__ rccl_float8& operator=(const rccl_float8& a) + inline __host__ __device__ rccl_float8_bc& operator=(const rccl_float8_bc& a) { data = a.data; return *this; } }; -struct rccl_bfloat8 +template +struct rccl_bfloat8_bc { uint8_t data; enum class rocblas_hip_f8_rounding_mode @@ -509,9 +517,9 @@ struct rccl_bfloat8 }; // default constructor - HIP_HOST_DEVICE rccl_bfloat8() = default; + HIP_HOST_DEVICE rccl_bfloat8_bc() = default; -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) // device specific optimized F8 down-conversion code template @@ -548,13 +556,13 @@ struct rccl_bfloat8 return i8data; } -#endif // __gfx940__ +#endif // __gfx942__ // constructor from float -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) // NOTE: ON-DEVICE... always optimal bias - explicit HIP_DEVICE rccl_bfloat8(float v, + explicit HIP_DEVICE rccl_bfloat8_bc(float v, rocblas_hip_f8_rounding_mode rm = rocblas_hip_f8_rounding_mode::standard, uint32_t rng = 0) @@ -569,10 +577,10 @@ struct rccl_bfloat8 // Host only implementation using s/w simulation explicit HIP_HOST #else - // both Host and DEVICE for non-gfx940 using s/w simulation + // both Host and DEVICE for non-gfx942 using s/w simulation explicit HIP_HOST_DEVICE #endif - rccl_bfloat8(float v, + rccl_bfloat8_bc(float v, rocblas_hip_f8_rounding_mode rm = rocblas_hip_f8_rounding_mode::standard, uint32_t rng = 0) { @@ -588,32 +596,32 @@ struct rccl_bfloat8 } // Constructor from half - explicit HIP_HOST_DEVICE rccl_bfloat8(_Float16 v, + explicit HIP_HOST_DEVICE rccl_bfloat8_bc(_Float16 v, rocblas_hip_f8_rounding_mode rm = rocblas_hip_f8_rounding_mode::standard, uint32_t rng = 0) - : rccl_bfloat8((float)v, rm, rng) + : rccl_bfloat8_bc((float)v, rm, rng) { } // constructor from int - explicit HIP_HOST_DEVICE rccl_bfloat8(int v, + explicit HIP_HOST_DEVICE rccl_bfloat8_bc(int v, rocblas_hip_f8_rounding_mode rm = rocblas_hip_f8_rounding_mode::standard, uint32_t rng = 0) - : rccl_bfloat8((float)v, rm, rng) + : rccl_bfloat8_bc((float)v, rm, rng) { } // constructor from double - explicit HIP_HOST_DEVICE rccl_bfloat8(double v, + explicit HIP_HOST_DEVICE rccl_bfloat8_bc(double v, rocblas_hip_f8_rounding_mode rm = rocblas_hip_f8_rounding_mode::standard, uint32_t rng = 0) - : rccl_bfloat8((float)v, rm, rng) + : rccl_bfloat8_bc((float)v, rm, rng) { } // convert to float -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) // upcast using device specific intrinsic explicit inline HIP_DEVICE operator float() const { @@ -627,7 +635,7 @@ struct rccl_bfloat8 } explicit inline HIP_HOST operator float() const -#else // non gfx940 +#else // non gfx942 explicit inline HIP_HOST_DEVICE operator float() const #endif { @@ -659,363 +667,202 @@ struct rccl_bfloat8 } // assignment overloading only from the same F8 types - inline __host__ __device__ rccl_bfloat8& operator=(const rccl_bfloat8& a) + inline __host__ __device__ rccl_bfloat8_bc& operator=(const rccl_bfloat8_bc& a) { data = a.data; return *this; } }; -namespace std -{ - inline rccl_float8 sin(rccl_float8 a) - { - return rccl_float8(sinf(float(a))); - } - inline rccl_float8 cos(rccl_float8 a) - { - return rccl_float8(cosf(float(a))); - } - inline rccl_bfloat8 sin(rccl_bfloat8 a) - { - return rccl_bfloat8(sinf(float(a))); - } - inline rccl_bfloat8 cos(rccl_bfloat8 a) - { - return rccl_bfloat8(cosf(float(a))); - } - __device__ __host__ constexpr rccl_float8 real(const rccl_float8& a) - { - return a; - } - __device__ __host__ constexpr rccl_bfloat8 real(const rccl_bfloat8& a) - { - return a; - } -} - -// Special operator overloading -inline std::ostream& operator<<(std::ostream& os, const rccl_float8& f8) -{ - return os << float(f8); -} - -inline std::ostream& operator<<(std::ostream& os, const rccl_bfloat8& bf8) -{ - return os << float(bf8); -} - -// all + operator overloading with mixed types -// mixed types, always converts to f32, does computation in f32, and returns float -inline __host__ __device__ float operator+(const float fa, rccl_float8 b) -{ - return (fa + float(b)); -} - -inline __host__ __device__ float operator+(const float fa, rccl_bfloat8 b) -{ - return (fa + float(b)); -} - -inline __host__ __device__ float operator+(rccl_float8 a, const float fb) -{ - return (float(a) + fb); -} - -inline __host__ __device__ float operator+(rccl_bfloat8 a, const float fb) -{ - return (float(a) + fb); -} - -inline __host__ __device__ float operator+(rccl_float8 a, rccl_bfloat8 b) -{ - return (float(a) + float(b)); -} - -inline __host__ __device__ float operator+(rccl_bfloat8 a, rccl_float8 b) -{ - return (float(a) + float(b)); -} - -inline __host__ __device__ rccl_float8 operator+(rccl_float8 a, rccl_float8 b) -{ - return rccl_float8(float(a) + float(b)); -} - -inline __host__ __device__ rccl_bfloat8 operator+(rccl_bfloat8 a, rccl_bfloat8 b) -{ - return rccl_bfloat8(float(a) + float(b)); -} - -inline __host__ __device__ rccl_float8& operator+=(rccl_float8& a, rccl_float8 b) -{ - return a = rccl_float8(float(a) + float(b)); -} - -inline __host__ __device__ rccl_bfloat8& operator+=(rccl_bfloat8& a, rccl_bfloat8 b) -{ - return a = rccl_bfloat8(float(a) + float(b)); -} - -// overloading multiplication, always returns float, -inline __host__ __device__ float operator*(rccl_float8 a, rccl_float8 b) -{ - return float(a) * float(b); -} - -inline __host__ __device__ float operator*(float a, rccl_float8 b) -{ - return (a * float(b)); -} - -inline __host__ __device__ float operator*(rccl_float8 a, float b) -{ - return (float(a) * b); -} - -inline __host__ __device__ float operator*(int32_t a, rccl_float8 b) -{ - return ((float)a * float(b)); -} - -inline __host__ __device__ float operator*(double a, rccl_float8 b) -{ - return ((float)a * float(b)); -} - -inline __host__ __device__ float operator*(rccl_bfloat8 a, rccl_bfloat8 b) -{ - return float(a) * float(b); -} - -inline __host__ __device__ float operator*(float a, rccl_bfloat8 b) -{ - return (a * float(b)); -} - -inline __host__ __device__ float operator*(rccl_bfloat8 a, float b) -{ - return (float(a) * b); -} - -inline __host__ __device__ float operator*(int32_t a, rccl_bfloat8 b) -{ - return ((float)a * float(b)); -} - -inline __host__ __device__ float operator*(double a, rccl_bfloat8 b) -{ - return ((float)a * float(b)); -} - -// overloading for mixed f8 and bf8 types -inline __host__ __device__ float operator*(rccl_float8 a, rccl_bfloat8 b) -{ - return float(a) * float(b); -} - -inline __host__ __device__ float operator*(rccl_bfloat8 a, rccl_float8 b) -{ - return float(a) * float(b); -} - -// all - operator overloading with mixed types -// mixed types, always converts to f32, does computation in f32, and returns float -inline __host__ __device__ float operator-(const float fa, rccl_float8 b) -{ - return (fa - float(b)); -} - -inline __host__ __device__ float operator-(const float fa, rccl_bfloat8 b) -{ - return (fa - float(b)); -} - -inline __host__ __device__ float operator-(rccl_float8 a, const float fb) -{ - return (float(a) - fb); -} - -inline __host__ __device__ float operator-(rccl_bfloat8 a, const float fb) -{ - return (float(a) - fb); -} - -inline __host__ __device__ float operator-(rccl_float8 a, rccl_bfloat8 b) -{ - return (float(a) - float(b)); -} - -inline __host__ __device__ float operator-(rccl_bfloat8 a, rccl_float8 b) -{ - return (float(a) - float(b)); -} - -inline __host__ __device__ rccl_float8 operator-(rccl_float8 a, rccl_float8 b) -{ - return rccl_float8(float(a) - float(b)); -} - -inline __host__ __device__ rccl_bfloat8 operator-(rccl_bfloat8 a, rccl_bfloat8 b) -{ - return rccl_bfloat8(float(a) - float(b)); -} - -inline __host__ __device__ rccl_float8& operator-=(rccl_float8& a, rccl_float8 b) -{ - return a = rccl_float8(float(a) - float(b)); -} - -inline __host__ __device__ rccl_bfloat8& operator-=(rccl_bfloat8& a, rccl_bfloat8 b) -{ - return a = rccl_bfloat8(float(a) - float(b)); -} - -// overloading division, always returns float, -inline __host__ __device__ float operator/(rccl_float8 a, rccl_float8 b) -{ - return float(a) / float(b); -} - -inline __host__ __device__ float operator/(float a, rccl_float8 b) -{ - return (a / float(b)); -} - -inline __host__ __device__ float operator/(rccl_float8 a, float b) -{ - return (float(a) / b); -} - -inline __host__ __device__ float operator/(int32_t a, rccl_float8 b) -{ - return ((float)a / float(b)); -} - -inline __host__ __device__ float operator/(double a, rccl_float8 b) -{ - return ((float)a / float(b)); -} - -inline __host__ __device__ float operator/(rccl_bfloat8 a, rccl_bfloat8 b) -{ - return float(a) / float(b); -} - -inline __host__ __device__ float operator/(float a, rccl_bfloat8 b) -{ - return (a / float(b)); -} - -inline __host__ __device__ float operator/(rccl_bfloat8 a, float b) -{ - return (float(a) / b); -} - -inline __host__ __device__ float operator/(int32_t a, rccl_bfloat8 b) -{ - return ((float)a / float(b)); -} - -inline __host__ __device__ float operator/(double a, rccl_bfloat8 b) -{ - return ((float)a / float(b)); -} - -// overloading for mixed f8 and bf8 types -inline __host__ __device__ float operator/(rccl_float8 a, rccl_bfloat8 b) -{ - return float(a) / float(b); -} - -inline __host__ __device__ float operator/(rccl_bfloat8 a, rccl_float8 b) -{ - return float(a) / float(b); -} - // overloading for compare -inline __host__ __device__ bool operator==(rccl_float8 a, rccl_float8 b) +template +inline __host__ __device__ bool operator==(rccl_float8_bc a, rccl_float8_bc b) { return (a.data == b.data); } -inline __host__ __device__ bool operator==(rccl_bfloat8 a, rccl_bfloat8 b) +template +inline __host__ __device__ bool operator==(rccl_bfloat8_bc a, rccl_bfloat8_bc b) { return (a.data == b.data); } -inline __host__ __device__ bool operator!=(rccl_float8 a, rccl_float8 b) +template +inline __host__ __device__ bool operator!=(rccl_float8_bc a, rccl_float8_bc b) { return (a.data != b.data); } -inline __host__ __device__ bool operator!=(rccl_bfloat8 a, rccl_bfloat8 b) +template +inline __host__ __device__ bool operator!=(rccl_bfloat8_bc a, rccl_bfloat8_bc b) { return (a.data != b.data); } -// ================ Explicit downcasting to support different rounding (RNE, SR) =============== -// NOTE: we going to remove all assignment operator overloading from other types and enforce -// this explicit_downcast function to make any roudning behavior default -// We have to explicitly call this function with SR flag - -template {}, int>::type = 0> -inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng = 0) -{ - // same type, no conversion - return a; -} +#if HIP_FP8_TYPE_OCP +typedef __hip_fp8_e4m3 rccl_float8; +typedef __hip_fp8_e5m2 rccl_bfloat8; +#else +typedef rccl_float8_bc rccl_float8; +typedef rccl_bfloat8_bc rccl_bfloat8; +#endif -// Use h/w intrinsic and optimized version when __gfx940__ -template < - typename T, - typename Ta, - bool stochastic_rounding, - typename std::enable_if<(!(std::is_same{}) - && (std::is_same{} || std::is_same{})), - int>::type - = 0> -inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng) -{ -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) - // NOTE: we are directly calling cast_to_f8_from_f32 instead of constructor to optimize away one runtime branch - T val; - if(std::is_same::value) - val.data = rccl_float8::cast_to_f8_from_f32(float(a), rng); - else - val.data = rccl_bfloat8::cast_to_bf8_from_f32(float(a), rng); - return val; -#else // non gfx940 - return T(float(a), - stochastic_rounding ? T::rocblas_hip_f8_rounding_mode::stochastic - : T::rocblas_hip_f8_rounding_mode::standard, - rng); -#endif // __gfx940__ -} +#if HIP_FP8_TYPE_FNUZ +typedef __hip_fp8_e4m3_fnuz rccl_float8_fnuz; +typedef __hip_fp8_e5m2_fnuz rccl_bfloat8_fnuz; +#else +typedef rccl_float8_bc rccl_float8_fnuz; +typedef rccl_bfloat8_bc rccl_bfloat8_fnuz; +#endif -// NOTE NOTE: The above code is good if we don't consider HIP-GEMM code and only consider the quantization -// However, if we need HIP-GEMM for fall-back, we would need explicit_cast handles Tacc=f32 to To=f16/bf16 conversion -template < - typename T, - typename Ta, - bool stochastic_rounding, - typename std::enable_if<(!(std::is_same{}) - && !(std::is_same{} || std::is_same{})), - int>::type - = 0> -inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng) -{ - // the return type is not a F8 types, no SR for those types - // not sure if we have direct conversion, so converting to float first - // no effect if the input type is float - return T(float(a)); -} +#define RCCL_FLOAT8_OPERATORS(TYPE) \ +/* Special operator overloading */ \ +inline std::ostream& operator<<(std::ostream& os, const TYPE& f8) \ +{ \ + return os << float(f8); \ +} \ +\ +/* all + operator overloading with mixed types */ \ +/* mixed types, always converts to f32, does computation in f32, and returns float */ \ +inline __host__ __device__ float operator+(const float fa, TYPE b) \ +{ \ + return (fa + float(b)); \ +} \ +\ +inline __host__ __device__ float operator+(TYPE a, const float fb) \ +{ \ + return (float(a) + fb); \ +} \ +\ +inline __host__ __device__ TYPE operator+(TYPE a, TYPE b) \ +{ \ + return TYPE(float(a) + float(b)); \ +} \ + \ +inline __host__ __device__ TYPE& operator+=(TYPE& a, TYPE b) \ +{ \ + return a = TYPE(float(a) + float(b)); \ +} \ + \ +/* overloading multiplication, always returns float, */ \ +inline __host__ __device__ float operator*(TYPE a, TYPE b) \ +{ \ + return float(a) * float(b); \ +} \ + \ +inline __host__ __device__ float operator*(float a, TYPE b) \ +{ \ + return (a * float(b)); \ +} \ + \ +inline __host__ __device__ float operator*(TYPE a, float b) \ +{ \ + return (float(a) * b); \ +} \ + \ +inline __host__ __device__ float operator*(int32_t a, TYPE b) \ +{ \ + return ((float)a * float(b)); \ +} \ + \ +inline __host__ __device__ float operator*(double a, TYPE b) \ +{ \ + return ((float)a * float(b)); \ +} \ + \ +/* all - operator overloading with mixed types */ \ +/* mixed types, always converts to f32, does computation in f32, and returns float */ \ +inline __host__ __device__ float operator-(const float fa, TYPE b) \ +{ \ + return (fa - float(b)); \ +} \ + \ +inline __host__ __device__ float operator-(TYPE a, const float fb) \ +{ \ + return (float(a) - fb); \ +} \ + \ +inline __host__ __device__ TYPE operator-(TYPE a, TYPE b) \ +{ \ + return TYPE(float(a) - float(b)); \ +} \ + \ +inline __host__ __device__ TYPE& operator-=(TYPE& a, TYPE b) \ +{ \ + return a = TYPE(float(a) - float(b)); \ +} \ + \ +/* overloading division, always returns float, */ \ +inline __host__ __device__ float operator/(TYPE a, TYPE b) \ +{ \ + return float(a) / float(b); \ +} \ + \ +inline __host__ __device__ float operator/(float a, TYPE b) \ +{ \ + return (a / float(b)); \ +} \ + \ +inline __host__ __device__ float operator/(TYPE a, float b) \ +{ \ + return (float(a) / b); \ +} \ + \ +inline __host__ __device__ float operator/(int32_t a, TYPE b) \ +{ \ + return ((float)a / float(b)); \ +} \ + \ +inline __host__ __device__ float operator/(double a, TYPE b) \ +{ \ + return ((float)a / float(b)); \ +} + +#define RCCL_FLOAT8_MIXED_OPERATORS_1(TYPE1, TYPE2) \ +/* overloading for mixed f8 and bf8 types */ \ +inline __host__ __device__ float operator*(TYPE1 a, TYPE2 b) \ +{ \ + return float(a) * float(b); \ +} \ + \ +inline __host__ __device__ float operator+(TYPE1 a, TYPE2 b) \ +{ \ + return (float(a) + float(b)); \ +} \ + \ +inline __host__ __device__ float operator-(TYPE1 a, TYPE2 b) \ +{ \ + return (float(a) - float(b)); \ +} \ + \ +inline __host__ __device__ float operator/(TYPE1 a, TYPE2 b) \ +{ \ + return float(a) / float(b); \ +} + +#define RCCL_FLOAT8_MIXED_OPERATORS(TYPE1, TYPE2) \ +RCCL_FLOAT8_MIXED_OPERATORS_1(TYPE1, TYPE2) \ +RCCL_FLOAT8_MIXED_OPERATORS_1(TYPE2, TYPE1) + +RCCL_FLOAT8_OPERATORS(rccl_float8) +RCCL_FLOAT8_OPERATORS(rccl_bfloat8) +RCCL_FLOAT8_OPERATORS(rccl_float8_fnuz) +RCCL_FLOAT8_OPERATORS(rccl_bfloat8_fnuz) +RCCL_FLOAT8_MIXED_OPERATORS(rccl_float8, rccl_bfloat8) +RCCL_FLOAT8_MIXED_OPERATORS(rccl_float8, rccl_float8_fnuz) +RCCL_FLOAT8_MIXED_OPERATORS(rccl_float8, rccl_bfloat8_fnuz) +RCCL_FLOAT8_MIXED_OPERATORS(rccl_bfloat8, rccl_float8_fnuz) +RCCL_FLOAT8_MIXED_OPERATORS(rccl_bfloat8, rccl_bfloat8_fnuz) +RCCL_FLOAT8_MIXED_OPERATORS(rccl_float8_fnuz, rccl_bfloat8_fnuz) + +#undef RCCL_FLOAT8_OPERATORS +#undef RCCL_FLOAT8_MIXED_OPERATORS +#undef RCCL_FLOAT8_MIXED_OPERATORS_1 // ================================================================================================= +extern bool rccl_float8_useFnuz; + #endif // __cplusplus < 201103L || (!defined(__HCC__) && !defined(__HIPCC__)) #endif // ROCBLAS_FLOAT8_H diff --git a/verifiable/verifiable.cu b/verifiable/verifiable.cu index 32c13b0..c842cc2 100644 --- a/verifiable/verifiable.cu +++ b/verifiable/verifiable.cu @@ -357,6 +357,17 @@ struct FloatLayout { static constexpr int exponent_bits = 5, mantissa_bits = 2; static constexpr int exponent_bias = (1<<(exponent_bits-1))-1; }; + +template<> +struct FloatLayout { + static constexpr int exponent_bits = 4, mantissa_bits = 3; + static constexpr int exponent_bias = (1<<(exponent_bits-1)); +}; +template<> +struct FloatLayout { + static constexpr int exponent_bits = 5, mantissa_bits = 2; + static constexpr int exponent_bias = (1<<(exponent_bits-1)); +}; #endif template @@ -890,8 +901,10 @@ void prepareInput1( case ncclBfloat16: CASE_TY(hip_bfloat16) #endif #if HAVE_ncclfp8 - case ncclFp8E4M3: CASE_TY(rccl_float8) - case ncclFp8E5M2: CASE_TY(rccl_bfloat8) + case ncclFloat8e4m3: if (rccl_float8_useFnuz) { CASE_TY(rccl_float8_fnuz);} + else { CASE_TY(rccl_float8);} + case ncclFloat8e5m2: if (rccl_float8_useFnuz) { CASE_TY(rccl_bfloat8_fnuz);} + else { CASE_TY(rccl_bfloat8);} #endif case ncclFloat32: CASE_TY(float) case ncclFloat64: CASE_TY(double) @@ -970,8 +983,10 @@ void prepareExpected1( case ncclBfloat16: CASE_TY(hip_bfloat16) #endif #if HAVE_ncclfp8 - case ncclFp8E4M3: CASE_TY(rccl_float8) - case ncclFp8E5M2: CASE_TY(rccl_bfloat8) + case ncclFloat8e4m3: if (rccl_float8_useFnuz) { CASE_TY(rccl_float8_fnuz);} + else { CASE_TY(rccl_float8);} + case ncclFloat8e5m2: if (rccl_float8_useFnuz) { CASE_TY(rccl_bfloat8_fnuz);} + else { CASE_TY(rccl_bfloat8);} #endif case ncclFloat32: CASE_TY(float) case ncclFloat64: CASE_TY(double) @@ -1044,8 +1059,8 @@ __host__ __device__ unsigned calcSumFloatTolerance(int rank_n, int elt_ty) { break; #endif #if HAVE_ncclfp8 - case ncclFp8E4M3: - case ncclFp8E5M2: + case ncclFloat8e4m3: + case ncclFloat8e5m2: power = .91f; coef = .66f; break; @@ -1175,8 +1190,8 @@ void ncclVerifiableVerify( floating |= elt_ty == ncclBfloat16; #endif #if HAVE_ncclfp8 - floating |= elt_ty == ncclFp8E4M3; - floating |= elt_ty == ncclFp8E5M2; + floating |= elt_ty == ncclFloat8e4m3; + floating |= elt_ty == ncclFloat8e5m2; #endif unsigned tolerance = 0; @@ -1207,8 +1222,10 @@ void ncclVerifiableVerify( case ncclBfloat16: CASE_TY(hip_bfloat16, uint16_t) #endif #if HAVE_ncclfp8 - case ncclFp8E4M3: CASE_TY(rccl_float8, uint8_t) - case ncclFp8E5M2: CASE_TY(rccl_bfloat8, uint8_t) + case ncclFloat8e4m3: if (rccl_float8_useFnuz) { CASE_TY(rccl_float8_fnuz, uint8_t);} + else { CASE_TY(rccl_float8, uint8_t);} + case ncclFloat8e5m2: if (rccl_float8_useFnuz) { CASE_TY(rccl_bfloat8_fnuz, uint8_t);} + else { CASE_TY(rccl_bfloat8, uint8_t);} #endif case ncclFloat32: CASE_TY(float, uint32_t) case ncclFloat64: CASE_TY(double, uint64_t)