From 8dfcdb123d39711e704aa633ff538fffa6205393 Mon Sep 17 00:00:00 2001 From: b1tg <33436708+b1tg@users.noreply.github.com> Date: Sun, 27 Jul 2025 12:24:05 +0800 Subject: [PATCH] less wmma args (#11385) * less wmma args * scalar * ops_python * mypy * lint * dedup * helper wmma_args --------- Co-authored-by: b1tg Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> --- tinygrad/renderer/cstyle.py | 36 +++++++++++++++++++--------------- tinygrad/renderer/ptx.py | 6 +++--- tinygrad/runtime/ops_python.py | 23 ++++++++++++---------- 3 files changed, 36 insertions(+), 29 deletions(-) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 82e30f726d..0b14577552 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -75,6 +75,10 @@ extra_pm = PatternMatcher([ def uops_to_dtypes(uops:list[UOp]) -> list[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType))) +# (name, dims, dtype_in, dtype_out, device, threads, upcast_axes, reduce_axes) +def wmma_args(uops:list[UOp]): + return dedup((uop.arg[0], uop.arg[1], uop.src[0].dtype.scalar(), uop.dtype.scalar(), *(uop.arg[4:8])) for uop in uops if uop.op is Ops.WMMA) + class CStyleLanguage(Renderer): kernel_typedef: str = "void" buffer_prefix: str = "" @@ -210,7 +214,7 @@ class ClangRenderer(CStyleLanguage): def _render_defines(self, uops) -> list[str]: prefix = [self.render_vector_prefix(dt) for dt in uops_to_dtypes(uops) if dt.count > 1] # https://github.com/corsix/amx - for name, (N, M, _), dtype_in, _, _, _, _, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]): + for name, (N, M, _), dtype_in, _, _, _, _, _ in wmma_args(uops): prefix += [ '#define AMX_SET(imm5) __asm("nop\\nnop\\nnop\\n.word (0x201000+(%0<<5)+%1)" : : "i"(17), "i"(imm5) : "memory")', '#define AMX(op, gpr, btf) __asm(".word (0x201000+(%0 << 5)+0%1-((0%1>>4)*6))" : : "i"(op), "r"((unsigned long long)(gpr)+(btf)) : "memory")', @@ -270,9 +274,9 @@ class IntelRenderer(OpenCLRenderer): def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str: prefix = [] - for arg in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]): - dt_in = ("ushort", "bf16") if arg[2] == dtypes.bfloat16 else (arg[2].name, "f16") - prefix.append(f"""{arg[3].name}8 __{arg[0]}({dt_in[0]}16 a, {dt_in[0]}16 b, {arg[3].name}8 c) {{ + for name, _, dtype_in, dtype_out, _, _, _, _ in wmma_args(uops): + dt_in = ("ushort", "bf16") if dtype_in == dtypes.bfloat16 else (dtype_in.name, "f16") + prefix.append(f"""{dtype_out.name}8 __{name}({dt_in[0]}16 a, {dt_in[0]}16 b, {dtype_out.name}8 c) {{ return intel_sub_group_{dt_in[1]}_{dt_in[1]}_matrix_mad_k16(as_int8(a), as_int8(b), c);\n}}""") return super().render_kernel(function_name, kernel, bufs, uops, prefix or None) @@ -308,13 +312,13 @@ class MetalRenderer(CStyleLanguage): ]) + base_rewrite def render_kernel(self, function_name, kernel, bufs, uops, prefix=None): - prefix, wmma_args = ["#include ","using namespace metal;"], set([uop.arg for uop in uops if uop.op is Ops.WMMA]) - for arg in wmma_args: prefix.append( - f"""{(dtype_out:=self.render_dtype(arg[3].vec(2)))} __{arg[0]}({(dtype_in:=self.render_dtype(arg[2].vec(2)))} a, {dtype_in} b, {dtype_out} c){{ - simdgroup_{self.render_dtype(arg[2])}8x8 mat_a, mat_b; simdgroup_{self.render_dtype(arg[3])}8x8 mat_c; + prefix = ["#include ","using namespace metal;"] + for name, _, dtype_in, dtype_out, _, _, _, _ in wmma_args(uops): prefix.append( + f"""{(dstr_out:=self.render_dtype(dtype_out.vec(2)))} __{name}({(dstr_in:=self.render_dtype(dtype_in.vec(2)))} a, {dstr_in} b, {dstr_out} c){{ + simdgroup_{self.render_dtype(dtype_in)}8x8 mat_a, mat_b; simdgroup_{self.render_dtype(dtype_out)}8x8 mat_c; mat_a.thread_elements()[0] = a[0]; mat_b.thread_elements()[0] = b[0]; mat_c.thread_elements()[0] = c[0]; mat_a.thread_elements()[1] = a[1]; mat_b.thread_elements()[1] = b[1]; mat_c.thread_elements()[1] = c[1]; - simdgroup_multiply_accumulate(mat_c, mat_a, mat_b, mat_c);\n return {dtype_out}(mat_c.thread_elements()[0], mat_c.thread_elements()[1]);\n}}""") + simdgroup_multiply_accumulate(mat_c, mat_a, mat_b, mat_c);\n return {dstr_out}(mat_c.thread_elements()[0], mat_c.thread_elements()[1]);\n}}""") return super().render_kernel(function_name, kernel, bufs, uops, prefix) _nms = "xyzwabcdefghijkl" @@ -363,7 +367,7 @@ class CUDARenderer(CStyleLanguage): dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16" } dt_map_out = { dtypes.float: "f32", dtypes.half: "f16" } - for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]): + for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in wmma_args(uops): upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes] wmma_dtypes = [self.render_dtype(dtype.vec(size)) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)] n_operands = [size*dtype.itemsize//4 for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)] # 4 => CUDA reg size in bytes @@ -458,15 +462,15 @@ class AMDRenderer(CStyleLanguage): if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("typedef unsigned short hip_bfloat16;") prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count > 1] - for arg in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper + for name, _, dtype_in, dtype_out, _, _, _, _ in wmma_args(uops): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper if self.tensor_cores == tc.amd_cdna: - prefix.append(f"#define __{arg[0]} __builtin_amdgcn_mfma_f32_16x16x16{'f16' if arg[2] == dtypes.half else 'bf16_1k'}") + prefix.append(f"#define __{name} __builtin_amdgcn_mfma_f32_16x16x16{'f16' if dtype_in == dtypes.half else 'bf16_1k'}") # #define __WMMA_16_16_16_half_half __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12 elif self.tensor_cores == tc.amd_rdna4: - prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_{type_map[arg[3]]}_16x16x16_{type_map[arg[2]]}_w32_gfx12") - elif arg[3] == dtypes.float: - prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_{'f16' if arg[2] == dtypes.half else 'bf16'}_w32") - else: prefix.append(f"static inline __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) { + prefix.append(f"#define __{name} __builtin_amdgcn_wmma_{type_map[dtype_out]}_16x16x16_{type_map[dtype_in]}_w32_gfx12") + elif dtype_out == dtypes.float: + prefix.append(f"#define __{name} __builtin_amdgcn_wmma_f32_16x16x16_{'f16' if dtype_in == dtypes.half else 'bf16'}_w32") + else: prefix.append(f"static inline __attribute__((device)) half8 __{name}"+"""(half16 a, half16 b, half8 c) { half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; } c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a, b, c_frag, false); for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""") diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index 96db0836c3..63f98904c9 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -207,9 +207,9 @@ class PTXRenderer(Renderer): elif u.op is Ops.DEFINE_GLOBAL: bufs.append((f"data{u.arg}", u.dtype)) elif u.op is Ops.WMMA: # registers for packing/unpacking input and acc - self.wmma_r = [[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[0]]), 4 // u.arg[2].itemsize)], - [ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[1]]), 4 // u.arg[2].itemsize)], - [ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.arg[3].itemsize)]] + self.wmma_r = [[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[0]]), 4 // u.dtype.scalar().itemsize)], + [ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[1]]), 4 // u.dtype.scalar().itemsize)], + [ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.src[0].dtype.scalar().itemsize)]] r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.ENDRANGE: ("pred", "pred"), Ops.RANGE: ("ridx", None), Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL:("local",self.types[dtypes.ulong]), diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 01fd318546..d9b4052413 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -123,24 +123,27 @@ class PythonProgram: out[elem_idx][goff+lane_id] += sum(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff) for _k in range(K)) return out + first_src_dtype = self.uops[idp[0]][1] + assert isinstance(first_src_dtype, DType) # mypy + dims, dtype_in, device, threads = arg[1], first_src_dtype.scalar(), arg[4], arg[5] # TODO: refactor these to a shared TensorCoreLayout in kernel.py - if arg[4] == "METAL": + if device == "METAL": # A (2 elements on 32 threads): row major def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16] # (i, j), C, D (2 elements on 32 threads): row major same as A/B def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4) ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map) - elif arg[4] == "AMD" and arg[5] == 64: + elif device == "AMD" and threads == 64: def a_elem(x, k, row, goff): return x[k%4][goff + (k//4)*16 + row] def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order def c_map(lane, elem): return (lane%16, (lane//16)*4 + elem) ul[i] = wmma_helper(64, 16, 4, 4, 4, a_elem, b_elem, c_map) - elif arg[4] == "AMD" and len(inp[0]) == 8: # RDNA4 + elif device == "AMD" and len(inp[0]) == 8: # RDNA4 def a_elem(x, k, row, goff): return x[k - [0, 4, 4, 8][k//4]][goff + row + [0, 16, 0, 16][k//4]] def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) def c_map(lane, elem): return (lane%16, (lane//16)*8 + elem) ul[i] = wmma_helper(32, 16, 8, 8, 8, a_elem, b_elem, c_map) - elif arg[4] == "AMD": + elif device == "AMD": # A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15 def a_elem(x, k, row, goff): assert x[k][goff+row] == x[k][goff+row+16], "warp elements not duplicated properly across lanes" @@ -149,27 +152,27 @@ class PythonProgram: def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map) - elif arg[4] == "CUDA": + elif device == "CUDA": # (col, row) given (lane, elem) for C & D (4 elements on 32 threads); shared by all tc shapes with M=16 N=8 def c_map(lane, elem): return (elem%2 + (lane%4)*2, lane//4 + (elem//2)*8) - if arg[1] == (8,16,16): + if dims == (8,16,16): def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2 + (k//8)*4][goff + (k//2)%4 + (row%8)*4] def b_elem(x, col, k, goff): return x[k%2 + (k//8)*2][goff + (k//2)%4 + col*4] ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map) - elif arg[1] == (8,16,8) and arg[2] == dtypes.half: + elif dims == (8,16,8) and dtype_in == dtypes.half: def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2][goff + k//2 + (row%8)*4] def b_elem(x, col, k, goff): return x[k%2][goff + k//2 + col*4] ul[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map) - elif arg[1] == (8,16,8) and arg[2] == dtypes.float: + elif dims == (8,16,8) and dtype_in == dtypes.float: def a_elem(x, k, row, goff): return x[(k//4)*2 + row//8][goff + k%4 + (row%8)*4] def b_elem(x, col, k, goff): return x[k//4][goff + k%4 + col*4] ul[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map) else: raise NotImplementedError(f"unimplemented tensor core {arg}") - elif arg[4] == "INTEL": + elif device == "INTEL": # A (16 elements on 8 threads) def a_elem(x, k, row, goff): return x[k%2+row*2][goff+k//2] # B (16 elements on 8 threads) @@ -177,7 +180,7 @@ class PythonProgram: # C, D (8 elements on 8 threads) def c_map(lane, elem): return (lane, elem) ul[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map) - elif arg[4] == "CPU": + elif device == "CPU": def elem(x, col, row, _): return x[col+row][0] # k is always 0 def c_map(_, elem): return (elem%16, elem//16) ul[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)