Files
tinygrad/test/test_cache_collector.py
nimlgen 1c0449e190 add cache collector (#1595)
* init cache collector

* add test_cache_collector.py

* switch GlobalCounters.cache to CacheCollector

* init jit models test

* jitted SD

* add debug msg to print loaded bufs count

* moved cache collctor to jit

* clearer SD

* no double device import
2023-08-28 19:59:55 -07:00

169 lines
8.0 KiB
Python

#!/usr/bin/env python
import unittest
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
from tinygrad.helpers import dtypes
from tinygrad.jit import CacheCollector
from weakref import ref
class FakeDeviceBuffer():
def __init__(self, sz, dt, device):
self.size = sz
self.dtype = dt
self.device = device
class FakeAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs): return FakeDeviceBuffer(size, dtype, device)
FAKE_GLOBAL_ALLOCATOR = None
class FakeBuffer(RawBuffer):
def __init__(self, size, dtype, device='0'):
global FAKE_GLOBAL_ALLOCATOR
super().__init__(size, dtype, allocator=FAKE_GLOBAL_ALLOCATOR, **{'device': device})
assert self._buf.size == size and self._buf.dtype == dtype and self._buf.device == device, "This allocator requires 100% match of dtype and size."
def alloc(allocator, size, dtype, **kwargs):
global FAKE_GLOBAL_ALLOCATOR
FAKE_GLOBAL_ALLOCATOR = allocator
buf = FakeBuffer(size, dtype, **kwargs)
assert buf.dtype == dtype and buf.size == size
FAKE_GLOBAL_ALLOCATOR = None
return buf
def anybuf(size, dtype):
return FakeBuffer(size, dtype)
def add_to_cache(bufs):
CacheCollector.add(None, bufs, None)
return bufs[0]
def add_to_cache_refed(bufs):
CacheCollector.add(None, bufs, None)
return bufs[0], [ref(buf) for buf in bufs]
def get_bufs_count(cache):
ss = set()
for (_,bufs,_) in cache:
for buf in bufs: ss.add(buf)
return len(ss)
class TestCacheCollector(unittest.TestCase):
def test_cache_collector_optimization(self):
global FAKE_GLOBAL_ALLOCATOR
FAKE_GLOBAL_ALLOCATOR = FakeAllocator(256 << 30)
inps = [FakeBuffer(64, dtypes.float32) for _ in range(2)]
CacheCollector.start()
out = add_to_cache([FakeBuffer(32, dtypes.float32), inps[0]])
out = add_to_cache([FakeBuffer(32, dtypes.float32), out, inps[1]])
out = add_to_cache([FakeBuffer(32, dtypes.float32), out])
cache = CacheCollector.finish()
assert cache[0][1][1] == inps[0], "Input should be on its place."
assert cache[1][1][2] == inps[1], "Input should be on its place."
assert cache[-1][1][0] == out, "Output does not match."
assert get_bufs_count(cache) == 4, "Should have 4 buffers in total"
assert cache[-1][1][0] == cache[0][1][0], "Should reuse final output buffer as output in 1st kernel"
FAKE_GLOBAL_ALLOCATOR = None
def test_cache_collector_cycle_avoidance(self):
global FAKE_GLOBAL_ALLOCATOR
FAKE_GLOBAL_ALLOCATOR = FakeAllocator(256 << 30)
inps = [FakeBuffer(64, dtypes.float32) for _ in range(2)]
CacheCollector.start()
# Output buffer here cannot be shared with final output buffer, since we could get a cycle the next step as inps[1] has the same shape and dtype.
out = add_to_cache([FakeBuffer(64, dtypes.float32), inps[0]])
out = add_to_cache([FakeBuffer(32, dtypes.float32), out, inps[1]])
out = add_to_cache([FakeBuffer(32, dtypes.float32), out])
out = add_to_cache([FakeBuffer(64, dtypes.float32), out])
out = add_to_cache([FakeBuffer(64, dtypes.float32), out])
cache = CacheCollector.finish()
assert cache[0][1][1] == inps[0], "Input should be on its place."
assert cache[1][1][2] == inps[1], "Input should be on its place."
assert cache[-1][1][0] == out, "Output does not match."
assert get_bufs_count(cache) == 6, "Should have 6 buffers in total"
assert cache[-1][1][0] != cache[0][1][0] and cache[0][1][0] == cache[3][1][0], "Output buffers from 1st and 4th kernel could not be the same as the 5th."
FAKE_GLOBAL_ALLOCATOR = None
def test_cache_collector_all_alive(self):
global FAKE_GLOBAL_ALLOCATOR
FAKE_GLOBAL_ALLOCATOR = FakeAllocator(256 << 30)
inps = [FakeBuffer(64, dtypes.float32) for _ in range(2)]
outs = [FakeBuffer(128, dtypes.float32) for _ in range(4)]
CacheCollector.start()
out = add_to_cache([outs[0], inps[0]])
out = add_to_cache([outs[1], out, inps[1]])
out = add_to_cache([outs[2], out])
out = add_to_cache([outs[3], out])
cache = CacheCollector.finish()
assert cache[0][1][1] == inps[0], "Input should be on its place."
assert cache[1][1][2] == inps[1], "Input should be on its place."
assert cache[0][1][0] == outs[0], "Output0 should be on its place."
assert cache[1][1][0] == outs[1], "Output1 should be on its place."
assert cache[2][1][0] == outs[2], "Output2 should be on its place."
assert cache[3][1][0] == outs[3], "Output3 should be on its place."
assert cache[-1][1][0] == out, "Output does not match."
assert get_bufs_count(cache) == len(outs) + len(inps), "Nothing to optimize, since buffers are alive and might be used as outputs"
FAKE_GLOBAL_ALLOCATOR = None
def test_cache_collector_middle_input(self):
global FAKE_GLOBAL_ALLOCATOR
FAKE_GLOBAL_ALLOCATOR = FakeAllocator(256 << 30)
inps = [FakeBuffer(64, dtypes.float32) for _ in range(2)]
outs = [FakeBuffer(32, dtypes.float32) for _ in range(1)]
CacheCollector.start()
out = add_to_cache([FakeBuffer(32, dtypes.float32), inps[0]])
out = add_to_cache([FakeBuffer(32, dtypes.float32), out, inps[1]])
out,refs2 = add_to_cache_refed([outs[0], out, FakeBuffer(32, dtypes.float32)])
out = add_to_cache([FakeBuffer(32, dtypes.float32), out])
out = add_to_cache([FakeBuffer(32, dtypes.float32), out])
cache = CacheCollector.finish()
assert cache[0][1][1] == inps[0], "Input should be on its place."
assert cache[1][1][2] == inps[1], "Input should be on its place."
assert cache[2][1][2] == refs2[2](), "Input should be captured."
assert cache[0][1][0] != cache[2][1][2], "None of outputs buffer should reuse new_input."
assert cache[1][1][0] != cache[2][1][2], "None of outputs buffer should reuse new_input."
assert cache[3][1][0] != cache[2][1][2], "None of outputs buffer should reuse new_input."
assert cache[4][1][0] != cache[2][1][2], "None of outputs buffer should reuse new_input."
assert cache[-1][1][0] == out, "Output does not match."
assert get_bufs_count(cache) == 7
FAKE_GLOBAL_ALLOCATOR = None
def test_cache_collector_multidev(self):
global FAKE_GLOBAL_ALLOCATOR
FAKE_GLOBAL_ALLOCATOR = FakeAllocator(256 << 30)
inps = [FakeBuffer(64, dtypes.float32, '1') for _ in range(2)]
CacheCollector.start()
out = add_to_cache([FakeBuffer(32, dtypes.float32, '1'), inps[0]])
out = add_to_cache([FakeBuffer(32, dtypes.float32, '1'), out, inps[1]])
out = add_to_cache([FakeBuffer(32, dtypes.float32, '1'), out])
out = add_to_cache([FakeBuffer(32, dtypes.float32, '2'), out])
out = add_to_cache([FakeBuffer(32, dtypes.float32, '2'), out])
out = add_to_cache([FakeBuffer(32, dtypes.float32, '2'), out])
cache = CacheCollector.finish()
assert cache[0][1][1] == inps[0], "Input should be on its place."
assert cache[1][1][2] == inps[1], "Input should be on its place."
for i in range(3):
assert cache[i][1][0]._device == '1', f"Device does not match {i}, has {cache[i][1][0]._device}."
for i in range(3, 6):
assert cache[i][1][0]._device == '2', f"Device does not match {i}, has {cache[i][1][0]._device}."
assert get_bufs_count(cache) == 6
FAKE_GLOBAL_ALLOCATOR = None
def test_cache_collector_anybufs_inputs(self):
global FAKE_GLOBAL_ALLOCATOR
FAKE_GLOBAL_ALLOCATOR = FakeAllocator(256 << 30)
inps = [FakeBuffer(64, dtypes.float32, '1') for _ in range(2)]
CacheCollector.start()
out = add_to_cache([FakeBuffer(32, dtypes.float32), inps[0]])
out = add_to_cache([FakeBuffer(32, dtypes.float32), out, inps[1]])
out = add_to_cache([FakeBuffer(32, dtypes.float32), 32, None])
out = add_to_cache([FakeBuffer(32, dtypes.float32), out, 58, None])
out = add_to_cache([FakeBuffer(32, dtypes.float32), out])
out = add_to_cache([FakeBuffer(32, dtypes.float32), out])
cache = CacheCollector.finish()
assert cache[0][1][1] == inps[0], "Input should be on its place."
assert cache[1][1][2] == inps[1], "Input should be on its place."
assert get_bufs_count(cache) == 7
FAKE_GLOBAL_ALLOCATOR = None
if __name__ == "__main__":
unittest.main()