diff --git a/test/test_gc.py b/test/test_gc.py index 3732802af7..37e632acda 100644 --- a/test/test_gc.py +++ b/test/test_gc.py @@ -2,11 +2,16 @@ import gc import unittest import numpy as np +from tinygrad.device import Buffer +from tinygrad.engine.realize import run_schedule from tinygrad.tensor import Tensor def tensors_allocated(): return sum([isinstance(x, Tensor) for x in gc.get_objects()]) +def bufs_allocated(): + return sum([isinstance(x, Buffer) for x in gc.get_objects()]) + class TestGC(unittest.TestCase): def test_gc(self): @@ -35,5 +40,26 @@ class TestGC(unittest.TestCase): del b assert (tensors_allocated() == 3) + def test_schedule_gc(self): + init = bufs_allocated() + x = Tensor.ones(256).contiguous().realize() + y = Tensor.ones(5, 5).contiguous() + y.schedule() + del x + del y + self.assertEqual(bufs_allocated()-init, 0) + + def test_schedule_gc_with_inputs(self): + init = bufs_allocated() + x = Tensor.ones(256).contiguous().realize() + y = x+Tensor.ones(256).contiguous() + ys = y.schedule() + del x + run_schedule(ys) + np.testing.assert_equal(y.numpy(), np.full((256,), 2)) + self.assertEqual(bufs_allocated()-init, 1) + del y + self.assertEqual(bufs_allocated()-init, 0) + if __name__ == '__main__': unittest.main() diff --git a/test/test_schedule.py b/test/test_schedule.py index 6c7e5036d8..a20137bc0a 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -15,7 +15,7 @@ from tinygrad.shape.view import View from tinygrad.tensor import Tensor from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps from tinygrad.ops import graph_rewrite -from tinygrad.helpers import AST_REWRITE, CI, DEBUG, FUSE_ARANGE, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod +from tinygrad.helpers import AST_REWRITE, CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod from tinygrad.codegen.kernel import Kernel, verify_ast from tinygrad.engine.schedule import create_schedule, reduceop_fusor, st_fixup from tinygrad.engine.realize import CompiledRunner, run_schedule @@ -1328,6 +1328,18 @@ class TestSchedule(unittest.TestCase): @unittest.expectedFailure def test_buf_cnt_over_limit_alt(self): self._test_buf_cnt(63, allowed=3) + def test_schedule_mem_used(self): + base = GlobalCounters.mem_used + Tensor.ones(256).contiguous().realize() + Tensor.ones(5, 5).contiguous().schedule() + self.assertEqual(GlobalCounters.mem_used-base, 0) + + def test_schedule_mem_used_with_inputs(self): + base = GlobalCounters.mem_used + x = Tensor.ones(256).contiguous().realize() + (x+Tensor.ones(256).contiguous()).schedule() + self.assertEqual(GlobalCounters.mem_used-base, 1024) + class TestIndexing(unittest.TestCase): def check_schedule(self, xt:Union[Tensor,List[Tensor]], cnt:int): with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)):