mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
python float8 support (#11960)
* basic support * alu * nan in exec_alu * rand_for_dtype * inf + 0.0 * finfo * revert rand_for_dtype * clean * truncate fp8s inf * spec ok * float_to_fp8 nan/inf * least_upper_dtype * clean up --------- Co-authored-by: b1tg <b1tg@users.noreply.github.com>
This commit is contained in:
@@ -21,7 +21,9 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target, tol_target_dtype:float
|
||||
if DEBUG >= 2: print(tensor.numpy())
|
||||
try:
|
||||
assert tensor.dtype == target_dtype
|
||||
np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2}.get(target_dtype, tol_target_dtype))
|
||||
np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2,
|
||||
dtypes.fp8e4m3:1e-1, dtypes.fp8e5m2:5e-1}.get(target_dtype, tol_target_dtype))
|
||||
|
||||
except AssertionError as e:
|
||||
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
|
||||
|
||||
@@ -576,10 +578,10 @@ class TestAutoCastType(unittest.TestCase):
|
||||
def test_gradient_dtype(self):
|
||||
old_default_float = dtypes.default_float
|
||||
|
||||
for default_dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
|
||||
for default_dtype in dtypes.floats:
|
||||
if not is_dtype_supported(default_dtype): continue
|
||||
dtypes.default_float = default_dtype
|
||||
for dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
|
||||
for dtype in dtypes.floats:
|
||||
if not is_dtype_supported(dtype): continue
|
||||
if DEBUG >= 2:
|
||||
print(f"testing {default_dtype=}, {dtype=}")
|
||||
|
||||
Reference in New Issue
Block a user