diff --git a/test/test_allocators.py b/test/test_allocators.py index 74119a90f3..849ccbb96a 100644 --- a/test/test_allocators.py +++ b/test/test_allocators.py @@ -174,5 +174,15 @@ class TestAllocators(unittest.TestCase): test() check_gc() + def test_lru_allocator_massive_buffer(self): + with self.assertRaises(AssertionError) as context: alloc(allocator := FakeAllocator(), size := 1e13, dtypes.int8) + self.assertEqual(str(context.exception), f"out of memory - requested: {size/1e9:5.2f} GB, available: {allocator._get_cur_free_space('0')/1e9:5.2f} GB") + + @unittest.skipIf(Device.DEFAULT != "METAL", "only applies to Metal") + def test_lru_allocator_metal_max_buffer_length(self): + from tinygrad.runtime.ops_metal import METAL + with self.assertRaises(AssertionError) as context: METAL.allocator._do_alloc(buf_len := (max_buf_len := METAL.device.maxBufferLength()+1), dtypes.int8, '0') + self.assertEqual(str(context.exception), f"Buffer length of {buf_len/1e9:5.2f} GB exceeds Metal's max buffer length of {max_buf_len/1e9:5.2f} GB.") + if __name__ == "__main__": unittest.main() diff --git a/tinygrad/runtime/lib.py b/tinygrad/runtime/lib.py index 4fe63486c4..e4d236a916 100644 --- a/tinygrad/runtime/lib.py +++ b/tinygrad/runtime/lib.py @@ -80,6 +80,7 @@ class LRUAllocator: while len(self.aging_order[device]) and self._get_cur_free_space(device) < space_to_free: # When OOM removing lru buffers. bucket, epoch = self.aging_order[device].popleft() if self.cached_buffers[bucket] and self.cached_buffers[bucket][-1][1] == epoch: self._free_buffer(self.cached_buffers[bucket].pop()[0]) # Free cached buffer if it is still in cache. + assert (curr_free := self._get_cur_free_space(device)) > space_to_free, f"out of memory - requested: {space_to_free/1e9:5.2f} GB, available: {curr_free/1e9:5.2f} GB" def _alloc_buffer(self, size, dtype, device, **kwargs): self.ensure_has_free_space(size*dtype.itemsize, device) diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index 4e07079c54..c3771c86b1 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -10,7 +10,12 @@ from tinygrad.runtime.lib import RawBufferMapped, RawBuffer, LRUAllocator from tinygrad.shape.symbolic import Variable, Node class MetalAllocator(LRUAllocator): - def _do_alloc(self, size, dtype, device, **kwargs): return METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared) + def _do_alloc(self, size, dtype, device, **kwargs): + buf_len, max_buf_len = size*dtype.itemsize, METAL.device.maxBufferLength() + assert buf_len < max_buf_len, f"Buffer length of {buf_len/1e9:5.2f} GB exceeds Metal's max buffer length of {max_buf_len/1e9:5.2f} GB." + buf = METAL.device.newBufferWithLength_options_(buf_len, Metal.MTLResourceStorageModeShared) + assert buf, f"Metal buffer allocation failed with {buf}." + return buf def _do_free(self, buf): buf.release() def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype.