From 67d1df80ba998d4ee57942e3745049e6ab461f6c Mon Sep 17 00:00:00 2001 From: George Hotz Date: Fri, 17 Feb 2023 11:54:50 -0800 Subject: [PATCH] gid is array, metal works --- tinygrad/llops/ops_gpu.py | 6 +++--- tinygrad/runtime/cuda.py | 5 +++-- tinygrad/runtime/metal.py | 9 +++++---- tinygrad/runtime/opencl.py | 5 +++-- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index fb8fc245ce..a2465a8433 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -255,7 +255,7 @@ class CLASTKernel(ASTKernel): # output_shape[-1] is get_global_id(0) MAX_OUTPUT_SHAPE = 3 - self.kernel += [f"int idx{len(self.output_shape)-1-i} = {CLProgram.gid(i)}; /* {self.output_shape[-1-i]} */\n" for i in range(min(MAX_OUTPUT_SHAPE, len(self.output_shape))) if self.output_shape[-1-i] != 1] + self.kernel += [f"int idx{len(self.output_shape)-1-i} = {CLProgram.gid[i]}; /* {self.output_shape[-1-i]} */\n" for i in range(min(MAX_OUTPUT_SHAPE, len(self.output_shape))) if self.output_shape[-1-i] != 1] if len(self.output_shape) > MAX_OUTPUT_SHAPE: # sometimes, there's more dimensions. compact all the dimensions into the first one # TODO: these compactions should be searchable @@ -284,9 +284,9 @@ class CLASTKernel(ASTKernel): assert lvalid.min == 1, "local buffer must always be valid" self.kernel.append(f"int mid_idx = {lidx.render(render_cl)};\n") for i,acc in enumerate(accumulators): - self.kernel.append(("__shared__ " if CUDA else "__local ") + f"{acc.decltype()} {self.buftokens[-1].tok}{i}[{prod(self.group_for_reduce)}];") + self.kernel.append(CLProgram.smem_prefix + f"{acc.decltype()} {self.buftokens[-1].tok}{i}[{prod(self.group_for_reduce)}];") self.kernel.append(f"{self.buftokens[-1].tok}{i}[mid_idx] = {acc.tok};\n") - self.kernel.append("barrier(CLK_LOCAL_MEM_FENCE);\n" if not CUDA else "__syncthreads();\n") + self.kernel.append(CLProgram.barrier+"\n") if self.upcast_in_mid_reduce: assert len(self.group_for_reduce) == 2 diff --git a/tinygrad/runtime/cuda.py b/tinygrad/runtime/cuda.py index 89c8973c95..e1328dbd49 100644 --- a/tinygrad/runtime/cuda.py +++ b/tinygrad/runtime/cuda.py @@ -16,8 +16,9 @@ class CLBuffer: class CLProgram: kernel_prefix = "__global__" buffer_prefix = "" - @staticmethod - def gid(i): return f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}' + smem_prefix = "__shared__ " + barrier = "__syncthreads();" + gid = [f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}' for i in range(3)] def __init__(self, name:str, prg:str, op_estimate:int=0, mem_estimate:int=0): self.name, self.op_estimate, self.mem_estimate = name, op_estimate, mem_estimate if DEBUG >= 4: print("CUDA compile", prg) diff --git a/tinygrad/runtime/metal.py b/tinygrad/runtime/metal.py index 09e9193aae..bb62565a2a 100644 --- a/tinygrad/runtime/metal.py +++ b/tinygrad/runtime/metal.py @@ -1,5 +1,5 @@ # pip3 install pyobjc-framework-Metal -import Metal +import Metal # type: ignore import numpy as np from typing import List, Any from tinygrad.ops import DEBUG @@ -33,10 +33,11 @@ class CLBuffer: def copyout(self, a:np.ndarray): np.copyto(a, self.toCPU().reshape(a.shape)) class CLProgram: - kernel_prefix = "kernel" + kernel_prefix = "using namespace metal;\nkernel" buffer_prefix = "device " - @staticmethod - def gid(i): return f"gid.{chr(120+i)}" + smem_prefix = "threadgroup " + barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);" + gid = [f"gid.{chr(120+i)}" for i in range(3)] def __init__(self, name:str, prg:str, op_estimate:int=0, mem_estimate:int=0): self.name, self.op_estimate, self.mem_estimate = name, op_estimate, mem_estimate options = Metal.MTLCompileOptions.alloc().init() diff --git a/tinygrad/runtime/opencl.py b/tinygrad/runtime/opencl.py index ea766c8965..81c573544b 100644 --- a/tinygrad/runtime/opencl.py +++ b/tinygrad/runtime/opencl.py @@ -64,9 +64,10 @@ class CLImage: class CLProgram: kernel_prefix = "__kernel" buffer_prefix = "__global " + smem_prefix = "__local " kernel_cnt : Final[Dict[str, int]] = defaultdict(int) - @staticmethod - def gid(i): return f'get_global_id({i})' + barrier = "barrier(CLK_LOCAL_MEM_FENCE);" + gid = [f'get_global_id({i})' for i in range(3)] def __init__(self, name:str, prg:str, options:Tuple[str, ...]=tuple(), argdtypes=None, rename=True, binary=False, op_estimate=0, mem_estimate=0): self.name = f"{name}{('_N'+str(CLProgram.kernel_cnt[name])) if CLProgram.kernel_cnt[name] else str()}" if rename else name self.prg, self.options, self.argdtypes, self.op_estimate, self.mem_estimate = prg.replace(f"{name}(", f"{self.name}(") if rename else prg, options, argdtypes, op_estimate, mem_estimate