mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
failing test for allreduce memory usage (#15106)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import gc, unittest
|
||||
from tinygrad import Tensor, GlobalCounters, dtypes
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
|
||||
class TestMultiRamUsage(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@@ -107,6 +108,38 @@ class TestMultiRamUsage(unittest.TestCase):
|
||||
def test_matmul_half(self): self._test_matmul_half(dev_count=2)
|
||||
def test_matmul_half_alt(self): self._test_matmul_half(dev_count=4)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_multi_layer_allreduce(self):
|
||||
N = 32
|
||||
devices_2 = ("NULL:1", "NULL:2")
|
||||
|
||||
def make_inp():
|
||||
x = Tensor.zeros(N, N).contiguous().shard(devices_2, axis=None).realize()
|
||||
w1 = Tensor.zeros(N, N).contiguous().shard(devices_2, axis=1).realize()
|
||||
w2 = Tensor.zeros(N, N).contiguous().shard(devices_2, axis=0).realize()
|
||||
return x, w1, w2
|
||||
|
||||
def run_layers(n_layers):
|
||||
GlobalCounters.reset()
|
||||
|
||||
@TinyJit
|
||||
def f(x, w1, w2):
|
||||
for _ in range(n_layers):
|
||||
x = (x @ w1 @ w2)
|
||||
return x.contiguous()
|
||||
|
||||
for _ in range(3):
|
||||
a = make_inp()
|
||||
r = f(*a)
|
||||
del a, r
|
||||
|
||||
gc.collect()
|
||||
return GlobalCounters.mem_used
|
||||
|
||||
mem_2 = run_layers(2)
|
||||
mem_4 = run_layers(4)
|
||||
self.assertEqual(mem_2, mem_4, f"graph memory should not grow with layers: 2 layers={mem_2}, 4 layers={mem_4}")
|
||||
|
||||
class TestMultiAxis(unittest.TestCase):
|
||||
def test_reshape_shard_invalid(self):
|
||||
devices = ("NULL:0", "NULL:1")
|
||||
|
||||
Reference in New Issue
Block a user