skip_0 in graph rewrite [pr] (#11627)

* skip_0 in graph rewrite [pr]

* no track_rewrites on test

* use dict instead of set
This commit is contained in:
George Hotz
2025-08-11 18:29:04 -07:00
committed by GitHub
parent ca7a641442
commit ca41b5e38b
2 changed files with 60 additions and 16 deletions

View File

@@ -7,9 +7,6 @@ from tinygrad.uop.ops import PatternMatcher, UOp, graph_rewrite, Ops, UPat, Grou
@dataclass
class ChildrenContext:
children: dict[UOp, list[UOp]]|None = None
seen_children: dict[UOp, set[int]] = field(default_factory=dict)
seen_consts:int = 0
saved_seen_consts:int = 0
# this is a generic child labeller
def extract_children(ctx:ChildrenContext, x:UOp):
@@ -25,14 +22,25 @@ pm_children = PatternMatcher([
(UPat(GroupOp.All-{Ops.CHILD}, name="x"), mark_children),
])
@dataclass
class TestContext:
seen_children: dict[UOp, set[int]] = field(default_factory=dict)
ready_children: dict[UOp, set[int]] = field(default_factory=dict)
seen_consts:int = 0
saved_seen_consts:int = 0
exp2_visit_count:int = 0
# this is a generic pattern
def visit_child(ctx:ChildrenContext, x:UOp):
if x.src[0] not in ctx.seen_children: ctx.seen_children[x.src[0]] = set()
if x.src[0] not in ctx.seen_children:
ctx.seen_children[x.src[0]] = set()
ctx.ready_children[x.src[0]] = set()
ctx.seen_children[x.src[0]].add(x.arg[0])
if len(ctx.seen_children[x.src[0]]) != x.arg[1]:
print(f"visit CHILD {x.arg} bottom up -- not ready {ctx.seen_children[x.src[0]]}")
raise RewriteNotReady
print(f"visit CHILD {x.arg} bottom up -- READY {ctx.seen_children[x.src[0]]}")
ctx.ready_children[x.src[0]].add(x.arg[0])
pm_child_visitor = PatternMatcher([
(UPat(Ops.CHILD, name="x"), visit_child),
@@ -40,30 +48,63 @@ pm_child_visitor = PatternMatcher([
# this is for the test
def see_const(ctx:ChildrenContext, c:UOp): ctx.seen_consts += c.arg
def see_exp2(ctx:ChildrenContext): ctx.exp2_visit_count += 1
def save_seen_consts(ctx:ChildrenContext, x:UOp): ctx.saved_seen_consts = ctx.seen_consts
pm_consts = PatternMatcher([
(UPat(Ops.DEFINE_VAR, name="x"), save_seen_consts),
(UPat()+UPat.cvar("c"), see_const),
(UPat(Ops.EXP2), see_exp2),
])
class TestChildrenRewrite(unittest.TestCase):
def test_not_ready_double_simple(self):
global_a = UOp.variable("a", 0, 10).exp2()
inter = (global_a+global_a).exp2()
global_sink = (inter+inter).sink()
sink = graph_rewrite(global_sink, pm_children, ctx=ChildrenContext(), bottom_up=True)
ctx = TestContext()
graph_rewrite(sink, pm_consts, ctx=ctx, bottom_up=True)
self.assertEqual(ctx.exp2_visit_count, 2)
def test_not_ready_double(self):
global_a = UOp.variable("a", 0, 10).exp2()
inter = ((global_a+1000)+(global_a+100)).exp2()
global_sink = ((inter+10)+(inter+1)).sink()
sink = graph_rewrite(global_sink, pm_children, ctx=ChildrenContext(), bottom_up=True)
print("test_not_ready_double")
ctx = TestContext()
graph_rewrite(sink, pm_child_visitor+pm_consts, ctx=ctx, bottom_up=True)
self.assertEqual(ctx.exp2_visit_count, 2)
self.assertEqual(ctx.seen_consts, ctx.saved_seen_consts)
self.assertEqual(ctx.seen_consts, 1111)
def test_in_srcs_twice(self):
global_a = UOp.variable("a", 0, 10).exp2()
global_sink = (global_a+global_a).sink()
ctx = TestContext()
graph_rewrite(global_sink, pm_consts, ctx=ctx, bottom_up=True)
self.assertEqual(ctx.exp2_visit_count, 1)
def test_not_ready(self):
a = UOp.variable("a", 0, 10).exp2()
b = a+2
c = a+3
d = b+c
sink = d.sink()
global_a = UOp.variable("a", 0, 10).exp2()
global_sink = ((global_a+2)+(global_a+3)).sink()
# without children and not ready, we don't see both adds before the DEFINE_VAR
ctx = ChildrenContext()
sink = graph_rewrite(sink, pm_consts, ctx=ctx, bottom_up=True)
ctx = TestContext()
graph_rewrite(global_sink, pm_consts, ctx=ctx, bottom_up=True)
self.assertNotEqual(ctx.seen_consts, ctx.saved_seen_consts)
self.assertEqual(ctx.exp2_visit_count, 1)
# with children and not ready we do
ctx = ChildrenContext()
sink = graph_rewrite(sink, pm_children, ctx=ctx, bottom_up=True)
sink = graph_rewrite(sink, pm_child_visitor+pm_consts, ctx=ctx, bottom_up=True)
sink = graph_rewrite(global_sink, pm_children, ctx=ChildrenContext(), bottom_up=True)
ctx = TestContext()
graph_rewrite(sink, pm_child_visitor+pm_consts, ctx=ctx, bottom_up=True)
self.assertEqual(ctx.seen_consts, ctx.saved_seen_consts)
self.assertEqual(ctx.exp2_visit_count, 1)
self.assertSetEqual(list(ctx.ready_children.values())[0], {0,1})
if __name__ == '__main__':
unittest.main()

View File

@@ -887,19 +887,22 @@ class RewriteContext:
self.bpm: PatternMatcher|None = bpm
self.ctx = ctx
self.replace: dict[UOp, UOp] = {}
self.skip_0: dict[UOp, None] = {} # NOTE: this is needed for RewriteNotReady. it also detects some infinite loops
def unified_rewrite(self, root:UOp) -> UOp:
stack: list[tuple[UOp, int, UOp]] = [(root, 0, root)]
while stack:
if len(stack) >= 200000: raise RuntimeError("infinite loop in graph_rewrite")
if len(stack) >= 200000: raise RuntimeError("infinite loop in graph_rewrite (stack too big)")
n, stage, new_n = stack.pop()
if n in self.replace: continue # skip any nodes we have seen
try:
if stage == 0:
if n in self.skip_0: continue
# if bottom up, we rewrite this node early. in both cases, we add its parents to the stack
if self.bpm is not None: new_n = self.bpm.fixed_point_rewrite(new_n, self.ctx)
stack.append((n, 1, new_n))
for x in reversed(new_n.src): stack.append((x, 0, x))
self.skip_0[n] = None
elif stage == 1:
try: new_src = tuple([self.replace[x] for x in new_n.src])
except KeyError: raise RewriteNotReady # pylint: disable=raise-missing-from
@@ -917,7 +920,7 @@ class RewriteContext:
else:
# in stage 2, we link the result of new_n to the result of n
try: self.replace[n] = self.replace[new_n]
except KeyError: raise RewriteNotReady # pylint: disable=raise-missing-from
except KeyError: raise RuntimeError("infinite loop in graph_rewrite (explicit)") # pylint: disable=raise-missing-from
except RewriteNotReady:
# retry this later
stack.insert(0, (n, stage, new_n))