diff --git a/test/null/test_uop_symbolic.py b/test/null/test_uop_symbolic.py index 23317d3a03..c7b24c04fa 100644 --- a/test/null/test_uop_symbolic.py +++ b/test/null/test_uop_symbolic.py @@ -703,6 +703,12 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((31*b+1)//18, 0, 172, "(((b*13)+1)//18+b)") self.helper_test_variable((19*b+3)//7, 0, 271, "(((b*5)+3)//7+(b*2))") + def test_div_by_factor_tie_break(self): + a = Variable("a", 0, 1) + b = Variable("b", 0, 1) + with Context(CORRECT_DIVMOD_FOLDING=1): + self.helper_test_variable((a*2+b*3+2)//6, 0, 1, "((a+b+1)//3)") + def test_div_mod_recombine_large_coeff(self): # recombine must work even when coeff > divisor: both mod and div reduce the coeff the same way b = Variable("b", 0, 100) diff --git a/tinygrad/uop/divandmod.py b/tinygrad/uop/divandmod.py index 57b222c194..bce9a3fc9f 100644 --- a/tinygrad/uop/divandmod.py +++ b/tinygrad/uop/divandmod.py @@ -70,7 +70,7 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None: for div in {abs(f) for u, f in zip(uops_no_const, factors) if u.op not in (Ops.CONST, Ops.VCONST) and 1 < abs(f) < c and (c%f)==0}: if (newxs := fold_divmod_general(x//div, correct_divmod_folding)) is not None and newxs.vmin >= 0: results.append((len(newxs.backward_slice), newxs // (c // div))) - if results: return min(results)[1] + if results: return min(results, key=lambda r: r[0])[1] # ** Variable Denominator / Fallback Rules ** # These rules apply to variables OR constants that failed the checks above.