FP8 support on NVIDIA (#8631)

* squashed fp8 commits

* tensorcore start

* minor changes

* pre-commit

* pylint

* Delete fp8mul.cu

* clean

* small bugfix

* fix test_dtype

* fix test_dtype_alu

* add EMULATE_CUDA_SM89

* fix ci

* fix test_linearizer

* fix test_linearizer

* fix swizzle

* add debug to simple_matmul

* fixed swizzle

* python emulator

* refactor python emulator

* setup fix

* numpy setup

* ml_dtypes only in emulate_cuda_sm89

* fix pylint

* fix tests

* fix mypy

* fix mypy

* fix ruff

* done python emulator

* add acc type

* tests

* mypy

* clean code

* add cuda tensor core tests to CI

* minor fix

* clean test_dtype.py

* clean cstyle.py

* clean test_ops.py

* fix test

* fix test

* whitespaces

* pylint

* pylint

* amd?

* amd?

* amd

* reduce lines

* mockgpu remove

* fix

* ruff

* ruff

* fix mypy

* ruff

* test only for cuda

* fixed formatting

* small fixes

* small fix

* least_upper_dtype if fp8s not supported

* log and reciprocal are supported for fp8s

* ops python fixes

* dtypes.fp8s use

* e4m3 + e5m2 result dtype test

* truncate linter fix

---------

Co-authored-by: pkotzbach <pawkotz@gmail.com>
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
pkotzbach
2025-04-09 03:54:04 +02:00
committed by GitHub
parent 5d85765327
commit 2c8e4ea865
13 changed files with 203 additions and 45 deletions

View File

@@ -243,6 +243,8 @@ jobs:
DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm_fp16
DEBUG=2 EMULATE_CUDA=1 ALLOW_TF32=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm
DEBUG=2 EMULATE_CUDA_SM75=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm_fp16
DEBUG=2 EMULATE_CUDA_SM89=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOpsFp8s.test_gemm_fp8e4m3
DEBUG=2 EMULATE_CUDA_SM89=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOpsFp8s.test_gemm_fp8e5m2
PYTHONPATH="." DEBUG=2 EMULATE_CUDA=1 ALLOW_TF32=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
- name: Test emulated INTEL OpenCL tensor cores
run: DEBUG=2 EMULATE_INTEL=1 FORWARD_ONLY=1 PYTHON=1 HALF=1 N=64 python3 ./extra/gemm/simple_matmul.py

View File

@@ -3,9 +3,11 @@ from tinygrad.helpers import getenv
from tinygrad.dtype import _to_np_dtype
from tinygrad import dtypes, Tensor
dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
acc_dtype = dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else None
if getenv("INT"): dtype_in, acc_dtype = dtypes.int8, dtypes.int32
dtype_in = (dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else
dtypes.fp8e4m3 if getenv("FP8E4M3") else dtypes.fp8e5m2 if getenv("FP8E5M2") else dtypes.float)
acc_dtype = (dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else
dtypes.fp8e4m3 if getenv("ACC_FP8E4M3") else dtypes.fp8e5m2 if getenv("ACC_FP8E5M2") else None)
if getenv("INT"): dtype_in = dtypes.int8acc_dtype = dtypes.int32
if getenv("UINT"): dtype_in, acc_dtype = dtypes.uint8, dtypes.int32
N = getenv("N", 4096)

View File

@@ -70,6 +70,7 @@ setup(name='tinygrad',
"bottle",
"ggml-python",
"capstone",
"ml_dtypes",
"pycocotools",
"boto3",
"pandas"

View File

@@ -39,7 +39,8 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target, tol_target_dtype:float
if DEBUG >= 2: print(tensor.numpy())
try:
assert tensor.dtype == target_dtype
np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2}.get(target_dtype, tol_target_dtype))
np.testing.assert_allclose(tensor.numpy(), target,
rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2, dtypes.fp8e5m2: 1, dtypes.fp8e4m3: 1e-1}.get(target_dtype, tol_target_dtype))
except AssertionError as e:
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
@@ -56,6 +57,7 @@ def _test_cast(a:Tensor, target_dtype:DType):
_test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(_to_np_dtype(target_dtype))))
def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
if target_dtype == dtypes.bfloat16: raise unittest.SkipTest("no test for bf16 bitcast yet")
if target_dtype in dtypes.fp8s: raise unittest.SkipTest("no test for fp8s bitcast yet")
if getenv("PTX") and a.dtype == dtypes.int8 and target_dtype.itemsize != a.dtype.itemsize:
raise unittest.SkipTest("shape changing bitcast of int8 broken on PTX")
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(_to_np_dtype(target_dtype)).tolist())
@@ -109,7 +111,7 @@ class TestDType(unittest.TestCase):
fields = dtypes.fields()
self.assertIn("float", fields)
self.assertIn("float32", fields)
self.assertEqual(len(fields), 24)
self.assertEqual(len(fields), 26)
self.assertTrue(all(isinstance(value, DType) for value in fields.values()))
self.assertTrue(all(issubclass(_to_np_dtype(value), np.generic) for value in fields.values() if _to_np_dtype(value) is not None))
@@ -205,6 +207,31 @@ class TestBFloat16DTypeCast(unittest.TestCase):
converted = random_values.cast(dtypes.bfloat16).cast(dtypes.float32)
np.testing.assert_allclose(converted.numpy(), random_values.cast(dtypes.float32).numpy(), rtol=1e-2, atol=1e-3)
class TestFp8sDType(unittest.TestCase):
def _float_to_fp8_conversion_test(self, dtype, input_values, expected_values):
test_tensor = Tensor(input_values).cast(dtype).realize()
back_to_float32 = test_tensor.cast(dtypes.float32)
np.testing.assert_equal(tuple(back_to_float32.numpy().tolist()), expected_values)
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), "fp8e4m3 not supported")
def test_float_to_fp8e4m3_conversion(self):
self._float_to_fp8_conversion_test(dtypes.fp8e4m3,
[10000000.0, -1.0, 402.0, -300.0, -10000000.0, 20.0, 1.4123, 0.0, math.inf, math.nan],
[448.0, -1.0, 416.0, -288.0, -448.0, 20.0, 1.375, 0.0, 448.0, math.nan])
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), "fp8e5m2 not supported")
def test_float_to_fp8e5m2_conversion(self):
self._float_to_fp8_conversion_test(dtypes.fp8e5m2,
[10000000.0, -1.0, 402.0, -300.0, -10000000.0, 20.0, 1.4123, 0.0, math.inf, math.nan],
[57344.0, -1, 384, -320, -57344.0, 20, 1.5, 0.0, 57344.0, math.nan])
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3) and is_dtype_supported(dtypes.fp8e5m2), "fp8s not supported")
def test_fp8e4m3_plus_fp8e5m2_output_dtype(self):
a = Tensor([1.0, 2.0, 3.0], dtype=dtypes.fp8e4m3)
b = Tensor([1.0, 2.0, 3.0], dtype=dtypes.fp8e5m2)
result = a + b
self.assertEqual(result.dtype, dtypes.half)
class TestHalfDType(TestDType): DTYPE = dtypes.half
class TestFloatDType(TestDType):
@@ -264,6 +291,7 @@ class TestBitCast(unittest.TestCase):
def test_shape_change_bitcast(self, dt1, dt2):
# NOTE: this has to be assume to prevent hypothesis from skipping all samples
assume(dt2 != dtypes.bfloat16 and dt1 != dtypes.bfloat16) # no test for bf16 bitcast yet
assume(dt1 not in dtypes.fp8s and dt2 not in dtypes.fp8s) # no test for fp8 bitcast yet
assume(not (getenv("PTX") and dt1 == dtypes.int8)) # TODO: bitcasting int8 fails in PTX
data = rand_for_dtype(dt1, 32).reshape(2, 2, 8)
_test_op(lambda: Tensor(data, dtype=dt1).bitcast(dt2), dt2, data.view(_to_np_dtype(dt2)).tolist())
@@ -399,6 +427,10 @@ class TestHelpers(unittest.TestCase):
def test_bf16_is_float(self):
assert dtypes.is_float(dtypes.bfloat16)
def test_fp8s_are_float(self):
assert dtypes.is_float(dtypes.fp8e4m3)
assert dtypes.is_float(dtypes.fp8e5m2)
@given(strat.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_float(d) or dtypes.is_int(d)]), strat.integers(min_value=2, max_value=8))
def test_scalar(self, dtype, amt):
assert dtype.vec(amt).scalar() == dtype
@@ -462,7 +494,7 @@ class TestTypeSpec(unittest.TestCase):
dtypes.default_int = default_int
assert dtypes.default_int == default_int
for default_float in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
for default_float in [*dtypes.fp8s, dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
dtypes.default_float = default_float
assert dtypes.default_float == default_float
@@ -693,7 +725,9 @@ class TestAutoCastType(unittest.TestCase):
assert (Tensor([0, 1], dtype=dtypes.uint32)).sum().dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint64)).sum().dtype == dtypes.uint64
assert (Tensor([0, 1], dtype=dtypes.float16)).sum().dtype == dtypes.float16
#assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).sum().dtype == dtypes.fp8e4m3
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).sum().dtype == dtypes.fp8e5m2
assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64
@@ -725,7 +759,9 @@ class TestAutoCastType(unittest.TestCase):
assert (Tensor([0, 1], dtype=dtypes.uint32)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.uint64)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float16)).mean().dtype == dtypes.float16
#assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).mean().dtype == dtypes.fp8e4m3
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).mean().dtype == dtypes.fp8e5m2
assert (Tensor([0, 1], dtype=dtypes.float32)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).mean().dtype == dtypes.float64
@@ -740,7 +776,9 @@ class TestAutoCastType(unittest.TestCase):
assert (Tensor([0, 1], dtype=dtypes.uint32)).cumsum(0).dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint64)).cumsum(0).dtype == dtypes.uint64
assert (Tensor([0, 1], dtype=dtypes.float16)).cumsum(0).dtype == dtypes.float16
#assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).cumsum().dtype == dtypes.fp8e4m3
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).cumsum().dtype == dtypes.fp8e5m2
assert (Tensor([0, 1], dtype=dtypes.float32)).cumsum(0).dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).cumsum(0).dtype == dtypes.float64
@@ -801,10 +839,10 @@ class TestAutoCastType(unittest.TestCase):
def test_gradient_dtype(self):
old_default_float = dtypes.default_float
for default_dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
for default_dtype in [*dtypes.fp8s, dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
if not is_dtype_supported(default_dtype): continue
dtypes.default_float = default_dtype
for dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
for dtype in [*dtypes.fp8s, dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
if not is_dtype_supported(dtype): continue
if DEBUG >= 2:
print(f"testing {default_dtype=}, {dtype=}")

View File

@@ -3,7 +3,7 @@ import unittest
from tinygrad import Tensor, dtypes, Device
import operator
import numpy as np
from hypothesis import given, strategies as strat, settings, HealthCheck
from hypothesis import given, strategies as strat, settings, HealthCheck, assume
from tinygrad.dtype import DType
from tinygrad.helpers import CI, getenv
from tinygrad.engine.realize import run_schedule
@@ -48,6 +48,8 @@ class ht:
float32 = strat.floats(width=32, allow_subnormal=False)
float16 = strat.floats(width=16, allow_subnormal=False)
bfloat16 = strat.floats(width=16, allow_subnormal=False)
fp8e4m3 = strat.sampled_from([0.0, 0.1, 0.5, 1.0, -0.1, -0.5, -1.0])
fp8e5m2 = strat.sampled_from([0.0, 0.1, 0.5, 1.0, -0.1, -0.5, -1.0])
uint8 = strat.integers(0, 255)
uint16 = strat.integers(0, 65535)
uint32 = strat.integers(0, 2**32-1)
@@ -64,7 +66,8 @@ def universal_test(a, b, dtype, op):
if not isinstance(op, tuple): op = (op, op)
tensor_value = (op[0](Tensor([a], dtype=dtype), Tensor([b], dtype=dtype))).numpy()
numpy_value = op[1](np.array([a]).astype(_to_np_dtype(dtype)), np.array([b]).astype(_to_np_dtype(dtype)))
if dtype is dtypes.bfloat16: np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-3, rtol=1e-2)
if dtype in dtypes.fp8s: np.testing.assert_allclose(tensor_value, numpy_value, atol=0.5, rtol=1e-2)
elif dtype is dtypes.bfloat16: np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-3, rtol=1e-2)
elif dtype in dtypes_float: np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-10)
else: np.testing.assert_equal(tensor_value, numpy_value)
@@ -76,7 +79,8 @@ def universal_test_unary(a, dtype, op):
run_schedule(sched)
tensor_value = out.numpy()
numpy_value = op[1](np.array([a]).astype(_to_np_dtype(dtype)))
if dtype in (*dtypes_float, dtypes.bfloat16):
if dtype in dtypes.fp8s: np.testing.assert_allclose(tensor_value, numpy_value, atol=2, rtol=1e-2)
elif dtype in (*dtypes_float, dtypes.bfloat16):
np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-3, rtol=1e-2)
else: np.testing.assert_equal(tensor_value, numpy_value)
if op[0] != Tensor.reciprocal: # reciprocal is not supported in most backends
@@ -114,6 +118,26 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.bfloat16, ht.bfloat16, strat.sampled_from(binary_operations))
def test_bfloat16(self, a, b, op): universal_test(a, b, dtypes.bfloat16, op)
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3, Device.DEFAULT), f"no fp8e4m3 on {Device.DEFAULT}")
@given(ht.fp8e4m3, ht.fp8e4m3, strat.sampled_from(binary_operations))
def test_fp8e4m3(self, a, b, op): universal_test(a, b, dtypes.fp8e4m3, op)
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2, Device.DEFAULT), f"no fp8e5m2 on {Device.DEFAULT}")
@given(ht.fp8e5m2, ht.fp8e5m2, strat.sampled_from(binary_operations))
def test_fp8e5m2(self, a, b, op): universal_test(a, b, dtypes.fp8e5m2, op)
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3, Device.DEFAULT), f"no fp8e4m3 on {Device.DEFAULT}")
@given(ht.fp8e4m3, strat.sampled_from(unary_operations))
def test_fp8e4m3_unary(self, a, op):
if (op[1] == np.reciprocal or op[1] == np.log): assume(a != 0.0) # reciprocal(0) and log(0) are undefined
universal_test_unary(a, dtypes.fp8e4m3, op)
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2, Device.DEFAULT), f"no fp8e5m2 on {Device.DEFAULT}")
@given(ht.fp8e5m2, strat.sampled_from(unary_operations))
def test_fp8e5m2_unary(self, a, op):
if (op[1] == np.reciprocal or op[1] == np.log): assume(a != 0.0) # reciprocal(0) and log(0) are undefined
universal_test_unary(a, dtypes.fp8e5m2, op)
@given(ht.float32, strat.sampled_from(unary_operations))
def test_float32_unary(self, a, op): universal_test_unary(a, dtypes.float32, op)

