metal can add

This commit is contained in:
George Hotz
2023-02-17 11:45:33 -08:00
parent e172f0087a
commit f9af0322e7
4 changed files with 78 additions and 7 deletions

View File

@@ -12,9 +12,10 @@ render_cl = render_python.copy()
render_cl[DivNode] = lambda self,ops,ctx: f"({self.a.render(ops)}/{self.b})"
from tinygrad.helpers import getenv
CUDA = getenv("CUDA", 0)
if not CUDA: from tinygrad.runtime.opencl import CLBuffer, CLImage, CLProgram # NOTE: using CL will not work for the CUDA runtime # noqa: F401
else: from tinygrad.runtime.cuda import CLBuffer, CLImage, CLProgram # type: ignore
CUDA,METAL = getenv("CUDA", 0), getenv("METAL", 0)
if not CUDA and not METAL: from tinygrad.runtime.opencl import CLBuffer, CLImage, CLProgram # NOTE: using CL will not work for the CUDA runtime # noqa: F401
elif CUDA: from tinygrad.runtime.cuda import CLBuffer, CLImage, CLProgram # type: ignore
elif METAL: from tinygrad.runtime.metal import CLBuffer, CLImage, CLProgram # type: ignore
VALIDHACKS = getenv("VALIDHACKS", 0) # TODO: remove the need for this
NATIVE_EXPLOG = getenv("NATIVE_EXPLOG", 0) # this is needed as a switch for the tests to pass
@@ -254,7 +255,7 @@ class CLASTKernel(ASTKernel):
# output_shape[-1] is get_global_id(0)
MAX_OUTPUT_SHAPE = 3
self.kernel += [f"int idx{len(self.output_shape)-1-i} = {f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}' if CUDA else f'get_global_id({i})'}; /* {self.output_shape[-1-i]} */\n" for i in range(min(MAX_OUTPUT_SHAPE, len(self.output_shape))) if self.output_shape[-1-i] != 1]
self.kernel += [f"int idx{len(self.output_shape)-1-i} = {CLProgram.gid(i)}; /* {self.output_shape[-1-i]} */\n" for i in range(min(MAX_OUTPUT_SHAPE, len(self.output_shape))) if self.output_shape[-1-i] != 1]
if len(self.output_shape) > MAX_OUTPUT_SHAPE:
# sometimes, there's more dimensions. compact all the dimensions into the first one
# TODO: these compactions should be searchable
@@ -311,9 +312,9 @@ class CLASTKernel(ASTKernel):
# kernel function definition
function_name = ("re_S" if self.reduceop else "ew_S") + '_'.join([str(x) for x in self.bufs[0].shape if x != 1])
buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if isinstance(x._buf, CLImage) else ("__global "+self.buftokens[i].decltype()) for i,x in enumerate(self.bufs)] if not CUDA else [self.buftokens[i].decltype() for i,x in enumerate(self.bufs)]
self.kernel = list(self.prekernel) + [f"{'__global__' if CUDA else '__kernel'} void {function_name}(",] + \
[', '.join([f'{t} data{i}' for i,t in enumerate(buftypes) if i not in self.bufs_to_delete])] + \
buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if isinstance(x._buf, CLImage) else (CLProgram.buffer_prefix+self.buftokens[i].decltype()) for i,x in enumerate(self.bufs)]
self.kernel = list(self.prekernel) + [f"{CLProgram.kernel_prefix} void {function_name}(",] + \
[', '.join([f'{t} data{i}' for i,t in enumerate(buftypes) if i not in self.bufs_to_delete] + (['uint3 gid [[thread_position_in_grid]]'] if METAL else []))] + \
[") {\n"] + self.kernel
# compile kernel

View File

