gid is array, metal works

This commit is contained in:
George Hotz
2023-02-17 11:54:50 -08:00
parent f9af0322e7
commit 67d1df80ba
4 changed files with 14 additions and 11 deletions

View File

@@ -255,7 +255,7 @@ class CLASTKernel(ASTKernel):
# output_shape[-1] is get_global_id(0)
MAX_OUTPUT_SHAPE = 3
self.kernel += [f"int idx{len(self.output_shape)-1-i} = {CLProgram.gid(i)}; /* {self.output_shape[-1-i]} */\n" for i in range(min(MAX_OUTPUT_SHAPE, len(self.output_shape))) if self.output_shape[-1-i] != 1]
self.kernel += [f"int idx{len(self.output_shape)-1-i} = {CLProgram.gid[i]}; /* {self.output_shape[-1-i]} */\n" for i in range(min(MAX_OUTPUT_SHAPE, len(self.output_shape))) if self.output_shape[-1-i] != 1]
if len(self.output_shape) > MAX_OUTPUT_SHAPE:
# sometimes, there's more dimensions. compact all the dimensions into the first one
# TODO: these compactions should be searchable
@@ -284,9 +284,9 @@ class CLASTKernel(ASTKernel):
assert lvalid.min == 1, "local buffer must always be valid"
self.kernel.append(f"int mid_idx = {lidx.render(render_cl)};\n")
for i,acc in enumerate(accumulators):
self.kernel.append(("__shared__ " if CUDA else "__local ") + f"{acc.decltype()} {self.buftokens[-1].tok}{i}[{prod(self.group_for_reduce)}];")
self.kernel.append(CLProgram.smem_prefix + f"{acc.decltype()} {self.buftokens[-1].tok}{i}[{prod(self.group_for_reduce)}];")
self.kernel.append(f"{self.buftokens[-1].tok}{i}[mid_idx] = {acc.tok};\n")
self.kernel.append("barrier(CLK_LOCAL_MEM_FENCE);\n" if not CUDA else "__syncthreads();\n")
self.kernel.append(CLProgram.barrier+"\n")
if self.upcast_in_mid_reduce:
assert len(self.group_for_reduce) == 2

View File

@@ -16,8 +16,9 @@ class CLBuffer:
class CLProgram:
kernel_prefix = "__global__"
buffer_prefix = ""
@staticmethod
def gid(i): return f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}'
smem_prefix = "__shared__ "
barrier = "__syncthreads();"
gid = [f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}' for i in range(3)]
def __init__(self, name:str, prg:str, op_estimate:int=0, mem_estimate:int=0):
self.name, self.op_estimate, self.mem_estimate = name, op_estimate, mem_estimate
if DEBUG >= 4: print("CUDA compile", prg)

View File

@@ -1,5 +1,5 @@
# pip3 install pyobjc-framework-Metal
import Metal
import Metal # type: ignore
import numpy as np
from typing import List, Any
from tinygrad.ops import DEBUG
@@ -33,10 +33,11 @@ class CLBuffer:
def copyout(self, a:np.ndarray): np.copyto(a, self.toCPU().reshape(a.shape))
class CLProgram:
kernel_prefix = "kernel"
kernel_prefix = "using namespace metal;\nkernel"
buffer_prefix = "device "
@staticmethod
def gid(i): return f"gid.{chr(120+i)}"
smem_prefix = "threadgroup "
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);"
gid = [f"gid.{chr(120+i)}" for i in range(3)]
def __init__(self, name:str, prg:str, op_estimate:int=0, mem_estimate:int=0):
self.name, self.op_estimate, self.mem_estimate = name, op_estimate, mem_estimate
options = Metal.MTLCompileOptions.alloc().init()

View File

@@ -64,9 +64,10 @@ class CLImage:
class CLProgram:
kernel_prefix = "__kernel"
buffer_prefix = "__global "
smem_prefix = "__local "
kernel_cnt : Final[Dict[str, int]] = defaultdict(int)
@staticmethod
def gid(i): return f'get_global_id({i})'
barrier = "barrier(CLK_LOCAL_MEM_FENCE);"
gid = [f'get_global_id({i})' for i in range(3)]
def __init__(self, name:str, prg:str, options:Tuple[str, ...]=tuple(), argdtypes=None, rename=True, binary=False, op_estimate=0, mem_estimate=0):
self.name = f"{name}{('_N'+str(CLProgram.kernel_cnt[name])) if CLProgram.kernel_cnt[name] else str()}" if rename else name
self.prg, self.options, self.argdtypes, self.op_estimate, self.mem_estimate = prg.replace(f"{name}(", f"{self.name}(") if rename else prg, options, argdtypes, op_estimate, mem_estimate