From bc178d14a953ba95b25f55047b88ee3594ca1c2d Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 31 Oct 2025 19:40:36 +0800 Subject: [PATCH] matmul example on metal showing off tensor core (#13033) * matmul example on metal showing off tensor core * flip the args of placeholder * mat_idx * imp --- extra/gemm/amd_uop_matmul.py | 23 ++++++++++--------- extra/gemm/metal_uop_matmul.py | 42 ++++++++++++++++++++++++++++++++++ test/test_uops.py | 8 +++---- tinygrad/uop/ops.py | 9 +++++--- 4 files changed, 64 insertions(+), 18 deletions(-) create mode 100644 extra/gemm/metal_uop_matmul.py diff --git a/extra/gemm/amd_uop_matmul.py b/extra/gemm/amd_uop_matmul.py index f269afdddd..1637a59987 100644 --- a/extra/gemm/amd_uop_matmul.py +++ b/extra/gemm/amd_uop_matmul.py @@ -68,17 +68,17 @@ def hand_spec_kernel3(): blockIdx_x = UOp.special(N // BLOCK_N, "gidx0") blockIdx_y = UOp.special(N // BLOCK_M, "gidx1") - a = UOp.placeholder(dtypes.float, (N, N), slot=1) - b = UOp.placeholder(dtypes.float, (N, N), slot=2) - c = UOp.placeholder(dtypes.float, (N, N), slot=0) + a = UOp.placeholder((N, N), dtypes.float, slot=1) + b = UOp.placeholder((N, N), dtypes.float, slot=2) + c = UOp.placeholder((N, N), dtypes.float, slot=0) BM_As_stride = (BLOCK_M + 4) if is_kernel5 else BLOCK_M - As = UOp.placeholder(dtypes.float, (BLOCK_K, BM_As_stride), slot=0, addrspace=AddrSpace.LOCAL).shrink_to((BLOCK_K, BLOCK_M)) - Bs = UOp.placeholder(dtypes.float, (BLOCK_K, BLOCK_N), slot=1, addrspace=AddrSpace.LOCAL) + As = UOp.placeholder((BLOCK_K, BM_As_stride), dtypes.float, slot=0, addrspace=AddrSpace.LOCAL).shrink_to((BLOCK_K, BLOCK_M)) + Bs = UOp.placeholder((BLOCK_K, BLOCK_N), dtypes.float, slot=1, addrspace=AddrSpace.LOCAL) - A_col = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_M, TM), slot=0, addrspace=AddrSpace.REG) - B_row = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_N, TN), slot=1, addrspace=AddrSpace.REG) - c_regs = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_M, TM, ITERS_PER_WAVE_N, TN), slot=2, addrspace=AddrSpace.REG) + A_col = UOp.placeholder((ITERS_PER_WAVE_M, TM), dtypes.float, slot=0, addrspace=AddrSpace.REG) + B_row = UOp.placeholder((ITERS_PER_WAVE_N, TN), dtypes.float, slot=1, addrspace=AddrSpace.REG) + c_regs = UOp.placeholder((ITERS_PER_WAVE_M, TM, ITERS_PER_WAVE_N, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG) i = UOp.range(c_regs.size, 16) c_regs = c_regs[i].set(0.0, end=i) @@ -151,15 +151,13 @@ def hand_spec_kernel3(): return sink.sink(arg=KernelInfo(opts_to_apply=())).simplify() - -if __name__ == "__main__": +def test_matmul(sink:UOp, N=N): with Context(DEBUG=0): a = Tensor.randn(N, N) b = Tensor.randn(N, N) hc = Tensor.empty(N, N) Tensor.realize(a, b, hc) - sink = hand_spec_kernel3() ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in [hc, a, b]]) GlobalCounters.reset() @@ -177,3 +175,6 @@ if __name__ == "__main__": print(f"mean squared error {err}") if err > 1e-06: raise RuntimeError("matmul is wrong!") + +if __name__ == "__main__": + test_matmul(hand_spec_kernel3(), N=N) diff --git a/extra/gemm/metal_uop_matmul.py b/extra/gemm/metal_uop_matmul.py new file mode 100644 index 0000000000..a2d619b45e --- /dev/null +++ b/extra/gemm/metal_uop_matmul.py @@ -0,0 +1,42 @@ +from tinygrad import UOp, dtypes +from tinygrad.uop.ops import AxisType, Ops, KernelInfo, AddrSpace +from extra.gemm.amd_uop_matmul import test_matmul + +N = 2048 + +# metal has an 8x8 tensor core. this is the indexing +def mat_idx(buf, g0, g1, warp, u): + l = [(warp//2**i)%2 for i in range(5)] + return buf[g0, l[4]*4 + l[2]*2 + l[1], g1, l[3]*4 + l[0]*2 + u] + +def hand_spec_tc_cores(): + gx = UOp.special(N // 8, "gidx0") + gy = UOp.special(N // 8, "gidx1") + warp = UOp.special(32, "lidx0") + + c = UOp.placeholder((N, N), dtypes.float, slot=0).reshape((N//8, 8, N//8, 8)) + a = UOp.placeholder((N, N), dtypes.float, slot=1).reshape((N//8, 8, N//8, 8)) + b = UOp.placeholder((N, N), dtypes.float, slot=2).reshape((N//8, 8, N//8, 8)) + + gk = UOp.range(N // 8, 0, AxisType.REDUCE) + + a_tc = UOp.vectorize(*[mat_idx(a, gx, gk, warp, i) for i in range(2)]) + b_tc = UOp.vectorize(*[mat_idx(b, gk, gy, warp, i) for i in range(2)]) + + acc = UOp.placeholder((2,), dtypes.float, slot=0, addrspace=AddrSpace.REG) + acc = acc[0].set(0.0) + acc = acc[1].set(0.0) + + # TODO: make this simple + wmma_arg = ('WMMA_8_8_8_float_float', (8, 8, 8), dtypes.float, dtypes.float, 'METAL', 32, (((3, 2),), ((3, 2),), ((3, 2),)), ()) + + acc_load = UOp.vectorize(acc.after(gk)[0], acc.after(gk)[1]) + out = UOp(Ops.WMMA, dtypes.float.vec(2), (a_tc, b_tc, acc_load), arg=wmma_arg) + + end_loop = UOp.group(*[acc[i].store(out.gep(i)) for i in range(2)]).end(gk) + + sink = UOp.group(*[mat_idx(c.after(end_loop), gx, gy, warp, i).store(acc[i]) for i in range(2)]) + return sink.sink(arg=KernelInfo(name="custom_metal_matmul", opts_to_apply=())).simplify() + +if __name__ == "__main__": + test_matmul(hand_spec_tc_cores(), N=N) diff --git a/test/test_uops.py b/test/test_uops.py index ac8ec77f67..eb23f10b33 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -572,7 +572,7 @@ class TestUOpPrograms(unittest.TestCase): def test_simple(self): out = Tensor.empty(10,10,dtype=dtypes.int) - ptr = UOp.placeholder(out.dtype, out.shape, slot=0) + ptr = UOp.placeholder(out.shape, out.dtype, slot=0) i, j = UOp.range(10, axis_id=0), UOp.range(10, axis_id=1) prog = ptr[i,j].set(42).end(i,j) self._run(prog.sink(), out) @@ -592,9 +592,9 @@ class TestUOpPrograms(unittest.TestCase): DT = dtypes.float32 # Placeholders (bind slots explicitly) - A = UOp.placeholder(DT, (M, K), slot=0) - B = UOp.placeholder(DT, (K, N), slot=1) - C = UOp.placeholder(DT, (M, N), slot=2) + A = UOp.placeholder((M, K), DT, slot=0) + B = UOp.placeholder((K, N), DT, slot=1) + C = UOp.placeholder((M, N), DT, slot=2) # Axes: i,j are spatial; k is a reduction axis over the shared dim K i = UOp.range(M, axis_id=0) # rows of A/C diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index bfdcce2c1d..e2dc13cf69 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -338,10 +338,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def group(*srcs:UOp|None): # pylint: disable=no-self-argument if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0] return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None])) + def vectorize(self, *srcs, **kwargs): + return UOp(Ops.VECTORIZE, self.dtype.vec(len(srcs)+1), (self,)+srcs, **kwargs) def detach(self): return UOp(Ops.DETACH, self.dtype, (self,)) def index(self, *srcs:UOp|None, ptr=False, **kwargs): return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs) - def __getitem__(self, idx): return self.index(*argfix(idx)) + def __getitem__(self, idx): + return self.index(*[UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in argfix(idx)]) def const_like(self, b:ConstLike): # constants can optionally have a DEVICE source return UOp.const(self.dtype, b, device=self._device, shape=self._shape) @@ -761,14 +764,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def shrink_to(self, arg:tuple[sint, ...]): return self.shrink(tuple([(0,x) for x in arg])) @staticmethod - def placeholder(dtype:DType, shape:tuple[int, ...], slot:int, addrspace=AddrSpace.GLOBAL): + def placeholder(shape:tuple[int, ...], dtype:DType, slot:int, addrspace=AddrSpace.GLOBAL): lookup = {AddrSpace.GLOBAL: Ops.DEFINE_GLOBAL, AddrSpace.LOCAL: Ops.DEFINE_LOCAL, AddrSpace.REG: Ops.DEFINE_REG} ret = UOp(lookup[addrspace], dtype.ptr(prod(shape), addrspace), arg=slot) if len(shape) > 1: ret = ret.reshape(shape) return ret def placeholder_like(self, slot:int): assert all_int(self.shape), "no placeholder-like on symbolic shape" - return UOp.placeholder(self.dtype, self.shape, slot) + return UOp.placeholder(self.shape, self.dtype, slot) # set is store+end+after def set(self:UOp, val:UOp|ConstType, end:UOp|tuple[UOp, ...]=()) -> UOp: