tiny metal cleanup (#8089)

* tiny metal cleanup

* cast

* sry
This commit is contained in:
nimlgen
2024-12-06 21:44:32 +03:00
committed by GitHub
parent d000c08f04
commit e180a31c5e

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import os, pathlib, struct, ctypes, tempfile, functools
from typing import List, Any, Union, Tuple, cast
from tinygrad.helpers import prod, to_mv, getenv, round_up, _cache_dir, T
from tinygrad.helpers import prod, to_mv, getenv, round_up, _cache_dir, T, init_c_struct_t
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator
from tinygrad.renderer.cstyle import MetalRenderer
@@ -45,10 +45,7 @@ def msg(ptr: objc_id, selector: str, /, *args: Any, restype: type[T] = objc_id)
def to_ns_str(s: str): return msg(libobjc.objc_getClass(b"NSString"), "stringWithUTF8String:", s.encode(), restype=objc_instance)
def to_struct(*t: int, _type: type = ctypes.c_ulong):
class Struct(ctypes.Structure): pass
Struct._fields_ = [(f"field{i}", _type) for i in range(len(t))]
return Struct(*t)
def to_struct(*t: int, _type: type = ctypes.c_ulong): return init_c_struct_t(tuple([(f"field{i}", _type) for i in range(len(t))]))(*t)
def wait_check(cbuf: Any):
msg(cbuf, "waitUntilCompleted")
@@ -112,9 +109,8 @@ class MetalProgram:
if lib[:4] == b"MTLB":
# binary metal library
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
error_library_creation = objc_instance()
self.library = msg(self.dev.sysdevice, "newLibraryWithData:error:", data, ctypes.byref(error_library_creation), restype=objc_instance)
error_check(error_library_creation)
self.library = msg(self.dev.sysdevice, "newLibraryWithData:error:", data, ctypes.byref(error_lib:=objc_instance()), restype=objc_instance)
error_check(error_lib)
else:
# metal source. rely on OS caching
try: self.library = metal_src_to_library(self.dev, lib.decode())
@@ -137,7 +133,7 @@ class MetalProgram:
encoder = msg(command_buffer, "computeCommandEncoder", restype=objc_instance)
msg(encoder, "setComputePipelineState:", self.pipeline_state)
for i,a in enumerate(bufs): msg(encoder, "setBuffer:offset:atIndex:", a.buf, a.offset, i)
for i,a in enumerate(vals,start=len(bufs)): msg(encoder, "setBytes:length:atIndex:", bytes(ctypes.c_int(a)), 4, i)
for i,a in enumerate(vals, start=len(bufs)): msg(encoder, "setBytes:length:atIndex:", bytes(ctypes.c_int(a)), 4, i)
msg(encoder, "dispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_size), to_struct(*local_size))
msg(encoder, "endEncoding")
msg(command_buffer, "commit")
@@ -178,9 +174,7 @@ class MetalAllocator(LRUAllocator):
src_dev.mtl_buffers_in_flight.append(src_command_buffer)
def _as_buffer(self, src:MetalBuffer) -> memoryview:
self.dev.synchronize()
ptr = msg(src.buf, "contents", restype=objc_id) # Shared memory, do not release here
array = (ctypes.c_char * (src.offset + src.size)).from_address(ptr.value)
return memoryview(array).cast("B")[src.offset:]
return to_mv(cast(int, msg(src.buf, "contents", restype=objc_id).value), src.size + src.offset)[src.offset:]
def _copyin(self, dest:MetalBuffer, src:memoryview): self._as_buffer(dest)[:] = src
def _copyout(self, dest:memoryview, src:MetalBuffer): dest[:] = self._as_buffer(src)
def _offset(self, buf:MetalBuffer, size:int, offset:int): return MetalBuffer(buf.buf, size, offset)