remove arm64, caching for cuda (#2201)

* remove arm64, caching for cuda

* caching in llvm

* switch cache_compiled to new cache

* fix clang

* caching for metal

* fix pylint

* cleanups

* perf_counter and binary
This commit is contained in:
George Hotz
2023-11-01 18:44:00 -07:00
committed by GitHub
parent 7103b716c4
commit 8932816816
9 changed files with 76 additions and 125 deletions

View File

@@ -31,3 +31,9 @@ repos:
language: system language: system
always_run: true always_run: true
pass_filenames: false pass_filenames: false
- id: pylint
name: pylint
entry: python -m pylint tinygrad/
language: system
always_run: true
pass_filenames: false

View File

@@ -16,5 +16,5 @@ if __name__ == "__main__":
cur3.execute(f"SELECT * FROM {table} LIMIT 10") cur3.execute(f"SELECT * FROM {table} LIMIT 10")
for f in cur3.fetchall(): for f in cur3.fetchall():
v = pickle.loads(f[-1]) v = pickle.loads(f[-1])
print(" ", len(f[0]), f[1:-1], v) print(" ", len(f[0]) if isinstance(f[0], str) else f[0], f[1:-1], str(v)[0:50])
#print(f"{len(k):10d}, {sk} -> {v}") #print(f"{len(k):10d}, {sk} -> {v}")

View File

@@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
import os, functools, platform, time, re, contextlib, operator, pathlib, hashlib, tempfile, pickle, sqlite3 import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3
import numpy as np import numpy as np
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10 if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
@@ -156,12 +156,11 @@ class GlobalCounters:
# *** compiled cache decorator *** # *** compiled cache decorator ***
def cache_compiled(func): def cache_compiled(func):
if getenv("DISABLE_COMPILER_CACHE"): return func
def wrapper(self, prg:str, *args, **kwargs) -> bytes: def wrapper(self, prg:str, *args, **kwargs) -> bytes:
cache_path, output_file = pathlib.Path(f"{tempfile.gettempdir()}/tinygrad_cc_{hashlib.sha256(prg.encode()).hexdigest()}"), pathlib.Path(tempfile.mktemp()) table, key = f"compiler_cache_{type(self).__name__}", hashlib.sha256(prg.encode()).hexdigest()
if not cache_path.exists(): if (ret:=diskcache_get(table, key)): return ret
output_file.write_bytes(func(self, prg, *args, **kwargs)) return diskcache_put(table, key, func(self, prg, *args, **kwargs))
output_file.rename(cache_path)
return cache_path.read_bytes()
return wrapper return wrapper
# *** universal database cache *** # *** universal database cache ***

View File

@@ -1,85 +1,38 @@
import time, ctypes, subprocess, platform, functools, pathlib, tempfile import time, ctypes, subprocess, platform, functools, pathlib, tempfile
from typing import Any from typing import Any
from functools import partial, reduce
from tinygrad.ops import Compiled from tinygrad.ops import Compiled
from tinygrad.helpers import fromimport, getenv, DEBUG, CI, cache_compiled from tinygrad.helpers import cache_compiled
from tinygrad.runtime.lib import RawMallocBuffer from tinygrad.runtime.lib import RawMallocBuffer
from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
import struct
import numpy as np
ARM64 = getenv('ARM64', False)
if CI and ARM64: from unicorn import Uc, UC_ARCH_ARM64, UC_MODE_ARM, UC_HOOK_CODE, arm64_const # type: ignore
args = { args = {
'Windows': {'cflags':'', 'ext':'dll', 'exp':'__declspec(dllexport) '}, 'Linux': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'so'},
'Linux': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'so', 'exp':''}, 'Darwin': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'dylib'}
'Darwin': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'dylib', 'exp':''}
}[platform.system()] }[platform.system()]
CLANG_PROGRAM_HEADER = '#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#include <stdbool.h>\n' CLANG_PROGRAM_HEADER = '#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#include <stdbool.h>\n'
ADDRESS = 0x10000
# Unicorn doesn't support external calls
def align(addr): return (addr+4095) & ~(4095)
mock_lm = {"sinf": np.sin, "sqrtf": np.sqrt, "exp2f": np.exp2, "log2f": np.log2}
def emulate_ext_calls(fn, uc, address, size, user_data):
s_in = struct.unpack('f', struct.pack('I', uc.reg_read(getattr(arm64_const, f'UC_ARM64_REG_S{fn[2][1:]}'))))[0]
uc.reg_write(getattr(arm64_const, f'UC_ARM64_REG_S{fn[1][1:]}'), struct.unpack('I', struct.pack('f', mock_lm[fn[0]](s_in)))[0]) # type: ignore
class ClangProgram: class ClangProgram:
def __init__(self, name:str, prg:str, binary:bool=False): def __init__(self, name:str, prg:str, binary=False):
if binary and DEBUG >= 5: print(prg) self.prg: bytes = prg if binary else self.compile(CLANG_PROGRAM_HEADER+prg)
self.prg: Any = self.compile(prg if binary else CLANG_PROGRAM_HEADER+prg, binary)
# TODO: is there a way to not write this to disk? # write to disk so we can load it
# A: it seems there isn't https://stackoverflow.com/questions/28053328/ctypes-cdll-load-library-from-memory-rather-than-file with tempfile.NamedTemporaryFile(delete=True) as cached_file_path:
# because ctypes.CDLL() calls dlopen (POSIX) or LoadLibrary (Windows) which require a file pathlib.Path(cached_file_path.name).write_bytes(self.prg)
if not (CI and ARM64): self.fxn: Any = ctypes.CDLL(str(cached_file_path.name))[name]
with tempfile.NamedTemporaryFile(delete=True) as cached_file_path:
pathlib.Path(cached_file_path.name).write_bytes(self.prg)
self.fxn: Any = ctypes.CDLL(str(cached_file_path.name))[name]
@cache_compiled @cache_compiled
def compile(self, prg, binary) -> bytes: def compile(self, prg) -> bytes:
with tempfile.NamedTemporaryFile(delete=True) as output_file, tempfile.NamedTemporaryFile(delete=True) as temp_file: # TODO: sadly clang doesn't like the use of /dev/stdout here
if not binary: with tempfile.NamedTemporaryFile(delete=True) as output_file:
subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+str(output_file.name)).split(), input=prg.encode('utf-8')) subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+str(output_file.name)).split(), input=prg.encode('utf-8'))
elif CI and ARM64:
prg = prg.split('\n') # type: ignore
self.varsize = align(int(prg[0].split(" ")[1]))
self.ext_calls = {(i*4+ADDRESS):ins.split(" ")[1:] for i, ins in enumerate(filter(lambda ins: ins[:4] != 'loop', prg[6:-3])) if ins[:2] == 'bl'}
prg = "\n".join(['nop' if ins[:2] == 'bl' else ins for ins in prg[6:-3]] + ['\n'])
subprocess.check_output(args=('aarch64-linux-gnu-as -o '+str(temp_file.name)).split(), input=prg.encode('utf-8'))
subprocess.check_output(args=('aarch64-linux-gnu-objcopy -O binary --only-section=.text '+str(temp_file.name)+' '+str(output_file.name)).split())
else:
subprocess.check_output(args=('as -o' + str(temp_file.name)).split(), input=prg.encode('utf-8'))
subprocess.check_output(args=('clang -lm -shared '+str(temp_file.name)+' -o'+str(output_file.name)).split())
return pathlib.Path(output_file.name).read_bytes() return pathlib.Path(output_file.name).read_bytes()
def __call__(self, global_size, local_size, *args, wait=False): def __call__(self, unused_global_size, unused_local_size, *args, wait=False):
if wait: st = time.monotonic() if wait: st = time.perf_counter()
if CI and ARM64: self.fxn(*[x._buf if isinstance(x, RawMallocBuffer) else x for x in args])
mu = Uc(UC_ARCH_ARM64, UC_MODE_ARM) if wait: return time.perf_counter()-st
total_mem = align(reduce(lambda total, arg: total + arg.size * arg.dtype.itemsize, args, len(self.prg)+self.varsize))
mu.mem_map(ADDRESS, total_mem)
for k, fn in self.ext_calls.items(): mu.hook_add(UC_HOOK_CODE, partial(emulate_ext_calls, fn), begin=k, end=k)
mu.mem_write(ADDRESS, self.prg + b''.join(bytes(arg._buf) for arg in args))
addr = ADDRESS + len(self.prg)
for i, arg in enumerate(args):
if i<=7:
mu.reg_write(getattr(arm64_const, f'UC_ARM64_REG_X{i}'), addr)
else:
# NOTE: In ARM, args beyond the first 8 are placed on the stack it also account for the stack red zone.
mu.mem_write(ADDRESS + total_mem - (len(args[8:])+2)*8 + 8*(i-8), addr.to_bytes(8, 'little'))
addr += arg.size * arg.dtype.itemsize
mu.reg_write(arm64_const.UC_ARM64_REG_SP, ADDRESS + total_mem - (len(args[8:])+2)*8)
mu.emu_start(ADDRESS, ADDRESS + len(self.prg))
args[0]._buf = mu.mem_read(mu.reg_read(arm64_const.UC_ARM64_REG_X0), args[0].size * args[0].dtype.itemsize)
else:
self.fxn(*[x._buf if isinstance(x, RawMallocBuffer) else x for x in args])
if wait: return time.monotonic()-st
renderer = fromimport("tinygrad.renderer.assembly_arm64", "uops_to_arm64_asm") if ARM64 else functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict", arg_int_prefix="const int")) renderer = functools.partial(uops_to_cstyle, CStyleLanguage(buffer_suffix=" restrict", arg_int_prefix="const int"))
ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, ClangProgram) ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, ClangProgram)

