mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
jit: prune independent copies (#9749)
* jit: prune independent copies * linter * check kernel cnt
This commit is contained in:
@@ -570,6 +570,24 @@ class TestJitPrune(unittest.TestCase):
|
||||
out = w2_prune(a)
|
||||
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
||||
|
||||
def test_prune_w_independent_copy_correct(self):
|
||||
weights = Tensor.rand(16, device="CPU").realize()
|
||||
def w2(x) -> Tensor: return (weights*2).contiguous().to(Device.DEFAULT) + x
|
||||
w2_noprune = TinyJit(w2)
|
||||
w2_prune = TinyJit(w2, prune=True)
|
||||
|
||||
for _ in range(3):
|
||||
a = Tensor.rand(16).realize()
|
||||
out = w2_noprune(a)
|
||||
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
||||
|
||||
for _ in range(3):
|
||||
a = Tensor.rand(16).realize()
|
||||
out = w2_prune(a)
|
||||
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
||||
|
||||
assert len(w2_prune.captured.jit_cache) == 1, "prune should have removed the copy"
|
||||
|
||||
class TestJitFree(unittest.TestCase):
|
||||
def test_free_intermediates(self):
|
||||
ext_tensor = Tensor([1,24,23,45,1])
|
||||
|
||||
Reference in New Issue
Block a user