add tests for multi ram usage [pr] (#10376)

This commit is contained in:
George Hotz
2025-05-17 15:33:40 -07:00
committed by GitHub
parent 5a18eab908
commit 6ec88d94df
2 changed files with 38 additions and 1 deletions

View File

@@ -1135,7 +1135,7 @@ def helper_test_shard_op(shps, fxn, atol=1e-6, rtol=1e-3):
except Exception as e:
raise Exception(f"Failed shape {single_out.shape}: {e}")
@unittest.skipIf(not_support_multi_device, "no multi")
@unittest.skipIf(not_support_multi_device(), "no multi")
class TestTensorOps(unittest.TestCase):
def test_interpolate(self):
helper_test_shard_op([(4,16,16),(4,24,24)], lambda x: Tensor.interpolate(x, (19,19)))
@@ -1143,5 +1143,40 @@ class TestTensorOps(unittest.TestCase):
def test_bitcast(self):
helper_test_shard_op([(256,), (256,)], lambda x: x.bitcast(dtypes.int))
# TODO: make these tests pass with VIZ=1
@unittest.skipIf(not_support_multi_device(), "no multi")
class TestMultiRamUsage(unittest.TestCase):
def setUp(self):
self.baseline = GlobalCounters.mem_used
self.N = 100
def assertUsed(self, amt, strict=True):
used = GlobalCounters.mem_used - self.baseline
print(f"used {used} bytes")
if strict: self.assertEqual(used, amt)
else: self.assertLessEqual(used, amt)
def test_zeros(self):
_ = Tensor.zeros(self.N, self.N).contiguous().realize()
self.assertUsed(self.N*self.N*4)
def test_zeros_del(self):
_ = Tensor.zeros(self.N, self.N).contiguous().realize()
del _
self.assertUsed(0)
def test_zeros_copy(self):
_ = Tensor.zeros(self.N, self.N).contiguous().to(devices_2).realize()
# NOTE: the first one on the DEFAULT device should be freed
self.assertUsed(self.N*self.N*4*2)
@unittest.skip("TODO: this is broken")
def test_zeros_shard(self):
_ = Tensor.zeros(self.N, self.N).contiguous().shard(devices_2, axis=0).realize()
self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage
def test_zeros_contiguous_shard(self):
_ = Tensor.zeros(self.N, self.N).contiguous().shard(devices_2, axis=0).contiguous().realize()
self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage
if __name__ == '__main__':
unittest.main()

View File

@@ -130,6 +130,7 @@ class Buffer:
def ensure_allocated(self) -> Buffer: return self.allocate() if not self.is_allocated() else self
def allocate(self, opaque=None, external_ptr=None) -> Buffer:
assert not self.is_allocated(), "can't allocate already allocated buffer"
if DEBUG >= 7: print(f"buffer: allocate {self.nbytes} bytes on {self.device}")
self.allocator:Allocator = Device[self.device].allocator
if external_ptr is not None:
self.options = replace(self.options, external_ptr=external_ptr) if self.options else BufferSpec(external_ptr=external_ptr)
@@ -144,6 +145,7 @@ class Buffer:
return self
def deallocate(self):
assert self.is_allocated(), "buffer must be allocated to deallocate"
if DEBUG >= 7: print(f"buffer: deallocate {self.nbytes} bytes on {self.device}")
if self._base is None and (self.options is None or self.options.external_ptr is None):
if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
self.allocator.free(self._buf, self.nbytes, self.options)