delete uneven shard tests and mentions (#14867)

This commit is contained in:
chenyu
2026-02-18 14:10:33 -05:00
committed by GitHub
parent 1c8c17a593
commit f84a11bb9f
2 changed files with 2 additions and 114 deletions

View File

@@ -19,8 +19,8 @@ cifar_std = [0.24703225141799082, 0.24348516474564, 0.26158783926049628]
BS, STEPS = getenv("BS", 512), getenv("STEPS", 1000)
EVAL_BS = getenv("EVAL_BS", BS)
GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 1))]
assert BS % len(GPUS) == 0, f"{BS=} is not a multiple of {len(GPUS)=}, uneven multi GPU is slow"
assert EVAL_BS % len(GPUS) == 0, f"{EVAL_BS=} is not a multiple of {len(GPUS)=}, uneven multi GPU is slow"
assert BS % len(GPUS) == 0, f"{BS=} is not a multiple of {len(GPUS)=}"
assert EVAL_BS % len(GPUS) == 0, f"{EVAL_BS=} is not a multiple of {len(GPUS)=}"
class UnsyncedBatchNorm:
def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1, num_devices=len(GPUS)):

View File

@@ -654,54 +654,6 @@ class TestMultiTensor(unittest.TestCase):
assert isinstance(jf.jit_cache[4].prg, BufferCopy)
assert isinstance(jf.jit_cache[5].prg, graph_d1)
@unittest.skip("no longer supports uneven shard")
def test_uneven_shard(self):
for N in range(1, 6):
X = Tensor.rand(4, 1, 257).contiguous().realize()
n = X.numpy()
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
X.shard_(devices, 2)
np.testing.assert_equal(X.numpy(), n)
np.testing.assert_equal(X.reshape(2, 2, 257).numpy(), n.reshape((2, 2, 257)))
np.testing.assert_equal(X.shrink(((0,2), (0, 1), (0,257))).numpy(), n[0:2, 0:1, 0:257])
np.testing.assert_equal(X.expand((4, 4, 257)).numpy(), np.tile(n, (1, 4, 1)))
np.testing.assert_equal(X.permute((0, 2, 1)).numpy(), np.transpose(n, (0, 2, 1)))
@unittest.skip("no longer supports uneven shard")
def test_uneven_multiple_zeros(self):
for data in ([1, 2, 3, 4], [1, 2, 3], [1, 2], [1], []):
for N in (1, 2, 3, 4):
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
# make sure something is computed on each device
X = ((Tensor(data).shard(devices, axis=0) + 1).realize() - 1).realize()
np.testing.assert_equal(X.numpy(), data)
@unittest.skip("no longer supports uneven shard")
def test_uneven_shard_with_empty(self):
N = 4
X = Tensor.rand(16, 1, 3).contiguous().realize()
np_x = X.numpy()
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
# test empty shard
np.testing.assert_equal(X.shard(devices, 0).numpy(), np_x)
# test reshape with empty shard
np.testing.assert_equal(X.shard(devices, 0).reshape(8, 1, 6).numpy(), np_x.reshape(8, 1, 6))
@unittest.skip("no longer supports uneven shard")
def test_multiple_uneven_shard(self):
N = 4
X = Tensor.rand(4, 1, 257).contiguous().realize()
Y = Tensor.rand(4, 1, 257).contiguous().realize()
np_x, np_y = X.numpy(), Y.numpy()
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
X.shard_(devices, 2)
Y.shard_(devices, 2)
np.testing.assert_equal(X.numpy(), np_x)
np.testing.assert_equal(Y.numpy(), np_y)
np.testing.assert_equal((X + Y).numpy(), np_x + np_y)
def test_bn_ast_on_devices(self):
t = Tensor.empty((16, 64, 112, 112)).shard(devices_4, axis=0)
bn = nn.BatchNorm2d(64)
@@ -754,33 +706,6 @@ class TestMultiTensor(unittest.TestCase):
with self.assertRaises((AssertionError, ValueError)):
t0.reshape((26*15,7)).schedule()
@unittest.skip("no longer supports uneven shard")
def test_reshape_on_axis_uneven(self):
def reshape_helper(t0, t, t_axis):
assert t.uop.axis == t_axis
np.testing.assert_allclose(t0.reshape(t.shape).numpy(), t.numpy())
t0 = Tensor.rand((4, 42, 15)).shard(devices_3, axis=1, splits=[14, 7, 21])
# ok to reshape as long as elements remain on same device
reshape_helper(t0, t0.reshape(2, 2, 42, 3, 5), 2)
# split to the right
reshape_helper(t0, t0.reshape(2, 2, 6, 7, 15), 2)
# split off and merge to the right
reshape_helper(t0, t0.reshape(4, 6, 105), 1)
# really blend the axes together
reshape_helper(t0, t0.reshape(4, 30, 21), 1)
# split off 1-shape
reshape_helper(t0, t0.reshape(4, 1, 42, 15), 2)
reshape_helper(t0, t0.reshape(4, 6, 1, 7, 15), 1)
# assert if cannot maintain shard axis without moving items between devices
with self.assertRaises(AssertionError): t0.reshape(4, 7, 6, 15)
# assert for degenerate reshape
with self.assertRaises(AssertionError): t0.reshape(4, 5, 7, 15)
# assert for cannot maintain axis
with self.assertRaises(AssertionError): t0.reshape(4, 3, 2, 7, 15)
# it doesn't work like this anymore
# NOTE: this never failed in assign_multi, it failed tensor spec because MULTI was never pushed in the graph
@unittest.skip("this test is broken")
@@ -849,16 +774,6 @@ class TestMultiTensor(unittest.TestCase):
self.assertEqual(rab.device, devices_4)
self.assertEqual(rab.uop.axis, 0)
@unittest.skip("no longer supports uneven shard")
def test_rand_like_uneven_shard(self):
t = Tensor.empty((4, 42, 15)).shard(devices_3, axis=1)
t2 = Tensor.rand_like(t)
self.assertEqual(t.shape, t2.shape)
self.assertEqual(t.device, t2.device)
self.assertEqual(t.dtype, t2.dtype)
self.assertEqual(t.uop.axis, t2.uop.axis)
assert all(tlb.shape == t2lb.shape for tlb, t2lb in zip(t.uop.src, t2.uop.src))
def test_rand_like_none_shard(self):
t = Tensor.empty((16, 16)).shard(devices_2)
t2 = Tensor.rand_like(t)
@@ -910,15 +825,6 @@ class TestMultiTensor(unittest.TestCase):
assert set(unique) == {0, 2}, unique
assert 200 < counts[0] < 312, counts[0]
@unittest.skip("no longer supports uneven shard")
def test_dropout_on_uneven_shard_axis(self):
with Tensor.train():
X = Tensor.ones(256).shard(devices_3, axis=0)
output = X.dropout(0.5).numpy()
unique, counts = np.unique(output, return_counts=True)
assert set(unique) == {0, 2}, unique
assert 100 < counts[0] < 156, counts[0]
@unittest.skip("TODO: this requires forced_realize to be deleted.")
def test_shard_memory(self):
devices = (d0, d1, d2, d3)
@@ -1042,24 +948,6 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
np.testing.assert_allclose(a.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), b.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.flip(-1).numpy(), b.flip(-1).numpy(), rtol=1e-7, atol=1e-3)
@unittest.skip("no longer supports uneven shard")
def test_uneven(self):
t = Tensor.arange(24).reshape(3, 8).contiguous().realize()
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(2)], axis=0)
a = t.shrink(((0, 2), None))
b = t.shrink(((2, 3), None))
na = t.numpy()[0:2]
nb = t.numpy()[2:3]
np.testing.assert_equal(a.numpy(), na)
np.testing.assert_equal(b.numpy(), nb)
np.testing.assert_equal((a+1).numpy(), na+1)
np.testing.assert_equal((b+1).numpy(), nb+1)
np.testing.assert_equal((1+a).numpy(), 1+na)
np.testing.assert_equal((1+b).numpy(), 1+nb)
np.testing.assert_equal((a+a).numpy(), na+na)
np.testing.assert_equal((b+b).numpy(), nb+nb)
def test_add_two_partitions(self):
t = Tensor.arange(64).reshape(8, 8).contiguous().realize()
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0)