mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
fix inf bug in float_to_fp8 (#12085)
This commit is contained in:
@@ -155,10 +155,10 @@ class TestFp8sConversions(unittest.TestCase):
|
||||
def test_float_to_fp8e4m3_extreme_values(self):
|
||||
np.testing.assert_equal(float_to_fp8(FP8E4M3_MAX, dtypes.fp8e4m3), 126)
|
||||
np.testing.assert_equal(float_to_fp8(FP8E4M3_MAX*1.01, dtypes.fp8e4m3), 126)
|
||||
np.testing.assert_equal(float_to_fp8(math.inf, dtypes.fp8e4m3), 126)
|
||||
np.testing.assert_equal(float_to_fp8(math.inf, dtypes.fp8e4m3), 127)
|
||||
np.testing.assert_equal(float_to_fp8(-FP8E4M3_MAX, dtypes.fp8e4m3), 254)
|
||||
np.testing.assert_equal(float_to_fp8(-FP8E4M3_MAX*1.01, dtypes.fp8e4m3), 254)
|
||||
np.testing.assert_equal(float_to_fp8(-math.inf, dtypes.fp8e4m3), 254)
|
||||
np.testing.assert_equal(float_to_fp8(-math.inf, dtypes.fp8e4m3), 255)
|
||||
np.testing.assert_equal(float_to_fp8(math.nan, dtypes.fp8e4m3), 127)
|
||||
np.testing.assert_equal(float_to_fp8(-math.nan, dtypes.fp8e4m3), 255)
|
||||
|
||||
@@ -168,10 +168,10 @@ class TestFp8sConversions(unittest.TestCase):
|
||||
def test_float_to_fp8e5m2_extreme_values(self):
|
||||
np.testing.assert_equal(float_to_fp8(FP8E5M2_MAX, dtypes.fp8e5m2), 123)
|
||||
np.testing.assert_equal(float_to_fp8(FP8E5M2_MAX*1.01, dtypes.fp8e5m2), 123)
|
||||
np.testing.assert_equal(float_to_fp8(math.inf, dtypes.fp8e5m2), 123)
|
||||
np.testing.assert_equal(float_to_fp8(math.inf, dtypes.fp8e5m2), 124)
|
||||
np.testing.assert_equal(float_to_fp8(-FP8E5M2_MAX, dtypes.fp8e5m2), 251)
|
||||
np.testing.assert_equal(float_to_fp8(-FP8E5M2_MAX*1.01, dtypes.fp8e5m2), 251)
|
||||
np.testing.assert_equal(float_to_fp8(-math.inf, dtypes.fp8e5m2), 251)
|
||||
np.testing.assert_equal(float_to_fp8(-math.inf, dtypes.fp8e5m2), 252)
|
||||
np.testing.assert_equal(float_to_fp8(math.nan, dtypes.fp8e5m2), 126)
|
||||
np.testing.assert_equal(float_to_fp8(-math.nan, dtypes.fp8e5m2), 254)
|
||||
|
||||
|
||||
@@ -186,13 +186,17 @@ class TestHelpers(unittest.TestCase):
|
||||
|
||||
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True))
|
||||
def test_truncate_fp8e4m3(self, x):
|
||||
if x > FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), FP8E4M3_MAX)
|
||||
if math.isnan(x): np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), x)
|
||||
elif math.isinf(x): np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), math.copysign(math.nan, x))
|
||||
elif x > FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), FP8E4M3_MAX)
|
||||
elif x < -FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), -FP8E4M3_MAX)
|
||||
else: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), ml_dtypes.float8_e4m3fn(x))
|
||||
|
||||
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True))
|
||||
def test_truncate_fp8e5m2(self, x):
|
||||
if x > FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), FP8E5M2_MAX)
|
||||
if math.isnan(x): np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), x)
|
||||
elif math.isinf(x): np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), x)
|
||||
elif x > FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), FP8E5M2_MAX)
|
||||
elif x < -FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), -FP8E5M2_MAX)
|
||||
else: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), ml_dtypes.float8_e5m2(x))
|
||||
|
||||
|
||||
@@ -230,6 +230,10 @@ def float_to_bf16(x):
|
||||
# fp8-float conversions based on https://gitlab.com/nvidia/headers/cuda-individual/cudart/-/blob/main/cuda_fp8.hpp
|
||||
def float_to_fp8(x: float, dtype: DType) -> int:
|
||||
assert dtype in dtypes.fp8s, "Only for fp8s"
|
||||
# e4m3 don't support inf, return 0x7f(+NaN) and 0xff(-NaN) to match jax
|
||||
# NaN is unordered, can't compare with zero, use math.copysign to get sign
|
||||
if dtype == dtypes.fp8e4m3 and not math.isfinite(x): return 0x7f if math.copysign(1, x) > 0 else 0xff
|
||||
if dtype == dtypes.fp8e5m2 and math.isinf(x): return 0x7c if math.copysign(1, x) > 0 else 0xfc
|
||||
config = {
|
||||
dtypes.fp8e4m3: {"EXP_BIAS": 7, "SIGNIFICAND_BITS": 4, "MANTISSA_MASK": 0x7, "MINDENORM_O2": 0x3F50000000000000,
|
||||
"OVERFLOW_THRESHOLD": 0x407D000000000000, "MAXNORM": 0x7E, "MINNORM": 0x3F90000000000000, "INF_VALUE": 0x7F},
|
||||
|
||||
Reference in New Issue
Block a user