From da61088ca4f2d01078a5e4f5a66fc78cbb9b8615 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 5 Mar 2026 12:53:22 -0500 Subject: [PATCH] more divmod recombine (#15162) --- test/null/test_uop_symbolic.py | 29 ++++++++++++++++++++--------- tinygrad/uop/divandmod.py | 14 ++++++++------ 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/test/null/test_uop_symbolic.py b/test/null/test_uop_symbolic.py index 488d016b98..23317d3a03 100644 --- a/test/null/test_uop_symbolic.py +++ b/test/null/test_uop_symbolic.py @@ -220,7 +220,7 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(usum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) % 2, 0, 0, "0") def test_sum_div_some_factor(self): - self.helper_test_variable(usum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "(((a*5)//2)+(b*2))") + self.helper_test_variable(usum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "((a*2)+(b*2)+(a//2))") def test_sum_div_trim_const(self): self.helper_test_variable((Variable("a", 0, 7)*4 + Variable("b", 0, 3)*4 + 7) // 16, 0, 2, "(((a+b)+1)//4)") @@ -228,10 +228,10 @@ class TestSymbolic(unittest.TestCase): def test_sum_div_some_partial_factor(self): self.helper_test_variable(usum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)") self.helper_test_variable(usum([uconst(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)") - self.helper_test_variable((Variable("a", 0, 7)*30+20)//20, 1, 11, "(((a*3)//2)+1)") + self.helper_test_variable((Variable("a", 0, 7)*30+20)//20, 1, 11, "((a+(a//2))+1)") def test_sum_div_no_factor(self): - self.helper_test_variable(usum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)") + self.helper_test_variable(usum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "((a*2)+(b*2)+((a+b)//2))") def test_mod_min_max(self): self.helper_test_variable(Variable("x", 0, 10)%Variable("y", 1, 10), 0, 9, "(x%y)") @@ -598,8 +598,7 @@ class TestSymbolic(unittest.TestCase): gidx0 = Variable("gidx0", 0, 2) lidx2 = Variable("lidx2", 0, 12) lidx3 = Variable("lidx3", 0, 12) - # TODO: improve nest_div_by_smallest_factor to get ((lidx2+(lidx3*2))//3) - self.helper_test_variable((gidx0*3+lidx2*19+lidx3*38)//(3*19), 0, 12, "((gidx0+(lidx2*19+lidx3*38)//3)//19)") + self.helper_test_variable((gidx0*3+lidx2*19+lidx3*38)//(3*19), 0, 12, "((lidx2+(lidx3*2))//3)") def test_sum_mul_distribute(self): gidx0 = Variable("gidx0", 0, 7) @@ -671,8 +670,7 @@ class TestSymbolic(unittest.TestCase): a = Variable("a", 0, 2) b = Variable("b", 0, 100) self.helper_test_variable((31 * a + 1) % 30 + ((31 * a + 1) // 30) * 30, 1, 63, "((a*31)+1)") - with self.assertRaises(AssertionError): - self.helper_test_variable((31 * b + 1) % 18 + ((31 * b + 1) // 18) * 18, 1, 3101, "((b*31)+1)") + self.helper_test_variable((31 * b + 1) % 18 + ((31 * b + 1) // 18) * 18, 1, 3101, "((b*31)+1)") def test_div_mod_recombine_3level(self): gidx = Variable("gidx", 0, 150527) @@ -696,8 +694,21 @@ class TestSymbolic(unittest.TestCase): b = Variable("b", 0, 100) exp = (16 * b + 2) % 18 + ((16 * b + 2) // 18) * 18 self.helper_test_variable(exp, 2, 1602, "((b*16)+2)") - with self.assertRaises(AssertionError): - self.helper_test_variable((30 * b + 1) % 18 + ((30 * b + 1) // 18) * 18, 1, 3001, "((b*30)+1)") + self.helper_test_variable((30 * b + 1) % 18 + ((30 * b + 1) // 18) * 18, 1, 3001, "((b*30)+1)") + + def test_div_partial_quotient(self): + # IDIV should extract partial quotients when const_factor > divisor, matching what MOD already does + # (f*x+c)//d -> (f%d*x+c)//d + (f//d)*x when f >= d + b = Variable("b", 0, 100) + self.helper_test_variable((31*b+1)//18, 0, 172, "(((b*13)+1)//18+b)") + self.helper_test_variable((19*b+3)//7, 0, 271, "(((b*5)+3)//7+(b*2))") + + def test_div_mod_recombine_large_coeff(self): + # recombine must work even when coeff > divisor: both mod and div reduce the coeff the same way + b = Variable("b", 0, 100) + self.helper_test_variable((19*b+3)%7 + ((19*b+3)//7)*7, 3, 1903, "((b*19)+3)") + a = Variable("a", 0, 10) + self.helper_test_variable((25*a+3)%10 + ((25*a+3)//10)*10, 3, 253, "((a*25)+3)") def test_gated_load(self): idx = Variable("idx", 0, 24) diff --git a/tinygrad/uop/divandmod.py b/tinygrad/uop/divandmod.py index 8be86b73ec..57b222c194 100644 --- a/tinygrad/uop/divandmod.py +++ b/tinygrad/uop/divandmod.py @@ -63,12 +63,14 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None: ret = new_x.alu(d.op, x.ufix(c//gcd.arg)) return ret*gcd + const%gcd.arg if d.op is Ops.MOD else ret+const//c - # nest_div_by_smallest_factor: try and nest the div and see if it allows the numerator to be simplified + # nest_div_by_factor: try nesting the div with each candidate factor and pick the simplest result if d.op is Ops.IDIV and x.vmin >= 0: - div = min([c] + [abs(f) for u, f in zip(uops_no_const, factors) if u.op not in (Ops.CONST, Ops.VCONST) and abs(f) > 1 and (c%f)==0]) # NOTE: this is recursive! - if div < c and (newxs := fold_divmod_general(x//div, correct_divmod_folding)) is not None and newxs.vmin >= 0: - return newxs // (c // div) + results = [] + for div in {abs(f) for u, f in zip(uops_no_const, factors) if u.op not in (Ops.CONST, Ops.VCONST) and 1 < abs(f) < c and (c%f)==0}: + if (newxs := fold_divmod_general(x//div, correct_divmod_folding)) is not None and newxs.vmin >= 0: + results.append((len(newxs.backward_slice), newxs // (c // div))) + if results: return min(results)[1] # ** Variable Denominator / Fallback Rules ** # These rules apply to variables OR constants that failed the checks above. @@ -86,9 +88,9 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None: quo, rem = [], [] for u in all_uops: if (q:=u.divide_exact(y)) is not None: quo.append(q) - elif d.op is Ops.MOD and y.op is Ops.CONST and (c:=u.const_factor())%y.arg!=c: + elif y.op is Ops.CONST and (c:=u.const_factor())%y.arg!=c: rem.append(u.divides(c)*(c%y.arg)) - quo.append(u.const_like(0)) + quo.append(u.divides(c)*(c//y.arg) if d.op is Ops.IDIV else u.const_like(0)) else: rem.append(u) if not quo: return None