mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 14:43:57 -05:00
* Revert "optim empty shard #13513 (#13598)"
This reverts commit 76d465dbc3.
* test_arange_shrink
* update test
This commit is contained in:
@@ -57,12 +57,18 @@ class TestMultiTensor(unittest.TestCase):
|
||||
assert lb.shape == (128,)
|
||||
(X + X).realize()
|
||||
|
||||
@unittest.expectedFailure # TODO: fix
|
||||
def test_shard_empty(self):
|
||||
GlobalCounters.reset()
|
||||
X = Tensor.empty(256).shard(devices_2, 0).realize()
|
||||
assert GlobalCounters.kernel_count == 0
|
||||
(X + X).realize()
|
||||
|
||||
def test_arange_shrink(self):
|
||||
x = Tensor.arange(4)
|
||||
self.assertEqual(x.shard(devices_2, 0).realize().shrink(((2, 4),)).tolist(), [2, 3])
|
||||
self.assertEqual(x.shard(devices_2, 0).realize().shrink(((0, 2),)).tolist(), [0, 1])
|
||||
|
||||
def test_shard_like(self):
|
||||
X = Tensor.ones(256).shard(devices_2, 0)
|
||||
Y = Tensor.zeros(256).shard_like(X)
|
||||
|
||||
@@ -102,8 +102,6 @@ def mstack_early_shrink(ms:UOp, shrink:UOp):
|
||||
replace_allreduce = PatternMatcher([
|
||||
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce_multirank),
|
||||
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.BUFFER, name="buf"), UPat(Ops.DEVICE, name="dev"))),lambda buf,dev: UOp.new_buffer(dev.arg, buf.arg, buf.dtype)
|
||||
if buf.device not in {"DISK", "NPY"} and isinstance(dev.arg, tuple) and isinstance(buf.device, str) else None),
|
||||
# BROADCAST: explicitly expand broadcast copies and combine with MSTACK
|
||||
(UPat(Ops.COPY, name="c", src=(UPat(GroupOp.All-{Ops.CONST}, name="x"), UPat(Ops.DEVICE))), lambda c,x:
|
||||
UOp(Ops.MSTACK, c.dtype, tuple(x.copy_to_device(d) for d in c.device)) if isinstance(c.device, tuple) and isinstance(x.device, str) else None),
|
||||
|
||||
Reference in New Issue
Block a user