add tests for multi ram usage [pr] (#10376)

This commit is contained in:
George Hotz
2025-05-17 15:33:40 -07:00
committed by GitHub
parent 5a18eab908
commit 6ec88d94df
2 changed files with 38 additions and 1 deletions

View File

@@ -1135,7 +1135,7 @@ def helper_test_shard_op(shps, fxn, atol=1e-6, rtol=1e-3):
except Exception as e:
raise Exception(f"Failed shape {single_out.shape}: {e}")
@unittest.skipIf(not_support_multi_device, "no multi")
@unittest.skipIf(not_support_multi_device(), "no multi")
class TestTensorOps(unittest.TestCase):
def test_interpolate(self):
helper_test_shard_op([(4,16,16),(4,24,24)], lambda x: Tensor.interpolate(x, (19,19)))
@@ -1143,5 +1143,40 @@ class TestTensorOps(unittest.TestCase):
def test_bitcast(self):
helper_test_shard_op([(256,), (256,)], lambda x: x.bitcast(dtypes.int))
# TODO: make these tests pass with VIZ=1
@unittest.skipIf(not_support_multi_device(), "no multi")
class TestMultiRamUsage(unittest.TestCase):
def setUp(self):
self.baseline = GlobalCounters.mem_used
self.N = 100
def assertUsed(self, amt, strict=True):
used = GlobalCounters.mem_used - self.baseline
print(f"used {used} bytes")
if strict: self.assertEqual(used, amt)
else: self.assertLessEqual(used, amt)
def test_zeros(self):
_ = Tensor.zeros(self.N, self.N).contiguous().realize()
self.assertUsed(self.N*self.N*4)
def test_zeros_del(self):
_ = Tensor.zeros(self.N, self.N).contiguous().realize()
del _
self.assertUsed(0)
def test_zeros_copy(self):
_ = Tensor.zeros(self.N, self.N).contiguous().to(devices_2).realize()
# NOTE: the first one on the DEFAULT device should be freed
self.assertUsed(self.N*self.N*4*2)
@unittest.skip("TODO: this is broken")
def test_zeros_shard(self):
_ = Tensor.zeros(self.N, self.N).contiguous().shard(devices_2, axis=0).realize()
self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage
def test_zeros_contiguous_shard(self):
_ = Tensor.zeros(self.N, self.N).contiguous().shard(devices_2, axis=0).contiguous().realize()
self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage
if __name__ == '__main__':
unittest.main()