View File

@@ -41,6 +41,7 @@ def helper_tc_allclose(N:int, M:int, K:int, dtype_in:DType, dtype_out:DType, axi
np_c = np_a @ np_b
if dtype_in == dtypes.half: tc_atol, tc_rtol = 1e-2, 1e-3
elif dtype_in == dtypes.bfloat16: tc_atol, tc_rtol = 1e-2, 1e-2
elif dtype_in in dtypes.fp8s: tc_atol, tc_rtol = 1e-1, 1e-2
else: tc_atol, tc_rtol = 5e-3, 1e-4
np.testing.assert_allclose(np_c, out, atol=tc_atol, rtol=tc_rtol)
@@ -62,6 +63,13 @@ def helper_tc_ensure_uops_and_opts_count(N: int, M:int, K:int, dtype_in:DType, d
assert wmmas == 0, "tensor core is incorrectly triggered"
assert tcs == 0, "tensor core opt is incorrectly included"
def is_emulated(tc) -> bool:
return (getenv("EMULATE_CUDA_SM89") or getenv("EMULATE_CUDA") or getenv("EMULATE_INTEL") or getenv("EMULATE_METAL")
or getenv("EMULATE_AMD_MFMA") or getenv("EMULATE_AMD")) and \
((tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16) or \
(tc.dtype_in == dtypes.fp8e4m3 or tc.dtype_out == dtypes.fp8e4m3) or \
(tc.dtype_in == dtypes.fp8e5m2 or tc.dtype_out == dtypes.fp8e5m2))
class TestLinearizer(unittest.TestCase):
def test_arg_dedup(self):
# NOTE: this realize exists because Tensor.numpy calls .contiguous() internally
@@ -1024,7 +1032,8 @@ class TestLinearizer(unittest.TestCase):
def test_sum_acc_dtype(self):
for tensor_dtype, acc_dtype in (
(dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)):
(dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float),
(dtypes.fp8e4m3, dtypes.float), (dtypes.fp8e5m2, dtypes.float)):
if is_dtype_supported(tensor_dtype) and is_dtype_supported(acc_dtype):
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
k = Kernel(a.schedule()[-1].ast)
@@ -1046,6 +1055,10 @@ class TestLinearizer(unittest.TestCase):
(dtypes.float16, dtypes.float16, dtypes.float16),
(dtypes.bfloat16, dtypes.bfloat16, dtypes.bfloat16),
(dtypes.float, dtypes.float16, dtypes.float16),
(dtypes.fp8e5m2, dtypes.fp8e5m2, dtypes.fp8e5m2),
(dtypes.fp8e4m3, dtypes.fp8e4m3, dtypes.fp8e4m3),
(dtypes.fp8e4m3, None, dtypes.float),
(dtypes.fp8e5m2, None, dtypes.float),
)
for tensor_dtype, acc_dtype, expected_dtype in tests:
if is_dtype_supported(tensor_dtype) and is_dtype_supported(acc_dtype) and is_dtype_supported(expected_dtype):
@@ -1060,8 +1073,7 @@ class TestLinearizer(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_cores(self):
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
if (getenv("EMULATE_CUDA") or getenv("EMULATE_INTEL") or getenv("EMULATE_METAL") or getenv("EMULATE_AMD_MFMA") or getenv("EMULATE_AMD")) and \
(tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
if is_emulated(tc): continue
if CI and Device.DEFAULT in ("METAL", "AMD") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
# for AMX, tc.dims[2] == 1 so reduceop is None thus tensor_cores are not triggered
helper_tc_allclose(tc.dims[0], tc.dims[1], 2 if AMX else tc.dims[2], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0)
@@ -1089,8 +1101,7 @@ class TestLinearizer(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_cores_padded(self):
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
if (getenv("EMULATE_CUDA") or getenv("EMULATE_METAL") or getenv("EMULATE_AMD_MFMA") or getenv("EMULATE_AMD")) and \
(tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
if is_emulated(tc): continue
if CI and Device.DEFAULT in ("METAL", "AMD") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
pad = 1
@@ -1117,7 +1128,8 @@ class TestLinearizer(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_cores_multi_reduce(self):
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
if tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16: continue
if tc.dtype_in in [dtypes.bfloat16, *dtypes.fp8s] or tc.dtype_out in [dtypes.bfloat16, *dtypes.fp8s]:
continue
# this will be a M=G16, N=G32, M=G16, M=G16, K=R16, K=R16, K=R16 with 9 choices of TC MNK axes
golden_result = None
for axis in range(9):

View File

@@ -3003,6 +3003,17 @@ class TestOpsUint8(unittest.TestCase):
lambda x: x.type(torch.uint8).min(),
lambda x: x.cast(dtypes.uint8).min(), forward_only=True, vals=[[0, 128, 255, 64, 32, 16]])
@unittest.skipUnless("CUDA" in Device.get_available_devices() and Device.DEFAULT == "PYTHON" and getenv("EMULATE_CUDA_SM89"),
"only for emulated CUDA")
class TestOpsFp8s(unittest.TestCase):
def _compare_to_cuda(self, shp_a, shp_b, op, dtype):
a = Tensor.rand(shp_a, dtype=dtype)
b = Tensor.rand(shp_b, dtype=dtype)
np.testing.assert_equal(op(a, b).numpy(), op(a.to("CUDA"), b.to("CUDA")).numpy())
def test_gemm_fp8e4m3(self): self._compare_to_cuda((64, 64), (64, 64), lambda x, y: x.matmul(y), dtypes.fp8e4m3)
def test_gemm_fp8e5m2(self): self._compare_to_cuda((64, 64), (64, 64), lambda x, y: x.matmul(y), dtypes.fp8e5m2)
if __name__ == '__main__':
np.random.seed(1337)
unittest.main(verbosity=2)

View File

@@ -193,6 +193,26 @@ class TestRandomness(unittest.TestCase):
assert nx[nx == 0].size > 0
equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.bfloat16).float(), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N))
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), "need fp8e4m3 support")
def test_rand_fp8e4m3(self):
N = 128
x = Tensor.rand((2, N, N), dtype=dtypes.fp8e4m3)
assert x.dtype == dtypes.fp8e4m3
nx = x.numpy()
assert nx[nx == 1].size == 0
assert nx[nx == 0].size > 0
equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.fp8e4m3).float(), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N))
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), "need fp8e5m2 support")
def test_rand_fp8e5m2(self):
N = 128
x = Tensor.rand((2, N, N), dtype=dtypes.fp8e5m2)
assert x.dtype == dtypes.fp8e5m2
nx = x.numpy()
assert nx[nx == 1].size == 0
assert nx[nx == 0].size > 0
equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.fp8e5m2).float(), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N))
def test_rand_like(self):
empty = Tensor.empty((80, 44))
rand = Tensor.rand_like(empty)

