mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
show rewrite errors in viz (#11684)
This commit is contained in:
@@ -76,7 +76,7 @@ class TestViz(BaseTestViz):
|
||||
self.assertEqual(lineno, inner.__code__.co_firstlineno)
|
||||
|
||||
def test_exceptions(self):
|
||||
# VIZ tracks rewrites up to the error
|
||||
# VIZ tracks rewrites up to and including the error
|
||||
def count_3(x:UOp):
|
||||
assert x.arg <= 3
|
||||
return x.replace(arg=x.arg+1)
|
||||
@@ -85,7 +85,7 @@ class TestViz(BaseTestViz):
|
||||
with self.assertRaises(AssertionError): exec_rewrite(a, [err_pm])
|
||||
lst = get_viz_list()
|
||||
err_step = lst[0]["steps"][0]
|
||||
self.assertEqual(err_step["match_count"], 3)
|
||||
self.assertEqual(err_step["match_count"], 4) # 3 successful rewrites + 1 err
|
||||
|
||||
def test_default_name(self):
|
||||
a = UOp.variable("a", 1, 10)
|
||||
|
||||
@@ -857,7 +857,12 @@ class TrackedPatternMatcher(PatternMatcher):
|
||||
match_stats[p][2] += time.perf_counter()-st
|
||||
continue
|
||||
match_stats[p][1] += 1
|
||||
if (ret:=match(uop, ctx)) is not None and ret is not uop:
|
||||
try: ret = match(uop, ctx)
|
||||
except Exception:
|
||||
if TRACK_MATCH_STATS >= 2 and active_rewrites:
|
||||
active_rewrites[-1].matches.append((track_uop(uop), track_uop(UOp(Ops.NOOP, arg=str(sys.exc_info()[1]))), p.location))
|
||||
raise
|
||||
if ret is not None and ret is not uop:
|
||||
match_stats[p][0] += 1
|
||||
match_stats[p][3] += (et:=time.perf_counter()-st)
|
||||
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", printable(p.location))
|
||||
|
||||
Reference in New Issue
Block a user