View File

@@ -3,7 +3,7 @@ from pathlib import Path
from typing import Optional, List, Any, Tuple from typing import Optional, List, Any, Tuple
import numpy as np import numpy as np
from pycuda.compiler import compile as cuda_compile # type: ignore from pycuda.compiler import compile as cuda_compile # type: ignore
from tinygrad.helpers import DEBUG, getenv, colored from tinygrad.helpers import DEBUG, getenv, colored, cache_compiled
from tinygrad.ops import Compiled, GraphBatchExecutor, ASTRunner from tinygrad.ops import Compiled, GraphBatchExecutor, ASTRunner
from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer, LRUAllocator from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer, LRUAllocator
from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.codegen.kernel import LinearizerOptions
@@ -89,12 +89,8 @@ class CUDAGraph(GraphBatchExecutor):
def exec_instance(self, instid): self.graphs[instid][0].launch() def exec_instance(self, instid): self.graphs[instid][0].launch()
class CUDAProgram: class CUDAProgram:
def __init__(self, name:str, prg:str, binary=False, shared = 0, local_size_override=None): def __init__(self, name:str, prg:str, binary=False, shared=0, local_size_override=None):
if not binary: if not binary: prg = self.compile(prg).decode('utf-8')
try: prg = cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets']).decode('utf-8')
except cuda.CompileError as e:
if DEBUG >= 3: print("FAILED TO BUILD", prg)
raise e
if DEBUG >= 5: print(pretty_ptx(prg)) if DEBUG >= 5: print(pretty_ptx(prg))
if DEBUG >= 6: if DEBUG >= 6:
try: try:
@@ -106,6 +102,10 @@ class CUDAProgram:
# TODO: name is wrong, so we get it from the ptx using hacks # TODO: name is wrong, so we get it from the ptx using hacks
self.prg, self.shared, self.local_size_override = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0]), shared, local_size_override self.prg, self.shared, self.local_size_override = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0]), shared, local_size_override
@cache_compiled
def compile(self, prg) -> bytes:
return cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets'])
def __call__(self, global_size, local_size, *args, wait=False): def __call__(self, global_size, local_size, *args, wait=False):
if wait: if wait:
start, end = cuda.Event(), cuda.Event() start, end = cuda.Event(), cuda.Event()

