mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
schedule sink folding try 2 [pr] (#7968)
This commit is contained in:
@@ -1936,15 +1936,15 @@ class TestBigGraph(unittest.TestCase):
|
||||
def test_sink_childless_const(self):
|
||||
x = UOp.const(dtypes.int, 0)
|
||||
big_graph = big_graph_rewrite(x.sink(), realizes:={})
|
||||
self.assertIs(big_graph, UOp(Ops.SINK, dtypes.void, (x,)))
|
||||
self.assertIs(big_graph, UOp(Ops.NOOP))
|
||||
self.assertEqual(len(realizes), 0)
|
||||
|
||||
def test_sink_childless_const_alt(self):
|
||||
x = UOp.const(dtypes.int, 0)
|
||||
y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(()))
|
||||
big_graph = big_graph_rewrite(UOp.sink(x, y), realizes:={})
|
||||
self.assertIs(big_graph, UOp(Ops.SINK, dtypes.void, (x, y)))
|
||||
self.assertEqual(len(realizes), 1) # TODO: this should fold a flat CONST
|
||||
self.assertIs(big_graph, UOp(Ops.NOOP))
|
||||
self.assertEqual(len(realizes), 0)
|
||||
|
||||
def test_sink_childless_const_alt_expanded(self):
|
||||
# this is a real STORE of CONST (post expand)
|
||||
|
||||
@@ -350,9 +350,13 @@ def fold_img_cast(ctx:Dict[UOp, UOp], xb:UOp, view:UOp, b:UOp, to_cast:UOp, **kw
|
||||
del ctx[b]
|
||||
return to_cast.view(unwrap(view.st))
|
||||
|
||||
def init_big_graph(ctx:ScheduleContext, sink:UOp) -> Optional[UOp]:
|
||||
new_src = tuple(x for x in sink.src if is_scheduled(x) and uval(x).op is not Ops.CONST)
|
||||
return None if new_src == sink.src else UOp(Ops.NOOP) if len(new_src) == 0 else UOp.sink(*new_src)
|
||||
|
||||
do_realize = PatternMatcher([
|
||||
# always realize sinked ops
|
||||
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.update((x.buf_uop, x) for x in sink.src if is_scheduled(x))),
|
||||
(UPat(Ops.SINK, name="sink"), init_big_graph),
|
||||
# always realize meta ops
|
||||
(UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}), realize),
|
||||
# realize before expand or unsafe pad ops
|
||||
@@ -394,6 +398,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
cache: Dict[LazyBuffer, UOp] = {}
|
||||
buffers: Dict[UOp, Buffer] = {}
|
||||
big_graph = graph_rewrite(UOp.sink(*(to_uop(x, ctx, buffers, cache) for x in outs)), do_realize, ctx.realizes)
|
||||
for u in big_graph.src: ctx.realizes[u.buf_uop] = u
|
||||
# group realizes into kernels
|
||||
store_groups = group_realizes(ctx)
|
||||
graph_rewrite(big_graph, break_sched, ctx)
|
||||
|
||||
Reference in New Issue
Block a user