make method_cache account for compiler (#12156)

* make method_cache account for compiler

* sorry
This commit is contained in:
nimlgen
2025-09-13 17:00:11 +03:00
committed by GitHub
parent 0c392089d9
commit 92df52d79a
4 changed files with 15 additions and 15 deletions

View File

@@ -20,7 +20,7 @@ class TestLLaMASpeed(unittest.TestCase):
def test_llama_compile(self):
backup_program = Device[Device.DEFAULT].runtime
backup_allocator = Device[Device.DEFAULT].allocator
backup_compiler = Device[Device.DEFAULT].compiler
backup_compiler = Device[Device.DEFAULT].compiler.compile_cached
Device[Device.DEFAULT].runtime = FakeProgram
Device[Device.DEFAULT].allocator = FakeAllocator(Device.default)
@@ -44,14 +44,14 @@ class TestLLaMASpeed(unittest.TestCase):
run_llama("codegen(1)")
# test no compiler use for this
Device[Device.DEFAULT].compiler = None
Device[Device.DEFAULT].compiler.compile_cached = None
run_llama("methodcache", False)
with Profiling(sort='time', frac=0.1, fn="/tmp/llama.prof", ts=5):
run_llama("profile", False)
Device[Device.DEFAULT].runtime = backup_program
Device[Device.DEFAULT].allocator = backup_allocator
Device[Device.DEFAULT].compiler = backup_compiler
Device[Device.DEFAULT].compiler.compile_cached = backup_compiler
if __name__ == '__main__':
TestLLaMASpeed().test_llama_compile()

View File

@@ -16,14 +16,14 @@ class TestKernelCache(unittest.TestCase):
a1 = Tensor.rand(4,4).realize()
b1 = Tensor.rand(4,4).realize()
orig_compile_func = Device['CPU'].compiler
Device['CPU'].compiler = None # making it not callable
orig_compile_func = Device['CPU'].compiler.compile_cached
Device['CPU'].compiler.compile_cached = None # making it not callable
try:
x1 = a1 + b1 + unique_const
x1.realize() # Same kernel should be from cache.
finally:
Device['CPU'].compiler = orig_compile_func
Device['CPU'].compiler.compile_cached = orig_compile_func
if __name__ == "__main__":
unittest.main()

View File

@@ -5,9 +5,9 @@ from tinygrad.nn.state import get_state_dict
class TestMethodCache(unittest.TestCase):
def setUp(self):
self.backup_compiler = Device[Device.DEFAULT].compiler
self.backup_compiler = Device[Device.DEFAULT].compiler.compile_cached
def tearDown(self):
Device[Device.DEFAULT].compiler = self.backup_compiler
Device[Device.DEFAULT].compiler.compile_cached = self.backup_compiler
def test_simple_methodcache(self):
a = Tensor([1])
@@ -15,19 +15,19 @@ class TestMethodCache(unittest.TestCase):
c = Tensor([3])
d = Tensor([4])
(a+b).realize()
Device[Device.DEFAULT].compiler = None
Device[Device.DEFAULT].compiler.compile_cached = None
(c+d).realize()
def test_nested_methodcache(self):
a,b,c,d = Tensor([1]), Tensor([2]), Tensor([3]), Tensor([4])
((a+b)+(a+b)).realize()
Device[Device.DEFAULT].compiler = None
Device[Device.DEFAULT].compiler.compile_cached = None
((c+d)+(c+d)).realize()
def test_nested_methodcache_swap(self):
a,b,c,d = Tensor([1]), Tensor([2]), Tensor([3]), Tensor([4])
((a+b)+(c+d)).realize()
Device[Device.DEFAULT].compiler = None
Device[Device.DEFAULT].compiler.compile_cached = None
((c+d)+(a+b)).realize()
@unittest.skip("incorrect use of transformer")
@@ -38,7 +38,7 @@ class TestMethodCache(unittest.TestCase):
# NOTE: you have to do this twice due to the k-v cache
for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize()
for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize()
Device[Device.DEFAULT].compiler = None
Device[Device.DEFAULT].compiler.compile_cached = None
for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize()
if __name__ == '__main__':

View File

@@ -140,13 +140,13 @@ class BufferXfer(BufferCopy):
# **************** method cache ****************
method_cache: dict[tuple[str, bytes, tuple[int, ...], bool], CompiledRunner] = {}
method_cache: dict[tuple[str, type, bytes, tuple[int, ...], bool], CompiledRunner] = {}
def get_runner(device:str, ast:UOp) -> CompiledRunner:
# TODO: this should be all context relevant to rendering
context = (BEAM.value, NOOPT.value, DEVECTORIZE.value)
ckey = (device, ast.key, context, False)
ckey = (device, type(Device[device].compiler), ast.key, context, False)
if cret:=method_cache.get(ckey): return cret
bkey = (device.split(":")[0], ast.key, context, True)
bkey = (device.split(":")[0], type(Device[device].compiler), ast.key, context, True)
if bret:=method_cache.get(bkey):
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib)
else: