tiebreak in fold_divmod_congruence (#15190)

need to try both direction
This commit is contained in:
chenyu
2026-03-09 03:40:39 -04:00
committed by GitHub
parent a8d8351e5a
commit 60215deb60
2 changed files with 12 additions and 5 deletions

View File

@@ -286,6 +286,11 @@ class TestSymbolic(unittest.TestCase):
"(((z+(x*-1))+(y*-1))+7)")
self.helper_test_variable((10+12*Variable("x",0,2)+Variable("y", 0, 4)%3)%13, 8, 12, "(((x*-1)+(y%3))+10)")
def test_mod_congruence_tied_remainder(self):
# when f%c == c/2, both r and r-c have equal abs — try both signs
self.helper_test_variable((3+2*Variable("x",0,1)+3*Variable("y",0,1))%4, 0, 3, "((x*-2)+(y*-1)+3)")
self.helper_test_variable((3+6*Variable("x",0,1)+7*Variable("y",0,1))%4, 0, 3, "((x*-2)+(y*-1)+3)")
def test_div_congruence(self):
self.helper_test_variable((3+3*Variable("a",0,3))//4, 0, 3, "a")
self.helper_test_variable((18+17*Variable("a",0,2)+17)//18, 1, 3, "(a+1)")

View File

@@ -1,4 +1,4 @@
import functools
import functools, itertools
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp
from tinygrad.dtype import dtypes
from tinygrad.helpers import cdiv, cmod, CORRECT_DIVMOD_FOLDING, unwrap
@@ -48,10 +48,12 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None:
# fold_divmod_congruence: fold if a is congruent to an expression whose range is between 0 and c
if not (x.vmin<0 and correct_divmod_folding):
rems = [min((r:=f%c), r-c, key=abs) for f in factors]
if (rem:=sum(r*v for r,v in zip(rems,terms))+const%c).vmin//c==rem.vmax//c:
if d.op is Ops.MOD: return rem - rem.vmin//c*c
return sum((f-r)//c * v for f,r,v in zip(factors,rems,terms)) + (const-const%c+rem.vmin//c*c)//c
# when f%c == c//2, abs(r) == abs(r-c) is a tie, try both signs since either may fit in one period
rem_choices = [((r:=f%c), r-c) if (r:=f%c)*2 == c else (min(r, r-c, key=abs),) for f in factors]
for rems in itertools.product(*rem_choices):
if (rem:=sum(r*v for r,v in zip(rems,terms))+const%c).vmin//c==rem.vmax//c:
if d.op is Ops.MOD: return rem - rem.vmin//c*c
return sum((f-r)//c * v for f,r,v in zip(factors,rems,terms)) + (const-const%c+rem.vmin//c*c)//c
# gcd_with_remainder: factor out common gcd from numerator
# Note: this rule uses uops_no_const to exclude the additive constant from the GCD calculation