From 60215deb60462d362ceeb4ea0eac414070a71deb Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 9 Mar 2026 03:40:39 -0400 Subject: [PATCH] tiebreak in fold_divmod_congruence (#15190) need to try both direction --- test/null/test_uop_symbolic.py | 5 +++++ tinygrad/uop/divandmod.py | 12 +++++++----- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/test/null/test_uop_symbolic.py b/test/null/test_uop_symbolic.py index b3cdeaaf36..5219cb657e 100644 --- a/test/null/test_uop_symbolic.py +++ b/test/null/test_uop_symbolic.py @@ -286,6 +286,11 @@ class TestSymbolic(unittest.TestCase): "(((z+(x*-1))+(y*-1))+7)") self.helper_test_variable((10+12*Variable("x",0,2)+Variable("y", 0, 4)%3)%13, 8, 12, "(((x*-1)+(y%3))+10)") + def test_mod_congruence_tied_remainder(self): + # when f%c == c/2, both r and r-c have equal abs — try both signs + self.helper_test_variable((3+2*Variable("x",0,1)+3*Variable("y",0,1))%4, 0, 3, "((x*-2)+(y*-1)+3)") + self.helper_test_variable((3+6*Variable("x",0,1)+7*Variable("y",0,1))%4, 0, 3, "((x*-2)+(y*-1)+3)") + def test_div_congruence(self): self.helper_test_variable((3+3*Variable("a",0,3))//4, 0, 3, "a") self.helper_test_variable((18+17*Variable("a",0,2)+17)//18, 1, 3, "(a+1)") diff --git a/tinygrad/uop/divandmod.py b/tinygrad/uop/divandmod.py index c8ddd24c18..52dd8d9dd2 100644 --- a/tinygrad/uop/divandmod.py +++ b/tinygrad/uop/divandmod.py @@ -1,4 +1,4 @@ -import functools +import functools, itertools from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp from tinygrad.dtype import dtypes from tinygrad.helpers import cdiv, cmod, CORRECT_DIVMOD_FOLDING, unwrap @@ -48,10 +48,12 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None: # fold_divmod_congruence: fold if a is congruent to an expression whose range is between 0 and c if not (x.vmin<0 and correct_divmod_folding): - rems = [min((r:=f%c), r-c, key=abs) for f in factors] - if (rem:=sum(r*v for r,v in zip(rems,terms))+const%c).vmin//c==rem.vmax//c: - if d.op is Ops.MOD: return rem - rem.vmin//c*c - return sum((f-r)//c * v for f,r,v in zip(factors,rems,terms)) + (const-const%c+rem.vmin//c*c)//c + # when f%c == c//2, abs(r) == abs(r-c) is a tie, try both signs since either may fit in one period + rem_choices = [((r:=f%c), r-c) if (r:=f%c)*2 == c else (min(r, r-c, key=abs),) for f in factors] + for rems in itertools.product(*rem_choices): + if (rem:=sum(r*v for r,v in zip(rems,terms))+const%c).vmin//c==rem.vmax//c: + if d.op is Ops.MOD: return rem - rem.vmin//c*c + return sum((f-r)//c * v for f,r,v in zip(factors,rems,terms)) + (const-const%c+rem.vmin//c*c)//c # gcd_with_remainder: factor out common gcd from numerator # Note: this rule uses uops_no_const to exclude the additive constant from the GCD calculation