In-tree autogen: objective c (#13223)

* checkout changes from autogen branch

* move assert

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Christopher Milan
2025-11-14 17:08:42 -05:00
committed by GitHub
parent e5351699bd
commit d1bb08c5a1
8 changed files with 4726 additions and 142 deletions

View File

@@ -130,6 +130,28 @@ jobs:
mv tinygrad/runtime/autogen/mesa.py /tmp/mesa.py.bak
python3 -c "from tinygrad.runtime.autogen import mesa"
diff /tmp/mesa.py.bak tinygrad/runtime/autogen/mesa.py
- name: Verify libclang autogen
run: |
mv tinygrad/runtime/autogen/libclang.py /tmp/libclang.py.bak
python3 -c "from tinygrad.runtime.autogen import libclang"
diff /tmp/libclang.py.bak tinygrad/runtime/autogen/libclang.py
autogen-mac:
name: In-tree Autogen (macos)
runs-on: macos-14
timeout-minutes: 15
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
llvm: 'true'
pydeps: 'clang>=20'
- name: Verify macos autogen
run: |
mv tinygrad/runtime/autogen/metal.py /tmp/metal.py.bak
LIBCLANG_PATH=/opt/homebrew/opt/llvm@20/lib/libclang.dylib python3 -c "from tinygrad.runtime.autogen import metal"
diff /tmp/metal.py.bak tinygrad/runtime/autogen/metal.py
autogen-comgr-3:
name: In-tree Autogen (comgr 3)
runs-on: ubuntu-24.04

View File

@@ -4,6 +4,7 @@ from tinygrad.helpers import fetch, flatten, system
root = (here:=pathlib.Path(__file__).parent).parents[2]
nv_src = {"nv_570": "https://github.com/NVIDIA/open-gpu-kernel-modules/archive/81fe4fb417c8ac3b9bdcc1d56827d116743892a5.tar.gz",
"nv_580": "https://github.com/NVIDIA/open-gpu-kernel-modules/archive/2af9f1f0f7de4988432d4ae875b5858ffdb09cc2.tar.gz"}
macossdk = "/var/db/xcode_select_link/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk"
def load(name, dll, files, **kwargs):
if not (f:=(root/(path:=kwargs.pop("path", __name__)).replace('.','/')/f"{name}.py")).exists():
@@ -120,4 +121,12 @@ python3 src/compiler/nir/nir_intrinsics_h.py --outdir gen
python3 src/compiler/builtin_types_h.py gen/builtin_types.h""", cwd=path, shell=True, check=True),
tarball="https://gitlab.freedesktop.org/mesa/mesa/-/archive/mesa-25.2.4/mesa-25.2.4.tar.gz",
prolog=["import gzip, base64", "from tinygrad.helpers import OSX"], epilog=lambda path: [system(f"{root}/extra/mesa/lvp_nir_options.sh {path}")])
case "libclang":
return load("libclang", ["os.getenv('LIBCLANG_PATH', find_library('clang-20'))"],
lambda: [system("llvm-config-20 --includedir")+"/clang-c/Index.h"], args=lambda: system("llvm-config-20 --cflags").split(),
types={"CXString":"ci._CXString","CXType":"ci.Type","CXCursor":"ci.Cursor"}, prolog=["import clang.cindex as ci"])
case "metal":
return load("metal", ["find_library('Metal')"],[f"{macossdk}/System/Library/Frameworks/Metal.framework/Headers/MTL{s}.h" for s in
["ComputeCommandEncoder", "ComputePipeline", "CommandQueue", "Device", "IndirectCommandBuffer", "Resource", "CommandEncoder"]],
args=["-xobjective-c","-isysroot",macossdk], types={"dispatch_data_t":"objc.id_"})
case _: raise AttributeError(f"no such autogen: {nm}")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -5,15 +5,9 @@ from tinygrad.helpers import dedup, getenv, merge_dicts, PROFILE
from tinygrad.device import Buffer, ProfileGraphEntry, ProfileGraphEvent
from tinygrad.engine.realize import ExecItem, CompiledRunner
from tinygrad.engine.jit import GraphRunner, GraphException
from tinygrad.runtime.ops_metal import wait_check, msg, libobjc, to_struct, objc_instance,\
MTLResourceOptions, cmdbuf_st_time, cmdbuf_en_time, objc_id, to_ns_str
class MTLIndirectCommandType:
MTLIndirectCommandTypeConcurrentDispatch = (1 << 5)
class MTLResourceUsage:
MTLResourceUsageRead = 0b01
MTLResourceUsageWrite = 0b10
from tinygrad.runtime.ops_metal import wait_check, to_ns_str
from tinygrad.runtime.autogen import metal
from tinygrad.runtime.support import objc
class MetalGraph(GraphRunner):
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
@@ -21,16 +15,17 @@ class MetalGraph(GraphRunner):
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
# create metal batch exec
icb_descriptor = msg("new", objc_instance)(libobjc.objc_getClass(b"MTLIndirectCommandBufferDescriptor"))
msg("setCommandTypes:")(icb_descriptor, MTLIndirectCommandType.MTLIndirectCommandTypeConcurrentDispatch)
msg("setInheritBuffers:")(icb_descriptor, False)
msg("setInheritPipelineState:")(icb_descriptor, False)
msg("setMaxKernelBufferBindCount:")(icb_descriptor, 31)
icb_descriptor = metal.MTLIndirectCommandBufferDescriptor.new()
icb_descriptor.setCommandTypes(metal.MTLIndirectCommandTypeConcurrentDispatch)
icb_descriptor.setInheritBuffers(False)
icb_descriptor.setInheritPipelineState(False)
icb_descriptor.setMaxKernelBufferBindCount(31)
self.icb = msg("newIndirectCommandBufferWithDescriptor:maxCommandCount:options:", objc_instance)(self.dev.sysdevice,
icb_descriptor, len(jit_cache), MTLResourceOptions.MTLResourceCPUCacheModeDefaultCache)
self.icb = self.dev.sysdevice.newIndirectCommandBufferWithDescriptor_maxCommandCount_options(icb_descriptor, len(jit_cache),
metal.MTLResourceCPUCacheModeDefaultCache)
if self.icb.value is None: raise GraphException("create indirect command buffer failed, does your system support this?")
icb_label = bytes(msg("UTF8String", ctypes.c_char_p)(msg("description", objc_instance)(self.icb))).decode()
# TODO: needs categories
icb_label = bytes(objc.msg("UTF8String", ctypes.c_char_p)(objc.msg("description")(self.icb).retained())).decode()
self.needs_icb_fix = int((m := re.search(r'AGXG(\d+)XFamily', icb_label)) is None or int(m.group(1)) < 15) # not required on M3+
self.fixedvars = merge_dicts([ji.fixedvars for ji in jit_cache])
@@ -40,26 +35,25 @@ class MetalGraph(GraphRunner):
all_pipelines, all_resources = [], [self.int_buf.buf] if len(self.varlist) else []
for j,ji in enumerate(jit_cache):
prg: CompiledRunner = cast(CompiledRunner, ji.prg)
icb_command = msg("indirectComputeCommandAtIndex:", objc_instance)(self.icb, j)
icb_command = self.icb.indirectComputeCommandAtIndex(j).retained()
all_pipelines.append(prg._prg.pipeline_state)
msg("setComputePipelineState:")(icb_command, prg._prg.pipeline_state)
icb_command.setComputePipelineState(prg._prg.pipeline_state)
for i,b in enumerate(ji.bufs):
if b is not None and b not in input_rawbuffers:
msg("setKernelBuffer:offset:atIndex:")(icb_command, b._buf.buf, b._buf.offset, i)
icb_command.setKernelBuffer_offset_atIndex(b._buf.buf, b._buf.offset, i)
all_resources.append(b._buf.buf)
for i,v in enumerate(prg.p.vars):
msg("setKernelBuffer:offset:atIndex:")(icb_command, self.int_buf.buf, self.varlist.index(v.expr)*4, len(ji.bufs)+i)
for i,v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex(self.int_buf.buf, self.varlist.index(v.expr)*4, len(ji.bufs)+i)
global_size, local_size = prg.p.launch_dims(var_vals)
msg("concurrentDispatchThreadgroups:threadsPerThreadgroup:")(icb_command, to_struct(*global_size), to_struct(*local_size))
msg("setBarrier")(icb_command)
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup(metal.MTLSize(*global_size), metal.MTLSize(*local_size))
icb_command.setBarrier()
self.all_resources = dedup(all_resources)
self.all_pipelines = dedup(all_pipelines)
self.command_buffer: Any = None
if len(self.varlist): self.int_buf_view = self.dev.allocator._as_buffer(self.int_buf).cast('i')
for var in self.fixedvars: self.int_buf_view[self.varlist.index(var)] = self.fixedvars[var]
self.range = to_struct(0, len(jit_cache))
self.range = metal.NSRange(0, len(jit_cache))
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer)
@@ -68,18 +62,18 @@ class MetalGraph(GraphRunner):
all_resources = dedup(self.all_resources + [input_rawbuffers[input_idx]._buf.buf for input_idx in self.input_replace.values()])
for (j,i),input_idx in self.input_replace.items():
computeCommand = msg("indirectComputeCommandAtIndex:", objc_id)(self.icb, j)
msg("setKernelBuffer:offset:atIndex:")(computeCommand, input_rawbuffers[input_idx]._buf.buf, input_rawbuffers[input_idx]._buf.offset, i)
computeCommand = self.icb.indirectComputeCommandAtIndex(j)
computeCommand.setKernelBuffer_offset_atIndex(input_rawbuffers[input_idx]._buf.buf, input_rawbuffers[input_idx]._buf.offset, i)
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
computeCommand = msg("indirectComputeCommandAtIndex:", objc_id)(self.icb, j)
msg("concurrentDispatchThreadgroups:threadsPerThreadgroup:")(computeCommand, to_struct(*global_dims), to_struct(*local_dims))
computeCommand = self.icb.indirectComputeCommandAtIndex(j)
computeCommand.concurrentDispatchThreadgroups_threadsPerThreadgroup(metal.MTLSize(*global_dims), metal.MTLSize(*local_dims))
for var in self.vars: self.int_buf_view[self.varlist.index(var)] = var_vals[var]
command_buffer = msg("commandBuffer", objc_instance)(self.dev.mtl_queue)
encoder = msg("computeCommandEncoder", objc_instance)(command_buffer)
msg("useResources:count:usage:")(encoder, (objc_id * len(all_resources))(*all_resources), len(all_resources),
MTLResourceUsage.MTLResourceUsageRead | MTLResourceUsage.MTLResourceUsageWrite)
command_buffer = self.dev.mtl_queue.commandBuffer().retained()
encoder = command_buffer.computeCommandEncoder().retained()
encoder.useResources_count_usage(ctypes.cast((metal.MTLBuffer * len(all_resources))(*all_resources), ctypes.POINTER(metal.MTLResource)),
len(all_resources), metal.MTLResourceUsageRead | metal.MTLResourceUsageWrite)
# NOTE: the pipelines likely need to be added to the used resources to fix the crash on M1/M2, but I haven't figured out how
# this is a O(n) hack to get them used. what should work is:
@@ -88,24 +82,24 @@ class MetalGraph(GraphRunner):
# to repro the crash (which can also crash other running GPU apps), run with FIX_METAL_ICB=0
if getenv("FIX_METAL_ICB", self.needs_icb_fix):
for ps in self.all_pipelines:
msg("setComputePipelineState:")(encoder, ps)
msg("dispatchThreadgroups:threadsPerThreadgroup:")(encoder, to_struct(0,0,0), to_struct(0,0,0))
encoder.setComputePipelineState(ps)
encoder.dispatchThreadgroups_threadsPerThreadgroup(metal.MTLSize(0,0,0), metal.MTLSize(0,0,0))
msg("executeCommandsInBuffer:withRange:")(encoder, self.icb, self.range)
msg("endEncoding")(encoder)
msg("setLabel:")(command_buffer, to_ns_str(f"batched {len(self.jit_cache)}"))
msg("commit")(command_buffer)
encoder.executeCommandsInBuffer_withRange(self.icb, self.range)
encoder.endEncoding()
command_buffer.setLabel(to_ns_str(f"batched {len(self.jit_cache)}"))
command_buffer.commit()
self.command_buffer = command_buffer
self.dev.mtl_buffers_in_flight.append(command_buffer)
if wait:
wait_check(command_buffer)
return cmdbuf_en_time(command_buffer) - cmdbuf_st_time(command_buffer)
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
return None
def collect_timestamps(self):
# create a graph event and evenly space each program
st, en = decimal.Decimal(cmdbuf_st_time(self.command_buffer)) * 1000000, decimal.Decimal(cmdbuf_en_time(self.command_buffer)) * 1000000
st, en = decimal.Decimal(self.command_buffer.GPUStartTime()) * 1000000, decimal.Decimal(self.command_buffer.GPUEndTime()) * 1000000
ents = [ProfileGraphEntry(self.device, cast(CompiledRunner, ji.prg)._prg.name, i, i+1, is_copy=False) for i,ji in enumerate(self.jit_cache)]
step = (en-st)/len(ents)
self.dev.profile_events += [ProfileGraphEvent(ents, [], [st+step*i for i in range(len(ents)+1)])]

View File

@@ -1,74 +1,42 @@
import subprocess, pathlib, struct, ctypes, tempfile, functools, contextlib, decimal, platform
from typing import Any, cast
from tinygrad.helpers import prod, to_mv, getenv, round_up, cache_dir, T, init_c_struct_t, PROFILE, ProfileRangeEvent, cpu_profile, unwrap
import subprocess, pathlib, struct, ctypes, tempfile, functools, contextlib, decimal, platform, sys
from tinygrad.helpers import prod, to_mv, getenv, round_up, cache_dir, init_c_struct_t, PROFILE, ProfileRangeEvent, cpu_profile, unwrap
import tinygrad.runtime.support.objc as objc
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator, ProfileDeviceEvent
from tinygrad.renderer.cstyle import MetalRenderer
class objc_id(ctypes.c_void_p): # This prevents ctypes from converting response to plain int, and dict.fromkeys() can use it to dedup
def __hash__(self): return hash(self.value)
def __eq__(self, other): return self.value == other.value
class objc_instance(objc_id): # method with name "new", "alloc" should be freed after use
def __del__(self):
# CPython doesn't make any guarantees about order in which globals (like `msg` or `libobjc`) are destroyed when the interpreter shuts down
# https://github.com/tinygrad/tinygrad/pull/8949 triggered the unlucky ordering which lead to a bunch of errors at exit
# TODO: Why isn't `sys.is_finalizing` working?
if msg is not None and libobjc is not None: msg("release")(self)
class MTLResourceOptions:
MTLResourceCPUCacheModeDefaultCache = 0
MTLResourceStorageModeShared = 0 << 4
class MTLPipelineOption:
MTLPipelineOptionNone = 0
from tinygrad.runtime.autogen import metal
# 13 is requestType that metal uses to compile source code into MTLB, there aren't any docs or symbols.
REQUEST_TYPE_COMPILE = 13
libobjc = ctypes.CDLL("/usr/lib/libobjc.dylib")
libmetal = ctypes.CDLL("/System/Library/Frameworks/Metal.framework/Metal")
# Must be loaded for default Metal Device: https://developer.apple.com/documentation/metal/1433401-mtlcreatesystemdefaultdevice?language=objc
ctypes.CDLL("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics")
libdispatch = ctypes.CDLL("/usr/lib/libSystem.dylib") # libdispatch is part of libSystem on mac
libobjc.objc_getClass.restype = objc_id
libobjc.sel_registerName.restype = objc_id
libmetal.MTLCreateSystemDefaultDevice.restype = objc_instance
libdispatch.dispatch_data_create.restype = objc_instance
# FIXME: these need autogen to support objc categories
# https://developer.apple.com/library/archive/documentation/Cocoa/Conceptual/ObjectiveC/Chapters/ocCategories.html
@functools.cache
def msg(selector: str, restype: type[T] = objc_id): # type: ignore [assignment]
resname = libobjc.sel_registerName(selector.encode())
sender = libobjc["objc_msgSend"] # Using attribute access returns a new reference so setting restype is safe
sender.restype = restype
def _msg(ptr: objc_id, *args: Any) -> T: return sender(ptr, resname, *args)
return _msg
@functools.cache
def to_ns_str(s: str): return msg("stringWithUTF8String:", objc_instance)(libobjc.objc_getClass(b"NSString"), s.encode())
def from_ns_str(s): return bytes(msg("UTF8String", ctypes.c_char_p)(s)).decode()
def to_ns_str(s: str): return ctypes.cast(objc.msg("stringWithUTF8String:")(metal.NSString._objc_class_, s.encode()), metal.NSString)
def from_ns_str(s): return bytes(objc.msg("UTF8String", ctypes.c_char_p)(s)).decode()
def to_struct(*t: int, _type: type[ctypes._SimpleCData] = ctypes.c_ulong):
return init_c_struct_t(tuple([(f"field{i}", _type) for i in range(len(t))]))(*t)
def wait_check(cbuf: Any):
msg("waitUntilCompleted")(cbuf)
error_check(msg("error", objc_instance)(cbuf))
def wait_check(cbuf:metal.MTLCommandBuffer):
cbuf.waitUntilCompleted()
error_check(cbuf.error().retained())
def cmdbuf_label(cbuf: objc_id) -> str|None: return from_ns_str(label) if (label:=msg("label", objc_id)(cbuf)).value is not None else None
def cmdbuf_st_time(cbuf: objc_id) -> float: return cast(float, msg("GPUStartTime", ctypes.c_double)(cbuf))
def cmdbuf_en_time(cbuf: objc_id) -> float: return cast(float, msg("GPUEndTime", ctypes.c_double)(cbuf))
def cmdbuf_label(cbuf:metal.MTLCommandBuffer) -> str|None: return from_ns_str(label) if (label:=cbuf.label()).value is not None else None
def error_check(error: objc_instance, error_constructor: type[Exception] = RuntimeError):
def error_check(error: metal.NSError, error_constructor: type[Exception] = RuntimeError):
if error.value is None: return None
raise error_constructor(from_ns_str(msg("localizedDescription", objc_instance)(error)))
raise error_constructor(from_ns_str(error.localizedDescription().retained()))
class MetalDevice(Compiled):
def __init__(self, device:str):
self.sysdevice = libmetal.MTLCreateSystemDefaultDevice()
self.mtl_queue = msg("newCommandQueueWithMaxCommandBufferCount:", objc_instance)(self.sysdevice, 1024)
self.sysdevice = metal.MTLCreateSystemDefaultDevice()
self.mtl_queue = self.sysdevice.newCommandQueueWithMaxCommandBufferCount(1024)
if self.mtl_queue is None: raise RuntimeError("Cannot allocate a new command queue")
self.mtl_buffers_in_flight: list[Any] = []
self.timeline_signal = msg("newSharedEvent", objc_instance)(self.sysdevice)
self.mtl_buffers_in_flight: list[metal.MTLCommandBuffer] = []
self.timeline_signal = self.sysdevice.newSharedEvent()
self.timeline_value = 0
Compiled.profile_events += [ProfileDeviceEvent(device)]
@@ -76,23 +44,22 @@ class MetalDevice(Compiled):
from tinygrad.runtime.graph.metal import MetalGraph
# NOTE: GitHub CI macOS runners use paravirtualized metal which is broken with graph.
# This can be reproduced locally with any virtualization software (like utm) that can create macOS VMs with apple's own virtualization framework.
super().__init__(device, MetalAllocator(self), [(MetalRenderer, MetalCompiler), (MetalRenderer, Compiler)],
functools.partial(MetalProgram, self), MetalGraph if 'virtual' not in from_ns_str(msg('name')(self.sysdevice)).lower() else None)
super().__init__(device, MetalAllocator(self), [(MetalRenderer, MetalCompiler), (MetalRenderer, Compiler)], functools.partial(MetalProgram, self),
MetalGraph if 'virtual' not in from_ns_str(self.sysdevice.name()).lower() else None)
def synchronize(self):
for cbuf in self.mtl_buffers_in_flight:
wait_check(cbuf)
st, en = decimal.Decimal(cmdbuf_st_time(cbuf)) * 1000000, decimal.Decimal(cmdbuf_en_time(cbuf)) * 1000000
st, en = decimal.Decimal(cbuf.GPUStartTime()) * 1000000, decimal.Decimal(cbuf.GPUEndTime()) * 1000000
# NOTE: command buffers from MetalGraph are not profiled here
if PROFILE and (lb:=cmdbuf_label(cbuf)) is not None and not lb.startswith("batched"):
Compiled.profile_events += [ProfileRangeEvent(self.device, lb, st, en, is_copy=lb.startswith("COPY"))]
self.mtl_buffers_in_flight.clear()
def metal_src_to_library(device:MetalDevice, src:str) -> objc_instance:
options = msg("new", objc_instance)(libobjc.objc_getClass(b"MTLCompileOptions"))
msg("setFastMathEnabled:")(options, getenv("METAL_FAST_MATH"))
library = msg("newLibraryWithSource:options:error:", objc_instance)(device.sysdevice, to_ns_str(src),
options, ctypes.byref(compileError:=objc_instance()))
def metal_src_to_library(device:MetalDevice, src:str) -> metal.MTLLibrary:
options = metal.MTLCompileOptions.new()
options.setFastMathEnabled(getenv("METAL_FAST_MATH"))
library = device.sysdevice.newLibraryWithSource_options_error(to_ns_str(src), options, ctypes.byref(compileError:=metal.NSError().retained()))
error_check(compileError, CompileError)
return library
@@ -155,71 +122,71 @@ class MetalProgram:
self.dev, self.name, self.lib = dev, name, lib
if lib[:4] == b"MTLB":
# binary metal library
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
self.library = msg("newLibraryWithData:error:", objc_instance)(self.dev.sysdevice, data, ctypes.byref(error_lib:=objc_instance()))
data = objc.dispatch_data_create(lib, len(lib), None, None)
self.library = self.dev.sysdevice.newLibraryWithData_error(data, ctypes.byref(error_lib:=metal.NSError().retained())).retained()
error_check(error_lib)
else:
# metal source. rely on OS caching
try: self.library = metal_src_to_library(self.dev, lib.decode())
except CompileError as e: raise RuntimeError from e
self.fxn = msg("newFunctionWithName:", objc_instance)(self.library, to_ns_str(name))
descriptor = msg("new", objc_instance)(libobjc.objc_getClass(b"MTLComputePipelineDescriptor"))
msg("setComputeFunction:")(descriptor, self.fxn)
msg("setSupportIndirectCommandBuffers:")(descriptor, True)
self.pipeline_state = msg("newComputePipelineStateWithDescriptor:options:reflection:error:", objc_instance)(self.dev.sysdevice,
descriptor, MTLPipelineOption.MTLPipelineOptionNone, None, ctypes.byref(error_pipeline_creation:=objc_instance()))
self.fxn = self.library.newFunctionWithName(to_ns_str(name)).retained()
descriptor = metal.MTLComputePipelineDescriptor.new()
descriptor.setComputeFunction(self.fxn)
descriptor.setSupportIndirectCommandBuffers(True)
self.pipeline_state = self.dev.sysdevice.newComputePipelineStateWithDescriptor_options_reflection_error(descriptor, metal.MTLPipelineOptionNone,
None, ctypes.byref(error_pipeline_creation:=metal.NSError().retained()))
error_check(error_pipeline_creation)
# cache these msg calls
self.max_total_threads: int = cast(int, msg("maxTotalThreadsPerThreadgroup", ctypes.c_ulong)(self.pipeline_state))
self.max_total_threads: int = self.pipeline_state.maxTotalThreadsPerThreadgroup()
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
if prod(local_size) > self.max_total_threads:
exec_width = msg("threadExecutionWidth", ctypes.c_ulong)(self.pipeline_state)
memory_length = msg("staticThreadgroupMemoryLength", ctypes.c_ulong)(self.pipeline_state)
exec_width = self.pipeline_state.threadExecutionWidth()
memory_length = self.pipeline_state.staticThreadgroupMemoryLength()
raise RuntimeError(f"local size {local_size} bigger than {self.max_total_threads} with exec width {exec_width} memory length {memory_length}")
command_buffer = msg("commandBuffer", objc_instance)(self.dev.mtl_queue)
encoder = msg("computeCommandEncoder", objc_instance)(command_buffer)
msg("setComputePipelineState:")(encoder, self.pipeline_state)
for i,a in enumerate(bufs): msg("setBuffer:offset:atIndex:")(encoder, a.buf, a.offset, i)
for i,a in enumerate(vals, start=len(bufs)): msg("setBytes:length:atIndex:")(encoder, bytes(ctypes.c_int(a)), 4, i)
msg("dispatchThreadgroups:threadsPerThreadgroup:")(encoder, to_struct(*global_size), to_struct(*local_size))
msg("endEncoding")(encoder)
msg("setLabel:")(command_buffer, to_ns_str(self.name)) # TODO: is this always needed?
msg("commit")(command_buffer)
command_buffer = self.dev.mtl_queue.commandBuffer().retained() # FIXME: is this really ARC?
encoder = command_buffer.computeCommandEncoder().retained() # FIXME: is this really ARC?
encoder.setComputePipelineState(self.pipeline_state)
for i,a in enumerate(bufs): encoder.setBuffer_offset_atIndex(a.buf, a.offset, i)
for i,a in enumerate(vals, start=len(bufs)): encoder.setBytes_length_atIndex(bytes(ctypes.c_int(a)), 4, i)
encoder.dispatchThreadgroups_threadsPerThreadgroup(metal.MTLSize(*global_size), metal.MTLSize(*local_size))
encoder.endEncoding()
command_buffer.setLabel(to_ns_str(self.name)) # TODO: is this always needed?
command_buffer.commit()
self.dev.mtl_buffers_in_flight.append(command_buffer)
if wait:
wait_check(command_buffer)
return cmdbuf_en_time(command_buffer) - cmdbuf_st_time(command_buffer)
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
class MetalBuffer:
def __init__(self, buf:Any, size:int, offset=0): self.buf, self.size, self.offset = buf, size, offset
def __init__(self, buf:metal.MTLBuffer, size:int, offset=0): self.buf, self.size, self.offset = buf, size, offset
class MetalAllocator(LRUAllocator[MetalDevice]):
def _alloc(self, size:int, options) -> MetalBuffer:
if options.external_ptr: return MetalBuffer(objc_id(options.external_ptr), size)
if options.external_ptr: return MetalBuffer(metal.MTLBuffer(options.external_ptr), size)
# Buffer is explicitly released in _free() rather than garbage collected via reference count
ret = msg("newBufferWithLength:options:", objc_id)(self.dev.sysdevice, ctypes.c_ulong(size), MTLResourceOptions.MTLResourceStorageModeShared)
ret = self.dev.sysdevice.newBufferWithLength_options(size, metal.MTLResourceStorageModeShared)
ret.retain = False
if ret.value is None: raise MemoryError(f"Metal OOM while allocating {size=}")
return MetalBuffer(ret, size)
def _free(self, opaque:MetalBuffer, options):
if msg is not None and libobjc is not None: msg("release")(opaque.buf)
if not sys.is_finalizing(): opaque.buf.release
def _transfer(self, dest:MetalBuffer, src:MetalBuffer, sz:int, src_dev:MetalDevice, dest_dev:MetalDevice):
dest_dev.synchronize()
src_command_buffer = msg("commandBuffer", objc_instance)(src_dev.mtl_queue)
encoder = msg("blitCommandEncoder", objc_instance)(src_command_buffer)
msg("copyFromBuffer:sourceOffset:toBuffer:destinationOffset:size:")(encoder, src.buf, ctypes.c_ulong(src.offset),
dest.buf, ctypes.c_ulong(dest.offset), ctypes.c_ulong(sz))
msg("endEncoding")(encoder)
src_command_buffer = src_dev.mtl_queue.commandBuffer().retained()
encoder = src_command_buffer.blitCommandEncoder().retained()
encoder.copyFromBuffer_sourceOffset_toBuffer_destinationOffset_size(src.buf, src.offset, dest.buf, dest.offset, sz)
encoder.endEncoding()
if src_dev != dest_dev:
msg("encodeSignalEvent:value:")(src_command_buffer, src_dev.timeline_signal, src_dev.timeline_value)
dest_command_buffer = msg("commandBuffer", objc_instance)(dest_dev.mtl_queue)
msg("encodeWaitForEvent:value:")(dest_command_buffer, src_dev.timeline_signal, src_dev.timeline_value)
msg("commit")(dest_command_buffer)
src_command_buffer.encodeSignalEvent_value(ctypes.cast(src_dev.timeline_signal, metal.MTLEvent), src_dev.timeline_value)
dest_command_buffer = dest_dev.mtl_queue.commandBuffer().retained()
dest_command_buffer.encodeWaitForEvent_value(ctypes.cast(src_dev.timeline_signal, metal.MTLEvent), src_dev.timeline_value)
dest_command_buffer.commit()
dest_dev.mtl_buffers_in_flight.append(dest_command_buffer)
src_dev.timeline_value += 1
msg("setLabel:")(src_command_buffer, to_ns_str(f"COPY {src_dev.device} -> {dest_dev.device}"))
msg("commit")(src_command_buffer)
src_command_buffer.setLabel(to_ns_str(f"COPY {src_dev.device} -> {dest_dev.device}"))
src_command_buffer.commit()
src_dev.mtl_buffers_in_flight.append(src_command_buffer)
# Transfers currently synchronize the completion. Otherwise, copies can sometimes lead to incorrect values.
# There is no real metal multidevice support for now, so transfer is used only for tests.
@@ -228,7 +195,7 @@ class MetalAllocator(LRUAllocator[MetalDevice]):
with cpu_profile(prof_desc, self.dev.device, is_copy=True): dst[:] = src
def _as_buffer(self, src:MetalBuffer) -> memoryview:
self.dev.synchronize()
return to_mv(cast(int, msg("contents", objc_id)(src.buf).value), src.size + src.offset)[src.offset:]
return to_mv(src.buf.contents(), src.size + src.offset)[src.offset:]
def _copyin(self, dest:MetalBuffer, src:memoryview): self._cp_mv(self._as_buffer(dest), src, "TINY -> METAL")
def _copyout(self, dest:memoryview, src:MetalBuffer): self._cp_mv(dest, self._as_buffer(src), "METAL -> TINY")
def _offset(self, buf:MetalBuffer, size:int, offset:int): return MetalBuffer(buf.buf, size, offset)

View File

@@ -1,9 +1,12 @@
import ctypes.util, importlib.metadata, itertools, re, functools, os
from tinygrad.helpers import flatten, unwrap
from tinygrad.helpers import flatten, unwrap, fromimport
assert importlib.metadata.version('clang')[:2] == "20", 'clang version 20 required, pip install "clang==20.1.0"'
from clang.cindex import Config, Index, CursorKind as CK, TranslationUnit as TU, LinkageKind as LK, TokenKind as ToK, TypeKind as TK
from clang.cindex import Config, Index, Cursor, Type, CursorKind as CK, TranslationUnit as TU, LinkageKind as LK, TokenKind as ToK, TypeKind as TK
from clang.cindex import PrintingPolicy as PP, PrintingPolicyProperty as PPP, SourceRange
libclang = functools.partial(fromimport, "tinygrad.runtime.autogen.libclang") # we can't actually import this, because then we can't generate it
if not Config.loaded: Config.set_library_file(os.getenv("LIBCLANG_PATH", ctypes.util.find_library("clang-20")))
def fst(c): return next(c.get_children())
@@ -14,18 +17,25 @@ def readext(f, fst, snd=None):
return f.read((fst.end.offset if isinstance(fst, SourceRange) else snd)-start)
def attrs(c): return list(filter(lambda k: (v:=k.value) >= 400 and v < 500, map(lambda c: c.kind, c.get_children())))
def protocols(t): yield from (Cursor.from_result(libclang("clang_Type_getObjCProtocolDecl")(t, i), t)
for i in range(libclang("clang_Type_getNumObjCProtocolRefs")(t)))
def basetype(t): return Type.from_result(libclang("clang_Type_getObjCObjectBaseType")(t), (t,))
base_rules = [(r'\s*\\\n\s*', ' '), (r'\s*\n\s*', ' '), (r'//.*', ''), (r'/\*.*?\*/', ''), (r'\b(0[xX][0-9a-fA-F]+|\d+)[uUlL]+\b', r'\1'),
(r'\b0+(?=\d)', ''), (r'\s*&&\s*', r' and '), (r'\s*\|\|\s*', r' or '), (r'\s*!\s*', ' not '),
(r'(struct|union|enum)\s*([a-zA-Z_][a-zA-Z0-9_]*\b)', r'\1_\2'),
(r'\((unsigned )?(char|uint64_t)\)', ''), (r'^.*\d+:\d+.*$', ''), (r'^.*\w##\w.*$', '')]
ints = (TK.INT, TK.UINT, TK.LONG, TK.ULONG, TK.LONGLONG, TK.ULONGLONG)
specs = (CK.OBJC_SUPER_CLASS_REF,)
# https://clang.llvm.org/docs/AutomaticReferenceCounting.html#arc-method-families
arc_families = ['alloc', 'copy', 'mutableCopy', 'new']
def gen(dll, files, args=[], prolog=[], rules=[], epilog=[], recsym=False, use_errno=False, anon_names={}, types={}, parse_macros=True):
macros, lines, anoncnt, types = [], [], itertools.count().__next__, {k:(v,True) for k,v in types.items()}
macros, lines, anoncnt, types, objc = [], [], itertools.count().__next__, {k:(v,True) for k,v in types.items()}, False
def tname(t, suggested_name=None, typedef=None) -> str:
suggested_name = anon_names.get(f"{(decl:=t.get_declaration()).location.file}:{decl.location.line}", suggested_name)
nonlocal lines, types, anoncnt
nonlocal lines, types, anoncnt, objc
tmap = {TK.VOID:"None", TK.CHAR_U:"ctypes.c_ubyte", TK.UCHAR:"ctypes.c_ubyte", TK.CHAR_S:"ctypes.c_char", TK.SCHAR:"ctypes.c_char",
**{getattr(TK, k):f"ctypes.c_{k.lower()}" for k in ["BOOL", "WCHAR", "FLOAT", "DOUBLE", "LONGDOUBLE"]},
**{getattr(TK, k):f"ctypes.c_{'u' if 'U' in k else ''}int{sz}" for sz,k in
@@ -37,6 +47,7 @@ def gen(dll, files, args=[], prolog=[], rules=[], epilog=[], recsym=False, use_e
return f"ctypes.CFUNCTYPE({tname(f.get_result())}{(', '+', '.join(map(tname, f.argument_types()))) if f.kind==TK.FUNCTIONPROTO else ''})"
match t.kind:
case TK.POINTER: return "ctypes.c_void_p" if (ptr:=t.get_pointee()).kind == TK.VOID else f"ctypes.POINTER({tname(ptr)})"
case TK.OBJCOBJECTPOINTER: return tname(t.get_pointee()) # TODO: this seems wrong
case TK.ELABORATED: return tname(t.get_named_type(), suggested_name)
case TK.TYPEDEF if t.spelling == t.get_canonical().spelling: return tname(t.get_canonical())
case TK.TYPEDEF:
@@ -76,8 +87,62 @@ def gen(dll, files, args=[], prolog=[], rules=[], epilog=[], recsym=False, use_e
case TK.CONSTANTARRAY:
return f"({tname(t.get_array_element_type(), suggested_name.rstrip('s') if suggested_name else None)} * {t.get_array_size()})"
case TK.INCOMPLETEARRAY: return f"({tname(t.get_array_element_type(), suggested_name.rstrip('s') if suggested_name else None)} * 0)"
case TK.OBJCINTERFACE:
is_defn = bool([f.kind for f in decl.get_children() if f.kind in (CK.OBJC_INSTANCE_METHOD_DECL, CK.OBJC_CLASS_METHOD_DECL)])
if (nm:=t.spelling) not in types: lines.append(f"class {nm}(objc.Spec): pass")
types[nm] = nm, is_defn
if is_defn:
ims, cms = parse_objc_spec(decl, t.spelling, CK.OBJC_INSTANCE_METHOD_DECL), parse_objc_spec(decl, t.spelling, CK.OBJC_CLASS_METHOD_DECL)
lines.extend([*([f"{nm}._bases_ = [{', '.join(bs)}]"] if (bs:=[tname(b.type) for b in decl.get_children() if b.kind in specs]) else []),
*([f"{nm}._methods_ = [", *ims, ']'] if ims else []), *([f"{nm}._classmethods_ = [", *cms, ']'] if cms else [])])
return nm
case TK.OBJCSEL: return "objc.id_"
case TK.OBJCID: return (objc:=True, "objc.id_")[1]
case TK.OBJCOBJECT:
if basetype(t).kind != TK.OBJCID: raise NotImplementedError(f"generics unsupported: {t.spelling}")
ps = [proto(p) for p in protocols(t)]
if len(ps) == 0:
types[t.spelling] = "objc.id_", True
return "objc.id_"
if len(ps) == 1:
types[t.spelling] = ps[0], True
return ps[0]
types[t.spelling] = (nm:=f"_anondynamic{anoncnt()}"), True
lines.append(f"class {nm}({', '.join(p for p in ps)}): pass # {t.spelling}")
return nm
case _: raise NotImplementedError(f"unsupported type {t.kind}")
# parses an objc @interface or @protocol, returning a list of declerations that objc.Spec can parse, for the specified kind
# NB: ivars are unsupported
def parse_objc_spec(decl:Cursor, nm:str, kind:CK) -> list[str]:
nonlocal lines, types
if decl is None: return []
ms = []
for d in filter(lambda d: d.kind == kind, decl.get_children()):
rollback = lines, types
try: ms.append(f" ('{d.spelling}', {repr('instancetype') if (rt:=d.result_type).spelling=='instancetype' else tname(rt)}, "
f"[{', '.join('instancetype' if a.spelling == 'instancetype' else tname(a.type) for a in d.get_arguments())}]" +
(", True" if CK.NS_RETURNS_RETAINED in attrs(d) or (any(d.spelling.startswith(s) for s in arc_families) and rt.kind!=TK.VOID) else "") + "),")
except NotImplementedError as e:
print(f"skipping {nm}.{d.spelling}: {e}")
lines, types = rollback
return ms
# libclang doesn't have a "type" for @protocol, so we have to do this here...
def proto(decl):
nonlocal lines, types
if (nm:=decl.spelling) in types and types[nm][1]: return types[nm][0]
# check if this is a forward declaration
is_defn = bool([f.kind for f in decl.get_children() if f.kind in (CK.OBJC_INSTANCE_METHOD_DECL, CK.OBJC_CLASS_METHOD_DECL)])
if nm not in types: lines.append(f"class {nm}(objc.Spec): pass")
types[nm] = nm, is_defn
if is_defn:
bs = [proto(b) for b in decl.get_children() if b.kind==CK.OBJC_PROTOCOL_REF and b.spelling != decl.spelling]
ims, cms = parse_objc_spec(decl, nm, CK.OBJC_INSTANCE_METHOD_DECL), parse_objc_spec(decl, nm, CK.OBJC_CLASS_METHOD_DECL)
lines.extend([*([f"{nm}._bases_ = [{', '.join(bs)}]"] if bs else []),
*([f"{nm}._methods_ = [", *ims, "]"] if ims else []), *([f"{nm}._classmethods_ = [", *cms, "]"] if cms else [])])
return nm
for f in files:
tu = Index.create().parse(f, args, options=TU.PARSE_DETAILED_PROCESSING_RECORD)
(pp:=PP.create(tu.cursor)).set_property(PPP.TerseOutput, 1)
@@ -90,7 +155,8 @@ def gen(dll, files, args=[], prolog=[], rules=[], epilog=[], recsym=False, use_e
# TODO: we could support name-mangling
lines.append(f"# {c.pretty_printed(pp)}\ntry: ({c.spelling}:=dll.{c.spelling}).restype, {c.spelling}.argtypes = "
f"{tname(c.result_type)}, [{', '.join(tname(arg.type) for arg in c.get_arguments())}]\nexcept AttributeError: pass\n")
case CK.STRUCT_DECL | CK.UNION_DECL | CK.TYPEDEF_DECL | CK.ENUM_DECL: tname(c.type)
if CK.NS_RETURNS_RETAINED in attrs(c): lines.append(f"{c.spelling} = objc.returns_retained({c.spelling})")
case CK.STRUCT_DECL | CK.UNION_DECL | CK.TYPEDEF_DECL | CK.ENUM_DECL | CK.OBJC_INTERFACE_DECL: tname(c.type)
case CK.MACRO_DEFINITION if parse_macros and len(toks:=list(c.get_tokens())) > 1:
if toks[1].spelling == '(' and toks[0].extent.end.column == toks[1].extent.start.column:
it = iter(toks[1:])
@@ -107,12 +173,13 @@ def gen(dll, files, args=[], prolog=[], rules=[], epilog=[], recsym=False, use_e
else: macros += [f"{c.spelling} = {tname(c.type)}({readext(f, last(c).extent)})"]
case CK.VAR_DECL if c.linkage == LK.EXTERNAL and dll:
lines.append(f"try: {c.spelling} = {tname(c.type)}.in_dll(dll, '{c.spelling}')\nexcept (ValueError,AttributeError): pass")
case CK.OBJC_PROTOCOL_DECL: proto(c)
except NotImplementedError as e:
print(f"skipping {c.spelling}: {e}")
lines, types = rollback
main = (f"# mypy: ignore-errors\nimport ctypes{', os' if any('os' in s for s in dll) else ''}\n"
"from tinygrad.helpers import unwrap\nfrom tinygrad.runtime.support.c import Struct, CEnum, _IO, _IOW, _IOR, _IOWR\n" + '\n'.join([*prolog,
*(["from ctypes.util import find_library"]*any('find_library' in s for s in dll)),
*(["from ctypes.util import find_library"]*any('find_library' in s for s in dll)), *(["from tinygrad.runtime.support import objc"]*objc),
*(["def dll():",*flatten([[f" try: return ctypes.CDLL(unwrap({d}){', use_errno=True' if use_errno else ''})",' except: pass'] for d in dll]),
" return None", "dll = dll()\n"]*bool(dll)), *lines]) + '\n')
macros = [r for m in macros if (r:=functools.reduce(lambda s,r:re.sub(r[0], r[1], s), rules + base_rules, m))]

View File

@@ -0,0 +1,70 @@
import ctypes, ctypes.util, functools, sys
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING: id_ = ctypes.c_void_p
else:
class id_(ctypes.c_void_p):
retain: bool = False
# This prevents ctypes from converting response to plain int, and dict.fromkeys() can use it to dedup
def __hash__(self): return hash(self.value)
def __eq__(self, other): return self.value == other.value
def __del__(self):
if self.retain and not sys.is_finalizing(): self.release()
def release(self): msg("release")(self)
def retained(self):
setattr(self, 'retain', True)
return self
def returns_retained(f): return functools.wraps(f)(lambda *args, **kwargs: f(*args, **kwargs).retained())
lib = ctypes.CDLL(ctypes.util.find_library('objc'))
lib.sel_registerName.restype = id_
getsel = functools.cache(lib.sel_registerName)
lib.objc_getClass.restype = id_
dispatch_data_create = ctypes.CDLL("/usr/lib/libSystem.dylib").dispatch_data_create
dispatch_data_create.restype = id_
dispatch_data_create = returns_retained(dispatch_data_create)
def msg(sel:str, restype=id_, argtypes=[], retain=False, clsmeth=False):
# Using attribute access returns a new reference so setting restype is safe
(sender:=lib["objc_msgSend"]).restype, sender.argtypes = restype, [id_, id_]+argtypes if argtypes else []
def f(ptr, *args): return sender(ptr._objc_class_ if clsmeth else ptr, getsel(sel.encode()), *args)
return returns_retained(f) if retain else f
if TYPE_CHECKING:
import _ctypes
class MetaSpec(_ctypes._PyCSimpleType):
_objc_class_: id_
def __getattr__(cls, nm:str) -> Any: ...
else:
class MetaSpec(type(id_)):
def __new__(mcs, name, bases, dct):
cls = super().__new__(mcs, name, bases, {'_objc_class_': lib.objc_getClass(name.encode()), '_children_': set(), **dct})
cls._methods_, cls._classmethods_ = dct.get('_methods_', []), dct.get('_classmethods_', [])
return cls
def __setattr__(cls, k, v):
super().__setattr__(k, v)
if k in ("_methods_", "_classmethods_"):
for m in v: cls._addmeth(m, clsmeth=(v=="_classmethods_"))
for c in cls._children_: c._inherit(cls)
if k == "_bases_":
for b in v:
b._children_.add(cls)
cls._inherit(b)
def _inherit(cls, b):
for _b in getattr(b, "_bases_", []): cls._inherit(_b)
for m in getattr(b, "_methods_", []): cls._addmeth(m)
for m in getattr(b, "_classmethods_", []): cls._addmeth(m, True)
for c in cls._children_: c._inherit(cls)
def _addmeth(cls, m, clsmeth=False):
nm = m[0].strip(':').replace(':', '_')
if clsmeth: setattr(cls, nm, classmethod(msg(m[0], cls if m[1] == 'instancetype' else m[1],
[cls if a == 'instancetype' else a for a in m[2]], *m[3:], clsmeth=True))) # type: ignore[misc]
else: setattr(cls, nm, msg(m[0], cls if m[1] == 'instancetype' else m[1], [cls if a == 'instancetype' else a for a in m[2]], *m[3:]))
class Spec(id_, metaclass=MetaSpec):
if TYPE_CHECKING:
def __getattr__(self, nm:str) -> Any: ...