remove arm64, caching for cuda (#2201)

* remove arm64, caching for cuda

* caching in llvm

* switch cache_compiled to new cache

* fix clang

* caching for metal

* fix pylint

* cleanups

* perf_counter and binary
This commit is contained in:
George Hotz
2023-11-01 18:44:00 -07:00
committed by GitHub
parent 7103b716c4
commit 8932816816
9 changed files with 76 additions and 125 deletions

View File

@@ -31,3 +31,9 @@ repos:
language: system
always_run: true
pass_filenames: false
- id: pylint
name: pylint
entry: python -m pylint tinygrad/
language: system
always_run: true
pass_filenames: false

View File

@@ -16,5 +16,5 @@ if __name__ == "__main__":
cur3.execute(f"SELECT * FROM {table} LIMIT 10")
for f in cur3.fetchall():
v = pickle.loads(f[-1])
print(" ", len(f[0]), f[1:-1], v)
print(" ", len(f[0]) if isinstance(f[0], str) else f[0], f[1:-1], str(v)[0:50])
#print(f"{len(k):10d}, {sk} -> {v}")

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
import os, functools, platform, time, re, contextlib, operator, pathlib, hashlib, tempfile, pickle, sqlite3
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3
import numpy as np
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
@@ -156,12 +156,11 @@ class GlobalCounters:
# *** compiled cache decorator ***
def cache_compiled(func):
if getenv("DISABLE_COMPILER_CACHE"): return func
def wrapper(self, prg:str, *args, **kwargs) -> bytes:
cache_path, output_file = pathlib.Path(f"{tempfile.gettempdir()}/tinygrad_cc_{hashlib.sha256(prg.encode()).hexdigest()}"), pathlib.Path(tempfile.mktemp())
if not cache_path.exists():
output_file.write_bytes(func(self, prg, *args, **kwargs))
output_file.rename(cache_path)
return cache_path.read_bytes()
table, key = f"compiler_cache_{type(self).__name__}", hashlib.sha256(prg.encode()).hexdigest()
if (ret:=diskcache_get(table, key)): return ret
return diskcache_put(table, key, func(self, prg, *args, **kwargs))
return wrapper
# *** universal database cache ***

View File

@@ -1,85 +1,38 @@
import time, ctypes, subprocess, platform, functools, pathlib, tempfile
from typing import Any
from functools import partial, reduce
from tinygrad.ops import Compiled
from tinygrad.helpers import fromimport, getenv, DEBUG, CI, cache_compiled
from tinygrad.helpers import cache_compiled
from tinygrad.runtime.lib import RawMallocBuffer
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
import struct
import numpy as np
ARM64 = getenv('ARM64', False)
if CI and ARM64: from unicorn import Uc, UC_ARCH_ARM64, UC_MODE_ARM, UC_HOOK_CODE, arm64_const # type: ignore
args = {
'Windows': {'cflags':'', 'ext':'dll', 'exp':'__declspec(dllexport) '},
'Linux': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'so', 'exp':''},
'Darwin': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'dylib', 'exp':''}
'Linux': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'so'},
'Darwin': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'dylib'}
}[platform.system()]
CLANG_PROGRAM_HEADER = '#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#include <stdbool.h>\n'
ADDRESS = 0x10000
# Unicorn doesn't support external calls
def align(addr): return (addr+4095) & ~(4095)
mock_lm = {"sinf": np.sin, "sqrtf": np.sqrt, "exp2f": np.exp2, "log2f": np.log2}
def emulate_ext_calls(fn, uc, address, size, user_data):
s_in = struct.unpack('f', struct.pack('I', uc.reg_read(getattr(arm64_const, f'UC_ARM64_REG_S{fn[2][1:]}'))))[0]
uc.reg_write(getattr(arm64_const, f'UC_ARM64_REG_S{fn[1][1:]}'), struct.unpack('I', struct.pack('f', mock_lm[fn[0]](s_in)))[0]) # type: ignore
class ClangProgram:
def __init__(self, name:str, prg:str, binary:bool=False):
if binary and DEBUG >= 5: print(prg)
self.prg: Any = self.compile(prg if binary else CLANG_PROGRAM_HEADER+prg, binary)
def __init__(self, name:str, prg:str, binary=False):
self.prg: bytes = prg if binary else self.compile(CLANG_PROGRAM_HEADER+prg)
# TODO: is there a way to not write this to disk?
# A: it seems there isn't https://stackoverflow.com/questions/28053328/ctypes-cdll-load-library-from-memory-rather-than-file
# because ctypes.CDLL() calls dlopen (POSIX) or LoadLibrary (Windows) which require a file
if not (CI and ARM64):
with tempfile.NamedTemporaryFile(delete=True) as cached_file_path:
pathlib.Path(cached_file_path.name).write_bytes(self.prg)
self.fxn: Any = ctypes.CDLL(str(cached_file_path.name))[name]
# write to disk so we can load it
with tempfile.NamedTemporaryFile(delete=True) as cached_file_path:
pathlib.Path(cached_file_path.name).write_bytes(self.prg)
self.fxn: Any = ctypes.CDLL(str(cached_file_path.name))[name]
@cache_compiled
def compile(self, prg, binary) -> bytes:
with tempfile.NamedTemporaryFile(delete=True) as output_file, tempfile.NamedTemporaryFile(delete=True) as temp_file:
if not binary:
subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+str(output_file.name)).split(), input=prg.encode('utf-8'))
elif CI and ARM64:
prg = prg.split('\n') # type: ignore
self.varsize = align(int(prg[0].split(" ")[1]))
self.ext_calls = {(i*4+ADDRESS):ins.split(" ")[1:] for i, ins in enumerate(filter(lambda ins: ins[:4] != 'loop', prg[6:-3])) if ins[:2] == 'bl'}
prg = "\n".join(['nop' if ins[:2] == 'bl' else ins for ins in prg[6:-3]] + ['\n'])
subprocess.check_output(args=('aarch64-linux-gnu-as -o '+str(temp_file.name)).split(), input=prg.encode('utf-8'))
subprocess.check_output(args=('aarch64-linux-gnu-objcopy -O binary --only-section=.text '+str(temp_file.name)+' '+str(output_file.name)).split())
else:
subprocess.check_output(args=('as -o' + str(temp_file.name)).split(), input=prg.encode('utf-8'))
subprocess.check_output(args=('clang -lm -shared '+str(temp_file.name)+' -o'+str(output_file.name)).split())
def compile(self, prg) -> bytes:
# TODO: sadly clang doesn't like the use of /dev/stdout here
with tempfile.NamedTemporaryFile(delete=True) as output_file:
subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+str(output_file.name)).split(), input=prg.encode('utf-8'))
return pathlib.Path(output_file.name).read_bytes()
def __call__(self, global_size, local_size, *args, wait=False):
if wait: st = time.monotonic()
if CI and ARM64:
mu = Uc(UC_ARCH_ARM64, UC_MODE_ARM)
total_mem = align(reduce(lambda total, arg: total + arg.size * arg.dtype.itemsize, args, len(self.prg)+self.varsize))
mu.mem_map(ADDRESS, total_mem)
for k, fn in self.ext_calls.items(): mu.hook_add(UC_HOOK_CODE, partial(emulate_ext_calls, fn), begin=k, end=k)
mu.mem_write(ADDRESS, self.prg + b''.join(bytes(arg._buf) for arg in args))
addr = ADDRESS + len(self.prg)
for i, arg in enumerate(args):
if i<=7:
mu.reg_write(getattr(arm64_const, f'UC_ARM64_REG_X{i}'), addr)
else:
# NOTE: In ARM, args beyond the first 8 are placed on the stack it also account for the stack red zone.
mu.mem_write(ADDRESS + total_mem - (len(args[8:])+2)*8 + 8*(i-8), addr.to_bytes(8, 'little'))
addr += arg.size * arg.dtype.itemsize
mu.reg_write(arm64_const.UC_ARM64_REG_SP, ADDRESS + total_mem - (len(args[8:])+2)*8)
mu.emu_start(ADDRESS, ADDRESS + len(self.prg))
args[0]._buf = mu.mem_read(mu.reg_read(arm64_const.UC_ARM64_REG_X0), args[0].size * args[0].dtype.itemsize)
else:
self.fxn(*[x._buf if isinstance(x, RawMallocBuffer) else x for x in args])
if wait: return time.monotonic()-st
def __call__(self, unused_global_size, unused_local_size, *args, wait=False):
if wait: st = time.perf_counter()
self.fxn(*[x._buf if isinstance(x, RawMallocBuffer) else x for x in args])
if wait: return time.perf_counter()-st
renderer = fromimport("tinygrad.renderer.assembly_arm64", "uops_to_arm64_asm") if ARM64 else functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict", arg_int_prefix="const int"))
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(buffer_suffix=" restrict", arg_int_prefix="const int"))
ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, ClangProgram)