View File

@@ -64,11 +64,7 @@ class CLBuffer(RawBufferCopyInOut, RawBufferTransfer):
class CLProgram: class CLProgram:
def __init__(self, name:str, prg:str, binary=False, argdtypes=None, options=None): def __init__(self, name:str, prg:str, binary=False, argdtypes=None, options=None):
self.name, self.clprograms = name, [cl.Program(ctx, ctx.devices, [prg]*len(ctx.devices)) if binary else cl.Program(ctx, prg) for ctx in CL.cl_ctxs] # type: ignore self.name, self.clprograms = name, [cl.Program(ctx, ctx.devices, [prg]*len(ctx.devices)) if binary else cl.Program(ctx, prg) for ctx in CL.cl_ctxs] # type: ignore
try: self._clprgs = [clprogram.build(options=options) for clprogram in self.clprograms]
self._clprgs = [clprogram.build(options=options) for clprogram in self.clprograms]
except cl.RuntimeError as e:
if DEBUG >= 3: print("FAILED TO BUILD", prg)
raise e
self.clprgs = [clprg.__getattr__(name) for clprg in self._clprgs] self.clprgs = [clprg.__getattr__(name) for clprg in self._clprgs]
if DEBUG >= 5 and not OSX: if DEBUG >= 5 and not OSX:
if 'Adreno' in CL.cl_ctxs[0].devices[0].name: if 'Adreno' in CL.cl_ctxs[0].devices[0].name:

