From a508c2b429b696202e304f6cb5a33ecb2ba5a579 Mon Sep 17 00:00:00 2001 From: Diogo Date: Sun, 19 Feb 2023 14:25:13 -0500 Subject: [PATCH] small tweaks to the metal runtime (#562) * small tweaks to the metal runtime * create buffer straight from numpy * reverted back due to bug when adding 1+1 * removed comments --- tinygrad/runtime/metal.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tinygrad/runtime/metal.py b/tinygrad/runtime/metal.py index 486bd9ea24..007adf5ca0 100644 --- a/tinygrad/runtime/metal.py +++ b/tinygrad/runtime/metal.py @@ -3,6 +3,7 @@ import Metal # type: ignore import numpy as np from typing import List, Any from tinygrad.ops import DEBUG, GlobalCounters +from tinygrad.helpers import prod device = Metal.MTLCreateSystemDefaultDevice() mtl_queue = device.newCommandQueue() @@ -19,11 +20,7 @@ class CLImage: class CLBuffer: def __init__(self, size): self._cl = device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared) def copyin(self, b:np.ndarray): - # TODO: don't reallocate buffer! - self._cl = device.newBufferWithBytes_length_options_( - b.astype(np.float32).data, - b.size*4, - Metal.MTLResourceStorageModeShared) + np.copyto(np.frombuffer(self._cl.contents().as_buffer(self._cl.length()), dtype=np.float32), b.data) def toCPU(self): sync() @@ -45,15 +42,19 @@ class CLProgram: self.library = device.newLibraryWithSource_options_error_(prg, options, None) assert self.library[0] is not None, str(self.library) self.fxn = self.library[0].newFunctionWithName_(name) + self.pipeline_state = device.newComputePipelineStateWithFunction_error_(self.fxn, None) + assert self.pipeline_state[0] is not None, str(self.pipeline_state) + def __call__(self, global_size, local_size, *args): global_size += [1] * (3-len(global_size)) if local_size is None: local_size = [32] local_size += [1] * (3-len(local_size)) - pipeline_state = device.newComputePipelineStateWithFunction_error_(self.fxn, None) - assert pipeline_state[0] is not None, str(pipeline_state) + + assert prod(local_size) <= self.pipeline_state[0].maxTotalThreadsPerThreadgroup(), f"local size {local_size} bigger than {self.pipeline_state[0].maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state[0].threadExecutionWidth()} memory length {self.pipeline_state[0].staticThreadgroupMemoryLength()}" command_buffer = mtl_queue.commandBuffer() encoder = command_buffer.computeCommandEncoder() - encoder.setComputePipelineState_(pipeline_state[0]) + + encoder.setComputePipelineState_(self.pipeline_state[0]) for i,a in enumerate(args): encoder.setBuffer_offset_atIndex_(a, 0, i) encoder.dispatchThreads_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))