mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
skip_0 in graph rewrite [pr] (#11627)
* skip_0 in graph rewrite [pr] * no track_rewrites on test * use dict instead of set
This commit is contained in:
@@ -7,9 +7,6 @@ from tinygrad.uop.ops import PatternMatcher, UOp, graph_rewrite, Ops, UPat, Grou
|
||||
@dataclass
|
||||
class ChildrenContext:
|
||||
children: dict[UOp, list[UOp]]|None = None
|
||||
seen_children: dict[UOp, set[int]] = field(default_factory=dict)
|
||||
seen_consts:int = 0
|
||||
saved_seen_consts:int = 0
|
||||
|
||||
# this is a generic child labeller
|
||||
def extract_children(ctx:ChildrenContext, x:UOp):
|
||||
@@ -25,14 +22,25 @@ pm_children = PatternMatcher([
|
||||
(UPat(GroupOp.All-{Ops.CHILD}, name="x"), mark_children),
|
||||
])
|
||||
|
||||
@dataclass
|
||||
class TestContext:
|
||||
seen_children: dict[UOp, set[int]] = field(default_factory=dict)
|
||||
ready_children: dict[UOp, set[int]] = field(default_factory=dict)
|
||||
seen_consts:int = 0
|
||||
saved_seen_consts:int = 0
|
||||
exp2_visit_count:int = 0
|
||||
|
||||
# this is a generic pattern
|
||||
def visit_child(ctx:ChildrenContext, x:UOp):
|
||||
if x.src[0] not in ctx.seen_children: ctx.seen_children[x.src[0]] = set()
|
||||
if x.src[0] not in ctx.seen_children:
|
||||
ctx.seen_children[x.src[0]] = set()
|
||||
ctx.ready_children[x.src[0]] = set()
|
||||
ctx.seen_children[x.src[0]].add(x.arg[0])
|
||||
if len(ctx.seen_children[x.src[0]]) != x.arg[1]:
|
||||
print(f"visit CHILD {x.arg} bottom up -- not ready {ctx.seen_children[x.src[0]]}")
|
||||
raise RewriteNotReady
|
||||
print(f"visit CHILD {x.arg} bottom up -- READY {ctx.seen_children[x.src[0]]}")
|
||||
ctx.ready_children[x.src[0]].add(x.arg[0])
|
||||
|
||||
pm_child_visitor = PatternMatcher([
|
||||
(UPat(Ops.CHILD, name="x"), visit_child),
|
||||
@@ -40,30 +48,63 @@ pm_child_visitor = PatternMatcher([
|
||||
|
||||
# this is for the test
|
||||
def see_const(ctx:ChildrenContext, c:UOp): ctx.seen_consts += c.arg
|
||||
def see_exp2(ctx:ChildrenContext): ctx.exp2_visit_count += 1
|
||||
def save_seen_consts(ctx:ChildrenContext, x:UOp): ctx.saved_seen_consts = ctx.seen_consts
|
||||
pm_consts = PatternMatcher([
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), save_seen_consts),
|
||||
(UPat()+UPat.cvar("c"), see_const),
|
||||
(UPat(Ops.EXP2), see_exp2),
|
||||
])
|
||||
|
||||
class TestChildrenRewrite(unittest.TestCase):
|
||||
def test_not_ready_double_simple(self):
|
||||
global_a = UOp.variable("a", 0, 10).exp2()
|
||||
inter = (global_a+global_a).exp2()
|
||||
global_sink = (inter+inter).sink()
|
||||
|
||||
sink = graph_rewrite(global_sink, pm_children, ctx=ChildrenContext(), bottom_up=True)
|
||||
ctx = TestContext()
|
||||
graph_rewrite(sink, pm_consts, ctx=ctx, bottom_up=True)
|
||||
self.assertEqual(ctx.exp2_visit_count, 2)
|
||||
|
||||
def test_not_ready_double(self):
|
||||
global_a = UOp.variable("a", 0, 10).exp2()
|
||||
inter = ((global_a+1000)+(global_a+100)).exp2()
|
||||
global_sink = ((inter+10)+(inter+1)).sink()
|
||||
|
||||
sink = graph_rewrite(global_sink, pm_children, ctx=ChildrenContext(), bottom_up=True)
|
||||
print("test_not_ready_double")
|
||||
ctx = TestContext()
|
||||
graph_rewrite(sink, pm_child_visitor+pm_consts, ctx=ctx, bottom_up=True)
|
||||
self.assertEqual(ctx.exp2_visit_count, 2)
|
||||
self.assertEqual(ctx.seen_consts, ctx.saved_seen_consts)
|
||||
self.assertEqual(ctx.seen_consts, 1111)
|
||||
|
||||
def test_in_srcs_twice(self):
|
||||
global_a = UOp.variable("a", 0, 10).exp2()
|
||||
global_sink = (global_a+global_a).sink()
|
||||
|
||||
ctx = TestContext()
|
||||
graph_rewrite(global_sink, pm_consts, ctx=ctx, bottom_up=True)
|
||||
self.assertEqual(ctx.exp2_visit_count, 1)
|
||||
|
||||
def test_not_ready(self):
|
||||
a = UOp.variable("a", 0, 10).exp2()
|
||||
b = a+2
|
||||
c = a+3
|
||||
d = b+c
|
||||
sink = d.sink()
|
||||
global_a = UOp.variable("a", 0, 10).exp2()
|
||||
global_sink = ((global_a+2)+(global_a+3)).sink()
|
||||
|
||||
# without children and not ready, we don't see both adds before the DEFINE_VAR
|
||||
ctx = ChildrenContext()
|
||||
sink = graph_rewrite(sink, pm_consts, ctx=ctx, bottom_up=True)
|
||||
ctx = TestContext()
|
||||
graph_rewrite(global_sink, pm_consts, ctx=ctx, bottom_up=True)
|
||||
self.assertNotEqual(ctx.seen_consts, ctx.saved_seen_consts)
|
||||
self.assertEqual(ctx.exp2_visit_count, 1)
|
||||
|
||||
# with children and not ready we do
|
||||
ctx = ChildrenContext()
|
||||
sink = graph_rewrite(sink, pm_children, ctx=ctx, bottom_up=True)
|
||||
sink = graph_rewrite(sink, pm_child_visitor+pm_consts, ctx=ctx, bottom_up=True)
|
||||
sink = graph_rewrite(global_sink, pm_children, ctx=ChildrenContext(), bottom_up=True)
|
||||
ctx = TestContext()
|
||||
graph_rewrite(sink, pm_child_visitor+pm_consts, ctx=ctx, bottom_up=True)
|
||||
self.assertEqual(ctx.seen_consts, ctx.saved_seen_consts)
|
||||
self.assertEqual(ctx.exp2_visit_count, 1)
|
||||
self.assertSetEqual(list(ctx.ready_children.values())[0], {0,1})
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -887,19 +887,22 @@ class RewriteContext:
|
||||
self.bpm: PatternMatcher|None = bpm
|
||||
self.ctx = ctx
|
||||
self.replace: dict[UOp, UOp] = {}
|
||||
self.skip_0: dict[UOp, None] = {} # NOTE: this is needed for RewriteNotReady. it also detects some infinite loops
|
||||
|
||||
def unified_rewrite(self, root:UOp) -> UOp:
|
||||
stack: list[tuple[UOp, int, UOp]] = [(root, 0, root)]
|
||||
while stack:
|
||||
if len(stack) >= 200000: raise RuntimeError("infinite loop in graph_rewrite")
|
||||
if len(stack) >= 200000: raise RuntimeError("infinite loop in graph_rewrite (stack too big)")
|
||||
n, stage, new_n = stack.pop()
|
||||
if n in self.replace: continue # skip any nodes we have seen
|
||||
try:
|
||||
if stage == 0:
|
||||
if n in self.skip_0: continue
|
||||
# if bottom up, we rewrite this node early. in both cases, we add its parents to the stack
|
||||
if self.bpm is not None: new_n = self.bpm.fixed_point_rewrite(new_n, self.ctx)
|
||||
stack.append((n, 1, new_n))
|
||||
for x in reversed(new_n.src): stack.append((x, 0, x))
|
||||
self.skip_0[n] = None
|
||||
elif stage == 1:
|
||||
try: new_src = tuple([self.replace[x] for x in new_n.src])
|
||||
except KeyError: raise RewriteNotReady # pylint: disable=raise-missing-from
|
||||
@@ -917,7 +920,7 @@ class RewriteContext:
|
||||
else:
|
||||
# in stage 2, we link the result of new_n to the result of n
|
||||
try: self.replace[n] = self.replace[new_n]
|
||||
except KeyError: raise RewriteNotReady # pylint: disable=raise-missing-from
|
||||
except KeyError: raise RuntimeError("infinite loop in graph_rewrite (explicit)") # pylint: disable=raise-missing-from
|
||||
except RewriteNotReady:
|
||||
# retry this later
|
||||
stack.insert(0, (n, stage, new_n))
|
||||
|
||||
Reference in New Issue
Block a user