mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
simple failing test for graph_rewrite children [pr] (#9489)
* simple failing test for graph_rewrite children [pr] * lint * update too
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.ops import PatternMatcher, Ops, UPat, graph_rewrite, RewriteContext, UOp
|
||||
from tinygrad.ops import PatternMatcher, Ops, UPat, graph_rewrite, RewriteContext, UOp, merge_views
|
||||
from tinygrad.engine.schedule import sym
|
||||
|
||||
class TestRewriteTrackedChildren(unittest.TestCase):
|
||||
def test_children_in_context(self):
|
||||
@@ -47,5 +48,15 @@ class TestRewriteTrackedChildren(unittest.TestCase):
|
||||
print([x.arg for x in sink.get_children_map()[view_w_child]])
|
||||
self.assertSetEqual(set([x.arg for x in sink.get_children_map()[view_w_child]]), set((3,4)))
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_child_after_parent_update(self):
|
||||
def print_children(ctx, r):
|
||||
ctx.update_children()
|
||||
print(ctx.children[r])
|
||||
extra = PatternMatcher([(UPat(Ops.REDUCE_AXIS, name="r"), print_children)])
|
||||
a = Tensor.empty(3, 3)
|
||||
r = (a+0).sum()
|
||||
graph_rewrite(r.lazydata, merge_views+sym+extra, track_children=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user