mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
matmul example on metal showing off tensor core (#13033)
* matmul example on metal showing off tensor core * flip the args of placeholder * mat_idx * imp
This commit is contained in:
@@ -68,17 +68,17 @@ def hand_spec_kernel3():
|
||||
blockIdx_x = UOp.special(N // BLOCK_N, "gidx0")
|
||||
blockIdx_y = UOp.special(N // BLOCK_M, "gidx1")
|
||||
|
||||
a = UOp.placeholder(dtypes.float, (N, N), slot=1)
|
||||
b = UOp.placeholder(dtypes.float, (N, N), slot=2)
|
||||
c = UOp.placeholder(dtypes.float, (N, N), slot=0)
|
||||
a = UOp.placeholder((N, N), dtypes.float, slot=1)
|
||||
b = UOp.placeholder((N, N), dtypes.float, slot=2)
|
||||
c = UOp.placeholder((N, N), dtypes.float, slot=0)
|
||||
|
||||
BM_As_stride = (BLOCK_M + 4) if is_kernel5 else BLOCK_M
|
||||
As = UOp.placeholder(dtypes.float, (BLOCK_K, BM_As_stride), slot=0, addrspace=AddrSpace.LOCAL).shrink_to((BLOCK_K, BLOCK_M))
|
||||
Bs = UOp.placeholder(dtypes.float, (BLOCK_K, BLOCK_N), slot=1, addrspace=AddrSpace.LOCAL)
|
||||
As = UOp.placeholder((BLOCK_K, BM_As_stride), dtypes.float, slot=0, addrspace=AddrSpace.LOCAL).shrink_to((BLOCK_K, BLOCK_M))
|
||||
Bs = UOp.placeholder((BLOCK_K, BLOCK_N), dtypes.float, slot=1, addrspace=AddrSpace.LOCAL)
|
||||
|
||||
A_col = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_M, TM), slot=0, addrspace=AddrSpace.REG)
|
||||
B_row = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_N, TN), slot=1, addrspace=AddrSpace.REG)
|
||||
c_regs = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_M, TM, ITERS_PER_WAVE_N, TN), slot=2, addrspace=AddrSpace.REG)
|
||||
A_col = UOp.placeholder((ITERS_PER_WAVE_M, TM), dtypes.float, slot=0, addrspace=AddrSpace.REG)
|
||||
B_row = UOp.placeholder((ITERS_PER_WAVE_N, TN), dtypes.float, slot=1, addrspace=AddrSpace.REG)
|
||||
c_regs = UOp.placeholder((ITERS_PER_WAVE_M, TM, ITERS_PER_WAVE_N, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
|
||||
|
||||
i = UOp.range(c_regs.size, 16)
|
||||
c_regs = c_regs[i].set(0.0, end=i)
|
||||
@@ -151,15 +151,13 @@ def hand_spec_kernel3():
|
||||
|
||||
return sink.sink(arg=KernelInfo(opts_to_apply=())).simplify()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def test_matmul(sink:UOp, N=N):
|
||||
with Context(DEBUG=0):
|
||||
a = Tensor.randn(N, N)
|
||||
b = Tensor.randn(N, N)
|
||||
hc = Tensor.empty(N, N)
|
||||
Tensor.realize(a, b, hc)
|
||||
|
||||
sink = hand_spec_kernel3()
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in [hc, a, b]])
|
||||
|
||||
GlobalCounters.reset()
|
||||
@@ -177,3 +175,6 @@ if __name__ == "__main__":
|
||||
print(f"mean squared error {err}")
|
||||
if err > 1e-06:
|
||||
raise RuntimeError("matmul is wrong!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_matmul(hand_spec_kernel3(), N=N)
|
||||
|
||||
42
extra/gemm/metal_uop_matmul.py
Normal file
42
extra/gemm/metal_uop_matmul.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from tinygrad import UOp, dtypes
|
||||
from tinygrad.uop.ops import AxisType, Ops, KernelInfo, AddrSpace
|
||||
from extra.gemm.amd_uop_matmul import test_matmul
|
||||
|
||||
N = 2048
|
||||
|
||||
# metal has an 8x8 tensor core. this is the indexing
|
||||
def mat_idx(buf, g0, g1, warp, u):
|
||||
l = [(warp//2**i)%2 for i in range(5)]
|
||||
return buf[g0, l[4]*4 + l[2]*2 + l[1], g1, l[3]*4 + l[0]*2 + u]
|
||||
|
||||
def hand_spec_tc_cores():
|
||||
gx = UOp.special(N // 8, "gidx0")
|
||||
gy = UOp.special(N // 8, "gidx1")
|
||||
warp = UOp.special(32, "lidx0")
|
||||
|
||||
c = UOp.placeholder((N, N), dtypes.float, slot=0).reshape((N//8, 8, N//8, 8))
|
||||
a = UOp.placeholder((N, N), dtypes.float, slot=1).reshape((N//8, 8, N//8, 8))
|
||||
b = UOp.placeholder((N, N), dtypes.float, slot=2).reshape((N//8, 8, N//8, 8))
|
||||
|
||||
gk = UOp.range(N // 8, 0, AxisType.REDUCE)
|
||||
|
||||
a_tc = UOp.vectorize(*[mat_idx(a, gx, gk, warp, i) for i in range(2)])
|
||||
b_tc = UOp.vectorize(*[mat_idx(b, gk, gy, warp, i) for i in range(2)])
|
||||
|
||||
acc = UOp.placeholder((2,), dtypes.float, slot=0, addrspace=AddrSpace.REG)
|
||||
acc = acc[0].set(0.0)
|
||||
acc = acc[1].set(0.0)
|
||||
|
||||
# TODO: make this simple
|
||||
wmma_arg = ('WMMA_8_8_8_float_float', (8, 8, 8), dtypes.float, dtypes.float, 'METAL', 32, (((3, 2),), ((3, 2),), ((3, 2),)), ())
|
||||
|
||||
acc_load = UOp.vectorize(acc.after(gk)[0], acc.after(gk)[1])
|
||||
out = UOp(Ops.WMMA, dtypes.float.vec(2), (a_tc, b_tc, acc_load), arg=wmma_arg)
|
||||
|
||||
end_loop = UOp.group(*[acc[i].store(out.gep(i)) for i in range(2)]).end(gk)
|
||||
|
||||
sink = UOp.group(*[mat_idx(c.after(end_loop), gx, gy, warp, i).store(acc[i]) for i in range(2)])
|
||||
return sink.sink(arg=KernelInfo(name="custom_metal_matmul", opts_to_apply=())).simplify()
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_matmul(hand_spec_tc_cores(), N=N)
|
||||
@@ -572,7 +572,7 @@ class TestUOpPrograms(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
out = Tensor.empty(10,10,dtype=dtypes.int)
|
||||
|
||||
ptr = UOp.placeholder(out.dtype, out.shape, slot=0)
|
||||
ptr = UOp.placeholder(out.shape, out.dtype, slot=0)
|
||||
i, j = UOp.range(10, axis_id=0), UOp.range(10, axis_id=1)
|
||||
prog = ptr[i,j].set(42).end(i,j)
|
||||
self._run(prog.sink(), out)
|
||||
@@ -592,9 +592,9 @@ class TestUOpPrograms(unittest.TestCase):
|
||||
DT = dtypes.float32
|
||||
|
||||
# Placeholders (bind slots explicitly)
|
||||
A = UOp.placeholder(DT, (M, K), slot=0)
|
||||
B = UOp.placeholder(DT, (K, N), slot=1)
|
||||
C = UOp.placeholder(DT, (M, N), slot=2)
|
||||
A = UOp.placeholder((M, K), DT, slot=0)
|
||||
B = UOp.placeholder((K, N), DT, slot=1)
|
||||
C = UOp.placeholder((M, N), DT, slot=2)
|
||||
|
||||
# Axes: i,j are spatial; k is a reduction axis over the shared dim K
|
||||
i = UOp.range(M, axis_id=0) # rows of A/C
|
||||
|
||||
@@ -338,10 +338,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def group(*srcs:UOp|None): # pylint: disable=no-self-argument
|
||||
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
|
||||
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
|
||||
def vectorize(self, *srcs, **kwargs):
|
||||
return UOp(Ops.VECTORIZE, self.dtype.vec(len(srcs)+1), (self,)+srcs, **kwargs)
|
||||
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
|
||||
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
|
||||
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def __getitem__(self, idx): return self.index(*argfix(idx))
|
||||
def __getitem__(self, idx):
|
||||
return self.index(*[UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in argfix(idx)])
|
||||
def const_like(self, b:ConstLike):
|
||||
# constants can optionally have a DEVICE source
|
||||
return UOp.const(self.dtype, b, device=self._device, shape=self._shape)
|
||||
@@ -761,14 +764,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def shrink_to(self, arg:tuple[sint, ...]): return self.shrink(tuple([(0,x) for x in arg]))
|
||||
|
||||
@staticmethod
|
||||
def placeholder(dtype:DType, shape:tuple[int, ...], slot:int, addrspace=AddrSpace.GLOBAL):
|
||||
def placeholder(shape:tuple[int, ...], dtype:DType, slot:int, addrspace=AddrSpace.GLOBAL):
|
||||
lookup = {AddrSpace.GLOBAL: Ops.DEFINE_GLOBAL, AddrSpace.LOCAL: Ops.DEFINE_LOCAL, AddrSpace.REG: Ops.DEFINE_REG}
|
||||
ret = UOp(lookup[addrspace], dtype.ptr(prod(shape), addrspace), arg=slot)
|
||||
if len(shape) > 1: ret = ret.reshape(shape)
|
||||
return ret
|
||||
def placeholder_like(self, slot:int):
|
||||
assert all_int(self.shape), "no placeholder-like on symbolic shape"
|
||||
return UOp.placeholder(self.dtype, self.shape, slot)
|
||||
return UOp.placeholder(self.shape, self.dtype, slot)
|
||||
|
||||
# set is store+end+after
|
||||
def set(self:UOp, val:UOp|ConstType, end:UOp|tuple[UOp, ...]=()) -> UOp:
|
||||
|
||||
Reference in New Issue
Block a user