mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
move realize_map fixup into realize_assign_src [pr] (#14990)
This commit is contained in:
@@ -18,6 +18,10 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
|
||||
if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
|
||||
|
||||
def realize_assign_src(ctx:dict[UOp, None], buf:UOp, x:UOp):
|
||||
# don't realize COPY/BUFFER_VIEW/ENCDEC when they are the direct source of ASSIGN — the ASSIGN target buffer is the output
|
||||
if x.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC} and x in ctx \
|
||||
and not buf.op_in_backward_slice_with_self(Ops.SHRINK, Ops.PERMUTE, Ops.FLIP, Ops.PAD):
|
||||
del ctx[x]
|
||||
# you don't usually have to do this for assign unless there's a WAR hazard like TestAssign.test_assign_double_diamond_reduce
|
||||
if buf.base in x.backward_slice_with_self: ctx[x] = None
|
||||
|
||||
@@ -162,12 +166,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
rctx = IndexingContext()
|
||||
|
||||
# get ops to realize
|
||||
graph_rewrite(tsink, pm_generate_realize_map, ctx=rctx.realize_map, bottom_up=True, name="get realize")
|
||||
# don't realize COPY/BUFFER_VIEW/ENCDEC when they are the direct source of ASSIGN — the ASSIGN target buffer is the output
|
||||
for u in tsink.toposort():
|
||||
if u.op is Ops.ASSIGN and u.src[1].op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC} and u.src[1] in rctx.realize_map \
|
||||
and not u.src[0].op_in_backward_slice_with_self(Ops.SHRINK, Ops.PERMUTE, Ops.FLIP, Ops.PAD):
|
||||
del rctx.realize_map[u.src[1]]
|
||||
graph_rewrite(tsink, pm_generate_realize_map, ctx=rctx.realize_map, name="get realize")
|
||||
|
||||
# get the consumer map
|
||||
with cpu_profile("consumer map in rangeify", "TINY"):
|
||||
|
||||
Reference in New Issue
Block a user