minor improvements to rewrite (#12454)

* minor improvements to rewrite

* need that continue

* faster
This commit is contained in:
George Hotz
2025-10-05 18:09:32 +08:00
committed by GitHub
parent 4b60121498
commit a976ace404
2 changed files with 47 additions and 31 deletions

View File

@@ -10,6 +10,7 @@ class FastEnum(IntEnum):
class Ops(FastEnum):
# uops that aren't rendered
NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); PRECAST = auto(); REWRITE_ERROR = auto() # noqa: E702
SENTINEL = auto()
# track children
CHILD = auto(); CHILDREN = auto() # noqa: E702

View File

@@ -1023,6 +1023,7 @@ if TRACK_MATCH_STATS or PROFILE:
# *** simple graph rewrite engine ***
SENTINEL = UOp(Ops.SENTINEL)
class RewriteNotReady(Exception): pass
class BottomUpGate(Exception): pass
class RewriteContext:
@@ -1035,45 +1036,58 @@ class RewriteContext:
self.replace: dict[UOp, UOp] = {}
def cached_pm_rewrite(self, x:UOp):
if (ret:=self.pm_cache.get(x,False)) is not False: return ret
if (ret:=self.pm_cache.get(x,SENTINEL)) is not SENTINEL: return ret
ret = self.pm_cache[x] = cast(PatternMatcher, self.pm).rewrite(x, self.ctx)
return ret
def cached_bpm_rewrite(self, x:UOp):
if (ret:=self.bpm_cache.get(x,False)) is not False: return ret
if (ret:=self.bpm_cache.get(x,SENTINEL)) is not SENTINEL: return ret
ret = self.bpm_cache[x] = cast(PatternMatcher, self.bpm).rewrite(x, self.ctx)
return ret
def unified_rewrite(self, root:UOp) -> UOp:
stack: collections.deque[tuple[UOp, int, UOp]] = collections.deque([(root, 0, root)])
on_stack = {root} # all UOps either on the stack or in self.replace, i.e. dont have to be placed again
REWRITE_STACK_LIMIT = getenv("REWRITE_STACK_LIMIT", 250000)
while stack:
if len(stack) > getenv("REWRITE_STACK_LIMIT", 250000): raise RuntimeError("infinite loop in graph_rewrite (stack too big)")
if len(stack) > REWRITE_STACK_LIMIT: raise RuntimeError("infinite loop in graph_rewrite (stack too big)")
n, stage, new_n = stack.pop()
if n in self.replace: continue # skip any nodes we have seen
try:
if stage == 0:
if stage == 0:
# if bottom up, we rewrite this node early. in both cases, we add its parents to the stack
if self.bpm is not None:
# apply rewrite rules until a fixed point is reached. may return `uop` itself if PatternMatcher doesn't match
test_n: UOp|None = n
seen = set()
try:
# if bottom up, we rewrite this node early. in both cases, we add its parents to the stack
if self.bpm is not None:
# apply rewrite rules until a fixed point is reached. may return `uop` itself if PatternMatcher doesn't match
test_n: UOp|None = n
seen = set()
while test_n is not None:
if test_n in seen: raise RuntimeError("infinite loop in fixed_point_rewrite")
seen.add(test_n)
new_n, test_n = test_n, self.cached_bpm_rewrite(test_n)
stack.append((n, 1, new_n))
for x in reversed(new_n.src):
if x in on_stack: continue
stack.append((x, 0, x))
on_stack.add(x)
# if the bpm matching raised a gate, we are done with this node and dont continue down the srcs
except BottomUpGate: self.replace[n] = new_n
elif stage == 1:
try: new_src = tuple([self.replace[x] for x in new_n.src])
except KeyError: raise RewriteNotReady
if new_src == new_n.src:
while test_n is not None:
if test_n in seen: raise RuntimeError("infinite loop in fixed_point_rewrite")
seen.add(test_n)
new_n, test_n = test_n, self.cached_bpm_rewrite(test_n)
except RewriteNotReady:
# try the full thing again later
stack.appendleft((n, 0, n))
continue
except BottomUpGate:
# if the bpm matching raised a gate, we are done with this node and dont continue down the srcs
self.replace[n] = new_n
continue
stack.append((n, 1, new_n))
for x in reversed(new_n.src):
if x in on_stack: continue
stack.append((x, 0, x))
on_stack.add(x)
elif stage == 1:
tmp = []
for x in new_n.src:
if (rx:=self.replace.get(x, SENTINEL)) is SENTINEL:
# if some new sources aren't ready, we try this again later
stack.appendleft((n, 1, new_n))
break
tmp.append(rx)
else:
# in stage 1, once all srcs are rewritten, rebuild (if changed) or run top-down rewrite
if (new_src:=tuple(tmp)) == new_n.src:
# if top down, do the rewrite. if no rewrite or bottom up, we are done rewriting this node so we add it to the dict
if self.pm is None or (new_src_n:=self.cached_pm_rewrite(new_n)) is None:
self.replace[n] = new_n
@@ -1084,13 +1098,14 @@ class RewriteContext:
# trigger a rewrite of new_src_n, then after that rewrite is done, link it back to n
stack.append((n, 2, new_src_n))
stack.append((new_src_n, 0, new_src_n))
else:
# in stage 2, we link the result of new_n to the result of n
if (replaced_new_n:=self.replace.get(new_n, SENTINEL)) is SENTINEL:
# not ready, try the link later
stack.appendleft((n, 2, new_n))
else:
# in stage 2, we link the result of new_n to the result of n
try: self.replace[n] = self.replace[new_n]
except KeyError: raise RewriteNotReady
except RewriteNotReady:
# retry this later
stack.appendleft((n, stage, new_n))
# otherwise we are done
self.replace[n] = replaced_new_n
return self.replace[root]
@track_matches