diff --git a/test/test_jit.py b/test/test_jit.py index 5ff0bb1b2b..a88f9a52c4 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -492,7 +492,43 @@ class TestCopyInsideJit(unittest.TestCase): a = Tensor.rand(16,16,device="CLANG").realize() b = Tensor.rand(16,16).realize() out = add(a,b) - self.assertListEqual(out.flatten().tolist(), [x+y for x,y in zip(a.flatten().tolist(), b.flatten().tolist())]) + np.testing.assert_allclose(out.flatten().tolist(), [x+y for x,y in zip(a.flatten().tolist(), b.flatten().tolist())]) + +class TestJitPrune(unittest.TestCase): + def test_simple_prune(self): + weights = Tensor.rand(16).realize() + def w2(x) -> Tensor: return (weights*2).contiguous() + x + w2_noprune = TinyJit(w2) + w2_prune = TinyJit(w2, prune=True) + + for _ in range(3): + a = Tensor.rand(16).realize() + out = w2_noprune(a) + np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())]) + assert len(w2_noprune.captured.jit_cache) == 2 + + for _ in range(3): + a = Tensor.rand(16).realize() + out = w2_prune(a) + np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())]) + assert len(w2_prune.captured.jit_cache) == 1 + + def test_prune_w_copy_correct(self): + weights = Tensor.rand(16).realize() + def w2(x) -> Tensor: return (weights*2).contiguous() + x.to(Device.DEFAULT) + w2_noprune = TinyJit(w2) + w2_prune = TinyJit(w2, prune=True) + + for _ in range(3): + a = Tensor.rand(16, device="CLANG").realize() + out = w2_noprune(a) + np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())]) + + for _ in range(3): + a = Tensor.rand(16, device="CLANG").realize() + out = w2_prune(a) + np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())]) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 15f1611f9c..531fd1a19d 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -8,7 +8,7 @@ from tinygrad.device import Buffer, Compiled, Device from tinygrad.dtype import DType from tinygrad.ops import UOp, ssimplify, Variable, sint, sym_infer from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferXfer, CompiledRunner, Runner +from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner from tinygrad.engine.memory import _internal_memory_planner from tinygrad.nn.state import get_parameters from dataclasses import dataclass @@ -268,6 +268,8 @@ class TinyJit(Generic[ReturnType]): if any(b in depends for b in ei.bufs): if isinstance(ei.prg, CompiledRunner): depends.update(cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs) + if isinstance(ei.prg, (BufferCopy, BufferXfer)): + depends.add(cast(Buffer, ei.bufs[0])) pruned, onetime = partition(jit_cache, lambda ei: not isinstance(ei.prg, CompiledRunner) or any(ei.bufs[out] in depends for out in ei.prg.p.outs)) if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")