mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
make method_cache account for compiler (#12156)
* make method_cache account for compiler * sorry
This commit is contained in:
6
test/external/external_test_speed_llama.py
vendored
6
test/external/external_test_speed_llama.py
vendored
@@ -20,7 +20,7 @@ class TestLLaMASpeed(unittest.TestCase):
|
||||
def test_llama_compile(self):
|
||||
backup_program = Device[Device.DEFAULT].runtime
|
||||
backup_allocator = Device[Device.DEFAULT].allocator
|
||||
backup_compiler = Device[Device.DEFAULT].compiler
|
||||
backup_compiler = Device[Device.DEFAULT].compiler.compile_cached
|
||||
Device[Device.DEFAULT].runtime = FakeProgram
|
||||
Device[Device.DEFAULT].allocator = FakeAllocator(Device.default)
|
||||
|
||||
@@ -44,14 +44,14 @@ class TestLLaMASpeed(unittest.TestCase):
|
||||
run_llama("codegen(1)")
|
||||
|
||||
# test no compiler use for this
|
||||
Device[Device.DEFAULT].compiler = None
|
||||
Device[Device.DEFAULT].compiler.compile_cached = None
|
||||
run_llama("methodcache", False)
|
||||
with Profiling(sort='time', frac=0.1, fn="/tmp/llama.prof", ts=5):
|
||||
run_llama("profile", False)
|
||||
|
||||
Device[Device.DEFAULT].runtime = backup_program
|
||||
Device[Device.DEFAULT].allocator = backup_allocator
|
||||
Device[Device.DEFAULT].compiler = backup_compiler
|
||||
Device[Device.DEFAULT].compiler.compile_cached = backup_compiler
|
||||
|
||||
if __name__ == '__main__':
|
||||
TestLLaMASpeed().test_llama_compile()
|
||||
|
||||
@@ -16,14 +16,14 @@ class TestKernelCache(unittest.TestCase):
|
||||
|
||||
a1 = Tensor.rand(4,4).realize()
|
||||
b1 = Tensor.rand(4,4).realize()
|
||||
orig_compile_func = Device['CPU'].compiler
|
||||
Device['CPU'].compiler = None # making it not callable
|
||||
orig_compile_func = Device['CPU'].compiler.compile_cached
|
||||
Device['CPU'].compiler.compile_cached = None # making it not callable
|
||||
|
||||
try:
|
||||
x1 = a1 + b1 + unique_const
|
||||
x1.realize() # Same kernel should be from cache.
|
||||
finally:
|
||||
Device['CPU'].compiler = orig_compile_func
|
||||
Device['CPU'].compiler.compile_cached = orig_compile_func
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -5,9 +5,9 @@ from tinygrad.nn.state import get_state_dict
|
||||
|
||||
class TestMethodCache(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.backup_compiler = Device[Device.DEFAULT].compiler
|
||||
self.backup_compiler = Device[Device.DEFAULT].compiler.compile_cached
|
||||
def tearDown(self):
|
||||
Device[Device.DEFAULT].compiler = self.backup_compiler
|
||||
Device[Device.DEFAULT].compiler.compile_cached = self.backup_compiler
|
||||
|
||||
def test_simple_methodcache(self):
|
||||
a = Tensor([1])
|
||||
@@ -15,19 +15,19 @@ class TestMethodCache(unittest.TestCase):
|
||||
c = Tensor([3])
|
||||
d = Tensor([4])
|
||||
(a+b).realize()
|
||||
Device[Device.DEFAULT].compiler = None
|
||||
Device[Device.DEFAULT].compiler.compile_cached = None
|
||||
(c+d).realize()
|
||||
|
||||
def test_nested_methodcache(self):
|
||||
a,b,c,d = Tensor([1]), Tensor([2]), Tensor([3]), Tensor([4])
|
||||
((a+b)+(a+b)).realize()
|
||||
Device[Device.DEFAULT].compiler = None
|
||||
Device[Device.DEFAULT].compiler.compile_cached = None
|
||||
((c+d)+(c+d)).realize()
|
||||
|
||||
def test_nested_methodcache_swap(self):
|
||||
a,b,c,d = Tensor([1]), Tensor([2]), Tensor([3]), Tensor([4])
|
||||
((a+b)+(c+d)).realize()
|
||||
Device[Device.DEFAULT].compiler = None
|
||||
Device[Device.DEFAULT].compiler.compile_cached = None
|
||||
((c+d)+(a+b)).realize()
|
||||
|
||||
@unittest.skip("incorrect use of transformer")
|
||||
@@ -38,7 +38,7 @@ class TestMethodCache(unittest.TestCase):
|
||||
# NOTE: you have to do this twice due to the k-v cache
|
||||
for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize()
|
||||
for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize()
|
||||
Device[Device.DEFAULT].compiler = None
|
||||
Device[Device.DEFAULT].compiler.compile_cached = None
|
||||
for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -140,13 +140,13 @@ class BufferXfer(BufferCopy):
|
||||
|
||||
# **************** method cache ****************
|
||||
|
||||
method_cache: dict[tuple[str, bytes, tuple[int, ...], bool], CompiledRunner] = {}
|
||||
method_cache: dict[tuple[str, type, bytes, tuple[int, ...], bool], CompiledRunner] = {}
|
||||
def get_runner(device:str, ast:UOp) -> CompiledRunner:
|
||||
# TODO: this should be all context relevant to rendering
|
||||
context = (BEAM.value, NOOPT.value, DEVECTORIZE.value)
|
||||
ckey = (device, ast.key, context, False)
|
||||
ckey = (device, type(Device[device].compiler), ast.key, context, False)
|
||||
if cret:=method_cache.get(ckey): return cret
|
||||
bkey = (device.split(":")[0], ast.key, context, True)
|
||||
bkey = (device.split(":")[0], type(Device[device].compiler), ast.key, context, True)
|
||||
if bret:=method_cache.get(bkey):
|
||||
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user