add test that the compiler isn't used (#3025)

* add test that the compiler isn't used

* one print_tree

* improve speed with st size cache

* switch to gpt-2
This commit is contained in:
George Hotz
2024-01-05 17:24:01 -08:00
committed by GitHub
parent 520406cf3a
commit 2a2d3233d2
5 changed files with 63 additions and 6 deletions

View File

@@ -20,6 +20,7 @@ class TestLLaMASpeed(unittest.TestCase):
def test_llama_compile(self):
backup_program = Device[Device.DEFAULT].runtime
backup_allocator = Device[Device.DEFAULT].allocator
backup_compiler = Device[Device.DEFAULT].compiler
Device[Device.DEFAULT].runtime = FakeProgram
Device[Device.DEFAULT].allocator = FakeAllocator()
@@ -33,20 +34,24 @@ class TestLLaMASpeed(unittest.TestCase):
def run_llama(st, empty_method_cache=True):
if empty_method_cache: Device[Device.DEFAULT].get_runner.cache_clear()
tms = [time.perf_counter()]
for i in range(10):
for i in range(5):
model(Tensor([[1,2,3,4]]), i).realize()
tms.append(time.perf_counter())
timings = [(tms[i+1]-tms[i])*1000 for i in range(len(tms)-1)]
print(f"{st:15s} mean runtime: {sum(timings)/len(timings):7.2f}ms, runs: ", ", ".join(f'{x:7.2f}' for x in timings))
run_llama("codegen")
run_llama("methodcache", False)
run_llama("codegen(0)")
run_llama("codegen(1)")
with Profiling(sort='time', frac=0.1):
run_llama("profile")
# test no compiler use for this
Device[Device.DEFAULT].compiler = None
run_llama("methodcache", False)
with Profiling(sort='time', frac=0.1, fn="/tmp/llama.prof"):
run_llama("profile", False)
Device[Device.DEFAULT].runtime = backup_program
Device[Device.DEFAULT].allocator = backup_allocator
Device[Device.DEFAULT].compiler = backup_compiler
if __name__ == '__main__':
unittest.main()

49
test/test_method_cache.py Normal file
View File

@@ -0,0 +1,49 @@
import unittest
from tinygrad import Tensor, Device, Variable
from tinygrad.device import Compiled
from examples.gpt2 import Transformer
from tinygrad.nn.state import get_state_dict
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "only test for compiled backends")
class TestMethodCache(unittest.TestCase):
def setUp(self):
self.backup_compiler = Device[Device.DEFAULT].compiler
def tearDown(self):
Device[Device.DEFAULT].compiler = self.backup_compiler
def test_simple_methodcache(self):
a = Tensor([1])
b = Tensor([2])
c = Tensor([3])
d = Tensor([4])
(a+b).realize()
Device[Device.DEFAULT].compiler = None
(c+d).realize()
def test_nested_methodcache(self):
a,b,c,d = Tensor([1]), Tensor([2]), Tensor([3]), Tensor([4])
((a+b)+(a+b)).realize()
Device[Device.DEFAULT].compiler = None
((c+d)+(c+d)).realize()
def test_nested_methodcache_swap(self):
a,b,c,d = Tensor([1]), Tensor([2]), Tensor([3]), Tensor([4])
((a+b)+(c+d)).realize()
Device[Device.DEFAULT].compiler = None
((c+d)+(a+b)).realize()
def test_small_transformer(self):
args_tiny = {"dim": 16, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 10}
model = Transformer(**args_tiny)
for v in get_state_dict(model).values(): v.assign(Tensor.empty(*v.shape, dtype=v.dtype).realize())
# NOTE: you have to do this twice due to the k-v cache
for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize()
for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize()
Device[Device.DEFAULT].compiler = None
for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize()
if __name__ == '__main__':
unittest.main()

View File

@@ -285,6 +285,7 @@ class Compiled:
def synchronize(self): pass # override this in your device
def to_program(self, k:Linearizer) -> CompiledASTRunner:
assert self.compiler is not None, f"compiler is None, can't build {k.ast}"
k.linearize()
src = self.renderer(to_function_name(k.name), k.uops)
if getenv("DISABLE_COMPILER_CACHE") or '<' in self.compiler.__name__:

View File

@@ -91,13 +91,14 @@ class Timing(contextlib.ContextDecorator):
if self.enabled: print(f"{self.prefix}{self.et*1e-6:6.2f} ms"+(self.on_exit(self.et) if self.on_exit else ""))
class Profiling(contextlib.ContextDecorator):
def __init__(self, enabled=True, sort='cumtime', frac=0.2): self.enabled, self.sort, self.frac = enabled, sort, frac
def __init__(self, enabled=True, sort='cumtime', frac=0.2, fn=None): self.enabled, self.sort, self.frac, self.fn = enabled, sort, frac, fn
def __enter__(self):
self.pr = cProfile.Profile(timer=lambda: int(time.time()*1e9), timeunit=1e-6)
if self.enabled: self.pr.enable()
def __exit__(self, *exc):
if self.enabled:
self.pr.disable()
if self.fn: self.pr.dump_stats(self.fn)
pstats.Stats(self.pr).strip_dirs().sort_stats(self.sort).print_stats(self.frac)
# *** universal database cache ***

View File

@@ -75,6 +75,7 @@ class ShapeTracker:
@property
def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape
@functools.lru_cache(maxsize=None) # NOTE: this keeps all ShapeTrackers alive
def size(self) -> int:
if 0 in self.shape: return 0
ret = self.expr_idxs()[0].max