mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 06:34:03 -05:00
remove arm64, caching for cuda (#2201)
* remove arm64, caching for cuda * caching in llvm * switch cache_compiled to new cache * fix clang * caching for metal * fix pylint * cleanups * perf_counter and binary
This commit is contained in:
@@ -31,3 +31,9 @@ repos:
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
- id: pylint
|
||||
name: pylint
|
||||
entry: python -m pylint tinygrad/
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
|
||||
@@ -16,5 +16,5 @@ if __name__ == "__main__":
|
||||
cur3.execute(f"SELECT * FROM {table} LIMIT 10")
|
||||
for f in cur3.fetchall():
|
||||
v = pickle.loads(f[-1])
|
||||
print(" ", len(f[0]), f[1:-1], v)
|
||||
print(" ", len(f[0]) if isinstance(f[0], str) else f[0], f[1:-1], str(v)[0:50])
|
||||
#print(f"{len(k):10d}, {sk} -> {v}")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import os, functools, platform, time, re, contextlib, operator, pathlib, hashlib, tempfile, pickle, sqlite3
|
||||
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3
|
||||
import numpy as np
|
||||
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING
|
||||
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
|
||||
@@ -156,12 +156,11 @@ class GlobalCounters:
|
||||
# *** compiled cache decorator ***
|
||||
|
||||
def cache_compiled(func):
|
||||
if getenv("DISABLE_COMPILER_CACHE"): return func
|
||||
def wrapper(self, prg:str, *args, **kwargs) -> bytes:
|
||||
cache_path, output_file = pathlib.Path(f"{tempfile.gettempdir()}/tinygrad_cc_{hashlib.sha256(prg.encode()).hexdigest()}"), pathlib.Path(tempfile.mktemp())
|
||||
if not cache_path.exists():
|
||||
output_file.write_bytes(func(self, prg, *args, **kwargs))
|
||||
output_file.rename(cache_path)
|
||||
return cache_path.read_bytes()
|
||||
table, key = f"compiler_cache_{type(self).__name__}", hashlib.sha256(prg.encode()).hexdigest()
|
||||
if (ret:=diskcache_get(table, key)): return ret
|
||||
return diskcache_put(table, key, func(self, prg, *args, **kwargs))
|
||||
return wrapper
|
||||
|
||||
# *** universal database cache ***
|
||||
|
||||
@@ -1,85 +1,38 @@
|
||||
import time, ctypes, subprocess, platform, functools, pathlib, tempfile
|
||||
from typing import Any
|
||||
from functools import partial, reduce
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.helpers import fromimport, getenv, DEBUG, CI, cache_compiled
|
||||
from tinygrad.helpers import cache_compiled
|
||||
from tinygrad.runtime.lib import RawMallocBuffer
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
|
||||
import struct
|
||||
import numpy as np
|
||||
|
||||
ARM64 = getenv('ARM64', False)
|
||||
if CI and ARM64: from unicorn import Uc, UC_ARCH_ARM64, UC_MODE_ARM, UC_HOOK_CODE, arm64_const # type: ignore
|
||||
|
||||
args = {
|
||||
'Windows': {'cflags':'', 'ext':'dll', 'exp':'__declspec(dllexport) '},
|
||||
'Linux': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'so', 'exp':''},
|
||||
'Darwin': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'dylib', 'exp':''}
|
||||
'Linux': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'so'},
|
||||
'Darwin': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'dylib'}
|
||||
}[platform.system()]
|
||||
|
||||
CLANG_PROGRAM_HEADER = '#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#include <stdbool.h>\n'
|
||||
ADDRESS = 0x10000
|
||||
|
||||
# Unicorn doesn't support external calls
|
||||
def align(addr): return (addr+4095) & ~(4095)
|
||||
mock_lm = {"sinf": np.sin, "sqrtf": np.sqrt, "exp2f": np.exp2, "log2f": np.log2}
|
||||
def emulate_ext_calls(fn, uc, address, size, user_data):
|
||||
s_in = struct.unpack('f', struct.pack('I', uc.reg_read(getattr(arm64_const, f'UC_ARM64_REG_S{fn[2][1:]}'))))[0]
|
||||
uc.reg_write(getattr(arm64_const, f'UC_ARM64_REG_S{fn[1][1:]}'), struct.unpack('I', struct.pack('f', mock_lm[fn[0]](s_in)))[0]) # type: ignore
|
||||
|
||||
class ClangProgram:
|
||||
def __init__(self, name:str, prg:str, binary:bool=False):
|
||||
if binary and DEBUG >= 5: print(prg)
|
||||
self.prg: Any = self.compile(prg if binary else CLANG_PROGRAM_HEADER+prg, binary)
|
||||
def __init__(self, name:str, prg:str, binary=False):
|
||||
self.prg: bytes = prg if binary else self.compile(CLANG_PROGRAM_HEADER+prg)
|
||||
|
||||
# TODO: is there a way to not write this to disk?
|
||||
# A: it seems there isn't https://stackoverflow.com/questions/28053328/ctypes-cdll-load-library-from-memory-rather-than-file
|
||||
# because ctypes.CDLL() calls dlopen (POSIX) or LoadLibrary (Windows) which require a file
|
||||
if not (CI and ARM64):
|
||||
with tempfile.NamedTemporaryFile(delete=True) as cached_file_path:
|
||||
pathlib.Path(cached_file_path.name).write_bytes(self.prg)
|
||||
self.fxn: Any = ctypes.CDLL(str(cached_file_path.name))[name]
|
||||
# write to disk so we can load it
|
||||
with tempfile.NamedTemporaryFile(delete=True) as cached_file_path:
|
||||
pathlib.Path(cached_file_path.name).write_bytes(self.prg)
|
||||
self.fxn: Any = ctypes.CDLL(str(cached_file_path.name))[name]
|
||||
|
||||
@cache_compiled
|
||||
def compile(self, prg, binary) -> bytes:
|
||||
with tempfile.NamedTemporaryFile(delete=True) as output_file, tempfile.NamedTemporaryFile(delete=True) as temp_file:
|
||||
if not binary:
|
||||
subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+str(output_file.name)).split(), input=prg.encode('utf-8'))
|
||||
elif CI and ARM64:
|
||||
prg = prg.split('\n') # type: ignore
|
||||
self.varsize = align(int(prg[0].split(" ")[1]))
|
||||
self.ext_calls = {(i*4+ADDRESS):ins.split(" ")[1:] for i, ins in enumerate(filter(lambda ins: ins[:4] != 'loop', prg[6:-3])) if ins[:2] == 'bl'}
|
||||
prg = "\n".join(['nop' if ins[:2] == 'bl' else ins for ins in prg[6:-3]] + ['\n'])
|
||||
subprocess.check_output(args=('aarch64-linux-gnu-as -o '+str(temp_file.name)).split(), input=prg.encode('utf-8'))
|
||||
subprocess.check_output(args=('aarch64-linux-gnu-objcopy -O binary --only-section=.text '+str(temp_file.name)+' '+str(output_file.name)).split())
|
||||
else:
|
||||
subprocess.check_output(args=('as -o' + str(temp_file.name)).split(), input=prg.encode('utf-8'))
|
||||
subprocess.check_output(args=('clang -lm -shared '+str(temp_file.name)+' -o'+str(output_file.name)).split())
|
||||
def compile(self, prg) -> bytes:
|
||||
# TODO: sadly clang doesn't like the use of /dev/stdout here
|
||||
with tempfile.NamedTemporaryFile(delete=True) as output_file:
|
||||
subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+str(output_file.name)).split(), input=prg.encode('utf-8'))
|
||||
return pathlib.Path(output_file.name).read_bytes()
|
||||
|
||||
def __call__(self, global_size, local_size, *args, wait=False):
|
||||
if wait: st = time.monotonic()
|
||||
if CI and ARM64:
|
||||
mu = Uc(UC_ARCH_ARM64, UC_MODE_ARM)
|
||||
total_mem = align(reduce(lambda total, arg: total + arg.size * arg.dtype.itemsize, args, len(self.prg)+self.varsize))
|
||||
mu.mem_map(ADDRESS, total_mem)
|
||||
for k, fn in self.ext_calls.items(): mu.hook_add(UC_HOOK_CODE, partial(emulate_ext_calls, fn), begin=k, end=k)
|
||||
mu.mem_write(ADDRESS, self.prg + b''.join(bytes(arg._buf) for arg in args))
|
||||
addr = ADDRESS + len(self.prg)
|
||||
for i, arg in enumerate(args):
|
||||
if i<=7:
|
||||
mu.reg_write(getattr(arm64_const, f'UC_ARM64_REG_X{i}'), addr)
|
||||
else:
|
||||
# NOTE: In ARM, args beyond the first 8 are placed on the stack it also account for the stack red zone.
|
||||
mu.mem_write(ADDRESS + total_mem - (len(args[8:])+2)*8 + 8*(i-8), addr.to_bytes(8, 'little'))
|
||||
addr += arg.size * arg.dtype.itemsize
|
||||
mu.reg_write(arm64_const.UC_ARM64_REG_SP, ADDRESS + total_mem - (len(args[8:])+2)*8)
|
||||
mu.emu_start(ADDRESS, ADDRESS + len(self.prg))
|
||||
args[0]._buf = mu.mem_read(mu.reg_read(arm64_const.UC_ARM64_REG_X0), args[0].size * args[0].dtype.itemsize)
|
||||
else:
|
||||
self.fxn(*[x._buf if isinstance(x, RawMallocBuffer) else x for x in args])
|
||||
if wait: return time.monotonic()-st
|
||||
def __call__(self, unused_global_size, unused_local_size, *args, wait=False):
|
||||
if wait: st = time.perf_counter()
|
||||
self.fxn(*[x._buf if isinstance(x, RawMallocBuffer) else x for x in args])
|
||||
if wait: return time.perf_counter()-st
|
||||
|
||||
renderer = fromimport("tinygrad.renderer.assembly_arm64", "uops_to_arm64_asm") if ARM64 else functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict", arg_int_prefix="const int"))
|
||||
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(buffer_suffix=" restrict", arg_int_prefix="const int"))
|
||||
ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, ClangProgram)
|
||||
|
||||
@@ -3,7 +3,7 @@ from pathlib import Path
|
||||
from typing import Optional, List, Any, Tuple
|
||||
import numpy as np
|
||||
from pycuda.compiler import compile as cuda_compile # type: ignore
|
||||
from tinygrad.helpers import DEBUG, getenv, colored
|
||||
from tinygrad.helpers import DEBUG, getenv, colored, cache_compiled
|
||||
from tinygrad.ops import Compiled, GraphBatchExecutor, ASTRunner
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer, LRUAllocator
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
@@ -89,12 +89,8 @@ class CUDAGraph(GraphBatchExecutor):
|
||||
def exec_instance(self, instid): self.graphs[instid][0].launch()
|
||||
|
||||
class CUDAProgram:
|
||||
def __init__(self, name:str, prg:str, binary=False, shared = 0, local_size_override=None):
|
||||
if not binary:
|
||||
try: prg = cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets']).decode('utf-8')
|
||||
except cuda.CompileError as e:
|
||||
if DEBUG >= 3: print("FAILED TO BUILD", prg)
|
||||
raise e
|
||||
def __init__(self, name:str, prg:str, binary=False, shared=0, local_size_override=None):
|
||||
if not binary: prg = self.compile(prg).decode('utf-8')
|
||||
if DEBUG >= 5: print(pretty_ptx(prg))
|
||||
if DEBUG >= 6:
|
||||
try:
|
||||
@@ -106,6 +102,10 @@ class CUDAProgram:
|
||||
# TODO: name is wrong, so we get it from the ptx using hacks
|
||||
self.prg, self.shared, self.local_size_override = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0]), shared, local_size_override
|
||||
|
||||
@cache_compiled
|
||||
def compile(self, prg) -> bytes:
|
||||
return cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets'])
|
||||
|
||||
def __call__(self, global_size, local_size, *args, wait=False):
|
||||
if wait:
|
||||
start, end = cuda.Event(), cuda.Event()
|
||||
|
||||
@@ -64,11 +64,7 @@ class CLBuffer(RawBufferCopyInOut, RawBufferTransfer):
|
||||
class CLProgram:
|
||||
def __init__(self, name:str, prg:str, binary=False, argdtypes=None, options=None):
|
||||
self.name, self.clprograms = name, [cl.Program(ctx, ctx.devices, [prg]*len(ctx.devices)) if binary else cl.Program(ctx, prg) for ctx in CL.cl_ctxs] # type: ignore
|
||||
try:
|
||||
self._clprgs = [clprogram.build(options=options) for clprogram in self.clprograms]
|
||||
except cl.RuntimeError as e:
|
||||
if DEBUG >= 3: print("FAILED TO BUILD", prg)
|
||||
raise e
|
||||
self._clprgs = [clprogram.build(options=options) for clprogram in self.clprograms]
|
||||
self.clprgs = [clprg.__getattr__(name) for clprg in self._clprgs]
|
||||
if DEBUG >= 5 and not OSX:
|
||||
if 'Adreno' in CL.cl_ctxs[0].devices[0].name:
|
||||
|
||||
@@ -94,13 +94,9 @@ class HIPProgram:
|
||||
|
||||
@cache_compiled
|
||||
def compile(self, prg, name) -> bytes:
|
||||
try:
|
||||
prog = hip.hiprtcCreateProgram(prg, name, [], [])
|
||||
hip.hiprtcCompileProgram(prog, [f'--offload-arch={hip.hipGetDeviceProperties(HIP.default_device).gcnArchName}'])
|
||||
return hip.hiprtcGetCode(prog)
|
||||
except Exception as e:
|
||||
if DEBUG >= 3: print("FAILED TO BUILD", prg)
|
||||
raise e
|
||||
prog = hip.hiprtcCreateProgram(prg, name, [], [])
|
||||
hip.hiprtcCompileProgram(prog, [f'--offload-arch={hip.hipGetDeviceProperties(HIP.default_device).gcnArchName}'])
|
||||
return hip.hiprtcGetCode(prog)
|
||||
|
||||
def __call__(self, global_size, local_size, *args, wait=False):
|
||||
hip.hipSetDevice(args[0]._device)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import time, hashlib, ctypes
|
||||
import time, ctypes
|
||||
from typing import ClassVar
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.helpers import getenv, DEBUG
|
||||
from tinygrad.helpers import getenv, DEBUG, cache_compiled
|
||||
from ctypes import CFUNCTYPE
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.renderer.llvmir import uops_to_llvm_ir
|
||||
@@ -46,22 +46,22 @@ class LLVM:
|
||||
|
||||
class LLVMProgram:
|
||||
def __init__(self, name:str, prg:str, binary=False):
|
||||
self.mod = llvm.parse_assembly(prg)
|
||||
self.mod.verify()
|
||||
LLVM().optimizer.run(self.mod)
|
||||
self.mod.name = hashlib.sha1(prg.encode('utf-8')).hexdigest()
|
||||
if DEBUG >= 5: print(LLVM.target_machine.emit_assembly(self.mod))
|
||||
LLVM.engine.add_module(self.mod)
|
||||
LLVM.engine.finalize_object()
|
||||
self.prg = prg if binary else self.compile(prg)
|
||||
LLVM().engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(self.prg))
|
||||
self.fxn = LLVM.engine.get_function_address(name)
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'mod'): LLVM.engine.remove_module(self.mod)
|
||||
@cache_compiled
|
||||
def compile(self, prg) -> bytes:
|
||||
mod = llvm.parse_assembly(prg)
|
||||
mod.verify()
|
||||
LLVM().optimizer.run(mod)
|
||||
if DEBUG >= 5: print(LLVM.target_machine.emit_assembly(mod))
|
||||
return LLVM.target_machine.emit_object(mod)
|
||||
|
||||
def __call__(self, unused_global_size, unused_local_size, *bufs, wait=False):
|
||||
cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.c_void_p for _ in bufs])(self.fxn)
|
||||
if wait: st = time.monotonic()
|
||||
if wait: st = time.perf_counter()
|
||||
cfunc(*[x._buf if not isinstance(x, int) else x for x in bufs])
|
||||
if wait: return time.monotonic()-st
|
||||
if wait: return time.perf_counter()-st
|
||||
|
||||
LLVMBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False, has_shared=False), uops_to_llvm_ir, LLVMProgram)
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
# pip3 install pyobjc-framework-Metal pyobjc-framework-Cocoa pyobjc-framework-libdispatch
|
||||
import os, subprocess, pathlib, ctypes
|
||||
import os, subprocess, pathlib, ctypes, tempfile
|
||||
import Metal, Cocoa, libdispatch # type: ignore
|
||||
from typing import List, Any, Tuple
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, cache_compiled
|
||||
from tinygrad.ops import Compiled, ASTRunner, BasicBatchExecutor
|
||||
from tinygrad.renderer.metal import MetalRenderer
|
||||
from tinygrad.runtime.lib import RawBufferMapped, LRUAllocator
|
||||
|
||||
METAL_XCODE = getenv("METAL_XCODE")
|
||||
|
||||
class MetalAllocator(LRUAllocator):
|
||||
def _do_alloc(self, size, dtype, device, **kwargs): return METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared)
|
||||
def _do_free(self, buf): buf.release()
|
||||
@@ -62,27 +60,30 @@ def unwrap(x):
|
||||
|
||||
class MetalProgram:
|
||||
def __init__(self, name:str, prg:str, binary:bool=False):
|
||||
if METAL_XCODE:
|
||||
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
|
||||
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
|
||||
lib = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
|
||||
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
|
||||
self.library = unwrap(METAL.device.newLibraryWithData_error_(data, None))
|
||||
else:
|
||||
options = Metal.MTLCompileOptions.alloc().init()
|
||||
self.library = unwrap(METAL.device.newLibraryWithSource_options_error_(prg, options, None))
|
||||
lib = prg if binary else self.compile(prg)
|
||||
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
|
||||
self.library = unwrap(METAL.device.newLibraryWithData_error_(data, None))
|
||||
self.fxn = self.library.newFunctionWithName_(name)
|
||||
# hacks to disassemble shader
|
||||
if DEBUG >= 5:
|
||||
arc = unwrap(METAL.device.newBinaryArchiveWithDescriptor_error_(Metal.MTLBinaryArchiveDescriptor.alloc().init(), None))
|
||||
desc = Metal.MTLComputePipelineDescriptor.alloc().init()
|
||||
desc.setComputeFunction_(self.fxn)
|
||||
unwrap(arc.addComputePipelineFunctionsWithDescriptor_error_(desc, None))
|
||||
unwrap(arc.serializeToURL_error_(Cocoa.NSURL.URLWithString_("file:///tmp/shader.bin"), None))
|
||||
# clone https://github.com/dougallj/applegpu.git in tinygrad/disassemblers
|
||||
os.system(f"cd {pathlib.Path(__file__).parents[2]}/disassemblers/applegpu && python3 compiler_explorer.py /tmp/shader.bin")
|
||||
with tempfile.NamedTemporaryFile(delete=True) as shader:
|
||||
shader.write(lib)
|
||||
shader.flush()
|
||||
os.system(f"cd {pathlib.Path(__file__).parents[2]}/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}")
|
||||
self.pipeline_state = unwrap(METAL.device.newComputePipelineStateWithFunction_error_(self.fxn, None))
|
||||
|
||||
@cache_compiled
|
||||
def compile(self, prg) -> bytes:
|
||||
if getenv("METAL_XCODE"):
|
||||
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
|
||||
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
|
||||
return subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
|
||||
options = Metal.MTLCompileOptions.alloc().init()
|
||||
library = unwrap(METAL.device.newLibraryWithSource_options_error_(prg, options, None))
|
||||
# TODO: avoid file write here?
|
||||
with tempfile.NamedTemporaryFile(delete=True) as output_file:
|
||||
library.serializeToURL_error_(Cocoa.NSURL.URLWithString_(f"file://{output_file.name}"), None)
|
||||
return pathlib.Path(output_file.name).read_bytes()
|
||||
|
||||
def __call__(self, global_size, local_size, *bufs, wait=False):
|
||||
assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(), f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}"
|
||||
command_buffer = METAL.mtl_queue.commandBuffer()
|
||||
|
||||
Reference in New Issue
Block a user