mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
In-tree autogen: objective c (#13223)
* checkout changes from autogen branch * move assert --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
e5351699bd
commit
d1bb08c5a1
22
.github/workflows/autogen.yml
vendored
22
.github/workflows/autogen.yml
vendored
@@ -130,6 +130,28 @@ jobs:
|
||||
mv tinygrad/runtime/autogen/mesa.py /tmp/mesa.py.bak
|
||||
python3 -c "from tinygrad.runtime.autogen import mesa"
|
||||
diff /tmp/mesa.py.bak tinygrad/runtime/autogen/mesa.py
|
||||
- name: Verify libclang autogen
|
||||
run: |
|
||||
mv tinygrad/runtime/autogen/libclang.py /tmp/libclang.py.bak
|
||||
python3 -c "from tinygrad.runtime.autogen import libclang"
|
||||
diff /tmp/libclang.py.bak tinygrad/runtime/autogen/libclang.py
|
||||
autogen-mac:
|
||||
name: In-tree Autogen (macos)
|
||||
runs-on: macos-14
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
- name: Setup Environment
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
llvm: 'true'
|
||||
pydeps: 'clang>=20'
|
||||
- name: Verify macos autogen
|
||||
run: |
|
||||
mv tinygrad/runtime/autogen/metal.py /tmp/metal.py.bak
|
||||
LIBCLANG_PATH=/opt/homebrew/opt/llvm@20/lib/libclang.dylib python3 -c "from tinygrad.runtime.autogen import metal"
|
||||
diff /tmp/metal.py.bak tinygrad/runtime/autogen/metal.py
|
||||
autogen-comgr-3:
|
||||
name: In-tree Autogen (comgr 3)
|
||||
runs-on: ubuntu-24.04
|
||||
|
||||
@@ -4,6 +4,7 @@ from tinygrad.helpers import fetch, flatten, system
|
||||
root = (here:=pathlib.Path(__file__).parent).parents[2]
|
||||
nv_src = {"nv_570": "https://github.com/NVIDIA/open-gpu-kernel-modules/archive/81fe4fb417c8ac3b9bdcc1d56827d116743892a5.tar.gz",
|
||||
"nv_580": "https://github.com/NVIDIA/open-gpu-kernel-modules/archive/2af9f1f0f7de4988432d4ae875b5858ffdb09cc2.tar.gz"}
|
||||
macossdk = "/var/db/xcode_select_link/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk"
|
||||
|
||||
def load(name, dll, files, **kwargs):
|
||||
if not (f:=(root/(path:=kwargs.pop("path", __name__)).replace('.','/')/f"{name}.py")).exists():
|
||||
@@ -120,4 +121,12 @@ python3 src/compiler/nir/nir_intrinsics_h.py --outdir gen
|
||||
python3 src/compiler/builtin_types_h.py gen/builtin_types.h""", cwd=path, shell=True, check=True),
|
||||
tarball="https://gitlab.freedesktop.org/mesa/mesa/-/archive/mesa-25.2.4/mesa-25.2.4.tar.gz",
|
||||
prolog=["import gzip, base64", "from tinygrad.helpers import OSX"], epilog=lambda path: [system(f"{root}/extra/mesa/lvp_nir_options.sh {path}")])
|
||||
case "libclang":
|
||||
return load("libclang", ["os.getenv('LIBCLANG_PATH', find_library('clang-20'))"],
|
||||
lambda: [system("llvm-config-20 --includedir")+"/clang-c/Index.h"], args=lambda: system("llvm-config-20 --cflags").split(),
|
||||
types={"CXString":"ci._CXString","CXType":"ci.Type","CXCursor":"ci.Cursor"}, prolog=["import clang.cindex as ci"])
|
||||
case "metal":
|
||||
return load("metal", ["find_library('Metal')"],[f"{macossdk}/System/Library/Frameworks/Metal.framework/Headers/MTL{s}.h" for s in
|
||||
["ComputeCommandEncoder", "ComputePipeline", "CommandQueue", "Device", "IndirectCommandBuffer", "Resource", "CommandEncoder"]],
|
||||
args=["-xobjective-c","-isysroot",macossdk], types={"dispatch_data_t":"objc.id_"})
|
||||
case _: raise AttributeError(f"no such autogen: {nm}")
|
||||
|
||||
2407
tinygrad/runtime/autogen/libclang.py
Normal file
2407
tinygrad/runtime/autogen/libclang.py
Normal file
File diff suppressed because it is too large
Load Diff
2048
tinygrad/runtime/autogen/metal.py
Normal file
2048
tinygrad/runtime/autogen/metal.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -5,15 +5,9 @@ from tinygrad.helpers import dedup, getenv, merge_dicts, PROFILE
|
||||
from tinygrad.device import Buffer, ProfileGraphEntry, ProfileGraphEvent
|
||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
||||
from tinygrad.engine.jit import GraphRunner, GraphException
|
||||
from tinygrad.runtime.ops_metal import wait_check, msg, libobjc, to_struct, objc_instance,\
|
||||
MTLResourceOptions, cmdbuf_st_time, cmdbuf_en_time, objc_id, to_ns_str
|
||||
|
||||
class MTLIndirectCommandType:
|
||||
MTLIndirectCommandTypeConcurrentDispatch = (1 << 5)
|
||||
|
||||
class MTLResourceUsage:
|
||||
MTLResourceUsageRead = 0b01
|
||||
MTLResourceUsageWrite = 0b10
|
||||
from tinygrad.runtime.ops_metal import wait_check, to_ns_str
|
||||
from tinygrad.runtime.autogen import metal
|
||||
from tinygrad.runtime.support import objc
|
||||
|
||||
class MetalGraph(GraphRunner):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
|
||||
@@ -21,16 +15,17 @@ class MetalGraph(GraphRunner):
|
||||
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
|
||||
|
||||
# create metal batch exec
|
||||
icb_descriptor = msg("new", objc_instance)(libobjc.objc_getClass(b"MTLIndirectCommandBufferDescriptor"))
|
||||
msg("setCommandTypes:")(icb_descriptor, MTLIndirectCommandType.MTLIndirectCommandTypeConcurrentDispatch)
|
||||
msg("setInheritBuffers:")(icb_descriptor, False)
|
||||
msg("setInheritPipelineState:")(icb_descriptor, False)
|
||||
msg("setMaxKernelBufferBindCount:")(icb_descriptor, 31)
|
||||
icb_descriptor = metal.MTLIndirectCommandBufferDescriptor.new()
|
||||
icb_descriptor.setCommandTypes(metal.MTLIndirectCommandTypeConcurrentDispatch)
|
||||
icb_descriptor.setInheritBuffers(False)
|
||||
icb_descriptor.setInheritPipelineState(False)
|
||||
icb_descriptor.setMaxKernelBufferBindCount(31)
|
||||
|
||||
self.icb = msg("newIndirectCommandBufferWithDescriptor:maxCommandCount:options:", objc_instance)(self.dev.sysdevice,
|
||||
icb_descriptor, len(jit_cache), MTLResourceOptions.MTLResourceCPUCacheModeDefaultCache)
|
||||
self.icb = self.dev.sysdevice.newIndirectCommandBufferWithDescriptor_maxCommandCount_options(icb_descriptor, len(jit_cache),
|
||||
metal.MTLResourceCPUCacheModeDefaultCache)
|
||||
if self.icb.value is None: raise GraphException("create indirect command buffer failed, does your system support this?")
|
||||
icb_label = bytes(msg("UTF8String", ctypes.c_char_p)(msg("description", objc_instance)(self.icb))).decode()
|
||||
# TODO: needs categories
|
||||
icb_label = bytes(objc.msg("UTF8String", ctypes.c_char_p)(objc.msg("description")(self.icb).retained())).decode()
|
||||
self.needs_icb_fix = int((m := re.search(r'AGXG(\d+)XFamily', icb_label)) is None or int(m.group(1)) < 15) # not required on M3+
|
||||
|
||||
self.fixedvars = merge_dicts([ji.fixedvars for ji in jit_cache])
|
||||
@@ -40,26 +35,25 @@ class MetalGraph(GraphRunner):
|
||||
all_pipelines, all_resources = [], [self.int_buf.buf] if len(self.varlist) else []
|
||||
for j,ji in enumerate(jit_cache):
|
||||
prg: CompiledRunner = cast(CompiledRunner, ji.prg)
|
||||
icb_command = msg("indirectComputeCommandAtIndex:", objc_instance)(self.icb, j)
|
||||
icb_command = self.icb.indirectComputeCommandAtIndex(j).retained()
|
||||
all_pipelines.append(prg._prg.pipeline_state)
|
||||
msg("setComputePipelineState:")(icb_command, prg._prg.pipeline_state)
|
||||
icb_command.setComputePipelineState(prg._prg.pipeline_state)
|
||||
for i,b in enumerate(ji.bufs):
|
||||
if b is not None and b not in input_rawbuffers:
|
||||
msg("setKernelBuffer:offset:atIndex:")(icb_command, b._buf.buf, b._buf.offset, i)
|
||||
icb_command.setKernelBuffer_offset_atIndex(b._buf.buf, b._buf.offset, i)
|
||||
all_resources.append(b._buf.buf)
|
||||
for i,v in enumerate(prg.p.vars):
|
||||
msg("setKernelBuffer:offset:atIndex:")(icb_command, self.int_buf.buf, self.varlist.index(v.expr)*4, len(ji.bufs)+i)
|
||||
for i,v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex(self.int_buf.buf, self.varlist.index(v.expr)*4, len(ji.bufs)+i)
|
||||
|
||||
global_size, local_size = prg.p.launch_dims(var_vals)
|
||||
msg("concurrentDispatchThreadgroups:threadsPerThreadgroup:")(icb_command, to_struct(*global_size), to_struct(*local_size))
|
||||
msg("setBarrier")(icb_command)
|
||||
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup(metal.MTLSize(*global_size), metal.MTLSize(*local_size))
|
||||
icb_command.setBarrier()
|
||||
|
||||
self.all_resources = dedup(all_resources)
|
||||
self.all_pipelines = dedup(all_pipelines)
|
||||
self.command_buffer: Any = None
|
||||
if len(self.varlist): self.int_buf_view = self.dev.allocator._as_buffer(self.int_buf).cast('i')
|
||||
for var in self.fixedvars: self.int_buf_view[self.varlist.index(var)] = self.fixedvars[var]
|
||||
self.range = to_struct(0, len(jit_cache))
|
||||
self.range = metal.NSRange(0, len(jit_cache))
|
||||
|
||||
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
|
||||
if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer)
|
||||
@@ -68,18 +62,18 @@ class MetalGraph(GraphRunner):
|
||||
|
||||
all_resources = dedup(self.all_resources + [input_rawbuffers[input_idx]._buf.buf for input_idx in self.input_replace.values()])
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
computeCommand = msg("indirectComputeCommandAtIndex:", objc_id)(self.icb, j)
|
||||
msg("setKernelBuffer:offset:atIndex:")(computeCommand, input_rawbuffers[input_idx]._buf.buf, input_rawbuffers[input_idx]._buf.offset, i)
|
||||
computeCommand = self.icb.indirectComputeCommandAtIndex(j)
|
||||
computeCommand.setKernelBuffer_offset_atIndex(input_rawbuffers[input_idx]._buf.buf, input_rawbuffers[input_idx]._buf.offset, i)
|
||||
|
||||
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
|
||||
computeCommand = msg("indirectComputeCommandAtIndex:", objc_id)(self.icb, j)
|
||||
msg("concurrentDispatchThreadgroups:threadsPerThreadgroup:")(computeCommand, to_struct(*global_dims), to_struct(*local_dims))
|
||||
computeCommand = self.icb.indirectComputeCommandAtIndex(j)
|
||||
computeCommand.concurrentDispatchThreadgroups_threadsPerThreadgroup(metal.MTLSize(*global_dims), metal.MTLSize(*local_dims))
|
||||
for var in self.vars: self.int_buf_view[self.varlist.index(var)] = var_vals[var]
|
||||
|
||||
command_buffer = msg("commandBuffer", objc_instance)(self.dev.mtl_queue)
|
||||
encoder = msg("computeCommandEncoder", objc_instance)(command_buffer)
|
||||
msg("useResources:count:usage:")(encoder, (objc_id * len(all_resources))(*all_resources), len(all_resources),
|
||||
MTLResourceUsage.MTLResourceUsageRead | MTLResourceUsage.MTLResourceUsageWrite)
|
||||
command_buffer = self.dev.mtl_queue.commandBuffer().retained()
|
||||
encoder = command_buffer.computeCommandEncoder().retained()
|
||||
encoder.useResources_count_usage(ctypes.cast((metal.MTLBuffer * len(all_resources))(*all_resources), ctypes.POINTER(metal.MTLResource)),
|
||||
len(all_resources), metal.MTLResourceUsageRead | metal.MTLResourceUsageWrite)
|
||||
|
||||
# NOTE: the pipelines likely need to be added to the used resources to fix the crash on M1/M2, but I haven't figured out how
|
||||
# this is a O(n) hack to get them used. what should work is:
|
||||
@@ -88,24 +82,24 @@ class MetalGraph(GraphRunner):
|
||||
# to repro the crash (which can also crash other running GPU apps), run with FIX_METAL_ICB=0
|
||||
if getenv("FIX_METAL_ICB", self.needs_icb_fix):
|
||||
for ps in self.all_pipelines:
|
||||
msg("setComputePipelineState:")(encoder, ps)
|
||||
msg("dispatchThreadgroups:threadsPerThreadgroup:")(encoder, to_struct(0,0,0), to_struct(0,0,0))
|
||||
encoder.setComputePipelineState(ps)
|
||||
encoder.dispatchThreadgroups_threadsPerThreadgroup(metal.MTLSize(0,0,0), metal.MTLSize(0,0,0))
|
||||
|
||||
msg("executeCommandsInBuffer:withRange:")(encoder, self.icb, self.range)
|
||||
msg("endEncoding")(encoder)
|
||||
msg("setLabel:")(command_buffer, to_ns_str(f"batched {len(self.jit_cache)}"))
|
||||
msg("commit")(command_buffer)
|
||||
encoder.executeCommandsInBuffer_withRange(self.icb, self.range)
|
||||
encoder.endEncoding()
|
||||
command_buffer.setLabel(to_ns_str(f"batched {len(self.jit_cache)}"))
|
||||
command_buffer.commit()
|
||||
self.command_buffer = command_buffer
|
||||
|
||||
self.dev.mtl_buffers_in_flight.append(command_buffer)
|
||||
if wait:
|
||||
wait_check(command_buffer)
|
||||
return cmdbuf_en_time(command_buffer) - cmdbuf_st_time(command_buffer)
|
||||
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
|
||||
return None
|
||||
|
||||
def collect_timestamps(self):
|
||||
# create a graph event and evenly space each program
|
||||
st, en = decimal.Decimal(cmdbuf_st_time(self.command_buffer)) * 1000000, decimal.Decimal(cmdbuf_en_time(self.command_buffer)) * 1000000
|
||||
st, en = decimal.Decimal(self.command_buffer.GPUStartTime()) * 1000000, decimal.Decimal(self.command_buffer.GPUEndTime()) * 1000000
|
||||
ents = [ProfileGraphEntry(self.device, cast(CompiledRunner, ji.prg)._prg.name, i, i+1, is_copy=False) for i,ji in enumerate(self.jit_cache)]
|
||||
step = (en-st)/len(ents)
|
||||
self.dev.profile_events += [ProfileGraphEvent(ents, [], [st+step*i for i in range(len(ents)+1)])]
|
||||
|
||||
@@ -1,74 +1,42 @@
|
||||
import subprocess, pathlib, struct, ctypes, tempfile, functools, contextlib, decimal, platform
|
||||
from typing import Any, cast
|
||||
from tinygrad.helpers import prod, to_mv, getenv, round_up, cache_dir, T, init_c_struct_t, PROFILE, ProfileRangeEvent, cpu_profile, unwrap
|
||||
import subprocess, pathlib, struct, ctypes, tempfile, functools, contextlib, decimal, platform, sys
|
||||
from tinygrad.helpers import prod, to_mv, getenv, round_up, cache_dir, init_c_struct_t, PROFILE, ProfileRangeEvent, cpu_profile, unwrap
|
||||
import tinygrad.runtime.support.objc as objc
|
||||
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator, ProfileDeviceEvent
|
||||
from tinygrad.renderer.cstyle import MetalRenderer
|
||||
|
||||
class objc_id(ctypes.c_void_p): # This prevents ctypes from converting response to plain int, and dict.fromkeys() can use it to dedup
|
||||
def __hash__(self): return hash(self.value)
|
||||
def __eq__(self, other): return self.value == other.value
|
||||
|
||||
class objc_instance(objc_id): # method with name "new", "alloc" should be freed after use
|
||||
def __del__(self):
|
||||
# CPython doesn't make any guarantees about order in which globals (like `msg` or `libobjc`) are destroyed when the interpreter shuts down
|
||||
# https://github.com/tinygrad/tinygrad/pull/8949 triggered the unlucky ordering which lead to a bunch of errors at exit
|
||||
# TODO: Why isn't `sys.is_finalizing` working?
|
||||
if msg is not None and libobjc is not None: msg("release")(self)
|
||||
|
||||
class MTLResourceOptions:
|
||||
MTLResourceCPUCacheModeDefaultCache = 0
|
||||
MTLResourceStorageModeShared = 0 << 4
|
||||
|
||||
class MTLPipelineOption:
|
||||
MTLPipelineOptionNone = 0
|
||||
from tinygrad.runtime.autogen import metal
|
||||
|
||||
# 13 is requestType that metal uses to compile source code into MTLB, there aren't any docs or symbols.
|
||||
REQUEST_TYPE_COMPILE = 13
|
||||
|
||||
libobjc = ctypes.CDLL("/usr/lib/libobjc.dylib")
|
||||
libmetal = ctypes.CDLL("/System/Library/Frameworks/Metal.framework/Metal")
|
||||
# Must be loaded for default Metal Device: https://developer.apple.com/documentation/metal/1433401-mtlcreatesystemdefaultdevice?language=objc
|
||||
ctypes.CDLL("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics")
|
||||
libdispatch = ctypes.CDLL("/usr/lib/libSystem.dylib") # libdispatch is part of libSystem on mac
|
||||
libobjc.objc_getClass.restype = objc_id
|
||||
libobjc.sel_registerName.restype = objc_id
|
||||
libmetal.MTLCreateSystemDefaultDevice.restype = objc_instance
|
||||
libdispatch.dispatch_data_create.restype = objc_instance
|
||||
|
||||
# FIXME: these need autogen to support objc categories
|
||||
# https://developer.apple.com/library/archive/documentation/Cocoa/Conceptual/ObjectiveC/Chapters/ocCategories.html
|
||||
@functools.cache
|
||||
def msg(selector: str, restype: type[T] = objc_id): # type: ignore [assignment]
|
||||
resname = libobjc.sel_registerName(selector.encode())
|
||||
sender = libobjc["objc_msgSend"] # Using attribute access returns a new reference so setting restype is safe
|
||||
sender.restype = restype
|
||||
def _msg(ptr: objc_id, *args: Any) -> T: return sender(ptr, resname, *args)
|
||||
return _msg
|
||||
|
||||
@functools.cache
|
||||
def to_ns_str(s: str): return msg("stringWithUTF8String:", objc_instance)(libobjc.objc_getClass(b"NSString"), s.encode())
|
||||
def from_ns_str(s): return bytes(msg("UTF8String", ctypes.c_char_p)(s)).decode()
|
||||
def to_ns_str(s: str): return ctypes.cast(objc.msg("stringWithUTF8String:")(metal.NSString._objc_class_, s.encode()), metal.NSString)
|
||||
def from_ns_str(s): return bytes(objc.msg("UTF8String", ctypes.c_char_p)(s)).decode()
|
||||
|
||||
def to_struct(*t: int, _type: type[ctypes._SimpleCData] = ctypes.c_ulong):
|
||||
return init_c_struct_t(tuple([(f"field{i}", _type) for i in range(len(t))]))(*t)
|
||||
|
||||
def wait_check(cbuf: Any):
|
||||
msg("waitUntilCompleted")(cbuf)
|
||||
error_check(msg("error", objc_instance)(cbuf))
|
||||
def wait_check(cbuf:metal.MTLCommandBuffer):
|
||||
cbuf.waitUntilCompleted()
|
||||
error_check(cbuf.error().retained())
|
||||
|
||||
def cmdbuf_label(cbuf: objc_id) -> str|None: return from_ns_str(label) if (label:=msg("label", objc_id)(cbuf)).value is not None else None
|
||||
def cmdbuf_st_time(cbuf: objc_id) -> float: return cast(float, msg("GPUStartTime", ctypes.c_double)(cbuf))
|
||||
def cmdbuf_en_time(cbuf: objc_id) -> float: return cast(float, msg("GPUEndTime", ctypes.c_double)(cbuf))
|
||||
def cmdbuf_label(cbuf:metal.MTLCommandBuffer) -> str|None: return from_ns_str(label) if (label:=cbuf.label()).value is not None else None
|
||||
|
||||
def error_check(error: objc_instance, error_constructor: type[Exception] = RuntimeError):
|
||||
def error_check(error: metal.NSError, error_constructor: type[Exception] = RuntimeError):
|
||||
if error.value is None: return None
|
||||
raise error_constructor(from_ns_str(msg("localizedDescription", objc_instance)(error)))
|
||||
raise error_constructor(from_ns_str(error.localizedDescription().retained()))
|
||||
|
||||
class MetalDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
self.sysdevice = libmetal.MTLCreateSystemDefaultDevice()
|
||||
self.mtl_queue = msg("newCommandQueueWithMaxCommandBufferCount:", objc_instance)(self.sysdevice, 1024)
|
||||
self.sysdevice = metal.MTLCreateSystemDefaultDevice()
|
||||
self.mtl_queue = self.sysdevice.newCommandQueueWithMaxCommandBufferCount(1024)
|
||||
if self.mtl_queue is None: raise RuntimeError("Cannot allocate a new command queue")
|
||||
self.mtl_buffers_in_flight: list[Any] = []
|
||||
self.timeline_signal = msg("newSharedEvent", objc_instance)(self.sysdevice)
|
||||
self.mtl_buffers_in_flight: list[metal.MTLCommandBuffer] = []
|
||||
self.timeline_signal = self.sysdevice.newSharedEvent()
|
||||
self.timeline_value = 0
|
||||
|
||||
Compiled.profile_events += [ProfileDeviceEvent(device)]
|
||||
@@ -76,23 +44,22 @@ class MetalDevice(Compiled):
|
||||
from tinygrad.runtime.graph.metal import MetalGraph
|
||||
# NOTE: GitHub CI macOS runners use paravirtualized metal which is broken with graph.
|
||||
# This can be reproduced locally with any virtualization software (like utm) that can create macOS VMs with apple's own virtualization framework.
|
||||
super().__init__(device, MetalAllocator(self), [(MetalRenderer, MetalCompiler), (MetalRenderer, Compiler)],
|
||||
functools.partial(MetalProgram, self), MetalGraph if 'virtual' not in from_ns_str(msg('name')(self.sysdevice)).lower() else None)
|
||||
super().__init__(device, MetalAllocator(self), [(MetalRenderer, MetalCompiler), (MetalRenderer, Compiler)], functools.partial(MetalProgram, self),
|
||||
MetalGraph if 'virtual' not in from_ns_str(self.sysdevice.name()).lower() else None)
|
||||
|
||||
def synchronize(self):
|
||||
for cbuf in self.mtl_buffers_in_flight:
|
||||
wait_check(cbuf)
|
||||
st, en = decimal.Decimal(cmdbuf_st_time(cbuf)) * 1000000, decimal.Decimal(cmdbuf_en_time(cbuf)) * 1000000
|
||||
st, en = decimal.Decimal(cbuf.GPUStartTime()) * 1000000, decimal.Decimal(cbuf.GPUEndTime()) * 1000000
|
||||
# NOTE: command buffers from MetalGraph are not profiled here
|
||||
if PROFILE and (lb:=cmdbuf_label(cbuf)) is not None and not lb.startswith("batched"):
|
||||
Compiled.profile_events += [ProfileRangeEvent(self.device, lb, st, en, is_copy=lb.startswith("COPY"))]
|
||||
self.mtl_buffers_in_flight.clear()
|
||||
|
||||
def metal_src_to_library(device:MetalDevice, src:str) -> objc_instance:
|
||||
options = msg("new", objc_instance)(libobjc.objc_getClass(b"MTLCompileOptions"))
|
||||
msg("setFastMathEnabled:")(options, getenv("METAL_FAST_MATH"))
|
||||
library = msg("newLibraryWithSource:options:error:", objc_instance)(device.sysdevice, to_ns_str(src),
|
||||
options, ctypes.byref(compileError:=objc_instance()))
|
||||
def metal_src_to_library(device:MetalDevice, src:str) -> metal.MTLLibrary:
|
||||
options = metal.MTLCompileOptions.new()
|
||||
options.setFastMathEnabled(getenv("METAL_FAST_MATH"))
|
||||
library = device.sysdevice.newLibraryWithSource_options_error(to_ns_str(src), options, ctypes.byref(compileError:=metal.NSError().retained()))
|
||||
error_check(compileError, CompileError)
|
||||
return library
|
||||
|
||||
@@ -155,71 +122,71 @@ class MetalProgram:
|
||||
self.dev, self.name, self.lib = dev, name, lib
|
||||
if lib[:4] == b"MTLB":
|
||||
# binary metal library
|
||||
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
|
||||
self.library = msg("newLibraryWithData:error:", objc_instance)(self.dev.sysdevice, data, ctypes.byref(error_lib:=objc_instance()))
|
||||
data = objc.dispatch_data_create(lib, len(lib), None, None)
|
||||
self.library = self.dev.sysdevice.newLibraryWithData_error(data, ctypes.byref(error_lib:=metal.NSError().retained())).retained()
|
||||
error_check(error_lib)
|
||||
else:
|
||||
# metal source. rely on OS caching
|
||||
try: self.library = metal_src_to_library(self.dev, lib.decode())
|
||||
except CompileError as e: raise RuntimeError from e
|
||||
self.fxn = msg("newFunctionWithName:", objc_instance)(self.library, to_ns_str(name))
|
||||
descriptor = msg("new", objc_instance)(libobjc.objc_getClass(b"MTLComputePipelineDescriptor"))
|
||||
msg("setComputeFunction:")(descriptor, self.fxn)
|
||||
msg("setSupportIndirectCommandBuffers:")(descriptor, True)
|
||||
self.pipeline_state = msg("newComputePipelineStateWithDescriptor:options:reflection:error:", objc_instance)(self.dev.sysdevice,
|
||||
descriptor, MTLPipelineOption.MTLPipelineOptionNone, None, ctypes.byref(error_pipeline_creation:=objc_instance()))
|
||||
self.fxn = self.library.newFunctionWithName(to_ns_str(name)).retained()
|
||||
descriptor = metal.MTLComputePipelineDescriptor.new()
|
||||
descriptor.setComputeFunction(self.fxn)
|
||||
descriptor.setSupportIndirectCommandBuffers(True)
|
||||
self.pipeline_state = self.dev.sysdevice.newComputePipelineStateWithDescriptor_options_reflection_error(descriptor, metal.MTLPipelineOptionNone,
|
||||
None, ctypes.byref(error_pipeline_creation:=metal.NSError().retained()))
|
||||
error_check(error_pipeline_creation)
|
||||
# cache these msg calls
|
||||
self.max_total_threads: int = cast(int, msg("maxTotalThreadsPerThreadgroup", ctypes.c_ulong)(self.pipeline_state))
|
||||
self.max_total_threads: int = self.pipeline_state.maxTotalThreadsPerThreadgroup()
|
||||
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
|
||||
if prod(local_size) > self.max_total_threads:
|
||||
exec_width = msg("threadExecutionWidth", ctypes.c_ulong)(self.pipeline_state)
|
||||
memory_length = msg("staticThreadgroupMemoryLength", ctypes.c_ulong)(self.pipeline_state)
|
||||
exec_width = self.pipeline_state.threadExecutionWidth()
|
||||
memory_length = self.pipeline_state.staticThreadgroupMemoryLength()
|
||||
raise RuntimeError(f"local size {local_size} bigger than {self.max_total_threads} with exec width {exec_width} memory length {memory_length}")
|
||||
command_buffer = msg("commandBuffer", objc_instance)(self.dev.mtl_queue)
|
||||
encoder = msg("computeCommandEncoder", objc_instance)(command_buffer)
|
||||
msg("setComputePipelineState:")(encoder, self.pipeline_state)
|
||||
for i,a in enumerate(bufs): msg("setBuffer:offset:atIndex:")(encoder, a.buf, a.offset, i)
|
||||
for i,a in enumerate(vals, start=len(bufs)): msg("setBytes:length:atIndex:")(encoder, bytes(ctypes.c_int(a)), 4, i)
|
||||
msg("dispatchThreadgroups:threadsPerThreadgroup:")(encoder, to_struct(*global_size), to_struct(*local_size))
|
||||
msg("endEncoding")(encoder)
|
||||
msg("setLabel:")(command_buffer, to_ns_str(self.name)) # TODO: is this always needed?
|
||||
msg("commit")(command_buffer)
|
||||
command_buffer = self.dev.mtl_queue.commandBuffer().retained() # FIXME: is this really ARC?
|
||||
encoder = command_buffer.computeCommandEncoder().retained() # FIXME: is this really ARC?
|
||||
encoder.setComputePipelineState(self.pipeline_state)
|
||||
for i,a in enumerate(bufs): encoder.setBuffer_offset_atIndex(a.buf, a.offset, i)
|
||||
for i,a in enumerate(vals, start=len(bufs)): encoder.setBytes_length_atIndex(bytes(ctypes.c_int(a)), 4, i)
|
||||
encoder.dispatchThreadgroups_threadsPerThreadgroup(metal.MTLSize(*global_size), metal.MTLSize(*local_size))
|
||||
encoder.endEncoding()
|
||||
command_buffer.setLabel(to_ns_str(self.name)) # TODO: is this always needed?
|
||||
command_buffer.commit()
|
||||
self.dev.mtl_buffers_in_flight.append(command_buffer)
|
||||
if wait:
|
||||
wait_check(command_buffer)
|
||||
return cmdbuf_en_time(command_buffer) - cmdbuf_st_time(command_buffer)
|
||||
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
|
||||
|
||||
class MetalBuffer:
|
||||
def __init__(self, buf:Any, size:int, offset=0): self.buf, self.size, self.offset = buf, size, offset
|
||||
def __init__(self, buf:metal.MTLBuffer, size:int, offset=0): self.buf, self.size, self.offset = buf, size, offset
|
||||
|
||||
class MetalAllocator(LRUAllocator[MetalDevice]):
|
||||
def _alloc(self, size:int, options) -> MetalBuffer:
|
||||
if options.external_ptr: return MetalBuffer(objc_id(options.external_ptr), size)
|
||||
if options.external_ptr: return MetalBuffer(metal.MTLBuffer(options.external_ptr), size)
|
||||
|
||||
# Buffer is explicitly released in _free() rather than garbage collected via reference count
|
||||
ret = msg("newBufferWithLength:options:", objc_id)(self.dev.sysdevice, ctypes.c_ulong(size), MTLResourceOptions.MTLResourceStorageModeShared)
|
||||
ret = self.dev.sysdevice.newBufferWithLength_options(size, metal.MTLResourceStorageModeShared)
|
||||
ret.retain = False
|
||||
if ret.value is None: raise MemoryError(f"Metal OOM while allocating {size=}")
|
||||
return MetalBuffer(ret, size)
|
||||
def _free(self, opaque:MetalBuffer, options):
|
||||
if msg is not None and libobjc is not None: msg("release")(opaque.buf)
|
||||
if not sys.is_finalizing(): opaque.buf.release
|
||||
def _transfer(self, dest:MetalBuffer, src:MetalBuffer, sz:int, src_dev:MetalDevice, dest_dev:MetalDevice):
|
||||
dest_dev.synchronize()
|
||||
src_command_buffer = msg("commandBuffer", objc_instance)(src_dev.mtl_queue)
|
||||
encoder = msg("blitCommandEncoder", objc_instance)(src_command_buffer)
|
||||
msg("copyFromBuffer:sourceOffset:toBuffer:destinationOffset:size:")(encoder, src.buf, ctypes.c_ulong(src.offset),
|
||||
dest.buf, ctypes.c_ulong(dest.offset), ctypes.c_ulong(sz))
|
||||
msg("endEncoding")(encoder)
|
||||
src_command_buffer = src_dev.mtl_queue.commandBuffer().retained()
|
||||
encoder = src_command_buffer.blitCommandEncoder().retained()
|
||||
encoder.copyFromBuffer_sourceOffset_toBuffer_destinationOffset_size(src.buf, src.offset, dest.buf, dest.offset, sz)
|
||||
encoder.endEncoding()
|
||||
if src_dev != dest_dev:
|
||||
msg("encodeSignalEvent:value:")(src_command_buffer, src_dev.timeline_signal, src_dev.timeline_value)
|
||||
dest_command_buffer = msg("commandBuffer", objc_instance)(dest_dev.mtl_queue)
|
||||
msg("encodeWaitForEvent:value:")(dest_command_buffer, src_dev.timeline_signal, src_dev.timeline_value)
|
||||
msg("commit")(dest_command_buffer)
|
||||
src_command_buffer.encodeSignalEvent_value(ctypes.cast(src_dev.timeline_signal, metal.MTLEvent), src_dev.timeline_value)
|
||||
dest_command_buffer = dest_dev.mtl_queue.commandBuffer().retained()
|
||||
dest_command_buffer.encodeWaitForEvent_value(ctypes.cast(src_dev.timeline_signal, metal.MTLEvent), src_dev.timeline_value)
|
||||
dest_command_buffer.commit()
|
||||
dest_dev.mtl_buffers_in_flight.append(dest_command_buffer)
|
||||
src_dev.timeline_value += 1
|
||||
msg("setLabel:")(src_command_buffer, to_ns_str(f"COPY {src_dev.device} -> {dest_dev.device}"))
|
||||
msg("commit")(src_command_buffer)
|
||||
src_command_buffer.setLabel(to_ns_str(f"COPY {src_dev.device} -> {dest_dev.device}"))
|
||||
src_command_buffer.commit()
|
||||
src_dev.mtl_buffers_in_flight.append(src_command_buffer)
|
||||
# Transfers currently synchronize the completion. Otherwise, copies can sometimes lead to incorrect values.
|
||||
# There is no real metal multidevice support for now, so transfer is used only for tests.
|
||||
@@ -228,7 +195,7 @@ class MetalAllocator(LRUAllocator[MetalDevice]):
|
||||
with cpu_profile(prof_desc, self.dev.device, is_copy=True): dst[:] = src
|
||||
def _as_buffer(self, src:MetalBuffer) -> memoryview:
|
||||
self.dev.synchronize()
|
||||
return to_mv(cast(int, msg("contents", objc_id)(src.buf).value), src.size + src.offset)[src.offset:]
|
||||
return to_mv(src.buf.contents(), src.size + src.offset)[src.offset:]
|
||||
def _copyin(self, dest:MetalBuffer, src:memoryview): self._cp_mv(self._as_buffer(dest), src, "TINY -> METAL")
|
||||
def _copyout(self, dest:memoryview, src:MetalBuffer): self._cp_mv(dest, self._as_buffer(src), "METAL -> TINY")
|
||||
def _offset(self, buf:MetalBuffer, size:int, offset:int): return MetalBuffer(buf.buf, size, offset)
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import ctypes.util, importlib.metadata, itertools, re, functools, os
|
||||
from tinygrad.helpers import flatten, unwrap
|
||||
from tinygrad.helpers import flatten, unwrap, fromimport
|
||||
|
||||
assert importlib.metadata.version('clang')[:2] == "20", 'clang version 20 required, pip install "clang==20.1.0"'
|
||||
from clang.cindex import Config, Index, CursorKind as CK, TranslationUnit as TU, LinkageKind as LK, TokenKind as ToK, TypeKind as TK
|
||||
from clang.cindex import Config, Index, Cursor, Type, CursorKind as CK, TranslationUnit as TU, LinkageKind as LK, TokenKind as ToK, TypeKind as TK
|
||||
from clang.cindex import PrintingPolicy as PP, PrintingPolicyProperty as PPP, SourceRange
|
||||
|
||||
libclang = functools.partial(fromimport, "tinygrad.runtime.autogen.libclang") # we can't actually import this, because then we can't generate it
|
||||
|
||||
if not Config.loaded: Config.set_library_file(os.getenv("LIBCLANG_PATH", ctypes.util.find_library("clang-20")))
|
||||
|
||||
def fst(c): return next(c.get_children())
|
||||
@@ -14,18 +17,25 @@ def readext(f, fst, snd=None):
|
||||
return f.read((fst.end.offset if isinstance(fst, SourceRange) else snd)-start)
|
||||
def attrs(c): return list(filter(lambda k: (v:=k.value) >= 400 and v < 500, map(lambda c: c.kind, c.get_children())))
|
||||
|
||||
def protocols(t): yield from (Cursor.from_result(libclang("clang_Type_getObjCProtocolDecl")(t, i), t)
|
||||
for i in range(libclang("clang_Type_getNumObjCProtocolRefs")(t)))
|
||||
def basetype(t): return Type.from_result(libclang("clang_Type_getObjCObjectBaseType")(t), (t,))
|
||||
|
||||
base_rules = [(r'\s*\\\n\s*', ' '), (r'\s*\n\s*', ' '), (r'//.*', ''), (r'/\*.*?\*/', ''), (r'\b(0[xX][0-9a-fA-F]+|\d+)[uUlL]+\b', r'\1'),
|
||||
(r'\b0+(?=\d)', ''), (r'\s*&&\s*', r' and '), (r'\s*\|\|\s*', r' or '), (r'\s*!\s*', ' not '),
|
||||
(r'(struct|union|enum)\s*([a-zA-Z_][a-zA-Z0-9_]*\b)', r'\1_\2'),
|
||||
(r'\((unsigned )?(char|uint64_t)\)', ''), (r'^.*\d+:\d+.*$', ''), (r'^.*\w##\w.*$', '')]
|
||||
|
||||
ints = (TK.INT, TK.UINT, TK.LONG, TK.ULONG, TK.LONGLONG, TK.ULONGLONG)
|
||||
specs = (CK.OBJC_SUPER_CLASS_REF,)
|
||||
# https://clang.llvm.org/docs/AutomaticReferenceCounting.html#arc-method-families
|
||||
arc_families = ['alloc', 'copy', 'mutableCopy', 'new']
|
||||
|
||||
def gen(dll, files, args=[], prolog=[], rules=[], epilog=[], recsym=False, use_errno=False, anon_names={}, types={}, parse_macros=True):
|
||||
macros, lines, anoncnt, types = [], [], itertools.count().__next__, {k:(v,True) for k,v in types.items()}
|
||||
macros, lines, anoncnt, types, objc = [], [], itertools.count().__next__, {k:(v,True) for k,v in types.items()}, False
|
||||
def tname(t, suggested_name=None, typedef=None) -> str:
|
||||
suggested_name = anon_names.get(f"{(decl:=t.get_declaration()).location.file}:{decl.location.line}", suggested_name)
|
||||
nonlocal lines, types, anoncnt
|
||||
nonlocal lines, types, anoncnt, objc
|
||||
tmap = {TK.VOID:"None", TK.CHAR_U:"ctypes.c_ubyte", TK.UCHAR:"ctypes.c_ubyte", TK.CHAR_S:"ctypes.c_char", TK.SCHAR:"ctypes.c_char",
|
||||
**{getattr(TK, k):f"ctypes.c_{k.lower()}" for k in ["BOOL", "WCHAR", "FLOAT", "DOUBLE", "LONGDOUBLE"]},
|
||||
**{getattr(TK, k):f"ctypes.c_{'u' if 'U' in k else ''}int{sz}" for sz,k in
|
||||
@@ -37,6 +47,7 @@ def gen(dll, files, args=[], prolog=[], rules=[], epilog=[], recsym=False, use_e
|
||||
return f"ctypes.CFUNCTYPE({tname(f.get_result())}{(', '+', '.join(map(tname, f.argument_types()))) if f.kind==TK.FUNCTIONPROTO else ''})"
|
||||
match t.kind:
|
||||
case TK.POINTER: return "ctypes.c_void_p" if (ptr:=t.get_pointee()).kind == TK.VOID else f"ctypes.POINTER({tname(ptr)})"
|
||||
case TK.OBJCOBJECTPOINTER: return tname(t.get_pointee()) # TODO: this seems wrong
|
||||
case TK.ELABORATED: return tname(t.get_named_type(), suggested_name)
|
||||
case TK.TYPEDEF if t.spelling == t.get_canonical().spelling: return tname(t.get_canonical())
|
||||
case TK.TYPEDEF:
|
||||
@@ -76,8 +87,62 @@ def gen(dll, files, args=[], prolog=[], rules=[], epilog=[], recsym=False, use_e
|
||||
case TK.CONSTANTARRAY:
|
||||
return f"({tname(t.get_array_element_type(), suggested_name.rstrip('s') if suggested_name else None)} * {t.get_array_size()})"
|
||||
case TK.INCOMPLETEARRAY: return f"({tname(t.get_array_element_type(), suggested_name.rstrip('s') if suggested_name else None)} * 0)"
|
||||
case TK.OBJCINTERFACE:
|
||||
is_defn = bool([f.kind for f in decl.get_children() if f.kind in (CK.OBJC_INSTANCE_METHOD_DECL, CK.OBJC_CLASS_METHOD_DECL)])
|
||||
if (nm:=t.spelling) not in types: lines.append(f"class {nm}(objc.Spec): pass")
|
||||
types[nm] = nm, is_defn
|
||||
if is_defn:
|
||||
ims, cms = parse_objc_spec(decl, t.spelling, CK.OBJC_INSTANCE_METHOD_DECL), parse_objc_spec(decl, t.spelling, CK.OBJC_CLASS_METHOD_DECL)
|
||||
lines.extend([*([f"{nm}._bases_ = [{', '.join(bs)}]"] if (bs:=[tname(b.type) for b in decl.get_children() if b.kind in specs]) else []),
|
||||
*([f"{nm}._methods_ = [", *ims, ']'] if ims else []), *([f"{nm}._classmethods_ = [", *cms, ']'] if cms else [])])
|
||||
return nm
|
||||
case TK.OBJCSEL: return "objc.id_"
|
||||
case TK.OBJCID: return (objc:=True, "objc.id_")[1]
|
||||
case TK.OBJCOBJECT:
|
||||
if basetype(t).kind != TK.OBJCID: raise NotImplementedError(f"generics unsupported: {t.spelling}")
|
||||
ps = [proto(p) for p in protocols(t)]
|
||||
if len(ps) == 0:
|
||||
types[t.spelling] = "objc.id_", True
|
||||
return "objc.id_"
|
||||
if len(ps) == 1:
|
||||
types[t.spelling] = ps[0], True
|
||||
return ps[0]
|
||||
types[t.spelling] = (nm:=f"_anondynamic{anoncnt()}"), True
|
||||
lines.append(f"class {nm}({', '.join(p for p in ps)}): pass # {t.spelling}")
|
||||
return nm
|
||||
case _: raise NotImplementedError(f"unsupported type {t.kind}")
|
||||
|
||||
# parses an objc @interface or @protocol, returning a list of declerations that objc.Spec can parse, for the specified kind
|
||||
# NB: ivars are unsupported
|
||||
def parse_objc_spec(decl:Cursor, nm:str, kind:CK) -> list[str]:
|
||||
nonlocal lines, types
|
||||
if decl is None: return []
|
||||
ms = []
|
||||
for d in filter(lambda d: d.kind == kind, decl.get_children()):
|
||||
rollback = lines, types
|
||||
try: ms.append(f" ('{d.spelling}', {repr('instancetype') if (rt:=d.result_type).spelling=='instancetype' else tname(rt)}, "
|
||||
f"[{', '.join('instancetype' if a.spelling == 'instancetype' else tname(a.type) for a in d.get_arguments())}]" +
|
||||
(", True" if CK.NS_RETURNS_RETAINED in attrs(d) or (any(d.spelling.startswith(s) for s in arc_families) and rt.kind!=TK.VOID) else "") + "),")
|
||||
except NotImplementedError as e:
|
||||
print(f"skipping {nm}.{d.spelling}: {e}")
|
||||
lines, types = rollback
|
||||
return ms
|
||||
|
||||
# libclang doesn't have a "type" for @protocol, so we have to do this here...
|
||||
def proto(decl):
|
||||
nonlocal lines, types
|
||||
if (nm:=decl.spelling) in types and types[nm][1]: return types[nm][0]
|
||||
# check if this is a forward declaration
|
||||
is_defn = bool([f.kind for f in decl.get_children() if f.kind in (CK.OBJC_INSTANCE_METHOD_DECL, CK.OBJC_CLASS_METHOD_DECL)])
|
||||
if nm not in types: lines.append(f"class {nm}(objc.Spec): pass")
|
||||
types[nm] = nm, is_defn
|
||||
if is_defn:
|
||||
bs = [proto(b) for b in decl.get_children() if b.kind==CK.OBJC_PROTOCOL_REF and b.spelling != decl.spelling]
|
||||
ims, cms = parse_objc_spec(decl, nm, CK.OBJC_INSTANCE_METHOD_DECL), parse_objc_spec(decl, nm, CK.OBJC_CLASS_METHOD_DECL)
|
||||
lines.extend([*([f"{nm}._bases_ = [{', '.join(bs)}]"] if bs else []),
|
||||
*([f"{nm}._methods_ = [", *ims, "]"] if ims else []), *([f"{nm}._classmethods_ = [", *cms, "]"] if cms else [])])
|
||||
return nm
|
||||
|
||||
for f in files:
|
||||
tu = Index.create().parse(f, args, options=TU.PARSE_DETAILED_PROCESSING_RECORD)
|
||||
(pp:=PP.create(tu.cursor)).set_property(PPP.TerseOutput, 1)
|
||||
@@ -90,7 +155,8 @@ def gen(dll, files, args=[], prolog=[], rules=[], epilog=[], recsym=False, use_e
|
||||
# TODO: we could support name-mangling
|
||||
lines.append(f"# {c.pretty_printed(pp)}\ntry: ({c.spelling}:=dll.{c.spelling}).restype, {c.spelling}.argtypes = "
|
||||
f"{tname(c.result_type)}, [{', '.join(tname(arg.type) for arg in c.get_arguments())}]\nexcept AttributeError: pass\n")
|
||||
case CK.STRUCT_DECL | CK.UNION_DECL | CK.TYPEDEF_DECL | CK.ENUM_DECL: tname(c.type)
|
||||
if CK.NS_RETURNS_RETAINED in attrs(c): lines.append(f"{c.spelling} = objc.returns_retained({c.spelling})")
|
||||
case CK.STRUCT_DECL | CK.UNION_DECL | CK.TYPEDEF_DECL | CK.ENUM_DECL | CK.OBJC_INTERFACE_DECL: tname(c.type)
|
||||
case CK.MACRO_DEFINITION if parse_macros and len(toks:=list(c.get_tokens())) > 1:
|
||||
if toks[1].spelling == '(' and toks[0].extent.end.column == toks[1].extent.start.column:
|
||||
it = iter(toks[1:])
|
||||
@@ -107,12 +173,13 @@ def gen(dll, files, args=[], prolog=[], rules=[], epilog=[], recsym=False, use_e
|
||||
else: macros += [f"{c.spelling} = {tname(c.type)}({readext(f, last(c).extent)})"]
|
||||
case CK.VAR_DECL if c.linkage == LK.EXTERNAL and dll:
|
||||
lines.append(f"try: {c.spelling} = {tname(c.type)}.in_dll(dll, '{c.spelling}')\nexcept (ValueError,AttributeError): pass")
|
||||
case CK.OBJC_PROTOCOL_DECL: proto(c)
|
||||
except NotImplementedError as e:
|
||||
print(f"skipping {c.spelling}: {e}")
|
||||
lines, types = rollback
|
||||
main = (f"# mypy: ignore-errors\nimport ctypes{', os' if any('os' in s for s in dll) else ''}\n"
|
||||
"from tinygrad.helpers import unwrap\nfrom tinygrad.runtime.support.c import Struct, CEnum, _IO, _IOW, _IOR, _IOWR\n" + '\n'.join([*prolog,
|
||||
*(["from ctypes.util import find_library"]*any('find_library' in s for s in dll)),
|
||||
*(["from ctypes.util import find_library"]*any('find_library' in s for s in dll)), *(["from tinygrad.runtime.support import objc"]*objc),
|
||||
*(["def dll():",*flatten([[f" try: return ctypes.CDLL(unwrap({d}){', use_errno=True' if use_errno else ''})",' except: pass'] for d in dll]),
|
||||
" return None", "dll = dll()\n"]*bool(dll)), *lines]) + '\n')
|
||||
macros = [r for m in macros if (r:=functools.reduce(lambda s,r:re.sub(r[0], r[1], s), rules + base_rules, m))]
|
||||
|
||||
70
tinygrad/runtime/support/objc.py
Normal file
70
tinygrad/runtime/support/objc.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import ctypes, ctypes.util, functools, sys
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING: id_ = ctypes.c_void_p
|
||||
else:
|
||||
class id_(ctypes.c_void_p):
|
||||
retain: bool = False
|
||||
# This prevents ctypes from converting response to plain int, and dict.fromkeys() can use it to dedup
|
||||
def __hash__(self): return hash(self.value)
|
||||
def __eq__(self, other): return self.value == other.value
|
||||
def __del__(self):
|
||||
if self.retain and not sys.is_finalizing(): self.release()
|
||||
def release(self): msg("release")(self)
|
||||
def retained(self):
|
||||
setattr(self, 'retain', True)
|
||||
return self
|
||||
|
||||
def returns_retained(f): return functools.wraps(f)(lambda *args, **kwargs: f(*args, **kwargs).retained())
|
||||
|
||||
lib = ctypes.CDLL(ctypes.util.find_library('objc'))
|
||||
lib.sel_registerName.restype = id_
|
||||
getsel = functools.cache(lib.sel_registerName)
|
||||
lib.objc_getClass.restype = id_
|
||||
dispatch_data_create = ctypes.CDLL("/usr/lib/libSystem.dylib").dispatch_data_create
|
||||
dispatch_data_create.restype = id_
|
||||
dispatch_data_create = returns_retained(dispatch_data_create)
|
||||
|
||||
def msg(sel:str, restype=id_, argtypes=[], retain=False, clsmeth=False):
|
||||
# Using attribute access returns a new reference so setting restype is safe
|
||||
(sender:=lib["objc_msgSend"]).restype, sender.argtypes = restype, [id_, id_]+argtypes if argtypes else []
|
||||
def f(ptr, *args): return sender(ptr._objc_class_ if clsmeth else ptr, getsel(sel.encode()), *args)
|
||||
return returns_retained(f) if retain else f
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import _ctypes
|
||||
class MetaSpec(_ctypes._PyCSimpleType):
|
||||
_objc_class_: id_
|
||||
def __getattr__(cls, nm:str) -> Any: ...
|
||||
else:
|
||||
class MetaSpec(type(id_)):
|
||||
def __new__(mcs, name, bases, dct):
|
||||
cls = super().__new__(mcs, name, bases, {'_objc_class_': lib.objc_getClass(name.encode()), '_children_': set(), **dct})
|
||||
cls._methods_, cls._classmethods_ = dct.get('_methods_', []), dct.get('_classmethods_', [])
|
||||
return cls
|
||||
|
||||
def __setattr__(cls, k, v):
|
||||
super().__setattr__(k, v)
|
||||
if k in ("_methods_", "_classmethods_"):
|
||||
for m in v: cls._addmeth(m, clsmeth=(v=="_classmethods_"))
|
||||
for c in cls._children_: c._inherit(cls)
|
||||
if k == "_bases_":
|
||||
for b in v:
|
||||
b._children_.add(cls)
|
||||
cls._inherit(b)
|
||||
|
||||
def _inherit(cls, b):
|
||||
for _b in getattr(b, "_bases_", []): cls._inherit(_b)
|
||||
for m in getattr(b, "_methods_", []): cls._addmeth(m)
|
||||
for m in getattr(b, "_classmethods_", []): cls._addmeth(m, True)
|
||||
for c in cls._children_: c._inherit(cls)
|
||||
|
||||
def _addmeth(cls, m, clsmeth=False):
|
||||
nm = m[0].strip(':').replace(':', '_')
|
||||
if clsmeth: setattr(cls, nm, classmethod(msg(m[0], cls if m[1] == 'instancetype' else m[1],
|
||||
[cls if a == 'instancetype' else a for a in m[2]], *m[3:], clsmeth=True))) # type: ignore[misc]
|
||||
else: setattr(cls, nm, msg(m[0], cls if m[1] == 'instancetype' else m[1], [cls if a == 'instancetype' else a for a in m[2]], *m[3:]))
|
||||
|
||||
class Spec(id_, metaclass=MetaSpec):
|
||||
if TYPE_CHECKING:
|
||||
def __getattr__(self, nm:str) -> Any: ...
|
||||
Reference in New Issue
Block a user