diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8e815624b7..411fe2c7d7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -204,7 +204,7 @@ jobs: DEBUG=2 EMULATE=CUDA FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm_fp16 DEBUG=2 EMULATE=CUDA ALLOW_TF32=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm DEBUG=2 EMULATE=CUDA_SM75 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm_fp16 - DEBUG=2 EMULATE=CUDA ALLOW_TF32=1 FORWARD_ONLY=1 PYTHON=1 python3 test/opt/test_tensor_cores.py + DEBUG=2 EMULATE=CUDA_SM89 ALLOW_TF32=1 FORWARD_ONLY=1 PYTHON=1 python3 test/opt/test_tensor_cores.py - name: Test emulated INTEL OpenCL tensor cores run: DEBUG=2 EMULATE=INTEL FORWARD_ONLY=1 PYTHON=1 HALF=1 N=64 python3 ./extra/gemm/simple_matmul.py - name: Test emulated AMX tensor cores diff --git a/extra/gemm/simple_matmul.py b/extra/gemm/simple_matmul.py index 0c91005a16..5a9f2da940 100644 --- a/extra/gemm/simple_matmul.py +++ b/extra/gemm/simple_matmul.py @@ -5,8 +5,10 @@ from tinygrad.dtype import _to_np_dtype from tinygrad.codegen.opt import OptOps from tinygrad.engine.realize import lower_schedule -dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float -acc_dtype = dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else None +dtype_in = (dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else + dtypes.fp8e4m3 if getenv("FP8E4M3") else dtypes.fp8e5m2 if getenv("FP8E5M2") else dtypes.float) +acc_dtype = (dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else + dtypes.fp8e4m3 if getenv("ACC_FP8E4M3") else dtypes.fp8e5m2 if getenv("ACC_FP8E5M2") else None) if getenv("INT"): dtype_in, acc_dtype = dtypes.int8, dtypes.int32 if getenv("UINT"): dtype_in, acc_dtype = dtypes.uint8, dtypes.int32 @@ -14,8 +16,10 @@ N = getenv("N", 4096) M = getenv("M", N) K = getenv("K", N) CNT = getenv("CNT", 10) -ATOL = getenv("ATOL", 1e-4) -RTOL = getenv("RTOL", 3e-2) + +atol, rtol = {dtypes.bfloat16:(1e-3, 1e-2), dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2:(1.0, 5e-1)}.get(dtype_in, (1e-4, 3e-2)) +ATOL, RTOL = getenv("ATOL", atol), getenv("RTOL", rtol) + INT_LOW = getenv("INT_LOW", 0) INT_HIGH = getenv("INT_HIGH", 10) diff --git a/test/test_dtype_alu.py b/test/test_dtype_alu.py index 446f3899d2..3f51c28c3c 100644 --- a/test/test_dtype_alu.py +++ b/test/test_dtype_alu.py @@ -75,7 +75,10 @@ def universal_test_unary(a, dtype, op): out: Tensor = op[0](ta) tensor_value = out.numpy() numpy_value = op[1](ta.numpy()) - if dtype in dtypes.fp8s: numpy_value = truncate[dtype](numpy_value) + if dtype in dtypes.fp8s: + # cuda cast f32 inf to f8 MAX, amd cast it to nan(E4M3)/inf(E5M2) + if math.isinf(numpy_value): return + numpy_value = truncate[dtype](numpy_value) if dtype in dtypes.floats: atol, rtol = { dtypes.float16:(1e-3, 1e-2), dtypes.bfloat16:(1e-3, 2e-2), dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2: (1.0, 5e-1)}.get(dtype, (1e-6, 1e-5)) diff --git a/tinygrad/codegen/opt/tc.py b/tinygrad/codegen/opt/tc.py index d32266c47a..b5b4dedd31 100644 --- a/tinygrad/codegen/opt/tc.py +++ b/tinygrad/codegen/opt/tc.py @@ -80,6 +80,10 @@ cuda_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4) swizzle=((('r1', 'r2', 'l2', 'l3', 'l4'), ('u1', 'r3'), ('l0', 'l1', 'u0', 'r0')), (('r1', 'r2', 'u0', 'l0', 'l1'), ('r0', 'r3'), ('l2', 'l3', 'l4', 'u1')))) for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float), (dtypes.half,dtypes.half)]] +cuda_81632_f8 = [TensorCore(dims=(8,16,32), threads=32, elements_per_thread=(16,8,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts, + swizzle=((('r2', 'r3', 'l2', 'l3', 'l4'), ('u1', 'r4'), ('l0', 'l1', 'u0', 'r0', 'r1')), + (('r2', 'r3', 'u0', 'l0', 'l1'), ('r1', 'r4'), ('l2', 'l3', 'l4', 'u1', 'r0')))) + for di,do in [(dtypes.fp8e4m3,dtypes.float),(dtypes.fp8e5m2,dtypes.float)]] cuda_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts, swizzle=((('r1', 'r2', 'l2', 'l3', 'l4'), ('r0', 'u1'), ('l0', 'l1', 'u0')), (('r1', 'r2', 'u0', 'l0', 'l1'), ('u1', 'r0'), ('l2', 'l3', 'l4')))) @@ -87,9 +91,10 @@ cuda_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2, cuda_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.float, dtype_out=dtypes.float, opts=cuda_tc_opts, swizzle=((('r0', 'r1', 'l2', 'l3', 'l4'), ('u1', 'r2'), ('l0', 'l1', 'u0')), (('r0', 'r1', 'u0', 'l0', 'l1'), ('u1', 'r2'), ('l2', 'l3', 'l4'))))] +cuda_sm75: list[TensorCore] = cuda_8168_f16 cuda_sm80: list[TensorCore] = cuda_81616 + cuda_8168_f16 if getenv("ALLOW_TF32", 0): cuda_sm80 += cuda_8168_tf32 -cuda_sm75: list[TensorCore] = cuda_8168_f16 +cuda_sm89: list[TensorCore] = cuda_sm80 + cuda_81632_f8 # ***** AMD ***** diff --git a/tinygrad/device.py b/tinygrad/device.py index 2e4b1f2520..f3af6cf044 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -331,7 +331,9 @@ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool: if device in {"CUDA", "NV"}: return not CI and not getenv(f"{device}_PTX") and not getenv("NV_NAK") if device in {"CPU"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} and not getenv("CPU_LVP") return device in {"AMD", "PYTHON", "NULL"} - if dtype in dtypes.fp8s: return device in {"PYTHON", "NULL"} + if dtype in dtypes.fp8s: + if device in {"CUDA", "NV"}: return not CI and not getenv(f"{device}_PTX") and not getenv("NV_NAK") + return device in {"PYTHON", "NULL"} if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short, dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half] # for CI GPU and OSX, cl_khr_fp16 isn't supported diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 5afcd0711a..e5031ea79d 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -37,7 +37,7 @@ base_rewrite = PatternMatcher([ (UPat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{truncate[x.dtype](x.arg)}u"), (UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"), # consts are rendered to larger type and casted - (UPat(Ops.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}f')})"), + (UPat(Ops.CONST, (*dtypes.fp8s, dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}f')})"), (UPat(Ops.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}u')})"), (UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, str(x.arg))})"), # default const render @@ -345,7 +345,8 @@ class CUDARenderer(CStyleLanguage): shared_max = 49152 def __init__(self, arch:str): - self.tensor_cores, self.arch = tc.cuda_sm80 if int(arch[3:]) >= 80 else tc.cuda_sm75 if int(arch[3:]) >= 75 else [], arch + self.arch = arch + self.tensor_cores = tc.cuda_sm89 if int(arch[3:]) >= 89 else tc.cuda_sm80 if int(arch[3:]) >= 80 else tc.cuda_sm75 if int(arch[3:]) >= 75 else [] def __reduce__(self): return self.__class__, (self.arch,) # language options @@ -364,8 +365,14 @@ class CUDARenderer(CStyleLanguage): Ops.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})", Ops.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})", Ops.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})" } - type_map = {dtypes.bfloat16: "nv_bfloat16"} - + type_map = {dtypes.bfloat16: "nv_bfloat16", dtypes.fp8e4m3: "__nv_fp8_e4m3", dtypes.fp8e5m2: "__nv_fp8_e5m2"} + extra_matcher = PatternMatcher([ + (UPat(Ops.CAST, dtypes.fp8s, UPat.var("x", dtypes.fp8s), name='y'), lambda x,y: x.cast(dtypes.float).cast(y.dtype) if x.dtype!=y.dtype else None), + (UPat(GroupOp.ALU, dtype=dtypes.fp8s, name="x"), + lambda x: UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(x.dtype)), + (UPat(GroupOp.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.fp8s), UPat.var("y", dtype=dtypes.fp8s))), + lambda alu,x,y: UOp(alu.op, dtypes.bool, (x.cast(dtypes.float), y.cast(dtypes.float)), alu.arg)), + ]) + extra_pm def render_vector_prefix(self, dt:DType) -> str: vec, scal = self.render_dtype(dt), self.render_dtype(dt.scalar()), elems, header = ', '.join(_nms[:dt.count]), ', '.join([f"{scal} {x}" for x in _nms[:dt.count]]) @@ -376,11 +383,12 @@ class CUDARenderer(CStyleLanguage): prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"] used_dtypes = uops_to_dtypes(uops) + if any(dt.scalar() in dtypes.fp8s for dt in used_dtypes): prefix.append("#include ") if any(dt.scalar() == dtypes.half for dt in used_dtypes): prefix.append("#include ") if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("#include ") - prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16}] - - dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16" } + prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if (dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16}) + or (dt.count in (8,16) and dt.scalar() in dtypes.fp8s)] + dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16", dtypes.fp8e4m3: "e4m3", dtypes.fp8e5m2: "e5m2" } dt_map_out = { dtypes.float: "f32", dtypes.half: "f16" } for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in wmma_args(uops): upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes] diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 9a8ade8e18..780762539d 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -177,6 +177,11 @@ class PythonProgram: def b_elem(x, col, k, goff): return x[k%2 + (k//8)*2][goff + (k//2)%4 + col*4] ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map) + elif dims == (8,16,32): + def a_elem(x, k, row, goff): return x[k%4 + (row//8)*4 + (k//16)*8][goff + (k//4)%4 + (row%8)*4] + def b_elem(x, col, k, goff): return x[k%4 + (k//16)*4][goff + (k//4)%4 + col*4] + ul[i] = wmma_helper(32, 32, 16, 8, 4, a_elem, b_elem, c_map) + elif dims == (8,16,8) and dtype_in == dtypes.half: def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2][goff + k//2 + (row%8)*4] def b_elem(x, col, k, goff): return x[k%2][goff + k//2 + col*4] @@ -220,6 +225,7 @@ class PythonRenderer(Renderer): case "AMD_RDNA4": self.device, self.tensor_cores = "AMD", tc.amd_rdna4 case "CUDA": self.device, self.tensor_cores = "CUDA", tc.cuda_sm80 case "CUDA_SM75": self.device, self.tensor_cores = "CUDA", tc.cuda_sm75 + case "CUDA_SM89": self.device, self.tensor_cores = "CUDA", tc.cuda_sm89 case "INTEL": self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", tc.intel case "AMX": self.device, self.tensor_cores = "CPU", tc.amx case "": pass