diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 7858b7da94..363e8190eb 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -763,7 +763,7 @@ match_stats:dict[UPat, list[Union[int, float]]] = dict() class TrackedGraphRewrite: loc: tuple[str, int] # location that called graph_rewrite sink: int # the sink input to graph_rewrite - matches: list[tuple[int, int, UPat]] # before+after of all the matches + matches: list[tuple[int, int, tuple]] # before/after UOp, UPat location name: str|None # optional name of the rewrite depth: int # depth if it's a subrewrite bottom_up: bool @@ -829,7 +829,8 @@ class TrackedPatternMatcher(PatternMatcher): match_stats[p][0] += 1 match_stats[p][3] += (et:=time.perf_counter()-st) if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", printable(p.location)) - if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and active_rewrites: active_rewrites[-1].matches.append((track_uop(uop),track_uop(ret), p)) + if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and active_rewrites: + active_rewrites[-1].matches.append((track_uop(uop), track_uop(ret), p.location)) return ret match_stats[p][2] += time.perf_counter()-st return None diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 8faea19861..30aca97e77 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -83,12 +83,12 @@ def _reconstruct(a:int): def get_details(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]: yield {"graph":uop_to_json(next_sink:=_reconstruct(ctx.sink)), "uop":str(next_sink), "changed_nodes":None, "diff":None, "upat":None} replaces: dict[UOp, UOp] = {} - for u0_num,u1_num,upat in tqdm(ctx.matches): + for u0_num,u1_num,upat_loc in tqdm(ctx.matches): replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num) try: new_sink = next_sink.substitute(replaces) except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e)) yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":str(new_sink), "changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json], - "diff":list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines())), "upat":(upat.location, printable(upat.location))} + "diff":list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines())), "upat":(upat_loc, printable(upat_loc))} if not ctx.bottom_up: next_sink = new_sink # Profiler API