* all2all

* um

* fix

* x

* um

* simler

* mypy

* fix

* t

* cmnts
This commit is contained in:
nimlgen
2025-12-31 16:38:32 +03:00
committed by GitHub
parent f7ee644950
commit 25440f0f72
6 changed files with 66 additions and 58 deletions

View File

@@ -256,6 +256,11 @@ class TestMultiTensor(unittest.TestCase):
a,b = _test_allreduce(Tensor.rand(256, 256))
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
def test_allreduce_all2all(self):
with Context(ALL2ALL=2):
a,b = _test_allreduce(Tensor.rand(256, 256))
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
def test_copy_jit(self):
@TinyJit
def copy_tensor(x:Tensor): return (x.to(f"{x.device.split(':')[0]}:1") + 1)