diff --git a/test/test_multitensor.py b/test/test_multitensor.py index e7c8aa401c..95d7825f8e 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -1273,19 +1273,20 @@ class TestMultiRamUsage(unittest.TestCase): _ = Tensor.zeros(self.N, self.N).contiguous().shard(devices_2, axis=0).contiguous().realize() self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage - def _test_matmul_half(self, devs): + def _test_matmul_half(self, dev_count:int): N = 32 total_mem = {} + devs = tuple(f"NULL:{i}" for i in range(dev_count)) for dtype in {dtypes.float, dtypes.half}: GlobalCounters.reset() - a = Tensor.empty((N, N), dtype=dtype).shard(devs, axis=0) - b = Tensor.empty((N, N), dtype=dtype).shard(devs, axis=None) + a = Tensor.empty((N, N), dtype=dtype, device=devs[0]).shard(devs, axis=0) + b = Tensor.empty((N, N), dtype=dtype, device=devs[0]).shard(devs, axis=None) (a @ b).realize() total_mem[dtype] = GlobalCounters.global_mem self.assertEqual(total_mem[dtypes.half], total_mem[dtypes.float] // 2) - def test_matmul_half(self): self._test_matmul_half(devices_2) - def test_matmul_half_alt(self): self._test_matmul_half(devices_4) + def test_matmul_half(self): self._test_matmul_half(dev_count=2) + def test_matmul_half_alt(self): self._test_matmul_half(dev_count=4) @unittest.skipIf(not_support_multi_device(), "need multi") class TestMultiFromUnrenderable(unittest.TestCase):