From 6d07087fe141e2fe76b42e1733f96e0d851b87b0 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 26 May 2025 19:19:01 +0300 Subject: [PATCH] remove contiguous from MSELECT 2 (#10522) * remove contiguous from MSELECT * test_shrink_on_shard_axis --------- Co-authored-by: George Hotz --- test/test_multitensor.py | 14 ++++++++++++++ tinygrad/engine/grouper.py | 10 ++++++++++ tinygrad/engine/schedule.py | 5 +++-- tinygrad/uop/ops.py | 3 +-- 4 files changed, 28 insertions(+), 4 deletions(-) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index b8dcf3f610..87ed97a546 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -163,6 +163,20 @@ class TestMultiTensor(unittest.TestCase): O = X.shrink(((0, 2), None)) * W.shrink(((0, 2), None)) < 2 np.testing.assert_allclose(O.numpy(), X.numpy()[0:2]*W.numpy()[0:2] < 2) + def test_shrink_on_shard_axis(self): + X = Tensor.arange(4*4).reshape(4,4).realize() + X_np = X.numpy() + X.shard_(devices_2, 0) + # only shrink on the device that owns the shard, this is enabled by the mselect simplifier + for i in range(2): + xt = X[i*2:i*2+2].contiguous() + sched = xt.schedule() + kernels = [s for s in sched if s.ast.op is Ops.SINK] + self.assertEqual(len(kernels), 1) + self.assertEqual(kernels[0].bufs[0].device, devices_2[i]) + run_schedule(sched) + np.testing.assert_equal(xt.numpy(), X_np[i*2:i*2+2]) + @given(strat.sampled_from((4, 5)), strat.sampled_from((devices_2, devices_3)), strat.sampled_from((Ops.ADD, Ops.MUL, Ops.MAX)), strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)), strat.sampled_from((1, 0, -1))) diff --git a/tinygrad/engine/grouper.py b/tinygrad/engine/grouper.py index f6feb8db04..fc71a7fa86 100644 --- a/tinygrad/engine/grouper.py +++ b/tinygrad/engine/grouper.py @@ -71,6 +71,14 @@ def copy_reorder_view(copy:UOp, view:UOp, base:UOp): if prod(view.shape) < prod(base.shape): return view.contiguous().copy_to_device(copy.device) return base.copy_to_device(copy.device).view(view.arg) +def mselect_reorder_view(ms:UOp, view:UOp, base:UOp): + st = unwrap(view.st) + # replace dnum in ShapeTracker with literal const for this mselect + if (dnums:=[x for x in st.vars() if x.arg[0] == '_device_num']): + assert len(dnums) == 1, f"view must have exactly 0 or 1 dnum, got {dnums}" + st = st.substitute({dnums[0]:dnums[0].const_like(ms.arg)}) + return base.mselect(ms.arg).view(st) + ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT} sym = symbolic_simple+PatternMatcher([ @@ -95,6 +103,8 @@ sym = symbolic_simple+PatternMatcher([ (UPat(Ops.COPY, name="c", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda c,x: x if c.device == x.device else None), # store a shrink before COPY, otherwise view after the COPY (UPat(Ops.COPY, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"), UPat(Ops.DEVICE)), name="copy"), copy_reorder_view), + # MSELECT must select a base, if there are views apply them after selecting the base + (UPat(Ops.MSELECT, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"),), name="ms"), mselect_reorder_view), # remove cast to image when it's already a contiguous image (UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"),)),)), lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None), diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 1f8796fa7e..ea074ae9df 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -48,8 +48,9 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[ children[s.src[1]].append(k) in_degree[k] += 1 elif s.op is Ops.MSELECT: - children[s.src[0].src[1]].append(k) - in_degree[k] += 1 + if s.src[0].op is not Ops.BUFFER: + children[s.src[0].src[1]].append(k) + in_degree[k] += 1 elif s.op is Ops.BUFFER: pass # a BUFFER is already realized, nothing to do here else: diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 5c19fa9044..1812bca9e1 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -490,8 +490,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val) def copy_to_device(self, device:str|tuple[str, ...]|UOp, arg=None): assert arg is None or isinstance(self.device, tuple) - # TODO: this contiguous should not be required!!! - inp = self if arg is None else UOp(Ops.MSELECT, self.dtype, src=(self.contiguous(),), arg=arg) + inp = self if arg is None else UOp(Ops.MSELECT, self.dtype, src=(self,), arg=arg) return UOp(Ops.COPY, self.dtype, (inp, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device)) def mselect(self, arg:int) -> UOp: return UOp(Ops.MSELECT, self.dtype, (self,), arg) @property