From 4ef531003972dfebac78166394deb5aa053807e8 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 10 Oct 2024 18:33:15 +0300 Subject: [PATCH] track viz context even if rewrite errors [pr] (#6976) --- test/test_viz.py | 11 +++++++++++ tinygrad/ops.py | 5 +++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/test/test_viz.py b/test/test_viz.py index e3592b9897..2a403df055 100644 --- a/test/test_viz.py +++ b/test/test_viz.py @@ -81,6 +81,17 @@ class TestViz(unittest.TestCase): self.assertEqual(key, "uop_1") self.assertEqual(len(m.upats), 0) + def test_track_with_exception(self): + simple = TrackedPatternMatcher([(UPat.var("x")*1, lambda x:x)]) + @track_rewrites + def do_rewrite(key:str, x:UOp): + x = graph_rewrite(x, simple) # NOTE: viz tracks this + raise Exception("test") + ld = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0))) + with self.assertRaises(Exception): do_rewrite("uop_0", ld*1) + ret = self.assert_valid_ctx() + self.assertEqual(len(ret), 1) + def test_dedup_ast(self): a = Tensor.empty(4, 4).contiguous().realize()+2 b = Tensor.empty(4, 4).contiguous().realize()+2 diff --git a/tinygrad/ops.py b/tinygrad/ops.py index f3bed7d1ab..35ba0ab102 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -597,8 +597,9 @@ contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = [] def track_rewrites(func): def __wrapper(self, *args, **kwargs): if TRACK_MATCH_STATS >= 2: rewrite_stack.append((self, [])) - ret = func(self, *args, **kwargs) - if TRACK_MATCH_STATS >= 2: contexts.append(rewrite_stack.pop()) + try: ret = func(self, *args, **kwargs) + finally: # NOTE: save everything in the stack + if TRACK_MATCH_STATS >= 2: contexts.append(rewrite_stack.pop()) return ret return __wrapper