From ada6b92b2df6d6442110dad81df86688d8edec62 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 30 Nov 2025 17:40:52 -0800 Subject: [PATCH] add a gate to rewrite if there's no rules [pr] (#13506) --- tinygrad/uop/ops.py | 52 +++++++++++++++++++++++---------------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 55bd7bf90b..0b49d24a0a 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -1034,10 +1034,11 @@ class PatternMatcher: def __add__(self, more:PatternMatcher) -> PatternMatcher: return PatternMatcher(self.patterns+more.patterns) def rewrite(self, uop:UOp, ctx=None) -> UOp|None: - ler = {u.op for u in uop.src} - for _,match,early_reject in self.pdict.get(uop.op, []): - if not early_reject.issubset(ler): continue - if (ret:=match(uop, ctx)) is not None and ret is not uop: return ret + if len(pats:=self.pdict.get(uop.op, [])): + ler = {u.op for u in uop.src} + for _,match,early_reject in pats: + if not early_reject.issubset(ler): continue + if (ret:=match(uop, ctx)) is not None and ret is not uop: return ret return None # *** tracking pattern matcher *** @@ -1119,28 +1120,29 @@ def profile_matches(fxn:Callable): class TrackedPatternMatcher(PatternMatcher): def rewrite(self, uop:UOp, ctx=None) -> UOp|None: - ret = None - ler = {u.op for u in uop.src} - for p,match,early_reject in self.pdict.get(uop.op, []): - if p not in match_stats: match_stats[p] = [0,0,0.0,0.0] - st = time.perf_counter() - if not early_reject.issubset(ler): + if len(pats:=self.pdict.get(uop.op, [])): + ret = None + ler = {u.op for u in uop.src} + for p,match,early_reject in pats: + if p not in match_stats: match_stats[p] = [0,0,0.0,0.0] + st = time.perf_counter() + if not early_reject.issubset(ler): + match_stats[p][2] += time.perf_counter()-st + continue + match_stats[p][1] += 1 + try: ret = match(uop, ctx) + except Exception: + if TRACK_MATCH_STATS >= 2 and active_rewrites: + active_rewrites[-1].matches.append((uop.trace_num, UOp(Ops.REWRITE_ERROR,src=uop.src,arg=str(sys.exc_info()[1])).trace_num,p.location,0)) + raise + if ret is not None and ret is not uop: + match_stats[p][0] += 1 + match_stats[p][3] += (et:=time.perf_counter()-st) + if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", printable(p.location)) + if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and active_rewrites: + active_rewrites[-1].matches.append((uop.trace_num, ret.trace_num, p.location, et)) + return ret match_stats[p][2] += time.perf_counter()-st - continue - match_stats[p][1] += 1 - try: ret = match(uop, ctx) - except Exception: - if TRACK_MATCH_STATS >= 2 and active_rewrites: - active_rewrites[-1].matches.append((uop.trace_num, UOp(Ops.REWRITE_ERROR,src=uop.src,arg=str(sys.exc_info()[1])).trace_num,p.location,0)) - raise - if ret is not None and ret is not uop: - match_stats[p][0] += 1 - match_stats[p][3] += (et:=time.perf_counter()-st) - if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", printable(p.location)) - if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and active_rewrites: - active_rewrites[-1].matches.append((uop.trace_num, ret.trace_num, p.location, et)) - return ret - match_stats[p][2] += time.perf_counter()-st return None @dataclass(frozen=True)