python float8 support (#11960)

* basic support

* alu

* nan in exec_alu

* rand_for_dtype

* inf + 0.0

* finfo

* revert rand_for_dtype

* clean

* truncate fp8s inf

* spec ok

* float_to_fp8 nan/inf

* least_upper_dtype

* clean up

---------

Co-authored-by: b1tg <b1tg@users.noreply.github.com>
This commit is contained in:
b1tg
2025-09-18 21:17:09 +08:00
committed by GitHub
parent dbbc261075
commit 54c15d74a4
6 changed files with 56 additions and 17 deletions

View File

@@ -21,7 +21,9 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target, tol_target_dtype:float
if DEBUG >= 2: print(tensor.numpy())
try:
assert tensor.dtype == target_dtype
np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2}.get(target_dtype, tol_target_dtype))
np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2,
dtypes.fp8e4m3:1e-1, dtypes.fp8e5m2:5e-1}.get(target_dtype, tol_target_dtype))
except AssertionError as e:
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
@@ -576,10 +578,10 @@ class TestAutoCastType(unittest.TestCase):
def test_gradient_dtype(self):
old_default_float = dtypes.default_float
for default_dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
for default_dtype in dtypes.floats:
if not is_dtype_supported(default_dtype): continue
dtypes.default_float = default_dtype
for dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
for dtype in dtypes.floats:
if not is_dtype_supported(dtype): continue
if DEBUG >= 2:
print(f"testing {default_dtype=}, {dtype=}")