fix inf bug in float_to_fp8 (#12085)

This commit is contained in:
b1tg
2025-09-10 00:02:56 +08:00
committed by GitHub
parent 14faf7a5c0
commit 82e955fe79
3 changed files with 14 additions and 6 deletions

View File

@@ -155,10 +155,10 @@ class TestFp8sConversions(unittest.TestCase):
def test_float_to_fp8e4m3_extreme_values(self):
np.testing.assert_equal(float_to_fp8(FP8E4M3_MAX, dtypes.fp8e4m3), 126)
np.testing.assert_equal(float_to_fp8(FP8E4M3_MAX*1.01, dtypes.fp8e4m3), 126)
np.testing.assert_equal(float_to_fp8(math.inf, dtypes.fp8e4m3), 126)
np.testing.assert_equal(float_to_fp8(math.inf, dtypes.fp8e4m3), 127)
np.testing.assert_equal(float_to_fp8(-FP8E4M3_MAX, dtypes.fp8e4m3), 254)
np.testing.assert_equal(float_to_fp8(-FP8E4M3_MAX*1.01, dtypes.fp8e4m3), 254)
np.testing.assert_equal(float_to_fp8(-math.inf, dtypes.fp8e4m3), 254)
np.testing.assert_equal(float_to_fp8(-math.inf, dtypes.fp8e4m3), 255)
np.testing.assert_equal(float_to_fp8(math.nan, dtypes.fp8e4m3), 127)
np.testing.assert_equal(float_to_fp8(-math.nan, dtypes.fp8e4m3), 255)
@@ -168,10 +168,10 @@ class TestFp8sConversions(unittest.TestCase):
def test_float_to_fp8e5m2_extreme_values(self):
np.testing.assert_equal(float_to_fp8(FP8E5M2_MAX, dtypes.fp8e5m2), 123)
np.testing.assert_equal(float_to_fp8(FP8E5M2_MAX*1.01, dtypes.fp8e5m2), 123)
np.testing.assert_equal(float_to_fp8(math.inf, dtypes.fp8e5m2), 123)
np.testing.assert_equal(float_to_fp8(math.inf, dtypes.fp8e5m2), 124)
np.testing.assert_equal(float_to_fp8(-FP8E5M2_MAX, dtypes.fp8e5m2), 251)
np.testing.assert_equal(float_to_fp8(-FP8E5M2_MAX*1.01, dtypes.fp8e5m2), 251)
np.testing.assert_equal(float_to_fp8(-math.inf, dtypes.fp8e5m2), 251)
np.testing.assert_equal(float_to_fp8(-math.inf, dtypes.fp8e5m2), 252)
np.testing.assert_equal(float_to_fp8(math.nan, dtypes.fp8e5m2), 126)
np.testing.assert_equal(float_to_fp8(-math.nan, dtypes.fp8e5m2), 254)

View File

@@ -186,13 +186,17 @@ class TestHelpers(unittest.TestCase):
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True))
def test_truncate_fp8e4m3(self, x):
if x > FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), FP8E4M3_MAX)
if math.isnan(x): np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), x)
elif math.isinf(x): np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), math.copysign(math.nan, x))
elif x > FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), FP8E4M3_MAX)
elif x < -FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), -FP8E4M3_MAX)
else: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), ml_dtypes.float8_e4m3fn(x))
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True))
def test_truncate_fp8e5m2(self, x):
if x > FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), FP8E5M2_MAX)
if math.isnan(x): np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), x)
elif math.isinf(x): np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), x)
elif x > FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), FP8E5M2_MAX)
elif x < -FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), -FP8E5M2_MAX)
else: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), ml_dtypes.float8_e5m2(x))

View File

@@ -230,6 +230,10 @@ def float_to_bf16(x):
# fp8-float conversions based on https://gitlab.com/nvidia/headers/cuda-individual/cudart/-/blob/main/cuda_fp8.hpp
def float_to_fp8(x: float, dtype: DType) -> int:
assert dtype in dtypes.fp8s, "Only for fp8s"
# e4m3 don't support inf, return 0x7f(+NaN) and 0xff(-NaN) to match jax
# NaN is unordered, can't compare with zero, use math.copysign to get sign
if dtype == dtypes.fp8e4m3 and not math.isfinite(x): return 0x7f if math.copysign(1, x) > 0 else 0xff
if dtype == dtypes.fp8e5m2 and math.isinf(x): return 0x7c if math.copysign(1, x) > 0 else 0xfc
config = {
dtypes.fp8e4m3: {"EXP_BIAS": 7, "SIGNIFICAND_BITS": 4, "MANTISSA_MASK": 0x7, "MINDENORM_O2": 0x3F50000000000000,
"OVERFLOW_THRESHOLD": 0x407D000000000000, "MAXNORM": 0x7E, "MINNORM": 0x3F90000000000000, "INF_VALUE": 0x7F},