mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Improved mod folding (#7887)
* Remove uneccessary if statement In all paths where something_changed was set to True, remainder is appended so the list can't be empty * Working version of improved mod folding * Fix offset calculation Passing fuzz_symbolic.py to 130_000 so far Added an extra test * Cleaner offset calculation
This commit is contained in:
@@ -192,6 +192,22 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_mod_to_sub(self):
|
||||
self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, "(a+-1)")
|
||||
|
||||
def test_mod_congruence(self):
|
||||
self.helper_test_variable((3+3*Variable("a",0,3))%4, 0, 3, "((a*-1)+3)")
|
||||
self.helper_test_variable((17+13*Variable("a",0,3))%18, 2, 17, "((a*-5)+17)")
|
||||
|
||||
def test_mod_congruence_mul_add(self):
|
||||
self.helper_test_variable((6*(Variable("a", 0, 2)+1))%9, 0, 6, "((a*-3)+6)")
|
||||
|
||||
def test_mod_congruence_multiple_vars(self):
|
||||
self.helper_test_variable((9+9*Variable("x",0,3)+9*Variable("y",0,3))%10, 3, 9, "(((x*-1)+(y*-1))+9)")
|
||||
self.helper_test_variable((7+9*Variable("x",0,2)+9*Variable("y",0,2)+Variable("z",0,2))%10, 3, 9, "(((z+(x*-1))+(y*-1))+7)")
|
||||
self.helper_test_variable((10+12*Variable("x",0,2)+Variable("y", 0, 4)%3)%13, 8, 12, "(((x*-1)+(y%3))+10)")
|
||||
|
||||
def test_mod_binary_expression(self):
|
||||
self.helper_test_variable((3+Variable("a",0,1))%4, 0, 3, "((a*-3)+3)")
|
||||
self.helper_test_variable((3+Variable("a",4,5))%4, 0, 3, "((a*-3)+15)")
|
||||
|
||||
def test_sum_div_const(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 4, 0, 7, "a")
|
||||
|
||||
|
||||
@@ -885,19 +885,32 @@ def mod_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
# simple cancel mod case
|
||||
if 0 < c and 0 <= x.vmin and (quotient:=x.vmin//c) == x.vmax//c: return x-quotient*c
|
||||
|
||||
remainder, something_changed = [], False
|
||||
terms, rem_const, something_changed, offset = [], 0, False, 0
|
||||
for u in split_uop(x, Ops.ADD):
|
||||
if (factor:=u.const_factor())%c != factor:
|
||||
divides = u.divides(factor)*(factor%c)
|
||||
assert divides is not None
|
||||
remainder.append(divides)
|
||||
something_changed = True
|
||||
factor = u.const_factor()
|
||||
e: UOp = u.divides(factor)
|
||||
if (new_factor:=factor%c) != factor: something_changed = True
|
||||
elif u.op is Ops.MOD and (s1:=u.src[1]).op is Ops.CONST and s1.arg%c == 0:
|
||||
remainder.append(u.src[0])
|
||||
e = u.src[0]
|
||||
something_changed = True
|
||||
else: remainder.append(u)
|
||||
offset += new_factor * e.vmin
|
||||
if u.op is Ops.CONST: rem_const += new_factor
|
||||
else: terms.append((new_factor, e))
|
||||
|
||||
match terms: # cases like (x[4-5] + 3) % 4 -> -3*x[4-5]+15
|
||||
case [(f, e)] if e.vmax-e.vmin == 1: return ((offset+f)%c - offset%c)*(e - e.vmin) + offset%c
|
||||
|
||||
# cases like (3+3x[0-3])%4 -> 3-x[0-3]
|
||||
lbound = ubound = offset = offset % c
|
||||
for (f, e) in terms:
|
||||
if f > c//2:
|
||||
if (lbound := lbound + (f-c)*(e.vmax-e.vmin)) < 0: break
|
||||
elif (ubound := ubound + f*(e.vmax-e.vmin)) >= c: break
|
||||
else: # we have found factors such that vmin/vmax of the final expression is between 0 and c, we can remove the mod
|
||||
return functools.reduce(lambda r, t: r + min(t[0], t[0]-c, key=abs)*(t[1]-t[1].vmin), terms, x.const_like(offset))
|
||||
|
||||
if not something_changed: return None
|
||||
return functools.reduce(operator.add, remainder)%c if remainder else x.const_like(0)
|
||||
return functools.reduce(lambda r, t: r + t[0]*t[1], terms, x.const_like(rem_const)) % c
|
||||
|
||||
def div_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
# simplify x // c, None means no change
|
||||
|
||||
Reference in New Issue
Block a user