update test_multiple_to_single_device (#15056)

follow up to #14482, add SCACHE=0 to the test
This commit is contained in:
chenyu
2026-02-27 21:44:33 -05:00
committed by GitHub
parent 5fd06f4f02
commit 151608aa90

View File

@@ -228,17 +228,17 @@ class TestMultiTensor(unittest.TestCase):
a,b = _test_allreduce(Tensor.rand(256, 256))
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
def test_multiple_to_single_device_naive(self):
with Context(RING=0):
t = Tensor.arange(32).shard(devices_4, 0).to(Device.DEFAULT).realize()
self.assertEqual(t.device, Device.DEFAULT)
np.testing.assert_equal(t.numpy(), np.arange(32))
def test_multiple_to_single_device_ring(self):
with Context(RING=2):
t = Tensor.arange(32).shard(devices_4, 0).to(Device.DEFAULT).realize()
self.assertEqual(t.device, Device.DEFAULT)
np.testing.assert_equal(t.numpy(), np.arange(32))
def test_multiple_to_single_device(self):
kernel_counts = {}
for ring in (0, 2):
GlobalCounters.reset()
with Context(RING=ring, SCACHE=0):
t = Tensor.arange(32).contiguous().shard(devices_4, 0).to(Device.DEFAULT)
t.realize()
kernel_counts[ring] = GlobalCounters.kernel_count
self.assertEqual(t.device, Device.DEFAULT)
np.testing.assert_equal(t.numpy(), np.arange(32))
self.assertNotEqual(kernel_counts[0], kernel_counts[2])
def test_allreduce_all2all(self):
with Context(ALL2ALL=2):