View File

@@ -3,7 +3,7 @@ from pathlib import Path
from typing import Optional, List, Any, Tuple
import numpy as np
from pycuda.compiler import compile as cuda_compile # type: ignore
from tinygrad.helpers import DEBUG, getenv, colored
from tinygrad.helpers import DEBUG, getenv, colored, cache_compiled
from tinygrad.ops import Compiled, GraphBatchExecutor, ASTRunner
from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer, LRUAllocator
from tinygrad.codegen.kernel import LinearizerOptions
@@ -89,12 +89,8 @@ class CUDAGraph(GraphBatchExecutor):
def exec_instance(self, instid): self.graphs[instid][0].launch()
class CUDAProgram:
def __init__(self, name:str, prg:str, binary=False, shared = 0, local_size_override=None):
if not binary:
try: prg = cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets']).decode('utf-8')
except cuda.CompileError as e:
if DEBUG >= 3: print("FAILED TO BUILD", prg)
raise e
def __init__(self, name:str, prg:str, binary=False, shared=0, local_size_override=None):
if not binary: prg = self.compile(prg).decode('utf-8')
if DEBUG >= 5: print(pretty_ptx(prg))
if DEBUG >= 6:
try:
@@ -106,6 +102,10 @@ class CUDAProgram:
# TODO: name is wrong, so we get it from the ptx using hacks
self.prg, self.shared, self.local_size_override = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0]), shared, local_size_override
@cache_compiled
def compile(self, prg) -> bytes:
return cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets'])
def __call__(self, global_size, local_size, *args, wait=False):
if wait:
start, end = cuda.Event(), cuda.Event()