View File

@@ -332,6 +332,8 @@ def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool:
if dtype == dtypes.bfloat16:
# NOTE: this requires bf16 buffer support
return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
if dtype in dtypes.fp8s:
return (device in {"CUDA", "NV"} and not CI and not getenv("PTX")) or (device in {"PYTHON"} and getenv("EMULATE_CUDA_SM89"))
if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short,
dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half]
# for CI GPU and OSX, cl_khr_fp16 isn't supported

View File

@@ -112,7 +112,8 @@ class dtypes:
def finfo(dtype:DType) -> tuple[int, int]:
"""(exponent, mantissa)"""
if not dtypes.is_float(dtype): raise ValueError(f"{dtype} is not a floating point type")
return {dtypes.float16: (5, 10), dtypes.bfloat16: (8, 7), dtypes.float32: (8, 23), dtypes.float64: (11, 52)}[dtype]
return {dtypes.float16: (5, 10), dtypes.bfloat16: (8, 7), dtypes.float32: (8, 23), dtypes.float64: (11, 52),
dtypes.fp8e5m2: (5, 2), dtypes.fp8e4m3: (4, 3)}[dtype]
@staticmethod
def fields() -> dict[str, DType]: return DTYPES_DICT
void: Final[DType] = DType.new(-1, 0, "void", None)
@@ -125,11 +126,13 @@ class dtypes:
uint32: Final[DType] = DType.new(6, 4, "unsigned int", 'I')
int64: Final[DType] = DType.new(7, 8, "long", 'q')
uint64: Final[DType] = DType.new(8, 8, "unsigned long", 'Q')
float16: Final[DType] = DType.new(9, 2, "half", 'e')
fp8e4m3: Final[DType] = DType.new(9, 1, "float8_e4m3", None)
fp8e5m2: Final[DType] = DType.new(10, 1, "float8_e5m2", None)
float16: Final[DType] = DType.new(11, 2, "half", 'e')
# bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16
bfloat16: Final[DType] = DType.new(10, 2, "__bf16", None)
float32: Final[DType] = DType.new(11, 4, "float", 'f')
float64: Final[DType] = DType.new(12, 8, "double", 'd')
bfloat16: Final[DType] = DType.new(12, 2, "__bf16", None)
float32: Final[DType] = DType.new(13, 4, "float", 'f')
float64: Final[DType] = DType.new(14, 8, "double", 'd')
# dtype aliases
half = float16; float = float32; double = float64 # noqa: E702
@@ -145,7 +148,8 @@ class dtypes:
default_float: ClassVar[DType] = float32
default_int: ClassVar[DType] = int32
floats = (float16, bfloat16, float32, float64)
floats = (fp8e4m3, fp8e5m2, float16, bfloat16, float32, float64)
fp8s = (fp8e4m3, fp8e5m2)
uints = (uint8, uint16, uint32, uint64)
sints = (int8, int16, int32, int64)
ints = uints + sints
@@ -161,8 +165,9 @@ def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType)
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
# we don't support weak type and complex type
promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64],
dtypes.int64: [dtypes.float16, dtypes.bfloat16], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32],
dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float16, dtypes.bfloat16],
dtypes.int64: [dtypes.fp8e5m2, dtypes.fp8e4m3], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32],
dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.fp8e5m2, dtypes.fp8e4m3],
dtypes.fp8e5m2: [dtypes.float16, dtypes.bfloat16], dtypes.fp8e4m3: [dtypes.float16, dtypes.bfloat16],
dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], }
@functools.cache
@@ -170,6 +175,9 @@ def _get_recursive_parents(dtype:DType) -> set[DType]:
return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64}
@functools.cache
def least_upper_dtype(*ds:DType) -> DType:
from tinygrad.device import is_dtype_supported
if not is_dtype_supported(dtypes.fp8e4m3) and not is_dtype_supported(dtypes.fp8e5m2):
promo_lattice[dtypes.int64] = promo_lattice[dtypes.uint64] = [dtypes.float16, dtypes.bfloat16]
return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.default_float)

