From e9c5b23ba1d5d5693ffd050247fbf4cb7be77079 Mon Sep 17 00:00:00 2001 From: uuuvn <83587632+uuuvn@users.noreply.github.com> Date: Wed, 4 Dec 2024 10:36:48 +0200 Subject: [PATCH] Use MTLCompiler directly (v2) (#7920) * Use MTLCompiler directly (v2) * to_block_literal and REQUEST_TYPE_COMPILE * Rewrite command encoding * Revert to_block_literal * Maybe that's more readable to some people? * Typo and comment about stdlib caching * Update ops_metal.py * Update ops_metal.py * Update ops_metal.py --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> --- test/external/external_metal_compile_fail.py | 2 +- test/test_metal.py | 8 ++- tinygrad/runtime/ops_metal.py | 54 ++++++++++++++------ 3 files changed, 41 insertions(+), 23 deletions(-) diff --git a/test/external/external_metal_compile_fail.py b/test/external/external_metal_compile_fail.py index 26d9cc4a4e..d6e2bf9829 100644 --- a/test/external/external_metal_compile_fail.py +++ b/test/external/external_metal_compile_fail.py @@ -225,6 +225,6 @@ from tinygrad.runtime.ops_metal import MetalDevice, MetalCompiler, MetalProgram if __name__ == "__main__": dev = MetalDevice("METAL") - lib = MetalCompiler(dev).compile(src) + lib = MetalCompiler().compile(src) prg = MetalProgram(dev, "r_64_32_8_16_4_6_6_4", lib) diff --git a/test/test_metal.py b/test/test_metal.py index b324d81c4c..3230b7b731 100644 --- a/test/test_metal.py +++ b/test/test_metal.py @@ -10,14 +10,12 @@ class TestMetal(unittest.TestCase): device.allocator.alloc(10000000000000000000) def test_compile_error(self): - device = MetalDevice("metal") - compiler = MetalCompiler(device) + compiler = MetalCompiler() with self.assertRaises(CompileError): compiler.compile("this is not valid metal") def test_compile_success(self): - device = MetalDevice("metal") - compiler = MetalCompiler(device) + compiler = MetalCompiler() ret = compiler.compile(""" #include using namespace metal; @@ -41,7 +39,7 @@ class TestMetal(unittest.TestCase): def test_failed_newLibraryWithData(self): device = MetalDevice("metal") - compiler = MetalCompiler(device) + compiler = MetalCompiler() compiled = compiler.compile(""" #include kernel void r_5(device int* data0, const device int* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]){ diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index e7185b69e3..c50f0076fa 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -1,7 +1,7 @@ from __future__ import annotations -import os, subprocess, pathlib, ctypes, tempfile, functools -from typing import List, Any, Tuple, Optional, cast -from tinygrad.helpers import prod, getenv, T +import os, pathlib, struct, ctypes, tempfile, functools +from typing import List, Any, Union, Tuple, cast +from tinygrad.helpers import prod, to_mv, getenv, round_up, _cache_dir, T from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator from tinygrad.renderer.cstyle import MetalRenderer @@ -22,14 +22,19 @@ class MTLResourceOptions: class MTLPipelineOption: MTLPipelineOptionNone = 0 +# 13 is requestType that metal uses to compile source code into MTLB, there aren't any docs or symbols. +REQUEST_TYPE_COMPILE = 13 + libobjc = ctypes.CDLL("/usr/lib/libobjc.dylib") libmetal = ctypes.CDLL("/System/Library/Frameworks/Metal.framework/Metal") +compiler = ctypes.CDLL("/System/Library/PrivateFrameworks/MTLCompiler.framework/MTLCompiler") # Must be loaded for default Metal Device: https://developer.apple.com/documentation/metal/1433401-mtlcreatesystemdefaultdevice?language=objc ctypes.CDLL("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics") libdispatch = ctypes.CDLL("/usr/lib/libSystem.dylib") # libdispatch is part of libSystem on mac libobjc.objc_getClass.restype = objc_id libobjc.sel_registerName.restype = objc_id libmetal.MTLCreateSystemDefaultDevice.restype = objc_instance +compiler.MTLCodeGenServiceCreate.restype = ctypes.c_void_p libdispatch.dispatch_data_create.restype = objc_instance # Ignore mypy error reporting incompatible default, because typevar default only works on python 3.12 @@ -65,20 +70,35 @@ def metal_src_to_library(device:MetalDevice, src:str) -> objc_instance: return library class MetalCompiler(Compiler): - def __init__(self, dev:Optional[MetalDevice]=None): - self.dev = dev - super().__init__("compile_metal_xcode" if self.dev is None else "compile_metal") + def __init__(self): + self.cgs = ctypes.c_void_p(compiler.MTLCodeGenServiceCreate(b"tinygrad")) + super().__init__("compile_metal_direct") + def __reduce__(self): return (MetalCompiler,()) # force pickle to create new instance for each multiprocessing fork def compile(self, src:str) -> bytes: - if self.dev is None: - # NOTE: if you run llvm-dis on "air" you can see the llvm bytecode - air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=src.encode('utf-8')) - lib = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air) - else: - library = metal_src_to_library(self.dev, src) - library_contents = msg(library, "libraryDataContents", restype=objc_instance) - lib = ctypes.string_at(msg(library_contents, "bytes"), cast(int, msg(library_contents, "length", restype=ctypes.c_ulong))) - assert lib[:4] == b"MTLB", "Invalid Metal library. Using conda? Corrupt XCode?" - return lib + ret: Union[Exception, bytes] = CompileError("MTLCodeGenServiceBuildRequest returned without calling the callback") + @ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_int32, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_char_p) + def callback(blockptr, error, dataPtr, dataLen, errorMessage): + nonlocal ret + if error == 0: + reply = bytes(to_mv(dataPtr, dataLen)) + # offset from beginning to data = header size + warning size + ret = reply[sum(struct.unpack(' 8ms compilation time) + # note that llvm won't necessarily create anything else here as apple has prebuilt versions of many standard libraries + params = f'-fno-fast-math -std=metal3.1 --driver-mode=metal -x metal -fmodules-cache-path="{os.path.join(_cache_dir, "tinygrad")}"' + # source blob has to be padded to multiple of 4 but at least one 'b\x00' should be added, params blob just has to be null terminated + src_padded, params_padded = src.encode() + b'\x00'*(round_up(len(src) + 1, 4) - len(src)), params.encode() + b'\x00' + request = struct.pack('