free intermediate buffers in the jit [pr] (#8581)

* free intermediate buffers in the jit [pr]

* intermediates_freed

* deallocate if not allocated

* self._first_run is simpler
This commit is contained in:
George Hotz
2025-01-12 15:41:41 -08:00
committed by GitHub
parent d817dc10db
commit 4ac4c1415a
6 changed files with 75 additions and 19 deletions

View File

@@ -7,7 +7,7 @@ from test.helpers import assert_jit_cache_len
from tinygrad.tensor import Tensor
from tinygrad.engine.jit import TinyJit
from tinygrad.device import Device
from tinygrad.helpers import CI, Context, JIT
from tinygrad.helpers import CI, Context, JIT, GlobalCounters
from tinygrad.dtype import dtypes
from extra.models.unet import ResBlock
@@ -537,6 +537,38 @@ class TestJitPrune(unittest.TestCase):
out = w2_prune(a)
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
class TestJitFree(unittest.TestCase):
def test_free_intermediates(self):
ext_tensor = Tensor([1,24,23,45,1])
@TinyJit
def fxn(x:Tensor):
out = (x*2+ext_tensor).reshape(5,1).expand(5, 100).contiguous()
return out.sum()
for i in range(5):
out = fxn(Tensor([i,1,2,3,4]))
self.assertEqual(out.item(), 11400+200*i)
pre_free = GlobalCounters.mem_used
fxn.captured.free_intermediates()
savings_after_free = pre_free - GlobalCounters.mem_used
self.assertEqual(savings_after_free, 2024)
out = fxn(Tensor([11,1,2,3,4]))
self.assertEqual(out.item(), 13600)
def test_updated_not_freed(self):
x = Tensor([1]).realize()
@TinyJit
def fxn(y):
nonlocal x
x += y
return x
for _ in range(5): fxn(Tensor([1]))
self.assertEqual(x.item(), 6)
pre_free = GlobalCounters.mem_used
fxn.captured.free_intermediates()
savings_after_free = pre_free - GlobalCounters.mem_used
self.assertEqual(savings_after_free, 0)
fxn(Tensor([2]))
self.assertEqual(x.item(), 8)
if __name__ == '__main__':
unittest.main()