diff --git a/test/null/test_multitensor.py b/test/null/test_multitensor.py index 6161f5394b..f2f746dae3 100644 --- a/test/null/test_multitensor.py +++ b/test/null/test_multitensor.py @@ -1,5 +1,6 @@ import gc, unittest from tinygrad import Tensor, GlobalCounters, dtypes +from tinygrad.engine.jit import TinyJit class TestMultiRamUsage(unittest.TestCase): def setUp(self): @@ -107,6 +108,38 @@ class TestMultiRamUsage(unittest.TestCase): def test_matmul_half(self): self._test_matmul_half(dev_count=2) def test_matmul_half_alt(self): self._test_matmul_half(dev_count=4) + @unittest.expectedFailure + def test_multi_layer_allreduce(self): + N = 32 + devices_2 = ("NULL:1", "NULL:2") + + def make_inp(): + x = Tensor.zeros(N, N).contiguous().shard(devices_2, axis=None).realize() + w1 = Tensor.zeros(N, N).contiguous().shard(devices_2, axis=1).realize() + w2 = Tensor.zeros(N, N).contiguous().shard(devices_2, axis=0).realize() + return x, w1, w2 + + def run_layers(n_layers): + GlobalCounters.reset() + + @TinyJit + def f(x, w1, w2): + for _ in range(n_layers): + x = (x @ w1 @ w2) + return x.contiguous() + + for _ in range(3): + a = make_inp() + r = f(*a) + del a, r + + gc.collect() + return GlobalCounters.mem_used + + mem_2 = run_layers(2) + mem_4 = run_layers(4) + self.assertEqual(mem_2, mem_4, f"graph memory should not grow with layers: 2 layers={mem_2}, 4 layers={mem_4}") + class TestMultiAxis(unittest.TestCase): def test_reshape_shard_invalid(self): devices = ("NULL:0", "NULL:1")