mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
saving rewrites [run_process_replay] (#6501)
* save rewrites with TRACK_MATCH_STATS=2 [run_process_replay] * cleaner
This commit is contained in:
@@ -720,6 +720,7 @@ class PatternMatcher:
|
||||
# *** tracking pattern matcher ***
|
||||
|
||||
TRACK_MATCH_STATS = getenv("TRACK_MATCH_STATS", 0)
|
||||
contexts: List[Tuple[UOp, List[Tuple[UOp, UOp]]]] = []
|
||||
match_stats:Dict[UPat, List[Union[int, float]]] = dict()
|
||||
class TrackedPattenMatcher(PatternMatcher):
|
||||
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
|
||||
@@ -740,23 +741,27 @@ class TrackedPattenMatcher(PatternMatcher):
|
||||
match_stats[p][0] += 1
|
||||
match_stats[p][2] += (et:=time.perf_counter()-st)
|
||||
match_stats[p][3] += et
|
||||
if TRACK_MATCH_STATS >= 2: print(f"{et*1e6:7.2f} us -- ", p.printable())
|
||||
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
|
||||
if TRACK_MATCH_STATS >= 2: contexts[-1][1].append((uop, ret))
|
||||
return ret # NOTE: if it returns None, we keep trying to match
|
||||
match_stats[p][2] += time.perf_counter()-st
|
||||
return None
|
||||
|
||||
if TRACK_MATCH_STATS:
|
||||
PatternMatcher = TrackedPattenMatcher # type: ignore
|
||||
import atexit
|
||||
import atexit, pickle
|
||||
@atexit.register
|
||||
def print_match_stats():
|
||||
ret = [0,0,0.0,0.0]
|
||||
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]):
|
||||
loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}"
|
||||
if getenv("UPAT_FILE", loc_str) not in loc_str: continue
|
||||
print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {v[2]*1000.:9.2f} ms -- {loc_str:15s}", k.printable())
|
||||
if v[1] != 0: print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {v[2]*1000.:9.2f} ms -- {loc_str:15s}", k.printable())
|
||||
ret = [x+y for x,y in zip(ret, v)]
|
||||
print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {ret[2]*1000.:9.2f} ms -- TOTAL")
|
||||
if TRACK_MATCH_STATS >= 2:
|
||||
with open("/tmp/rewrites.pkl", "wb") as f:
|
||||
print(f"rewrote {len(contexts)} graphs and applied {sum(len(x[1]) for x in contexts)} rules, saved to /tmp/rewrites.pkl")
|
||||
pickle.dump(contexts, f)
|
||||
|
||||
# *** simple graph rewrite engine ***
|
||||
|
||||
@@ -773,4 +778,6 @@ class RewriteContext:
|
||||
x = UOp(*replace_source) if new_src != n.src else n
|
||||
self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x)) else x
|
||||
return found
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp: return RewriteContext(pm).rewrite(sink)
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
|
||||
if TRACK_MATCH_STATS >= 2: contexts.append((sink, []))
|
||||
return RewriteContext(pm).rewrite(sink)
|
||||
|
||||
Reference in New Issue
Block a user