mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Use MTLCompiler directly (v2) (#7920)
* Use MTLCompiler directly (v2) * to_block_literal and REQUEST_TYPE_COMPILE * Rewrite command encoding * Revert to_block_literal * Maybe that's more readable to some people? * Typo and comment about stdlib caching * Update ops_metal.py * Update ops_metal.py * Update ops_metal.py --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
2
test/external/external_metal_compile_fail.py
vendored
2
test/external/external_metal_compile_fail.py
vendored
@@ -225,6 +225,6 @@ from tinygrad.runtime.ops_metal import MetalDevice, MetalCompiler, MetalProgram
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
dev = MetalDevice("METAL")
|
dev = MetalDevice("METAL")
|
||||||
lib = MetalCompiler(dev).compile(src)
|
lib = MetalCompiler().compile(src)
|
||||||
prg = MetalProgram(dev, "r_64_32_8_16_4_6_6_4", lib)
|
prg = MetalProgram(dev, "r_64_32_8_16_4_6_6_4", lib)
|
||||||
|
|
||||||
|
|||||||
@@ -10,14 +10,12 @@ class TestMetal(unittest.TestCase):
|
|||||||
device.allocator.alloc(10000000000000000000)
|
device.allocator.alloc(10000000000000000000)
|
||||||
|
|
||||||
def test_compile_error(self):
|
def test_compile_error(self):
|
||||||
device = MetalDevice("metal")
|
compiler = MetalCompiler()
|
||||||
compiler = MetalCompiler(device)
|
|
||||||
with self.assertRaises(CompileError):
|
with self.assertRaises(CompileError):
|
||||||
compiler.compile("this is not valid metal")
|
compiler.compile("this is not valid metal")
|
||||||
|
|
||||||
def test_compile_success(self):
|
def test_compile_success(self):
|
||||||
device = MetalDevice("metal")
|
compiler = MetalCompiler()
|
||||||
compiler = MetalCompiler(device)
|
|
||||||
ret = compiler.compile("""
|
ret = compiler.compile("""
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
@@ -41,7 +39,7 @@ class TestMetal(unittest.TestCase):
|
|||||||
|
|
||||||
def test_failed_newLibraryWithData(self):
|
def test_failed_newLibraryWithData(self):
|
||||||
device = MetalDevice("metal")
|
device = MetalDevice("metal")
|
||||||
compiler = MetalCompiler(device)
|
compiler = MetalCompiler()
|
||||||
compiled = compiler.compile("""
|
compiled = compiler.compile("""
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
kernel void r_5(device int* data0, const device int* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]){
|
kernel void r_5(device int* data0, const device int* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]){
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import os, subprocess, pathlib, ctypes, tempfile, functools
|
import os, pathlib, struct, ctypes, tempfile, functools
|
||||||
from typing import List, Any, Tuple, Optional, cast
|
from typing import List, Any, Union, Tuple, cast
|
||||||
from tinygrad.helpers import prod, getenv, T
|
from tinygrad.helpers import prod, to_mv, getenv, round_up, _cache_dir, T
|
||||||
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator
|
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator
|
||||||
from tinygrad.renderer.cstyle import MetalRenderer
|
from tinygrad.renderer.cstyle import MetalRenderer
|
||||||
|
|
||||||
@@ -22,14 +22,19 @@ class MTLResourceOptions:
|
|||||||
class MTLPipelineOption:
|
class MTLPipelineOption:
|
||||||
MTLPipelineOptionNone = 0
|
MTLPipelineOptionNone = 0
|
||||||
|
|
||||||
|
# 13 is requestType that metal uses to compile source code into MTLB, there aren't any docs or symbols.
|
||||||
|
REQUEST_TYPE_COMPILE = 13
|
||||||
|
|
||||||
libobjc = ctypes.CDLL("/usr/lib/libobjc.dylib")
|
libobjc = ctypes.CDLL("/usr/lib/libobjc.dylib")
|
||||||
libmetal = ctypes.CDLL("/System/Library/Frameworks/Metal.framework/Metal")
|
libmetal = ctypes.CDLL("/System/Library/Frameworks/Metal.framework/Metal")
|
||||||
|
compiler = ctypes.CDLL("/System/Library/PrivateFrameworks/MTLCompiler.framework/MTLCompiler")
|
||||||
# Must be loaded for default Metal Device: https://developer.apple.com/documentation/metal/1433401-mtlcreatesystemdefaultdevice?language=objc
|
# Must be loaded for default Metal Device: https://developer.apple.com/documentation/metal/1433401-mtlcreatesystemdefaultdevice?language=objc
|
||||||
ctypes.CDLL("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics")
|
ctypes.CDLL("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics")
|
||||||
libdispatch = ctypes.CDLL("/usr/lib/libSystem.dylib") # libdispatch is part of libSystem on mac
|
libdispatch = ctypes.CDLL("/usr/lib/libSystem.dylib") # libdispatch is part of libSystem on mac
|
||||||
libobjc.objc_getClass.restype = objc_id
|
libobjc.objc_getClass.restype = objc_id
|
||||||
libobjc.sel_registerName.restype = objc_id
|
libobjc.sel_registerName.restype = objc_id
|
||||||
libmetal.MTLCreateSystemDefaultDevice.restype = objc_instance
|
libmetal.MTLCreateSystemDefaultDevice.restype = objc_instance
|
||||||
|
compiler.MTLCodeGenServiceCreate.restype = ctypes.c_void_p
|
||||||
libdispatch.dispatch_data_create.restype = objc_instance
|
libdispatch.dispatch_data_create.restype = objc_instance
|
||||||
|
|
||||||
# Ignore mypy error reporting incompatible default, because typevar default only works on python 3.12
|
# Ignore mypy error reporting incompatible default, because typevar default only works on python 3.12
|
||||||
@@ -65,20 +70,35 @@ def metal_src_to_library(device:MetalDevice, src:str) -> objc_instance:
|
|||||||
return library
|
return library
|
||||||
|
|
||||||
class MetalCompiler(Compiler):
|
class MetalCompiler(Compiler):
|
||||||
def __init__(self, dev:Optional[MetalDevice]=None):
|
def __init__(self):
|
||||||
self.dev = dev
|
self.cgs = ctypes.c_void_p(compiler.MTLCodeGenServiceCreate(b"tinygrad"))
|
||||||
super().__init__("compile_metal_xcode" if self.dev is None else "compile_metal")
|
super().__init__("compile_metal_direct")
|
||||||
|
def __reduce__(self): return (MetalCompiler,()) # force pickle to create new instance for each multiprocessing fork
|
||||||
def compile(self, src:str) -> bytes:
|
def compile(self, src:str) -> bytes:
|
||||||
if self.dev is None:
|
ret: Union[Exception, bytes] = CompileError("MTLCodeGenServiceBuildRequest returned without calling the callback")
|
||||||
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
|
@ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_int32, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_char_p)
|
||||||
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=src.encode('utf-8'))
|
def callback(blockptr, error, dataPtr, dataLen, errorMessage):
|
||||||
lib = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
|
nonlocal ret
|
||||||
else:
|
if error == 0:
|
||||||
library = metal_src_to_library(self.dev, src)
|
reply = bytes(to_mv(dataPtr, dataLen))
|
||||||
library_contents = msg(library, "libraryDataContents", restype=objc_instance)
|
# offset from beginning to data = header size + warning size
|
||||||
lib = ctypes.string_at(msg(library_contents, "bytes"), cast(int, msg(library_contents, "length", restype=ctypes.c_ulong)))
|
ret = reply[sum(struct.unpack('<LL', reply[8:16])):]
|
||||||
assert lib[:4] == b"MTLB", "Invalid Metal library. Using conda? Corrupt XCode?"
|
else:
|
||||||
return lib
|
ret = CompileError(errorMessage.decode())
|
||||||
|
# llvm will create modules.timestamp in cache path and cache compilation of metal stdlib (250ms => 8ms compilation time)
|
||||||
|
# note that llvm won't necessarily create anything else here as apple has prebuilt versions of many standard libraries
|
||||||
|
params = f'-fno-fast-math -std=metal3.1 --driver-mode=metal -x metal -fmodules-cache-path="{os.path.join(_cache_dir, "tinygrad")}"'
|
||||||
|
# source blob has to be padded to multiple of 4 but at least one 'b\x00' should be added, params blob just has to be null terminated
|
||||||
|
src_padded, params_padded = src.encode() + b'\x00'*(round_up(len(src) + 1, 4) - len(src)), params.encode() + b'\x00'
|
||||||
|
request = struct.pack('<QQ', len(src_padded), len(params_padded)) + src_padded + params_padded
|
||||||
|
# The callback is actually not a callback but a block which is apple's non-standard extension to add closures to C.
|
||||||
|
# See https://clang.llvm.org/docs/Block-ABI-Apple.html#high-level for struct layout.
|
||||||
|
# Fields other than invoke are unused in this case so we can just use ctypes.byref with negative offset to invoke field, add blockptr as a first
|
||||||
|
# argument and pretend it's a normal callback
|
||||||
|
compiler.MTLCodeGenServiceBuildRequest(self.cgs, None, REQUEST_TYPE_COMPILE, request, len(request), ctypes.byref(callback, -0x10))
|
||||||
|
if isinstance(ret, Exception): raise ret
|
||||||
|
assert ret[:4] == b"MTLB" and ret[-4:] == b"ENDT", f"Invalid Metal library. {ret!r}"
|
||||||
|
return ret
|
||||||
def disassemble(self, lib:bytes):
|
def disassemble(self, lib:bytes):
|
||||||
with tempfile.NamedTemporaryFile(delete=True) as shader:
|
with tempfile.NamedTemporaryFile(delete=True) as shader:
|
||||||
shader.write(lib)
|
shader.write(lib)
|
||||||
@@ -176,7 +196,7 @@ class MetalDevice(Compiled):
|
|||||||
self.timeline_value = 0
|
self.timeline_value = 0
|
||||||
|
|
||||||
from tinygrad.runtime.graph.metal import MetalGraph
|
from tinygrad.runtime.graph.metal import MetalGraph
|
||||||
super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler() if getenv("METAL_XCODE") else Compiler(),
|
super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler() if getenv("METAL_DIRECT", 1) else Compiler(),
|
||||||
functools.partial(MetalProgram, self), MetalGraph)
|
functools.partial(MetalProgram, self), MetalGraph)
|
||||||
def synchronize(self):
|
def synchronize(self):
|
||||||
for cbuf in self.mtl_buffers_in_flight: wait_check(cbuf)
|
for cbuf in self.mtl_buffers_in_flight: wait_check(cbuf)
|
||||||
|
|||||||
Reference in New Issue
Block a user