mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
metal can add
This commit is contained in:
@@ -12,9 +12,10 @@ render_cl = render_python.copy()
|
||||
render_cl[DivNode] = lambda self,ops,ctx: f"({self.a.render(ops)}/{self.b})"
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
CUDA = getenv("CUDA", 0)
|
||||
if not CUDA: from tinygrad.runtime.opencl import CLBuffer, CLImage, CLProgram # NOTE: using CL will not work for the CUDA runtime # noqa: F401
|
||||
else: from tinygrad.runtime.cuda import CLBuffer, CLImage, CLProgram # type: ignore
|
||||
CUDA,METAL = getenv("CUDA", 0), getenv("METAL", 0)
|
||||
if not CUDA and not METAL: from tinygrad.runtime.opencl import CLBuffer, CLImage, CLProgram # NOTE: using CL will not work for the CUDA runtime # noqa: F401
|
||||
elif CUDA: from tinygrad.runtime.cuda import CLBuffer, CLImage, CLProgram # type: ignore
|
||||
elif METAL: from tinygrad.runtime.metal import CLBuffer, CLImage, CLProgram # type: ignore
|
||||
|
||||
VALIDHACKS = getenv("VALIDHACKS", 0) # TODO: remove the need for this
|
||||
NATIVE_EXPLOG = getenv("NATIVE_EXPLOG", 0) # this is needed as a switch for the tests to pass
|
||||
@@ -254,7 +255,7 @@ class CLASTKernel(ASTKernel):
|
||||
|
||||
# output_shape[-1] is get_global_id(0)
|
||||
MAX_OUTPUT_SHAPE = 3
|
||||
self.kernel += [f"int idx{len(self.output_shape)-1-i} = {f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}' if CUDA else f'get_global_id({i})'}; /* {self.output_shape[-1-i]} */\n" for i in range(min(MAX_OUTPUT_SHAPE, len(self.output_shape))) if self.output_shape[-1-i] != 1]
|
||||
self.kernel += [f"int idx{len(self.output_shape)-1-i} = {CLProgram.gid(i)}; /* {self.output_shape[-1-i]} */\n" for i in range(min(MAX_OUTPUT_SHAPE, len(self.output_shape))) if self.output_shape[-1-i] != 1]
|
||||
if len(self.output_shape) > MAX_OUTPUT_SHAPE:
|
||||
# sometimes, there's more dimensions. compact all the dimensions into the first one
|
||||
# TODO: these compactions should be searchable
|
||||
@@ -311,9 +312,9 @@ class CLASTKernel(ASTKernel):
|
||||
|
||||
# kernel function definition
|
||||
function_name = ("re_S" if self.reduceop else "ew_S") + '_'.join([str(x) for x in self.bufs[0].shape if x != 1])
|
||||
buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if isinstance(x._buf, CLImage) else ("__global "+self.buftokens[i].decltype()) for i,x in enumerate(self.bufs)] if not CUDA else [self.buftokens[i].decltype() for i,x in enumerate(self.bufs)]
|
||||
self.kernel = list(self.prekernel) + [f"{'__global__' if CUDA else '__kernel'} void {function_name}(",] + \
|
||||
[', '.join([f'{t} data{i}' for i,t in enumerate(buftypes) if i not in self.bufs_to_delete])] + \
|
||||
buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if isinstance(x._buf, CLImage) else (CLProgram.buffer_prefix+self.buftokens[i].decltype()) for i,x in enumerate(self.bufs)]
|
||||
self.kernel = list(self.prekernel) + [f"{CLProgram.kernel_prefix} void {function_name}(",] + \
|
||||
[', '.join([f'{t} data{i}' for i,t in enumerate(buftypes) if i not in self.bufs_to_delete] + (['uint3 gid [[thread_position_in_grid]]'] if METAL else []))] + \
|
||||
[") {\n"] + self.kernel
|
||||
|
||||
# compile kernel
|
||||
|
||||
@@ -14,6 +14,10 @@ class CLBuffer:
|
||||
def copyout(self, a:np.ndarray): cuda.memcpy_dtoh(a, self._cl)
|
||||
|
||||
class CLProgram:
|
||||
kernel_prefix = "__global__"
|
||||
buffer_prefix = ""
|
||||
@staticmethod
|
||||
def gid(i): return f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}'
|
||||
def __init__(self, name:str, prg:str, op_estimate:int=0, mem_estimate:int=0):
|
||||
self.name, self.op_estimate, self.mem_estimate = name, op_estimate, mem_estimate
|
||||
if DEBUG >= 4: print("CUDA compile", prg)
|
||||
|
||||
62
tinygrad/runtime/metal.py
Normal file
62
tinygrad/runtime/metal.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# pip3 install pyobjc-framework-Metal
|
||||
import Metal
|
||||
import numpy as np
|
||||
from typing import List, Any
|
||||
from tinygrad.ops import DEBUG
|
||||
|
||||
device = Metal.MTLCreateSystemDefaultDevice()
|
||||
mtl_queue = device.newCommandQueue()
|
||||
mtl_buffers_in_flight : List[Any] = []
|
||||
|
||||
def sync():
|
||||
global mtl_buffers_in_flight
|
||||
for cbuf in mtl_buffers_in_flight: cbuf.waitUntilCompleted()
|
||||
mtl_buffers_in_flight = []
|
||||
|
||||
class CLImage:
|
||||
def __init__(self, shape): raise NotImplementedError("Metal runtime doesn't support images")
|
||||
|
||||
class CLBuffer:
|
||||
def __init__(self, size): self._cl = device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared)
|
||||
def copyin(self, b:np.ndarray):
|
||||
# TODO: don't reallocate buffer!
|
||||
self._cl = device.newBufferWithBytes_length_options_(
|
||||
b.astype(np.float32).data,
|
||||
b.size*4,
|
||||
Metal.MTLResourceStorageModeShared)
|
||||
|
||||
def toCPU(self):
|
||||
sync()
|
||||
return np.frombuffer(self._cl.contents().as_buffer(self._cl.length()), dtype=np.float32)
|
||||
|
||||
# TODO: remove copyout everywhere
|
||||
def copyout(self, a:np.ndarray): np.copyto(a, self.toCPU().reshape(a.shape))
|
||||
|
||||
class CLProgram:
|
||||
kernel_prefix = "kernel"
|
||||
buffer_prefix = "device "
|
||||
@staticmethod
|
||||
def gid(i): return f"gid.{chr(120+i)}"
|
||||
def __init__(self, name:str, prg:str, op_estimate:int=0, mem_estimate:int=0):
|
||||
self.name, self.op_estimate, self.mem_estimate = name, op_estimate, mem_estimate
|
||||
options = Metal.MTLCompileOptions.alloc().init()
|
||||
if DEBUG >= 4: print("Metal compile", prg)
|
||||
self.library = device.newLibraryWithSource_options_error_(prg, options, None)
|
||||
assert self.library[0] is not None, str(self.library)
|
||||
self.fxn = self.library[0].newFunctionWithName_(name)
|
||||
def __call__(self, global_size, local_size, *args):
|
||||
global_size += [1] * (3-len(global_size))
|
||||
if local_size is None: local_size = []
|
||||
local_size += [1] * (3-len(local_size))
|
||||
pipeline_state = device.newComputePipelineStateWithFunction_error_(self.fxn, None)
|
||||
assert pipeline_state[0] is not None, str(pipeline_state)
|
||||
command_buffer = mtl_queue.commandBuffer()
|
||||
encoder = command_buffer.computeCommandEncoder()
|
||||
encoder.setComputePipelineState_(pipeline_state[0])
|
||||
for i,a in enumerate(args):
|
||||
encoder.setBuffer_offset_atIndex_(a, 0, i)
|
||||
encoder.dispatchThreads_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
||||
encoder.endEncoding()
|
||||
command_buffer.commit()
|
||||
mtl_buffers_in_flight.append(command_buffer)
|
||||
return command_buffer
|
||||
@@ -62,7 +62,11 @@ class CLImage:
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
class CLProgram:
|
||||
kernel_prefix = "__kernel"
|
||||
buffer_prefix = "__global "
|
||||
kernel_cnt : Final[Dict[str, int]] = defaultdict(int)
|
||||
@staticmethod
|
||||
def gid(i): return f'get_global_id({i})'
|
||||
def __init__(self, name:str, prg:str, options:Tuple[str, ...]=tuple(), argdtypes=None, rename=True, binary=False, op_estimate=0, mem_estimate=0):
|
||||
self.name = f"{name}{('_N'+str(CLProgram.kernel_cnt[name])) if CLProgram.kernel_cnt[name] else str()}" if rename else name
|
||||
self.prg, self.options, self.argdtypes, self.op_estimate, self.mem_estimate = prg.replace(f"{name}(", f"{self.name}(") if rename else prg, options, argdtypes, op_estimate, mem_estimate
|
||||
|
||||
Reference in New Issue
Block a user