saving rewrites [run_process_replay] (#6501)

* save rewrites with TRACK_MATCH_STATS=2 [run_process_replay]

* cleaner
This commit is contained in:
George Hotz
2024-09-13 15:02:27 +08:00
committed by GitHub
parent 7c078191ce
commit 774bf39f85

View File

@@ -720,6 +720,7 @@ class PatternMatcher:
# *** tracking pattern matcher ***
TRACK_MATCH_STATS = getenv("TRACK_MATCH_STATS", 0)
contexts: List[Tuple[UOp, List[Tuple[UOp, UOp]]]] = []
match_stats:Dict[UPat, List[Union[int, float]]] = dict()
class TrackedPattenMatcher(PatternMatcher):
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
@@ -740,23 +741,27 @@ class TrackedPattenMatcher(PatternMatcher):
match_stats[p][0] += 1
match_stats[p][2] += (et:=time.perf_counter()-st)
match_stats[p][3] += et
if TRACK_MATCH_STATS >= 2: print(f"{et*1e6:7.2f} us -- ", p.printable())
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
if TRACK_MATCH_STATS >= 2: contexts[-1][1].append((uop, ret))
return ret # NOTE: if it returns None, we keep trying to match
match_stats[p][2] += time.perf_counter()-st
return None
if TRACK_MATCH_STATS:
PatternMatcher = TrackedPattenMatcher # type: ignore
import atexit
import atexit, pickle
@atexit.register
def print_match_stats():
ret = [0,0,0.0,0.0]
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]):
loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}"
if getenv("UPAT_FILE", loc_str) not in loc_str: continue
print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {v[2]*1000.:9.2f} ms -- {loc_str:15s}", k.printable())
if v[1] != 0: print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {v[2]*1000.:9.2f} ms -- {loc_str:15s}", k.printable())
ret = [x+y for x,y in zip(ret, v)]
print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {ret[2]*1000.:9.2f} ms -- TOTAL")
if TRACK_MATCH_STATS >= 2:
with open("/tmp/rewrites.pkl", "wb") as f:
print(f"rewrote {len(contexts)} graphs and applied {sum(len(x[1]) for x in contexts)} rules, saved to /tmp/rewrites.pkl")
pickle.dump(contexts, f)
# *** simple graph rewrite engine ***
@@ -773,4 +778,6 @@ class RewriteContext:
x = UOp(*replace_source) if new_src != n.src else n
self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x)) else x
return found
def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp: return RewriteContext(pm).rewrite(sink)
def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
if TRACK_MATCH_STATS >= 2: contexts.append((sink, []))
return RewriteContext(pm).rewrite(sink)