graph_rewrite_map in the other order [pr] (#10476)

* graph_rewrite_map in the other order [pr]

* reversed to preserve behavior
This commit is contained in:
George Hotz
2025-05-22 20:22:07 -07:00
committed by GitHub
parent 9fc01c1e03
commit d2bb50d75b

View File

@@ -988,8 +988,8 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=N
@track_matches
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, input_map:dict[UOp, UOp]|None=None) -> dict[UOp, UOp]:
rewrite_ctx = RewriteContext(pm, ctx)
new_map = {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort())[::-1]}
all_metadata.update((v, k.metadata) for k,v in new_map.items() if k.metadata is not None)
new_map = {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in sink.toposort()}
all_metadata.update((v, k.metadata) for k,v in reversed(new_map.items()) if k.metadata is not None)
if input_map is not None:
for k,v in input_map.items(): new_map[k] = new_map.get(v,v)
return new_map