mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
viz: pickle UPat location (#11086)
This commit is contained in:
@@ -763,7 +763,7 @@ match_stats:dict[UPat, list[Union[int, float]]] = dict()
|
||||
class TrackedGraphRewrite:
|
||||
loc: tuple[str, int] # location that called graph_rewrite
|
||||
sink: int # the sink input to graph_rewrite
|
||||
matches: list[tuple[int, int, UPat]] # before+after of all the matches
|
||||
matches: list[tuple[int, int, tuple]] # before/after UOp, UPat location
|
||||
name: str|None # optional name of the rewrite
|
||||
depth: int # depth if it's a subrewrite
|
||||
bottom_up: bool
|
||||
@@ -829,7 +829,8 @@ class TrackedPatternMatcher(PatternMatcher):
|
||||
match_stats[p][0] += 1
|
||||
match_stats[p][3] += (et:=time.perf_counter()-st)
|
||||
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", printable(p.location))
|
||||
if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and active_rewrites: active_rewrites[-1].matches.append((track_uop(uop),track_uop(ret), p))
|
||||
if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and active_rewrites:
|
||||
active_rewrites[-1].matches.append((track_uop(uop), track_uop(ret), p.location))
|
||||
return ret
|
||||
match_stats[p][2] += time.perf_counter()-st
|
||||
return None
|
||||
|
||||
@@ -83,12 +83,12 @@ def _reconstruct(a:int):
|
||||
def get_details(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
|
||||
yield {"graph":uop_to_json(next_sink:=_reconstruct(ctx.sink)), "uop":str(next_sink), "changed_nodes":None, "diff":None, "upat":None}
|
||||
replaces: dict[UOp, UOp] = {}
|
||||
for u0_num,u1_num,upat in tqdm(ctx.matches):
|
||||
for u0_num,u1_num,upat_loc in tqdm(ctx.matches):
|
||||
replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num)
|
||||
try: new_sink = next_sink.substitute(replaces)
|
||||
except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
|
||||
yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":str(new_sink), "changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json],
|
||||
"diff":list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines())), "upat":(upat.location, printable(upat.location))}
|
||||
"diff":list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines())), "upat":(upat_loc, printable(upat_loc))}
|
||||
if not ctx.bottom_up: next_sink = new_sink
|
||||
|
||||
# Profiler API
|
||||
|
||||
Reference in New Issue
Block a user