viz don't include graph_rewrites that return a non-UOp (#6832)

* viz don't include graph_rewrites that return a non-UOp

* delete bad things
This commit is contained in:
qazal
2024-10-01 18:13:53 +08:00
committed by GitHub
parent 2a540d87e7
commit 0cb82f308c
2 changed files with 3 additions and 7 deletions

View File

@@ -551,7 +551,7 @@ class TrackedPatternMatcher(PatternMatcher):
match_stats[p][2] += (et:=time.perf_counter()-st)
match_stats[p][3] += et
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
if TRACK_MATCH_STATS >= 2 and contexts: contexts[-1].rewrites.append((uop, ret, p))
if TRACK_MATCH_STATS >= 2 and contexts and isinstance(ret, UOp): contexts[-1].rewrites.append((uop, ret, p))
return ret # NOTE: if it returns None, we keep trying to match
match_stats[p][2] += time.perf_counter()-st
return None

View File

@@ -11,10 +11,6 @@ from tinygrad.engine.graph import uops_colors, word_wrap
# **** /graph - detailed UOp + rewrites
# NOTE: UPats in ops.py are spec
def graph_rewrites(ctx:TrackedRewriteContext):
return [x for x in ctx.rewrites if x[2].location[0].split("/")[-1] != "ops.py"]
@dataclass(frozen=True)
class RewriteLocation:
filename: str
@@ -27,7 +23,7 @@ class RewriteLocation:
p = r"graph_rewrite\([^,]+,\s*([^>]+)\)"
match = re.search(p, code:=lines(fp)[lineno-1].strip())
return RewriteLocation(f"{fp.split('/')[-1]}:{lineno}", code, match.group(1).split(",")[0] if match is not None else None,
len(graph_rewrites(ctx)))
len(ctx.rewrites))
def to_json(self): return asdict(self)
@dataclass(frozen=True)
@@ -44,7 +40,7 @@ class UOpRet:
extra: List[List[str]] = [[str(ctx.sink)]]
additions: List[List[int]] = [[]]
seen_replaces: Dict[bytes, UOp] = {}
for i, (first, rewritten, pattern) in enumerate(graph_rewrites(ctx)):
for i, (first, rewritten, pattern) in enumerate(ctx.rewrites):
# first, rewrite this UOp with the current rewrite + all the seen rewrites before this
seen_replaces[first.key] = rewritten
new_sink = replace_uop(uops[-1], {**seen_replaces})