From c13d9d29ff37d34b61cd20ef7d33779406c31a57 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 21 Mar 2026 16:16:03 +0800 Subject: [PATCH] add SHAPED_WMMA (#15400) * add SHAPED_WMMA * shaped wmma * less bad --- extra/gemm/amd_copy_matmul.py | 13 ++----------- tinygrad/schedule/rangeify.py | 13 +++++++++++++ tinygrad/uop/__init__.py | 2 +- tinygrad/uop/ops.py | 3 +++ tinygrad/uop/spec.py | 4 ++++ tinygrad/viz/serve.py | 2 +- 6 files changed, 24 insertions(+), 13 deletions(-) diff --git a/extra/gemm/amd_copy_matmul.py b/extra/gemm/amd_copy_matmul.py index 85e6479447..956c0bba7b 100644 --- a/extra/gemm/amd_copy_matmul.py +++ b/extra/gemm/amd_copy_matmul.py @@ -72,17 +72,8 @@ def block_128x128_gemm(c:UOp, a:UOp, b:UOp) -> UOp: a_frag = A_local.reshape(WAVES_M, TM // WMMA_ACC, WMMA_M, BLOCK_K // WMMA_K, WMMA_K)[wave_m, tile_m, lane_n, k] b_frag = B_local.reshape(WAVES_N, TN, WMMA_N, BLOCK_K // WMMA_K, WMMA_K)[wave_n, tile_n, lane_n, k] - # TODO: remove unneeded CONTRACTS - k_upcast_a = UOp.range(WMMA_K, 301, axis_type=AxisType.UPCAST) - k_upcast_b = UOp.range(WMMA_K, 311, axis_type=AxisType.UPCAST) - acc_upcast = UOp.range(WMMA_ACC, 302, axis_type=AxisType.UPCAST) - wmma_arg = ('WMMA_16_16_16_half_float', (16, 16, 16), dtypes.half, dtypes.float, 'AMD', 32, - (((301, 16),), ((311, 16),), ((302, WMMA_ACC),)), ()) - out = UOp(Ops.WMMA, dtypes.float.vec(WMMA_ACC), (a_frag[k_upcast_a].contract(k_upcast_a), - b_frag[k_upcast_b].contract(k_upcast_b), - acc_frag.after(k)[acc_upcast].contract(acc_upcast)), arg=wmma_arg) - - acc_store = UOp.group(*[acc_frag[e].store(out.gep(e)) for e in range(WMMA_ACC)]).end(tile_m, tile_n) + wmma = UOp(Ops.SHAPED_WMMA, dtypes.float, (a_frag, b_frag, acc_frag.after(k)), arg=((16, 16, 16), 'AMD', 32)) + acc_store = acc_frag.store(wmma).end(tile_m, tile_n) else: # registers for LOCAL -> REG a_frag = UOp.placeholder((TM//UNROLL_M, UNROLL_M), dtypes.float, slot=0, addrspace=AddrSpace.REG) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index d3ce0fcdcf..6a03638f41 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -22,6 +22,17 @@ def add_ranges_to_store(ctx, x): idxs = [UOp.range(r, next(ctx), AxisType.LOOP) for r in x.src[0].shape] return UOp.store(x.src[0].index(*idxs), x.src[1].index(*idxs)).end(*idxs) +def lower_shaped_wmma(ctx, x): + dims, device, threads = x.arg + dtype_in, dtype_out = x.src[0].dtype.base, x.dtype + upcasts = [(s, UOp.range(s.shape[-1], next(ctx), axis_type=AxisType.UPCAST)) for s in x.src] + tc_upcast_axes = tuple(((u.arg[0], s.shape[-1]),) for s, u in upcasts) + name = f"WMMA_{'_'.join(map(str, dims))}_{dtype_in.name}_{dtype_out.name}" + wmma_arg = (name, dims, dtype_in, dtype_out, device, threads, tc_upcast_axes, ()) + wmma = UOp(Ops.WMMA, dtype_out.vec(x.src[2].shape[-1]), tuple(s[u].contract(u) for s, u in upcasts), arg=wmma_arg) + tmp = UOp.placeholder((x.src[2].shape[-1],), dtype_out, slot=next(ctx), addrspace=AddrSpace.REG) + return tmp.after(UOp.group(*[tmp[e].store(wmma.gep(e)) for e in range(x.src[2].shape[-1])])) + pm_store_ranges = PatternMatcher([ (UPat(Ops.STORE, name="x"), add_ranges_to_store), ]) @@ -63,6 +74,8 @@ pm_mops = PatternMatcher([ lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:]),)+r.src[1:], r.arg) if a.src[0]._shape is not None and not any(s.op is Ops.STORE and s.src[0]._shape is not None for s in a.src[1:]) else None), (UPat(GroupOp.Movement, name="r").end(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:])), + # lower SHAPED_WMMA to WMMA with CONTRACT/UNROLL + (UPat(Ops.SHAPED_WMMA, name="x"), lower_shaped_wmma), ]) # ***************** diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index 2780f23eb5..a67a2f8e98 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -53,7 +53,7 @@ class Ops(FastEnum): # ** 4 -- math ** # tensor core math op, not elementwise - WMMA = auto() + WMMA = auto(); SHAPED_WMMA = auto() # UnaryOps CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto() diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 76a944639b..ffe9a84303 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -249,6 +249,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass): if len(self.src) >= 1: return tuple(self.src[0].sgep(i) for i in range(self.src[0].dtype.count)) return None + # SHAPED_WMMA output shape = accumulator shape (src[2]) + case Ops.SHAPED_WMMA: return self.src[2]._shape + # passthrough ops case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END: return self.src[0]._shape diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index bde80ab2f7..0874e8644c 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -205,6 +205,10 @@ kernel_spec = PatternMatcher([ (UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)), (UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)), + # SHAPED_WMMA has with shaped inputs, arg=((M,N,K), device, threads), lowered to WMMA+CONTRACT later + (UPat(Ops.SHAPED_WMMA, src=(UPat(), UPat(), UPat()), name="x"), + lambda x: isinstance(x.arg, tuple) and len(x.arg) == 3 and isinstance(x.arg[0], tuple)), + # END can end multiple axes here (UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True), lambda: True), diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index f3b0b838ca..d16d542812 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -44,7 +44,7 @@ from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphE from tinygrad.dtype import dtypes uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B", - **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B", + **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B", Ops.SHAPED_WMMA: "#FF5B5B", Ops.RANGE: "#c8a0e0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff", Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",