mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
more divmod recombine (#15162)
This commit is contained in:
@@ -220,7 +220,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable(usum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) % 2, 0, 0, "0")
|
||||
|
||||
def test_sum_div_some_factor(self):
|
||||
self.helper_test_variable(usum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "(((a*5)//2)+(b*2))")
|
||||
self.helper_test_variable(usum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "((a*2)+(b*2)+(a//2))")
|
||||
|
||||
def test_sum_div_trim_const(self):
|
||||
self.helper_test_variable((Variable("a", 0, 7)*4 + Variable("b", 0, 3)*4 + 7) // 16, 0, 2, "(((a+b)+1)//4)")
|
||||
@@ -228,10 +228,10 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_sum_div_some_partial_factor(self):
|
||||
self.helper_test_variable(usum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
|
||||
self.helper_test_variable(usum([uconst(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)")
|
||||
self.helper_test_variable((Variable("a", 0, 7)*30+20)//20, 1, 11, "(((a*3)//2)+1)")
|
||||
self.helper_test_variable((Variable("a", 0, 7)*30+20)//20, 1, 11, "((a+(a//2))+1)")
|
||||
|
||||
def test_sum_div_no_factor(self):
|
||||
self.helper_test_variable(usum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)")
|
||||
self.helper_test_variable(usum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "((a*2)+(b*2)+((a+b)//2))")
|
||||
|
||||
def test_mod_min_max(self):
|
||||
self.helper_test_variable(Variable("x", 0, 10)%Variable("y", 1, 10), 0, 9, "(x%y)")
|
||||
@@ -598,8 +598,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
gidx0 = Variable("gidx0", 0, 2)
|
||||
lidx2 = Variable("lidx2", 0, 12)
|
||||
lidx3 = Variable("lidx3", 0, 12)
|
||||
# TODO: improve nest_div_by_smallest_factor to get ((lidx2+(lidx3*2))//3)
|
||||
self.helper_test_variable((gidx0*3+lidx2*19+lidx3*38)//(3*19), 0, 12, "((gidx0+(lidx2*19+lidx3*38)//3)//19)")
|
||||
self.helper_test_variable((gidx0*3+lidx2*19+lidx3*38)//(3*19), 0, 12, "((lidx2+(lidx3*2))//3)")
|
||||
|
||||
def test_sum_mul_distribute(self):
|
||||
gidx0 = Variable("gidx0", 0, 7)
|
||||
@@ -671,8 +670,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
a = Variable("a", 0, 2)
|
||||
b = Variable("b", 0, 100)
|
||||
self.helper_test_variable((31 * a + 1) % 30 + ((31 * a + 1) // 30) * 30, 1, 63, "((a*31)+1)")
|
||||
with self.assertRaises(AssertionError):
|
||||
self.helper_test_variable((31 * b + 1) % 18 + ((31 * b + 1) // 18) * 18, 1, 3101, "((b*31)+1)")
|
||||
self.helper_test_variable((31 * b + 1) % 18 + ((31 * b + 1) // 18) * 18, 1, 3101, "((b*31)+1)")
|
||||
|
||||
def test_div_mod_recombine_3level(self):
|
||||
gidx = Variable("gidx", 0, 150527)
|
||||
@@ -696,8 +694,21 @@ class TestSymbolic(unittest.TestCase):
|
||||
b = Variable("b", 0, 100)
|
||||
exp = (16 * b + 2) % 18 + ((16 * b + 2) // 18) * 18
|
||||
self.helper_test_variable(exp, 2, 1602, "((b*16)+2)")
|
||||
with self.assertRaises(AssertionError):
|
||||
self.helper_test_variable((30 * b + 1) % 18 + ((30 * b + 1) // 18) * 18, 1, 3001, "((b*30)+1)")
|
||||
self.helper_test_variable((30 * b + 1) % 18 + ((30 * b + 1) // 18) * 18, 1, 3001, "((b*30)+1)")
|
||||
|
||||
def test_div_partial_quotient(self):
|
||||
# IDIV should extract partial quotients when const_factor > divisor, matching what MOD already does
|
||||
# (f*x+c)//d -> (f%d*x+c)//d + (f//d)*x when f >= d
|
||||
b = Variable("b", 0, 100)
|
||||
self.helper_test_variable((31*b+1)//18, 0, 172, "(((b*13)+1)//18+b)")
|
||||
self.helper_test_variable((19*b+3)//7, 0, 271, "(((b*5)+3)//7+(b*2))")
|
||||
|
||||
def test_div_mod_recombine_large_coeff(self):
|
||||
# recombine must work even when coeff > divisor: both mod and div reduce the coeff the same way
|
||||
b = Variable("b", 0, 100)
|
||||
self.helper_test_variable((19*b+3)%7 + ((19*b+3)//7)*7, 3, 1903, "((b*19)+3)")
|
||||
a = Variable("a", 0, 10)
|
||||
self.helper_test_variable((25*a+3)%10 + ((25*a+3)//10)*10, 3, 253, "((a*25)+3)")
|
||||
|
||||
def test_gated_load(self):
|
||||
idx = Variable("idx", 0, 24)
|
||||
|
||||
@@ -63,12 +63,14 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None:
|
||||
ret = new_x.alu(d.op, x.ufix(c//gcd.arg))
|
||||
return ret*gcd + const%gcd.arg if d.op is Ops.MOD else ret+const//c
|
||||
|
||||
# nest_div_by_smallest_factor: try and nest the div and see if it allows the numerator to be simplified
|
||||
# nest_div_by_factor: try nesting the div with each candidate factor and pick the simplest result
|
||||
if d.op is Ops.IDIV and x.vmin >= 0:
|
||||
div = min([c] + [abs(f) for u, f in zip(uops_no_const, factors) if u.op not in (Ops.CONST, Ops.VCONST) and abs(f) > 1 and (c%f)==0])
|
||||
# NOTE: this is recursive!
|
||||
if div < c and (newxs := fold_divmod_general(x//div, correct_divmod_folding)) is not None and newxs.vmin >= 0:
|
||||
return newxs // (c // div)
|
||||
results = []
|
||||
for div in {abs(f) for u, f in zip(uops_no_const, factors) if u.op not in (Ops.CONST, Ops.VCONST) and 1 < abs(f) < c and (c%f)==0}:
|
||||
if (newxs := fold_divmod_general(x//div, correct_divmod_folding)) is not None and newxs.vmin >= 0:
|
||||
results.append((len(newxs.backward_slice), newxs // (c // div)))
|
||||
if results: return min(results)[1]
|
||||
|
||||
# ** Variable Denominator / Fallback Rules **
|
||||
# These rules apply to variables OR constants that failed the checks above.
|
||||
@@ -86,9 +88,9 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None:
|
||||
quo, rem = [], []
|
||||
for u in all_uops:
|
||||
if (q:=u.divide_exact(y)) is not None: quo.append(q)
|
||||
elif d.op is Ops.MOD and y.op is Ops.CONST and (c:=u.const_factor())%y.arg!=c:
|
||||
elif y.op is Ops.CONST and (c:=u.const_factor())%y.arg!=c:
|
||||
rem.append(u.divides(c)*(c%y.arg))
|
||||
quo.append(u.const_like(0))
|
||||
quo.append(u.divides(c)*(c//y.arg) if d.op is Ops.IDIV else u.const_like(0))
|
||||
else: rem.append(u)
|
||||
|
||||
if not quo: return None
|
||||
|
||||
Reference in New Issue
Block a user