Use MTLCompiler directly (v2) (#7920)

* Use MTLCompiler directly (v2)

* to_block_literal and REQUEST_TYPE_COMPILE

* Rewrite command encoding

* Revert to_block_literal

* Maybe that's more readable to some people?

* Typo and comment about stdlib caching

* Update ops_metal.py

* Update ops_metal.py

* Update ops_metal.py

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
uuuvn
2024-12-04 10:36:48 +02:00
committed by GitHub
parent bb98bae751
commit e9c5b23ba1
3 changed files with 41 additions and 23 deletions

View File

@@ -225,6 +225,6 @@ from tinygrad.runtime.ops_metal import MetalDevice, MetalCompiler, MetalProgram
if __name__ == "__main__":
dev = MetalDevice("METAL")
lib = MetalCompiler(dev).compile(src)
lib = MetalCompiler().compile(src)
prg = MetalProgram(dev, "r_64_32_8_16_4_6_6_4", lib)

View File

@@ -10,14 +10,12 @@ class TestMetal(unittest.TestCase):
device.allocator.alloc(10000000000000000000)
def test_compile_error(self):
device = MetalDevice("metal")
compiler = MetalCompiler(device)
compiler = MetalCompiler()
with self.assertRaises(CompileError):
compiler.compile("this is not valid metal")
def test_compile_success(self):
device = MetalDevice("metal")
compiler = MetalCompiler(device)
compiler = MetalCompiler()
ret = compiler.compile("""
#include <metal_stdlib>
using namespace metal;
@@ -41,7 +39,7 @@ class TestMetal(unittest.TestCase):
def test_failed_newLibraryWithData(self):
device = MetalDevice("metal")
compiler = MetalCompiler(device)
compiler = MetalCompiler()
compiled = compiler.compile("""
#include <metal_stdlib>
kernel void r_5(device int* data0, const device int* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]){

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import os, subprocess, pathlib, ctypes, tempfile, functools
from typing import List, Any, Tuple, Optional, cast
from tinygrad.helpers import prod, getenv, T
import os, pathlib, struct, ctypes, tempfile, functools
from typing import List, Any, Union, Tuple, cast
from tinygrad.helpers import prod, to_mv, getenv, round_up, _cache_dir, T
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator
from tinygrad.renderer.cstyle import MetalRenderer
@@ -22,14 +22,19 @@ class MTLResourceOptions:
class MTLPipelineOption:
MTLPipelineOptionNone = 0
# 13 is requestType that metal uses to compile source code into MTLB, there aren't any docs or symbols.
REQUEST_TYPE_COMPILE = 13
libobjc = ctypes.CDLL("/usr/lib/libobjc.dylib")
libmetal = ctypes.CDLL("/System/Library/Frameworks/Metal.framework/Metal")
compiler = ctypes.CDLL("/System/Library/PrivateFrameworks/MTLCompiler.framework/MTLCompiler")
# Must be loaded for default Metal Device: https://developer.apple.com/documentation/metal/1433401-mtlcreatesystemdefaultdevice?language=objc
ctypes.CDLL("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics")
libdispatch = ctypes.CDLL("/usr/lib/libSystem.dylib") # libdispatch is part of libSystem on mac
libobjc.objc_getClass.restype = objc_id
libobjc.sel_registerName.restype = objc_id
libmetal.MTLCreateSystemDefaultDevice.restype = objc_instance
compiler.MTLCodeGenServiceCreate.restype = ctypes.c_void_p
libdispatch.dispatch_data_create.restype = objc_instance
# Ignore mypy error reporting incompatible default, because typevar default only works on python 3.12
@@ -65,20 +70,35 @@ def metal_src_to_library(device:MetalDevice, src:str) -> objc_instance:
return library
class MetalCompiler(Compiler):
def __init__(self, dev:Optional[MetalDevice]=None):
self.dev = dev
super().__init__("compile_metal_xcode" if self.dev is None else "compile_metal")
def __init__(self):
self.cgs = ctypes.c_void_p(compiler.MTLCodeGenServiceCreate(b"tinygrad"))
super().__init__("compile_metal_direct")
def __reduce__(self): return (MetalCompiler,()) # force pickle to create new instance for each multiprocessing fork
def compile(self, src:str) -> bytes:
if self.dev is None:
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=src.encode('utf-8'))
lib = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
else:
library = metal_src_to_library(self.dev, src)
library_contents = msg(library, "libraryDataContents", restype=objc_instance)
lib = ctypes.string_at(msg(library_contents, "bytes"), cast(int, msg(library_contents, "length", restype=ctypes.c_ulong)))
assert lib[:4] == b"MTLB", "Invalid Metal library. Using conda? Corrupt XCode?"
return lib
ret: Union[Exception, bytes] = CompileError("MTLCodeGenServiceBuildRequest returned without calling the callback")
@ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_int32, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_char_p)
def callback(blockptr, error, dataPtr, dataLen, errorMessage):
nonlocal ret
if error == 0:
reply = bytes(to_mv(dataPtr, dataLen))
# offset from beginning to data = header size + warning size
ret = reply[sum(struct.unpack('<LL', reply[8:16])):]
else:
ret = CompileError(errorMessage.decode())
# llvm will create modules.timestamp in cache path and cache compilation of metal stdlib (250ms => 8ms compilation time)
# note that llvm won't necessarily create anything else here as apple has prebuilt versions of many standard libraries
params = f'-fno-fast-math -std=metal3.1 --driver-mode=metal -x metal -fmodules-cache-path="{os.path.join(_cache_dir, "tinygrad")}"'
# source blob has to be padded to multiple of 4 but at least one 'b\x00' should be added, params blob just has to be null terminated
src_padded, params_padded = src.encode() + b'\x00'*(round_up(len(src) + 1, 4) - len(src)), params.encode() + b'\x00'
request = struct.pack('<QQ', len(src_padded), len(params_padded)) + src_padded + params_padded
# The callback is actually not a callback but a block which is apple's non-standard extension to add closures to C.
# See https://clang.llvm.org/docs/Block-ABI-Apple.html#high-level for struct layout.
# Fields other than invoke are unused in this case so we can just use ctypes.byref with negative offset to invoke field, add blockptr as a first
# argument and pretend it's a normal callback
compiler.MTLCodeGenServiceBuildRequest(self.cgs, None, REQUEST_TYPE_COMPILE, request, len(request), ctypes.byref(callback, -0x10))
if isinstance(ret, Exception): raise ret
assert ret[:4] == b"MTLB" and ret[-4:] == b"ENDT", f"Invalid Metal library. {ret!r}"
return ret
def disassemble(self, lib:bytes):
with tempfile.NamedTemporaryFile(delete=True) as shader:
shader.write(lib)
@@ -176,7 +196,7 @@ class MetalDevice(Compiled):
self.timeline_value = 0
from tinygrad.runtime.graph.metal import MetalGraph
super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler() if getenv("METAL_XCODE") else Compiler(),
super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler() if getenv("METAL_DIRECT", 1) else Compiler(),
functools.partial(MetalProgram, self), MetalGraph)
def synchronize(self):
for cbuf in self.mtl_buffers_in_flight: wait_check(cbuf)