diff --git a/test/test_schedule.py b/test/test_schedule.py index 9e140fc38c..a7ad17ea9c 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1936,22 +1936,14 @@ class TestBigGraph(unittest.TestCase): def test_sink_childless_const(self): x = UOp.const(dtypes.int, 0) big_graph = big_graph_rewrite(x.sink(), realizes:={}) - self.assertIs(big_graph, UOp(Ops.NOOP)) + self.assertIs(big_graph, UOp(Ops.SINK, dtypes.void, (x,))) self.assertEqual(len(realizes), 0) def test_sink_childless_const_alt(self): x = UOp.const(dtypes.int, 0) - y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(())) + y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape((10, 10))) big_graph = big_graph_rewrite(UOp.sink(x, y), realizes:={}) - self.assertIs(big_graph, UOp(Ops.NOOP)) - self.assertEqual(len(realizes), 0) - - def test_sink_childless_const_alt_expanded(self): - # this is a real STORE of CONST - y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(())) - out = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), y.reshape((1,)).expand((2,)).contiguous(),), ShapeTracker.from_shape((2,))) - big_graph = big_graph_rewrite(out.sink(), realizes:={}) - self.assertIs(big_graph, out.sink()) + self.assertIs(big_graph, UOp(Ops.SINK, dtypes.void, (x, y))) self.assertEqual(len(realizes), 1) if __name__ == '__main__': diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index fcfa57a1d9..1bb662fa7c 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -350,13 +350,9 @@ def fold_img_cast(ctx:Dict[UOp, UOp], xb:UOp, view:UOp, b:UOp, to_cast:UOp, **kw del ctx[b] return to_cast.view(unwrap(view.st)) -def init_big_graph(sink:UOp) -> Optional[UOp]: - new_src = tuple(x.base for x in sink.src if is_scheduled(x.base) and uval(x.base).op is not Ops.CONST and x.base.size != 0) - return None if new_src == sink.src else UOp(Ops.NOOP) if len(new_src) == 0 else UOp.sink(*new_src) - do_realize = PatternMatcher([ # always realize sinked ops - (UPat(Ops.SINK, name="sink"), init_big_graph), + (UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.update((x.buf_uop, x) for x in sink.src if is_scheduled(x))), # always realize meta ops (UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}), realize), # realize before expand or unsafe pad ops @@ -392,12 +388,12 @@ break_sched = PatternMatcher([ @track_rewrites(named=True) def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]: + if len(outs:=dedup(x.base for x in outs if x.base.realized is None and x.base.op is not Ops.CONST)) == 0: return [], {} # create the big graph ctx = ScheduleContext() cache: Dict[LazyBuffer, UOp] = {} buffers: Dict[UOp, Buffer] = {} big_graph = graph_rewrite(UOp.sink(*(to_uop(x, ctx, buffers, cache) for x in outs)), do_realize, ctx.realizes) - for u in big_graph.src: ctx.realizes[u.buf_uop] = u # group realizes into kernels store_groups = group_realizes(ctx) graph_rewrite(big_graph, break_sched, ctx) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index fa19bbd22f..ef2c52d59e 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -274,8 +274,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @functools.cached_property def full_shape(self) -> Tuple[sint, ...]: return self.arg.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st])) - @property - def size(self) -> int: return self.arg[1][1] if self.op is Ops.BUFFER else unwrap(self.st).size # *** uop evaluation *** @@ -386,6 +384,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): case Ops.BUFFER: return self.arg[1][0] case _: return self.src[0].device @property + def size(self) -> int: return self.buf_uop.arg[1][1] + @property def buf_uop(self) -> UOp: if self.op is Ops.BUFFER: return self assert self.op in {*GroupOp.Buffer, Ops.ASSIGN, Ops.VIEW} and self.src[0].op is Ops.BUFFER, f"buf_uop called on {self.op}"