failing test for allreduce memory usage (#15106)

This commit is contained in:
wozeparrot
2026-03-04 15:38:38 +08:00
committed by GitHub
parent 5ecfe549e7
commit 759c7fc81c

View File

@@ -1,5 +1,6 @@
import gc, unittest
from tinygrad import Tensor, GlobalCounters, dtypes
from tinygrad.engine.jit import TinyJit
class TestMultiRamUsage(unittest.TestCase):
def setUp(self):
@@ -107,6 +108,38 @@ class TestMultiRamUsage(unittest.TestCase):
def test_matmul_half(self): self._test_matmul_half(dev_count=2)
def test_matmul_half_alt(self): self._test_matmul_half(dev_count=4)
@unittest.expectedFailure
def test_multi_layer_allreduce(self):
N = 32
devices_2 = ("NULL:1", "NULL:2")
def make_inp():
x = Tensor.zeros(N, N).contiguous().shard(devices_2, axis=None).realize()
w1 = Tensor.zeros(N, N).contiguous().shard(devices_2, axis=1).realize()
w2 = Tensor.zeros(N, N).contiguous().shard(devices_2, axis=0).realize()
return x, w1, w2
def run_layers(n_layers):
GlobalCounters.reset()
@TinyJit
def f(x, w1, w2):
for _ in range(n_layers):
x = (x @ w1 @ w2)
return x.contiguous()
for _ in range(3):
a = make_inp()
r = f(*a)
del a, r
gc.collect()
return GlobalCounters.mem_used
mem_2 = run_layers(2)
mem_4 = run_layers(4)
self.assertEqual(mem_2, mem_4, f"graph memory should not grow with layers: 2 layers={mem_2}, 4 layers={mem_4}")
class TestMultiAxis(unittest.TestCase):
def test_reshape_shard_invalid(self):
devices = ("NULL:0", "NULL:1")