show rewrite errors in viz (#11684)

This commit is contained in:
qazal
2025-08-15 19:09:47 +03:00
committed by GitHub
parent 560984fd8d
commit c8ba48b223
2 changed files with 8 additions and 3 deletions

View File

@@ -76,7 +76,7 @@ class TestViz(BaseTestViz):
self.assertEqual(lineno, inner.__code__.co_firstlineno)
def test_exceptions(self):
# VIZ tracks rewrites up to the error
# VIZ tracks rewrites up to and including the error
def count_3(x:UOp):
assert x.arg <= 3
return x.replace(arg=x.arg+1)
@@ -85,7 +85,7 @@ class TestViz(BaseTestViz):
with self.assertRaises(AssertionError): exec_rewrite(a, [err_pm])
lst = get_viz_list()
err_step = lst[0]["steps"][0]
self.assertEqual(err_step["match_count"], 3)
self.assertEqual(err_step["match_count"], 4) # 3 successful rewrites + 1 err
def test_default_name(self):
a = UOp.variable("a", 1, 10)

View File

@@ -857,7 +857,12 @@ class TrackedPatternMatcher(PatternMatcher):
match_stats[p][2] += time.perf_counter()-st
continue
match_stats[p][1] += 1
if (ret:=match(uop, ctx)) is not None and ret is not uop:
try: ret = match(uop, ctx)
except Exception:
if TRACK_MATCH_STATS >= 2 and active_rewrites:
active_rewrites[-1].matches.append((track_uop(uop), track_uop(UOp(Ops.NOOP, arg=str(sys.exc_info()[1]))), p.location))
raise
if ret is not None and ret is not uop:
match_stats[p][0] += 1
match_stats[p][3] += (et:=time.perf_counter()-st)
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", printable(p.location))