diff --git a/test/null/test_multitensor.py b/test/null/test_multitensor.py index 117f7c3937..6161f5394b 100644 --- a/test/null/test_multitensor.py +++ b/test/null/test_multitensor.py @@ -66,16 +66,17 @@ class TestMultiRamUsage(unittest.TestCase): self.assertUsed(256 * 4) # TODO: can be zero def test_zeros_per_device(self): - _ = Tensor.zeros(self.N, self.N).contiguous().realize() + _ = Tensor.zeros(self.N, self.N, device="NULL").contiguous().realize() self.assertDeviceUsed({"NULL": self.N*self.N*4}) def test_zeros_del_per_device(self): - _ = Tensor.zeros(self.N, self.N).contiguous().realize() + _ = Tensor.zeros(self.N, self.N, device="NULL").contiguous().realize() del _ self.assertDeviceUsed({"NULL": 0}) def test_zeros_copy_per_device(self): - _ = Tensor.zeros(self.N, self.N).contiguous().to(("NULL:1", "NULL:2")).realize() + devices_2 = ("NULL:1", "NULL:2") + _ = Tensor.zeros(self.N, self.N).contiguous().to(devices_2).realize() self.assertDeviceUsed({"NULL:1": self.N*self.N*4, "NULL:2": self.N*self.N*4}) def test_zeros_shard_per_device(self): @@ -85,7 +86,7 @@ class TestMultiRamUsage(unittest.TestCase): def test_sharded_memory_replicated_per_device(self): devices_4 = tuple(f"NULL:{i+1}" for i in range(4)) - X = Tensor.ones(256).contiguous().realize() + X = Tensor.ones(256, device="NULL").contiguous().realize() self.assertDeviceUsed({"NULL": 256*4}) X.shard_(devices_4).realize() for d in devices_4: