From 784b919f7f7614898061654eafaee8939ea264c0 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 27 Dec 2025 21:10:23 -0500 Subject: [PATCH] Revert "optim empty shard #13513 (#13598)" (#13855) * Revert "optim empty shard #13513 (#13598)" This reverts commit 76d465dbc35fc1f47c3d19c8e2f4f6652a891084. * test_arange_shrink * update test --- test/test_multitensor.py | 6 ++++++ tinygrad/schedule/multi.py | 2 -- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index c68531654e..e090766904 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -57,12 +57,18 @@ class TestMultiTensor(unittest.TestCase): assert lb.shape == (128,) (X + X).realize() + @unittest.expectedFailure # TODO: fix def test_shard_empty(self): GlobalCounters.reset() X = Tensor.empty(256).shard(devices_2, 0).realize() assert GlobalCounters.kernel_count == 0 (X + X).realize() + def test_arange_shrink(self): + x = Tensor.arange(4) + self.assertEqual(x.shard(devices_2, 0).realize().shrink(((2, 4),)).tolist(), [2, 3]) + self.assertEqual(x.shard(devices_2, 0).realize().shrink(((0, 2),)).tolist(), [0, 1]) + def test_shard_like(self): X = Tensor.ones(256).shard(devices_2, 0) Y = Tensor.zeros(256).shard_like(X) diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index dbd9c7b098..145b3850be 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -102,8 +102,6 @@ def mstack_early_shrink(ms:UOp, shrink:UOp): replace_allreduce = PatternMatcher([ (UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce_multirank), (UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce), - (UPat(Ops.COPY, src=(UPat(Ops.BUFFER, name="buf"), UPat(Ops.DEVICE, name="dev"))),lambda buf,dev: UOp.new_buffer(dev.arg, buf.arg, buf.dtype) - if buf.device not in {"DISK", "NPY"} and isinstance(dev.arg, tuple) and isinstance(buf.device, str) else None), # BROADCAST: explicitly expand broadcast copies and combine with MSTACK (UPat(Ops.COPY, name="c", src=(UPat(GroupOp.All-{Ops.CONST}, name="x"), UPat(Ops.DEVICE))), lambda c,x: UOp(Ops.MSTACK, c.dtype, tuple(x.copy_to_device(d) for d in c.device)) if isinstance(c.device, tuple) and isinstance(x.device, str) else None),