Merge branch 'master' into retinanet_mlperf

This commit is contained in:
Francis Lata
2024-09-25 17:50:55 -07:00
2 changed files with 39 additions and 1 deletions

View File

@@ -2,11 +2,16 @@
import gc
import unittest
import numpy as np
from tinygrad.device import Buffer
from tinygrad.engine.realize import run_schedule
from tinygrad.tensor import Tensor
def tensors_allocated():
return sum([isinstance(x, Tensor) for x in gc.get_objects()])
def bufs_allocated():
return sum([isinstance(x, Buffer) for x in gc.get_objects()])
class TestGC(unittest.TestCase):
def test_gc(self):
@@ -35,5 +40,26 @@ class TestGC(unittest.TestCase):
del b
assert (tensors_allocated() == 3)
def test_schedule_gc(self):
init = bufs_allocated()
x = Tensor.ones(256).contiguous().realize()
y = Tensor.ones(5, 5).contiguous()
y.schedule()
del x
del y
self.assertEqual(bufs_allocated()-init, 0)
def test_schedule_gc_with_inputs(self):
init = bufs_allocated()
x = Tensor.ones(256).contiguous().realize()
y = x+Tensor.ones(256).contiguous()
ys = y.schedule()
del x
run_schedule(ys)
np.testing.assert_equal(y.numpy(), np.full((256,), 2))
self.assertEqual(bufs_allocated()-init, 1)
del y
self.assertEqual(bufs_allocated()-init, 0)
if __name__ == '__main__':
unittest.main()

View File

@@ -15,7 +15,7 @@ from tinygrad.shape.view import View
from tinygrad.tensor import Tensor
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps
from tinygrad.ops import graph_rewrite
from tinygrad.helpers import AST_REWRITE, CI, DEBUG, FUSE_ARANGE, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod
from tinygrad.helpers import AST_REWRITE, CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import create_schedule, reduceop_fusor, st_fixup
from tinygrad.engine.realize import CompiledRunner, run_schedule
@@ -1328,6 +1328,18 @@ class TestSchedule(unittest.TestCase):
@unittest.expectedFailure
def test_buf_cnt_over_limit_alt(self): self._test_buf_cnt(63, allowed=3)
def test_schedule_mem_used(self):
base = GlobalCounters.mem_used
Tensor.ones(256).contiguous().realize()
Tensor.ones(5, 5).contiguous().schedule()
self.assertEqual(GlobalCounters.mem_used-base, 0)
def test_schedule_mem_used_with_inputs(self):
base = GlobalCounters.mem_used
x = Tensor.ones(256).contiguous().realize()
(x+Tensor.ones(256).contiguous()).schedule()
self.assertEqual(GlobalCounters.mem_used-base, 1024)
class TestIndexing(unittest.TestCase):
def check_schedule(self, xt:Union[Tensor,List[Tensor]], cnt:int):
with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)):