Files
tinygrad/test/test_allocators.py
Christopher Mauri Milan 7f01dd04f0 Apply ruff linting rules to tests (#2473)
* everything except F821

* enable F821 with noqa

* dumb fix

* fix remaining imports and (former) lambdas

* replace _ with noqa to avoid gc
2023-11-27 21:24:06 -08:00

189 lines
7.4 KiB
Python

#!/usr/bin/env python
import unittest
import pytest
import numpy as np
from weakref import ref
from tinygrad.helpers import GlobalCounters
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
from tinygrad.helpers import dtypes, prod
from tinygrad import Device
from tinygrad.tensor import Tensor
def check_gc():
if Device.DEFAULT == "GPU":
from extra.introspection import print_objects
assert print_objects() == 0
class FakeDeviceBuffer:
def __init__(self, sz, dt, device):
self.id = 1
self.size = sz
self.dtype = dt
self.device = device
def __del__(self):
assert self.id == 0, "Should called _do_free() before"
class FakeAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs):
if size*dtype.itemsize > self._get_cur_free_space(device): raise Exception("OOM")
return FakeDeviceBuffer(size, dtype, device)
def _do_free(self, buf):
buf.id -= 1
assert buf.id == 0, f"Free should be called once, but {buf.id}"
def __del__(self): # Fake allocator should clear all buffers after each test.
for v in self.cached_buffers.values():
for buf, _ in v: self._free_buffer(buf)
FAKE_GLOBAL_ALLOCATOR = None
class FakeBuffer(RawBuffer):
def __init__(self, size, dtype, device='0'):
global FAKE_GLOBAL_ALLOCATOR
super().__init__(size, dtype, allocator=FAKE_GLOBAL_ALLOCATOR, **{'device': device})
assert self._buf.size == size and self._buf.dtype == dtype and self._buf.device == device, "This allocator requires 100% match of dtype and size."
@classmethod
def fromCPU(cls, x:np.ndarray, **kwargs): return cls(prod(x.shape), dtypes.from_np(x.dtype), **kwargs)
def toCPU(self): return np.empty(self.size, dtype=self.dtype.np)
def alloc(allocator, size, dtype, **kwargs):
global FAKE_GLOBAL_ALLOCATOR
FAKE_GLOBAL_ALLOCATOR = allocator
buf = FakeBuffer(size, dtype, **kwargs)
assert buf.dtype == dtype and buf.size == size
FAKE_GLOBAL_ALLOCATOR = None
return buf
def alloc_free_trace(allocator, size, dtype, **kwargs):
buf = alloc(allocator, size, dtype, **kwargs)
return ref(buf._buf)
def cmp_trace_and_buf(buf, trace_ref): return trace_ref and trace_ref() == buf._buf
class TestAllocators(unittest.TestCase):
def test_lru_allocator_reusage(self):
mc, mu = GlobalCounters.mem_cached, GlobalCounters.mem_used
def test():
lru_allocator = FakeAllocator(2048)
traced_buf = alloc_free_trace(lru_allocator, 16, dtypes.float32)
assert GlobalCounters.mem_cached - mc == 16*dtypes.float32.itemsize, "Buffer should be cached"
for _ in range(32):
def __test():
buf = alloc(lru_allocator, 16, dtypes.float32)
assert cmp_trace_and_buf(buf, traced_buf), "Buffer should be reused"
__test()
usedbuf = alloc(lru_allocator, 16, dtypes.float32)
for _ in range(32):
def __test():
buf = alloc(lru_allocator, 16, dtypes.float32)
assert usedbuf != buf, "Nobody should get used buffer"
__test()
assert GlobalCounters.mem_used - mu == 16*dtypes.float32.itemsize, "Only usedbuf is still allocated."
test()
check_gc()
def test_lru_allocator_cache_free(self):
mc, mu = GlobalCounters.mem_cached, GlobalCounters.mem_used
def test():
lru_allocator = FakeAllocator(128)
refs = []
for _ in range(32):
refs.append(alloc_free_trace(lru_allocator, 16, dtypes.float32))
for sz in range(1, 32):
alloc_free_trace(lru_allocator, sz, dtypes.float32)
assert GlobalCounters.mem_used + GlobalCounters.mem_cached - mc - mu <= 128, "Should not allocate on device more than allowed (128)"
for r in refs: assert r() is None, "All refs should be dead, since buffers were cleared from cache"
test()
check_gc()
def test_lru_allocator_multidevice(self):
def test():
lru_allocator = FakeAllocator(256)
refs=[]
for i in range(8):
refs.append(alloc_free_trace(lru_allocator, 16, dtypes.float32, device=str(i)))
for i in range(64):
def __test():
dev = str(i % 8)
buf = alloc(lru_allocator, 16, dtypes.float32, device=dev)
assert cmp_trace_and_buf(buf, refs[i%8]), "Buffer should be reused"
__test()
for r in refs: assert r() is not None, "All refs should be cached"
test()
check_gc()
def test_lru_allocator_failing_alloc_cleans_cache(self):
def test():
lru_allocator = FakeAllocator(128)
for size in range(1, 4):
alloc_free_trace(lru_allocator, size, dtypes.float32, device='0')
assert len(lru_allocator.aging_order['0']) == 3, "All buffers should be cached"
assert lru_allocator.free_space['0'] == 128 - 24, "24 bytes to be used by current cached buffers"
def always_raise_exception(*args, **kwargs):
raise MemoryError("OOM")
lru_allocator._do_alloc = always_raise_exception
with pytest.raises(Exception):
alloc(lru_allocator, 5, dtypes.float32, device='0')
assert len(lru_allocator.aging_order['0']) == 0, "All buffers should be freed from cache due to failing alloc"
test()
check_gc()
def test_lru_allocator_fail_first_alloc_pass_after_clear_cahce(self):
def test():
lru_allocator = FakeAllocator(128)
for size in range(1, 4):
alloc_free_trace(lru_allocator, size, dtypes.float32, device='0')
cache_length = 3
assert len(lru_allocator.aging_order['0']) == cache_length, "All buffers should be cached"
assert lru_allocator.free_space['0'] == 128 - 24, "24 bytes to be used by current cached buffers"
original_do_alloc = lru_allocator._do_alloc # save the original method
def single_fail_then_pass(*args, **kwargs):
lru_allocator._do_alloc = original_do_alloc # restore the original method
raise MemoryError("OOM")
lru_allocator._do_alloc = single_fail_then_pass
alloc(lru_allocator, 5, dtypes.float32, device='0')
assert len(lru_allocator.aging_order['0']) < cache_length, "Some buffers should be cleaned as first alloc failed"
test()
check_gc()
@unittest.skip("failing in CI")
def test_gpu_copyout(self):
def test():
from tinygrad.runtime.ops_gpu import CL
# Allocation to init the allocator.
tx = Tensor.rand(1)
tx.realize()
free_space = CL.cl_allocator.free_space[tx.lazydata.realized._device]
# Spawning 128mb objects to fill half of free_space
will_allocate = free_space // 3
trash_allocation_size = free_space // 2
def sp():
trash_buffer = Tensor.rand(trash_allocation_size // 4)
trash_buffer.realize()
sp()
xx = Tensor.rand(will_allocate // 4)
_ = xx.numpy()
test()
check_gc()
def test_lru_allocator_massive_buffer(self):
with self.assertRaises(AssertionError) as context: alloc(allocator := FakeAllocator(), size := 1e13, dtypes.int8)
self.assertEqual(str(context.exception), f"out of memory - requested: {size/1e9:5.2f} GB, available: {allocator._get_cur_free_space('0')/1e9:5.2f} GB")
@unittest.skipIf(Device.DEFAULT != "METAL", "only applies to Metal")
def test_lru_allocator_metal_max_buffer_length(self):
from tinygrad.runtime.ops_metal import METAL
with self.assertRaises(AssertionError) as context: METAL.allocator._do_alloc(buf_len := (max_buf_len := METAL.device.maxBufferLength()+1), dtypes.int8, '0')
self.assertEqual(str(context.exception), f"Buffer length of {buf_len/1e9:5.2f} GB exceeds Metal's max buffer length of {max_buf_len/1e9:5.2f} GB.")
if __name__ == "__main__":
unittest.main()