mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
minor improvements to rewrite (#12454)
* minor improvements to rewrite * need that continue * faster
This commit is contained in:
@@ -10,6 +10,7 @@ class FastEnum(IntEnum):
|
||||
class Ops(FastEnum):
|
||||
# uops that aren't rendered
|
||||
NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); PRECAST = auto(); REWRITE_ERROR = auto() # noqa: E702
|
||||
SENTINEL = auto()
|
||||
|
||||
# track children
|
||||
CHILD = auto(); CHILDREN = auto() # noqa: E702
|
||||
|
||||
@@ -1023,6 +1023,7 @@ if TRACK_MATCH_STATS or PROFILE:
|
||||
|
||||
# *** simple graph rewrite engine ***
|
||||
|
||||
SENTINEL = UOp(Ops.SENTINEL)
|
||||
class RewriteNotReady(Exception): pass
|
||||
class BottomUpGate(Exception): pass
|
||||
class RewriteContext:
|
||||
@@ -1035,45 +1036,58 @@ class RewriteContext:
|
||||
self.replace: dict[UOp, UOp] = {}
|
||||
|
||||
def cached_pm_rewrite(self, x:UOp):
|
||||
if (ret:=self.pm_cache.get(x,False)) is not False: return ret
|
||||
if (ret:=self.pm_cache.get(x,SENTINEL)) is not SENTINEL: return ret
|
||||
ret = self.pm_cache[x] = cast(PatternMatcher, self.pm).rewrite(x, self.ctx)
|
||||
return ret
|
||||
|
||||
def cached_bpm_rewrite(self, x:UOp):
|
||||
if (ret:=self.bpm_cache.get(x,False)) is not False: return ret
|
||||
if (ret:=self.bpm_cache.get(x,SENTINEL)) is not SENTINEL: return ret
|
||||
ret = self.bpm_cache[x] = cast(PatternMatcher, self.bpm).rewrite(x, self.ctx)
|
||||
return ret
|
||||
|
||||
def unified_rewrite(self, root:UOp) -> UOp:
|
||||
stack: collections.deque[tuple[UOp, int, UOp]] = collections.deque([(root, 0, root)])
|
||||
on_stack = {root} # all UOps either on the stack or in self.replace, i.e. dont have to be placed again
|
||||
REWRITE_STACK_LIMIT = getenv("REWRITE_STACK_LIMIT", 250000)
|
||||
while stack:
|
||||
if len(stack) > getenv("REWRITE_STACK_LIMIT", 250000): raise RuntimeError("infinite loop in graph_rewrite (stack too big)")
|
||||
if len(stack) > REWRITE_STACK_LIMIT: raise RuntimeError("infinite loop in graph_rewrite (stack too big)")
|
||||
n, stage, new_n = stack.pop()
|
||||
if n in self.replace: continue # skip any nodes we have seen
|
||||
try:
|
||||
if stage == 0:
|
||||
if stage == 0:
|
||||
# if bottom up, we rewrite this node early. in both cases, we add its parents to the stack
|
||||
if self.bpm is not None:
|
||||
# apply rewrite rules until a fixed point is reached. may return `uop` itself if PatternMatcher doesn't match
|
||||
test_n: UOp|None = n
|
||||
seen = set()
|
||||
try:
|
||||
# if bottom up, we rewrite this node early. in both cases, we add its parents to the stack
|
||||
if self.bpm is not None:
|
||||
# apply rewrite rules until a fixed point is reached. may return `uop` itself if PatternMatcher doesn't match
|
||||
test_n: UOp|None = n
|
||||
seen = set()
|
||||
while test_n is not None:
|
||||
if test_n in seen: raise RuntimeError("infinite loop in fixed_point_rewrite")
|
||||
seen.add(test_n)
|
||||
new_n, test_n = test_n, self.cached_bpm_rewrite(test_n)
|
||||
stack.append((n, 1, new_n))
|
||||
for x in reversed(new_n.src):
|
||||
if x in on_stack: continue
|
||||
stack.append((x, 0, x))
|
||||
on_stack.add(x)
|
||||
# if the bpm matching raised a gate, we are done with this node and dont continue down the srcs
|
||||
except BottomUpGate: self.replace[n] = new_n
|
||||
elif stage == 1:
|
||||
try: new_src = tuple([self.replace[x] for x in new_n.src])
|
||||
except KeyError: raise RewriteNotReady
|
||||
if new_src == new_n.src:
|
||||
while test_n is not None:
|
||||
if test_n in seen: raise RuntimeError("infinite loop in fixed_point_rewrite")
|
||||
seen.add(test_n)
|
||||
new_n, test_n = test_n, self.cached_bpm_rewrite(test_n)
|
||||
except RewriteNotReady:
|
||||
# try the full thing again later
|
||||
stack.appendleft((n, 0, n))
|
||||
continue
|
||||
except BottomUpGate:
|
||||
# if the bpm matching raised a gate, we are done with this node and dont continue down the srcs
|
||||
self.replace[n] = new_n
|
||||
continue
|
||||
stack.append((n, 1, new_n))
|
||||
for x in reversed(new_n.src):
|
||||
if x in on_stack: continue
|
||||
stack.append((x, 0, x))
|
||||
on_stack.add(x)
|
||||
elif stage == 1:
|
||||
tmp = []
|
||||
for x in new_n.src:
|
||||
if (rx:=self.replace.get(x, SENTINEL)) is SENTINEL:
|
||||
# if some new sources aren't ready, we try this again later
|
||||
stack.appendleft((n, 1, new_n))
|
||||
break
|
||||
tmp.append(rx)
|
||||
else:
|
||||
# in stage 1, once all srcs are rewritten, rebuild (if changed) or run top-down rewrite
|
||||
if (new_src:=tuple(tmp)) == new_n.src:
|
||||
# if top down, do the rewrite. if no rewrite or bottom up, we are done rewriting this node so we add it to the dict
|
||||
if self.pm is None or (new_src_n:=self.cached_pm_rewrite(new_n)) is None:
|
||||
self.replace[n] = new_n
|
||||
@@ -1084,13 +1098,14 @@ class RewriteContext:
|
||||
# trigger a rewrite of new_src_n, then after that rewrite is done, link it back to n
|
||||
stack.append((n, 2, new_src_n))
|
||||
stack.append((new_src_n, 0, new_src_n))
|
||||
else:
|
||||
# in stage 2, we link the result of new_n to the result of n
|
||||
if (replaced_new_n:=self.replace.get(new_n, SENTINEL)) is SENTINEL:
|
||||
# not ready, try the link later
|
||||
stack.appendleft((n, 2, new_n))
|
||||
else:
|
||||
# in stage 2, we link the result of new_n to the result of n
|
||||
try: self.replace[n] = self.replace[new_n]
|
||||
except KeyError: raise RewriteNotReady
|
||||
except RewriteNotReady:
|
||||
# retry this later
|
||||
stack.appendleft((n, stage, new_n))
|
||||
# otherwise we are done
|
||||
self.replace[n] = replaced_new_n
|
||||
return self.replace[root]
|
||||
|
||||
@track_matches
|
||||
|
||||
Reference in New Issue
Block a user