diff --git a/test/test_rewrite_tracked_childen.py b/test/test_rewrite_tracked_childen.py index e9b64d82fe..bdc52a8382 100644 --- a/test/test_rewrite_tracked_childen.py +++ b/test/test_rewrite_tracked_childen.py @@ -1,6 +1,7 @@ import unittest from tinygrad import Tensor -from tinygrad.ops import PatternMatcher, Ops, UPat, graph_rewrite, RewriteContext, UOp +from tinygrad.ops import PatternMatcher, Ops, UPat, graph_rewrite, RewriteContext, UOp, merge_views +from tinygrad.engine.schedule import sym class TestRewriteTrackedChildren(unittest.TestCase): def test_children_in_context(self): @@ -47,5 +48,15 @@ class TestRewriteTrackedChildren(unittest.TestCase): print([x.arg for x in sink.get_children_map()[view_w_child]]) self.assertSetEqual(set([x.arg for x in sink.get_children_map()[view_w_child]]), set((3,4))) + @unittest.expectedFailure + def test_child_after_parent_update(self): + def print_children(ctx, r): + ctx.update_children() + print(ctx.children[r]) + extra = PatternMatcher([(UPat(Ops.REDUCE_AXIS, name="r"), print_children)]) + a = Tensor.empty(3, 3) + r = (a+0).sum() + graph_rewrite(r.lazydata, merge_views+sym+extra, track_children=True) + if __name__ == '__main__': unittest.main()