mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
fix inf bug in float_to_fp8 (#12085)
This commit is contained in:
@@ -186,13 +186,17 @@ class TestHelpers(unittest.TestCase):
|
||||
|
||||
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True))
|
||||
def test_truncate_fp8e4m3(self, x):
|
||||
if x > FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), FP8E4M3_MAX)
|
||||
if math.isnan(x): np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), x)
|
||||
elif math.isinf(x): np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), math.copysign(math.nan, x))
|
||||
elif x > FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), FP8E4M3_MAX)
|
||||
elif x < -FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), -FP8E4M3_MAX)
|
||||
else: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), ml_dtypes.float8_e4m3fn(x))
|
||||
|
||||
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True))
|
||||
def test_truncate_fp8e5m2(self, x):
|
||||
if x > FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), FP8E5M2_MAX)
|
||||
if math.isnan(x): np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), x)
|
||||
elif math.isinf(x): np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), x)
|
||||
elif x > FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), FP8E5M2_MAX)
|
||||
elif x < -FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), -FP8E5M2_MAX)
|
||||
else: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), ml_dtypes.float8_e5m2(x))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user