Skip to content

Commit

Permalink
Add float8_e4m3
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov committed Aug 1, 2024
1 parent b157c19 commit ab2fbdc
Show file tree
Hide file tree
Showing 8 changed files with 361 additions and 9 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
an alternative to the standard [`float16`](https://en.wikipedia.org/wiki/Half-precision_floating-point_format) format
- `float8_*`: several experimental 8-bit floating point representations
including:
* `float8_e4m3`
* `float8_e4m3b11fnuz`
* `float8_e4m3fn`
* `float8_e4m3fnuz`
Expand Down Expand Up @@ -64,6 +65,10 @@ A `bfloat16` number is a single-precision float truncated at 16 bits.

Exponent: 8, Mantissa: 7, exponent bias: 127. IEEE 754, with NaN and inf.

### `float8_e4m3`

Exponent: 4, Mantissa: 3, bias: 7. IEEE 754, with NaN and inf.

### `float8_e4m3b11fnuz`

Exponent: 4, Mantissa: 3, bias: 11.
Expand Down
3 changes: 3 additions & 0 deletions ml_dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"__version__",
"bfloat16",
"finfo",
"float8_e4m3",
"float8_e4m3b11fnuz",
"float8_e4m3fn",
"float8_e4m3fnuz",
Expand All @@ -34,6 +35,7 @@
from ml_dtypes._finfo import finfo
from ml_dtypes._iinfo import iinfo
from ml_dtypes._ml_dtypes_ext import bfloat16
from ml_dtypes._ml_dtypes_ext import float8_e4m3
from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz
from ml_dtypes._ml_dtypes_ext import float8_e4m3fn
from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz
Expand All @@ -46,6 +48,7 @@
import numpy as np

bfloat16: Type[np.generic]
float8_e4m3: Type[np.generic]
float8_e4m3b11fnuz: Type[np.generic]
float8_e4m3fn: Type[np.generic]
float8_e4m3fnuz: Type[np.generic]
Expand Down
63 changes: 63 additions & 0 deletions ml_dtypes/_finfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Dict

from ml_dtypes._ml_dtypes_ext import bfloat16
from ml_dtypes._ml_dtypes_ext import float8_e4m3
from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz
from ml_dtypes._ml_dtypes_ext import float8_e4m3fn
from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz
Expand All @@ -25,6 +26,7 @@
import numpy as np

_bfloat16_dtype = np.dtype(bfloat16)
_float8_e4m3_dtype = np.dtype(float8_e4m3)
_float8_e4m3b11fnuz_dtype = np.dtype(float8_e4m3b11fnuz)
_float8_e4m3fn_dtype = np.dtype(float8_e4m3fn)
_float8_e4m3fnuz_dtype = np.dtype(float8_e4m3fnuz)
Expand All @@ -41,6 +43,14 @@ def __init__(self):
self.smallest_subnormal = bfloat16(smallest_subnormal)


class _Float8E4m3MachArLike:

def __init__(self):
smallest_normal = float.fromhex("0x1p-6")
self.smallest_normal = float8_e4m3(smallest_normal)
smallest_subnormal = float.fromhex("0x1p-9")
self.smallest_subnormal = float8_e4m3(smallest_subnormal)

class _Float8E4m3b11fnuzMachArLike:

def __init__(self):
Expand Down Expand Up @@ -135,6 +145,51 @@ def float_to_str(f):
# pylint: enable=protected-access
return obj

@staticmethod
def _float8_e4m3_finfo():
def float_to_str(f):
return "%6.2e" % float(f)

tiny = float.fromhex("0x1p-6") # 1/64 min normal
resolution = 0.1
eps = float.fromhex("0x1p-3") # 1/8
epsneg = float.fromhex("0x1p-4") # 1/16
max_ = float.fromhex("0x1.Ep7") # 240 max normal

obj = object.__new__(np.finfo)
obj.dtype = _float8_e4m3_dtype
obj.bits = 8
obj.eps = float8_e4m3(eps)
obj.epsneg = float8_e4m3(epsneg)
obj.machep = -3
obj.negep = -4
obj.max = float8_e4m3(max_)
obj.min = float8_e4m3(-max_)
obj.nexp = 4
obj.nmant = 3
obj.iexp = obj.nexp
obj.maxexp = 8
obj.minexp = -6
obj.precision = 1
obj.resolution = float8_e4m3(resolution)
# pylint: disable=protected-access
obj._machar = _Float8E4m3MachArLike()
if not hasattr(obj, "tiny"):
obj.tiny = float8_e4m3(tiny)
if not hasattr(obj, "smallest_normal"):
obj.smallest_normal = obj._machar.smallest_normal
obj.smallest_subnormal = obj._machar.smallest_subnormal

obj._str_tiny = float_to_str(tiny)
obj._str_smallest_normal = float_to_str(tiny)
obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
obj._str_max = float_to_str(max_)
obj._str_epsneg = float_to_str(epsneg)
obj._str_eps = float_to_str(eps)
obj._str_resolution = float_to_str(resolution)
# pylint: enable=protected-access
return obj

@staticmethod
def _float8_e4m3b11fnuz_finfo():
def float_to_str(f):
Expand Down Expand Up @@ -369,6 +424,14 @@ def __new__(cls, dtype):
if _bfloat16_dtype not in cls._finfo_cache:
cls._finfo_cache[_bfloat16_dtype] = cls._bfloat16_finfo()
return cls._finfo_cache[_bfloat16_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e4m3"
or dtype == _float8_e4m3_dtype
):
if _float8_e4m3_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e4m3_dtype] = cls._float8_e4m3_finfo()
return cls._finfo_cache[_float8_e4m3_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e4m3b11fnuz"
Expand Down
30 changes: 30 additions & 0 deletions ml_dtypes/_src/dtypes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,22 @@ struct TypeDescriptor<bfloat16> : CustomFloatType<bfloat16> {
static constexpr char kNpyDescrByteorder = '=';
};

template <>
struct TypeDescriptor<float8_e4m3> : CustomFloatType<float8_e4m3> {
typedef float8_e4m3 T;
static constexpr bool is_floating = true;
static constexpr bool is_integral = false;
static constexpr const char* kTypeName = "float8_e4m3";
static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_e4m3";
static constexpr const char* kTpDoc = "float8_e4m3 floating-point values";
// Treating e4m3 as the natural "float" type since it is IEEE-754 compliant.
static constexpr char kNpyDescrKind = 'f';
// TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
// character is unique.
static constexpr char kNpyDescrType = '4';
static constexpr char kNpyDescrByteorder = '=';
};

template <>
struct TypeDescriptor<float8_e4m3b11fnuz>
: CustomFloatType<float8_e4m3b11fnuz> {
Expand Down Expand Up @@ -269,6 +285,9 @@ bool Initialize() {
if (!RegisterFloatDtype<bfloat16>(numpy.get())) {
return false;
}
if (!RegisterFloatDtype<float8_e4m3>(numpy.get())) {
return false;
}
if (!RegisterFloatDtype<float8_e4m3b11fnuz>(numpy.get())) {
return false;
}
Expand Down Expand Up @@ -319,6 +338,12 @@ bool Initialize() {
success &= RegisterTwoWayCustomCast<float8_e5m2fnuz, float8_e4m3fn, float>();
success &= RegisterTwoWayCustomCast<float8_e4m3fnuz, float8_e5m2, float>();
success &= RegisterTwoWayCustomCast<float8_e5m2fnuz, float8_e5m2, float>();
success &= RegisterTwoWayCustomCast<float8_e4m3, bfloat16, float>();
success &= RegisterTwoWayCustomCast<float8_e4m3, float8_e4m3b11fnuz, float>();
success &= RegisterTwoWayCustomCast<float8_e4m3, float8_e5m2fnuz, float>();
success &= RegisterTwoWayCustomCast<float8_e4m3, float8_e4m3fnuz, float>();
success &= RegisterTwoWayCustomCast<float8_e4m3, float8_e4m3fn, float>();
success &= RegisterTwoWayCustomCast<float8_e4m3, float8_e5m2, float>();
success &= RegisterOneWayCustomCast<int2, int4, int8_t>();
success &= RegisterOneWayCustomCast<uint2, uint4, uint8_t>();
return success;
Expand Down Expand Up @@ -349,6 +374,11 @@ extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_ext() {
return nullptr;
}

if (PyObject_SetAttrString(m.get(), "float8_e4m3",
reinterpret_cast<PyObject*>(
TypeDescriptor<float8_e4m3>::type_ptr)) < 0) {
return nullptr;
}
if (PyObject_SetAttrString(
m.get(), "float8_e4m3b11fnuz",
reinterpret_cast<PyObject*>(
Expand Down
Loading

0 comments on commit ab2fbdc

Please sign in to comment.