viz: pickle UPat location (#11086)

This commit is contained in:
qazal
2025-07-04 13:09:00 +03:00
committed by GitHub
parent 2403f126ed
commit f6d55d9272
2 changed files with 5 additions and 4 deletions

View File

@@ -763,7 +763,7 @@ match_stats:dict[UPat, list[Union[int, float]]] = dict()
class TrackedGraphRewrite:
loc: tuple[str, int] # location that called graph_rewrite
sink: int # the sink input to graph_rewrite
matches: list[tuple[int, int, UPat]] # before+after of all the matches
matches: list[tuple[int, int, tuple]] # before/after UOp, UPat location
name: str|None # optional name of the rewrite
depth: int # depth if it's a subrewrite
bottom_up: bool
@@ -829,7 +829,8 @@ class TrackedPatternMatcher(PatternMatcher):
match_stats[p][0] += 1
match_stats[p][3] += (et:=time.perf_counter()-st)
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", printable(p.location))
if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and active_rewrites: active_rewrites[-1].matches.append((track_uop(uop),track_uop(ret), p))
if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and active_rewrites:
active_rewrites[-1].matches.append((track_uop(uop), track_uop(ret), p.location))
return ret
match_stats[p][2] += time.perf_counter()-st
return None

View File

@@ -83,12 +83,12 @@ def _reconstruct(a:int):
def get_details(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
yield {"graph":uop_to_json(next_sink:=_reconstruct(ctx.sink)), "uop":str(next_sink), "changed_nodes":None, "diff":None, "upat":None}
replaces: dict[UOp, UOp] = {}
for u0_num,u1_num,upat in tqdm(ctx.matches):
for u0_num,u1_num,upat_loc in tqdm(ctx.matches):
replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num)
try: new_sink = next_sink.substitute(replaces)
except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":str(new_sink), "changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json],
"diff":list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines())), "upat":(upat.location, printable(upat.location))}
"diff":list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines())), "upat":(upat_loc, printable(upat_loc))}
if not ctx.bottom_up: next_sink = new_sink
# Profiler API