Skip to content

Commit 161db24

Browse files
majnemerThe ml_dtypes Authors
authored and
The ml_dtypes Authors
committed
Allow the copy constructor to use copy-by-value as float8 types are quite small
No functional change is intended. PiperOrigin-RevId: 576576088
1 parent 348fd37 commit 161db24

File tree

1 file changed

+15
-27
lines changed

1 file changed

+15
-27
lines changed

ml_dtypes/include/float8.h

+15-27
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ class float8_base {
5959
public:
6060
constexpr float8_base() : rep_(0) {}
6161

62-
template <typename T,
63-
typename EnableIf = std::enable_if<std::is_arithmetic_v<T>>>
64-
explicit EIGEN_DEVICE_FUNC float8_base(T f)
62+
template <typename T>
63+
explicit EIGEN_DEVICE_FUNC float8_base(
64+
T f, std::enable_if_t<std::is_arithmetic_v<T>, int> = 0)
6565
: float8_base(ConvertFrom(static_cast<float>(f)).rep(),
6666
ConstructFromRepTag{}) {}
6767
explicit EIGEN_DEVICE_FUNC float8_base(double f64)
@@ -239,6 +239,10 @@ class float8_base {
239239
uint8_t rep_;
240240
};
241241

242+
template <typename T>
243+
using RequiresIsDerivedFromFloat8Base =
244+
std::enable_if_t<std::is_base_of_v<float8_base<T>, T>, int>;
245+
242246
class float8_e4m3fn : public float8_base<float8_e4m3fn> {
243247
// Exponent: 4, Mantissa: 3, bias: 7.
244248
// Extended range: no inf, NaN represented by 0bS111'1111.
@@ -252,9 +256,8 @@ class float8_e4m3fn : public float8_base<float8_e4m3fn> {
252256
using Base::Base;
253257

254258
public:
255-
explicit EIGEN_DEVICE_FUNC float8_e4m3fn(const float8_e5m2& f8)
256-
: float8_e4m3fn(ConvertFrom(f8)) {}
257-
explicit EIGEN_DEVICE_FUNC float8_e4m3fn(const float8_e4m3b11fnuz& f8)
259+
template <typename T, RequiresIsDerivedFromFloat8Base<T> = 0>
260+
explicit EIGEN_DEVICE_FUNC float8_e4m3fn(T f8)
258261
: float8_e4m3fn(ConvertFrom(f8)) {}
259262
};
260263

@@ -267,13 +270,8 @@ class float8_e4m3b11fnuz : public float8_base<float8_e4m3b11fnuz> {
267270
using Base::Base;
268271

269272
public:
270-
explicit EIGEN_DEVICE_FUNC float8_e4m3b11fnuz(const float8_e5m2& f8)
271-
: float8_e4m3b11fnuz(ConvertFrom(f8)) {}
272-
explicit EIGEN_DEVICE_FUNC float8_e4m3b11fnuz(const float8_e5m2fnuz& f8)
273-
: float8_e4m3b11fnuz(ConvertFrom(f8)) {}
274-
explicit EIGEN_DEVICE_FUNC float8_e4m3b11fnuz(const float8_e4m3fn& f8)
275-
: float8_e4m3b11fnuz(ConvertFrom(f8)) {}
276-
explicit EIGEN_DEVICE_FUNC float8_e4m3b11fnuz(const float8_e4m3fnuz& f8)
273+
template <typename T, RequiresIsDerivedFromFloat8Base<T> = 0>
274+
explicit EIGEN_DEVICE_FUNC float8_e4m3b11fnuz(T f8)
277275
: float8_e4m3b11fnuz(ConvertFrom(f8)) {}
278276

279277
constexpr float8_e4m3b11fnuz operator-() const {
@@ -315,13 +313,8 @@ class float8_e4m3fnuz : public float8_base<float8_e4m3fnuz> {
315313
using Base::Base;
316314

317315
public:
318-
explicit EIGEN_DEVICE_FUNC float8_e4m3fnuz(const float8_e5m2& f8)
319-
: float8_e4m3fnuz(ConvertFrom(f8)) {}
320-
explicit EIGEN_DEVICE_FUNC float8_e4m3fnuz(const float8_e5m2fnuz& f8)
321-
: float8_e4m3fnuz(ConvertFrom(f8)) {}
322-
explicit EIGEN_DEVICE_FUNC float8_e4m3fnuz(const float8_e4m3b11fnuz& f8)
323-
: float8_e4m3fnuz(ConvertFrom(f8)) {}
324-
explicit EIGEN_DEVICE_FUNC float8_e4m3fnuz(const float8_e4m3fn& f8)
316+
template <typename T, RequiresIsDerivedFromFloat8Base<T> = 0>
317+
explicit EIGEN_DEVICE_FUNC float8_e4m3fnuz(T f8)
325318
: float8_e4m3fnuz(ConvertFrom(f8)) {}
326319

327320
constexpr float8_e4m3fnuz operator-() const {
@@ -347,13 +340,8 @@ class float8_e5m2 : public float8_base<float8_e5m2> {
347340
using Base::Base;
348341

349342
public:
350-
explicit EIGEN_DEVICE_FUNC float8_e5m2(float8_e4m3fn f8)
351-
: float8_e5m2(ConvertFrom(f8)) {}
352-
explicit EIGEN_DEVICE_FUNC float8_e5m2(float8_e4m3fnuz f8)
353-
: float8_e5m2(ConvertFrom(f8)) {}
354-
explicit EIGEN_DEVICE_FUNC float8_e5m2(float8_e4m3b11fnuz f8)
355-
: float8_e5m2(ConvertFrom(f8)) {}
356-
explicit EIGEN_DEVICE_FUNC float8_e5m2(float8_e5m2fnuz& f8)
343+
template <typename T, RequiresIsDerivedFromFloat8Base<T> = 0>
344+
explicit EIGEN_DEVICE_FUNC float8_e5m2(T f8)
357345
: float8_e5m2(ConvertFrom(f8)) {}
358346
};
359347

0 commit comments

Comments
 (0)