mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
add tests for multi ram usage [pr] (#10376)
This commit is contained in:
@@ -1135,7 +1135,7 @@ def helper_test_shard_op(shps, fxn, atol=1e-6, rtol=1e-3):
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed shape {single_out.shape}: {e}")
|
||||
|
||||
@unittest.skipIf(not_support_multi_device, "no multi")
|
||||
@unittest.skipIf(not_support_multi_device(), "no multi")
|
||||
class TestTensorOps(unittest.TestCase):
|
||||
def test_interpolate(self):
|
||||
helper_test_shard_op([(4,16,16),(4,24,24)], lambda x: Tensor.interpolate(x, (19,19)))
|
||||
@@ -1143,5 +1143,40 @@ class TestTensorOps(unittest.TestCase):
|
||||
def test_bitcast(self):
|
||||
helper_test_shard_op([(256,), (256,)], lambda x: x.bitcast(dtypes.int))
|
||||
|
||||
# TODO: make these tests pass with VIZ=1
|
||||
@unittest.skipIf(not_support_multi_device(), "no multi")
|
||||
class TestMultiRamUsage(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.baseline = GlobalCounters.mem_used
|
||||
self.N = 100
|
||||
def assertUsed(self, amt, strict=True):
|
||||
used = GlobalCounters.mem_used - self.baseline
|
||||
print(f"used {used} bytes")
|
||||
if strict: self.assertEqual(used, amt)
|
||||
else: self.assertLessEqual(used, amt)
|
||||
|
||||
def test_zeros(self):
|
||||
_ = Tensor.zeros(self.N, self.N).contiguous().realize()
|
||||
self.assertUsed(self.N*self.N*4)
|
||||
|
||||
def test_zeros_del(self):
|
||||
_ = Tensor.zeros(self.N, self.N).contiguous().realize()
|
||||
del _
|
||||
self.assertUsed(0)
|
||||
|
||||
def test_zeros_copy(self):
|
||||
_ = Tensor.zeros(self.N, self.N).contiguous().to(devices_2).realize()
|
||||
# NOTE: the first one on the DEFAULT device should be freed
|
||||
self.assertUsed(self.N*self.N*4*2)
|
||||
|
||||
@unittest.skip("TODO: this is broken")
|
||||
def test_zeros_shard(self):
|
||||
_ = Tensor.zeros(self.N, self.N).contiguous().shard(devices_2, axis=0).realize()
|
||||
self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage
|
||||
|
||||
def test_zeros_contiguous_shard(self):
|
||||
_ = Tensor.zeros(self.N, self.N).contiguous().shard(devices_2, axis=0).contiguous().realize()
|
||||
self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -130,6 +130,7 @@ class Buffer:
|
||||
def ensure_allocated(self) -> Buffer: return self.allocate() if not self.is_allocated() else self
|
||||
def allocate(self, opaque=None, external_ptr=None) -> Buffer:
|
||||
assert not self.is_allocated(), "can't allocate already allocated buffer"
|
||||
if DEBUG >= 7: print(f"buffer: allocate {self.nbytes} bytes on {self.device}")
|
||||
self.allocator:Allocator = Device[self.device].allocator
|
||||
if external_ptr is not None:
|
||||
self.options = replace(self.options, external_ptr=external_ptr) if self.options else BufferSpec(external_ptr=external_ptr)
|
||||
@@ -144,6 +145,7 @@ class Buffer:
|
||||
return self
|
||||
def deallocate(self):
|
||||
assert self.is_allocated(), "buffer must be allocated to deallocate"
|
||||
if DEBUG >= 7: print(f"buffer: deallocate {self.nbytes} bytes on {self.device}")
|
||||
if self._base is None and (self.options is None or self.options.external_ptr is None):
|
||||
if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
|
||||
self.allocator.free(self._buf, self.nbytes, self.options)
|
||||
|
||||
Reference in New Issue
Block a user