optim empty shard #13513 (#13598)

* optim empty shard

* remove tuple

* simplify

* lint

* lint2

* test

* remove original buffer unique id

* new rule

* reset shard

* update

* reset shard
This commit is contained in:
Nino Risteski
2025-12-09 18:28:36 +01:00
committed by GitHub
parent 47a170be2e
commit 76d465dbc3
2 changed files with 8 additions and 0 deletions

View File

@@ -57,6 +57,12 @@ class TestMultiTensor(unittest.TestCase):
assert lb.shape == (128,)
(X + X).realize()
def test_shard_empty(self):
GlobalCounters.reset()
X = Tensor.empty(256).shard(devices_2, 0).realize()
assert GlobalCounters.kernel_count == 0
(X + X).realize()
def _test_shard_op(self, op, out, n=4):
t = Tensor.ones(n).contiguous().realize().shard(devices_2, 0)
r = op(t).realize()

View File

@@ -102,6 +102,8 @@ def mstack_early_shrink(ms:UOp, shrink:UOp):
replace_allreduce = PatternMatcher([
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce_multirank),
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),
(UPat(Ops.COPY, src=(UPat(Ops.BUFFER, name="buf"), UPat(Ops.DEVICE, name="dev"))),lambda buf,dev: UOp.new_buffer(dev.arg, buf.arg, buf.dtype)
if buf.device not in {"DISK", "NPY"} and isinstance(dev.arg, tuple) and isinstance(buf.device, str) else None),
# BROADCAST: explicitly expand broadcast copies and combine with MSTACK
(UPat(Ops.COPY, name="c", src=(UPat(GroupOp.All-{Ops.CONST}, name="x"), UPat(Ops.DEVICE))), lambda c,x:
UOp(Ops.MSTACK, c.dtype, tuple(x.copy_to_device(d) for d in c.device)) if isinstance(c.device, tuple) and isinstance(x.device, str) else None),