jit: prune independent copies (#9749)

* jit: prune independent copies

* linter

* check kernel cnt
This commit is contained in:
nimlgen
2025-04-05 20:50:28 +03:00
committed by GitHub
parent c2573b247c
commit 5f7c79676f
2 changed files with 25 additions and 7 deletions

View File

@@ -570,6 +570,24 @@ class TestJitPrune(unittest.TestCase):
out = w2_prune(a)
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
def test_prune_w_independent_copy_correct(self):
weights = Tensor.rand(16, device="CPU").realize()
def w2(x) -> Tensor: return (weights*2).contiguous().to(Device.DEFAULT) + x
w2_noprune = TinyJit(w2)
w2_prune = TinyJit(w2, prune=True)
for _ in range(3):
a = Tensor.rand(16).realize()
out = w2_noprune(a)
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
for _ in range(3):
a = Tensor.rand(16).realize()
out = w2_prune(a)
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
assert len(w2_prune.captured.jit_cache) == 1, "prune should have removed the copy"
class TestJitFree(unittest.TestCase):
def test_free_intermediates(self):
ext_tensor = Tensor([1,24,23,45,1])