more divmod recombine (#15162)

This commit is contained in:
chenyu
2026-03-05 12:53:22 -05:00
committed by GitHub
parent 167a1d56a6
commit da61088ca4
2 changed files with 28 additions and 15 deletions

View File

@@ -220,7 +220,7 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(usum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) % 2, 0, 0, "0")
def test_sum_div_some_factor(self):
self.helper_test_variable(usum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "(((a*5)//2)+(b*2))")
self.helper_test_variable(usum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "((a*2)+(b*2)+(a//2))")
def test_sum_div_trim_const(self):
self.helper_test_variable((Variable("a", 0, 7)*4 + Variable("b", 0, 3)*4 + 7) // 16, 0, 2, "(((a+b)+1)//4)")
@@ -228,10 +228,10 @@ class TestSymbolic(unittest.TestCase):
def test_sum_div_some_partial_factor(self):
self.helper_test_variable(usum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
self.helper_test_variable(usum([uconst(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)")
self.helper_test_variable((Variable("a", 0, 7)*30+20)//20, 1, 11, "(((a*3)//2)+1)")
self.helper_test_variable((Variable("a", 0, 7)*30+20)//20, 1, 11, "((a+(a//2))+1)")
def test_sum_div_no_factor(self):
self.helper_test_variable(usum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)")
self.helper_test_variable(usum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "((a*2)+(b*2)+((a+b)//2))")
def test_mod_min_max(self):
self.helper_test_variable(Variable("x", 0, 10)%Variable("y", 1, 10), 0, 9, "(x%y)")
@@ -598,8 +598,7 @@ class TestSymbolic(unittest.TestCase):
gidx0 = Variable("gidx0", 0, 2)
lidx2 = Variable("lidx2", 0, 12)
lidx3 = Variable("lidx3", 0, 12)
# TODO: improve nest_div_by_smallest_factor to get ((lidx2+(lidx3*2))//3)
self.helper_test_variable((gidx0*3+lidx2*19+lidx3*38)//(3*19), 0, 12, "((gidx0+(lidx2*19+lidx3*38)//3)//19)")
self.helper_test_variable((gidx0*3+lidx2*19+lidx3*38)//(3*19), 0, 12, "((lidx2+(lidx3*2))//3)")
def test_sum_mul_distribute(self):
gidx0 = Variable("gidx0", 0, 7)
@@ -671,8 +670,7 @@ class TestSymbolic(unittest.TestCase):
a = Variable("a", 0, 2)
b = Variable("b", 0, 100)
self.helper_test_variable((31 * a + 1) % 30 + ((31 * a + 1) // 30) * 30, 1, 63, "((a*31)+1)")
with self.assertRaises(AssertionError):
self.helper_test_variable((31 * b + 1) % 18 + ((31 * b + 1) // 18) * 18, 1, 3101, "((b*31)+1)")
self.helper_test_variable((31 * b + 1) % 18 + ((31 * b + 1) // 18) * 18, 1, 3101, "((b*31)+1)")
def test_div_mod_recombine_3level(self):
gidx = Variable("gidx", 0, 150527)
@@ -696,8 +694,21 @@ class TestSymbolic(unittest.TestCase):
b = Variable("b", 0, 100)
exp = (16 * b + 2) % 18 + ((16 * b + 2) // 18) * 18
self.helper_test_variable(exp, 2, 1602, "((b*16)+2)")
with self.assertRaises(AssertionError):
self.helper_test_variable((30 * b + 1) % 18 + ((30 * b + 1) // 18) * 18, 1, 3001, "((b*30)+1)")
self.helper_test_variable((30 * b + 1) % 18 + ((30 * b + 1) // 18) * 18, 1, 3001, "((b*30)+1)")
def test_div_partial_quotient(self):
# IDIV should extract partial quotients when const_factor > divisor, matching what MOD already does
# (f*x+c)//d -> (f%d*x+c)//d + (f//d)*x when f >= d
b = Variable("b", 0, 100)
self.helper_test_variable((31*b+1)//18, 0, 172, "(((b*13)+1)//18+b)")
self.helper_test_variable((19*b+3)//7, 0, 271, "(((b*5)+3)//7+(b*2))")
def test_div_mod_recombine_large_coeff(self):
# recombine must work even when coeff > divisor: both mod and div reduce the coeff the same way
b = Variable("b", 0, 100)
self.helper_test_variable((19*b+3)%7 + ((19*b+3)//7)*7, 3, 1903, "((b*19)+3)")
a = Variable("a", 0, 10)
self.helper_test_variable((25*a+3)%10 + ((25*a+3)//10)*10, 3, 253, "((a*25)+3)")
def test_gated_load(self):
idx = Variable("idx", 0, 24)

View File

@@ -63,12 +63,14 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None:
ret = new_x.alu(d.op, x.ufix(c//gcd.arg))
return ret*gcd + const%gcd.arg if d.op is Ops.MOD else ret+const//c
# nest_div_by_smallest_factor: try and nest the div and see if it allows the numerator to be simplified
# nest_div_by_factor: try nesting the div with each candidate factor and pick the simplest result
if d.op is Ops.IDIV and x.vmin >= 0:
div = min([c] + [abs(f) for u, f in zip(uops_no_const, factors) if u.op not in (Ops.CONST, Ops.VCONST) and abs(f) > 1 and (c%f)==0])
# NOTE: this is recursive!
if div < c and (newxs := fold_divmod_general(x//div, correct_divmod_folding)) is not None and newxs.vmin >= 0:
return newxs // (c // div)
results = []
for div in {abs(f) for u, f in zip(uops_no_const, factors) if u.op not in (Ops.CONST, Ops.VCONST) and 1 < abs(f) < c and (c%f)==0}:
if (newxs := fold_divmod_general(x//div, correct_divmod_folding)) is not None and newxs.vmin >= 0:
results.append((len(newxs.backward_slice), newxs // (c // div)))
if results: return min(results)[1]
# ** Variable Denominator / Fallback Rules **
# These rules apply to variables OR constants that failed the checks above.
@@ -86,9 +88,9 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None:
quo, rem = [], []
for u in all_uops:
if (q:=u.divide_exact(y)) is not None: quo.append(q)
elif d.op is Ops.MOD and y.op is Ops.CONST and (c:=u.const_factor())%y.arg!=c:
elif y.op is Ops.CONST and (c:=u.const_factor())%y.arg!=c:
rem.append(u.divides(c)*(c%y.arg))
quo.append(u.const_like(0))
quo.append(u.divides(c)*(c//y.arg) if d.op is Ops.IDIV else u.const_like(0))
else: rem.append(u)
if not quo: return None