UOp pattern x + x -> x * 2 (#6224)

* UOp pattern x + x -> x * 2

now there's no NEG, with this it covers all kinds of a*x+b*x

* can remove x-x
This commit is contained in:
chenyu
2024-08-21 12:06:19 -04:00
committed by GitHub
parent c9a9631818
commit a666450e4d
3 changed files with 14 additions and 5 deletions

View File

@@ -97,8 +97,13 @@ class TestSymbolic(unittest.TestCase):
def test_sub_num_1(self):
self.helper_test_variable(Variable("a", 0, 8)-NumNode(1), -1, 7, "(-1+a)")
def test_add_self(self):
a = Variable("a", 0, 8)
self.helper_test_variable(a+a, 0, 16, "(a*2)")
def test_sub_self(self):
a = Variable("a", 0, 8)
self.helper_test_variable(a-a, 0, 0, "0")
self.helper_test_variable(a*3-a, 0, 16, "(a*2)")
def test_mul_0(self):

View File

@@ -150,8 +150,13 @@ class TestSymbolic(unittest.TestCase):
def test_sub_num_1(self):
self.helper_test_variable(Variable("a", 0, 8)-NumNode(1), -1, 7, {"(-1+a)", "(a+(-1))"})
def test_add_self(self):
a = Variable("a", 0, 8)
self.helper_test_variable(a+a, 0, 16, "(a*2)")
def test_sub_self(self):
a = Variable("a", 0, 8)
self.helper_test_variable(a-a, 0, 0, "0")
self.helper_test_variable(a*3-a, 0, 16, "(a*2)")
def test_mul_0(self):

View File

@@ -269,7 +269,6 @@ constant_folder = PatternMatcher([
# ** constant folding **
(UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: root.const(exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
# ** self folding **
(-(-NOp.var('x')), lambda x: x), # -(-x) -> x
(NOp.var('x') + 0, lambda x: x), # x+0 -> x
(NOp.var('x') * 1, lambda x: x), # x*1 -> x
(NOp.var('x') // NOp.var('x'), lambda x: x.const(1)), # x//x -> 1
@@ -283,8 +282,6 @@ constant_folder = PatternMatcher([
# if x is nan or inf it should render the nan value.
# NOTE: this can be wrong for loaded NaN
(NOp.var('x') * 0, lambda x: x.const(float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
# x-x -> 0
(NOp.var('x') - NOp.var('x'), lambda x: x.const(0)),
# min==max -> CONST (slow!)
(UPat({UOps.ALU, UOps.DEFINE_VAR}, name='x'), lambda x: x.const(x.vmin.arg) if x.vmin.arg == x.vmax.arg else None),
# ** load/store folding **
@@ -322,6 +319,10 @@ constant_folder = PatternMatcher([
x*c1+c0.arg*c1.arg if dtypes.is_int(x.dtype) and not dtypes.is_unsigned(x.dtype) else None),
# (x*c0)+(x*c1) -> x*(c0+c1)
(NOp.var("x") * NOp.cvar("c0") + NOp.var("x") * NOp.cvar("c1"), lambda x,c0,c1: x*exec_alu(BinaryOps.ADD, x.dtype, [c0.arg, c1.arg])),
# (x+x*c)-> x*(c+1)
(NOp.var("x") + NOp.var("x") * NOp.cvar("c"), lambda x,c: x*(c.arg+1)),
# (x+x)-> x*2
(NOp.var("x") + NOp.var("x"), lambda x: x*2),
# (x*c0)+(y*c0) -> (x+y)*c0
#((NOp.var("x") * NOp.cvar("c0")) + (NOp.var("y") * NOp.cvar("c0")), lambda x,y,c0: c0*(x+y)),
# (x*x2)/x2 -> x
@@ -332,8 +333,6 @@ constant_folder = PatternMatcher([
((NOp.var("x") / NOp.var("x2")) / NOp.var("x3"), lambda x,x2,x3: x/(x2*x3)),
# c0 + x < c1 -> x < c1 - c0
((NOp.cvar("c0") + NOp.var("x")).lt(NOp.cvar("c1")), lambda x,c0,c1: UOp.lt(x, x.const(exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, -c0.arg])))),
# (x+x*c0)-> x*(c0+1)
(NOp.var("x") + NOp.var("x") * NOp.cvar("c0"), lambda x,c0: x*(c0.arg+1)),
# x!=0 -> (bool)x
(NOp.var("x").ne(0), lambda x: x.cast(dtypes.bool)),
# TODO: can do the invert of this (flip alt/load) when we fix double ops