Use MTLCompiler directly (v2) (#7920)

* Use MTLCompiler directly (v2)

* to_block_literal and REQUEST_TYPE_COMPILE

* Rewrite command encoding

* Revert to_block_literal

* Maybe that's more readable to some people?

* Typo and comment about stdlib caching

* Update ops_metal.py

* Update ops_metal.py

* Update ops_metal.py

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
uuuvn
2024-12-04 10:36:48 +02:00
committed by GitHub
parent bb98bae751
commit e9c5b23ba1
3 changed files with 41 additions and 23 deletions

View File

@@ -225,6 +225,6 @@ from tinygrad.runtime.ops_metal import MetalDevice, MetalCompiler, MetalProgram
if __name__ == "__main__": if __name__ == "__main__":
dev = MetalDevice("METAL") dev = MetalDevice("METAL")
lib = MetalCompiler(dev).compile(src) lib = MetalCompiler().compile(src)
prg = MetalProgram(dev, "r_64_32_8_16_4_6_6_4", lib) prg = MetalProgram(dev, "r_64_32_8_16_4_6_6_4", lib)

View File

@@ -10,14 +10,12 @@ class TestMetal(unittest.TestCase):
device.allocator.alloc(10000000000000000000) device.allocator.alloc(10000000000000000000)
def test_compile_error(self): def test_compile_error(self):
device = MetalDevice("metal") compiler = MetalCompiler()
compiler = MetalCompiler(device)
with self.assertRaises(CompileError): with self.assertRaises(CompileError):
compiler.compile("this is not valid metal") compiler.compile("this is not valid metal")
def test_compile_success(self): def test_compile_success(self):
device = MetalDevice("metal") compiler = MetalCompiler()
compiler = MetalCompiler(device)
ret = compiler.compile(""" ret = compiler.compile("""
#include <metal_stdlib> #include <metal_stdlib>
using namespace metal; using namespace metal;
@@ -41,7 +39,7 @@ class TestMetal(unittest.TestCase):
def test_failed_newLibraryWithData(self): def test_failed_newLibraryWithData(self):
device = MetalDevice("metal") device = MetalDevice("metal")
compiler = MetalCompiler(device) compiler = MetalCompiler()
compiled = compiler.compile(""" compiled = compiler.compile("""
#include <metal_stdlib> #include <metal_stdlib>
kernel void r_5(device int* data0, const device int* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]){ kernel void r_5(device int* data0, const device int* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]){

View File

@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import os, subprocess, pathlib, ctypes, tempfile, functools import os, pathlib, struct, ctypes, tempfile, functools
from typing import List, Any, Tuple, Optional, cast from typing import List, Any, Union, Tuple, cast
from tinygrad.helpers import prod, getenv, T from tinygrad.helpers import prod, to_mv, getenv, round_up, _cache_dir, T
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator
from tinygrad.renderer.cstyle import MetalRenderer from tinygrad.renderer.cstyle import MetalRenderer
@@ -22,14 +22,19 @@ class MTLResourceOptions:
class MTLPipelineOption: class MTLPipelineOption:
MTLPipelineOptionNone = 0 MTLPipelineOptionNone = 0
# 13 is requestType that metal uses to compile source code into MTLB, there aren't any docs or symbols.
REQUEST_TYPE_COMPILE = 13
libobjc = ctypes.CDLL("/usr/lib/libobjc.dylib") libobjc = ctypes.CDLL("/usr/lib/libobjc.dylib")
libmetal = ctypes.CDLL("/System/Library/Frameworks/Metal.framework/Metal") libmetal = ctypes.CDLL("/System/Library/Frameworks/Metal.framework/Metal")
compiler = ctypes.CDLL("/System/Library/PrivateFrameworks/MTLCompiler.framework/MTLCompiler")
# Must be loaded for default Metal Device: https://developer.apple.com/documentation/metal/1433401-mtlcreatesystemdefaultdevice?language=objc # Must be loaded for default Metal Device: https://developer.apple.com/documentation/metal/1433401-mtlcreatesystemdefaultdevice?language=objc
ctypes.CDLL("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics") ctypes.CDLL("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics")
libdispatch = ctypes.CDLL("/usr/lib/libSystem.dylib") # libdispatch is part of libSystem on mac libdispatch = ctypes.CDLL("/usr/lib/libSystem.dylib") # libdispatch is part of libSystem on mac
libobjc.objc_getClass.restype = objc_id libobjc.objc_getClass.restype = objc_id
libobjc.sel_registerName.restype = objc_id libobjc.sel_registerName.restype = objc_id
libmetal.MTLCreateSystemDefaultDevice.restype = objc_instance libmetal.MTLCreateSystemDefaultDevice.restype = objc_instance
compiler.MTLCodeGenServiceCreate.restype = ctypes.c_void_p
libdispatch.dispatch_data_create.restype = objc_instance libdispatch.dispatch_data_create.restype = objc_instance
# Ignore mypy error reporting incompatible default, because typevar default only works on python 3.12 # Ignore mypy error reporting incompatible default, because typevar default only works on python 3.12
@@ -65,20 +70,35 @@ def metal_src_to_library(device:MetalDevice, src:str) -> objc_instance:
return library return library
class MetalCompiler(Compiler): class MetalCompiler(Compiler):
def __init__(self, dev:Optional[MetalDevice]=None): def __init__(self):
self.dev = dev self.cgs = ctypes.c_void_p(compiler.MTLCodeGenServiceCreate(b"tinygrad"))
super().__init__("compile_metal_xcode" if self.dev is None else "compile_metal") super().__init__("compile_metal_direct")
def __reduce__(self): return (MetalCompiler,()) # force pickle to create new instance for each multiprocessing fork
def compile(self, src:str) -> bytes: def compile(self, src:str) -> bytes:
if self.dev is None: ret: Union[Exception, bytes] = CompileError("MTLCodeGenServiceBuildRequest returned without calling the callback")
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode @ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_int32, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_char_p)
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=src.encode('utf-8')) def callback(blockptr, error, dataPtr, dataLen, errorMessage):
lib = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air) nonlocal ret
else: if error == 0:
library = metal_src_to_library(self.dev, src) reply = bytes(to_mv(dataPtr, dataLen))
library_contents = msg(library, "libraryDataContents", restype=objc_instance) # offset from beginning to data = header size + warning size
lib = ctypes.string_at(msg(library_contents, "bytes"), cast(int, msg(library_contents, "length", restype=ctypes.c_ulong))) ret = reply[sum(struct.unpack('<LL', reply[8:16])):]
assert lib[:4] == b"MTLB", "Invalid Metal library. Using conda? Corrupt XCode?" else:
return lib ret = CompileError(errorMessage.decode())
# llvm will create modules.timestamp in cache path and cache compilation of metal stdlib (250ms => 8ms compilation time)
# note that llvm won't necessarily create anything else here as apple has prebuilt versions of many standard libraries
params = f'-fno-fast-math -std=metal3.1 --driver-mode=metal -x metal -fmodules-cache-path="{os.path.join(_cache_dir, "tinygrad")}"'
# source blob has to be padded to multiple of 4 but at least one 'b\x00' should be added, params blob just has to be null terminated
src_padded, params_padded = src.encode() + b'\x00'*(round_up(len(src) + 1, 4) - len(src)), params.encode() + b'\x00'
request = struct.pack('<QQ', len(src_padded), len(params_padded)) + src_padded + params_padded
# The callback is actually not a callback but a block which is apple's non-standard extension to add closures to C.
# See https://clang.llvm.org/docs/Block-ABI-Apple.html#high-level for struct layout.
# Fields other than invoke are unused in this case so we can just use ctypes.byref with negative offset to invoke field, add blockptr as a first
# argument and pretend it's a normal callback
compiler.MTLCodeGenServiceBuildRequest(self.cgs, None, REQUEST_TYPE_COMPILE, request, len(request), ctypes.byref(callback, -0x10))
if isinstance(ret, Exception): raise ret
assert ret[:4] == b"MTLB" and ret[-4:] == b"ENDT", f"Invalid Metal library. {ret!r}"
return ret
def disassemble(self, lib:bytes): def disassemble(self, lib:bytes):
with tempfile.NamedTemporaryFile(delete=True) as shader: with tempfile.NamedTemporaryFile(delete=True) as shader:
shader.write(lib) shader.write(lib)
@@ -176,7 +196,7 @@ class MetalDevice(Compiled):
self.timeline_value = 0 self.timeline_value = 0
from tinygrad.runtime.graph.metal import MetalGraph from tinygrad.runtime.graph.metal import MetalGraph
super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler() if getenv("METAL_XCODE") else Compiler(), super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler() if getenv("METAL_DIRECT", 1) else Compiler(),
functools.partial(MetalProgram, self), MetalGraph) functools.partial(MetalProgram, self), MetalGraph)
def synchronize(self): def synchronize(self):
for cbuf in self.mtl_buffers_in_flight: wait_check(cbuf) for cbuf in self.mtl_buffers_in_flight: wait_check(cbuf)