mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
skip copies of reshaped buffers
This commit is contained in:
@@ -46,6 +46,8 @@ earliest_rewrites = PatternMatcher([
|
||||
# just removing it works...
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]),
|
||||
|
||||
(UPat(Ops.BUFFER).f(Ops.RESHAPE, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
|
||||
|
||||
# split_reduceop
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
|
||||
|
||||
@@ -736,7 +738,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
# rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph
|
||||
# MSTACK stacks multiple BUFFERIZEs in one tagged tensor
|
||||
# if it's not tagged by here, it's out
|
||||
tsink = UOp.sink(*[x for x in tsink.parents if x.base.op in {Ops.BUFFERIZE, Ops.MSTACK, Ops.CONST} and x.tag is not None])
|
||||
tsink = UOp.sink(*[x for x in tsink.parents if x.base.op in {Ops.BUFFERIZE, Ops.MSTACK, Ops.CONST, Ops.BUFFER} and x.tag is not None])
|
||||
|
||||
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user