track viz context even if rewrite errors [pr] (#6976)

This commit is contained in:
qazal
2024-10-10 18:33:15 +03:00
committed by GitHub
parent 592e5f1df2
commit 4ef5310039
2 changed files with 14 additions and 2 deletions

View File

@@ -81,6 +81,17 @@ class TestViz(unittest.TestCase):
self.assertEqual(key, "uop_1")
self.assertEqual(len(m.upats), 0)
def test_track_with_exception(self):
simple = TrackedPatternMatcher([(UPat.var("x")*1, lambda x:x)])
@track_rewrites
def do_rewrite(key:str, x:UOp):
x = graph_rewrite(x, simple) # NOTE: viz tracks this
raise Exception("test")
ld = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
with self.assertRaises(Exception): do_rewrite("uop_0", ld*1)
ret = self.assert_valid_ctx()
self.assertEqual(len(ret), 1)
def test_dedup_ast(self):
a = Tensor.empty(4, 4).contiguous().realize()+2
b = Tensor.empty(4, 4).contiguous().realize()+2

View File

@@ -597,8 +597,9 @@ contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = []
def track_rewrites(func):
def __wrapper(self, *args, **kwargs):
if TRACK_MATCH_STATS >= 2: rewrite_stack.append((self, []))
ret = func(self, *args, **kwargs)
if TRACK_MATCH_STATS >= 2: contexts.append(rewrite_stack.pop())
try: ret = func(self, *args, **kwargs)
finally: # NOTE: save everything in the stack
if TRACK_MATCH_STATS >= 2: contexts.append(rewrite_stack.pop())
return ret
return __wrapper