View File

@@ -64,11 +64,7 @@ class CLBuffer(RawBufferCopyInOut, RawBufferTransfer):
class CLProgram:
def __init__(self, name:str, prg:str, binary=False, argdtypes=None, options=None):
self.name, self.clprograms = name, [cl.Program(ctx, ctx.devices, [prg]*len(ctx.devices)) if binary else cl.Program(ctx, prg) for ctx in CL.cl_ctxs] # type: ignore
try:
self._clprgs = [clprogram.build(options=options) for clprogram in self.clprograms]
except cl.RuntimeError as e:
if DEBUG >= 3: print("FAILED TO BUILD", prg)
raise e
self._clprgs = [clprogram.build(options=options) for clprogram in self.clprograms]
self.clprgs = [clprg.__getattr__(name) for clprg in self._clprgs]
if DEBUG >= 5 and not OSX:
if 'Adreno' in CL.cl_ctxs[0].devices[0].name:

View File

@@ -94,13 +94,9 @@ class HIPProgram:
@cache_compiled
def compile(self, prg, name) -> bytes:
try:
prog = hip.hiprtcCreateProgram(prg, name, [], [])
hip.hiprtcCompileProgram(prog, [f'--offload-arch={hip.hipGetDeviceProperties(HIP.default_device).gcnArchName}'])
return hip.hiprtcGetCode(prog)
except Exception as e:
if DEBUG >= 3: print("FAILED TO BUILD", prg)
raise e
prog = hip.hiprtcCreateProgram(prg, name, [], [])
hip.hiprtcCompileProgram(prog, [f'--offload-arch={hip.hipGetDeviceProperties(HIP.default_device).gcnArchName}'])
return hip.hiprtcGetCode(prog)
def __call__(self, global_size, local_size, *args, wait=False):
hip.hipSetDevice(args[0]._device)

View File

