mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
gid is array, metal works
This commit is contained in:
@@ -255,7 +255,7 @@ class CLASTKernel(ASTKernel):
|
||||
|
||||
# output_shape[-1] is get_global_id(0)
|
||||
MAX_OUTPUT_SHAPE = 3
|
||||
self.kernel += [f"int idx{len(self.output_shape)-1-i} = {CLProgram.gid(i)}; /* {self.output_shape[-1-i]} */\n" for i in range(min(MAX_OUTPUT_SHAPE, len(self.output_shape))) if self.output_shape[-1-i] != 1]
|
||||
self.kernel += [f"int idx{len(self.output_shape)-1-i} = {CLProgram.gid[i]}; /* {self.output_shape[-1-i]} */\n" for i in range(min(MAX_OUTPUT_SHAPE, len(self.output_shape))) if self.output_shape[-1-i] != 1]
|
||||
if len(self.output_shape) > MAX_OUTPUT_SHAPE:
|
||||
# sometimes, there's more dimensions. compact all the dimensions into the first one
|
||||
# TODO: these compactions should be searchable
|
||||
@@ -284,9 +284,9 @@ class CLASTKernel(ASTKernel):
|
||||
assert lvalid.min == 1, "local buffer must always be valid"
|
||||
self.kernel.append(f"int mid_idx = {lidx.render(render_cl)};\n")
|
||||
for i,acc in enumerate(accumulators):
|
||||
self.kernel.append(("__shared__ " if CUDA else "__local ") + f"{acc.decltype()} {self.buftokens[-1].tok}{i}[{prod(self.group_for_reduce)}];")
|
||||
self.kernel.append(CLProgram.smem_prefix + f"{acc.decltype()} {self.buftokens[-1].tok}{i}[{prod(self.group_for_reduce)}];")
|
||||
self.kernel.append(f"{self.buftokens[-1].tok}{i}[mid_idx] = {acc.tok};\n")
|
||||
self.kernel.append("barrier(CLK_LOCAL_MEM_FENCE);\n" if not CUDA else "__syncthreads();\n")
|
||||
self.kernel.append(CLProgram.barrier+"\n")
|
||||
|
||||
if self.upcast_in_mid_reduce:
|
||||
assert len(self.group_for_reduce) == 2
|
||||
|
||||
@@ -16,8 +16,9 @@ class CLBuffer:
|
||||
class CLProgram:
|
||||
kernel_prefix = "__global__"
|
||||
buffer_prefix = ""
|
||||
@staticmethod
|
||||
def gid(i): return f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}'
|
||||
smem_prefix = "__shared__ "
|
||||
barrier = "__syncthreads();"
|
||||
gid = [f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}' for i in range(3)]
|
||||
def __init__(self, name:str, prg:str, op_estimate:int=0, mem_estimate:int=0):
|
||||
self.name, self.op_estimate, self.mem_estimate = name, op_estimate, mem_estimate
|
||||
if DEBUG >= 4: print("CUDA compile", prg)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# pip3 install pyobjc-framework-Metal
|
||||
import Metal
|
||||
import Metal # type: ignore
|
||||
import numpy as np
|
||||
from typing import List, Any
|
||||
from tinygrad.ops import DEBUG
|
||||
@@ -33,10 +33,11 @@ class CLBuffer:
|
||||
def copyout(self, a:np.ndarray): np.copyto(a, self.toCPU().reshape(a.shape))
|
||||
|
||||
class CLProgram:
|
||||
kernel_prefix = "kernel"
|
||||
kernel_prefix = "using namespace metal;\nkernel"
|
||||
buffer_prefix = "device "
|
||||
@staticmethod
|
||||
def gid(i): return f"gid.{chr(120+i)}"
|
||||
smem_prefix = "threadgroup "
|
||||
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);"
|
||||
gid = [f"gid.{chr(120+i)}" for i in range(3)]
|
||||
def __init__(self, name:str, prg:str, op_estimate:int=0, mem_estimate:int=0):
|
||||
self.name, self.op_estimate, self.mem_estimate = name, op_estimate, mem_estimate
|
||||
options = Metal.MTLCompileOptions.alloc().init()
|
||||
|
||||
@@ -64,9 +64,10 @@ class CLImage:
|
||||
class CLProgram:
|
||||
kernel_prefix = "__kernel"
|
||||
buffer_prefix = "__global "
|
||||
smem_prefix = "__local "
|
||||
kernel_cnt : Final[Dict[str, int]] = defaultdict(int)
|
||||
@staticmethod
|
||||
def gid(i): return f'get_global_id({i})'
|
||||
barrier = "barrier(CLK_LOCAL_MEM_FENCE);"
|
||||
gid = [f'get_global_id({i})' for i in range(3)]
|
||||
def __init__(self, name:str, prg:str, options:Tuple[str, ...]=tuple(), argdtypes=None, rename=True, binary=False, op_estimate=0, mem_estimate=0):
|
||||
self.name = f"{name}{('_N'+str(CLProgram.kernel_cnt[name])) if CLProgram.kernel_cnt[name] else str()}" if rename else name
|
||||
self.prg, self.options, self.argdtypes, self.op_estimate, self.mem_estimate = prg.replace(f"{name}(", f"{self.name}(") if rename else prg, options, argdtypes, op_estimate, mem_estimate
|
||||
|
||||
Reference in New Issue
Block a user