mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
add tests for multi ram usage [pr] (#10376)
This commit is contained in:
@@ -1135,7 +1135,7 @@ def helper_test_shard_op(shps, fxn, atol=1e-6, rtol=1e-3):
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed shape {single_out.shape}: {e}")
|
||||
|
||||
@unittest.skipIf(not_support_multi_device, "no multi")
|
||||
@unittest.skipIf(not_support_multi_device(), "no multi")
|
||||
class TestTensorOps(unittest.TestCase):
|
||||
def test_interpolate(self):
|
||||
helper_test_shard_op([(4,16,16),(4,24,24)], lambda x: Tensor.interpolate(x, (19,19)))
|
||||
@@ -1143,5 +1143,40 @@ class TestTensorOps(unittest.TestCase):
|
||||
def test_bitcast(self):
|
||||
helper_test_shard_op([(256,), (256,)], lambda x: x.bitcast(dtypes.int))
|
||||
|
||||
# TODO: make these tests pass with VIZ=1
|
||||
@unittest.skipIf(not_support_multi_device(), "no multi")
|
||||
class TestMultiRamUsage(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.baseline = GlobalCounters.mem_used
|
||||
self.N = 100
|
||||
def assertUsed(self, amt, strict=True):
|
||||
used = GlobalCounters.mem_used - self.baseline
|
||||
print(f"used {used} bytes")
|
||||
if strict: self.assertEqual(used, amt)
|
||||
else: self.assertLessEqual(used, amt)
|
||||
|
||||
def test_zeros(self):
|
||||
_ = Tensor.zeros(self.N, self.N).contiguous().realize()
|
||||
self.assertUsed(self.N*self.N*4)
|
||||
|
||||
def test_zeros_del(self):
|
||||
_ = Tensor.zeros(self.N, self.N).contiguous().realize()
|
||||
del _
|
||||
self.assertUsed(0)
|
||||
|
||||
def test_zeros_copy(self):
|
||||
_ = Tensor.zeros(self.N, self.N).contiguous().to(devices_2).realize()
|
||||
# NOTE: the first one on the DEFAULT device should be freed
|
||||
self.assertUsed(self.N*self.N*4*2)
|
||||
|
||||
@unittest.skip("TODO: this is broken")
|
||||
def test_zeros_shard(self):
|
||||
_ = Tensor.zeros(self.N, self.N).contiguous().shard(devices_2, axis=0).realize()
|
||||
self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage
|
||||
|
||||
def test_zeros_contiguous_shard(self):
|
||||
_ = Tensor.zeros(self.N, self.N).contiguous().shard(devices_2, axis=0).contiguous().realize()
|
||||
self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user