mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
Use MTLCompiler directly (v2) (#7920)
* Use MTLCompiler directly (v2) * to_block_literal and REQUEST_TYPE_COMPILE * Rewrite command encoding * Revert to_block_literal * Maybe that's more readable to some people? * Typo and comment about stdlib caching * Update ops_metal.py * Update ops_metal.py * Update ops_metal.py --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
2
test/external/external_metal_compile_fail.py
vendored
2
test/external/external_metal_compile_fail.py
vendored
@@ -225,6 +225,6 @@ from tinygrad.runtime.ops_metal import MetalDevice, MetalCompiler, MetalProgram
|
||||
|
||||
if __name__ == "__main__":
|
||||
dev = MetalDevice("METAL")
|
||||
lib = MetalCompiler(dev).compile(src)
|
||||
lib = MetalCompiler().compile(src)
|
||||
prg = MetalProgram(dev, "r_64_32_8_16_4_6_6_4", lib)
|
||||
|
||||
|
||||
@@ -10,14 +10,12 @@ class TestMetal(unittest.TestCase):
|
||||
device.allocator.alloc(10000000000000000000)
|
||||
|
||||
def test_compile_error(self):
|
||||
device = MetalDevice("metal")
|
||||
compiler = MetalCompiler(device)
|
||||
compiler = MetalCompiler()
|
||||
with self.assertRaises(CompileError):
|
||||
compiler.compile("this is not valid metal")
|
||||
|
||||
def test_compile_success(self):
|
||||
device = MetalDevice("metal")
|
||||
compiler = MetalCompiler(device)
|
||||
compiler = MetalCompiler()
|
||||
ret = compiler.compile("""
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
@@ -41,7 +39,7 @@ class TestMetal(unittest.TestCase):
|
||||
|
||||
def test_failed_newLibraryWithData(self):
|
||||
device = MetalDevice("metal")
|
||||
compiler = MetalCompiler(device)
|
||||
compiler = MetalCompiler()
|
||||
compiled = compiler.compile("""
|
||||
#include <metal_stdlib>
|
||||
kernel void r_5(device int* data0, const device int* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]){
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import os, subprocess, pathlib, ctypes, tempfile, functools
|
||||
from typing import List, Any, Tuple, Optional, cast
|
||||
from tinygrad.helpers import prod, getenv, T
|
||||
import os, pathlib, struct, ctypes, tempfile, functools
|
||||
from typing import List, Any, Union, Tuple, cast
|
||||
from tinygrad.helpers import prod, to_mv, getenv, round_up, _cache_dir, T
|
||||
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator
|
||||
from tinygrad.renderer.cstyle import MetalRenderer
|
||||
|
||||
@@ -22,14 +22,19 @@ class MTLResourceOptions:
|
||||
class MTLPipelineOption:
|
||||
MTLPipelineOptionNone = 0
|
||||
|
||||
# 13 is requestType that metal uses to compile source code into MTLB, there aren't any docs or symbols.
|
||||
REQUEST_TYPE_COMPILE = 13
|
||||
|
||||
libobjc = ctypes.CDLL("/usr/lib/libobjc.dylib")
|
||||
libmetal = ctypes.CDLL("/System/Library/Frameworks/Metal.framework/Metal")
|
||||
compiler = ctypes.CDLL("/System/Library/PrivateFrameworks/MTLCompiler.framework/MTLCompiler")
|
||||
# Must be loaded for default Metal Device: https://developer.apple.com/documentation/metal/1433401-mtlcreatesystemdefaultdevice?language=objc
|
||||
ctypes.CDLL("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics")
|
||||
libdispatch = ctypes.CDLL("/usr/lib/libSystem.dylib") # libdispatch is part of libSystem on mac
|
||||
libobjc.objc_getClass.restype = objc_id
|
||||
libobjc.sel_registerName.restype = objc_id
|
||||
libmetal.MTLCreateSystemDefaultDevice.restype = objc_instance
|
||||
compiler.MTLCodeGenServiceCreate.restype = ctypes.c_void_p
|
||||
libdispatch.dispatch_data_create.restype = objc_instance
|
||||
|
||||
# Ignore mypy error reporting incompatible default, because typevar default only works on python 3.12
|
||||
@@ -65,20 +70,35 @@ def metal_src_to_library(device:MetalDevice, src:str) -> objc_instance:
|
||||
return library
|
||||
|
||||
class MetalCompiler(Compiler):
|
||||
def __init__(self, dev:Optional[MetalDevice]=None):
|
||||
self.dev = dev
|
||||
super().__init__("compile_metal_xcode" if self.dev is None else "compile_metal")
|
||||
def __init__(self):
|
||||
self.cgs = ctypes.c_void_p(compiler.MTLCodeGenServiceCreate(b"tinygrad"))
|
||||
super().__init__("compile_metal_direct")
|
||||
def __reduce__(self): return (MetalCompiler,()) # force pickle to create new instance for each multiprocessing fork
|
||||
def compile(self, src:str) -> bytes:
|
||||
if self.dev is None:
|
||||
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
|
||||
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=src.encode('utf-8'))
|
||||
lib = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
|
||||
else:
|
||||
library = metal_src_to_library(self.dev, src)
|
||||
library_contents = msg(library, "libraryDataContents", restype=objc_instance)
|
||||
lib = ctypes.string_at(msg(library_contents, "bytes"), cast(int, msg(library_contents, "length", restype=ctypes.c_ulong)))
|
||||
assert lib[:4] == b"MTLB", "Invalid Metal library. Using conda? Corrupt XCode?"
|
||||
return lib
|
||||
ret: Union[Exception, bytes] = CompileError("MTLCodeGenServiceBuildRequest returned without calling the callback")
|
||||
@ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_int32, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_char_p)
|
||||
def callback(blockptr, error, dataPtr, dataLen, errorMessage):
|
||||
nonlocal ret
|
||||
if error == 0:
|
||||
reply = bytes(to_mv(dataPtr, dataLen))
|
||||
# offset from beginning to data = header size + warning size
|
||||
ret = reply[sum(struct.unpack('<LL', reply[8:16])):]
|
||||
else:
|
||||
ret = CompileError(errorMessage.decode())
|
||||
# llvm will create modules.timestamp in cache path and cache compilation of metal stdlib (250ms => 8ms compilation time)
|
||||
# note that llvm won't necessarily create anything else here as apple has prebuilt versions of many standard libraries
|
||||
params = f'-fno-fast-math -std=metal3.1 --driver-mode=metal -x metal -fmodules-cache-path="{os.path.join(_cache_dir, "tinygrad")}"'
|
||||
# source blob has to be padded to multiple of 4 but at least one 'b\x00' should be added, params blob just has to be null terminated
|
||||
src_padded, params_padded = src.encode() + b'\x00'*(round_up(len(src) + 1, 4) - len(src)), params.encode() + b'\x00'
|
||||
request = struct.pack('<QQ', len(src_padded), len(params_padded)) + src_padded + params_padded
|
||||
# The callback is actually not a callback but a block which is apple's non-standard extension to add closures to C.
|
||||
# See https://clang.llvm.org/docs/Block-ABI-Apple.html#high-level for struct layout.
|
||||
# Fields other than invoke are unused in this case so we can just use ctypes.byref with negative offset to invoke field, add blockptr as a first
|
||||
# argument and pretend it's a normal callback
|
||||
compiler.MTLCodeGenServiceBuildRequest(self.cgs, None, REQUEST_TYPE_COMPILE, request, len(request), ctypes.byref(callback, -0x10))
|
||||
if isinstance(ret, Exception): raise ret
|
||||
assert ret[:4] == b"MTLB" and ret[-4:] == b"ENDT", f"Invalid Metal library. {ret!r}"
|
||||
return ret
|
||||
def disassemble(self, lib:bytes):
|
||||
with tempfile.NamedTemporaryFile(delete=True) as shader:
|
||||
shader.write(lib)
|
||||
@@ -176,7 +196,7 @@ class MetalDevice(Compiled):
|
||||
self.timeline_value = 0
|
||||
|
||||
from tinygrad.runtime.graph.metal import MetalGraph
|
||||
super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler() if getenv("METAL_XCODE") else Compiler(),
|
||||
super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler() if getenv("METAL_DIRECT", 1) else Compiler(),
|
||||
functools.partial(MetalProgram, self), MetalGraph)
|
||||
def synchronize(self):
|
||||
for cbuf in self.mtl_buffers_in_flight: wait_check(cbuf)
|
||||
|
||||
Reference in New Issue
Block a user