Improved mod folding (#7887)

* Remove uneccessary if statement

In all paths where something_changed was set to True, remainder is
appended so the list can't be empty

* Working version of improved mod folding

* Fix offset calculation

Passing fuzz_symbolic.py to 130_000 so far
Added an extra test

* Cleaner offset calculation
This commit is contained in:
Sieds Lykles
2024-11-25 04:21:34 +01:00
committed by GitHub
parent 5d92efb121
commit a49a7c4784
2 changed files with 38 additions and 9 deletions

View File

@@ -192,6 +192,22 @@ class TestSymbolic(unittest.TestCase):
def test_mod_to_sub(self):
self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, "(a+-1)")
def test_mod_congruence(self):
self.helper_test_variable((3+3*Variable("a",0,3))%4, 0, 3, "((a*-1)+3)")
self.helper_test_variable((17+13*Variable("a",0,3))%18, 2, 17, "((a*-5)+17)")
def test_mod_congruence_mul_add(self):
self.helper_test_variable((6*(Variable("a", 0, 2)+1))%9, 0, 6, "((a*-3)+6)")
def test_mod_congruence_multiple_vars(self):
self.helper_test_variable((9+9*Variable("x",0,3)+9*Variable("y",0,3))%10, 3, 9, "(((x*-1)+(y*-1))+9)")
self.helper_test_variable((7+9*Variable("x",0,2)+9*Variable("y",0,2)+Variable("z",0,2))%10, 3, 9, "(((z+(x*-1))+(y*-1))+7)")
self.helper_test_variable((10+12*Variable("x",0,2)+Variable("y", 0, 4)%3)%13, 8, 12, "(((x*-1)+(y%3))+10)")
def test_mod_binary_expression(self):
self.helper_test_variable((3+Variable("a",0,1))%4, 0, 3, "((a*-3)+3)")
self.helper_test_variable((3+Variable("a",4,5))%4, 0, 3, "((a*-3)+15)")
def test_sum_div_const(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 4, 0, 7, "a")

View File

@@ -885,19 +885,32 @@ def mod_folding(x:UOp, c:int) -> Optional[UOp]:
# simple cancel mod case
if 0 < c and 0 <= x.vmin and (quotient:=x.vmin//c) == x.vmax//c: return x-quotient*c
remainder, something_changed = [], False
terms, rem_const, something_changed, offset = [], 0, False, 0
for u in split_uop(x, Ops.ADD):
if (factor:=u.const_factor())%c != factor:
divides = u.divides(factor)*(factor%c)
assert divides is not None
remainder.append(divides)
something_changed = True
factor = u.const_factor()
e: UOp = u.divides(factor)
if (new_factor:=factor%c) != factor: something_changed = True
elif u.op is Ops.MOD and (s1:=u.src[1]).op is Ops.CONST and s1.arg%c == 0:
remainder.append(u.src[0])
e = u.src[0]
something_changed = True
else: remainder.append(u)
offset += new_factor * e.vmin
if u.op is Ops.CONST: rem_const += new_factor
else: terms.append((new_factor, e))
match terms: # cases like (x[4-5] + 3) % 4 -> -3*x[4-5]+15
case [(f, e)] if e.vmax-e.vmin == 1: return ((offset+f)%c - offset%c)*(e - e.vmin) + offset%c
# cases like (3+3x[0-3])%4 -> 3-x[0-3]
lbound = ubound = offset = offset % c
for (f, e) in terms:
if f > c//2:
if (lbound := lbound + (f-c)*(e.vmax-e.vmin)) < 0: break
elif (ubound := ubound + f*(e.vmax-e.vmin)) >= c: break
else: # we have found factors such that vmin/vmax of the final expression is between 0 and c, we can remove the mod
return functools.reduce(lambda r, t: r + min(t[0], t[0]-c, key=abs)*(t[1]-t[1].vmin), terms, x.const_like(offset))
if not something_changed: return None
return functools.reduce(operator.add, remainder)%c if remainder else x.const_like(0)
return functools.reduce(lambda r, t: r + t[0]*t[1], terms, x.const_like(rem_const)) % c
def div_folding(x:UOp, c:int) -> Optional[UOp]:
# simplify x // c, None means no change