skip copies of reshaped buffers

This commit is contained in:
George Hotz
2025-10-03 10:55:58 +08:00
parent 9cd365c12e
commit a734437da8

View File

@@ -46,6 +46,8 @@ earliest_rewrites = PatternMatcher([
# just removing it works...
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]),
(UPat(Ops.BUFFER).f(Ops.RESHAPE, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
# split_reduceop
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
@@ -736,7 +738,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
# rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph
# MSTACK stacks multiple BUFFERIZEs in one tagged tensor
# if it's not tagged by here, it's out
tsink = UOp.sink(*[x for x in tsink.parents if x.base.op in {Ops.BUFFERIZE, Ops.MSTACK, Ops.CONST} and x.tag is not None])
tsink = UOp.sink(*[x for x in tsink.parents if x.base.op in {Ops.BUFFERIZE, Ops.MSTACK, Ops.CONST, Ops.BUFFER} and x.tag is not None])
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify")