This commit is contained in:
George Hotz
2025-10-05 14:58:52 +08:00
parent c600446299
commit 18552a3040

View File

@@ -1056,11 +1056,25 @@ class RewriteContext:
ret = self.bpm_cache[x] = cast(PatternMatcher, self.bpm).rewrite(x, self.ctx)
return ret
def canon(self, u: UOp) -> UOp:
# chase replace chains with path compression
path = []
while True:
v = self.replace.get(u)
if v is None or v is u: # no redirect or self
rep = u
break
path.append(u)
u = v
for x in path: self.replace[x] = rep
return rep
def unified_rewrite(self, root:UOp) -> UOp:
stack: collections.deque[tuple[UOp, int, UOp]] = collections.deque([(root, 0, root)])
while stack:
if len(stack) > getenv("REWRITE_STACK_LIMIT", 250000): raise RuntimeError("infinite loop in graph_rewrite (stack too big)")
n, stage, new_n = stack.pop()
#n, new_n = self.canon(n), self.canon(new_n)
#print(len(stack), stage)
if n in self.replace: continue # skip any nodes we have seen
try:
@@ -1080,10 +1094,7 @@ class RewriteContext:
# if the bpm matching raised a gate, we are done with this node and dont continue down the srcs
except BottomUpGate: self.replace[n] = new_n
elif stage == 1:
try: new_src = tuple([self.replace[x] for x in new_n.src])
except KeyError:
stack.append((new_n, 0, new_n))
continue
new_src = tuple([self.replace[x] for x in new_n.src])
if new_src == new_n.src:
# if top down, do the rewrite. if no rewrite or bottom up, we are done rewriting this node so we add it to the dict
if self.pm is None or (new_src_n:=self.cached_pm_rewrite(new_n)) is None:
@@ -1100,7 +1111,27 @@ class RewriteContext:
self.replace[n] = self.replace[new_n]
except ReprocessNode as e:
assert e.node is self.replace[e.node]
"""
# invalidate node and all children
invalid = [e.node]
tset = [e.node]
while len(tset):
u: UOp = tset.pop()
for c in u.children:
if c in self.replace:
tset.append(c)
invalid.append(c)
print(len(invalid))
for u in invalid:
del self.replace[u]
"""
del self.replace[e.node]
for n1,_,n2 in stack:
if n1 is e.node or n2 is e.node:
print("ISSUE")
#stack.clear()
#stack.append((root, 0, root))
stack.append((e.node, 0, e.node))
return self.replace[root]