mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
free intermediate buffers in the jit [pr] (#8581)
* free intermediate buffers in the jit [pr] * intermediates_freed * deallocate if not allocated * self._first_run is simpler
This commit is contained in:
@@ -7,7 +7,7 @@ from test.helpers import assert_jit_cache_len
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.helpers import CI, Context, JIT
|
||||
from tinygrad.helpers import CI, Context, JIT, GlobalCounters
|
||||
from tinygrad.dtype import dtypes
|
||||
from extra.models.unet import ResBlock
|
||||
|
||||
@@ -537,6 +537,38 @@ class TestJitPrune(unittest.TestCase):
|
||||
out = w2_prune(a)
|
||||
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
||||
|
||||
class TestJitFree(unittest.TestCase):
|
||||
def test_free_intermediates(self):
|
||||
ext_tensor = Tensor([1,24,23,45,1])
|
||||
@TinyJit
|
||||
def fxn(x:Tensor):
|
||||
out = (x*2+ext_tensor).reshape(5,1).expand(5, 100).contiguous()
|
||||
return out.sum()
|
||||
for i in range(5):
|
||||
out = fxn(Tensor([i,1,2,3,4]))
|
||||
self.assertEqual(out.item(), 11400+200*i)
|
||||
pre_free = GlobalCounters.mem_used
|
||||
fxn.captured.free_intermediates()
|
||||
savings_after_free = pre_free - GlobalCounters.mem_used
|
||||
self.assertEqual(savings_after_free, 2024)
|
||||
out = fxn(Tensor([11,1,2,3,4]))
|
||||
self.assertEqual(out.item(), 13600)
|
||||
|
||||
def test_updated_not_freed(self):
|
||||
x = Tensor([1]).realize()
|
||||
@TinyJit
|
||||
def fxn(y):
|
||||
nonlocal x
|
||||
x += y
|
||||
return x
|
||||
for _ in range(5): fxn(Tensor([1]))
|
||||
self.assertEqual(x.item(), 6)
|
||||
pre_free = GlobalCounters.mem_used
|
||||
fxn.captured.free_intermediates()
|
||||
savings_after_free = pre_free - GlobalCounters.mem_used
|
||||
self.assertEqual(savings_after_free, 0)
|
||||
fxn(Tensor([2]))
|
||||
self.assertEqual(x.item(), 8)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user