From c8ba48b223e5986e2711efba2bdd313ad87ecc39 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 15 Aug 2025 19:09:47 +0300 Subject: [PATCH] show rewrite errors in viz (#11684) --- test/unit/test_viz.py | 4 ++-- tinygrad/uop/ops.py | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/test/unit/test_viz.py b/test/unit/test_viz.py index 020a8e465b..1d8189203d 100644 --- a/test/unit/test_viz.py +++ b/test/unit/test_viz.py @@ -76,7 +76,7 @@ class TestViz(BaseTestViz): self.assertEqual(lineno, inner.__code__.co_firstlineno) def test_exceptions(self): - # VIZ tracks rewrites up to the error + # VIZ tracks rewrites up to and including the error def count_3(x:UOp): assert x.arg <= 3 return x.replace(arg=x.arg+1) @@ -85,7 +85,7 @@ class TestViz(BaseTestViz): with self.assertRaises(AssertionError): exec_rewrite(a, [err_pm]) lst = get_viz_list() err_step = lst[0]["steps"][0] - self.assertEqual(err_step["match_count"], 3) + self.assertEqual(err_step["match_count"], 4) # 3 successful rewrites + 1 err def test_default_name(self): a = UOp.variable("a", 1, 10) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 753634a915..c56b709cf7 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -857,7 +857,12 @@ class TrackedPatternMatcher(PatternMatcher): match_stats[p][2] += time.perf_counter()-st continue match_stats[p][1] += 1 - if (ret:=match(uop, ctx)) is not None and ret is not uop: + try: ret = match(uop, ctx) + except Exception: + if TRACK_MATCH_STATS >= 2 and active_rewrites: + active_rewrites[-1].matches.append((track_uop(uop), track_uop(UOp(Ops.NOOP, arg=str(sys.exc_info()[1]))), p.location)) + raise + if ret is not None and ret is not uop: match_stats[p][0] += 1 match_stats[p][3] += (et:=time.perf_counter()-st) if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", printable(p.location))