move realize_map fixup into realize_assign_src [pr] (#14990)

This commit is contained in:
chenyu
2026-02-24 15:51:40 -05:00
committed by GitHub
parent 9d9151a21e
commit 8dae9be573

View File

@@ -18,6 +18,10 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
def realize_assign_src(ctx:dict[UOp, None], buf:UOp, x:UOp):
# don't realize COPY/BUFFER_VIEW/ENCDEC when they are the direct source of ASSIGN — the ASSIGN target buffer is the output
if x.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC} and x in ctx \
and not buf.op_in_backward_slice_with_self(Ops.SHRINK, Ops.PERMUTE, Ops.FLIP, Ops.PAD):
del ctx[x]
# you don't usually have to do this for assign unless there's a WAR hazard like TestAssign.test_assign_double_diamond_reduce
if buf.base in x.backward_slice_with_self: ctx[x] = None
@@ -162,12 +166,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
rctx = IndexingContext()
# get ops to realize
graph_rewrite(tsink, pm_generate_realize_map, ctx=rctx.realize_map, bottom_up=True, name="get realize")
# don't realize COPY/BUFFER_VIEW/ENCDEC when they are the direct source of ASSIGN — the ASSIGN target buffer is the output
for u in tsink.toposort():
if u.op is Ops.ASSIGN and u.src[1].op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC} and u.src[1] in rctx.realize_map \
and not u.src[0].op_in_backward_slice_with_self(Ops.SHRINK, Ops.PERMUTE, Ops.FLIP, Ops.PAD):
del rctx.realize_map[u.src[1]]
graph_rewrite(tsink, pm_generate_realize_map, ctx=rctx.realize_map, name="get realize")
# get the consumer map
with cpu_profile("consumer map in rangeify", "TINY"):