kernel cache (#2035)

* init compiled cache

* clang not compile to stdout

* use kwrags in compile

* remove some useless lines

* slimmer

* fix

* tabs

* retry

* remove decorator

* no race in hip

* smaller hip

* unused import

* unused pathlib

* path to str

* add test

* fix linter

* less lines?

* decorator is back

* update tests

* no hip version

* better comments

* a bit better test

* linter

* work wo decorator

* linter happy

* simpler return type

* more tests

* better comment

* readable

* readable

* readable

* compile returns bytes

* no ununsed imports

* readable
This commit is contained in:
nimlgen
2023-10-13 16:32:01 +03:00
committed by GitHub
parent 6f1810af2d
commit bd42fa0b73
4 changed files with 109 additions and 36 deletions

56
test/test_kernel_cache.py Normal file
View File

@@ -0,0 +1,56 @@
#!/usr/bin/env python
import unittest
import secrets
import string
import tempfile
import pathlib
from tinygrad.tensor import Tensor
from tinygrad.ops import Device
from tinygrad.helpers import cache_compiled
import tinygrad.runtime.ops_clang
def generate_random_string(length=16):
alphabet = string.ascii_letters + string.digits
return ''.join(secrets.choice(alphabet) for _ in range(length))
class TestKernelCache(unittest.TestCase):
compile_call_count = 0
@cache_compiled
def __helper_test_compile(self, prg, output_file=pathlib.Path(tempfile.mktemp()), **kwargs):
self.compile_call_count += 1
return prg.encode()
def test_compile_cache(self):
prg1 = generate_random_string(64) + "a"
prg2 = generate_random_string(64) + "b"
cold_compile_res = self.__helper_test_compile(prg1)
warm_compile_res = self.__helper_test_compile(prg1)
assert cold_compile_res == warm_compile_res == prg1.encode()
assert self.compile_call_count == 1
prg2_res = self.__helper_test_compile(prg2)
assert prg2_res == prg2.encode()
assert self.compile_call_count == 2
def test_kernel_cache_in_action(self):
if Device.DEFAULT not in ["CLANG"]:
self.skipTest("No custom kernel cache is implemented")
a = Tensor.rand(4,4)
b = Tensor.rand(4,4)
x = a + b
x.realize()
orig_compile_func = tinygrad.runtime.ops_clang.ClangBuffer.runtime.compile
tinygrad.runtime.ops_clang.ClangBuffer.runtime.compile = None # making it not callable
a1 = Tensor.rand(4,4)
b1 = Tensor.rand(4,4)
x1 = a1 + b1
x1.realize() # Same kernel should be from cache.
tinygrad.runtime.ops_clang.ClangBuffer.runtime.compile = orig_compile_func
if __name__ == "__main__":
unittest.main()

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
import os, functools, platform, time, re, contextlib, operator
import os, functools, platform, time, re, contextlib, operator, pathlib, hashlib, tempfile
import numpy as np
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
@@ -148,3 +148,14 @@ class GlobalCounters:
mem_cached: ClassVar[int] = 0 # NOTE: this is not reset
@staticmethod
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0
# *** compiled cache decorator ***
def cache_compiled(func):
def wrapper(self, prg:str, *args, **kwargs) -> bytes:
cache_path, output_file = pathlib.Path(f"{tempfile.gettempdir()}/tinygrad_cc_{hashlib.sha256(prg.encode()).hexdigest()}"), pathlib.Path(tempfile.mktemp())
if not cache_path.exists():
output_file.write_bytes(func(self, prg, *args, **kwargs))
output_file.rename(cache_path)
return cache_path.read_bytes()
return wrapper

View File

@@ -1,7 +1,8 @@
import os, time, ctypes, hashlib, subprocess, platform, tempfile, functools
import time, ctypes, subprocess, platform, functools, pathlib, tempfile
from typing import Any
from functools import partial, reduce
from tinygrad.ops import Compiled
from tinygrad.helpers import fromimport, getenv, DEBUG, CI
from tinygrad.helpers import fromimport, getenv, DEBUG, CI, cache_compiled
from tinygrad.runtime.lib import RawMallocBuffer
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
@@ -29,32 +30,34 @@ def emulate_ext_calls(fn, uc, address, size, user_data):
class ClangProgram:
def __init__(self, name:str, prg:str, binary:bool=False):
if binary and DEBUG >= 5: print(prg)
self.prg: Any = self.compile(prg if binary else CLANG_PROGRAM_HEADER+prg, binary)
# TODO: is there a way to not write this to disk?
# A: it seems there isn't https://stackoverflow.com/questions/28053328/ctypes-cdll-load-library-from-memory-rather-than-file
# because ctypes.CDLL() calls dlopen (POSIX) or LoadLibrary (Windows) which require a file
fn = f"{tempfile.gettempdir()}/clang_{hashlib.md5(prg.encode('utf-8')).hexdigest()}.{args['ext']}"
if binary and DEBUG >= 5: print(prg)
if not os.path.exists(fn):
tmp = f"{fn}.{os.getpid()}.tmp"
if not binary:
prg = CLANG_PROGRAM_HEADER + prg
subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+tmp).split(), input=prg.encode('utf-8'))
os.rename(tmp, fn)
else:
if CI and ARM64:
prg = prg.split('\n') # type: ignore
self.varsize = align(int(prg[0].split(" ")[1]))
self.ext_calls = {(i*4+ADDRESS):ins.split(" ")[1:] for i, ins in enumerate(filter(lambda ins: ins[:4] != 'loop', prg[6:-3])) if ins[:2] == 'bl'}
prg = "\n".join(['nop' if ins[:2] == 'bl' else ins for ins in prg[6:-3]] + ['\n'])
subprocess.check_output(args=('aarch64-linux-gnu-as -o '+tmp).split(), input=prg.encode('utf-8'))
subprocess.check_output(args=('aarch64-linux-gnu-objcopy -O binary --only-section=.text '+tmp+' '+fn+'.bin').split())
with open(fn + '.bin', 'rb') as f:
self.prg = f.read()
return
subprocess.check_output(args=('as -o' + tmp).split(), input=prg.encode('utf-8'))
subprocess.check_output(args=('clang -lm -shared '+tmp+' -o'+fn).split())
self.lib = ctypes.CDLL(fn)
self.fxn = self.lib[name]
if not (CI and ARM64):
cached_file_path = pathlib.Path(tempfile.mktemp())
cached_file_path.write_bytes(self.prg)
self.fxn: Any = ctypes.CDLL(str(cached_file_path))[name]
@cache_compiled
def compile(self, prg, binary) -> bytes:
output_file, temp_file = pathlib.Path(tempfile.mktemp()), pathlib.Path(tempfile.mktemp())
if not binary:
subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+str(output_file)).split(), input=prg.encode('utf-8'))
elif CI and ARM64:
prg = prg.split('\n') # type: ignore
self.varsize = align(int(prg[0].split(" ")[1]))
self.ext_calls = {(i*4+ADDRESS):ins.split(" ")[1:] for i, ins in enumerate(filter(lambda ins: ins[:4] != 'loop', prg[6:-3])) if ins[:2] == 'bl'}
prg = "\n".join(['nop' if ins[:2] == 'bl' else ins for ins in prg[6:-3]] + ['\n'])
subprocess.check_output(args=('aarch64-linux-gnu-as -o '+str(temp_file)).split(), input=prg.encode('utf-8'))
subprocess.check_output(args=('aarch64-linux-gnu-objcopy -O binary --only-section=.text '+str(temp_file)+' '+str(output_file)).split())
else:
subprocess.check_output(args=('as -o' + str(temp_file)).split(), input=prg.encode('utf-8'))
subprocess.check_output(args=('clang -lm -shared '+str(temp_file)+' -o'+str(output_file)).split())
return output_file.read_bytes()
def __call__(self, global_size, local_size, *args, wait=False):
if wait: st = time.monotonic()
if CI and ARM64:

View File

@@ -2,7 +2,7 @@ import numpy as np
import ctypes, functools, math, collections
import extra.hip_wrapper as hip
from typing import Tuple, Any, List
from tinygrad.helpers import DEBUG, getenv
from tinygrad.helpers import DEBUG, getenv, cache_compiled
from tinygrad.ops import Compiled, ASTRunner, BasicBatchExecutor
from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer
from tinygrad.codegen.kernel import LinearizerOptions
@@ -88,15 +88,8 @@ class RawHIPBuffer(RawBufferCopyInOut, RawBufferTransfer):
class HIPProgram:
def __init__(self, name:str, prg:str, binary=False):
try:
if not binary:
prog = hip.hiprtcCreateProgram(prg, name, [], [])
device_properties = hip.hipGetDeviceProperties(HIP.default_device)
hip.hiprtcCompileProgram(prog, [f'--offload-arch={device_properties.gcnArchName}'])
prg = hip.hiprtcGetCode(prog)
except Exception as e:
if DEBUG >= 3: print("FAILED TO BUILD", prg)
raise e
prg = prg if binary else self.compile(prg, name)
if DEBUG >= 6:
asm = early_exec((["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], prg))
print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
@@ -106,6 +99,16 @@ class HIPProgram:
hip.hipSetDevice(i)
self.prgs.append(hip.hipModuleGetFunction(hip.hipModuleLoadData(prg), name))
@cache_compiled
def compile(self, prg, name) -> bytes:
try:
prog = hip.hiprtcCreateProgram(prg, name, [], [])
hip.hiprtcCompileProgram(prog, [f'--offload-arch={hip.hipGetDeviceProperties(HIP.default_device).gcnArchName}'])
return hip.hiprtcGetCode(prog)
except Exception as e:
if DEBUG >= 3: print("FAILED TO BUILD", prg)
raise e
def __call__(self, global_size, local_size, *args, wait=False):
hip.hipSetDevice(args[0]._device)
if wait: