diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c362f9b8f2..baaca1418f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -243,6 +243,8 @@ jobs: DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm_fp16 DEBUG=2 EMULATE_CUDA=1 ALLOW_TF32=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm DEBUG=2 EMULATE_CUDA_SM75=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm_fp16 + DEBUG=2 EMULATE_CUDA_SM89=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOpsFp8s.test_gemm_fp8e4m3 + DEBUG=2 EMULATE_CUDA_SM89=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOpsFp8s.test_gemm_fp8e5m2 PYTHONPATH="." DEBUG=2 EMULATE_CUDA=1 ALLOW_TF32=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded - name: Test emulated INTEL OpenCL tensor cores run: DEBUG=2 EMULATE_INTEL=1 FORWARD_ONLY=1 PYTHON=1 HALF=1 N=64 python3 ./extra/gemm/simple_matmul.py diff --git a/extra/gemm/simple_matmul.py b/extra/gemm/simple_matmul.py index 1edad82c23..8e97b61ba2 100644 --- a/extra/gemm/simple_matmul.py +++ b/extra/gemm/simple_matmul.py @@ -3,9 +3,11 @@ from tinygrad.helpers import getenv from tinygrad.dtype import _to_np_dtype from tinygrad import dtypes, Tensor -dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float -acc_dtype = dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else None -if getenv("INT"): dtype_in, acc_dtype = dtypes.int8, dtypes.int32 +dtype_in = (dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else + dtypes.fp8e4m3 if getenv("FP8E4M3") else dtypes.fp8e5m2 if getenv("FP8E5M2") else dtypes.float) +acc_dtype = (dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else + dtypes.fp8e4m3 if getenv("ACC_FP8E4M3") else dtypes.fp8e5m2 if getenv("ACC_FP8E5M2") else None) +if getenv("INT"): dtype_in = dtypes.int8acc_dtype = dtypes.int32 if getenv("UINT"): dtype_in, acc_dtype = dtypes.uint8, dtypes.int32 N = getenv("N", 4096) diff --git a/setup.py b/setup.py index c022ff342e..95d6d2676e 100644 --- a/setup.py +++ b/setup.py @@ -70,6 +70,7 @@ setup(name='tinygrad', "bottle", "ggml-python", "capstone", + "ml_dtypes", "pycocotools", "boto3", "pandas" diff --git a/test/test_dtype.py b/test/test_dtype.py index 79fbfb8f6d..76cd18943e 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -39,7 +39,8 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target, tol_target_dtype:float if DEBUG >= 2: print(tensor.numpy()) try: assert tensor.dtype == target_dtype - np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2}.get(target_dtype, tol_target_dtype)) + np.testing.assert_allclose(tensor.numpy(), target, + rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2, dtypes.fp8e5m2: 1, dtypes.fp8e4m3: 1e-1}.get(target_dtype, tol_target_dtype)) except AssertionError as e: raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e @@ -56,6 +57,7 @@ def _test_cast(a:Tensor, target_dtype:DType): _test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(_to_np_dtype(target_dtype)))) def _test_bitcast(a:Tensor, target_dtype:DType, target=None): if target_dtype == dtypes.bfloat16: raise unittest.SkipTest("no test for bf16 bitcast yet") + if target_dtype in dtypes.fp8s: raise unittest.SkipTest("no test for fp8s bitcast yet") if getenv("PTX") and a.dtype == dtypes.int8 and target_dtype.itemsize != a.dtype.itemsize: raise unittest.SkipTest("shape changing bitcast of int8 broken on PTX") _test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(_to_np_dtype(target_dtype)).tolist()) @@ -109,7 +111,7 @@ class TestDType(unittest.TestCase): fields = dtypes.fields() self.assertIn("float", fields) self.assertIn("float32", fields) - self.assertEqual(len(fields), 24) + self.assertEqual(len(fields), 26) self.assertTrue(all(isinstance(value, DType) for value in fields.values())) self.assertTrue(all(issubclass(_to_np_dtype(value), np.generic) for value in fields.values() if _to_np_dtype(value) is not None)) @@ -205,6 +207,31 @@ class TestBFloat16DTypeCast(unittest.TestCase): converted = random_values.cast(dtypes.bfloat16).cast(dtypes.float32) np.testing.assert_allclose(converted.numpy(), random_values.cast(dtypes.float32).numpy(), rtol=1e-2, atol=1e-3) +class TestFp8sDType(unittest.TestCase): + def _float_to_fp8_conversion_test(self, dtype, input_values, expected_values): + test_tensor = Tensor(input_values).cast(dtype).realize() + back_to_float32 = test_tensor.cast(dtypes.float32) + np.testing.assert_equal(tuple(back_to_float32.numpy().tolist()), expected_values) + + @unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), "fp8e4m3 not supported") + def test_float_to_fp8e4m3_conversion(self): + self._float_to_fp8_conversion_test(dtypes.fp8e4m3, + [10000000.0, -1.0, 402.0, -300.0, -10000000.0, 20.0, 1.4123, 0.0, math.inf, math.nan], + [448.0, -1.0, 416.0, -288.0, -448.0, 20.0, 1.375, 0.0, 448.0, math.nan]) + + @unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), "fp8e5m2 not supported") + def test_float_to_fp8e5m2_conversion(self): + self._float_to_fp8_conversion_test(dtypes.fp8e5m2, + [10000000.0, -1.0, 402.0, -300.0, -10000000.0, 20.0, 1.4123, 0.0, math.inf, math.nan], + [57344.0, -1, 384, -320, -57344.0, 20, 1.5, 0.0, 57344.0, math.nan]) + + @unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3) and is_dtype_supported(dtypes.fp8e5m2), "fp8s not supported") + def test_fp8e4m3_plus_fp8e5m2_output_dtype(self): + a = Tensor([1.0, 2.0, 3.0], dtype=dtypes.fp8e4m3) + b = Tensor([1.0, 2.0, 3.0], dtype=dtypes.fp8e5m2) + result = a + b + self.assertEqual(result.dtype, dtypes.half) + class TestHalfDType(TestDType): DTYPE = dtypes.half class TestFloatDType(TestDType): @@ -264,6 +291,7 @@ class TestBitCast(unittest.TestCase): def test_shape_change_bitcast(self, dt1, dt2): # NOTE: this has to be assume to prevent hypothesis from skipping all samples assume(dt2 != dtypes.bfloat16 and dt1 != dtypes.bfloat16) # no test for bf16 bitcast yet + assume(dt1 not in dtypes.fp8s and dt2 not in dtypes.fp8s) # no test for fp8 bitcast yet assume(not (getenv("PTX") and dt1 == dtypes.int8)) # TODO: bitcasting int8 fails in PTX data = rand_for_dtype(dt1, 32).reshape(2, 2, 8) _test_op(lambda: Tensor(data, dtype=dt1).bitcast(dt2), dt2, data.view(_to_np_dtype(dt2)).tolist()) @@ -399,6 +427,10 @@ class TestHelpers(unittest.TestCase): def test_bf16_is_float(self): assert dtypes.is_float(dtypes.bfloat16) + def test_fp8s_are_float(self): + assert dtypes.is_float(dtypes.fp8e4m3) + assert dtypes.is_float(dtypes.fp8e5m2) + @given(strat.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_float(d) or dtypes.is_int(d)]), strat.integers(min_value=2, max_value=8)) def test_scalar(self, dtype, amt): assert dtype.vec(amt).scalar() == dtype @@ -462,7 +494,7 @@ class TestTypeSpec(unittest.TestCase): dtypes.default_int = default_int assert dtypes.default_int == default_int - for default_float in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]: + for default_float in [*dtypes.fp8s, dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]: dtypes.default_float = default_float assert dtypes.default_float == default_float @@ -693,7 +725,9 @@ class TestAutoCastType(unittest.TestCase): assert (Tensor([0, 1], dtype=dtypes.uint32)).sum().dtype == dtypes.uint32 assert (Tensor([0, 1], dtype=dtypes.uint64)).sum().dtype == dtypes.uint64 assert (Tensor([0, 1], dtype=dtypes.float16)).sum().dtype == dtypes.float16 - #assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16 + assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16 + assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).sum().dtype == dtypes.fp8e4m3 + assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).sum().dtype == dtypes.fp8e5m2 assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32 assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64 @@ -725,7 +759,9 @@ class TestAutoCastType(unittest.TestCase): assert (Tensor([0, 1], dtype=dtypes.uint32)).mean().dtype == dtypes.float32 assert (Tensor([0, 1], dtype=dtypes.uint64)).mean().dtype == dtypes.float32 assert (Tensor([0, 1], dtype=dtypes.float16)).mean().dtype == dtypes.float16 - #assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16 + assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16 + assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).mean().dtype == dtypes.fp8e4m3 + assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).mean().dtype == dtypes.fp8e5m2 assert (Tensor([0, 1], dtype=dtypes.float32)).mean().dtype == dtypes.float32 assert (Tensor([0, 1], dtype=dtypes.float64)).mean().dtype == dtypes.float64 @@ -740,7 +776,9 @@ class TestAutoCastType(unittest.TestCase): assert (Tensor([0, 1], dtype=dtypes.uint32)).cumsum(0).dtype == dtypes.uint32 assert (Tensor([0, 1], dtype=dtypes.uint64)).cumsum(0).dtype == dtypes.uint64 assert (Tensor([0, 1], dtype=dtypes.float16)).cumsum(0).dtype == dtypes.float16 - #assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16 + assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16 + assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).cumsum().dtype == dtypes.fp8e4m3 + assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).cumsum().dtype == dtypes.fp8e5m2 assert (Tensor([0, 1], dtype=dtypes.float32)).cumsum(0).dtype == dtypes.float32 assert (Tensor([0, 1], dtype=dtypes.float64)).cumsum(0).dtype == dtypes.float64 @@ -801,10 +839,10 @@ class TestAutoCastType(unittest.TestCase): def test_gradient_dtype(self): old_default_float = dtypes.default_float - for default_dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]: + for default_dtype in [*dtypes.fp8s, dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]: if not is_dtype_supported(default_dtype): continue dtypes.default_float = default_dtype - for dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]: + for dtype in [*dtypes.fp8s, dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]: if not is_dtype_supported(dtype): continue if DEBUG >= 2: print(f"testing {default_dtype=}, {dtype=}") diff --git a/test/test_dtype_alu.py b/test/test_dtype_alu.py index a01246fb68..7d051723d1 100644 --- a/test/test_dtype_alu.py +++ b/test/test_dtype_alu.py @@ -3,7 +3,7 @@ import unittest from tinygrad import Tensor, dtypes, Device import operator import numpy as np -from hypothesis import given, strategies as strat, settings, HealthCheck +from hypothesis import given, strategies as strat, settings, HealthCheck, assume from tinygrad.dtype import DType from tinygrad.helpers import CI, getenv from tinygrad.engine.realize import run_schedule @@ -48,6 +48,8 @@ class ht: float32 = strat.floats(width=32, allow_subnormal=False) float16 = strat.floats(width=16, allow_subnormal=False) bfloat16 = strat.floats(width=16, allow_subnormal=False) + fp8e4m3 = strat.sampled_from([0.0, 0.1, 0.5, 1.0, -0.1, -0.5, -1.0]) + fp8e5m2 = strat.sampled_from([0.0, 0.1, 0.5, 1.0, -0.1, -0.5, -1.0]) uint8 = strat.integers(0, 255) uint16 = strat.integers(0, 65535) uint32 = strat.integers(0, 2**32-1) @@ -64,7 +66,8 @@ def universal_test(a, b, dtype, op): if not isinstance(op, tuple): op = (op, op) tensor_value = (op[0](Tensor([a], dtype=dtype), Tensor([b], dtype=dtype))).numpy() numpy_value = op[1](np.array([a]).astype(_to_np_dtype(dtype)), np.array([b]).astype(_to_np_dtype(dtype))) - if dtype is dtypes.bfloat16: np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-3, rtol=1e-2) + if dtype in dtypes.fp8s: np.testing.assert_allclose(tensor_value, numpy_value, atol=0.5, rtol=1e-2) + elif dtype is dtypes.bfloat16: np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-3, rtol=1e-2) elif dtype in dtypes_float: np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-10) else: np.testing.assert_equal(tensor_value, numpy_value) @@ -76,7 +79,8 @@ def universal_test_unary(a, dtype, op): run_schedule(sched) tensor_value = out.numpy() numpy_value = op[1](np.array([a]).astype(_to_np_dtype(dtype))) - if dtype in (*dtypes_float, dtypes.bfloat16): + if dtype in dtypes.fp8s: np.testing.assert_allclose(tensor_value, numpy_value, atol=2, rtol=1e-2) + elif dtype in (*dtypes_float, dtypes.bfloat16): np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-3, rtol=1e-2) else: np.testing.assert_equal(tensor_value, numpy_value) if op[0] != Tensor.reciprocal: # reciprocal is not supported in most backends @@ -114,6 +118,26 @@ class TestDTypeALU(unittest.TestCase): @given(ht.bfloat16, ht.bfloat16, strat.sampled_from(binary_operations)) def test_bfloat16(self, a, b, op): universal_test(a, b, dtypes.bfloat16, op) + @unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3, Device.DEFAULT), f"no fp8e4m3 on {Device.DEFAULT}") + @given(ht.fp8e4m3, ht.fp8e4m3, strat.sampled_from(binary_operations)) + def test_fp8e4m3(self, a, b, op): universal_test(a, b, dtypes.fp8e4m3, op) + + @unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2, Device.DEFAULT), f"no fp8e5m2 on {Device.DEFAULT}") + @given(ht.fp8e5m2, ht.fp8e5m2, strat.sampled_from(binary_operations)) + def test_fp8e5m2(self, a, b, op): universal_test(a, b, dtypes.fp8e5m2, op) + + @unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3, Device.DEFAULT), f"no fp8e4m3 on {Device.DEFAULT}") + @given(ht.fp8e4m3, strat.sampled_from(unary_operations)) + def test_fp8e4m3_unary(self, a, op): + if (op[1] == np.reciprocal or op[1] == np.log): assume(a != 0.0) # reciprocal(0) and log(0) are undefined + universal_test_unary(a, dtypes.fp8e4m3, op) + + @unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2, Device.DEFAULT), f"no fp8e5m2 on {Device.DEFAULT}") + @given(ht.fp8e5m2, strat.sampled_from(unary_operations)) + def test_fp8e5m2_unary(self, a, op): + if (op[1] == np.reciprocal or op[1] == np.log): assume(a != 0.0) # reciprocal(0) and log(0) are undefined + universal_test_unary(a, dtypes.fp8e5m2, op) + @given(ht.float32, strat.sampled_from(unary_operations)) def test_float32_unary(self, a, op): universal_test_unary(a, dtypes.float32, op) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index c07f08047b..7ae1b99b80 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -41,6 +41,7 @@ def helper_tc_allclose(N:int, M:int, K:int, dtype_in:DType, dtype_out:DType, axi np_c = np_a @ np_b if dtype_in == dtypes.half: tc_atol, tc_rtol = 1e-2, 1e-3 elif dtype_in == dtypes.bfloat16: tc_atol, tc_rtol = 1e-2, 1e-2 + elif dtype_in in dtypes.fp8s: tc_atol, tc_rtol = 1e-1, 1e-2 else: tc_atol, tc_rtol = 5e-3, 1e-4 np.testing.assert_allclose(np_c, out, atol=tc_atol, rtol=tc_rtol) @@ -62,6 +63,13 @@ def helper_tc_ensure_uops_and_opts_count(N: int, M:int, K:int, dtype_in:DType, d assert wmmas == 0, "tensor core is incorrectly triggered" assert tcs == 0, "tensor core opt is incorrectly included" +def is_emulated(tc) -> bool: + return (getenv("EMULATE_CUDA_SM89") or getenv("EMULATE_CUDA") or getenv("EMULATE_INTEL") or getenv("EMULATE_METAL") + or getenv("EMULATE_AMD_MFMA") or getenv("EMULATE_AMD")) and \ + ((tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16) or \ + (tc.dtype_in == dtypes.fp8e4m3 or tc.dtype_out == dtypes.fp8e4m3) or \ + (tc.dtype_in == dtypes.fp8e5m2 or tc.dtype_out == dtypes.fp8e5m2)) + class TestLinearizer(unittest.TestCase): def test_arg_dedup(self): # NOTE: this realize exists because Tensor.numpy calls .contiguous() internally @@ -1024,7 +1032,8 @@ class TestLinearizer(unittest.TestCase): def test_sum_acc_dtype(self): for tensor_dtype, acc_dtype in ( - (dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)): + (dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float), + (dtypes.fp8e4m3, dtypes.float), (dtypes.fp8e5m2, dtypes.float)): if is_dtype_supported(tensor_dtype) and is_dtype_supported(acc_dtype): a = Tensor([1, 2, 3], dtype=tensor_dtype).sum() k = Kernel(a.schedule()[-1].ast) @@ -1046,6 +1055,10 @@ class TestLinearizer(unittest.TestCase): (dtypes.float16, dtypes.float16, dtypes.float16), (dtypes.bfloat16, dtypes.bfloat16, dtypes.bfloat16), (dtypes.float, dtypes.float16, dtypes.float16), + (dtypes.fp8e5m2, dtypes.fp8e5m2, dtypes.fp8e5m2), + (dtypes.fp8e4m3, dtypes.fp8e4m3, dtypes.fp8e4m3), + (dtypes.fp8e4m3, None, dtypes.float), + (dtypes.fp8e5m2, None, dtypes.float), ) for tensor_dtype, acc_dtype, expected_dtype in tests: if is_dtype_supported(tensor_dtype) and is_dtype_supported(acc_dtype) and is_dtype_supported(expected_dtype): @@ -1060,8 +1073,7 @@ class TestLinearizer(unittest.TestCase): @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") def test_tensor_cores(self): for tc in Device[Device.DEFAULT].renderer.tensor_cores: - if (getenv("EMULATE_CUDA") or getenv("EMULATE_INTEL") or getenv("EMULATE_METAL") or getenv("EMULATE_AMD_MFMA") or getenv("EMULATE_AMD")) and \ - (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue + if is_emulated(tc): continue if CI and Device.DEFAULT in ("METAL", "AMD") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue # for AMX, tc.dims[2] == 1 so reduceop is None thus tensor_cores are not triggered helper_tc_allclose(tc.dims[0], tc.dims[1], 2 if AMX else tc.dims[2], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0) @@ -1089,8 +1101,7 @@ class TestLinearizer(unittest.TestCase): @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") def test_tensor_cores_padded(self): for tc in Device[Device.DEFAULT].renderer.tensor_cores: - if (getenv("EMULATE_CUDA") or getenv("EMULATE_METAL") or getenv("EMULATE_AMD_MFMA") or getenv("EMULATE_AMD")) and \ - (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue + if is_emulated(tc): continue if CI and Device.DEFAULT in ("METAL", "AMD") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue pad = 1 @@ -1117,7 +1128,8 @@ class TestLinearizer(unittest.TestCase): @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") def test_tensor_cores_multi_reduce(self): for tc in Device[Device.DEFAULT].renderer.tensor_cores: - if tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16: continue + if tc.dtype_in in [dtypes.bfloat16, *dtypes.fp8s] or tc.dtype_out in [dtypes.bfloat16, *dtypes.fp8s]: + continue # this will be a M=G16, N=G32, M=G16, M=G16, K=R16, K=R16, K=R16 with 9 choices of TC MNK axes golden_result = None for axis in range(9): diff --git a/test/test_ops.py b/test/test_ops.py index 9a078b74d1..df47b8f675 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -3003,6 +3003,17 @@ class TestOpsUint8(unittest.TestCase): lambda x: x.type(torch.uint8).min(), lambda x: x.cast(dtypes.uint8).min(), forward_only=True, vals=[[0, 128, 255, 64, 32, 16]]) +@unittest.skipUnless("CUDA" in Device.get_available_devices() and Device.DEFAULT == "PYTHON" and getenv("EMULATE_CUDA_SM89"), + "only for emulated CUDA") +class TestOpsFp8s(unittest.TestCase): + def _compare_to_cuda(self, shp_a, shp_b, op, dtype): + a = Tensor.rand(shp_a, dtype=dtype) + b = Tensor.rand(shp_b, dtype=dtype) + np.testing.assert_equal(op(a, b).numpy(), op(a.to("CUDA"), b.to("CUDA")).numpy()) + + def test_gemm_fp8e4m3(self): self._compare_to_cuda((64, 64), (64, 64), lambda x, y: x.matmul(y), dtypes.fp8e4m3) + def test_gemm_fp8e5m2(self): self._compare_to_cuda((64, 64), (64, 64), lambda x, y: x.matmul(y), dtypes.fp8e5m2) + if __name__ == '__main__': np.random.seed(1337) unittest.main(verbosity=2) diff --git a/test/test_randomness.py b/test/test_randomness.py index b130aa733f..272df9e7ba 100644 --- a/test/test_randomness.py +++ b/test/test_randomness.py @@ -193,6 +193,26 @@ class TestRandomness(unittest.TestCase): assert nx[nx == 0].size > 0 equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.bfloat16).float(), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N)) + @unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), "need fp8e4m3 support") + def test_rand_fp8e4m3(self): + N = 128 + x = Tensor.rand((2, N, N), dtype=dtypes.fp8e4m3) + assert x.dtype == dtypes.fp8e4m3 + nx = x.numpy() + assert nx[nx == 1].size == 0 + assert nx[nx == 0].size > 0 + equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.fp8e4m3).float(), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N)) + + @unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), "need fp8e5m2 support") + def test_rand_fp8e5m2(self): + N = 128 + x = Tensor.rand((2, N, N), dtype=dtypes.fp8e5m2) + assert x.dtype == dtypes.fp8e5m2 + nx = x.numpy() + assert nx[nx == 1].size == 0 + assert nx[nx == 0].size > 0 + equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.fp8e5m2).float(), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N)) + def test_rand_like(self): empty = Tensor.empty((80, 44)) rand = Tensor.rand_like(empty) diff --git a/tinygrad/device.py b/tinygrad/device.py index adbeefb70b..1b4572f263 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -332,6 +332,8 @@ def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool: if dtype == dtypes.bfloat16: # NOTE: this requires bf16 buffer support return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX")) + if dtype in dtypes.fp8s: + return (device in {"CUDA", "NV"} and not CI and not getenv("PTX")) or (device in {"PYTHON"} and getenv("EMULATE_CUDA_SM89")) if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short, dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half] # for CI GPU and OSX, cl_khr_fp16 isn't supported diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index c2cbc5dc31..e0f2ac386f 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -112,7 +112,8 @@ class dtypes: def finfo(dtype:DType) -> tuple[int, int]: """(exponent, mantissa)""" if not dtypes.is_float(dtype): raise ValueError(f"{dtype} is not a floating point type") - return {dtypes.float16: (5, 10), dtypes.bfloat16: (8, 7), dtypes.float32: (8, 23), dtypes.float64: (11, 52)}[dtype] + return {dtypes.float16: (5, 10), dtypes.bfloat16: (8, 7), dtypes.float32: (8, 23), dtypes.float64: (11, 52), + dtypes.fp8e5m2: (5, 2), dtypes.fp8e4m3: (4, 3)}[dtype] @staticmethod def fields() -> dict[str, DType]: return DTYPES_DICT void: Final[DType] = DType.new(-1, 0, "void", None) @@ -125,11 +126,13 @@ class dtypes: uint32: Final[DType] = DType.new(6, 4, "unsigned int", 'I') int64: Final[DType] = DType.new(7, 8, "long", 'q') uint64: Final[DType] = DType.new(8, 8, "unsigned long", 'Q') - float16: Final[DType] = DType.new(9, 2, "half", 'e') + fp8e4m3: Final[DType] = DType.new(9, 1, "float8_e4m3", None) + fp8e5m2: Final[DType] = DType.new(10, 1, "float8_e5m2", None) + float16: Final[DType] = DType.new(11, 2, "half", 'e') # bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16 - bfloat16: Final[DType] = DType.new(10, 2, "__bf16", None) - float32: Final[DType] = DType.new(11, 4, "float", 'f') - float64: Final[DType] = DType.new(12, 8, "double", 'd') + bfloat16: Final[DType] = DType.new(12, 2, "__bf16", None) + float32: Final[DType] = DType.new(13, 4, "float", 'f') + float64: Final[DType] = DType.new(14, 8, "double", 'd') # dtype aliases half = float16; float = float32; double = float64 # noqa: E702 @@ -145,7 +148,8 @@ class dtypes: default_float: ClassVar[DType] = float32 default_int: ClassVar[DType] = int32 - floats = (float16, bfloat16, float32, float64) + floats = (fp8e4m3, fp8e5m2, float16, bfloat16, float32, float64) + fp8s = (fp8e4m3, fp8e5m2) uints = (uint8, uint16, uint32, uint64) sints = (int8, int16, int32, int64) ints = uints + sints @@ -161,8 +165,9 @@ def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) # https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html # we don't support weak type and complex type promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64], - dtypes.int64: [dtypes.float16, dtypes.bfloat16], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32], - dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float16, dtypes.bfloat16], + dtypes.int64: [dtypes.fp8e5m2, dtypes.fp8e4m3], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32], + dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.fp8e5m2, dtypes.fp8e4m3], + dtypes.fp8e5m2: [dtypes.float16, dtypes.bfloat16], dtypes.fp8e4m3: [dtypes.float16, dtypes.bfloat16], dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], } @functools.cache @@ -170,6 +175,9 @@ def _get_recursive_parents(dtype:DType) -> set[DType]: return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64} @functools.cache def least_upper_dtype(*ds:DType) -> DType: + from tinygrad.device import is_dtype_supported + if not is_dtype_supported(dtypes.fp8e4m3) and not is_dtype_supported(dtypes.fp8e5m2): + promo_lattice[dtypes.int64] = promo_lattice[dtypes.uint64] = [dtypes.float16, dtypes.bfloat16] return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0] def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.default_float) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 809cdcb3db..41a8fd4982 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -37,7 +37,8 @@ base_rewrite = PatternMatcher([ (UPat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"), (UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"), # consts are rendered to larger type and casted - (UPat(Ops.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}f')})"), + (UPat(Ops.CONST, (*dtypes.fp8s, dtypes.bfloat16, dtypes.half), name="x"), + lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}f')})"), (UPat(Ops.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}u')})"), (UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, x.arg)})"), # default const render @@ -331,14 +332,20 @@ class CUDARenderer(CStyleLanguage): (dtypes.half,dtypes.half)]] tc_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts, swizzle=(((6,7,2,3,4),(0,1,8,5,9)), ((6,7,8,0,1),(2,3,4,9,5)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.half,dtypes.half)]] + tc_81632_f8 = [TensorCore(dims=(8,16,32), threads=32, elements_per_thread=(16,8,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts, + swizzle=(((7,8,2,3,4),(0,1,10,5,6,11,9)), ((7,8,10,0,1),(2,3,4,11,5,6,9)))) + for di,do in [(dtypes.fp8e4m3,dtypes.float),(dtypes.fp8e5m2,dtypes.float)]] tc_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.float, dtype_out=dtypes.float, opts=cuda_tc_opts, swizzle=(((5,6,2,3,4),(0,1,8,9,7)), ((5,6,8,0,1),(2,3,4,9,7))))] + tc_sm89 = tc_81616 + tc_8168_f16 + tc_81632_f8 tc_sm80 = tc_81616 + tc_8168_f16 if getenv("ALLOW_TF32", 0): tc_sm80 += tc_8168_tf32 tc_sm75 = tc_8168_f16 - def __init__(self, arch:str): - self.tensor_cores, self.arch = CUDARenderer.tc_sm80 if int(arch[3:]) >= 80 else CUDARenderer.tc_sm75 if int(arch[3:]) >= 75 else [], arch + def __init__(self, arch: str): + self.arch = arch + tensor_cores_map = {89: CUDARenderer.tc_sm89, 80: CUDARenderer.tc_sm80, 75: CUDARenderer.tc_sm75} + self.tensor_cores = next((tc for version, tc in sorted(tensor_cores_map.items(), reverse=True) if int(arch[3:]) >= version), []) def __reduce__(self): return self.__class__, (self.arch,) # language options @@ -355,7 +362,19 @@ class CUDARenderer(CStyleLanguage): Ops.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})", Ops.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})", Ops.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})" } - type_map = {dtypes.bfloat16: "nv_bfloat16"} + type_map = {dtypes.bfloat16: "nv_bfloat16", dtypes.fp8e4m3: "__nv_fp8_e4m3", dtypes.fp8e5m2: "__nv_fp8_e5m2"} + + @staticmethod + def __create_fp8_patterns(dtype) -> list: + return [(UPat(Ops.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtype), UPat.var("y", dtype=dtype))), + lambda b, x, y, dtype=dtype: UOp(Ops.WHERE, dtype=dtypes.float, src=(b, x.cast(dtypes.float), y.cast(dtypes.float))).cast(dtype)), + (UPat(GroupOp.ALU, dtype=dtype, name="x"), + lambda x, dtype=dtype: UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtype)), + (UPat(GroupOp.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtype), UPat.var("y", dtype=dtype))), + lambda alu, x, y, dtype=dtype: UOp(alu.op, dtypes.bool, (x.cast(dtypes.float), y.cast(dtypes.float)), alu.arg)), + (UPat((Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN), dtype=dtype, name="x"), + lambda x, dtype=dtype: UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtype)),] + extra_matcher = PatternMatcher(__create_fp8_patterns(dtypes.fp8e4m3) + __create_fp8_patterns(dtypes.fp8e5m2)) + extra_pm def render_vector_prefix(self, dt:DType) -> str: vec, scal = self.render_dtype(dt), self.render_dtype(dt.scalar()), @@ -367,11 +386,13 @@ class CUDARenderer(CStyleLanguage): prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"] used_dtypes = uops_to_dtypes(uops) + if any(dt.scalar() in dtypes.fp8s for dt in used_dtypes): prefix.append("#include ") if any(dt.scalar() == dtypes.half for dt in used_dtypes): prefix.append("#include ") if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("#include ") - prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16}] + prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if (dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16}) + or (dt.count in (8,16) and dt.scalar() in dtypes.fp8s)] - dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16" } + dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16", dtypes.fp8e4m3: "e4m3", dtypes.fp8e5m2: "e5m2"} dt_map_out = { dtypes.float: "f32", dtypes.half: "f16" } for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]): upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes] diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 3d66448c0c..cd5627b80e 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -10,6 +10,10 @@ from tinygrad.device import Compiled, Compiler, Allocator from tinygrad.ops import exec_alu, Ops, UOp, GroupOp from tinygrad.renderer import Renderer from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer, IntelRenderer, ClangRenderer +if getenv("EMULATE_CUDA_SM89"): + import numpy as np + from ml_dtypes import float8_e4m3, float8_e5m2 + truncate.update({dtypes.fp8e4m3: float8_e4m3, dtypes.fp8e5m2: float8_e5m2, float8_e5m2.dtype: np.float32, float8_e4m3.dtype: np.float32}) def _load(m, i): if i is None: return 0.0 @@ -67,10 +71,13 @@ class PythonProgram: assert dtype is not None, f"{uop} is missing a dtype" dl[i] = dtype if uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL}: - assert dtype.fmt is not None and isinstance(dtype, PtrDType) + assert isinstance(dtype, PtrDType) if TYPE_CHECKING or sys.version_info < (3, 12): assert dtype.fmt != "e" buf = memoryview(bytearray(dtype.size*dtype.itemsize)) if uop is Ops.DEFINE_LOCAL else pbufs.pop(0) - ul[i] = [buf.cast(dtype.fmt)] * warp_size + if dtype.base in dtypes.fp8s: ul[i] = [np.frombuffer(buf, dtype=dtype.name)] * warp_size + else: + assert dtype.fmt is not None + ul[i] = [buf.cast(dtype.fmt)] * warp_size elif uop is Ops.DEFINE_VAR: ul[i] = [pvals.pop(0)] * warp_size elif uop is Ops.SPECIAL: @@ -100,11 +107,13 @@ class PythonProgram: i = loop_ends[i] + 1 continue elif uop is Ops.VECTORIZE: ul[i] = inp - elif uop in {Ops.CAST, Ops.BITCAST}: - assert dtp[0].fmt and dtype.fmt - pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt - if uop is Ops.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0]))) - else: ul[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in inp[0]] + elif uop is Ops.BITCAST: + assert (dtp[0].fmt and dtype.fmt) or (dtp[0] in dtypes.fp8s and dtype) or (dtype in dtypes.fp8s and dtp[0]) + if dtp[0] in dtypes.fp8s: packed = b''.join([truncate.get(dtp[0], lambda dt: dt)(z).tobytes() for z in inp[0]]) + else: packed = struct.pack(str(warp_size) + str(dtp[0].fmt), *inp[0]) + if dtype in dtypes.fp8s: ul[i] = np.frombuffer(packed,dtype=truncate.get(dtype, lambda x: x)).tolist() + else: ul[i] = list(struct.unpack(str(warp_size) + str(dtype.fmt), packed)) + elif uop is Ops.CAST: ul[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in inp[0]] elif uop is Ops.LOAD: if dtype.count > 1: ul[i] = [load([inp[i][j] if i != 0 and dtp[i].count > 1 else inp[i] for i in range(len(inp))], j) for j in range(dtype.count)] @@ -126,7 +135,8 @@ class PythonProgram: for lane_id in range(WARP_THREADS): for elem_idx in range(NUM_C): # calculate new muls and add to acc (c_i, c_j) = c_map(lane_id, elem_idx) - out[elem_idx][goff+lane_id] += sum(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff) for _k in range(K)) + def cast_fn(x): return truncate.get(x.dtype, lambda x: x)(x) if dtp[0].scalar() in dtypes.fp8s else x + out[elem_idx][goff+lane_id] += sum(cast_fn(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff)) for _k in range(K)) return out # TODO: refactor these to a shared TensorCoreLayout in kernel.py @@ -169,6 +179,11 @@ class PythonProgram: def b_elem(x, col, k, goff): return x[k//4][goff + k%4 + col*4] ul[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map) + elif arg[1] == (8,16,32): + def a_elem(x, k, row, goff): return x[k%4 + (k//16)*8 + (row//8)*4][goff + (k//4)%4 + (row%8)*4] + def b_elem(x, col, k, goff): return x[k%4 + (k//16)*4][goff + (k//4)%4 + col*4] + ul[i] = wmma_helper(32, 32, 16, 8, 4, a_elem, b_elem, c_map) + else: raise NotImplementedError(f"unimplemented tensor core {arg}") elif arg[4] == "INTEL": # A (16 elements on 8 threads) @@ -199,6 +214,7 @@ class PythonRenderer(Renderer): if getenv("EMULATE_AMD_MFMA"): self.device, self.tensor_cores = "AMD", AMDRenderer.tensor_cores_mfma if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tc_sm80 if getenv("EMULATE_CUDA_SM75"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tc_sm75 + if getenv("EMULATE_CUDA_SM89"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tc_sm89 if getenv("EMULATE_INTEL"): self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", IntelRenderer.tensor_cores if getenv("EMULATE_AMX"): self.device, self.tensor_cores = "CPU", ClangRenderer.tensor_cores diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index b31b4a8288..41ee307d99 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -149,7 +149,8 @@ class Tensor(SimpleMathTrait): if dtype is None: if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float # NOTE: this works because all_int([True, False]) is True - if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata + if dtype in [dtypes.bfloat16, *dtypes.fp8s]: + data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtype).lazydata else: data = _frompy(data, dtype) elif str(type(data)) == "": import numpy as np @@ -344,7 +345,7 @@ class Tensor(SimpleMathTrait): """ assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}" import numpy as np - if self.dtype.base == dtypes.bfloat16: return self.float().numpy() + if self.dtype.base in [dtypes.bfloat16, *dtypes.fp8s]: return self.float().numpy() if 0 in self.shape: return np.empty(self.shape, dtype=_to_np_dtype(self.dtype.base)) return self._buffer().numpy().reshape(self.shape) @@ -1590,7 +1591,7 @@ class Tensor(SimpleMathTrait): ``` """ ret = self.cast(sum_acc_dtype(self.dtype) if dtype is None else dtype)._reduce(Ops.ADD, axis, keepdim) - return ret.cast(self.dtype) if dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret + return ret.cast(self.dtype) if dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16, *dtypes.fp8s) else ret def prod(self, axis:int|Sequence[int]|None=None, keepdim=False, dtype:DTypeLike|None=None) -> Tensor: """