mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Merge branch 'master' into retinanet_mlperf
This commit is contained in:
@@ -2,11 +2,16 @@
|
||||
import gc
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
def tensors_allocated():
|
||||
return sum([isinstance(x, Tensor) for x in gc.get_objects()])
|
||||
|
||||
def bufs_allocated():
|
||||
return sum([isinstance(x, Buffer) for x in gc.get_objects()])
|
||||
|
||||
class TestGC(unittest.TestCase):
|
||||
|
||||
def test_gc(self):
|
||||
@@ -35,5 +40,26 @@ class TestGC(unittest.TestCase):
|
||||
del b
|
||||
assert (tensors_allocated() == 3)
|
||||
|
||||
def test_schedule_gc(self):
|
||||
init = bufs_allocated()
|
||||
x = Tensor.ones(256).contiguous().realize()
|
||||
y = Tensor.ones(5, 5).contiguous()
|
||||
y.schedule()
|
||||
del x
|
||||
del y
|
||||
self.assertEqual(bufs_allocated()-init, 0)
|
||||
|
||||
def test_schedule_gc_with_inputs(self):
|
||||
init = bufs_allocated()
|
||||
x = Tensor.ones(256).contiguous().realize()
|
||||
y = x+Tensor.ones(256).contiguous()
|
||||
ys = y.schedule()
|
||||
del x
|
||||
run_schedule(ys)
|
||||
np.testing.assert_equal(y.numpy(), np.full((256,), 2))
|
||||
self.assertEqual(bufs_allocated()-init, 1)
|
||||
del y
|
||||
self.assertEqual(bufs_allocated()-init, 0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -15,7 +15,7 @@ from tinygrad.shape.view import View
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps
|
||||
from tinygrad.ops import graph_rewrite
|
||||
from tinygrad.helpers import AST_REWRITE, CI, DEBUG, FUSE_ARANGE, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod
|
||||
from tinygrad.helpers import AST_REWRITE, CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import create_schedule, reduceop_fusor, st_fixup
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule
|
||||
@@ -1328,6 +1328,18 @@ class TestSchedule(unittest.TestCase):
|
||||
@unittest.expectedFailure
|
||||
def test_buf_cnt_over_limit_alt(self): self._test_buf_cnt(63, allowed=3)
|
||||
|
||||
def test_schedule_mem_used(self):
|
||||
base = GlobalCounters.mem_used
|
||||
Tensor.ones(256).contiguous().realize()
|
||||
Tensor.ones(5, 5).contiguous().schedule()
|
||||
self.assertEqual(GlobalCounters.mem_used-base, 0)
|
||||
|
||||
def test_schedule_mem_used_with_inputs(self):
|
||||
base = GlobalCounters.mem_used
|
||||
x = Tensor.ones(256).contiguous().realize()
|
||||
(x+Tensor.ones(256).contiguous()).schedule()
|
||||
self.assertEqual(GlobalCounters.mem_used-base, 1024)
|
||||
|
||||
class TestIndexing(unittest.TestCase):
|
||||
def check_schedule(self, xt:Union[Tensor,List[Tensor]], cnt:int):
|
||||
with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)):
|
||||
|
||||
Reference in New Issue
Block a user