From 0cb82f308cc2446fb519d5e81ff70924ef3378d1 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 1 Oct 2024 18:13:53 +0800 Subject: [PATCH] viz don't include graph_rewrites that return a non-UOp (#6832) * viz don't include graph_rewrites that return a non-UOp * delete bad things --- tinygrad/ops.py | 2 +- viz/serve.py | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 6c33e3556f..330aad2955 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -551,7 +551,7 @@ class TrackedPatternMatcher(PatternMatcher): match_stats[p][2] += (et:=time.perf_counter()-st) match_stats[p][3] += et if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable()) - if TRACK_MATCH_STATS >= 2 and contexts: contexts[-1].rewrites.append((uop, ret, p)) + if TRACK_MATCH_STATS >= 2 and contexts and isinstance(ret, UOp): contexts[-1].rewrites.append((uop, ret, p)) return ret # NOTE: if it returns None, we keep trying to match match_stats[p][2] += time.perf_counter()-st return None diff --git a/viz/serve.py b/viz/serve.py index fb9269a826..4d9f399817 100755 --- a/viz/serve.py +++ b/viz/serve.py @@ -11,10 +11,6 @@ from tinygrad.engine.graph import uops_colors, word_wrap # **** /graph - detailed UOp + rewrites -# NOTE: UPats in ops.py are spec -def graph_rewrites(ctx:TrackedRewriteContext): - return [x for x in ctx.rewrites if x[2].location[0].split("/")[-1] != "ops.py"] - @dataclass(frozen=True) class RewriteLocation: filename: str @@ -27,7 +23,7 @@ class RewriteLocation: p = r"graph_rewrite\([^,]+,\s*([^>]+)\)" match = re.search(p, code:=lines(fp)[lineno-1].strip()) return RewriteLocation(f"{fp.split('/')[-1]}:{lineno}", code, match.group(1).split(",")[0] if match is not None else None, - len(graph_rewrites(ctx))) + len(ctx.rewrites)) def to_json(self): return asdict(self) @dataclass(frozen=True) @@ -44,7 +40,7 @@ class UOpRet: extra: List[List[str]] = [[str(ctx.sink)]] additions: List[List[int]] = [[]] seen_replaces: Dict[bytes, UOp] = {} - for i, (first, rewritten, pattern) in enumerate(graph_rewrites(ctx)): + for i, (first, rewritten, pattern) in enumerate(ctx.rewrites): # first, rewrite this UOp with the current rewrite + all the seen rewrites before this seen_replaces[first.key] = rewritten new_sink = replace_uop(uops[-1], {**seen_replaces})