View File

@@ -37,7 +37,8 @@ base_rewrite = PatternMatcher([
(UPat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"),
(UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"),
# consts are rendered to larger type and casted
(UPat(Ops.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}f')})"),
(UPat(Ops.CONST, (*dtypes.fp8s, dtypes.bfloat16, dtypes.half), name="x"),
lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}f')})"),
(UPat(Ops.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}u')})"),
(UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, x.arg)})"),
# default const render
@@ -331,14 +332,20 @@ class CUDARenderer(CStyleLanguage):
(dtypes.half,dtypes.half)]]
tc_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
swizzle=(((6,7,2,3,4),(0,1,8,5,9)), ((6,7,8,0,1),(2,3,4,9,5)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.half,dtypes.half)]]
tc_81632_f8 = [TensorCore(dims=(8,16,32), threads=32, elements_per_thread=(16,8,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
swizzle=(((7,8,2,3,4),(0,1,10,5,6,11,9)), ((7,8,10,0,1),(2,3,4,11,5,6,9))))
for di,do in [(dtypes.fp8e4m3,dtypes.float),(dtypes.fp8e5m2,dtypes.float)]]
tc_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.float, dtype_out=dtypes.float, opts=cuda_tc_opts,
swizzle=(((5,6,2,3,4),(0,1,8,9,7)), ((5,6,8,0,1),(2,3,4,9,7))))]
tc_sm89 = tc_81616 + tc_8168_f16 + tc_81632_f8
tc_sm80 = tc_81616 + tc_8168_f16
if getenv("ALLOW_TF32", 0): tc_sm80 += tc_8168_tf32
tc_sm75 = tc_8168_f16
def __init__(self, arch:str):
self.tensor_cores, self.arch = CUDARenderer.tc_sm80 if int(arch[3:]) >= 80 else CUDARenderer.tc_sm75 if int(arch[3:]) >= 75 else [], arch
def __init__(self, arch: str):
self.arch = arch
tensor_cores_map = {89: CUDARenderer.tc_sm89, 80: CUDARenderer.tc_sm80, 75: CUDARenderer.tc_sm75}
self.tensor_cores = next((tc for version, tc in sorted(tensor_cores_map.items(), reverse=True) if int(arch[3:]) >= version), [])
def __reduce__(self): return self.__class__, (self.arch,)
# language options
@@ -355,7 +362,19 @@ class CUDARenderer(CStyleLanguage):
Ops.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",
Ops.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
Ops.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})" }
type_map = {dtypes.bfloat16: "nv_bfloat16"}
type_map = {dtypes.bfloat16: "nv_bfloat16", dtypes.fp8e4m3: "__nv_fp8_e4m3", dtypes.fp8e5m2: "__nv_fp8_e5m2"}
@staticmethod
def __create_fp8_patterns(dtype) -> list:
return [(UPat(Ops.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtype), UPat.var("y", dtype=dtype))),
lambda b, x, y, dtype=dtype: UOp(Ops.WHERE, dtype=dtypes.float, src=(b, x.cast(dtypes.float), y.cast(dtypes.float))).cast(dtype)),
(UPat(GroupOp.ALU, dtype=dtype, name="x"),
lambda x, dtype=dtype: UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtype)),
(UPat(GroupOp.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtype), UPat.var("y", dtype=dtype))),
lambda alu, x, y, dtype=dtype: UOp(alu.op, dtypes.bool, (x.cast(dtypes.float), y.cast(dtypes.float)), alu.arg)),
(UPat((Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN), dtype=dtype, name="x"),
lambda x, dtype=dtype: UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtype)),]
extra_matcher = PatternMatcher(__create_fp8_patterns(dtypes.fp8e4m3) + __create_fp8_patterns(dtypes.fp8e5m2)) + extra_pm
def render_vector_prefix(self, dt:DType) -> str:
vec, scal = self.render_dtype(dt), self.render_dtype(dt.scalar()),
@@ -367,11 +386,13 @@ class CUDARenderer(CStyleLanguage):
prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"]
used_dtypes = uops_to_dtypes(uops)
if any(dt.scalar() in dtypes.fp8s for dt in used_dtypes): prefix.append("#include <cuda_fp8.h>")
if any(dt.scalar() == dtypes.half for dt in used_dtypes): prefix.append("#include <cuda_fp16.h>")
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("#include <cuda_bf16.h>")
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16}]
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if (dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16})
or (dt.count in (8,16) and dt.scalar() in dtypes.fp8s)]
dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16" }
dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16", dtypes.fp8e4m3: "e4m3", dtypes.fp8e5m2: "e5m2"}
dt_map_out = { dtypes.float: "f32", dtypes.half: "f16" }
for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes]

