remove contiguous from MSELECT 2 (#10522)

* remove contiguous from MSELECT

* test_shrink_on_shard_axis

---------

Co-authored-by: George Hotz <geohot@gmail.com>
This commit is contained in:
qazal
2025-05-26 19:19:01 +03:00
committed by GitHub
parent 602a145f8f
commit 6d07087fe1
4 changed files with 28 additions and 4 deletions

View File

@@ -163,6 +163,20 @@ class TestMultiTensor(unittest.TestCase):
O = X.shrink(((0, 2), None)) * W.shrink(((0, 2), None)) < 2
np.testing.assert_allclose(O.numpy(), X.numpy()[0:2]*W.numpy()[0:2] < 2)
def test_shrink_on_shard_axis(self):
X = Tensor.arange(4*4).reshape(4,4).realize()
X_np = X.numpy()
X.shard_(devices_2, 0)
# only shrink on the device that owns the shard, this is enabled by the mselect simplifier
for i in range(2):
xt = X[i*2:i*2+2].contiguous()
sched = xt.schedule()
kernels = [s for s in sched if s.ast.op is Ops.SINK]
self.assertEqual(len(kernels), 1)
self.assertEqual(kernels[0].bufs[0].device, devices_2[i])
run_schedule(sched)
np.testing.assert_equal(xt.numpy(), X_np[i*2:i*2+2])
@given(strat.sampled_from((4, 5)), strat.sampled_from((devices_2, devices_3)),
strat.sampled_from((Ops.ADD, Ops.MUL, Ops.MAX)),
strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)), strat.sampled_from((1, 0, -1)))

View File

@@ -71,6 +71,14 @@ def copy_reorder_view(copy:UOp, view:UOp, base:UOp):
if prod(view.shape) < prod(base.shape): return view.contiguous().copy_to_device(copy.device)
return base.copy_to_device(copy.device).view(view.arg)
def mselect_reorder_view(ms:UOp, view:UOp, base:UOp):
st = unwrap(view.st)
# replace dnum in ShapeTracker with literal const for this mselect
if (dnums:=[x for x in st.vars() if x.arg[0] == '_device_num']):
assert len(dnums) == 1, f"view must have exactly 0 or 1 dnum, got {dnums}"
st = st.substitute({dnums[0]:dnums[0].const_like(ms.arg)})
return base.mselect(ms.arg).view(st)
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT}
sym = symbolic_simple+PatternMatcher([
@@ -95,6 +103,8 @@ sym = symbolic_simple+PatternMatcher([
(UPat(Ops.COPY, name="c", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda c,x: x if c.device == x.device else None),
# store a shrink before COPY, otherwise view after the COPY
(UPat(Ops.COPY, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"), UPat(Ops.DEVICE)), name="copy"), copy_reorder_view),
# MSELECT must select a base, if there are views apply them after selecting the base
(UPat(Ops.MSELECT, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"),), name="ms"), mselect_reorder_view),
# remove cast to image when it's already a contiguous image
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"),)),)),
lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),

View File

@@ -48,8 +48,9 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
children[s.src[1]].append(k)
in_degree[k] += 1
elif s.op is Ops.MSELECT:
children[s.src[0].src[1]].append(k)
in_degree[k] += 1
if s.src[0].op is not Ops.BUFFER:
children[s.src[0].src[1]].append(k)
in_degree[k] += 1
elif s.op is Ops.BUFFER:
pass # a BUFFER is already realized, nothing to do here
else:

View File

@@ -490,8 +490,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val)
def copy_to_device(self, device:str|tuple[str, ...]|UOp, arg=None):
assert arg is None or isinstance(self.device, tuple)
# TODO: this contiguous should not be required!!!
inp = self if arg is None else UOp(Ops.MSELECT, self.dtype, src=(self.contiguous(),), arg=arg)
inp = self if arg is None else UOp(Ops.MSELECT, self.dtype, src=(self,), arg=arg)
return UOp(Ops.COPY, self.dtype, (inp, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device))
def mselect(self, arg:int) -> UOp: return UOp(Ops.MSELECT, self.dtype, (self,), arg)
@property