From 893281681667f9b9c4a8e4b6b383ca9a443360b5 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 1 Nov 2023 18:44:00 -0700 Subject: [PATCH] remove arm64, caching for cuda (#2201) * remove arm64, caching for cuda * caching in llvm * switch cache_compiled to new cache * fix clang * caching for metal * fix pylint * cleanups * perf_counter and binary --- .pre-commit-config.yaml | 6 +++ extra/dump_cache.py | 2 +- tinygrad/helpers.py | 11 +++-- tinygrad/runtime/ops_clang.py | 83 ++++++++--------------------------- tinygrad/runtime/ops_cuda.py | 14 +++--- tinygrad/runtime/ops_gpu.py | 6 +-- tinygrad/runtime/ops_hip.py | 10 ++--- tinygrad/runtime/ops_llvm.py | 26 +++++------ tinygrad/runtime/ops_metal.py | 43 +++++++++--------- 9 files changed, 76 insertions(+), 125 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0449f31a89..1b7047675e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,3 +31,9 @@ repos: language: system always_run: true pass_filenames: false + - id: pylint + name: pylint + entry: python -m pylint tinygrad/ + language: system + always_run: true + pass_filenames: false diff --git a/extra/dump_cache.py b/extra/dump_cache.py index f45f7a0b15..ca762238f3 100644 --- a/extra/dump_cache.py +++ b/extra/dump_cache.py @@ -16,5 +16,5 @@ if __name__ == "__main__": cur3.execute(f"SELECT * FROM {table} LIMIT 10") for f in cur3.fetchall(): v = pickle.loads(f[-1]) - print(" ", len(f[0]), f[1:-1], v) + print(" ", len(f[0]) if isinstance(f[0], str) else f[0], f[1:-1], str(v)[0:50]) #print(f"{len(k):10d}, {sk} -> {v}") diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 5938ec919a..418acd1da8 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -1,5 +1,5 @@ from __future__ import annotations -import os, functools, platform, time, re, contextlib, operator, pathlib, hashlib, tempfile, pickle, sqlite3 +import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3 import numpy as np from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10 @@ -156,12 +156,11 @@ class GlobalCounters: # *** compiled cache decorator *** def cache_compiled(func): + if getenv("DISABLE_COMPILER_CACHE"): return func def wrapper(self, prg:str, *args, **kwargs) -> bytes: - cache_path, output_file = pathlib.Path(f"{tempfile.gettempdir()}/tinygrad_cc_{hashlib.sha256(prg.encode()).hexdigest()}"), pathlib.Path(tempfile.mktemp()) - if not cache_path.exists(): - output_file.write_bytes(func(self, prg, *args, **kwargs)) - output_file.rename(cache_path) - return cache_path.read_bytes() + table, key = f"compiler_cache_{type(self).__name__}", hashlib.sha256(prg.encode()).hexdigest() + if (ret:=diskcache_get(table, key)): return ret + return diskcache_put(table, key, func(self, prg, *args, **kwargs)) return wrapper # *** universal database cache *** diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index 3e82b2204a..843704e1eb 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -1,85 +1,38 @@ import time, ctypes, subprocess, platform, functools, pathlib, tempfile from typing import Any -from functools import partial, reduce from tinygrad.ops import Compiled -from tinygrad.helpers import fromimport, getenv, DEBUG, CI, cache_compiled +from tinygrad.helpers import cache_compiled from tinygrad.runtime.lib import RawMallocBuffer from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage -import struct -import numpy as np - -ARM64 = getenv('ARM64', False) -if CI and ARM64: from unicorn import Uc, UC_ARCH_ARM64, UC_MODE_ARM, UC_HOOK_CODE, arm64_const # type: ignore args = { - 'Windows': {'cflags':'', 'ext':'dll', 'exp':'__declspec(dllexport) '}, - 'Linux': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'so', 'exp':''}, - 'Darwin': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'dylib', 'exp':''} + 'Linux': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'so'}, + 'Darwin': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'dylib'} }[platform.system()] CLANG_PROGRAM_HEADER = '#include \n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#include \n' -ADDRESS = 0x10000 - -# Unicorn doesn't support external calls -def align(addr): return (addr+4095) & ~(4095) -mock_lm = {"sinf": np.sin, "sqrtf": np.sqrt, "exp2f": np.exp2, "log2f": np.log2} -def emulate_ext_calls(fn, uc, address, size, user_data): - s_in = struct.unpack('f', struct.pack('I', uc.reg_read(getattr(arm64_const, f'UC_ARM64_REG_S{fn[2][1:]}'))))[0] - uc.reg_write(getattr(arm64_const, f'UC_ARM64_REG_S{fn[1][1:]}'), struct.unpack('I', struct.pack('f', mock_lm[fn[0]](s_in)))[0]) # type: ignore class ClangProgram: - def __init__(self, name:str, prg:str, binary:bool=False): - if binary and DEBUG >= 5: print(prg) - self.prg: Any = self.compile(prg if binary else CLANG_PROGRAM_HEADER+prg, binary) + def __init__(self, name:str, prg:str, binary=False): + self.prg: bytes = prg if binary else self.compile(CLANG_PROGRAM_HEADER+prg) - # TODO: is there a way to not write this to disk? - # A: it seems there isn't https://stackoverflow.com/questions/28053328/ctypes-cdll-load-library-from-memory-rather-than-file - # because ctypes.CDLL() calls dlopen (POSIX) or LoadLibrary (Windows) which require a file - if not (CI and ARM64): - with tempfile.NamedTemporaryFile(delete=True) as cached_file_path: - pathlib.Path(cached_file_path.name).write_bytes(self.prg) - self.fxn: Any = ctypes.CDLL(str(cached_file_path.name))[name] + # write to disk so we can load it + with tempfile.NamedTemporaryFile(delete=True) as cached_file_path: + pathlib.Path(cached_file_path.name).write_bytes(self.prg) + self.fxn: Any = ctypes.CDLL(str(cached_file_path.name))[name] @cache_compiled - def compile(self, prg, binary) -> bytes: - with tempfile.NamedTemporaryFile(delete=True) as output_file, tempfile.NamedTemporaryFile(delete=True) as temp_file: - if not binary: - subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+str(output_file.name)).split(), input=prg.encode('utf-8')) - elif CI and ARM64: - prg = prg.split('\n') # type: ignore - self.varsize = align(int(prg[0].split(" ")[1])) - self.ext_calls = {(i*4+ADDRESS):ins.split(" ")[1:] for i, ins in enumerate(filter(lambda ins: ins[:4] != 'loop', prg[6:-3])) if ins[:2] == 'bl'} - prg = "\n".join(['nop' if ins[:2] == 'bl' else ins for ins in prg[6:-3]] + ['\n']) - subprocess.check_output(args=('aarch64-linux-gnu-as -o '+str(temp_file.name)).split(), input=prg.encode('utf-8')) - subprocess.check_output(args=('aarch64-linux-gnu-objcopy -O binary --only-section=.text '+str(temp_file.name)+' '+str(output_file.name)).split()) - else: - subprocess.check_output(args=('as -o' + str(temp_file.name)).split(), input=prg.encode('utf-8')) - subprocess.check_output(args=('clang -lm -shared '+str(temp_file.name)+' -o'+str(output_file.name)).split()) + def compile(self, prg) -> bytes: + # TODO: sadly clang doesn't like the use of /dev/stdout here + with tempfile.NamedTemporaryFile(delete=True) as output_file: + subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+str(output_file.name)).split(), input=prg.encode('utf-8')) return pathlib.Path(output_file.name).read_bytes() - def __call__(self, global_size, local_size, *args, wait=False): - if wait: st = time.monotonic() - if CI and ARM64: - mu = Uc(UC_ARCH_ARM64, UC_MODE_ARM) - total_mem = align(reduce(lambda total, arg: total + arg.size * arg.dtype.itemsize, args, len(self.prg)+self.varsize)) - mu.mem_map(ADDRESS, total_mem) - for k, fn in self.ext_calls.items(): mu.hook_add(UC_HOOK_CODE, partial(emulate_ext_calls, fn), begin=k, end=k) - mu.mem_write(ADDRESS, self.prg + b''.join(bytes(arg._buf) for arg in args)) - addr = ADDRESS + len(self.prg) - for i, arg in enumerate(args): - if i<=7: - mu.reg_write(getattr(arm64_const, f'UC_ARM64_REG_X{i}'), addr) - else: - # NOTE: In ARM, args beyond the first 8 are placed on the stack it also account for the stack red zone. - mu.mem_write(ADDRESS + total_mem - (len(args[8:])+2)*8 + 8*(i-8), addr.to_bytes(8, 'little')) - addr += arg.size * arg.dtype.itemsize - mu.reg_write(arm64_const.UC_ARM64_REG_SP, ADDRESS + total_mem - (len(args[8:])+2)*8) - mu.emu_start(ADDRESS, ADDRESS + len(self.prg)) - args[0]._buf = mu.mem_read(mu.reg_read(arm64_const.UC_ARM64_REG_X0), args[0].size * args[0].dtype.itemsize) - else: - self.fxn(*[x._buf if isinstance(x, RawMallocBuffer) else x for x in args]) - if wait: return time.monotonic()-st + def __call__(self, unused_global_size, unused_local_size, *args, wait=False): + if wait: st = time.perf_counter() + self.fxn(*[x._buf if isinstance(x, RawMallocBuffer) else x for x in args]) + if wait: return time.perf_counter()-st -renderer = fromimport("tinygrad.renderer.assembly_arm64", "uops_to_arm64_asm") if ARM64 else functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict", arg_int_prefix="const int")) +renderer = functools.partial(uops_to_cstyle, CStyleLanguage(buffer_suffix=" restrict", arg_int_prefix="const int")) ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, ClangProgram) diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 71dc74c0a0..a987cf29fe 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import Optional, List, Any, Tuple import numpy as np from pycuda.compiler import compile as cuda_compile # type: ignore -from tinygrad.helpers import DEBUG, getenv, colored +from tinygrad.helpers import DEBUG, getenv, colored, cache_compiled from tinygrad.ops import Compiled, GraphBatchExecutor, ASTRunner from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer, LRUAllocator from tinygrad.codegen.kernel import LinearizerOptions @@ -89,12 +89,8 @@ class CUDAGraph(GraphBatchExecutor): def exec_instance(self, instid): self.graphs[instid][0].launch() class CUDAProgram: - def __init__(self, name:str, prg:str, binary=False, shared = 0, local_size_override=None): - if not binary: - try: prg = cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets']).decode('utf-8') - except cuda.CompileError as e: - if DEBUG >= 3: print("FAILED TO BUILD", prg) - raise e + def __init__(self, name:str, prg:str, binary=False, shared=0, local_size_override=None): + if not binary: prg = self.compile(prg).decode('utf-8') if DEBUG >= 5: print(pretty_ptx(prg)) if DEBUG >= 6: try: @@ -106,6 +102,10 @@ class CUDAProgram: # TODO: name is wrong, so we get it from the ptx using hacks self.prg, self.shared, self.local_size_override = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0]), shared, local_size_override + @cache_compiled + def compile(self, prg) -> bytes: + return cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets']) + def __call__(self, global_size, local_size, *args, wait=False): if wait: start, end = cuda.Event(), cuda.Event() diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 8a9057b976..9c2715a8b0 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -64,11 +64,7 @@ class CLBuffer(RawBufferCopyInOut, RawBufferTransfer): class CLProgram: def __init__(self, name:str, prg:str, binary=False, argdtypes=None, options=None): self.name, self.clprograms = name, [cl.Program(ctx, ctx.devices, [prg]*len(ctx.devices)) if binary else cl.Program(ctx, prg) for ctx in CL.cl_ctxs] # type: ignore - try: - self._clprgs = [clprogram.build(options=options) for clprogram in self.clprograms] - except cl.RuntimeError as e: - if DEBUG >= 3: print("FAILED TO BUILD", prg) - raise e + self._clprgs = [clprogram.build(options=options) for clprogram in self.clprograms] self.clprgs = [clprg.__getattr__(name) for clprg in self._clprgs] if DEBUG >= 5 and not OSX: if 'Adreno' in CL.cl_ctxs[0].devices[0].name: diff --git a/tinygrad/runtime/ops_hip.py b/tinygrad/runtime/ops_hip.py index 2b3cc832f8..e1b9491d02 100644 --- a/tinygrad/runtime/ops_hip.py +++ b/tinygrad/runtime/ops_hip.py @@ -94,13 +94,9 @@ class HIPProgram: @cache_compiled def compile(self, prg, name) -> bytes: - try: - prog = hip.hiprtcCreateProgram(prg, name, [], []) - hip.hiprtcCompileProgram(prog, [f'--offload-arch={hip.hipGetDeviceProperties(HIP.default_device).gcnArchName}']) - return hip.hiprtcGetCode(prog) - except Exception as e: - if DEBUG >= 3: print("FAILED TO BUILD", prg) - raise e + prog = hip.hiprtcCreateProgram(prg, name, [], []) + hip.hiprtcCompileProgram(prog, [f'--offload-arch={hip.hipGetDeviceProperties(HIP.default_device).gcnArchName}']) + return hip.hiprtcGetCode(prog) def __call__(self, global_size, local_size, *args, wait=False): hip.hipSetDevice(args[0]._device) diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index b6278ea0f5..a8c935efe3 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -1,7 +1,7 @@ -import time, hashlib, ctypes +import time, ctypes from typing import ClassVar from tinygrad.ops import Compiled -from tinygrad.helpers import getenv, DEBUG +from tinygrad.helpers import getenv, DEBUG, cache_compiled from ctypes import CFUNCTYPE from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.renderer.llvmir import uops_to_llvm_ir @@ -46,22 +46,22 @@ class LLVM: class LLVMProgram: def __init__(self, name:str, prg:str, binary=False): - self.mod = llvm.parse_assembly(prg) - self.mod.verify() - LLVM().optimizer.run(self.mod) - self.mod.name = hashlib.sha1(prg.encode('utf-8')).hexdigest() - if DEBUG >= 5: print(LLVM.target_machine.emit_assembly(self.mod)) - LLVM.engine.add_module(self.mod) - LLVM.engine.finalize_object() + self.prg = prg if binary else self.compile(prg) + LLVM().engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(self.prg)) self.fxn = LLVM.engine.get_function_address(name) - def __del__(self): - if hasattr(self, 'mod'): LLVM.engine.remove_module(self.mod) + @cache_compiled + def compile(self, prg) -> bytes: + mod = llvm.parse_assembly(prg) + mod.verify() + LLVM().optimizer.run(mod) + if DEBUG >= 5: print(LLVM.target_machine.emit_assembly(mod)) + return LLVM.target_machine.emit_object(mod) def __call__(self, unused_global_size, unused_local_size, *bufs, wait=False): cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.c_void_p for _ in bufs])(self.fxn) - if wait: st = time.monotonic() + if wait: st = time.perf_counter() cfunc(*[x._buf if not isinstance(x, int) else x for x in bufs]) - if wait: return time.monotonic()-st + if wait: return time.perf_counter()-st LLVMBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False, has_shared=False), uops_to_llvm_ir, LLVMProgram) diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index 79fbac937d..43320c7e88 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -1,15 +1,13 @@ # pip3 install pyobjc-framework-Metal pyobjc-framework-Cocoa pyobjc-framework-libdispatch -import os, subprocess, pathlib, ctypes +import os, subprocess, pathlib, ctypes, tempfile import Metal, Cocoa, libdispatch # type: ignore from typing import List, Any, Tuple from tinygrad.codegen.kernel import LinearizerOptions -from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes +from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, cache_compiled from tinygrad.ops import Compiled, ASTRunner, BasicBatchExecutor from tinygrad.renderer.metal import MetalRenderer from tinygrad.runtime.lib import RawBufferMapped, LRUAllocator -METAL_XCODE = getenv("METAL_XCODE") - class MetalAllocator(LRUAllocator): def _do_alloc(self, size, dtype, device, **kwargs): return METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared) def _do_free(self, buf): buf.release() @@ -62,27 +60,30 @@ def unwrap(x): class MetalProgram: def __init__(self, name:str, prg:str, binary:bool=False): - if METAL_XCODE: - air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8')) - # NOTE: if you run llvm-dis on "air" you can see the llvm bytecode - lib = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air) - data = libdispatch.dispatch_data_create(lib, len(lib), None, None) - self.library = unwrap(METAL.device.newLibraryWithData_error_(data, None)) - else: - options = Metal.MTLCompileOptions.alloc().init() - self.library = unwrap(METAL.device.newLibraryWithSource_options_error_(prg, options, None)) + lib = prg if binary else self.compile(prg) + data = libdispatch.dispatch_data_create(lib, len(lib), None, None) + self.library = unwrap(METAL.device.newLibraryWithData_error_(data, None)) self.fxn = self.library.newFunctionWithName_(name) - # hacks to disassemble shader if DEBUG >= 5: - arc = unwrap(METAL.device.newBinaryArchiveWithDescriptor_error_(Metal.MTLBinaryArchiveDescriptor.alloc().init(), None)) - desc = Metal.MTLComputePipelineDescriptor.alloc().init() - desc.setComputeFunction_(self.fxn) - unwrap(arc.addComputePipelineFunctionsWithDescriptor_error_(desc, None)) - unwrap(arc.serializeToURL_error_(Cocoa.NSURL.URLWithString_("file:///tmp/shader.bin"), None)) - # clone https://github.com/dougallj/applegpu.git in tinygrad/disassemblers - os.system(f"cd {pathlib.Path(__file__).parents[2]}/disassemblers/applegpu && python3 compiler_explorer.py /tmp/shader.bin") + with tempfile.NamedTemporaryFile(delete=True) as shader: + shader.write(lib) + shader.flush() + os.system(f"cd {pathlib.Path(__file__).parents[2]}/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}") self.pipeline_state = unwrap(METAL.device.newComputePipelineStateWithFunction_error_(self.fxn, None)) + @cache_compiled + def compile(self, prg) -> bytes: + if getenv("METAL_XCODE"): + # NOTE: if you run llvm-dis on "air" you can see the llvm bytecode + air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8')) + return subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air) + options = Metal.MTLCompileOptions.alloc().init() + library = unwrap(METAL.device.newLibraryWithSource_options_error_(prg, options, None)) + # TODO: avoid file write here? + with tempfile.NamedTemporaryFile(delete=True) as output_file: + library.serializeToURL_error_(Cocoa.NSURL.URLWithString_(f"file://{output_file.name}"), None) + return pathlib.Path(output_file.name).read_bytes() + def __call__(self, global_size, local_size, *bufs, wait=False): assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(), f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}" command_buffer = METAL.mtl_queue.commandBuffer()