mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
FP8 support on NVIDIA (#8631)
* squashed fp8 commits * tensorcore start * minor changes * pre-commit * pylint * Delete fp8mul.cu * clean * small bugfix * fix test_dtype * fix test_dtype_alu * add EMULATE_CUDA_SM89 * fix ci * fix test_linearizer * fix test_linearizer * fix swizzle * add debug to simple_matmul * fixed swizzle * python emulator * refactor python emulator * setup fix * numpy setup * ml_dtypes only in emulate_cuda_sm89 * fix pylint * fix tests * fix mypy * fix mypy * fix ruff * done python emulator * add acc type * tests * mypy * clean code * add cuda tensor core tests to CI * minor fix * clean test_dtype.py * clean cstyle.py * clean test_ops.py * fix test * fix test * whitespaces * pylint * pylint * amd? * amd? * amd * reduce lines * mockgpu remove * fix * ruff * ruff * fix mypy * ruff * test only for cuda * fixed formatting * small fixes * small fix * least_upper_dtype if fp8s not supported * log and reciprocal are supported for fp8s * ops python fixes * dtypes.fp8s use * e4m3 + e5m2 result dtype test * truncate linter fix --------- Co-authored-by: pkotzbach <pawkotz@gmail.com> Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -243,6 +243,8 @@ jobs:
|
||||
DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm_fp16
|
||||
DEBUG=2 EMULATE_CUDA=1 ALLOW_TF32=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm
|
||||
DEBUG=2 EMULATE_CUDA_SM75=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm_fp16
|
||||
DEBUG=2 EMULATE_CUDA_SM89=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOpsFp8s.test_gemm_fp8e4m3
|
||||
DEBUG=2 EMULATE_CUDA_SM89=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOpsFp8s.test_gemm_fp8e5m2
|
||||
PYTHONPATH="." DEBUG=2 EMULATE_CUDA=1 ALLOW_TF32=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
|
||||
- name: Test emulated INTEL OpenCL tensor cores
|
||||
run: DEBUG=2 EMULATE_INTEL=1 FORWARD_ONLY=1 PYTHON=1 HALF=1 N=64 python3 ./extra/gemm/simple_matmul.py
|
||||
|
||||
@@ -3,9 +3,11 @@ from tinygrad.helpers import getenv
|
||||
from tinygrad.dtype import _to_np_dtype
|
||||
from tinygrad import dtypes, Tensor
|
||||
|
||||
dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
|
||||
acc_dtype = dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else None
|
||||
if getenv("INT"): dtype_in, acc_dtype = dtypes.int8, dtypes.int32
|
||||
dtype_in = (dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else
|
||||
dtypes.fp8e4m3 if getenv("FP8E4M3") else dtypes.fp8e5m2 if getenv("FP8E5M2") else dtypes.float)
|
||||
acc_dtype = (dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else
|
||||
dtypes.fp8e4m3 if getenv("ACC_FP8E4M3") else dtypes.fp8e5m2 if getenv("ACC_FP8E5M2") else None)
|
||||
if getenv("INT"): dtype_in = dtypes.int8acc_dtype = dtypes.int32
|
||||
if getenv("UINT"): dtype_in, acc_dtype = dtypes.uint8, dtypes.int32
|
||||
|
||||
N = getenv("N", 4096)
|
||||
|
||||
1
setup.py
1
setup.py
@@ -70,6 +70,7 @@ setup(name='tinygrad',
|
||||
"bottle",
|
||||
"ggml-python",
|
||||
"capstone",
|
||||
"ml_dtypes",
|
||||
"pycocotools",
|
||||
"boto3",
|
||||
"pandas"
|
||||
|
||||
@@ -39,7 +39,8 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target, tol_target_dtype:float
|
||||
if DEBUG >= 2: print(tensor.numpy())
|
||||
try:
|
||||
assert tensor.dtype == target_dtype
|
||||
np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2}.get(target_dtype, tol_target_dtype))
|
||||
np.testing.assert_allclose(tensor.numpy(), target,
|
||||
rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2, dtypes.fp8e5m2: 1, dtypes.fp8e4m3: 1e-1}.get(target_dtype, tol_target_dtype))
|
||||
except AssertionError as e:
|
||||
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
|
||||
|
||||
@@ -56,6 +57,7 @@ def _test_cast(a:Tensor, target_dtype:DType):
|
||||
_test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(_to_np_dtype(target_dtype))))
|
||||
def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
|
||||
if target_dtype == dtypes.bfloat16: raise unittest.SkipTest("no test for bf16 bitcast yet")
|
||||
if target_dtype in dtypes.fp8s: raise unittest.SkipTest("no test for fp8s bitcast yet")
|
||||
if getenv("PTX") and a.dtype == dtypes.int8 and target_dtype.itemsize != a.dtype.itemsize:
|
||||
raise unittest.SkipTest("shape changing bitcast of int8 broken on PTX")
|
||||
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(_to_np_dtype(target_dtype)).tolist())
|
||||
@@ -109,7 +111,7 @@ class TestDType(unittest.TestCase):
|
||||
fields = dtypes.fields()
|
||||
self.assertIn("float", fields)
|
||||
self.assertIn("float32", fields)
|
||||
self.assertEqual(len(fields), 24)
|
||||
self.assertEqual(len(fields), 26)
|
||||
self.assertTrue(all(isinstance(value, DType) for value in fields.values()))
|
||||
self.assertTrue(all(issubclass(_to_np_dtype(value), np.generic) for value in fields.values() if _to_np_dtype(value) is not None))
|
||||
|
||||
@@ -205,6 +207,31 @@ class TestBFloat16DTypeCast(unittest.TestCase):
|
||||
converted = random_values.cast(dtypes.bfloat16).cast(dtypes.float32)
|
||||
np.testing.assert_allclose(converted.numpy(), random_values.cast(dtypes.float32).numpy(), rtol=1e-2, atol=1e-3)
|
||||
|
||||
class TestFp8sDType(unittest.TestCase):
|
||||
def _float_to_fp8_conversion_test(self, dtype, input_values, expected_values):
|
||||
test_tensor = Tensor(input_values).cast(dtype).realize()
|
||||
back_to_float32 = test_tensor.cast(dtypes.float32)
|
||||
np.testing.assert_equal(tuple(back_to_float32.numpy().tolist()), expected_values)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), "fp8e4m3 not supported")
|
||||
def test_float_to_fp8e4m3_conversion(self):
|
||||
self._float_to_fp8_conversion_test(dtypes.fp8e4m3,
|
||||
[10000000.0, -1.0, 402.0, -300.0, -10000000.0, 20.0, 1.4123, 0.0, math.inf, math.nan],
|
||||
[448.0, -1.0, 416.0, -288.0, -448.0, 20.0, 1.375, 0.0, 448.0, math.nan])
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), "fp8e5m2 not supported")
|
||||
def test_float_to_fp8e5m2_conversion(self):
|
||||
self._float_to_fp8_conversion_test(dtypes.fp8e5m2,
|
||||
[10000000.0, -1.0, 402.0, -300.0, -10000000.0, 20.0, 1.4123, 0.0, math.inf, math.nan],
|
||||
[57344.0, -1, 384, -320, -57344.0, 20, 1.5, 0.0, 57344.0, math.nan])
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3) and is_dtype_supported(dtypes.fp8e5m2), "fp8s not supported")
|
||||
def test_fp8e4m3_plus_fp8e5m2_output_dtype(self):
|
||||
a = Tensor([1.0, 2.0, 3.0], dtype=dtypes.fp8e4m3)
|
||||
b = Tensor([1.0, 2.0, 3.0], dtype=dtypes.fp8e5m2)
|
||||
result = a + b
|
||||
self.assertEqual(result.dtype, dtypes.half)
|
||||
|
||||
class TestHalfDType(TestDType): DTYPE = dtypes.half
|
||||
|
||||
class TestFloatDType(TestDType):
|
||||
@@ -264,6 +291,7 @@ class TestBitCast(unittest.TestCase):
|
||||
def test_shape_change_bitcast(self, dt1, dt2):
|
||||
# NOTE: this has to be assume to prevent hypothesis from skipping all samples
|
||||
assume(dt2 != dtypes.bfloat16 and dt1 != dtypes.bfloat16) # no test for bf16 bitcast yet
|
||||
assume(dt1 not in dtypes.fp8s and dt2 not in dtypes.fp8s) # no test for fp8 bitcast yet
|
||||
assume(not (getenv("PTX") and dt1 == dtypes.int8)) # TODO: bitcasting int8 fails in PTX
|
||||
data = rand_for_dtype(dt1, 32).reshape(2, 2, 8)
|
||||
_test_op(lambda: Tensor(data, dtype=dt1).bitcast(dt2), dt2, data.view(_to_np_dtype(dt2)).tolist())
|
||||
@@ -399,6 +427,10 @@ class TestHelpers(unittest.TestCase):
|
||||
def test_bf16_is_float(self):
|
||||
assert dtypes.is_float(dtypes.bfloat16)
|
||||
|
||||
def test_fp8s_are_float(self):
|
||||
assert dtypes.is_float(dtypes.fp8e4m3)
|
||||
assert dtypes.is_float(dtypes.fp8e5m2)
|
||||
|
||||
@given(strat.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_float(d) or dtypes.is_int(d)]), strat.integers(min_value=2, max_value=8))
|
||||
def test_scalar(self, dtype, amt):
|
||||
assert dtype.vec(amt).scalar() == dtype
|
||||
@@ -462,7 +494,7 @@ class TestTypeSpec(unittest.TestCase):
|
||||
dtypes.default_int = default_int
|
||||
assert dtypes.default_int == default_int
|
||||
|
||||
for default_float in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
|
||||
for default_float in [*dtypes.fp8s, dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
|
||||
dtypes.default_float = default_float
|
||||
assert dtypes.default_float == default_float
|
||||
|
||||
@@ -693,7 +725,9 @@ class TestAutoCastType(unittest.TestCase):
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint32)).sum().dtype == dtypes.uint32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint64)).sum().dtype == dtypes.uint64
|
||||
assert (Tensor([0, 1], dtype=dtypes.float16)).sum().dtype == dtypes.float16
|
||||
#assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).sum().dtype == dtypes.fp8e4m3
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).sum().dtype == dtypes.fp8e5m2
|
||||
assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64
|
||||
|
||||
@@ -725,7 +759,9 @@ class TestAutoCastType(unittest.TestCase):
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint32)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint64)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float16)).mean().dtype == dtypes.float16
|
||||
#assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).mean().dtype == dtypes.fp8e4m3
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).mean().dtype == dtypes.fp8e5m2
|
||||
assert (Tensor([0, 1], dtype=dtypes.float32)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float64)).mean().dtype == dtypes.float64
|
||||
|
||||
@@ -740,7 +776,9 @@ class TestAutoCastType(unittest.TestCase):
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint32)).cumsum(0).dtype == dtypes.uint32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint64)).cumsum(0).dtype == dtypes.uint64
|
||||
assert (Tensor([0, 1], dtype=dtypes.float16)).cumsum(0).dtype == dtypes.float16
|
||||
#assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).cumsum().dtype == dtypes.fp8e4m3
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).cumsum().dtype == dtypes.fp8e5m2
|
||||
assert (Tensor([0, 1], dtype=dtypes.float32)).cumsum(0).dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float64)).cumsum(0).dtype == dtypes.float64
|
||||
|
||||
@@ -801,10 +839,10 @@ class TestAutoCastType(unittest.TestCase):
|
||||
def test_gradient_dtype(self):
|
||||
old_default_float = dtypes.default_float
|
||||
|
||||
for default_dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
|
||||
for default_dtype in [*dtypes.fp8s, dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
|
||||
if not is_dtype_supported(default_dtype): continue
|
||||
dtypes.default_float = default_dtype
|
||||
for dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
|
||||
for dtype in [*dtypes.fp8s, dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
|
||||
if not is_dtype_supported(dtype): continue
|
||||
if DEBUG >= 2:
|
||||
print(f"testing {default_dtype=}, {dtype=}")
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
import operator
|
||||
import numpy as np
|
||||
from hypothesis import given, strategies as strat, settings, HealthCheck
|
||||
from hypothesis import given, strategies as strat, settings, HealthCheck, assume
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.helpers import CI, getenv
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
@@ -48,6 +48,8 @@ class ht:
|
||||
float32 = strat.floats(width=32, allow_subnormal=False)
|
||||
float16 = strat.floats(width=16, allow_subnormal=False)
|
||||
bfloat16 = strat.floats(width=16, allow_subnormal=False)
|
||||
fp8e4m3 = strat.sampled_from([0.0, 0.1, 0.5, 1.0, -0.1, -0.5, -1.0])
|
||||
fp8e5m2 = strat.sampled_from([0.0, 0.1, 0.5, 1.0, -0.1, -0.5, -1.0])
|
||||
uint8 = strat.integers(0, 255)
|
||||
uint16 = strat.integers(0, 65535)
|
||||
uint32 = strat.integers(0, 2**32-1)
|
||||
@@ -64,7 +66,8 @@ def universal_test(a, b, dtype, op):
|
||||
if not isinstance(op, tuple): op = (op, op)
|
||||
tensor_value = (op[0](Tensor([a], dtype=dtype), Tensor([b], dtype=dtype))).numpy()
|
||||
numpy_value = op[1](np.array([a]).astype(_to_np_dtype(dtype)), np.array([b]).astype(_to_np_dtype(dtype)))
|
||||
if dtype is dtypes.bfloat16: np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-3, rtol=1e-2)
|
||||
if dtype in dtypes.fp8s: np.testing.assert_allclose(tensor_value, numpy_value, atol=0.5, rtol=1e-2)
|
||||
elif dtype is dtypes.bfloat16: np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-3, rtol=1e-2)
|
||||
elif dtype in dtypes_float: np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-10)
|
||||
else: np.testing.assert_equal(tensor_value, numpy_value)
|
||||
|
||||
@@ -76,7 +79,8 @@ def universal_test_unary(a, dtype, op):
|
||||
run_schedule(sched)
|
||||
tensor_value = out.numpy()
|
||||
numpy_value = op[1](np.array([a]).astype(_to_np_dtype(dtype)))
|
||||
if dtype in (*dtypes_float, dtypes.bfloat16):
|
||||
if dtype in dtypes.fp8s: np.testing.assert_allclose(tensor_value, numpy_value, atol=2, rtol=1e-2)
|
||||
elif dtype in (*dtypes_float, dtypes.bfloat16):
|
||||
np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-3, rtol=1e-2)
|
||||
else: np.testing.assert_equal(tensor_value, numpy_value)
|
||||
if op[0] != Tensor.reciprocal: # reciprocal is not supported in most backends
|
||||
@@ -114,6 +118,26 @@ class TestDTypeALU(unittest.TestCase):
|
||||
@given(ht.bfloat16, ht.bfloat16, strat.sampled_from(binary_operations))
|
||||
def test_bfloat16(self, a, b, op): universal_test(a, b, dtypes.bfloat16, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3, Device.DEFAULT), f"no fp8e4m3 on {Device.DEFAULT}")
|
||||
@given(ht.fp8e4m3, ht.fp8e4m3, strat.sampled_from(binary_operations))
|
||||
def test_fp8e4m3(self, a, b, op): universal_test(a, b, dtypes.fp8e4m3, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2, Device.DEFAULT), f"no fp8e5m2 on {Device.DEFAULT}")
|
||||
@given(ht.fp8e5m2, ht.fp8e5m2, strat.sampled_from(binary_operations))
|
||||
def test_fp8e5m2(self, a, b, op): universal_test(a, b, dtypes.fp8e5m2, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3, Device.DEFAULT), f"no fp8e4m3 on {Device.DEFAULT}")
|
||||
@given(ht.fp8e4m3, strat.sampled_from(unary_operations))
|
||||
def test_fp8e4m3_unary(self, a, op):
|
||||
if (op[1] == np.reciprocal or op[1] == np.log): assume(a != 0.0) # reciprocal(0) and log(0) are undefined
|
||||
universal_test_unary(a, dtypes.fp8e4m3, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2, Device.DEFAULT), f"no fp8e5m2 on {Device.DEFAULT}")
|
||||
@given(ht.fp8e5m2, strat.sampled_from(unary_operations))
|
||||
def test_fp8e5m2_unary(self, a, op):
|
||||
if (op[1] == np.reciprocal or op[1] == np.log): assume(a != 0.0) # reciprocal(0) and log(0) are undefined
|
||||
universal_test_unary(a, dtypes.fp8e5m2, op)
|
||||
|
||||
@given(ht.float32, strat.sampled_from(unary_operations))
|
||||
def test_float32_unary(self, a, op): universal_test_unary(a, dtypes.float32, op)
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ def helper_tc_allclose(N:int, M:int, K:int, dtype_in:DType, dtype_out:DType, axi
|
||||
np_c = np_a @ np_b
|
||||
if dtype_in == dtypes.half: tc_atol, tc_rtol = 1e-2, 1e-3
|
||||
elif dtype_in == dtypes.bfloat16: tc_atol, tc_rtol = 1e-2, 1e-2
|
||||
elif dtype_in in dtypes.fp8s: tc_atol, tc_rtol = 1e-1, 1e-2
|
||||
else: tc_atol, tc_rtol = 5e-3, 1e-4
|
||||
np.testing.assert_allclose(np_c, out, atol=tc_atol, rtol=tc_rtol)
|
||||
|
||||
@@ -62,6 +63,13 @@ def helper_tc_ensure_uops_and_opts_count(N: int, M:int, K:int, dtype_in:DType, d
|
||||
assert wmmas == 0, "tensor core is incorrectly triggered"
|
||||
assert tcs == 0, "tensor core opt is incorrectly included"
|
||||
|
||||
def is_emulated(tc) -> bool:
|
||||
return (getenv("EMULATE_CUDA_SM89") or getenv("EMULATE_CUDA") or getenv("EMULATE_INTEL") or getenv("EMULATE_METAL")
|
||||
or getenv("EMULATE_AMD_MFMA") or getenv("EMULATE_AMD")) and \
|
||||
((tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16) or \
|
||||
(tc.dtype_in == dtypes.fp8e4m3 or tc.dtype_out == dtypes.fp8e4m3) or \
|
||||
(tc.dtype_in == dtypes.fp8e5m2 or tc.dtype_out == dtypes.fp8e5m2))
|
||||
|
||||
class TestLinearizer(unittest.TestCase):
|
||||
def test_arg_dedup(self):
|
||||
# NOTE: this realize exists because Tensor.numpy calls .contiguous() internally
|
||||
@@ -1024,7 +1032,8 @@ class TestLinearizer(unittest.TestCase):
|
||||
|
||||
def test_sum_acc_dtype(self):
|
||||
for tensor_dtype, acc_dtype in (
|
||||
(dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)):
|
||||
(dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float),
|
||||
(dtypes.fp8e4m3, dtypes.float), (dtypes.fp8e5m2, dtypes.float)):
|
||||
if is_dtype_supported(tensor_dtype) and is_dtype_supported(acc_dtype):
|
||||
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
|
||||
k = Kernel(a.schedule()[-1].ast)
|
||||
@@ -1046,6 +1055,10 @@ class TestLinearizer(unittest.TestCase):
|
||||
(dtypes.float16, dtypes.float16, dtypes.float16),
|
||||
(dtypes.bfloat16, dtypes.bfloat16, dtypes.bfloat16),
|
||||
(dtypes.float, dtypes.float16, dtypes.float16),
|
||||
(dtypes.fp8e5m2, dtypes.fp8e5m2, dtypes.fp8e5m2),
|
||||
(dtypes.fp8e4m3, dtypes.fp8e4m3, dtypes.fp8e4m3),
|
||||
(dtypes.fp8e4m3, None, dtypes.float),
|
||||
(dtypes.fp8e5m2, None, dtypes.float),
|
||||
)
|
||||
for tensor_dtype, acc_dtype, expected_dtype in tests:
|
||||
if is_dtype_supported(tensor_dtype) and is_dtype_supported(acc_dtype) and is_dtype_supported(expected_dtype):
|
||||
@@ -1060,8 +1073,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_tensor_cores(self):
|
||||
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
if (getenv("EMULATE_CUDA") or getenv("EMULATE_INTEL") or getenv("EMULATE_METAL") or getenv("EMULATE_AMD_MFMA") or getenv("EMULATE_AMD")) and \
|
||||
(tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
|
||||
if is_emulated(tc): continue
|
||||
if CI and Device.DEFAULT in ("METAL", "AMD") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
|
||||
# for AMX, tc.dims[2] == 1 so reduceop is None thus tensor_cores are not triggered
|
||||
helper_tc_allclose(tc.dims[0], tc.dims[1], 2 if AMX else tc.dims[2], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0)
|
||||
@@ -1089,8 +1101,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_tensor_cores_padded(self):
|
||||
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
if (getenv("EMULATE_CUDA") or getenv("EMULATE_METAL") or getenv("EMULATE_AMD_MFMA") or getenv("EMULATE_AMD")) and \
|
||||
(tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
|
||||
if is_emulated(tc): continue
|
||||
if CI and Device.DEFAULT in ("METAL", "AMD") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
|
||||
pad = 1
|
||||
|
||||
@@ -1117,7 +1128,8 @@ class TestLinearizer(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_tensor_cores_multi_reduce(self):
|
||||
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
if tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16: continue
|
||||
if tc.dtype_in in [dtypes.bfloat16, *dtypes.fp8s] or tc.dtype_out in [dtypes.bfloat16, *dtypes.fp8s]:
|
||||
continue
|
||||
# this will be a M=G16, N=G32, M=G16, M=G16, K=R16, K=R16, K=R16 with 9 choices of TC MNK axes
|
||||
golden_result = None
|
||||
for axis in range(9):
|
||||
|
||||
@@ -3003,6 +3003,17 @@ class TestOpsUint8(unittest.TestCase):
|
||||
lambda x: x.type(torch.uint8).min(),
|
||||
lambda x: x.cast(dtypes.uint8).min(), forward_only=True, vals=[[0, 128, 255, 64, 32, 16]])
|
||||
|
||||
@unittest.skipUnless("CUDA" in Device.get_available_devices() and Device.DEFAULT == "PYTHON" and getenv("EMULATE_CUDA_SM89"),
|
||||
"only for emulated CUDA")
|
||||
class TestOpsFp8s(unittest.TestCase):
|
||||
def _compare_to_cuda(self, shp_a, shp_b, op, dtype):
|
||||
a = Tensor.rand(shp_a, dtype=dtype)
|
||||
b = Tensor.rand(shp_b, dtype=dtype)
|
||||
np.testing.assert_equal(op(a, b).numpy(), op(a.to("CUDA"), b.to("CUDA")).numpy())
|
||||
|
||||
def test_gemm_fp8e4m3(self): self._compare_to_cuda((64, 64), (64, 64), lambda x, y: x.matmul(y), dtypes.fp8e4m3)
|
||||
def test_gemm_fp8e5m2(self): self._compare_to_cuda((64, 64), (64, 64), lambda x, y: x.matmul(y), dtypes.fp8e5m2)
|
||||
|
||||
if __name__ == '__main__':
|
||||
np.random.seed(1337)
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -193,6 +193,26 @@ class TestRandomness(unittest.TestCase):
|
||||
assert nx[nx == 0].size > 0
|
||||
equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.bfloat16).float(), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N))
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), "need fp8e4m3 support")
|
||||
def test_rand_fp8e4m3(self):
|
||||
N = 128
|
||||
x = Tensor.rand((2, N, N), dtype=dtypes.fp8e4m3)
|
||||
assert x.dtype == dtypes.fp8e4m3
|
||||
nx = x.numpy()
|
||||
assert nx[nx == 1].size == 0
|
||||
assert nx[nx == 0].size > 0
|
||||
equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.fp8e4m3).float(), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N))
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), "need fp8e5m2 support")
|
||||
def test_rand_fp8e5m2(self):
|
||||
N = 128
|
||||
x = Tensor.rand((2, N, N), dtype=dtypes.fp8e5m2)
|
||||
assert x.dtype == dtypes.fp8e5m2
|
||||
nx = x.numpy()
|
||||
assert nx[nx == 1].size == 0
|
||||
assert nx[nx == 0].size > 0
|
||||
equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.fp8e5m2).float(), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N))
|
||||
|
||||
def test_rand_like(self):
|
||||
empty = Tensor.empty((80, 44))
|
||||
rand = Tensor.rand_like(empty)
|
||||
|
||||
@@ -332,6 +332,8 @@ def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool:
|
||||
if dtype == dtypes.bfloat16:
|
||||
# NOTE: this requires bf16 buffer support
|
||||
return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
|
||||
if dtype in dtypes.fp8s:
|
||||
return (device in {"CUDA", "NV"} and not CI and not getenv("PTX")) or (device in {"PYTHON"} and getenv("EMULATE_CUDA_SM89"))
|
||||
if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short,
|
||||
dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half]
|
||||
# for CI GPU and OSX, cl_khr_fp16 isn't supported
|
||||
|
||||
@@ -112,7 +112,8 @@ class dtypes:
|
||||
def finfo(dtype:DType) -> tuple[int, int]:
|
||||
"""(exponent, mantissa)"""
|
||||
if not dtypes.is_float(dtype): raise ValueError(f"{dtype} is not a floating point type")
|
||||
return {dtypes.float16: (5, 10), dtypes.bfloat16: (8, 7), dtypes.float32: (8, 23), dtypes.float64: (11, 52)}[dtype]
|
||||
return {dtypes.float16: (5, 10), dtypes.bfloat16: (8, 7), dtypes.float32: (8, 23), dtypes.float64: (11, 52),
|
||||
dtypes.fp8e5m2: (5, 2), dtypes.fp8e4m3: (4, 3)}[dtype]
|
||||
@staticmethod
|
||||
def fields() -> dict[str, DType]: return DTYPES_DICT
|
||||
void: Final[DType] = DType.new(-1, 0, "void", None)
|
||||
@@ -125,11 +126,13 @@ class dtypes:
|
||||
uint32: Final[DType] = DType.new(6, 4, "unsigned int", 'I')
|
||||
int64: Final[DType] = DType.new(7, 8, "long", 'q')
|
||||
uint64: Final[DType] = DType.new(8, 8, "unsigned long", 'Q')
|
||||
float16: Final[DType] = DType.new(9, 2, "half", 'e')
|
||||
fp8e4m3: Final[DType] = DType.new(9, 1, "float8_e4m3", None)
|
||||
fp8e5m2: Final[DType] = DType.new(10, 1, "float8_e5m2", None)
|
||||
float16: Final[DType] = DType.new(11, 2, "half", 'e')
|
||||
# bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16
|
||||
bfloat16: Final[DType] = DType.new(10, 2, "__bf16", None)
|
||||
float32: Final[DType] = DType.new(11, 4, "float", 'f')
|
||||
float64: Final[DType] = DType.new(12, 8, "double", 'd')
|
||||
bfloat16: Final[DType] = DType.new(12, 2, "__bf16", None)
|
||||
float32: Final[DType] = DType.new(13, 4, "float", 'f')
|
||||
float64: Final[DType] = DType.new(14, 8, "double", 'd')
|
||||
|
||||
# dtype aliases
|
||||
half = float16; float = float32; double = float64 # noqa: E702
|
||||
@@ -145,7 +148,8 @@ class dtypes:
|
||||
default_float: ClassVar[DType] = float32
|
||||
default_int: ClassVar[DType] = int32
|
||||
|
||||
floats = (float16, bfloat16, float32, float64)
|
||||
floats = (fp8e4m3, fp8e5m2, float16, bfloat16, float32, float64)
|
||||
fp8s = (fp8e4m3, fp8e5m2)
|
||||
uints = (uint8, uint16, uint32, uint64)
|
||||
sints = (int8, int16, int32, int64)
|
||||
ints = uints + sints
|
||||
@@ -161,8 +165,9 @@ def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType)
|
||||
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
|
||||
# we don't support weak type and complex type
|
||||
promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64],
|
||||
dtypes.int64: [dtypes.float16, dtypes.bfloat16], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32],
|
||||
dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float16, dtypes.bfloat16],
|
||||
dtypes.int64: [dtypes.fp8e5m2, dtypes.fp8e4m3], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32],
|
||||
dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.fp8e5m2, dtypes.fp8e4m3],
|
||||
dtypes.fp8e5m2: [dtypes.float16, dtypes.bfloat16], dtypes.fp8e4m3: [dtypes.float16, dtypes.bfloat16],
|
||||
dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], }
|
||||
|
||||
@functools.cache
|
||||
@@ -170,6 +175,9 @@ def _get_recursive_parents(dtype:DType) -> set[DType]:
|
||||
return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64}
|
||||
@functools.cache
|
||||
def least_upper_dtype(*ds:DType) -> DType:
|
||||
from tinygrad.device import is_dtype_supported
|
||||
if not is_dtype_supported(dtypes.fp8e4m3) and not is_dtype_supported(dtypes.fp8e5m2):
|
||||
promo_lattice[dtypes.int64] = promo_lattice[dtypes.uint64] = [dtypes.float16, dtypes.bfloat16]
|
||||
return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
|
||||
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.default_float)
|
||||
|
||||
|
||||
@@ -37,7 +37,8 @@ base_rewrite = PatternMatcher([
|
||||
(UPat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"),
|
||||
(UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"),
|
||||
# consts are rendered to larger type and casted
|
||||
(UPat(Ops.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}f')})"),
|
||||
(UPat(Ops.CONST, (*dtypes.fp8s, dtypes.bfloat16, dtypes.half), name="x"),
|
||||
lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}f')})"),
|
||||
(UPat(Ops.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}u')})"),
|
||||
(UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, x.arg)})"),
|
||||
# default const render
|
||||
@@ -331,14 +332,20 @@ class CUDARenderer(CStyleLanguage):
|
||||
(dtypes.half,dtypes.half)]]
|
||||
tc_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
|
||||
swizzle=(((6,7,2,3,4),(0,1,8,5,9)), ((6,7,8,0,1),(2,3,4,9,5)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.half,dtypes.half)]]
|
||||
tc_81632_f8 = [TensorCore(dims=(8,16,32), threads=32, elements_per_thread=(16,8,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
|
||||
swizzle=(((7,8,2,3,4),(0,1,10,5,6,11,9)), ((7,8,10,0,1),(2,3,4,11,5,6,9))))
|
||||
for di,do in [(dtypes.fp8e4m3,dtypes.float),(dtypes.fp8e5m2,dtypes.float)]]
|
||||
tc_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.float, dtype_out=dtypes.float, opts=cuda_tc_opts,
|
||||
swizzle=(((5,6,2,3,4),(0,1,8,9,7)), ((5,6,8,0,1),(2,3,4,9,7))))]
|
||||
|
||||
tc_sm89 = tc_81616 + tc_8168_f16 + tc_81632_f8
|
||||
tc_sm80 = tc_81616 + tc_8168_f16
|
||||
if getenv("ALLOW_TF32", 0): tc_sm80 += tc_8168_tf32
|
||||
tc_sm75 = tc_8168_f16
|
||||
def __init__(self, arch:str):
|
||||
self.tensor_cores, self.arch = CUDARenderer.tc_sm80 if int(arch[3:]) >= 80 else CUDARenderer.tc_sm75 if int(arch[3:]) >= 75 else [], arch
|
||||
def __init__(self, arch: str):
|
||||
self.arch = arch
|
||||
tensor_cores_map = {89: CUDARenderer.tc_sm89, 80: CUDARenderer.tc_sm80, 75: CUDARenderer.tc_sm75}
|
||||
self.tensor_cores = next((tc for version, tc in sorted(tensor_cores_map.items(), reverse=True) if int(arch[3:]) >= version), [])
|
||||
def __reduce__(self): return self.__class__, (self.arch,)
|
||||
|
||||
# language options
|
||||
@@ -355,7 +362,19 @@ class CUDARenderer(CStyleLanguage):
|
||||
Ops.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",
|
||||
Ops.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
|
||||
Ops.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})" }
|
||||
type_map = {dtypes.bfloat16: "nv_bfloat16"}
|
||||
type_map = {dtypes.bfloat16: "nv_bfloat16", dtypes.fp8e4m3: "__nv_fp8_e4m3", dtypes.fp8e5m2: "__nv_fp8_e5m2"}
|
||||
|
||||
@staticmethod
|
||||
def __create_fp8_patterns(dtype) -> list:
|
||||
return [(UPat(Ops.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtype), UPat.var("y", dtype=dtype))),
|
||||
lambda b, x, y, dtype=dtype: UOp(Ops.WHERE, dtype=dtypes.float, src=(b, x.cast(dtypes.float), y.cast(dtypes.float))).cast(dtype)),
|
||||
(UPat(GroupOp.ALU, dtype=dtype, name="x"),
|
||||
lambda x, dtype=dtype: UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtype)),
|
||||
(UPat(GroupOp.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtype), UPat.var("y", dtype=dtype))),
|
||||
lambda alu, x, y, dtype=dtype: UOp(alu.op, dtypes.bool, (x.cast(dtypes.float), y.cast(dtypes.float)), alu.arg)),
|
||||
(UPat((Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN), dtype=dtype, name="x"),
|
||||
lambda x, dtype=dtype: UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtype)),]
|
||||
extra_matcher = PatternMatcher(__create_fp8_patterns(dtypes.fp8e4m3) + __create_fp8_patterns(dtypes.fp8e5m2)) + extra_pm
|
||||
|
||||
def render_vector_prefix(self, dt:DType) -> str:
|
||||
vec, scal = self.render_dtype(dt), self.render_dtype(dt.scalar()),
|
||||
@@ -367,11 +386,13 @@ class CUDARenderer(CStyleLanguage):
|
||||
prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"]
|
||||
|
||||
used_dtypes = uops_to_dtypes(uops)
|
||||
if any(dt.scalar() in dtypes.fp8s for dt in used_dtypes): prefix.append("#include <cuda_fp8.h>")
|
||||
if any(dt.scalar() == dtypes.half for dt in used_dtypes): prefix.append("#include <cuda_fp16.h>")
|
||||
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("#include <cuda_bf16.h>")
|
||||
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16}]
|
||||
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if (dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16})
|
||||
or (dt.count in (8,16) and dt.scalar() in dtypes.fp8s)]
|
||||
|
||||
dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16" }
|
||||
dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16", dtypes.fp8e4m3: "e4m3", dtypes.fp8e5m2: "e5m2"}
|
||||
dt_map_out = { dtypes.float: "f32", dtypes.half: "f16" }
|
||||
for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
|
||||
upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes]
|
||||
|
||||
@@ -10,6 +10,10 @@ from tinygrad.device import Compiled, Compiler, Allocator
|
||||
from tinygrad.ops import exec_alu, Ops, UOp, GroupOp
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer, IntelRenderer, ClangRenderer
|
||||
if getenv("EMULATE_CUDA_SM89"):
|
||||
import numpy as np
|
||||
from ml_dtypes import float8_e4m3, float8_e5m2
|
||||
truncate.update({dtypes.fp8e4m3: float8_e4m3, dtypes.fp8e5m2: float8_e5m2, float8_e5m2.dtype: np.float32, float8_e4m3.dtype: np.float32})
|
||||
|
||||
def _load(m, i):
|
||||
if i is None: return 0.0
|
||||
@@ -67,10 +71,13 @@ class PythonProgram:
|
||||
assert dtype is not None, f"{uop} is missing a dtype"
|
||||
dl[i] = dtype
|
||||
if uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL}:
|
||||
assert dtype.fmt is not None and isinstance(dtype, PtrDType)
|
||||
assert isinstance(dtype, PtrDType)
|
||||
if TYPE_CHECKING or sys.version_info < (3, 12): assert dtype.fmt != "e"
|
||||
buf = memoryview(bytearray(dtype.size*dtype.itemsize)) if uop is Ops.DEFINE_LOCAL else pbufs.pop(0)
|
||||
ul[i] = [buf.cast(dtype.fmt)] * warp_size
|
||||
if dtype.base in dtypes.fp8s: ul[i] = [np.frombuffer(buf, dtype=dtype.name)] * warp_size
|
||||
else:
|
||||
assert dtype.fmt is not None
|
||||
ul[i] = [buf.cast(dtype.fmt)] * warp_size
|
||||
elif uop is Ops.DEFINE_VAR:
|
||||
ul[i] = [pvals.pop(0)] * warp_size
|
||||
elif uop is Ops.SPECIAL:
|
||||
@@ -100,11 +107,13 @@ class PythonProgram:
|
||||
i = loop_ends[i] + 1
|
||||
continue
|
||||
elif uop is Ops.VECTORIZE: ul[i] = inp
|
||||
elif uop in {Ops.CAST, Ops.BITCAST}:
|
||||
assert dtp[0].fmt and dtype.fmt
|
||||
pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
|
||||
if uop is Ops.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
|
||||
else: ul[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in inp[0]]
|
||||
elif uop is Ops.BITCAST:
|
||||
assert (dtp[0].fmt and dtype.fmt) or (dtp[0] in dtypes.fp8s and dtype) or (dtype in dtypes.fp8s and dtp[0])
|
||||
if dtp[0] in dtypes.fp8s: packed = b''.join([truncate.get(dtp[0], lambda dt: dt)(z).tobytes() for z in inp[0]])
|
||||
else: packed = struct.pack(str(warp_size) + str(dtp[0].fmt), *inp[0])
|
||||
if dtype in dtypes.fp8s: ul[i] = np.frombuffer(packed,dtype=truncate.get(dtype, lambda x: x)).tolist()
|
||||
else: ul[i] = list(struct.unpack(str(warp_size) + str(dtype.fmt), packed))
|
||||
elif uop is Ops.CAST: ul[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in inp[0]]
|
||||
elif uop is Ops.LOAD:
|
||||
if dtype.count > 1:
|
||||
ul[i] = [load([inp[i][j] if i != 0 and dtp[i].count > 1 else inp[i] for i in range(len(inp))], j) for j in range(dtype.count)]
|
||||
@@ -126,7 +135,8 @@ class PythonProgram:
|
||||
for lane_id in range(WARP_THREADS):
|
||||
for elem_idx in range(NUM_C): # calculate new muls and add to acc
|
||||
(c_i, c_j) = c_map(lane_id, elem_idx)
|
||||
out[elem_idx][goff+lane_id] += sum(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff) for _k in range(K))
|
||||
def cast_fn(x): return truncate.get(x.dtype, lambda x: x)(x) if dtp[0].scalar() in dtypes.fp8s else x
|
||||
out[elem_idx][goff+lane_id] += sum(cast_fn(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff)) for _k in range(K))
|
||||
return out
|
||||
|
||||
# TODO: refactor these to a shared TensorCoreLayout in kernel.py
|
||||
@@ -169,6 +179,11 @@ class PythonProgram:
|
||||
def b_elem(x, col, k, goff): return x[k//4][goff + k%4 + col*4]
|
||||
ul[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
|
||||
|
||||
elif arg[1] == (8,16,32):
|
||||
def a_elem(x, k, row, goff): return x[k%4 + (k//16)*8 + (row//8)*4][goff + (k//4)%4 + (row%8)*4]
|
||||
def b_elem(x, col, k, goff): return x[k%4 + (k//16)*4][goff + (k//4)%4 + col*4]
|
||||
ul[i] = wmma_helper(32, 32, 16, 8, 4, a_elem, b_elem, c_map)
|
||||
|
||||
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
|
||||
elif arg[4] == "INTEL":
|
||||
# A (16 elements on 8 threads)
|
||||
@@ -199,6 +214,7 @@ class PythonRenderer(Renderer):
|
||||
if getenv("EMULATE_AMD_MFMA"): self.device, self.tensor_cores = "AMD", AMDRenderer.tensor_cores_mfma
|
||||
if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tc_sm80
|
||||
if getenv("EMULATE_CUDA_SM75"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tc_sm75
|
||||
if getenv("EMULATE_CUDA_SM89"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tc_sm89
|
||||
if getenv("EMULATE_INTEL"): self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", IntelRenderer.tensor_cores
|
||||
if getenv("EMULATE_AMX"): self.device, self.tensor_cores = "CPU", ClangRenderer.tensor_cores
|
||||
|
||||
|
||||
@@ -149,7 +149,8 @@ class Tensor(SimpleMathTrait):
|
||||
if dtype is None:
|
||||
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
|
||||
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float # NOTE: this works because all_int([True, False]) is True
|
||||
if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata
|
||||
if dtype in [dtypes.bfloat16, *dtypes.fp8s]:
|
||||
data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtype).lazydata
|
||||
else: data = _frompy(data, dtype)
|
||||
elif str(type(data)) == "<class 'numpy.ndarray'>":
|
||||
import numpy as np
|
||||
@@ -344,7 +345,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
||||
import numpy as np
|
||||
if self.dtype.base == dtypes.bfloat16: return self.float().numpy()
|
||||
if self.dtype.base in [dtypes.bfloat16, *dtypes.fp8s]: return self.float().numpy()
|
||||
if 0 in self.shape: return np.empty(self.shape, dtype=_to_np_dtype(self.dtype.base))
|
||||
return self._buffer().numpy().reshape(self.shape)
|
||||
|
||||
@@ -1590,7 +1591,7 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
ret = self.cast(sum_acc_dtype(self.dtype) if dtype is None else dtype)._reduce(Ops.ADD, axis, keepdim)
|
||||
return ret.cast(self.dtype) if dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
|
||||
return ret.cast(self.dtype) if dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16, *dtypes.fp8s) else ret
|
||||
|
||||
def prod(self, axis:int|Sequence[int]|None=None, keepdim=False, dtype:DTypeLike|None=None) -> Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user