mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
track viz context even if rewrite errors [pr] (#6976)
This commit is contained in:
@@ -81,6 +81,17 @@ class TestViz(unittest.TestCase):
|
||||
self.assertEqual(key, "uop_1")
|
||||
self.assertEqual(len(m.upats), 0)
|
||||
|
||||
def test_track_with_exception(self):
|
||||
simple = TrackedPatternMatcher([(UPat.var("x")*1, lambda x:x)])
|
||||
@track_rewrites
|
||||
def do_rewrite(key:str, x:UOp):
|
||||
x = graph_rewrite(x, simple) # NOTE: viz tracks this
|
||||
raise Exception("test")
|
||||
ld = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
|
||||
with self.assertRaises(Exception): do_rewrite("uop_0", ld*1)
|
||||
ret = self.assert_valid_ctx()
|
||||
self.assertEqual(len(ret), 1)
|
||||
|
||||
def test_dedup_ast(self):
|
||||
a = Tensor.empty(4, 4).contiguous().realize()+2
|
||||
b = Tensor.empty(4, 4).contiguous().realize()+2
|
||||
|
||||
@@ -597,8 +597,9 @@ contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = []
|
||||
def track_rewrites(func):
|
||||
def __wrapper(self, *args, **kwargs):
|
||||
if TRACK_MATCH_STATS >= 2: rewrite_stack.append((self, []))
|
||||
ret = func(self, *args, **kwargs)
|
||||
if TRACK_MATCH_STATS >= 2: contexts.append(rewrite_stack.pop())
|
||||
try: ret = func(self, *args, **kwargs)
|
||||
finally: # NOTE: save everything in the stack
|
||||
if TRACK_MATCH_STATS >= 2: contexts.append(rewrite_stack.pop())
|
||||
return ret
|
||||
return __wrapper
|
||||
|
||||
|
||||
Reference in New Issue
Block a user