fix inf bug in float_to_fp8 (#12085)

This commit is contained in:
b1tg
2025-09-10 00:02:56 +08:00
committed by GitHub
parent 14faf7a5c0
commit 82e955fe79
3 changed files with 14 additions and 6 deletions

View File

@@ -186,13 +186,17 @@ class TestHelpers(unittest.TestCase):
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True))
def test_truncate_fp8e4m3(self, x):
if x > FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), FP8E4M3_MAX)
if math.isnan(x): np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), x)
elif math.isinf(x): np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), math.copysign(math.nan, x))
elif x > FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), FP8E4M3_MAX)
elif x < -FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), -FP8E4M3_MAX)
else: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), ml_dtypes.float8_e4m3fn(x))
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True))
def test_truncate_fp8e5m2(self, x):
if x > FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), FP8E5M2_MAX)
if math.isnan(x): np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), x)
elif math.isinf(x): np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), x)
elif x > FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), FP8E5M2_MAX)
elif x < -FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), -FP8E5M2_MAX)
else: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), ml_dtypes.float8_e5m2(x))