View File

@@ -10,6 +10,10 @@ from tinygrad.device import Compiled, Compiler, Allocator
from tinygrad.ops import exec_alu, Ops, UOp, GroupOp
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer, IntelRenderer, ClangRenderer
if getenv("EMULATE_CUDA_SM89"):
import numpy as np
from ml_dtypes import float8_e4m3, float8_e5m2
truncate.update({dtypes.fp8e4m3: float8_e4m3, dtypes.fp8e5m2: float8_e5m2, float8_e5m2.dtype: np.float32, float8_e4m3.dtype: np.float32})
def _load(m, i):
if i is None: return 0.0
@@ -67,10 +71,13 @@ class PythonProgram:
assert dtype is not None, f"{uop} is missing a dtype"
dl[i] = dtype
if uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL}:
assert dtype.fmt is not None and isinstance(dtype, PtrDType)
assert isinstance(dtype, PtrDType)
if TYPE_CHECKING or sys.version_info < (3, 12): assert dtype.fmt != "e"
buf = memoryview(bytearray(dtype.size*dtype.itemsize)) if uop is Ops.DEFINE_LOCAL else pbufs.pop(0)
ul[i] = [buf.cast(dtype.fmt)] * warp_size
if dtype.base in dtypes.fp8s: ul[i] = [np.frombuffer(buf, dtype=dtype.name)] * warp_size
else:
assert dtype.fmt is not None
ul[i] = [buf.cast(dtype.fmt)] * warp_size
elif uop is Ops.DEFINE_VAR:
ul[i] = [pvals.pop(0)] * warp_size
elif uop is Ops.SPECIAL:
@@ -100,11 +107,13 @@ class PythonProgram:
i = loop_ends[i] + 1
continue
elif uop is Ops.VECTORIZE: ul[i] = inp
elif uop in {Ops.CAST, Ops.BITCAST}:
assert dtp[0].fmt and dtype.fmt
pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
if uop is Ops.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
else: ul[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in inp[0]]
elif uop is Ops.BITCAST:
assert (dtp[0].fmt and dtype.fmt) or (dtp[0] in dtypes.fp8s and dtype) or (dtype in dtypes.fp8s and dtp[0])
if dtp[0] in dtypes.fp8s: packed = b''.join([truncate.get(dtp[0], lambda dt: dt)(z).tobytes() for z in inp[0]])
else: packed = struct.pack(str(warp_size) + str(dtp[0].fmt), *inp[0])
if dtype in dtypes.fp8s: ul[i] = np.frombuffer(packed,dtype=truncate.get(dtype, lambda x: x)).tolist()
else: ul[i] = list(struct.unpack(str(warp_size) + str(dtype.fmt), packed))
elif uop is Ops.CAST: ul[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in inp[0]]
elif uop is Ops.LOAD:
if dtype.count > 1:
ul[i] = [load([inp[i][j] if i != 0 and dtp[i].count > 1 else inp[i] for i in range(len(inp))], j) for j in range(dtype.count)]
@@ -126,7 +135,8 @@ class PythonProgram:
for lane_id in range(WARP_THREADS):
for elem_idx in range(NUM_C): # calculate new muls and add to acc
(c_i, c_j) = c_map(lane_id, elem_idx)
out[elem_idx][goff+lane_id] += sum(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff) for _k in range(K))
def cast_fn(x): return truncate.get(x.dtype, lambda x: x)(x) if dtp[0].scalar() in dtypes.fp8s else x
out[elem_idx][goff+lane_id] += sum(cast_fn(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff)) for _k in range(K))
return out
# TODO: refactor these to a shared TensorCoreLayout in kernel.py
@@ -169,6 +179,11 @@ class PythonProgram:
def b_elem(x, col, k, goff): return x[k//4][goff + k%4 + col*4]
ul[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
elif arg[1] == (8,16,32):
def a_elem(x, k, row, goff): return x[k%4 + (k//16)*8 + (row//8)*4][goff + (k//4)%4 + (row%8)*4]
def b_elem(x, col, k, goff): return x[k%4 + (k//16)*4][goff + (k//4)%4 + col*4]
ul[i] = wmma_helper(32, 32, 16, 8, 4, a_elem, b_elem, c_map)
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
elif arg[4] == "INTEL":
# A (16 elements on 8 threads)
@@ -199,6 +214,7 @@ class PythonRenderer(Renderer):
if getenv("EMULATE_AMD_MFMA"): self.device, self.tensor_cores = "AMD", AMDRenderer.tensor_cores_mfma
if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tc_sm80
if getenv("EMULATE_CUDA_SM75"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tc_sm75
if getenv("EMULATE_CUDA_SM89"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tc_sm89
if getenv("EMULATE_INTEL"): self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", IntelRenderer.tensor_cores
if getenv("EMULATE_AMX"): self.device, self.tensor_cores = "CPU", ClangRenderer.tensor_cores

View File

@@ -149,7 +149,8 @@ class Tensor(SimpleMathTrait):
if dtype is None:
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float # NOTE: this works because all_int([True, False]) is True
if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata
if dtype in [dtypes.bfloat16, *dtypes.fp8s]:
data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtype).lazydata
else: data = _frompy(data, dtype)
elif str(type(data)) == "<class 'numpy.ndarray'>":
import numpy as np
@@ -344,7 +345,7 @@ class Tensor(SimpleMathTrait):
"""
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
import numpy as np
if self.dtype.base == dtypes.bfloat16: return self.float().numpy()
if self.dtype.base in [dtypes.bfloat16, *dtypes.fp8s]: return self.float().numpy()
if 0 in self.shape: return np.empty(self.shape, dtype=_to_np_dtype(self.dtype.base))
return self._buffer().numpy().reshape(self.shape)
@@ -1590,7 +1591,7 @@ class Tensor(SimpleMathTrait):
```
"""
ret = self.cast(sum_acc_dtype(self.dtype) if dtype is None else dtype)._reduce(Ops.ADD, axis, keepdim)
return ret.cast(self.dtype) if dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
return ret.cast(self.dtype) if dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16, *dtypes.fp8s) else ret
def prod(self, axis:int|Sequence[int]|None=None, keepdim=False, dtype:DTypeLike|None=None) -> Tensor:
"""