diff --git a/test/test_schedule.py b/test/test_schedule.py index f3c0d14da1..7647e6ff7c 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1992,8 +1992,6 @@ class TestBigGraph(unittest.TestCase): sink = tensor_rewrite(a) assert UPat.cvar().match(sink, {}) - # failure: View doesn't support __lt__, UOp.tuplize needs it. - @unittest.expectedFailure def test_masked_const_elementwise(self): a = Tensor.eye(10)@Tensor.eye(10) sink = tensor_rewrite(a) diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index d3b75cfcc0..550f383aa5 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -87,6 +87,12 @@ class View: mask:Optional[tuple[tuple[sint, sint], ...]] contiguous:bool + @functools.cached_property + def t(self): + return tuple(x.tuplize if isinstance(x, UOp) else (x,) \ + for x in self.shape+self.strides+(self.offset,)+(tuple(flatten(self.mask)) if self.mask is not None else tuple())) + def __lt__(self, o:View): return self.t < o.t + def to_indexed_uops(self:View, idxs:Optional[Sequence[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> tuple[UOp, UOp]: """(idx, valid)""" if idxs is None: idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)]