matmul example on metal showing off tensor core (#13033)

* matmul example on metal showing off tensor core

* flip the args of placeholder

* mat_idx

* imp
This commit is contained in:
George Hotz
2025-10-31 19:40:36 +08:00
committed by GitHub
parent e066b3176b
commit bc178d14a9
4 changed files with 64 additions and 18 deletions

View File

@@ -68,17 +68,17 @@ def hand_spec_kernel3():
blockIdx_x = UOp.special(N // BLOCK_N, "gidx0")
blockIdx_y = UOp.special(N // BLOCK_M, "gidx1")
a = UOp.placeholder(dtypes.float, (N, N), slot=1)
b = UOp.placeholder(dtypes.float, (N, N), slot=2)
c = UOp.placeholder(dtypes.float, (N, N), slot=0)
a = UOp.placeholder((N, N), dtypes.float, slot=1)
b = UOp.placeholder((N, N), dtypes.float, slot=2)
c = UOp.placeholder((N, N), dtypes.float, slot=0)
BM_As_stride = (BLOCK_M + 4) if is_kernel5 else BLOCK_M
As = UOp.placeholder(dtypes.float, (BLOCK_K, BM_As_stride), slot=0, addrspace=AddrSpace.LOCAL).shrink_to((BLOCK_K, BLOCK_M))
Bs = UOp.placeholder(dtypes.float, (BLOCK_K, BLOCK_N), slot=1, addrspace=AddrSpace.LOCAL)
As = UOp.placeholder((BLOCK_K, BM_As_stride), dtypes.float, slot=0, addrspace=AddrSpace.LOCAL).shrink_to((BLOCK_K, BLOCK_M))
Bs = UOp.placeholder((BLOCK_K, BLOCK_N), dtypes.float, slot=1, addrspace=AddrSpace.LOCAL)
A_col = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_M, TM), slot=0, addrspace=AddrSpace.REG)
B_row = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_N, TN), slot=1, addrspace=AddrSpace.REG)
c_regs = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_M, TM, ITERS_PER_WAVE_N, TN), slot=2, addrspace=AddrSpace.REG)
A_col = UOp.placeholder((ITERS_PER_WAVE_M, TM), dtypes.float, slot=0, addrspace=AddrSpace.REG)
B_row = UOp.placeholder((ITERS_PER_WAVE_N, TN), dtypes.float, slot=1, addrspace=AddrSpace.REG)
c_regs = UOp.placeholder((ITERS_PER_WAVE_M, TM, ITERS_PER_WAVE_N, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
i = UOp.range(c_regs.size, 16)
c_regs = c_regs[i].set(0.0, end=i)
@@ -151,15 +151,13 @@ def hand_spec_kernel3():
return sink.sink(arg=KernelInfo(opts_to_apply=())).simplify()
if __name__ == "__main__":
def test_matmul(sink:UOp, N=N):
with Context(DEBUG=0):
a = Tensor.randn(N, N)
b = Tensor.randn(N, N)
hc = Tensor.empty(N, N)
Tensor.realize(a, b, hc)
sink = hand_spec_kernel3()
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in [hc, a, b]])
GlobalCounters.reset()
@@ -177,3 +175,6 @@ if __name__ == "__main__":
print(f"mean squared error {err}")
if err > 1e-06:
raise RuntimeError("matmul is wrong!")
if __name__ == "__main__":
test_matmul(hand_spec_kernel3(), N=N)

View File

