@@ -59,9 +59,9 @@ class float8_base {
59
59
public:
60
60
constexpr float8_base () : rep_(0 ) {}
61
61
62
- template <typename T,
63
- typename EnableIf = std::enable_if<std::is_arithmetic_v<T>>>
64
- explicit EIGEN_DEVICE_FUNC float8_base ( T f)
62
+ template <typename T>
63
+ explicit EIGEN_DEVICE_FUNC float8_base (
64
+ T f, std:: enable_if_t <std::is_arithmetic_v<T>, int > = 0 )
65
65
: float8_base(ConvertFrom(static_cast <float >(f)).rep(),
66
66
ConstructFromRepTag{}) {}
67
67
explicit EIGEN_DEVICE_FUNC float8_base (double f64)
@@ -239,6 +239,10 @@ class float8_base {
239
239
uint8_t rep_;
240
240
};
241
241
242
+ template <typename T>
243
+ using RequiresIsDerivedFromFloat8Base =
244
+ std::enable_if_t <std::is_base_of_v<float8_base<T>, T>, int >;
245
+
242
246
class float8_e4m3fn : public float8_base <float8_e4m3fn> {
243
247
// Exponent: 4, Mantissa: 3, bias: 7.
244
248
// Extended range: no inf, NaN represented by 0bS111'1111.
@@ -252,9 +256,8 @@ class float8_e4m3fn : public float8_base<float8_e4m3fn> {
252
256
using Base::Base;
253
257
254
258
public:
255
- explicit EIGEN_DEVICE_FUNC float8_e4m3fn (const float8_e5m2& f8)
256
- : float8_e4m3fn(ConvertFrom(f8)) {}
257
- explicit EIGEN_DEVICE_FUNC float8_e4m3fn (const float8_e4m3b11fnuz& f8)
259
+ template <typename T, RequiresIsDerivedFromFloat8Base<T> = 0 >
260
+ explicit EIGEN_DEVICE_FUNC float8_e4m3fn (T f8)
258
261
: float8_e4m3fn(ConvertFrom(f8)) {}
259
262
};
260
263
@@ -267,13 +270,8 @@ class float8_e4m3b11fnuz : public float8_base<float8_e4m3b11fnuz> {
267
270
using Base::Base;
268
271
269
272
public:
270
- explicit EIGEN_DEVICE_FUNC float8_e4m3b11fnuz (const float8_e5m2& f8)
271
- : float8_e4m3b11fnuz(ConvertFrom(f8)) {}
272
- explicit EIGEN_DEVICE_FUNC float8_e4m3b11fnuz (const float8_e5m2fnuz& f8)
273
- : float8_e4m3b11fnuz(ConvertFrom(f8)) {}
274
- explicit EIGEN_DEVICE_FUNC float8_e4m3b11fnuz (const float8_e4m3fn& f8)
275
- : float8_e4m3b11fnuz(ConvertFrom(f8)) {}
276
- explicit EIGEN_DEVICE_FUNC float8_e4m3b11fnuz (const float8_e4m3fnuz& f8)
273
+ template <typename T, RequiresIsDerivedFromFloat8Base<T> = 0 >
274
+ explicit EIGEN_DEVICE_FUNC float8_e4m3b11fnuz (T f8)
277
275
: float8_e4m3b11fnuz(ConvertFrom(f8)) {}
278
276
279
277
constexpr float8_e4m3b11fnuz operator -() const {
@@ -315,13 +313,8 @@ class float8_e4m3fnuz : public float8_base<float8_e4m3fnuz> {
315
313
using Base::Base;
316
314
317
315
public:
318
- explicit EIGEN_DEVICE_FUNC float8_e4m3fnuz (const float8_e5m2& f8)
319
- : float8_e4m3fnuz(ConvertFrom(f8)) {}
320
- explicit EIGEN_DEVICE_FUNC float8_e4m3fnuz (const float8_e5m2fnuz& f8)
321
- : float8_e4m3fnuz(ConvertFrom(f8)) {}
322
- explicit EIGEN_DEVICE_FUNC float8_e4m3fnuz (const float8_e4m3b11fnuz& f8)
323
- : float8_e4m3fnuz(ConvertFrom(f8)) {}
324
- explicit EIGEN_DEVICE_FUNC float8_e4m3fnuz (const float8_e4m3fn& f8)
316
+ template <typename T, RequiresIsDerivedFromFloat8Base<T> = 0 >
317
+ explicit EIGEN_DEVICE_FUNC float8_e4m3fnuz (T f8)
325
318
: float8_e4m3fnuz(ConvertFrom(f8)) {}
326
319
327
320
constexpr float8_e4m3fnuz operator -() const {
@@ -347,13 +340,8 @@ class float8_e5m2 : public float8_base<float8_e5m2> {
347
340
using Base::Base;
348
341
349
342
public:
350
- explicit EIGEN_DEVICE_FUNC float8_e5m2 (float8_e4m3fn f8)
351
- : float8_e5m2(ConvertFrom(f8)) {}
352
- explicit EIGEN_DEVICE_FUNC float8_e5m2 (float8_e4m3fnuz f8)
353
- : float8_e5m2(ConvertFrom(f8)) {}
354
- explicit EIGEN_DEVICE_FUNC float8_e5m2 (float8_e4m3b11fnuz f8)
355
- : float8_e5m2(ConvertFrom(f8)) {}
356
- explicit EIGEN_DEVICE_FUNC float8_e5m2 (float8_e5m2fnuz& f8)
343
+ template <typename T, RequiresIsDerivedFromFloat8Base<T> = 0 >
344
+ explicit EIGEN_DEVICE_FUNC float8_e5m2 (T f8)
357
345
: float8_e5m2(ConvertFrom(f8)) {}
358
346
};
359
347
0 commit comments