mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
remove contiguous from MSELECT 2 (#10522)
* remove contiguous from MSELECT * test_shrink_on_shard_axis --------- Co-authored-by: George Hotz <geohot@gmail.com>
This commit is contained in:
@@ -163,6 +163,20 @@ class TestMultiTensor(unittest.TestCase):
|
||||
O = X.shrink(((0, 2), None)) * W.shrink(((0, 2), None)) < 2
|
||||
np.testing.assert_allclose(O.numpy(), X.numpy()[0:2]*W.numpy()[0:2] < 2)
|
||||
|
||||
def test_shrink_on_shard_axis(self):
|
||||
X = Tensor.arange(4*4).reshape(4,4).realize()
|
||||
X_np = X.numpy()
|
||||
X.shard_(devices_2, 0)
|
||||
# only shrink on the device that owns the shard, this is enabled by the mselect simplifier
|
||||
for i in range(2):
|
||||
xt = X[i*2:i*2+2].contiguous()
|
||||
sched = xt.schedule()
|
||||
kernels = [s for s in sched if s.ast.op is Ops.SINK]
|
||||
self.assertEqual(len(kernels), 1)
|
||||
self.assertEqual(kernels[0].bufs[0].device, devices_2[i])
|
||||
run_schedule(sched)
|
||||
np.testing.assert_equal(xt.numpy(), X_np[i*2:i*2+2])
|
||||
|
||||
@given(strat.sampled_from((4, 5)), strat.sampled_from((devices_2, devices_3)),
|
||||
strat.sampled_from((Ops.ADD, Ops.MUL, Ops.MAX)),
|
||||
strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)), strat.sampled_from((1, 0, -1)))
|
||||
|
||||
@@ -71,6 +71,14 @@ def copy_reorder_view(copy:UOp, view:UOp, base:UOp):
|
||||
if prod(view.shape) < prod(base.shape): return view.contiguous().copy_to_device(copy.device)
|
||||
return base.copy_to_device(copy.device).view(view.arg)
|
||||
|
||||
def mselect_reorder_view(ms:UOp, view:UOp, base:UOp):
|
||||
st = unwrap(view.st)
|
||||
# replace dnum in ShapeTracker with literal const for this mselect
|
||||
if (dnums:=[x for x in st.vars() if x.arg[0] == '_device_num']):
|
||||
assert len(dnums) == 1, f"view must have exactly 0 or 1 dnum, got {dnums}"
|
||||
st = st.substitute({dnums[0]:dnums[0].const_like(ms.arg)})
|
||||
return base.mselect(ms.arg).view(st)
|
||||
|
||||
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT}
|
||||
|
||||
sym = symbolic_simple+PatternMatcher([
|
||||
@@ -95,6 +103,8 @@ sym = symbolic_simple+PatternMatcher([
|
||||
(UPat(Ops.COPY, name="c", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda c,x: x if c.device == x.device else None),
|
||||
# store a shrink before COPY, otherwise view after the COPY
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"), UPat(Ops.DEVICE)), name="copy"), copy_reorder_view),
|
||||
# MSELECT must select a base, if there are views apply them after selecting the base
|
||||
(UPat(Ops.MSELECT, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"),), name="ms"), mselect_reorder_view),
|
||||
# remove cast to image when it's already a contiguous image
|
||||
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"),)),)),
|
||||
lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
|
||||
|
||||
@@ -48,8 +48,9 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
|
||||
children[s.src[1]].append(k)
|
||||
in_degree[k] += 1
|
||||
elif s.op is Ops.MSELECT:
|
||||
children[s.src[0].src[1]].append(k)
|
||||
in_degree[k] += 1
|
||||
if s.src[0].op is not Ops.BUFFER:
|
||||
children[s.src[0].src[1]].append(k)
|
||||
in_degree[k] += 1
|
||||
elif s.op is Ops.BUFFER:
|
||||
pass # a BUFFER is already realized, nothing to do here
|
||||
else:
|
||||
|
||||
@@ -490,8 +490,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val)
|
||||
def copy_to_device(self, device:str|tuple[str, ...]|UOp, arg=None):
|
||||
assert arg is None or isinstance(self.device, tuple)
|
||||
# TODO: this contiguous should not be required!!!
|
||||
inp = self if arg is None else UOp(Ops.MSELECT, self.dtype, src=(self.contiguous(),), arg=arg)
|
||||
inp = self if arg is None else UOp(Ops.MSELECT, self.dtype, src=(self,), arg=arg)
|
||||
return UOp(Ops.COPY, self.dtype, (inp, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device))
|
||||
def mselect(self, arg:int) -> UOp: return UOp(Ops.MSELECT, self.dtype, (self,), arg)
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user