@@ -0,0 +1,42 @@
from tinygrad import UOp, dtypes
from tinygrad.uop.ops import AxisType, Ops, KernelInfo, AddrSpace
from extra.gemm.amd_uop_matmul import test_matmul
N = 2048
# metal has an 8x8 tensor core. this is the indexing
def mat_idx(buf, g0, g1, warp, u):
l = [(warp//2**i)%2 for i in range(5)]
return buf[g0, l[4]*4 + l[2]*2 + l[1], g1, l[3]*4 + l[0]*2 + u]
def hand_spec_tc_cores():
gx = UOp.special(N // 8, "gidx0")
gy = UOp.special(N // 8, "gidx1")
warp = UOp.special(32, "lidx0")
c = UOp.placeholder((N, N), dtypes.float, slot=0).reshape((N//8, 8, N//8, 8))
a = UOp.placeholder((N, N), dtypes.float, slot=1).reshape((N//8, 8, N//8, 8))
b = UOp.placeholder((N, N), dtypes.float, slot=2).reshape((N//8, 8, N//8, 8))
gk = UOp.range(N // 8, 0, AxisType.REDUCE)
a_tc = UOp.vectorize(*[mat_idx(a, gx, gk, warp, i) for i in range(2)])
b_tc = UOp.vectorize(*[mat_idx(b, gk, gy, warp, i) for i in range(2)])
acc = UOp.placeholder((2,), dtypes.float, slot=0, addrspace=AddrSpace.REG)
acc = acc[0].set(0.0)
acc = acc[1].set(0.0)
# TODO: make this simple
wmma_arg = ('WMMA_8_8_8_float_float', (8, 8, 8), dtypes.float, dtypes.float, 'METAL', 32, (((3, 2),), ((3, 2),), ((3, 2),)), ())
acc_load = UOp.vectorize(acc.after(gk)[0], acc.after(gk)[1])
out = UOp(Ops.WMMA, dtypes.float.vec(2), (a_tc, b_tc, acc_load), arg=wmma_arg)
end_loop = UOp.group(*[acc[i].store(out.gep(i)) for i in range(2)]).end(gk)
sink = UOp.group(*[mat_idx(c.after(end_loop), gx, gy, warp, i).store(acc[i]) for i in range(2)])
return sink.sink(arg=KernelInfo(name="custom_metal_matmul", opts_to_apply=())).simplify()
if __name__ == "__main__":
test_matmul(hand_spec_tc_cores(), N=N)

View File

@@ -572,7 +572,7 @@ class TestUOpPrograms(unittest.TestCase):
def test_simple(self):
out = Tensor.empty(10,10,dtype=dtypes.int)
ptr = UOp.placeholder(out.dtype, out.shape, slot=0)
ptr = UOp.placeholder(out.shape, out.dtype, slot=0)
i, j = UOp.range(10, axis_id=0), UOp.range(10, axis_id=1)
prog = ptr[i,j].set(42).end(i,j)
self._run(prog.sink(), out)
@@ -592,9 +592,9 @@ class TestUOpPrograms(unittest.TestCase):
DT = dtypes.float32
# Placeholders (bind slots explicitly)
A = UOp.placeholder(DT, (M, K), slot=0)
B = UOp.placeholder(DT, (K, N), slot=1)
C = UOp.placeholder(DT, (M, N), slot=2)
A = UOp.placeholder((M, K), DT, slot=0)
B = UOp.placeholder((K, N), DT, slot=1)
C = UOp.placeholder((M, N), DT, slot=2)
# Axes: i,j are spatial; k is a reduction axis over the shared dim K
i = UOp.range(M, axis_id=0) # rows of A/C

View File

@@ -338,10 +338,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def group(*srcs:UOp|None): # pylint: disable=no-self-argument
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
def vectorize(self, *srcs, **kwargs):
return UOp(Ops.VECTORIZE, self.dtype.vec(len(srcs)+1), (self,)+srcs, **kwargs)
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def __getitem__(self, idx): return self.index(*argfix(idx))
def __getitem__(self, idx):
return self.index(*[UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in argfix(idx)])
def const_like(self, b:ConstLike):
# constants can optionally have a DEVICE source
return UOp.const(self.dtype, b, device=self._device, shape=self._shape)
@@ -761,14 +764,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def shrink_to(self, arg:tuple[sint, ...]): return self.shrink(tuple([(0,x) for x in arg]))
@staticmethod
def placeholder(dtype:DType, shape:tuple[int, ...], slot:int, addrspace=AddrSpace.GLOBAL):
def placeholder(shape:tuple[int, ...], dtype:DType, slot:int, addrspace=AddrSpace.GLOBAL):
lookup = {AddrSpace.GLOBAL: Ops.DEFINE_GLOBAL, AddrSpace.LOCAL: Ops.DEFINE_LOCAL, AddrSpace.REG: Ops.DEFINE_REG}
ret = UOp(lookup[addrspace], dtype.ptr(prod(shape), addrspace), arg=slot)
if len(shape) > 1: ret = ret.reshape(shape)
return ret
def placeholder_like(self, slot:int):
assert all_int(self.shape), "no placeholder-like on symbolic shape"
return UOp.placeholder(self.dtype, self.shape, slot)
return UOp.placeholder(self.shape, self.dtype, slot)
# set is store+end+after
def set(self:UOp, val:UOp|ConstType, end:UOp|tuple[UOp, ...]=()) -> UOp: