simple failing test for graph_rewrite children [pr] (#9489)

* simple failing test for graph_rewrite children [pr]

* lint

* update too
This commit is contained in:
qazal
2025-03-18 13:07:21 +08:00
committed by GitHub
parent d20494e6d7
commit 935cd01f56

View File

@@ -1,6 +1,7 @@
import unittest
from tinygrad import Tensor
from tinygrad.ops import PatternMatcher, Ops, UPat, graph_rewrite, RewriteContext, UOp
from tinygrad.ops import PatternMatcher, Ops, UPat, graph_rewrite, RewriteContext, UOp, merge_views
from tinygrad.engine.schedule import sym
class TestRewriteTrackedChildren(unittest.TestCase):
def test_children_in_context(self):
@@ -47,5 +48,15 @@ class TestRewriteTrackedChildren(unittest.TestCase):
print([x.arg for x in sink.get_children_map()[view_w_child]])
self.assertSetEqual(set([x.arg for x in sink.get_children_map()[view_w_child]]), set((3,4)))
@unittest.expectedFailure
def test_child_after_parent_update(self):
def print_children(ctx, r):
ctx.update_children()
print(ctx.children[r])
extra = PatternMatcher([(UPat(Ops.REDUCE_AXIS, name="r"), print_children)])
a = Tensor.empty(3, 3)
r = (a+0).sum()
graph_rewrite(r.lazydata, merge_views+sym+extra, track_children=True)
if __name__ == '__main__':
unittest.main()