@@ -14,6 +14,10 @@ class CLBuffer:
def copyout(self, a:np.ndarray): cuda.memcpy_dtoh(a, self._cl)
class CLProgram:
kernel_prefix = "__global__"
buffer_prefix = ""
@staticmethod
def gid(i): return f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}'
def __init__(self, name:str, prg:str, op_estimate:int=0, mem_estimate:int=0):
self.name, self.op_estimate, self.mem_estimate = name, op_estimate, mem_estimate
if DEBUG >= 4: print("CUDA compile", prg)

62
tinygrad/runtime/metal.py Normal file
View File

@@ -0,0 +1,62 @@
# pip3 install pyobjc-framework-Metal
import Metal
import numpy as np
from typing import List, Any
from tinygrad.ops import DEBUG
device = Metal.MTLCreateSystemDefaultDevice()
mtl_queue = device.newCommandQueue()
mtl_buffers_in_flight : List[Any] = []
def sync():
global mtl_buffers_in_flight
for cbuf in mtl_buffers_in_flight: cbuf.waitUntilCompleted()
mtl_buffers_in_flight = []
class CLImage:
def __init__(self, shape): raise NotImplementedError("Metal runtime doesn't support images")
class CLBuffer:
def __init__(self, size): self._cl = device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared)
def copyin(self, b:np.ndarray):
# TODO: don't reallocate buffer!
self._cl = device.newBufferWithBytes_length_options_(
b.astype(np.float32).data,
b.size*4,
Metal.MTLResourceStorageModeShared)
def toCPU(self):
sync()
return np.frombuffer(self._cl.contents().as_buffer(self._cl.length()), dtype=np.float32)
# TODO: remove copyout everywhere
def copyout(self, a:np.ndarray): np.copyto(a, self.toCPU().reshape(a.shape))
class CLProgram:
kernel_prefix = "kernel"
buffer_prefix = "device "
@staticmethod
def gid(i): return f"gid.{chr(120+i)}"
def __init__(self, name:str, prg:str, op_estimate:int=0, mem_estimate:int=0):
self.name, self.op_estimate, self.mem_estimate = name, op_estimate, mem_estimate
options = Metal.MTLCompileOptions.alloc().init()
if DEBUG >= 4: print("Metal compile", prg)
self.library = device.newLibraryWithSource_options_error_(prg, options, None)
assert self.library[0] is not None, str(self.library)
self.fxn = self.library[0].newFunctionWithName_(name)
def __call__(self, global_size, local_size, *args):
global_size += [1] * (3-len(global_size))
if local_size is None: local_size = []
local_size += [1] * (3-len(local_size))
pipeline_state = device.newComputePipelineStateWithFunction_error_(self.fxn, None)
assert pipeline_state[0] is not None, str(pipeline_state)
command_buffer = mtl_queue.commandBuffer()
encoder = command_buffer.computeCommandEncoder()
encoder.setComputePipelineState_(pipeline_state[0])
for i,a in enumerate(args):
encoder.setBuffer_offset_atIndex_(a, 0, i)
encoder.dispatchThreads_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
encoder.endEncoding()
command_buffer.commit()
mtl_buffers_in_flight.append(command_buffer)
return command_buffer

View File

@@ -62,7 +62,11 @@ class CLImage:
@functools.lru_cache(maxsize=None)
class CLProgram:
kernel_prefix = "__kernel"
buffer_prefix = "__global "
kernel_cnt : Final[Dict[str, int]] = defaultdict(int)
@staticmethod
def gid(i): return f'get_global_id({i})'
def __init__(self, name:str, prg:str, options:Tuple[str, ...]=tuple(), argdtypes=None, rename=True, binary=False, op_estimate=0, mem_estimate=0):
self.name = f"{name}{('_N'+str(CLProgram.kernel_cnt[name])) if CLProgram.kernel_cnt[name] else str()}" if rename else name
self.prg, self.options, self.argdtypes, self.op_estimate, self.mem_estimate = prg.replace(f"{name}(", f"{self.name}(") if rename else prg, options, argdtypes, op_estimate, mem_estimate