diff --git a/test/test_kernel_cache.py b/test/test_kernel_cache.py new file mode 100644 index 0000000000..83c87d1dfe --- /dev/null +++ b/test/test_kernel_cache.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python +import unittest +import secrets +import string +import tempfile +import pathlib +from tinygrad.tensor import Tensor +from tinygrad.ops import Device +from tinygrad.helpers import cache_compiled +import tinygrad.runtime.ops_clang + +def generate_random_string(length=16): + alphabet = string.ascii_letters + string.digits + return ''.join(secrets.choice(alphabet) for _ in range(length)) + +class TestKernelCache(unittest.TestCase): + compile_call_count = 0 + + @cache_compiled + def __helper_test_compile(self, prg, output_file=pathlib.Path(tempfile.mktemp()), **kwargs): + self.compile_call_count += 1 + return prg.encode() + + def test_compile_cache(self): + prg1 = generate_random_string(64) + "a" + prg2 = generate_random_string(64) + "b" + cold_compile_res = self.__helper_test_compile(prg1) + warm_compile_res = self.__helper_test_compile(prg1) + assert cold_compile_res == warm_compile_res == prg1.encode() + assert self.compile_call_count == 1 + + prg2_res = self.__helper_test_compile(prg2) + assert prg2_res == prg2.encode() + assert self.compile_call_count == 2 + + def test_kernel_cache_in_action(self): + if Device.DEFAULT not in ["CLANG"]: + self.skipTest("No custom kernel cache is implemented") + + a = Tensor.rand(4,4) + b = Tensor.rand(4,4) + x = a + b + x.realize() + + orig_compile_func = tinygrad.runtime.ops_clang.ClangBuffer.runtime.compile + tinygrad.runtime.ops_clang.ClangBuffer.runtime.compile = None # making it not callable + + a1 = Tensor.rand(4,4) + b1 = Tensor.rand(4,4) + x1 = a1 + b1 + x1.realize() # Same kernel should be from cache. + + tinygrad.runtime.ops_clang.ClangBuffer.runtime.compile = orig_compile_func + +if __name__ == "__main__": + unittest.main() diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index a77447ee97..f24dc938e8 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -1,5 +1,5 @@ from __future__ import annotations -import os, functools, platform, time, re, contextlib, operator +import os, functools, platform, time, re, contextlib, operator, pathlib, hashlib, tempfile import numpy as np from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10 @@ -148,3 +148,14 @@ class GlobalCounters: mem_cached: ClassVar[int] = 0 # NOTE: this is not reset @staticmethod def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0 + +# *** compiled cache decorator *** + +def cache_compiled(func): + def wrapper(self, prg:str, *args, **kwargs) -> bytes: + cache_path, output_file = pathlib.Path(f"{tempfile.gettempdir()}/tinygrad_cc_{hashlib.sha256(prg.encode()).hexdigest()}"), pathlib.Path(tempfile.mktemp()) + if not cache_path.exists(): + output_file.write_bytes(func(self, prg, *args, **kwargs)) + output_file.rename(cache_path) + return cache_path.read_bytes() + return wrapper \ No newline at end of file diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index 2959f75634..d3cbda7999 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -1,7 +1,8 @@ -import os, time, ctypes, hashlib, subprocess, platform, tempfile, functools +import time, ctypes, subprocess, platform, functools, pathlib, tempfile +from typing import Any from functools import partial, reduce from tinygrad.ops import Compiled -from tinygrad.helpers import fromimport, getenv, DEBUG, CI +from tinygrad.helpers import fromimport, getenv, DEBUG, CI, cache_compiled from tinygrad.runtime.lib import RawMallocBuffer from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage @@ -29,32 +30,34 @@ def emulate_ext_calls(fn, uc, address, size, user_data): class ClangProgram: def __init__(self, name:str, prg:str, binary:bool=False): + if binary and DEBUG >= 5: print(prg) + self.prg: Any = self.compile(prg if binary else CLANG_PROGRAM_HEADER+prg, binary) + # TODO: is there a way to not write this to disk? # A: it seems there isn't https://stackoverflow.com/questions/28053328/ctypes-cdll-load-library-from-memory-rather-than-file # because ctypes.CDLL() calls dlopen (POSIX) or LoadLibrary (Windows) which require a file - fn = f"{tempfile.gettempdir()}/clang_{hashlib.md5(prg.encode('utf-8')).hexdigest()}.{args['ext']}" - if binary and DEBUG >= 5: print(prg) - if not os.path.exists(fn): - tmp = f"{fn}.{os.getpid()}.tmp" - if not binary: - prg = CLANG_PROGRAM_HEADER + prg - subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+tmp).split(), input=prg.encode('utf-8')) - os.rename(tmp, fn) - else: - if CI and ARM64: - prg = prg.split('\n') # type: ignore - self.varsize = align(int(prg[0].split(" ")[1])) - self.ext_calls = {(i*4+ADDRESS):ins.split(" ")[1:] for i, ins in enumerate(filter(lambda ins: ins[:4] != 'loop', prg[6:-3])) if ins[:2] == 'bl'} - prg = "\n".join(['nop' if ins[:2] == 'bl' else ins for ins in prg[6:-3]] + ['\n']) - subprocess.check_output(args=('aarch64-linux-gnu-as -o '+tmp).split(), input=prg.encode('utf-8')) - subprocess.check_output(args=('aarch64-linux-gnu-objcopy -O binary --only-section=.text '+tmp+' '+fn+'.bin').split()) - with open(fn + '.bin', 'rb') as f: - self.prg = f.read() - return - subprocess.check_output(args=('as -o' + tmp).split(), input=prg.encode('utf-8')) - subprocess.check_output(args=('clang -lm -shared '+tmp+' -o'+fn).split()) - self.lib = ctypes.CDLL(fn) - self.fxn = self.lib[name] + if not (CI and ARM64): + cached_file_path = pathlib.Path(tempfile.mktemp()) + cached_file_path.write_bytes(self.prg) + self.fxn: Any = ctypes.CDLL(str(cached_file_path))[name] + + @cache_compiled + def compile(self, prg, binary) -> bytes: + output_file, temp_file = pathlib.Path(tempfile.mktemp()), pathlib.Path(tempfile.mktemp()) + if not binary: + subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+str(output_file)).split(), input=prg.encode('utf-8')) + elif CI and ARM64: + prg = prg.split('\n') # type: ignore + self.varsize = align(int(prg[0].split(" ")[1])) + self.ext_calls = {(i*4+ADDRESS):ins.split(" ")[1:] for i, ins in enumerate(filter(lambda ins: ins[:4] != 'loop', prg[6:-3])) if ins[:2] == 'bl'} + prg = "\n".join(['nop' if ins[:2] == 'bl' else ins for ins in prg[6:-3]] + ['\n']) + subprocess.check_output(args=('aarch64-linux-gnu-as -o '+str(temp_file)).split(), input=prg.encode('utf-8')) + subprocess.check_output(args=('aarch64-linux-gnu-objcopy -O binary --only-section=.text '+str(temp_file)+' '+str(output_file)).split()) + else: + subprocess.check_output(args=('as -o' + str(temp_file)).split(), input=prg.encode('utf-8')) + subprocess.check_output(args=('clang -lm -shared '+str(temp_file)+' -o'+str(output_file)).split()) + return output_file.read_bytes() + def __call__(self, global_size, local_size, *args, wait=False): if wait: st = time.monotonic() if CI and ARM64: diff --git a/tinygrad/runtime/ops_hip.py b/tinygrad/runtime/ops_hip.py index a96e46a3bb..a5b68712e5 100644 --- a/tinygrad/runtime/ops_hip.py +++ b/tinygrad/runtime/ops_hip.py @@ -2,7 +2,7 @@ import numpy as np import ctypes, functools, math, collections import extra.hip_wrapper as hip from typing import Tuple, Any, List -from tinygrad.helpers import DEBUG, getenv +from tinygrad.helpers import DEBUG, getenv, cache_compiled from tinygrad.ops import Compiled, ASTRunner, BasicBatchExecutor from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer from tinygrad.codegen.kernel import LinearizerOptions @@ -88,15 +88,8 @@ class RawHIPBuffer(RawBufferCopyInOut, RawBufferTransfer): class HIPProgram: def __init__(self, name:str, prg:str, binary=False): - try: - if not binary: - prog = hip.hiprtcCreateProgram(prg, name, [], []) - device_properties = hip.hipGetDeviceProperties(HIP.default_device) - hip.hiprtcCompileProgram(prog, [f'--offload-arch={device_properties.gcnArchName}']) - prg = hip.hiprtcGetCode(prog) - except Exception as e: - if DEBUG >= 3: print("FAILED TO BUILD", prg) - raise e + prg = prg if binary else self.compile(prg, name) + if DEBUG >= 6: asm = early_exec((["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], prg)) print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x])) @@ -106,6 +99,16 @@ class HIPProgram: hip.hipSetDevice(i) self.prgs.append(hip.hipModuleGetFunction(hip.hipModuleLoadData(prg), name)) + @cache_compiled + def compile(self, prg, name) -> bytes: + try: + prog = hip.hiprtcCreateProgram(prg, name, [], []) + hip.hiprtcCompileProgram(prog, [f'--offload-arch={hip.hipGetDeviceProperties(HIP.default_device).gcnArchName}']) + return hip.hiprtcGetCode(prog) + except Exception as e: + if DEBUG >= 3: print("FAILED TO BUILD", prg) + raise e + def __call__(self, global_size, local_size, *args, wait=False): hip.hipSetDevice(args[0]._device) if wait: