diff --git a/test/test_multitensor.py b/test/test_multitensor.py index a059899835..7bc21753fb 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -1135,7 +1135,7 @@ def helper_test_shard_op(shps, fxn, atol=1e-6, rtol=1e-3): except Exception as e: raise Exception(f"Failed shape {single_out.shape}: {e}") -@unittest.skipIf(not_support_multi_device, "no multi") +@unittest.skipIf(not_support_multi_device(), "no multi") class TestTensorOps(unittest.TestCase): def test_interpolate(self): helper_test_shard_op([(4,16,16),(4,24,24)], lambda x: Tensor.interpolate(x, (19,19))) @@ -1143,5 +1143,40 @@ class TestTensorOps(unittest.TestCase): def test_bitcast(self): helper_test_shard_op([(256,), (256,)], lambda x: x.bitcast(dtypes.int)) +# TODO: make these tests pass with VIZ=1 +@unittest.skipIf(not_support_multi_device(), "no multi") +class TestMultiRamUsage(unittest.TestCase): + def setUp(self): + self.baseline = GlobalCounters.mem_used + self.N = 100 + def assertUsed(self, amt, strict=True): + used = GlobalCounters.mem_used - self.baseline + print(f"used {used} bytes") + if strict: self.assertEqual(used, amt) + else: self.assertLessEqual(used, amt) + + def test_zeros(self): + _ = Tensor.zeros(self.N, self.N).contiguous().realize() + self.assertUsed(self.N*self.N*4) + + def test_zeros_del(self): + _ = Tensor.zeros(self.N, self.N).contiguous().realize() + del _ + self.assertUsed(0) + + def test_zeros_copy(self): + _ = Tensor.zeros(self.N, self.N).contiguous().to(devices_2).realize() + # NOTE: the first one on the DEFAULT device should be freed + self.assertUsed(self.N*self.N*4*2) + + @unittest.skip("TODO: this is broken") + def test_zeros_shard(self): + _ = Tensor.zeros(self.N, self.N).contiguous().shard(devices_2, axis=0).realize() + self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage + + def test_zeros_contiguous_shard(self): + _ = Tensor.zeros(self.N, self.N).contiguous().shard(devices_2, axis=0).contiguous().realize() + self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/device.py b/tinygrad/device.py index 281983ab0d..016eb0d6a3 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -130,6 +130,7 @@ class Buffer: def ensure_allocated(self) -> Buffer: return self.allocate() if not self.is_allocated() else self def allocate(self, opaque=None, external_ptr=None) -> Buffer: assert not self.is_allocated(), "can't allocate already allocated buffer" + if DEBUG >= 7: print(f"buffer: allocate {self.nbytes} bytes on {self.device}") self.allocator:Allocator = Device[self.device].allocator if external_ptr is not None: self.options = replace(self.options, external_ptr=external_ptr) if self.options else BufferSpec(external_ptr=external_ptr) @@ -144,6 +145,7 @@ class Buffer: return self def deallocate(self): assert self.is_allocated(), "buffer must be allocated to deallocate" + if DEBUG >= 7: print(f"buffer: deallocate {self.nbytes} bytes on {self.device}") if self._base is None and (self.options is None or self.options.external_ptr is None): if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes self.allocator.free(self._buf, self.nbytes, self.options)