More conditions for (x//c1+a)//c2 -> (x+a*c1)//(c1*c2) (#10834)

* add rule and test

* typo

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
Sieds Lykles
2025-06-16 22:34:52 +02:00
committed by GitHub
parent 18d936f981
commit b1fefb76dd
2 changed files with 5 additions and 1 deletions

View File

@@ -271,6 +271,10 @@ class TestSymbolic(unittest.TestCase):
a = Variable("a", 0, 124)
self.helper_test_variable(((a-10)//2+10)//2, 2, 33, "((((a+-10)//2)+10)//2)")
def test_div_const_div_wrong_sign_divisor(self):
a = Variable("a", 0, 124)
self.helper_test_variable(((a+10)//-2+10)//-4, -1, 14, "(((((a//2)*-1)+5)//4)*-1)")
def test_neg_mod(self):
a = Variable("a", 0, 124)
self.helper_test_variable((-a)%4, -3, 0, "((a%4)*-1)")

View File

@@ -283,7 +283,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
# ** div **
# div folding
((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)
if (x.vmin>=0 and a.vmin>=0) or (x.vmax<=0 and a.vmax<=0) else None), # (x//c+a)//d -> (x+a*c)//(c*d)
if c.vmin>0 and d.vmin>0 and ((x.vmin>=0 and a.vmin>=0) or (x.vmax<=0 and a.vmax<=0)) else None), # (x//c+a)//d -> (x+a*c)//(c*d)
(UPat.var("x", dtypes.sints) // UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.IDIV)),
(UPat.var("x") // UPat.var("d"), lambda x,d: -(x//(-d)) if d.vmax <=0 else None),
(UPat.var("x") // UPat.var("d"), lambda x,d: -((-x)//d) if x.vmax <=0 else None),