mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
FP8s truncate (#9937)
* truncate fp8 * fix * maybe like that? * fix linters * ruff * move from extra and add ml_types to tests * minor changes * str to dtypes and nan support --------- Co-authored-by: pkotzbach <pawkotz@gmail.com>
This commit is contained in:
@@ -5,10 +5,12 @@ from typing import Any, List
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import getenv, DEBUG, CI
|
||||
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype, truncate_fp16, truncate_bf16, to_dtype
|
||||
from tinygrad.dtype import truncate, fp8_to_float, float_to_fp8
|
||||
from tinygrad import Device, Tensor, dtypes
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from hypothesis import assume, given, settings, strategies as strat
|
||||
from test.helpers import rand_for_dtype
|
||||
import ml_dtypes
|
||||
import pytest
|
||||
pytestmark = pytest.mark.filterwarnings("ignore")
|
||||
|
||||
@@ -19,6 +21,8 @@ core_dtypes = list(DTYPES_DICT.values())
|
||||
if Device.DEFAULT == "CPU": core_dtypes.remove(dtypes.bfloat16) # NOTE: this is for teenygrad, don't remove
|
||||
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and is_dtype_supported(dt)]
|
||||
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)]
|
||||
FP8E4M3_MAX = 448.0
|
||||
FP8E5M2_MAX = 57344.0
|
||||
|
||||
def get_available_cast_dtypes(dtype: DType) -> List[DType]:
|
||||
if not is_dtype_supported(dtype): return []
|
||||
@@ -146,6 +150,39 @@ class TestFp8s(unittest.TestCase):
|
||||
def test_fp8e4m3_creation(self): assert Tensor([-1, 1, 2], dtype=dtypes.fp8e4m3).dtype == dtypes.fp8e4m3
|
||||
def test_fp8e5m2_creation(self): assert Tensor([-1, 1, 2], dtype=dtypes.fp8e5m2).dtype == dtypes.fp8e5m2
|
||||
|
||||
class TestFp8sConversions(unittest.TestCase):
|
||||
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=False, allow_infinity=False, min_value=-FP8E4M3_MAX, max_value=FP8E4M3_MAX))
|
||||
def test_float_to_fp8e4m3(self, x): np.testing.assert_equal(float_to_fp8(x, dtypes.fp8e4m3), ml_dtypes.float8_e4m3fn(x).tobytes()[0])
|
||||
|
||||
def test_float_to_fp8e4m3_extreme_values(self):
|
||||
np.testing.assert_equal(float_to_fp8(FP8E4M3_MAX, dtypes.fp8e4m3), 126)
|
||||
np.testing.assert_equal(float_to_fp8(FP8E4M3_MAX*1.01, dtypes.fp8e4m3), 126)
|
||||
np.testing.assert_equal(float_to_fp8(math.inf, dtypes.fp8e4m3), 126)
|
||||
np.testing.assert_equal(float_to_fp8(-FP8E4M3_MAX, dtypes.fp8e4m3), 254)
|
||||
np.testing.assert_equal(float_to_fp8(-FP8E4M3_MAX*1.01, dtypes.fp8e4m3), 254)
|
||||
np.testing.assert_equal(float_to_fp8(-math.inf, dtypes.fp8e4m3), 254)
|
||||
np.testing.assert_equal(float_to_fp8(math.nan, dtypes.fp8e4m3), 127)
|
||||
np.testing.assert_equal(float_to_fp8(-math.nan, dtypes.fp8e4m3), 255)
|
||||
|
||||
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=False, allow_infinity=False, min_value=-FP8E5M2_MAX, max_value=FP8E5M2_MAX))
|
||||
def test_float_to_fp8e5m2(self, x): np.testing.assert_equal(float_to_fp8(x, dtypes.fp8e5m2), ml_dtypes.float8_e5m2(x).tobytes()[0])
|
||||
|
||||
def test_float_to_fp8e5m2_extreme_values(self):
|
||||
np.testing.assert_equal(float_to_fp8(FP8E5M2_MAX, dtypes.fp8e5m2), 123)
|
||||
np.testing.assert_equal(float_to_fp8(FP8E5M2_MAX*1.01, dtypes.fp8e5m2), 123)
|
||||
np.testing.assert_equal(float_to_fp8(math.inf, dtypes.fp8e5m2), 123)
|
||||
np.testing.assert_equal(float_to_fp8(-FP8E5M2_MAX, dtypes.fp8e5m2), 251)
|
||||
np.testing.assert_equal(float_to_fp8(-FP8E5M2_MAX*1.01, dtypes.fp8e5m2), 251)
|
||||
np.testing.assert_equal(float_to_fp8(-math.inf, dtypes.fp8e5m2), 251)
|
||||
np.testing.assert_equal(float_to_fp8(math.nan, dtypes.fp8e5m2), 126)
|
||||
np.testing.assert_equal(float_to_fp8(-math.nan, dtypes.fp8e5m2), 254)
|
||||
|
||||
@given(strat.integers(min_value=0, max_value=255))
|
||||
def test_fp8e4m3_to_float(self, x): np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e4m3), np.uint8(x).view(ml_dtypes.float8_e4m3fn).item())
|
||||
|
||||
@given(strat.integers(min_value=0, max_value=255))
|
||||
def test_fp8e5m2_to_float(self, x): np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e5m2), np.uint8(x).view(ml_dtypes.float8_e5m2).item())
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
|
||||
class TestBFloat16(unittest.TestCase):
|
||||
def test_bf16_creation_numpy(self):
|
||||
@@ -459,6 +496,18 @@ class TestHelpers(unittest.TestCase):
|
||||
self.assertEqual(truncate_bf16(max_bf16 * 1.00001), math.inf)
|
||||
self.assertEqual(truncate_bf16(min_bf16 * 1.00001), -math.inf)
|
||||
|
||||
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True))
|
||||
def test_truncate_fp8e4m3(self, x):
|
||||
if x > FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), FP8E4M3_MAX)
|
||||
elif x < -FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), -FP8E4M3_MAX)
|
||||
else: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), ml_dtypes.float8_e4m3fn(x))
|
||||
|
||||
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True))
|
||||
def test_truncate_fp8e5m2(self, x):
|
||||
if x > FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), FP8E5M2_MAX)
|
||||
elif x < -FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), -FP8E5M2_MAX)
|
||||
else: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), ml_dtypes.float8_e5m2(x))
|
||||
|
||||
class TestTypeSpec(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
|
||||
|
||||
Reference in New Issue
Block a user