View File

@@ -94,13 +94,9 @@ class HIPProgram:
@cache_compiled @cache_compiled
def compile(self, prg, name) -> bytes: def compile(self, prg, name) -> bytes:
try: prog = hip.hiprtcCreateProgram(prg, name, [], [])
prog = hip.hiprtcCreateProgram(prg, name, [], []) hip.hiprtcCompileProgram(prog, [f'--offload-arch={hip.hipGetDeviceProperties(HIP.default_device).gcnArchName}'])
hip.hiprtcCompileProgram(prog, [f'--offload-arch={hip.hipGetDeviceProperties(HIP.default_device).gcnArchName}']) return hip.hiprtcGetCode(prog)
return hip.hiprtcGetCode(prog)
except Exception as e:
if DEBUG >= 3: print("FAILED TO BUILD", prg)
raise e
def __call__(self, global_size, local_size, *args, wait=False): def __call__(self, global_size, local_size, *args, wait=False):
hip.hipSetDevice(args[0]._device) hip.hipSetDevice(args[0]._device)

View File

@@ -1,7 +1,7 @@
import time, hashlib, ctypes import time, ctypes
from typing import ClassVar from typing import ClassVar
from tinygrad.ops import Compiled from tinygrad.ops import Compiled
from tinygrad.helpers import getenv, DEBUG from tinygrad.helpers import getenv, DEBUG, cache_compiled
from ctypes import CFUNCTYPE from ctypes import CFUNCTYPE
from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.llvmir import uops_to_llvm_ir from tinygrad.renderer.llvmir import uops_to_llvm_ir
@@ -46,22 +46,22 @@ class LLVM:
class LLVMProgram: class LLVMProgram:
def __init__(self, name:str, prg:str, binary=False): def __init__(self, name:str, prg:str, binary=False):
self.mod = llvm.parse_assembly(prg) self.prg = prg if binary else self.compile(prg)
self.mod.verify() LLVM().engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(self.prg))
LLVM().optimizer.run(self.mod)
self.mod.name = hashlib.sha1(prg.encode('utf-8')).hexdigest()
if DEBUG >= 5: print(LLVM.target_machine.emit_assembly(self.mod))
LLVM.engine.add_module(self.mod)
LLVM.engine.finalize_object()
self.fxn = LLVM.engine.get_function_address(name) self.fxn = LLVM.engine.get_function_address(name)
def __del__(self): @cache_compiled
if hasattr(self, 'mod'): LLVM.engine.remove_module(self.mod) def compile(self, prg) -> bytes:
mod = llvm.parse_assembly(prg)
mod.verify()
LLVM().optimizer.run(mod)
if DEBUG >= 5: print(LLVM.target_machine.emit_assembly(mod))
return LLVM.target_machine.emit_object(mod)
def __call__(self, unused_global_size, unused_local_size, *bufs, wait=False): def __call__(self, unused_global_size, unused_local_size, *bufs, wait=False):
cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.c_void_p for _ in bufs])(self.fxn) cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.c_void_p for _ in bufs])(self.fxn)
if wait: st = time.monotonic() if wait: st = time.perf_counter()
cfunc(*[x._buf if not isinstance(x, int) else x for x in bufs]) cfunc(*[x._buf if not isinstance(x, int) else x for x in bufs])
if wait: return time.monotonic()-st if wait: return time.perf_counter()-st
LLVMBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False, has_shared=False), uops_to_llvm_ir, LLVMProgram) LLVMBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False, has_shared=False), uops_to_llvm_ir, LLVMProgram)

View File

@@ -1,15 +1,13 @@
# pip3 install pyobjc-framework-Metal pyobjc-framework-Cocoa pyobjc-framework-libdispatch # pip3 install pyobjc-framework-Metal pyobjc-framework-Cocoa pyobjc-framework-libdispatch
import os, subprocess, pathlib, ctypes import os, subprocess, pathlib, ctypes, tempfile
import Metal, Cocoa, libdispatch # type: ignore import Metal, Cocoa, libdispatch # type: ignore
from typing import List, Any, Tuple from typing import List, Any, Tuple
from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, cache_compiled
from tinygrad.ops import Compiled, ASTRunner, BasicBatchExecutor from tinygrad.ops import Compiled, ASTRunner, BasicBatchExecutor
from tinygrad.renderer.metal import MetalRenderer from tinygrad.renderer.metal import MetalRenderer
from tinygrad.runtime.lib import RawBufferMapped, LRUAllocator from tinygrad.runtime.lib import RawBufferMapped, LRUAllocator
METAL_XCODE = getenv("METAL_XCODE")
class MetalAllocator(LRUAllocator): class MetalAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs): return METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared) def _do_alloc(self, size, dtype, device, **kwargs): return METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared)
def _do_free(self, buf): buf.release() def _do_free(self, buf): buf.release()
@@ -62,27 +60,30 @@ def unwrap(x):
class MetalProgram: class MetalProgram:
def __init__(self, name:str, prg:str, binary:bool=False): def __init__(self, name:str, prg:str, binary:bool=False):
if METAL_XCODE: lib = prg if binary else self.compile(prg)
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8')) data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode self.library = unwrap(METAL.device.newLibraryWithData_error_(data, None))
lib = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
self.library = unwrap(METAL.device.newLibraryWithData_error_(data, None))
else:
options = Metal.MTLCompileOptions.alloc().init()
self.library = unwrap(METAL.device.newLibraryWithSource_options_error_(prg, options, None))
self.fxn = self.library.newFunctionWithName_(name) self.fxn = self.library.newFunctionWithName_(name)
# hacks to disassemble shader
if DEBUG >= 5: if DEBUG >= 5:
arc = unwrap(METAL.device.newBinaryArchiveWithDescriptor_error_(Metal.MTLBinaryArchiveDescriptor.alloc().init(), None)) with tempfile.NamedTemporaryFile(delete=True) as shader:
desc = Metal.MTLComputePipelineDescriptor.alloc().init() shader.write(lib)
desc.setComputeFunction_(self.fxn) shader.flush()
unwrap(arc.addComputePipelineFunctionsWithDescriptor_error_(desc, None)) os.system(f"cd {pathlib.Path(__file__).parents[2]}/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}")
unwrap(arc.serializeToURL_error_(Cocoa.NSURL.URLWithString_("file:///tmp/shader.bin"), None))
# clone https://github.com/dougallj/applegpu.git in tinygrad/disassemblers
os.system(f"cd {pathlib.Path(__file__).parents[2]}/disassemblers/applegpu && python3 compiler_explorer.py /tmp/shader.bin")
self.pipeline_state = unwrap(METAL.device.newComputePipelineStateWithFunction_error_(self.fxn, None)) self.pipeline_state = unwrap(METAL.device.newComputePipelineStateWithFunction_error_(self.fxn, None))
@cache_compiled
def compile(self, prg) -> bytes:
if getenv("METAL_XCODE"):
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
return subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
options = Metal.MTLCompileOptions.alloc().init()
library = unwrap(METAL.device.newLibraryWithSource_options_error_(prg, options, None))
# TODO: avoid file write here?
with tempfile.NamedTemporaryFile(delete=True) as output_file:
library.serializeToURL_error_(Cocoa.NSURL.URLWithString_(f"file://{output_file.name}"), None)
return pathlib.Path(output_file.name).read_bytes()
def __call__(self, global_size, local_size, *bufs, wait=False): def __call__(self, global_size, local_size, *bufs, wait=False):
assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(), f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}" assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(), f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}"
command_buffer = METAL.mtl_queue.commandBuffer() command_buffer = METAL.mtl_queue.commandBuffer()