mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
kernel cache (#2035)
* init compiled cache * clang not compile to stdout * use kwrags in compile * remove some useless lines * slimmer * fix * tabs * retry * remove decorator * no race in hip * smaller hip * unused import * unused pathlib * path to str * add test * fix linter * less lines? * decorator is back * update tests * no hip version * better comments * a bit better test * linter * work wo decorator * linter happy * simpler return type * more tests * better comment * readable * readable * readable * compile returns bytes * no ununsed imports * readable
This commit is contained in:
56
test/test_kernel_cache.py
Normal file
56
test/test_kernel_cache.py
Normal file
@@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import secrets
|
||||
import string
|
||||
import tempfile
|
||||
import pathlib
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import Device
|
||||
from tinygrad.helpers import cache_compiled
|
||||
import tinygrad.runtime.ops_clang
|
||||
|
||||
def generate_random_string(length=16):
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
return ''.join(secrets.choice(alphabet) for _ in range(length))
|
||||
|
||||
class TestKernelCache(unittest.TestCase):
|
||||
compile_call_count = 0
|
||||
|
||||
@cache_compiled
|
||||
def __helper_test_compile(self, prg, output_file=pathlib.Path(tempfile.mktemp()), **kwargs):
|
||||
self.compile_call_count += 1
|
||||
return prg.encode()
|
||||
|
||||
def test_compile_cache(self):
|
||||
prg1 = generate_random_string(64) + "a"
|
||||
prg2 = generate_random_string(64) + "b"
|
||||
cold_compile_res = self.__helper_test_compile(prg1)
|
||||
warm_compile_res = self.__helper_test_compile(prg1)
|
||||
assert cold_compile_res == warm_compile_res == prg1.encode()
|
||||
assert self.compile_call_count == 1
|
||||
|
||||
prg2_res = self.__helper_test_compile(prg2)
|
||||
assert prg2_res == prg2.encode()
|
||||
assert self.compile_call_count == 2
|
||||
|
||||
def test_kernel_cache_in_action(self):
|
||||
if Device.DEFAULT not in ["CLANG"]:
|
||||
self.skipTest("No custom kernel cache is implemented")
|
||||
|
||||
a = Tensor.rand(4,4)
|
||||
b = Tensor.rand(4,4)
|
||||
x = a + b
|
||||
x.realize()
|
||||
|
||||
orig_compile_func = tinygrad.runtime.ops_clang.ClangBuffer.runtime.compile
|
||||
tinygrad.runtime.ops_clang.ClangBuffer.runtime.compile = None # making it not callable
|
||||
|
||||
a1 = Tensor.rand(4,4)
|
||||
b1 = Tensor.rand(4,4)
|
||||
x1 = a1 + b1
|
||||
x1.realize() # Same kernel should be from cache.
|
||||
|
||||
tinygrad.runtime.ops_clang.ClangBuffer.runtime.compile = orig_compile_func
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import os, functools, platform, time, re, contextlib, operator
|
||||
import os, functools, platform, time, re, contextlib, operator, pathlib, hashlib, tempfile
|
||||
import numpy as np
|
||||
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING
|
||||
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
|
||||
@@ -148,3 +148,14 @@ class GlobalCounters:
|
||||
mem_cached: ClassVar[int] = 0 # NOTE: this is not reset
|
||||
@staticmethod
|
||||
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0
|
||||
|
||||
# *** compiled cache decorator ***
|
||||
|
||||
def cache_compiled(func):
|
||||
def wrapper(self, prg:str, *args, **kwargs) -> bytes:
|
||||
cache_path, output_file = pathlib.Path(f"{tempfile.gettempdir()}/tinygrad_cc_{hashlib.sha256(prg.encode()).hexdigest()}"), pathlib.Path(tempfile.mktemp())
|
||||
if not cache_path.exists():
|
||||
output_file.write_bytes(func(self, prg, *args, **kwargs))
|
||||
output_file.rename(cache_path)
|
||||
return cache_path.read_bytes()
|
||||
return wrapper
|
||||
@@ -1,7 +1,8 @@
|
||||
import os, time, ctypes, hashlib, subprocess, platform, tempfile, functools
|
||||
import time, ctypes, subprocess, platform, functools, pathlib, tempfile
|
||||
from typing import Any
|
||||
from functools import partial, reduce
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.helpers import fromimport, getenv, DEBUG, CI
|
||||
from tinygrad.helpers import fromimport, getenv, DEBUG, CI, cache_compiled
|
||||
from tinygrad.runtime.lib import RawMallocBuffer
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
|
||||
@@ -29,32 +30,34 @@ def emulate_ext_calls(fn, uc, address, size, user_data):
|
||||
|
||||
class ClangProgram:
|
||||
def __init__(self, name:str, prg:str, binary:bool=False):
|
||||
if binary and DEBUG >= 5: print(prg)
|
||||
self.prg: Any = self.compile(prg if binary else CLANG_PROGRAM_HEADER+prg, binary)
|
||||
|
||||
# TODO: is there a way to not write this to disk?
|
||||
# A: it seems there isn't https://stackoverflow.com/questions/28053328/ctypes-cdll-load-library-from-memory-rather-than-file
|
||||
# because ctypes.CDLL() calls dlopen (POSIX) or LoadLibrary (Windows) which require a file
|
||||
fn = f"{tempfile.gettempdir()}/clang_{hashlib.md5(prg.encode('utf-8')).hexdigest()}.{args['ext']}"
|
||||
if binary and DEBUG >= 5: print(prg)
|
||||
if not os.path.exists(fn):
|
||||
tmp = f"{fn}.{os.getpid()}.tmp"
|
||||
if not binary:
|
||||
prg = CLANG_PROGRAM_HEADER + prg
|
||||
subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+tmp).split(), input=prg.encode('utf-8'))
|
||||
os.rename(tmp, fn)
|
||||
else:
|
||||
if CI and ARM64:
|
||||
prg = prg.split('\n') # type: ignore
|
||||
self.varsize = align(int(prg[0].split(" ")[1]))
|
||||
self.ext_calls = {(i*4+ADDRESS):ins.split(" ")[1:] for i, ins in enumerate(filter(lambda ins: ins[:4] != 'loop', prg[6:-3])) if ins[:2] == 'bl'}
|
||||
prg = "\n".join(['nop' if ins[:2] == 'bl' else ins for ins in prg[6:-3]] + ['\n'])
|
||||
subprocess.check_output(args=('aarch64-linux-gnu-as -o '+tmp).split(), input=prg.encode('utf-8'))
|
||||
subprocess.check_output(args=('aarch64-linux-gnu-objcopy -O binary --only-section=.text '+tmp+' '+fn+'.bin').split())
|
||||
with open(fn + '.bin', 'rb') as f:
|
||||
self.prg = f.read()
|
||||
return
|
||||
subprocess.check_output(args=('as -o' + tmp).split(), input=prg.encode('utf-8'))
|
||||
subprocess.check_output(args=('clang -lm -shared '+tmp+' -o'+fn).split())
|
||||
self.lib = ctypes.CDLL(fn)
|
||||
self.fxn = self.lib[name]
|
||||
if not (CI and ARM64):
|
||||
cached_file_path = pathlib.Path(tempfile.mktemp())
|
||||
cached_file_path.write_bytes(self.prg)
|
||||
self.fxn: Any = ctypes.CDLL(str(cached_file_path))[name]
|
||||
|
||||
@cache_compiled
|
||||
def compile(self, prg, binary) -> bytes:
|
||||
output_file, temp_file = pathlib.Path(tempfile.mktemp()), pathlib.Path(tempfile.mktemp())
|
||||
if not binary:
|
||||
subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+str(output_file)).split(), input=prg.encode('utf-8'))
|
||||
elif CI and ARM64:
|
||||
prg = prg.split('\n') # type: ignore
|
||||
self.varsize = align(int(prg[0].split(" ")[1]))
|
||||
self.ext_calls = {(i*4+ADDRESS):ins.split(" ")[1:] for i, ins in enumerate(filter(lambda ins: ins[:4] != 'loop', prg[6:-3])) if ins[:2] == 'bl'}
|
||||
prg = "\n".join(['nop' if ins[:2] == 'bl' else ins for ins in prg[6:-3]] + ['\n'])
|
||||
subprocess.check_output(args=('aarch64-linux-gnu-as -o '+str(temp_file)).split(), input=prg.encode('utf-8'))
|
||||
subprocess.check_output(args=('aarch64-linux-gnu-objcopy -O binary --only-section=.text '+str(temp_file)+' '+str(output_file)).split())
|
||||
else:
|
||||
subprocess.check_output(args=('as -o' + str(temp_file)).split(), input=prg.encode('utf-8'))
|
||||
subprocess.check_output(args=('clang -lm -shared '+str(temp_file)+' -o'+str(output_file)).split())
|
||||
return output_file.read_bytes()
|
||||
|
||||
def __call__(self, global_size, local_size, *args, wait=False):
|
||||
if wait: st = time.monotonic()
|
||||
if CI and ARM64:
|
||||
|
||||
@@ -2,7 +2,7 @@ import numpy as np
|
||||
import ctypes, functools, math, collections
|
||||
import extra.hip_wrapper as hip
|
||||
from typing import Tuple, Any, List
|
||||
from tinygrad.helpers import DEBUG, getenv
|
||||
from tinygrad.helpers import DEBUG, getenv, cache_compiled
|
||||
from tinygrad.ops import Compiled, ASTRunner, BasicBatchExecutor
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
@@ -88,15 +88,8 @@ class RawHIPBuffer(RawBufferCopyInOut, RawBufferTransfer):
|
||||
|
||||
class HIPProgram:
|
||||
def __init__(self, name:str, prg:str, binary=False):
|
||||
try:
|
||||
if not binary:
|
||||
prog = hip.hiprtcCreateProgram(prg, name, [], [])
|
||||
device_properties = hip.hipGetDeviceProperties(HIP.default_device)
|
||||
hip.hiprtcCompileProgram(prog, [f'--offload-arch={device_properties.gcnArchName}'])
|
||||
prg = hip.hiprtcGetCode(prog)
|
||||
except Exception as e:
|
||||
if DEBUG >= 3: print("FAILED TO BUILD", prg)
|
||||
raise e
|
||||
prg = prg if binary else self.compile(prg, name)
|
||||
|
||||
if DEBUG >= 6:
|
||||
asm = early_exec((["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], prg))
|
||||
print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
|
||||
@@ -106,6 +99,16 @@ class HIPProgram:
|
||||
hip.hipSetDevice(i)
|
||||
self.prgs.append(hip.hipModuleGetFunction(hip.hipModuleLoadData(prg), name))
|
||||
|
||||
@cache_compiled
|
||||
def compile(self, prg, name) -> bytes:
|
||||
try:
|
||||
prog = hip.hiprtcCreateProgram(prg, name, [], [])
|
||||
hip.hiprtcCompileProgram(prog, [f'--offload-arch={hip.hipGetDeviceProperties(HIP.default_device).gcnArchName}'])
|
||||
return hip.hiprtcGetCode(prog)
|
||||
except Exception as e:
|
||||
if DEBUG >= 3: print("FAILED TO BUILD", prg)
|
||||
raise e
|
||||
|
||||
def __call__(self, global_size, local_size, *args, wait=False):
|
||||
hip.hipSetDevice(args[0]._device)
|
||||
if wait:
|
||||
|
||||
Reference in New Issue
Block a user