From 774bf39f85685aa55196b2716fa27f36d0dc88ba Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 13 Sep 2024 15:02:27 +0800 Subject: [PATCH] saving rewrites [run_process_replay] (#6501) * save rewrites with TRACK_MATCH_STATS=2 [run_process_replay] * cleaner --- tinygrad/ops.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index d467653af1..dcdd6fd8a1 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -720,6 +720,7 @@ class PatternMatcher: # *** tracking pattern matcher *** TRACK_MATCH_STATS = getenv("TRACK_MATCH_STATS", 0) +contexts: List[Tuple[UOp, List[Tuple[UOp, UOp]]]] = [] match_stats:Dict[UPat, List[Union[int, float]]] = dict() class TrackedPattenMatcher(PatternMatcher): def __init__(self, patterns:List[Tuple[UPat, Callable]]): @@ -740,23 +741,27 @@ class TrackedPattenMatcher(PatternMatcher): match_stats[p][0] += 1 match_stats[p][2] += (et:=time.perf_counter()-st) match_stats[p][3] += et - if TRACK_MATCH_STATS >= 2: print(f"{et*1e6:7.2f} us -- ", p.printable()) + if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable()) + if TRACK_MATCH_STATS >= 2: contexts[-1][1].append((uop, ret)) return ret # NOTE: if it returns None, we keep trying to match match_stats[p][2] += time.perf_counter()-st return None if TRACK_MATCH_STATS: PatternMatcher = TrackedPattenMatcher # type: ignore - import atexit + import atexit, pickle @atexit.register def print_match_stats(): ret = [0,0,0.0,0.0] for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]): loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}" - if getenv("UPAT_FILE", loc_str) not in loc_str: continue - print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {v[2]*1000.:9.2f} ms -- {loc_str:15s}", k.printable()) + if v[1] != 0: print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {v[2]*1000.:9.2f} ms -- {loc_str:15s}", k.printable()) ret = [x+y for x,y in zip(ret, v)] print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {ret[2]*1000.:9.2f} ms -- TOTAL") + if TRACK_MATCH_STATS >= 2: + with open("/tmp/rewrites.pkl", "wb") as f: + print(f"rewrote {len(contexts)} graphs and applied {sum(len(x[1]) for x in contexts)} rules, saved to /tmp/rewrites.pkl") + pickle.dump(contexts, f) # *** simple graph rewrite engine *** @@ -773,4 +778,6 @@ class RewriteContext: x = UOp(*replace_source) if new_src != n.src else n self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x)) else x return found -def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp: return RewriteContext(pm).rewrite(sink) +def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp: + if TRACK_MATCH_STATS >= 2: contexts.append((sink, [])) + return RewriteContext(pm).rewrite(sink)