use more UPat.var and UPat.cvar [run_process_replay] (#6491)

This commit is contained in:
qazal
2024-09-12 13:52:41 +08:00
committed by GitHub
parent e5e14fc4ef
commit 4dc9436d63

View File

@@ -361,8 +361,8 @@ constant_folder = PatternMatcher([
(UPat(UOps.SINK, name="root"),
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None),
# ** move add consts to end (NOTE: this is still happening before constant folding) **
(UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat(UOps.CONST, name='c1'), UPat(name='x'))), lambda c1,x: x+c1 if x.op is not UOps.CONST else None),
(UPat(UOps.ALU, arg=BinaryOps.ADD, src=[UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat(name='x'), UPat(UOps.CONST, name='c1'))), UPat(name='y')]),
(UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat.cvar('c1'), UPat.var('x'))), lambda c1,x: x+c1 if x.op is not UOps.CONST else None),
(UPat(UOps.ALU, arg=BinaryOps.ADD, src=[UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat.var('x'), UPat.cvar('c1'))), UPat.var('y')]),
lambda x,c1,y: (x+y)+c1),
])