@@ -1,7 +1,7 @@
import time, hashlib, ctypes
import time, ctypes
from typing import ClassVar
from tinygrad.ops import Compiled
from tinygrad.helpers import getenv, DEBUG
from tinygrad.helpers import getenv, DEBUG, cache_compiled
from ctypes import CFUNCTYPE
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.llvmir import uops_to_llvm_ir
@@ -46,22 +46,22 @@ class LLVM:
class LLVMProgram:
def __init__(self, name:str, prg:str, binary=False):
self.mod = llvm.parse_assembly(prg)
self.mod.verify()
LLVM().optimizer.run(self.mod)
self.mod.name = hashlib.sha1(prg.encode('utf-8')).hexdigest()
if DEBUG >= 5: print(LLVM.target_machine.emit_assembly(self.mod))
LLVM.engine.add_module(self.mod)
LLVM.engine.finalize_object()
self.prg = prg if binary else self.compile(prg)
LLVM().engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(self.prg))
self.fxn = LLVM.engine.get_function_address(name)
def __del__(self):
if hasattr(self, 'mod'): LLVM.engine.remove_module(self.mod)
@cache_compiled
def compile(self, prg) -> bytes:
mod = llvm.parse_assembly(prg)
mod.verify()
LLVM().optimizer.run(mod)
if DEBUG >= 5: print(LLVM.target_machine.emit_assembly(mod))
return LLVM.target_machine.emit_object(mod)
def __call__(self, unused_global_size, unused_local_size, *bufs, wait=False):
cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.c_void_p for _ in bufs])(self.fxn)
if wait: st = time.monotonic()
if wait: st = time.perf_counter()
cfunc(*[x._buf if not isinstance(x, int) else x for x in bufs])
if wait: return time.monotonic()-st
if wait: return time.perf_counter()-st
LLVMBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False, has_shared=False), uops_to_llvm_ir, LLVMProgram)

View File

@@ -1,15 +1,13 @@
# pip3 install pyobjc-framework-Metal pyobjc-framework-Cocoa pyobjc-framework-libdispatch
import os, subprocess, pathlib, ctypes
import os, subprocess, pathlib, ctypes, tempfile
import Metal, Cocoa, libdispatch # type: ignore
from typing import List, Any, Tuple
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, cache_compiled
from tinygrad.ops import Compiled, ASTRunner, BasicBatchExecutor
from tinygrad.renderer.metal import MetalRenderer
from tinygrad.runtime.lib import RawBufferMapped, LRUAllocator
METAL_XCODE = getenv("METAL_XCODE")
class MetalAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs): return METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared)
def _do_free(self, buf): buf.release()
@@ -62,27 +60,30 @@ def unwrap(x):
class MetalProgram:
def __init__(self, name:str, prg:str, binary:bool=False):
if METAL_XCODE:
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
lib = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
self.library = unwrap(METAL.device.newLibraryWithData_error_(data, None))
else:
options = Metal.MTLCompileOptions.alloc().init()
self.library = unwrap(METAL.device.newLibraryWithSource_options_error_(prg, options, None))
lib = prg if binary else self.compile(prg)
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
self.library = unwrap(METAL.device.newLibraryWithData_error_(data, None))
self.fxn = self.library.newFunctionWithName_(name)
# hacks to disassemble shader
if DEBUG >= 5:
arc = unwrap(METAL.device.newBinaryArchiveWithDescriptor_error_(Metal.MTLBinaryArchiveDescriptor.alloc().init(), None))
desc = Metal.MTLComputePipelineDescriptor.alloc().init()
desc.setComputeFunction_(self.fxn)
unwrap(arc.addComputePipelineFunctionsWithDescriptor_error_(desc, None))
unwrap(arc.serializeToURL_error_(Cocoa.NSURL.URLWithString_("file:///tmp/shader.bin"), None))
# clone https://github.com/dougallj/applegpu.git in tinygrad/disassemblers
os.system(f"cd {pathlib.Path(__file__).parents[2]}/disassemblers/applegpu && python3 compiler_explorer.py /tmp/shader.bin")
with tempfile.NamedTemporaryFile(delete=True) as shader:
shader.write(lib)
shader.flush()
os.system(f"cd {pathlib.Path(__file__).parents[2]}/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}")
self.pipeline_state = unwrap(METAL.device.newComputePipelineStateWithFunction_error_(self.fxn, None))
@cache_compiled
def compile(self, prg) -> bytes:
if getenv("METAL_XCODE"):
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
return subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
options = Metal.MTLCompileOptions.alloc().init()
library = unwrap(METAL.device.newLibraryWithSource_options_error_(prg, options, None))
# TODO: avoid file write here?
with tempfile.NamedTemporaryFile(delete=True) as output_file:
library.serializeToURL_error_(Cocoa.NSURL.URLWithString_(f"file://{output_file.name}"), None)
return pathlib.Path(output_file.name).read_bytes()
def __call__(self, global_size, local_size, *bufs, wait=False):
assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(), f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}"
command_buffer = METAL.mtl_queue.commandBuffer()