mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
canon
This commit is contained in:
@@ -1056,11 +1056,25 @@ class RewriteContext:
|
||||
ret = self.bpm_cache[x] = cast(PatternMatcher, self.bpm).rewrite(x, self.ctx)
|
||||
return ret
|
||||
|
||||
def canon(self, u: UOp) -> UOp:
|
||||
# chase replace chains with path compression
|
||||
path = []
|
||||
while True:
|
||||
v = self.replace.get(u)
|
||||
if v is None or v is u: # no redirect or self
|
||||
rep = u
|
||||
break
|
||||
path.append(u)
|
||||
u = v
|
||||
for x in path: self.replace[x] = rep
|
||||
return rep
|
||||
|
||||
def unified_rewrite(self, root:UOp) -> UOp:
|
||||
stack: collections.deque[tuple[UOp, int, UOp]] = collections.deque([(root, 0, root)])
|
||||
while stack:
|
||||
if len(stack) > getenv("REWRITE_STACK_LIMIT", 250000): raise RuntimeError("infinite loop in graph_rewrite (stack too big)")
|
||||
n, stage, new_n = stack.pop()
|
||||
#n, new_n = self.canon(n), self.canon(new_n)
|
||||
#print(len(stack), stage)
|
||||
if n in self.replace: continue # skip any nodes we have seen
|
||||
try:
|
||||
@@ -1080,10 +1094,7 @@ class RewriteContext:
|
||||
# if the bpm matching raised a gate, we are done with this node and dont continue down the srcs
|
||||
except BottomUpGate: self.replace[n] = new_n
|
||||
elif stage == 1:
|
||||
try: new_src = tuple([self.replace[x] for x in new_n.src])
|
||||
except KeyError:
|
||||
stack.append((new_n, 0, new_n))
|
||||
continue
|
||||
new_src = tuple([self.replace[x] for x in new_n.src])
|
||||
if new_src == new_n.src:
|
||||
# if top down, do the rewrite. if no rewrite or bottom up, we are done rewriting this node so we add it to the dict
|
||||
if self.pm is None or (new_src_n:=self.cached_pm_rewrite(new_n)) is None:
|
||||
@@ -1100,7 +1111,27 @@ class RewriteContext:
|
||||
self.replace[n] = self.replace[new_n]
|
||||
except ReprocessNode as e:
|
||||
assert e.node is self.replace[e.node]
|
||||
|
||||
"""
|
||||
# invalidate node and all children
|
||||
invalid = [e.node]
|
||||
tset = [e.node]
|
||||
while len(tset):
|
||||
u: UOp = tset.pop()
|
||||
for c in u.children:
|
||||
if c in self.replace:
|
||||
tset.append(c)
|
||||
invalid.append(c)
|
||||
print(len(invalid))
|
||||
for u in invalid:
|
||||
del self.replace[u]
|
||||
"""
|
||||
del self.replace[e.node]
|
||||
for n1,_,n2 in stack:
|
||||
if n1 is e.node or n2 is e.node:
|
||||
print("ISSUE")
|
||||
#stack.clear()
|
||||
#stack.append((root, 0, root))
|
||||
stack.append((e.node, 0, e.node))
|
||||
return self.replace[root]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user