simplify CacheCollector (#1944)

* rewrite cc

* fix

* fix tests

* fix all tests

* is it better

* better with shape

* cleaner

* linter fix

* no ;

* better comment

* better comments

* no thneed changes
This commit is contained in:
nimlgen
2023-09-29 20:13:04 +03:00
committed by GitHub
parent 90326dbdc3
commit 692bec7b6f
3 changed files with 41 additions and 58 deletions

View File

@@ -59,8 +59,9 @@ class TestCacheCollector(unittest.TestCase):
assert cache[0][1][1] == inps[0], "Input should be on its place."
assert cache[1][1][2] == inps[1], "Input should be on its place."
assert cache[-1][1][0] == out, "Output does not match."
assert get_bufs_count(cache) == 4, "Should have 4 buffers in total"
assert cache[-1][1][0] == cache[0][1][0], "Should reuse final output buffer as output in 1st kernel"
assert get_bufs_count(cache) == 5, "Should have 5 buffers in total"
# This is not worth added complexity on real models
# assert cache[-1][1][0] == cache[0][1][0], "Should reuse final output buffer as output in 1st kernel"
FAKE_GLOBAL_ALLOCATOR = None
def test_cache_collector_cycle_avoidance(self):
@@ -78,8 +79,8 @@ class TestCacheCollector(unittest.TestCase):
assert cache[0][1][1] == inps[0], "Input should be on its place."
assert cache[1][1][2] == inps[1], "Input should be on its place."
assert cache[-1][1][0] == out, "Output does not match."
assert get_bufs_count(cache) == 6, "Should have 6 buffers in total"
assert cache[-1][1][0] != cache[0][1][0] and cache[0][1][0] == cache[3][1][0], "Output buffers from 1st and 4th kernel could not be the same as the 5th."
assert get_bufs_count(cache) == 6, "Should have 6 buffers in total"
FAKE_GLOBAL_ALLOCATOR = None
def test_cache_collector_all_alive(self):
@@ -144,7 +145,7 @@ class TestCacheCollector(unittest.TestCase):
assert cache[i][1][0]._device == '1', f"Device does not match {i}, has {cache[i][1][0]._device}."
for i in range(3, 6):
assert cache[i][1][0]._device == '2', f"Device does not match {i}, has {cache[i][1][0]._device}."
assert get_bufs_count(cache) == 6
assert get_bufs_count(cache) == 7
FAKE_GLOBAL_ALLOCATOR = None
def test_cache_collector_anybufs_inputs(self):
@@ -161,7 +162,7 @@ class TestCacheCollector(unittest.TestCase):
cache = CacheCollector.finish()
assert cache[0][1][1] == inps[0], "Input should be on its place."
assert cache[1][1][2] == inps[1], "Input should be on its place."
assert get_bufs_count(cache) == 7
assert get_bufs_count(cache) == 8
FAKE_GLOBAL_ALLOCATOR = None
def test_cache_collector_optimize_when_not_cached_anymore(self):
@@ -201,7 +202,7 @@ class TestCacheCollector(unittest.TestCase):
assert cache[1][1][2] == inps[1], "Input should be on its place."
assert cache[-1][1][0] == out, "Output does not match."
assert cache[0][1][0] != cache[3][1][0], "Cannot reuse 4th output buffer, it's an output buffer which might ovewrite itself"
assert get_bufs_count(cache) == 7, "Should have 7 buffers in total"
assert get_bufs_count(cache) == 6, "Should have 6 buffers in total"
FAKE_GLOBAL_ALLOCATOR = None
if __name__ == "__main__":