diff --git a/test/backend/test_multitensor.py b/test/backend/test_multitensor.py index ebb1bd9660..7c9c423900 100644 --- a/test/backend/test_multitensor.py +++ b/test/backend/test_multitensor.py @@ -228,17 +228,17 @@ class TestMultiTensor(unittest.TestCase): a,b = _test_allreduce(Tensor.rand(256, 256)) np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5) - def test_multiple_to_single_device_naive(self): - with Context(RING=0): - t = Tensor.arange(32).shard(devices_4, 0).to(Device.DEFAULT).realize() - self.assertEqual(t.device, Device.DEFAULT) - np.testing.assert_equal(t.numpy(), np.arange(32)) - - def test_multiple_to_single_device_ring(self): - with Context(RING=2): - t = Tensor.arange(32).shard(devices_4, 0).to(Device.DEFAULT).realize() - self.assertEqual(t.device, Device.DEFAULT) - np.testing.assert_equal(t.numpy(), np.arange(32)) + def test_multiple_to_single_device(self): + kernel_counts = {} + for ring in (0, 2): + GlobalCounters.reset() + with Context(RING=ring, SCACHE=0): + t = Tensor.arange(32).contiguous().shard(devices_4, 0).to(Device.DEFAULT) + t.realize() + kernel_counts[ring] = GlobalCounters.kernel_count + self.assertEqual(t.device, Device.DEFAULT) + np.testing.assert_equal(t.numpy(), np.arange(32)) + self.assertNotEqual(kernel_counts[0], kernel_counts[2]) def test_allreduce_all2all(self): with Context(ALL2ALL=2):