mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-16 17:45:38 -05:00
simplify CacheCollector (#1944)
* rewrite cc * fix * fix tests * fix all tests * is it better * better with shape * cleaner * linter fix * no ; * better comment * better comments * no thneed changes
This commit is contained in:
@@ -59,8 +59,9 @@ class TestCacheCollector(unittest.TestCase):
|
||||
assert cache[0][1][1] == inps[0], "Input should be on its place."
|
||||
assert cache[1][1][2] == inps[1], "Input should be on its place."
|
||||
assert cache[-1][1][0] == out, "Output does not match."
|
||||
assert get_bufs_count(cache) == 4, "Should have 4 buffers in total"
|
||||
assert cache[-1][1][0] == cache[0][1][0], "Should reuse final output buffer as output in 1st kernel"
|
||||
assert get_bufs_count(cache) == 5, "Should have 5 buffers in total"
|
||||
# This is not worth added complexity on real models
|
||||
# assert cache[-1][1][0] == cache[0][1][0], "Should reuse final output buffer as output in 1st kernel"
|
||||
FAKE_GLOBAL_ALLOCATOR = None
|
||||
|
||||
def test_cache_collector_cycle_avoidance(self):
|
||||
@@ -78,8 +79,8 @@ class TestCacheCollector(unittest.TestCase):
|
||||
assert cache[0][1][1] == inps[0], "Input should be on its place."
|
||||
assert cache[1][1][2] == inps[1], "Input should be on its place."
|
||||
assert cache[-1][1][0] == out, "Output does not match."
|
||||
assert get_bufs_count(cache) == 6, "Should have 6 buffers in total"
|
||||
assert cache[-1][1][0] != cache[0][1][0] and cache[0][1][0] == cache[3][1][0], "Output buffers from 1st and 4th kernel could not be the same as the 5th."
|
||||
assert get_bufs_count(cache) == 6, "Should have 6 buffers in total"
|
||||
FAKE_GLOBAL_ALLOCATOR = None
|
||||
|
||||
def test_cache_collector_all_alive(self):
|
||||
@@ -144,7 +145,7 @@ class TestCacheCollector(unittest.TestCase):
|
||||
assert cache[i][1][0]._device == '1', f"Device does not match {i}, has {cache[i][1][0]._device}."
|
||||
for i in range(3, 6):
|
||||
assert cache[i][1][0]._device == '2', f"Device does not match {i}, has {cache[i][1][0]._device}."
|
||||
assert get_bufs_count(cache) == 6
|
||||
assert get_bufs_count(cache) == 7
|
||||
FAKE_GLOBAL_ALLOCATOR = None
|
||||
|
||||
def test_cache_collector_anybufs_inputs(self):
|
||||
@@ -161,7 +162,7 @@ class TestCacheCollector(unittest.TestCase):
|
||||
cache = CacheCollector.finish()
|
||||
assert cache[0][1][1] == inps[0], "Input should be on its place."
|
||||
assert cache[1][1][2] == inps[1], "Input should be on its place."
|
||||
assert get_bufs_count(cache) == 7
|
||||
assert get_bufs_count(cache) == 8
|
||||
FAKE_GLOBAL_ALLOCATOR = None
|
||||
|
||||
def test_cache_collector_optimize_when_not_cached_anymore(self):
|
||||
@@ -201,7 +202,7 @@ class TestCacheCollector(unittest.TestCase):
|
||||
assert cache[1][1][2] == inps[1], "Input should be on its place."
|
||||
assert cache[-1][1][0] == out, "Output does not match."
|
||||
assert cache[0][1][0] != cache[3][1][0], "Cannot reuse 4th output buffer, it's an output buffer which might ovewrite itself"
|
||||
assert get_bufs_count(cache) == 7, "Should have 7 buffers in total"
|
||||
assert get_bufs_count(cache) == 6, "Should have 6 buffers in total"
|
||||
FAKE_GLOBAL_ALLOCATOR = None
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user