mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
matmul example on metal showing off tensor core (#13033)
* matmul example on metal showing off tensor core * flip the args of placeholder * mat_idx * imp
This commit is contained in:
@@ -68,17 +68,17 @@ def hand_spec_kernel3():
|
||||
blockIdx_x = UOp.special(N // BLOCK_N, "gidx0")
|
||||
blockIdx_y = UOp.special(N // BLOCK_M, "gidx1")
|
||||
|
||||
a = UOp.placeholder(dtypes.float, (N, N), slot=1)
|
||||
b = UOp.placeholder(dtypes.float, (N, N), slot=2)
|
||||
c = UOp.placeholder(dtypes.float, (N, N), slot=0)
|
||||
a = UOp.placeholder((N, N), dtypes.float, slot=1)
|
||||
b = UOp.placeholder((N, N), dtypes.float, slot=2)
|
||||
c = UOp.placeholder((N, N), dtypes.float, slot=0)
|
||||
|
||||
BM_As_stride = (BLOCK_M + 4) if is_kernel5 else BLOCK_M
|
||||
As = UOp.placeholder(dtypes.float, (BLOCK_K, BM_As_stride), slot=0, addrspace=AddrSpace.LOCAL).shrink_to((BLOCK_K, BLOCK_M))
|
||||
Bs = UOp.placeholder(dtypes.float, (BLOCK_K, BLOCK_N), slot=1, addrspace=AddrSpace.LOCAL)
|
||||
As = UOp.placeholder((BLOCK_K, BM_As_stride), dtypes.float, slot=0, addrspace=AddrSpace.LOCAL).shrink_to((BLOCK_K, BLOCK_M))
|
||||
Bs = UOp.placeholder((BLOCK_K, BLOCK_N), dtypes.float, slot=1, addrspace=AddrSpace.LOCAL)
|
||||
|
||||
A_col = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_M, TM), slot=0, addrspace=AddrSpace.REG)
|
||||
B_row = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_N, TN), slot=1, addrspace=AddrSpace.REG)
|
||||
c_regs = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_M, TM, ITERS_PER_WAVE_N, TN), slot=2, addrspace=AddrSpace.REG)
|
||||
A_col = UOp.placeholder((ITERS_PER_WAVE_M, TM), dtypes.float, slot=0, addrspace=AddrSpace.REG)
|
||||
B_row = UOp.placeholder((ITERS_PER_WAVE_N, TN), dtypes.float, slot=1, addrspace=AddrSpace.REG)
|
||||
c_regs = UOp.placeholder((ITERS_PER_WAVE_M, TM, ITERS_PER_WAVE_N, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
|
||||
|
||||
i = UOp.range(c_regs.size, 16)
|
||||
c_regs = c_regs[i].set(0.0, end=i)
|
||||
@@ -151,15 +151,13 @@ def hand_spec_kernel3():
|
||||
|
||||
return sink.sink(arg=KernelInfo(opts_to_apply=())).simplify()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def test_matmul(sink:UOp, N=N):
|
||||
with Context(DEBUG=0):
|
||||
a = Tensor.randn(N, N)
|
||||
b = Tensor.randn(N, N)
|
||||
hc = Tensor.empty(N, N)
|
||||
Tensor.realize(a, b, hc)
|
||||
|
||||
sink = hand_spec_kernel3()
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in [hc, a, b]])
|
||||
|
||||
GlobalCounters.reset()
|
||||
@@ -177,3 +175,6 @@ if __name__ == "__main__":
|
||||
print(f"mean squared error {err}")
|
||||
if err > 1e-06:
|
||||
raise RuntimeError("matmul is wrong!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_matmul(hand_spec_kernel3(), N=N)
|
||||
|
||||
Reference in New Issue
Block a user