mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
update test_multiple_to_single_device (#15056)
follow up to #14482, add SCACHE=0 to the test
This commit is contained in:
@@ -228,17 +228,17 @@ class TestMultiTensor(unittest.TestCase):
|
||||
a,b = _test_allreduce(Tensor.rand(256, 256))
|
||||
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
|
||||
|
||||
def test_multiple_to_single_device_naive(self):
|
||||
with Context(RING=0):
|
||||
t = Tensor.arange(32).shard(devices_4, 0).to(Device.DEFAULT).realize()
|
||||
self.assertEqual(t.device, Device.DEFAULT)
|
||||
np.testing.assert_equal(t.numpy(), np.arange(32))
|
||||
|
||||
def test_multiple_to_single_device_ring(self):
|
||||
with Context(RING=2):
|
||||
t = Tensor.arange(32).shard(devices_4, 0).to(Device.DEFAULT).realize()
|
||||
self.assertEqual(t.device, Device.DEFAULT)
|
||||
np.testing.assert_equal(t.numpy(), np.arange(32))
|
||||
def test_multiple_to_single_device(self):
|
||||
kernel_counts = {}
|
||||
for ring in (0, 2):
|
||||
GlobalCounters.reset()
|
||||
with Context(RING=ring, SCACHE=0):
|
||||
t = Tensor.arange(32).contiguous().shard(devices_4, 0).to(Device.DEFAULT)
|
||||
t.realize()
|
||||
kernel_counts[ring] = GlobalCounters.kernel_count
|
||||
self.assertEqual(t.device, Device.DEFAULT)
|
||||
np.testing.assert_equal(t.numpy(), np.arange(32))
|
||||
self.assertNotEqual(kernel_counts[0], kernel_counts[2])
|
||||
|
||||
def test_allreduce_all2all(self):
|
||||
with Context(ALL2ALL=2):
|
||||
|
||||
Reference in New Issue
Block a user