diff --git a/extra/gemm/metal_matvec.py b/extra/gemm/metal_matvec.py new file mode 100644 index 0000000000..8c21a3711a --- /dev/null +++ b/extra/gemm/metal_matvec.py @@ -0,0 +1,126 @@ +import os +#os.environ["METAL"] = "1" +import numpy as np +import time, torch, torch.mps + +from tinygrad.ops import GlobalCounters +from tinygrad.tensor import Tensor +from tinygrad.jit import TinyJit +from tinygrad.ops import Device +from tinygrad.helpers import colored, getenv, CI + +import os +os.environ["METAL"] = "1" +import time +import numpy as np +from tinygrad.helpers import dtypes, getenv +from tinygrad.runtime.ops_metal import RawMetalBuffer, MetalProgram + +N = 16384 +M = 4096 +FLOPS = N*M*2 + +nb = np.random.default_rng().standard_normal(size=(N), dtype=np.float32) #.astype(np.int32).astype(np.float32) +nc = np.random.default_rng().standard_normal(size=(N,M), dtype=np.float32) #.astype(np.int32).astype(np.float32) + +import torch, torch.mps +b = torch.from_numpy(nb).to('mps') +c = torch.from_numpy(nc).to('mps') + +def torch_prog(b, c): + st = time.perf_counter() + a = b@c + torch.mps.synchronize() + return time.perf_counter() - st +tm = min([torch_prog(b, c) for _ in range(200)]) +print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in torch") +torch_a = (b@c).cpu() + +WORKSIZE_ROW = 16 +WORKSIZE_COL = 1 +LOCAL_SIZE = [32, WORKSIZE_COL, WORKSIZE_ROW] +GLOBAL_SIZE = [M//(LOCAL_SIZE[0]*LOCAL_SIZE[1]*4), 1, 1] +prog_string = f""" +#include +using namespace metal; +kernel void test(device float* data0, const device float* data1, const device float* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {{ + int gidx0 = gid.x; /* {GLOBAL_SIZE[0]} */ + int lidx1 = lid.x; /* {LOCAL_SIZE[0]} */ + int lidx2 = lid.y; /* {LOCAL_SIZE[1]} */ + int lidx3 = lid.z; /* {LOCAL_SIZE[2]} */ + + // 4 rows per thread + threadgroup float4 acc0[{LOCAL_SIZE[0]*LOCAL_SIZE[1]*LOCAL_SIZE[2]}]; + int acc0_index = ((lidx1*{LOCAL_SIZE[1]})+lidx2)+({LOCAL_SIZE[0]*LOCAL_SIZE[1]}*lidx3); + acc0[acc0_index] = float4(0.0f,0.0f,0.0f,0.0f); + + threadgroup float4 val1[{LOCAL_SIZE[0]*LOCAL_SIZE[1]*LOCAL_SIZE[2]}]; + + // iterate over the columns + for (int ridx2 = 0; ridx2 < {N//(4*LOCAL_SIZE[0]*LOCAL_SIZE[1]*(LOCAL_SIZE[2]))}; ++ridx2) {{ + // load 4*threadgroup_size columns into shared memory + int col_1 = (((lidx3*{N//(4*LOCAL_SIZE[0]*LOCAL_SIZE[1]*(LOCAL_SIZE[2]))})+ridx2)*{LOCAL_SIZE[0]*LOCAL_SIZE[1]})+(lidx1*{LOCAL_SIZE[1]})+lidx2; + val1[(lidx3*{LOCAL_SIZE[1]*LOCAL_SIZE[0]})+((lidx1*{LOCAL_SIZE[1]})+lidx2)] = *((device float4*)(data1+(col_1*4))); + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int ridx3 = 0; ridx3 < {LOCAL_SIZE[0]*LOCAL_SIZE[1]}; ++ridx3) {{ + int col = ((((lidx3*{N//(4*LOCAL_SIZE[0]*LOCAL_SIZE[1]*(LOCAL_SIZE[2]))})+ridx2)*{LOCAL_SIZE[0]*LOCAL_SIZE[1]})+ridx3); + float4 val1_0 = val1[(lidx3*{LOCAL_SIZE[1]*LOCAL_SIZE[0]})+ridx3]; + float4 val2_0 = (float4)(*((device float4*)(data2+(gidx0*{M//GLOBAL_SIZE[0]})+(((lidx1*{LOCAL_SIZE[1]})+lidx2)*4)+(col*{M*4})+{M*0}))); + float4 val2_1 = (float4)(*((device float4*)(data2+(gidx0*{M//GLOBAL_SIZE[0]})+(((lidx1*{LOCAL_SIZE[1]})+lidx2)*4)+(col*{M*4})+{M*1}))); + float4 val2_2 = (float4)(*((device float4*)(data2+(gidx0*{M//GLOBAL_SIZE[0]})+(((lidx1*{LOCAL_SIZE[1]})+lidx2)*4)+(col*{M*4})+{M*2}))); + float4 val2_3 = (float4)(*((device float4*)(data2+(gidx0*{M//GLOBAL_SIZE[0]})+(((lidx1*{LOCAL_SIZE[1]})+lidx2)*4)+(col*{M*4})+{M*3}))); + acc0[acc0_index] = ((val1_0.x*val2_0)+acc0[acc0_index]); + acc0[acc0_index] = ((val1_0.y*val2_1)+acc0[acc0_index]); + acc0[acc0_index] = ((val1_0.z*val2_2)+acc0[acc0_index]); + acc0[acc0_index] = ((val1_0.w*val2_3)+acc0[acc0_index]); + }} + threadgroup_barrier(mem_flags::mem_threadgroup); + }} /* reduce */ + + if (lidx3 == 0) {{ + float4 out = float4(0.0f,0.0f,0.0f,0.0f); + for (int n = 0; n < {LOCAL_SIZE[2]}; n++) {{ + out += acc0[((lidx1*{LOCAL_SIZE[1]})+lidx2)+({LOCAL_SIZE[0]*LOCAL_SIZE[1]}*n)]; + }} + *( (device float4 *) (data0 + (gidx0*{M//GLOBAL_SIZE[0]}) + ( ( (lidx1*{LOCAL_SIZE[1]})+lidx2 ) * 4 ) ) ) = out; + }} +}} +""" +prog = MetalProgram("test", prog_string) +# print(prog_string) +na = np.zeros(M, dtype=np.float32) +b = RawMetalBuffer.fromCPU(nb) +c = RawMetalBuffer.fromCPU(nc) +def metalrun(): + a = RawMetalBuffer.fromCPU(na) + prog(GLOBAL_SIZE, LOCAL_SIZE, a, b, c, wait=True) + return a +def timeit(fxn): + st = time.perf_counter() + et = fxn() + # NOTE: et doesn't contain the launch overhead + return time.perf_counter() - st +tm = min([timeit(metalrun) for _ in range(200)]) +print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in metal") +metal_a = metalrun().toCPU().reshape(M) +np.testing.assert_allclose(metal_a, torch_a, atol=5e-3) + +from tinygrad.tensor import Tensor +from tinygrad.jit import TinyJit +from tinygrad.runtime.ops_metal import METAL +b = Tensor(nb) +c = Tensor(nc) +# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator +@TinyJit +def tiny_jit(b, c): + return (b@c).realize() +def tiny_prog(b, c): + st = time.perf_counter() + a = tiny_jit(b, c) + METAL.synchronize() + return time.perf_counter() - st +tm = min([tiny_prog(b, c) for _ in range(200)]) +print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in tinygrad") +tiny_a = tiny_jit(b, c).numpy() +np.testing.assert_allclose(tiny_a, torch_a, atol=5e-3) \ No newline at end of file diff --git a/tinygrad/codegen/optimizer.py b/tinygrad/codegen/optimizer.py index 401cbaf8c5..4571bb9e39 100644 --- a/tinygrad/codegen/optimizer.py +++ b/tinygrad/codegen/optimizer.py @@ -304,6 +304,29 @@ class OptimizedKernel(Kernel): # early exit return + # should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat + MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4) + if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \ + self.reduceop and self.reduceop.op == ReduceOps.SUM and len(self.full_shape) >= 2 and \ + isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == BinaryOps.MUL and \ + self.reduceop.src[0].src[0].op == BufferOps.MEM and self.reduceop.src[0].src[1].op == BufferOps.MEM: + buf0 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[0]).arg) + buf0_strides = self.sts[buf0].real_strides() + if buf0_strides[self.first_reduce] == 1: + for global_idx in range(self.global_dims): + if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0: + if DEBUG >= 3: print(f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread{MV_ROWS_PER_THREAD}") + if MV_THREADS_PER_ROW > 1: + self.shift_to(self.first_reduce, MV_THREADS_PER_ROW, top=False, insert_before=self.first_reduce + len(self.group_for_reduce)) + self.group_for_reduce.append(MV_THREADS_PER_ROW) + if MV_BLOCKSIZE > 1: + self.shift_to(global_idx, MV_BLOCKSIZE, insert_before=self.first_reduce) + self.local_dims += 1 + if MV_ROWS_PER_THREAD > 1: + self.shift_to(global_idx, MV_ROWS_PER_THREAD) + self.upcast() + return + if self.opts.has_local and self.opts.has_shared and all(isinstance(s, int) for s in self.sts[0].shape[:self.first_reduce]): # are we grouping? (requires local shape support) if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: @@ -407,4 +430,4 @@ class OptimizedKernel(Kernel): for axis, local_sz in sorted(to_local[:3]): self.shift_to(axis, local_sz, insert_before=self.first_reduce) self.local_dims += 1 - self.simplify_ones() \ No newline at end of file + self.simplify_ones() diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 98882d247f..0fad2fa2c9 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -11,6 +11,7 @@ class CStyleLanguage(NamedTuple): kernel_prefix: str = "" buffer_prefix: str = "" buffer_suffix: str = "" + smem_align: str = "" smem_prefix: str = "" arg_int_prefix: str = "" barrier: str = "" @@ -70,7 +71,7 @@ class CStyleLanguage(NamedTuple): return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val def render_local(self, name:str, size:int): - return self.smem_prefix + f"float {name}[{size}];" + return self.smem_align + self.smem_prefix + f"float {name}[{size}];" def render_for(self, expr: str, _min:Union[int,str], _max:Union[int,str]) -> str: return f"for (int {expr} = {_min}; {expr} <= {_max}; ++{expr}) {{" diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 23310dd7e3..8142c504d2 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -104,7 +104,7 @@ class CLProgram: return None renderer = functools.partial(uops_to_cstyle, CStyleLanguage( - kernel_prefix = "__kernel ", buffer_prefix = "__global ", smem_prefix = "__local ", arg_int_prefix = "const int", + kernel_prefix = "__kernel ", buffer_prefix = "__global ", smem_align = "__attribute__ ((aligned (16))) ", smem_prefix = "__local ", arg_int_prefix = "const int", half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable", barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)", gid = [f'get_group_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True))