Restore fast path for matching new_src in rewrite (#6870)

This commit is contained in:
Tim Becker
2024-10-03 23:22:24 -04:00
committed by GitHub
parent 8931f20765
commit d42cb5596f

View File

@@ -607,8 +607,8 @@ class RewriteContext:
def rewrite(self, n:UOp) -> UOp:
if (rn := self.replace.get(n)) is not None: return rn
new_src = tuple(map(self.rewrite, n.src))
x = UOp(n.op, n.dtype, new_src, n.arg) if new_src != n.src else n
self.replace[n] = ret = self.rewrite(new_x) if (new_x := self.pm.rewrite(x, self.ctx)) is not None else x
new_n = self.pm.rewrite(n, self.ctx) if new_src == n.src else UOp(n.op, n.dtype, new_src, n.arg)
self.replace[n] = ret = n if new_n is None else self.rewrite(new_n)
return ret
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp: