From ddbdb52f77ddedbe1afdec4003dad60ca09cc1ce Mon Sep 17 00:00:00 2001 From: Francis Lam Date: Fri, 12 Jan 2024 13:25:28 -0800 Subject: [PATCH] wmma: enable METAL half tensor cores and clean up cstyle (#3095) * wmma: enable METAL half tensor cores and clean up cstyle * revert simple_matmul rand changes and break line in tensor * added metal fp16->fp32 tensor core --- extra/gemm/simple_matmul.py | 4 ++-- test/test_linearizer.py | 5 +---- tinygrad/codegen/kernel.py | 10 +++++----- tinygrad/codegen/linearizer.py | 28 ++++++++++++++-------------- tinygrad/features/image.py | 8 ++++---- tinygrad/renderer/cstyle.py | 9 +++------ tinygrad/tensor.py | 15 ++++++++------- 7 files changed, 37 insertions(+), 42 deletions(-) diff --git a/extra/gemm/simple_matmul.py b/extra/gemm/simple_matmul.py index 08272d12a4..2d362c9b15 100644 --- a/extra/gemm/simple_matmul.py +++ b/extra/gemm/simple_matmul.py @@ -2,14 +2,14 @@ import numpy as np from tinygrad.helpers import getenv from tinygrad import dtypes, Tensor dtype_in = dtypes.half if getenv("HALF") else dtypes.float +acc_dtype = dtypes.half if getenv("ACC_HALF") else None N = getenv("N", 4096) CNT = getenv("CNT", 10) a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize() for i in range(CNT): if i > 0 and getenv("RAND", 0) != 0: a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize() - # NOTE: accumulate is in float32 - c = (a @ b).realize() + c = a.matmul(b, acc_dtype=acc_dtype).realize() comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32) nc = c.numpy() np.testing.assert_allclose(nc, comp, atol=1e-4, rtol=3e-2) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index fa0bff627d..6ab43defa2 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -87,10 +87,7 @@ class TestLinearizer(unittest.TestCase): if tc.arch is not None and tc.arch != os.uname().machine: continue a, b = Tensor.rand(tc.dims[0], tc.dims[2], dtype=tc.dtype_in), Tensor.rand(tc.dims[2], tc.dims[1], dtype=tc.dtype_in) np_a, np_b = a.numpy(), b.numpy() - if tc.dtype_out != tc.dtype_in: - r = (a.reshape(tc.dims[0], 1, tc.dims[2]) * b.permute(1,0).reshape(1, tc.dims[1], tc.dims[2])).cast(tc.dtype_out).sum(axis=2) - else: - r = a @ b + r = a.matmul(b, acc_dtype=tc.dtype_out) realized_ast, _ = helper_realized_ast(r) k = Linearizer(realized_ast) k.apply_tensor_cores(1) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 2849905a4d..e29f6c0900 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -33,18 +33,18 @@ class TensorCore: upcast_dim: int # which TC dim to upcast thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501 thread_local_sizes: List[int] # in each thread, the number of elements stored in registers for each TC dim + wmma_func: str # name of wmma function to call arch: Optional[str] = None def __str__(self): return f"tensor_core<{self.device}, {self.dims}, {self.dtype_in}, {self.dtype_out}>" tensor_cores: Dict[str, List[TensorCore]] = { "METAL": [ - TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501 - # TODO: enable half @ half -> half tensor core with correct dtypes in uop - # TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501 + TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, wmma_func="__metal_wmma", upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501 + TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__metal_wmma", upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501 + TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__metal_wmma", upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501 ], "HIP": [ - TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501 - TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501 + TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__builtin_amdgcn_wmma_f32_16x16x16_f16_w32", upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501 ] } diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 1c33d58460..c51c9245fb 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -265,7 +265,7 @@ class Linearizer(Kernel): # define accumulator acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop)) - if self.tensor_core: + if (tc:=self.tensor_core): def calc_tc_idxs(local_size: int, aliases: List[List[int]]): replace_idxs = [] for alias in aliases: @@ -277,11 +277,11 @@ class Linearizer(Kernel): full_var_sz *= next_var.max+1 replace_idxs.append(full_var) return replace_idxs - replace_acc_idxs = calc_tc_idxs(self.tensor_core.thread_local_sizes[2], self.tensor_core.thread_local_aliases[2]) - for n in range(len(self.tensor_core.threads)): - local_idxs[self.local_dims-len(self.tensor_core.threads)+n] = replace_acc_idxs[n] # replace locals - for n in range(len(replace_acc_idxs)-len(self.tensor_core.threads)): - upcast_idxs[n] = replace_acc_idxs[len(self.tensor_core.threads)+n] # replace upcasts + replace_acc_idxs = calc_tc_idxs(tc.thread_local_sizes[2], tc.thread_local_aliases[2]) + for n in range(len(tc.threads)): + local_idxs[self.local_dims-len(tc.threads)+n] = replace_acc_idxs[n] # replace locals + for n in range(len(replace_acc_idxs)-len(tc.threads)): + upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts # reduce loop loop_ctx = render_loop(reduce_idxs) @@ -306,8 +306,8 @@ class Linearizer(Kernel): locals_to_store.append((localbuf_idx, buf_idxs, ll)) # copy in any global buffers - if self.tensor_core: - wmma_sz, dtype_in, dtype_out = self.tensor_core.thread_local_sizes, self.tensor_core.dtype_in, self.tensor_core.dtype_out + if (tc:=self.tensor_core): + wmma_sz = tc.thread_local_sizes # calculate the number of local accumulator reduces and render WMMAs: this is bad... this needs to come from someplace else nx, ny, nacc = (len(locals_to_store[0][2])//wmma_sz[0]), (len(locals_to_store[1][2])//wmma_sz[1]), (len(acc)//wmma_sz[2]) acc_reds = math.isqrt((nx*ny)//nacc) @@ -315,12 +315,12 @@ class Linearizer(Kernel): for y in range(by): for x in range(bx): for j in range(acc_reds): - ops = (self.uop(UOps.CAST, dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]])), - self.uop(UOps.CAST, dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]])), - self.uop(UOps.CAST, dtype_out.vec(wmma_sz[2]), tuple(op3:=acc[i:i+wmma_sz[2]]))) - ret = self.uop(UOps.WMMA, dtype_out.vec(wmma_sz[2]), ops, (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,)) - for z in range(cast(DType, ret.dtype).sz): - acc[i+z] = self.uop(UOps.PHI, dtypes.float, (op3[z], self.uop(UOps.GEP, dtypes.float, (ret,), z)) + loop_ctx) + ops = (self.uop(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]])), + self.uop(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]])), + self.uop(UOps.CAST, tc.dtype_out.vec(wmma_sz[2]), tuple(op3:=acc[i:i+wmma_sz[2]]))) + ret = self.uop(UOps.WMMA, tc.dtype_out.vec(wmma_sz[2]), ops, tc.wmma_func) + for z in range(wmma_sz[2]): + acc[i+z] = self.uop(UOps.PHI, tc.dtype_out, (op3[z], self.uop(UOps.GEP, tc.dtype_out, (ret,), z)) + loop_ctx) i += wmma_sz[2] else: if locals_to_store: diff --git a/tinygrad/features/image.py b/tinygrad/features/image.py index 6ca5567603..1d071bfc59 100644 --- a/tinygrad/features/image.py +++ b/tinygrad/features/image.py @@ -4,7 +4,7 @@ from tinygrad.dtype import dtypes # *** image Tensor function replacements *** -def image_dot(self, w): +def image_dot(self, w, acc_dtype=None): # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1) n1, n2 = len(self.shape), len(w.shape) assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D" @@ -17,9 +17,9 @@ def image_dot(self, w): cx = self.transpose(self.ndim-1, self.ndim-2).reshape((bs//groups, groups*cin, -1, 1)) # groups*cout x cin x H, W cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1)) - return image_conv2d(cx, cw, groups=groups).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2) + return image_conv2d(cx, cw, groups=groups, acc_dtype=acc_dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2) -def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, padding=0): +def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None): base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef (bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape @@ -72,7 +72,7 @@ def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, paddin w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W)) # the conv! - ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1)) + ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), acc_dtype=acc_dtype) # undo hack for non multiples of 4 on C.rcout if added_output_channels != 0: diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index d4acf6ff54..75e6cf3aef 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -138,10 +138,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st kk(lang.render_for(ssa(u,'ridx'), r[vin[0]], r[vin[1]])) depth += 1 elif uop == UOps.WMMA: - if args[0] == "METAL" and dtype == dtypes.float.vec(2): wmma_func = "__metal_wmma" - elif args[0] == "HIP" and dtype == dtypes.float.vec(8): wmma_func = "__builtin_amdgcn_wmma_f32_16x16x16_f16_w32" - else: raise NotImplementedError(f"WMMA not implemented for {args}") - kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u, 'wmma')} = {wmma_func}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});") # noqa: E501 + kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u, 'wmma')} = {args}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});") # noqa: E501 elif uop == UOps.ALU: # remove parens if ALU types are the same. TODO: can do more here if vin[0].uop == UOps.ALU and vin[0].arg == args and args in {BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.XOR}: @@ -218,10 +215,10 @@ class OpenCLLanguage(CStyleLanguage): OpenCLRenderer = functools.partial(uops_to_cstyle, OpenCLLanguage()) class MetalLanguage(CStyleLanguage): - kernel_prefix = """#include \nusing namespace metal;\ntemplate T __metal_wmma(T m, T n, T o) { + kernel_prefix = """#include \nusing namespace metal;\ntemplate U __metal_wmma(T m, T n, U o) { S a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x; b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c); - return T(c.thread_elements()[0], c.thread_elements()[1]);\n}\nkernel """ + return U(c.thread_elements()[0], c.thread_elements()[1]);\n}\nkernel """ buffer_prefix = "device " smem_prefix = "threadgroup " arg_int_prefix = "constant int&" diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index b38aa5e016..7714353888 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -522,10 +522,10 @@ class Tensor: ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else s for i,s in enumerate(self.shape)])) return ret if keepdim else ret.reshape(shape=shape) - def sum(self, axis=None, keepdim=False): - acc_dtype = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \ - least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else \ - least_upper_dtype(self.dtype, dtypes.float) + def sum(self, axis=None, keepdim=False, acc_dtype=None): + if acc_dtype is None: acc_dtype = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \ + least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else \ + least_upper_dtype(self.dtype, dtypes.float) # cast back to float16 or bfloat16 to match torch / jax behavior, but we use float for acc output_dtype = self.dtype if self.dtype in (dtypes.float16, dtypes.bfloat16) else acc_dtype return self.cast(acc_dtype)._reduce(mlops.Sum, axis, keepdim).cast(output_dtype) @@ -680,15 +680,16 @@ class Tensor: return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward() - def dot(self, w:Tensor) -> Tensor: + def dot(self, w:Tensor, acc_dtype=None) -> Tensor: n1, n2 = len(self.shape), len(w.shape) assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D" assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501 x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1]) w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2)) - return (x*w).sum(-1).cast(least_upper_dtype(x.dtype, w.dtype)) + return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype)) - def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x) + def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DType]=None) -> Tensor: + return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype) def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor: return self.transpose(axis,-1).pad2d((self.shape[axis]-int(not _first_zero),0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)