mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
small tweaks to the metal runtime (#562)
* small tweaks to the metal runtime * create buffer straight from numpy * reverted back due to bug when adding 1+1 * removed comments
This commit is contained in:
@@ -3,6 +3,7 @@ import Metal # type: ignore
|
||||
import numpy as np
|
||||
from typing import List, Any
|
||||
from tinygrad.ops import DEBUG, GlobalCounters
|
||||
from tinygrad.helpers import prod
|
||||
|
||||
device = Metal.MTLCreateSystemDefaultDevice()
|
||||
mtl_queue = device.newCommandQueue()
|
||||
@@ -19,11 +20,7 @@ class CLImage:
|
||||
class CLBuffer:
|
||||
def __init__(self, size): self._cl = device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared)
|
||||
def copyin(self, b:np.ndarray):
|
||||
# TODO: don't reallocate buffer!
|
||||
self._cl = device.newBufferWithBytes_length_options_(
|
||||
b.astype(np.float32).data,
|
||||
b.size*4,
|
||||
Metal.MTLResourceStorageModeShared)
|
||||
np.copyto(np.frombuffer(self._cl.contents().as_buffer(self._cl.length()), dtype=np.float32), b.data)
|
||||
|
||||
def toCPU(self):
|
||||
sync()
|
||||
@@ -45,15 +42,19 @@ class CLProgram:
|
||||
self.library = device.newLibraryWithSource_options_error_(prg, options, None)
|
||||
assert self.library[0] is not None, str(self.library)
|
||||
self.fxn = self.library[0].newFunctionWithName_(name)
|
||||
self.pipeline_state = device.newComputePipelineStateWithFunction_error_(self.fxn, None)
|
||||
assert self.pipeline_state[0] is not None, str(self.pipeline_state)
|
||||
|
||||
def __call__(self, global_size, local_size, *args):
|
||||
global_size += [1] * (3-len(global_size))
|
||||
if local_size is None: local_size = [32]
|
||||
local_size += [1] * (3-len(local_size))
|
||||
pipeline_state = device.newComputePipelineStateWithFunction_error_(self.fxn, None)
|
||||
assert pipeline_state[0] is not None, str(pipeline_state)
|
||||
|
||||
assert prod(local_size) <= self.pipeline_state[0].maxTotalThreadsPerThreadgroup(), f"local size {local_size} bigger than {self.pipeline_state[0].maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state[0].threadExecutionWidth()} memory length {self.pipeline_state[0].staticThreadgroupMemoryLength()}"
|
||||
command_buffer = mtl_queue.commandBuffer()
|
||||
encoder = command_buffer.computeCommandEncoder()
|
||||
encoder.setComputePipelineState_(pipeline_state[0])
|
||||
|
||||
encoder.setComputePipelineState_(self.pipeline_state[0])
|
||||
for i,a in enumerate(args):
|
||||
encoder.setBuffer_offset_atIndex_(a, 0, i)
|
||||
encoder.dispatchThreads_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
||||
|
||||
Reference in New Issue
Block a user