diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index c50f0076fa..be3f77c407 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -1,7 +1,7 @@ from __future__ import annotations import os, pathlib, struct, ctypes, tempfile, functools from typing import List, Any, Union, Tuple, cast -from tinygrad.helpers import prod, to_mv, getenv, round_up, _cache_dir, T +from tinygrad.helpers import prod, to_mv, getenv, round_up, _cache_dir, T, init_c_struct_t from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator from tinygrad.renderer.cstyle import MetalRenderer @@ -45,10 +45,7 @@ def msg(ptr: objc_id, selector: str, /, *args: Any, restype: type[T] = objc_id) def to_ns_str(s: str): return msg(libobjc.objc_getClass(b"NSString"), "stringWithUTF8String:", s.encode(), restype=objc_instance) -def to_struct(*t: int, _type: type = ctypes.c_ulong): - class Struct(ctypes.Structure): pass - Struct._fields_ = [(f"field{i}", _type) for i in range(len(t))] - return Struct(*t) +def to_struct(*t: int, _type: type = ctypes.c_ulong): return init_c_struct_t(tuple([(f"field{i}", _type) for i in range(len(t))]))(*t) def wait_check(cbuf: Any): msg(cbuf, "waitUntilCompleted") @@ -112,9 +109,8 @@ class MetalProgram: if lib[:4] == b"MTLB": # binary metal library data = libdispatch.dispatch_data_create(lib, len(lib), None, None) - error_library_creation = objc_instance() - self.library = msg(self.dev.sysdevice, "newLibraryWithData:error:", data, ctypes.byref(error_library_creation), restype=objc_instance) - error_check(error_library_creation) + self.library = msg(self.dev.sysdevice, "newLibraryWithData:error:", data, ctypes.byref(error_lib:=objc_instance()), restype=objc_instance) + error_check(error_lib) else: # metal source. rely on OS caching try: self.library = metal_src_to_library(self.dev, lib.decode()) @@ -137,7 +133,7 @@ class MetalProgram: encoder = msg(command_buffer, "computeCommandEncoder", restype=objc_instance) msg(encoder, "setComputePipelineState:", self.pipeline_state) for i,a in enumerate(bufs): msg(encoder, "setBuffer:offset:atIndex:", a.buf, a.offset, i) - for i,a in enumerate(vals,start=len(bufs)): msg(encoder, "setBytes:length:atIndex:", bytes(ctypes.c_int(a)), 4, i) + for i,a in enumerate(vals, start=len(bufs)): msg(encoder, "setBytes:length:atIndex:", bytes(ctypes.c_int(a)), 4, i) msg(encoder, "dispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_size), to_struct(*local_size)) msg(encoder, "endEncoding") msg(command_buffer, "commit") @@ -178,9 +174,7 @@ class MetalAllocator(LRUAllocator): src_dev.mtl_buffers_in_flight.append(src_command_buffer) def _as_buffer(self, src:MetalBuffer) -> memoryview: self.dev.synchronize() - ptr = msg(src.buf, "contents", restype=objc_id) # Shared memory, do not release here - array = (ctypes.c_char * (src.offset + src.size)).from_address(ptr.value) - return memoryview(array).cast("B")[src.offset:] + return to_mv(cast(int, msg(src.buf, "contents", restype=objc_id).value), src.size + src.offset)[src.offset:] def _copyin(self, dest:MetalBuffer, src:memoryview): self._as_buffer(dest)[:] = src def _copyout(self, dest:memoryview, src:MetalBuffer): dest[:] = self._as_buffer(src) def _offset(self, buf:MetalBuffer, size:int, offset:int): return MetalBuffer(buf.buf, size, offset)