Skip to content

Commit

Permalink
[RFC] Add f8E4M3 and f8E3M4 types support (#2486)
Browse files Browse the repository at this point in the history
### Summary
This is a proposal to add `Float8E4M3` and `Float8E3M4` floating point
types to StableHLO.
Feedback welcome, see [RFC: Float8E4M3 and
Float8E3M4](https://github.com/apivovarov/stablehlo/blob/rfc_f8E4M3_f8E3M4/rfcs/20240808-f8E4M3_f8E3M4.md)
for more details.

### References and Links
- LLVM [PR-97179](llvm/llvm-project#97179)
[APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118)
[MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698)
[APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
- LLVM [PR-101230](llvm/llvm-project#101230)
[MLIR] Add f8E3M4 IEEE 754 type (Merged)
- [RFC: FP8 in
StableHLO](https://github.com/openxla/stablehlo/blob/main/rfcs/20221031-fp8.md)
- [RFC: Float8E4M3FNUZ and
Float8E5M2FNUZ](https://github.com/openxla/stablehlo/blob/main/rfcs/20230321-fp8_fnuz.md)
- StableHLO [PR-2482](#2482)
Add f8E4M3 and f8E3M4 types support
- [Amazon EC2 Trn1
Instances](https://aws.amazon.com/ec2/instance-types/trn1/)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add
float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add
float8_e3m4 (Merged)
- XLA [PR-16585](openxla/xla#16585) Add support
for float8_e4m3
  • Loading branch information
apivovarov authored Sep 3, 2024
1 parent 1456dfa commit d68ab07
Showing 1 changed file with 178 additions and 0 deletions.
178 changes: 178 additions & 0 deletions rfcs/20240808-f8E4M3_f8E3M4.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# RFC: Float8E4M3 and Float8E3M4

Status: In Review<br/>
Initial version: 8/8/2024<br/>
Last updated: 8/9/2024<br/>
Discussion thread: [PR-2486](https://github.com/openxla/stablehlo/pull/2486)
[RFC] Add f8E4M3 and f8E3M4 types support

## Summary

Amazon has proposed two new FP8 types, Float8E4M3 and Float8E3M4. These
types are implemented in commercially available hardware[^1], and added to MLIR
builtin types[^2]˒[^3] and LLVM APFloat[^4]˒[^5].

Both Float8E4M3 and Float8E3M4 follows IEEE 754 convention similar to existing
type Float8E5M2.

### Float8E4M3

8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits mantissa
following IEEE-754 conventions with bit layout S1E4M3.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Minimum stored exponent value: 1 (binary 0001)
- Maximum stored exponent value: 14 (binary 1110)
- Minimum unbiased exponent value: 1 − 7 = −6
- Maximum unbiased exponent value: 14 - 7 = 7
- Precision specifies the total number of bits used for the significand
(mantisa), including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Min exp (unbiased): -6
- Max exp (unbiased): 7
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Min normal number: S.0001.000 = +/-2^(1 - 7) x (1 + 0) = +/-2^(-6)
- Max normal number: S.1110.111 = +/-2^(14 - 7) x (1 + 7/8) = +/-240
- Min subnormal number: S.0000.001 = +/-2^(-6) x 1/8 = +/-2^(-9)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 7/8 = +/-2^(-9) x 7
```
#### Comparison of Float8E4M3FN and Float8E4M3
| |Float8E4M3FN |Float8E4M3 |
|-------------------|------------------------------------------------------------------------|-------------------------------------------------------------------------|
|Bias |7 |7 |
|Min Normal Value |`0bS0001000` = -1<sup>S</sup> $\times$ 1.0 $\times$ 2<sup>-6</sup> |`0bS0001000` = -1<sup>S</sup> $\times$ 1.0 $\times$ 2<sup>-6</sup> |
|Max Normal Value |`0bS1111110` = -1<sup>S</sup> $\times$ 1.75 $\times$ 2<sup>8</sup> = 448|`0bS1110111` = -1<sup>S</sup> $\times$ 1.875 $\times$ 2<sup>7</sup> = 240|
|Min Subnormal Value|`0bS0000001` = -1<sup>S</sup> $\times$ 0.125 $\times$ 2<sup>-6</sup> |`0bS0000001` = -1<sup>S</sup> $\times$ 0.125 $\times$ 2<sup>-6</sup> |
|Max Subnormal Value|`0bS0000111` = -1<sup>S</sup> $\times$ 0.875 $\times$ 2<sup>-6</sup> |`0bS0000111` = -1<sup>S</sup> $\times$ 0.875 $\times$ 2<sup>-6</sup> |
|NaN |`0bS1111111` |`0bS1111MMM`, where `MMM` is non-zero. |
|Infinity |N/A |`0bS1111000` |
|-0.0 |`0b10000000` |`0b10000000` |
### Float8E3M4
8-bit floating point type with 1 sign bit, 3 bits exponent and 4 bits mantissa
following IEEE-754 conventions with bit layout S1E3M4.
```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Minimum stored exponent value: 1 (binary 001)
- Maximum stored exponent value: 6 (binary 110)
- Minimum unbiased exponent value: 1 − 3 = −2
- Maximum unbiased exponent value: 6 - 3 = 3
- Precision specifies the total number of bits used for the significand
(mantissa), including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs
Additional details:
- Min exp (unbiased): -2
- Max exp (unbiased): 3
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Min normal number: S.001.0000 = +/-2^(1 - 3) x (1 + 0) = +/-0.25
- Max normal number: S.110.1111 = +/-2^(6 - 3) x (1 + 15/16) = +/-15.5
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-6)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-6) x 15
```

### Comparison of Float8E5M2, Float8E4M3 and Float8E3M4

| |Float8E5M2 |Float8E4M3 |Float8E3M4 |
|-------------------|----------------------------------------------------------------------------|-------------------------------------------------------------------------|---------------------------------------------------------------------------|
|Bias |15 |7 |3 |
|Min Normal Value |`0bS0000100` = -1<sup>S</sup> $\times$ 1.0 $\times$ 2<sup>-14</sup> |`0bS0001000` = -1<sup>S</sup> $\times$ 1.0 $\times$ 2<sup>-6</sup> |`0bS0010000` = -1<sup>S</sup> $\times$ 1.0 $\times$ 2<sup>-2</sup> |
|Max Normal Value |`0bS1111011` = -1<sup>S</sup> $\times$ 1.75 $\times$ 2<sup>15</sup> = 57344 |`0bS1110111` = -1<sup>S</sup> $\times$ 1.875 $\times$ 2<sup>7</sup> = 240|`0bS1101111` = -1<sup>S</sup> $\times$ 1.9375 $\times$ 2<sup>3</sup> = 15.5|
|Min Subnormal Value|`0bS0000001` = -1<sup>S</sup> $\times$ 0.25 $\times$ 2<sup>-14</sup> |`0bS0000001` = -1<sup>S</sup> $\times$ 0.125 $\times$ 2<sup>-6</sup> |`0bS0000001` = -1<sup>S</sup> $\times$ 0.0625 $\times$ 2<sup>-2</sup> |
|Max Subnormal Value|`0bS0000011` = -1<sup>S</sup> $\times$ 0.75 $\times$ 2<sup>-14</sup> |`0bS0000111` = -1<sup>S</sup> $\times$ 0.875 $\times$ 2<sup>-6</sup> |`0bS0001111` = -1<sup>S</sup> $\times$ 0.9375 $\times$ 2<sup>-2</sup> |
|NaN |`0bS11111MM`, where `MM` is non-zero. |`0bS1111MMM`, where `MMM` is non-zero. |`0bS111MMMM`, where `MMMM` is non-zero. |
|Infinity |`0bS1111100` |`0bS1111000` |`0bS1110000` |
|-0.0 |`0b10000000` |`0b10000000` |`0b10000000` |

## Changes in StableHLO

I propose adding Float8E4M3 and Float8E3M4 types to StableHLO similar to the
previously introduces FP8 types (below) with some differences:

- [FP8 RFC](https://github.com/openxla/xla/discussions/22)
- [[RFC] Add Float8E4M3FNUZ and Float8E5M2FNUZ to StableHLO](https://github.com/openxla/stablehlo/pull/1342)

### StableHLO Interpreter

To provide a reference implementation, I intend to add support for
Float8E4M3 and Float8E3M4 in the StableHLO interpreter. This will be
useful for testing other backends and validating new implementations. This will
be achieved in two ways:

1. Map directly to the appropriate APFloat operation.
2. Cast up to the appropriate type, use that implementation, cast back down.

### Float8E4M3 and Float8E3M4 Arithmetic

I intend for Float8E4M3 and Float8E3M4 to be types that support the
appropriate arithmetic operations, like any other floating point type. For
platforms that don't have hardware support for these types, they may either
throw an error and reject the program or cast up to an appropriate higher
precision type that is supported, compute the answer, and cast back down.

This is a simple approach that aligns with user expectations of a floating
point data type, and is the approach taken by BFloat16. This also gives
backends freedom to exploit any hardware support.

Here's an example of a real JAX program (logging the MLIR) computing a simple
dot product in Float8E4M3. Note the answer is slightly "wrong", as expected
due to the lower precision (round-to-nearest).

```python
>>> import jax
>>> import jax.numpy as jnp
>>> x = jnp.arange(8, dtype=jnp.float8_e4m3)
module @jit_iota {
func.func public @main() -> tensor<8xf8E4M3> {
%0 = stablehlo.iota dim = 0 : tensor<8xf8E4M3>
return %0 : tensor<8xf8E4M3>
}
}
>>> x
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=float8_e4m3)
>>> x @ x
module @jit_matmul {
func.func public @main(%arg0: tensor<8xf8E4M3> {mhlo.sharding = ""}, %arg1: tensor<8xf8E4M3> {mhlo.sharding = ""}) -> tensor<f8E4M3> {
%0 = "stablehlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #stablehlo.dot<lhs_contracting_dimensions = [0], rhs_contracting_dimensions = [0]>, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<8xf8E4M3>, tensor<8xf8E4M3>) -> tensor<f8E4M3>
return %0 : tensor<f8E4M3>
}
}
Array(144, dtype=float8_e4m3)
```

### Testing

Built on the StableHLO interpreter, I intend to introduce tests for all
possible operations with Float8E4M3 and Float8E3M4 inputs. This will at
a minimum mean adding additional cases to the `interpret_X.mlir` family of
tests.

### References and Links

- [RFC: FP8 in StableHLO](https://github.com/openxla/stablehlo/blob/main/rfcs/20221031-fp8.md)
- [RFC: Float8E4M3FNUZ and Float8E5M2FNUZ](https://github.com/openxla/stablehlo/blob/main/rfcs/20230321-fp8_fnuz.md)

[^1]: [Amazon EC2 Trn1 Instances](https://aws.amazon.com/ec2/instance-types/trn1/)
[^2]: LLVM [PR-97118](https://github.com/llvm/llvm-project/pull/97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
[^3]: LLVM [PR-101230](https://github.com/llvm/llvm-project/pull/101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
[^4]: LLVM [PR-97179](https://github.com/llvm/llvm-project/pull/97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
[^5]: LLVM [PR-99698](https://github.com/llvm/llvm-project/pull/99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)

0 comments on commit d68ab07

Please sign in to comment.