FP8s truncate (#9937)

* truncate fp8

* fix

* maybe like that?

* fix linters

* ruff

* move from extra and add ml_types to tests

* minor changes

* str to dtypes and nan support

---------

Co-authored-by: pkotzbach <pawkotz@gmail.com>
This commit is contained in:
pkotzbach
2025-04-23 01:12:49 +02:00
committed by GitHub
parent 58180caad3
commit dbbd755cba
3 changed files with 113 additions and 0 deletions

View File

@@ -5,10 +5,12 @@ from typing import Any, List
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import getenv, DEBUG, CI
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype, truncate_fp16, truncate_bf16, to_dtype
from tinygrad.dtype import truncate, fp8_to_float, float_to_fp8
from tinygrad import Device, Tensor, dtypes
from tinygrad.tensor import _to_np_dtype
from hypothesis import assume, given, settings, strategies as strat
from test.helpers import rand_for_dtype
import ml_dtypes
import pytest
pytestmark = pytest.mark.filterwarnings("ignore")
@@ -19,6 +21,8 @@ core_dtypes = list(DTYPES_DICT.values())
if Device.DEFAULT == "CPU": core_dtypes.remove(dtypes.bfloat16) # NOTE: this is for teenygrad, don't remove
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and is_dtype_supported(dt)]
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)]
FP8E4M3_MAX = 448.0
FP8E5M2_MAX = 57344.0
def get_available_cast_dtypes(dtype: DType) -> List[DType]:
if not is_dtype_supported(dtype): return []
@@ -146,6 +150,39 @@ class TestFp8s(unittest.TestCase):
def test_fp8e4m3_creation(self): assert Tensor([-1, 1, 2], dtype=dtypes.fp8e4m3).dtype == dtypes.fp8e4m3
def test_fp8e5m2_creation(self): assert Tensor([-1, 1, 2], dtype=dtypes.fp8e5m2).dtype == dtypes.fp8e5m2
class TestFp8sConversions(unittest.TestCase):
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=False, allow_infinity=False, min_value=-FP8E4M3_MAX, max_value=FP8E4M3_MAX))
def test_float_to_fp8e4m3(self, x): np.testing.assert_equal(float_to_fp8(x, dtypes.fp8e4m3), ml_dtypes.float8_e4m3fn(x).tobytes()[0])
def test_float_to_fp8e4m3_extreme_values(self):
np.testing.assert_equal(float_to_fp8(FP8E4M3_MAX, dtypes.fp8e4m3), 126)
np.testing.assert_equal(float_to_fp8(FP8E4M3_MAX*1.01, dtypes.fp8e4m3), 126)
np.testing.assert_equal(float_to_fp8(math.inf, dtypes.fp8e4m3), 126)
np.testing.assert_equal(float_to_fp8(-FP8E4M3_MAX, dtypes.fp8e4m3), 254)
np.testing.assert_equal(float_to_fp8(-FP8E4M3_MAX*1.01, dtypes.fp8e4m3), 254)
np.testing.assert_equal(float_to_fp8(-math.inf, dtypes.fp8e4m3), 254)
np.testing.assert_equal(float_to_fp8(math.nan, dtypes.fp8e4m3), 127)
np.testing.assert_equal(float_to_fp8(-math.nan, dtypes.fp8e4m3), 255)
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=False, allow_infinity=False, min_value=-FP8E5M2_MAX, max_value=FP8E5M2_MAX))
def test_float_to_fp8e5m2(self, x): np.testing.assert_equal(float_to_fp8(x, dtypes.fp8e5m2), ml_dtypes.float8_e5m2(x).tobytes()[0])
def test_float_to_fp8e5m2_extreme_values(self):
np.testing.assert_equal(float_to_fp8(FP8E5M2_MAX, dtypes.fp8e5m2), 123)
np.testing.assert_equal(float_to_fp8(FP8E5M2_MAX*1.01, dtypes.fp8e5m2), 123)
np.testing.assert_equal(float_to_fp8(math.inf, dtypes.fp8e5m2), 123)
np.testing.assert_equal(float_to_fp8(-FP8E5M2_MAX, dtypes.fp8e5m2), 251)
np.testing.assert_equal(float_to_fp8(-FP8E5M2_MAX*1.01, dtypes.fp8e5m2), 251)
np.testing.assert_equal(float_to_fp8(-math.inf, dtypes.fp8e5m2), 251)
np.testing.assert_equal(float_to_fp8(math.nan, dtypes.fp8e5m2), 126)
np.testing.assert_equal(float_to_fp8(-math.nan, dtypes.fp8e5m2), 254)
@given(strat.integers(min_value=0, max_value=255))
def test_fp8e4m3_to_float(self, x): np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e4m3), np.uint8(x).view(ml_dtypes.float8_e4m3fn).item())
@given(strat.integers(min_value=0, max_value=255))
def test_fp8e5m2_to_float(self, x): np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e5m2), np.uint8(x).view(ml_dtypes.float8_e5m2).item())
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
class TestBFloat16(unittest.TestCase):
def test_bf16_creation_numpy(self):
@@ -459,6 +496,18 @@ class TestHelpers(unittest.TestCase):
self.assertEqual(truncate_bf16(max_bf16 * 1.00001), math.inf)
self.assertEqual(truncate_bf16(min_bf16 * 1.00001), -math.inf)
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True))
def test_truncate_fp8e4m3(self, x):
if x > FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), FP8E4M3_MAX)
elif x < -FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), -FP8E4M3_MAX)
else: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), ml_dtypes.float8_e4m3fn(x))
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True))
def test_truncate_fp8e5m2(self, x):
if x > FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), FP8E5M2_MAX)
elif x < -FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), -FP8E5M2_MAX)
else: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), ml_dtypes.float8_e5m2(x))
class TestTypeSpec(unittest.TestCase):
def setUp(self):
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float