simplify c0*x<c1 for negative int c0,c1 (#6431)

* simplify c0*x<c1 for negative int c0,c1

* fine if rhs is zero
This commit is contained in:
chenyu
2024-09-09 21:05:53 -04:00
committed by GitHub
parent f6f4f3222f
commit fcc69adfc5
3 changed files with 7 additions and 2 deletions

View File

@@ -198,6 +198,7 @@ class TestSymbolic(unittest.TestCase):
def test_mul_lt(self):
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,13), 0, 1, "(a<4)")
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,16), 0, 1, "(a<4)")
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*(-2),0), 0, 1, "((a*-1)<0)")
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,12), 0, 1, "((a*-1)<-2)")
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,13), 0, 1, "((a*-1)<-3)")

View File

@@ -254,8 +254,9 @@ class TestSymbolic(unittest.TestCase):
def test_mul_lt(self):
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,13), 0, 1, "(a<4)")
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*4,16), 0, 1, "(a<4)")
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,12), 0, 1, {"((a*-1)<-2)", "((a*(-4))<(-11))"})
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,13), 0, 1, {"((a*-1)<-3)", "((a*(-4))<(-12))"})
self.helper_test_variable(create_lt_node(Variable("a", 0, 5)*(-2),0), 0, 1, {"((a*-1)<0)", "((a*(-1))<0)"})
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,12), 0, 1, {"((a*-1)<-2)", "((a*(-1))<(-2))"})
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,13), 0, 1, {"((a*-1)<-3)", "((a*(-1))<(-3))"})
def test_div_div(self):
self.helper_test_variable((Variable("a", 0, 1800)//10)//9, 0, 20, "(a//90)")

View File

@@ -292,6 +292,9 @@ constant_folder = PatternMatcher([
# c0*x<c1 for positive int c0,c1
((NOp.cvar('c0')*NOp.var('x')).lt(NOp.cvar('c1')),
lambda x,c0,c1: x.lt(math.ceil(c1.arg/c0.arg)) if dtypes.is_int(x.dtype) and c0.arg > 0 and c1.arg > 0 else None),
# c0*x<c1 for negative int c0 and non-positive c1
((NOp.cvar('c0')*NOp.var('x')).lt(NOp.cvar('c1')),
lambda x,c0,c1: (-x).lt(-(math.floor(-c1.arg/-c0.arg))) if dtypes.is_int(x.dtype) and c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
# mul add lt
(((NOp.cvar('c0')*NOp.var('x'))+NOp.var('x2')).lt(NOp.cvar('c1')),
lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None),