UOp.const(x.dtype, y) -> x.const(y) [run_process_replay] (#5642)

This commit is contained in:
chenyu
2024-07-22 17:09:40 -04:00
committed by GitHub
parent 97b116bb1d
commit 24505199fb
4 changed files with 20 additions and 21 deletions

View File

@@ -10,7 +10,7 @@ simple_pm = PatternMatcher([
(UOp.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
(UOp.cvar('x') + UOp.cvar('y'), lambda x,y: UOp.const(dtypes.float, x.arg+y.arg)),
(UOp.cvar('x') * UOp.cvar('y') * UOp.cvar('z'), lambda x,y,z: UOp.const(dtypes.float, x.arg*y.arg*z.arg)),
((UOp.var('x') + UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x + UOp.const(x.dtype, c1.arg+c2.arg)),
((UOp.var('x') + UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x + x.const(c1.arg+c2.arg)),
])
class TestGraphRewrite(unittest.TestCase):