matmul example on metal showing off tensor core (#13033)

* matmul example on metal showing off tensor core

* flip the args of placeholder

* mat_idx

* imp
This commit is contained in:
George Hotz
2025-10-31 19:40:36 +08:00
committed by GitHub
parent e066b3176b
commit bc178d14a9
4 changed files with 64 additions and 18 deletions

View File

@@ -68,17 +68,17 @@ def hand_spec_kernel3():
blockIdx_x = UOp.special(N // BLOCK_N, "gidx0")
blockIdx_y = UOp.special(N // BLOCK_M, "gidx1")
a = UOp.placeholder(dtypes.float, (N, N), slot=1)
b = UOp.placeholder(dtypes.float, (N, N), slot=2)
c = UOp.placeholder(dtypes.float, (N, N), slot=0)
a = UOp.placeholder((N, N), dtypes.float, slot=1)
b = UOp.placeholder((N, N), dtypes.float, slot=2)
c = UOp.placeholder((N, N), dtypes.float, slot=0)
BM_As_stride = (BLOCK_M + 4) if is_kernel5 else BLOCK_M
As = UOp.placeholder(dtypes.float, (BLOCK_K, BM_As_stride), slot=0, addrspace=AddrSpace.LOCAL).shrink_to((BLOCK_K, BLOCK_M))
Bs = UOp.placeholder(dtypes.float, (BLOCK_K, BLOCK_N), slot=1, addrspace=AddrSpace.LOCAL)
As = UOp.placeholder((BLOCK_K, BM_As_stride), dtypes.float, slot=0, addrspace=AddrSpace.LOCAL).shrink_to((BLOCK_K, BLOCK_M))
Bs = UOp.placeholder((BLOCK_K, BLOCK_N), dtypes.float, slot=1, addrspace=AddrSpace.LOCAL)
A_col = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_M, TM), slot=0, addrspace=AddrSpace.REG)
B_row = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_N, TN), slot=1, addrspace=AddrSpace.REG)
c_regs = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_M, TM, ITERS_PER_WAVE_N, TN), slot=2, addrspace=AddrSpace.REG)
A_col = UOp.placeholder((ITERS_PER_WAVE_M, TM), dtypes.float, slot=0, addrspace=AddrSpace.REG)
B_row = UOp.placeholder((ITERS_PER_WAVE_N, TN), dtypes.float, slot=1, addrspace=AddrSpace.REG)
c_regs = UOp.placeholder((ITERS_PER_WAVE_M, TM, ITERS_PER_WAVE_N, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
i = UOp.range(c_regs.size, 16)
c_regs = c_regs[i].set(0.0, end=i)
@@ -151,15 +151,13 @@ def hand_spec_kernel3():
return sink.sink(arg=KernelInfo(opts_to_apply=())).simplify()
if __name__ == "__main__":
def test_matmul(sink:UOp, N=N):
with Context(DEBUG=0):
a = Tensor.randn(N, N)
b = Tensor.randn(N, N)
hc = Tensor.empty(N, N)
Tensor.realize(a, b, hc)
sink = hand_spec_kernel3()
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in [hc, a, b]])
GlobalCounters.reset()
@@ -177,3 +175,6 @@ if __name__ == "__main__":
print(f"mean squared error {err}")
if err > 1e-06:
raise RuntimeError("matmul is wrong!")
if __name__ == "__main__":
test_matmul(hand_spec_kernel3(), N=N)