mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
tie break for divmod (#15169)
This commit is contained in:
@@ -703,6 +703,12 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable((31*b+1)//18, 0, 172, "(((b*13)+1)//18+b)")
|
||||
self.helper_test_variable((19*b+3)//7, 0, 271, "(((b*5)+3)//7+(b*2))")
|
||||
|
||||
def test_div_by_factor_tie_break(self):
|
||||
a = Variable("a", 0, 1)
|
||||
b = Variable("b", 0, 1)
|
||||
with Context(CORRECT_DIVMOD_FOLDING=1):
|
||||
self.helper_test_variable((a*2+b*3+2)//6, 0, 1, "((a+b+1)//3)")
|
||||
|
||||
def test_div_mod_recombine_large_coeff(self):
|
||||
# recombine must work even when coeff > divisor: both mod and div reduce the coeff the same way
|
||||
b = Variable("b", 0, 100)
|
||||
|
||||
@@ -70,7 +70,7 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None:
|
||||
for div in {abs(f) for u, f in zip(uops_no_const, factors) if u.op not in (Ops.CONST, Ops.VCONST) and 1 < abs(f) < c and (c%f)==0}:
|
||||
if (newxs := fold_divmod_general(x//div, correct_divmod_folding)) is not None and newxs.vmin >= 0:
|
||||
results.append((len(newxs.backward_slice), newxs // (c // div)))
|
||||
if results: return min(results)[1]
|
||||
if results: return min(results, key=lambda r: r[0])[1]
|
||||
|
||||
# ** Variable Denominator / Fallback Rules **
|
||||
# These rules apply to variables OR constants that failed the checks above.
|
||||
|
||||
Reference in New Issue
Block a user