From 5f7c79676fa3218eb083e16e6c6ff6c1f327d299 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Sat, 5 Apr 2025 20:50:28 +0300 Subject: [PATCH] jit: prune independent copies (#9749) * jit: prune independent copies * linter * check kernel cnt --- test/test_jit.py | 18 ++++++++++++++++++ tinygrad/engine/jit.py | 14 +++++++------- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 7c50b93adb..e4615a6d45 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -570,6 +570,24 @@ class TestJitPrune(unittest.TestCase): out = w2_prune(a) np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())]) + def test_prune_w_independent_copy_correct(self): + weights = Tensor.rand(16, device="CPU").realize() + def w2(x) -> Tensor: return (weights*2).contiguous().to(Device.DEFAULT) + x + w2_noprune = TinyJit(w2) + w2_prune = TinyJit(w2, prune=True) + + for _ in range(3): + a = Tensor.rand(16).realize() + out = w2_noprune(a) + np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())]) + + for _ in range(3): + a = Tensor.rand(16).realize() + out = w2_prune(a) + np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())]) + + assert len(w2_prune.captured.jit_cache) == 1, "prune should have removed the copy" + class TestJitFree(unittest.TestCase): def test_free_intermediates(self): ext_tensor = Tensor([1,24,23,45,1]) diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 0e36dac219..3d1105baa3 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -130,13 +130,14 @@ class GraphRunner(Runner): # a marker for your graph supporting multiple devices of the same type class MultiGraphRunner(GraphRunner): pass +def get_out_buffers_for_ei(ei:ExecItem) -> list[Buffer]: + if isinstance(ei.prg, CompiledRunner): return [cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins] + if isinstance(ei.prg, (BufferCopy, BufferXfer)): return [cast(Buffer, ei.bufs[0])] + return [] + def update_depends(depends:set[Buffer|None], jit_cache:list[ExecItem]): for ei in jit_cache: - if any(b in depends for b in ei.bufs): - if isinstance(ei.prg, CompiledRunner): - depends.update(cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins) - if isinstance(ei.prg, (BufferCopy, BufferXfer)): - depends.add(cast(Buffer, ei.bufs[0])) + if any(b in depends for b in ei.bufs): depends.update(get_out_buffers_for_ei(ei)) ReturnType = TypeVar('ReturnType') @dataclass @@ -294,8 +295,7 @@ class TinyJit(Generic[ReturnType]): if self.prune: depends = set(input_buffers) update_depends(depends, jit_cache) - pruned, onetime = partition(jit_cache, - lambda ei: not isinstance(ei.prg, CompiledRunner) or any(ei.bufs[out] in depends for out in ei.prg.p.outs)) + pruned, onetime = partition(jit_cache, lambda ei: any(b in depends for b in get_out_buffers_for_ei(ei))) if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels") # run the onetime kernels here for ei in onetime: