schedule sink folding with graph_rewrite [pr] (#7963)

* schedule sink folding with graph_rewrite [pr]

* x is reserved, use u

* match lazy const folding
This commit is contained in:
qazal
2024-11-30 05:30:41 -05:00
committed by GitHub
parent 10f431b96d
commit 4529c5d0da
3 changed files with 19 additions and 7 deletions

View File

@@ -1936,14 +1936,22 @@ class TestBigGraph(unittest.TestCase):
def test_sink_childless_const(self):
x = UOp.const(dtypes.int, 0)
big_graph = big_graph_rewrite(x.sink(), realizes:={})
self.assertIs(big_graph, UOp(Ops.SINK, dtypes.void, (x,)))
self.assertIs(big_graph, UOp(Ops.NOOP))
self.assertEqual(len(realizes), 0)
def test_sink_childless_const_alt(self):
x = UOp.const(dtypes.int, 0)
y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape((10, 10)))
y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(()))
big_graph = big_graph_rewrite(UOp.sink(x, y), realizes:={})
self.assertIs(big_graph, UOp(Ops.SINK, dtypes.void, (x, y)))
self.assertIs(big_graph, UOp(Ops.NOOP))
self.assertEqual(len(realizes), 0)
def test_sink_childless_const_alt_expanded(self):
# this is a real STORE of CONST
y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(()))
out = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), y.reshape((1,)).expand((2,)).contiguous(),), ShapeTracker.from_shape((2,)))
big_graph = big_graph_rewrite(out.sink(), realizes:={})
self.assertIs(big_graph, out.sink())
self.assertEqual(len(realizes), 1)
if __name__ == '__main__':

View File

@@ -350,9 +350,13 @@ def fold_img_cast(ctx:Dict[UOp, UOp], xb:UOp, view:UOp, b:UOp, to_cast:UOp, **kw
del ctx[b]
return to_cast.view(unwrap(view.st))
def init_big_graph(sink:UOp) -> Optional[UOp]:
new_src = tuple(x.base for x in sink.src if is_scheduled(x.base) and uval(x.base).op is not Ops.CONST and x.base.size != 0)
return None if new_src == sink.src else UOp(Ops.NOOP) if len(new_src) == 0 else UOp.sink(*new_src)
do_realize = PatternMatcher([
# always realize sinked ops
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.update((x.buf_uop, x) for x in sink.src if is_scheduled(x))),
(UPat(Ops.SINK, name="sink"), init_big_graph),
# always realize meta ops
(UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}), realize),
# realize before expand or unsafe pad ops
@@ -388,12 +392,12 @@ break_sched = PatternMatcher([
@track_rewrites(named=True)
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
if len(outs:=dedup(x.base for x in outs if x.base.realized is None and x.base.op is not Ops.CONST)) == 0: return [], {}
# create the big graph
ctx = ScheduleContext()
cache: Dict[LazyBuffer, UOp] = {}
buffers: Dict[UOp, Buffer] = {}
big_graph = graph_rewrite(UOp.sink(*(to_uop(x, ctx, buffers, cache) for x in outs)), do_realize, ctx.realizes)
for u in big_graph.src: ctx.realizes[u.buf_uop] = u
# group realizes into kernels
store_groups = group_realizes(ctx)
graph_rewrite(big_graph, break_sched, ctx)

View File

@@ -274,6 +274,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@functools.cached_property
def full_shape(self) -> Tuple[sint, ...]:
return self.arg.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
@property
def size(self) -> int: return self.arg[1][1] if self.op is Ops.BUFFER else unwrap(self.st).size
# *** uop evaluation ***
@@ -384,8 +386,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
case Ops.BUFFER: return self.arg[1][0]
case _: return self.src[0].device
@property
def size(self) -> int: return self.buf_uop.arg[1][1]
@property
def buf_uop(self) -> UOp:
if self.op is Ops.BUFFER: return self
assert self.op in {*GroupOp.Buffer, Ops.ASSIGN, Ops.VIEW} and self.src[0].op is Ops.BUFFER, f"buf_uop called on {self.op}"