mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
tighter bound for MOD (#13550)
This commit is contained in:
@@ -322,12 +322,12 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_mod_mod_wrong_sign(self):
|
||||
v1=Variable("v1", 0, 128)
|
||||
v3=Variable("v3", 0, 7)
|
||||
self.helper_test_variable((((((v1%2)*2)+((v3+-1)%5))+-2)%5), -4, 4, "(((((v1%2)*2)+((v3+-1)%5))+-2)%5)")
|
||||
self.helper_test_variable((((((v1%2)*2)+((v3+-1)%5))+-2)%5), -3, 4, "(v1%2*2+(v3+-1)%5+-2)")
|
||||
|
||||
def test_mod_mod_wrong_sign2(self):
|
||||
v2=Variable("v2", 0, 8)
|
||||
v3=Variable("v3", 0, 4)
|
||||
self.helper_test_variable((((((v3+3)%7)+(v2+-2))%7)%7), -6, 6, "(((v2+((v3+3)%7))+-2)%7)")
|
||||
self.helper_test_variable((((((v3+3)%7)+(v2+-2))%7)%7), -2, 6, "(((v2+((v3+3)%7))+-2)%7)")
|
||||
|
||||
def test_mul_mul(self):
|
||||
self.helper_test_variable((Variable("a", 0, 5)*10)*9, 0, 5*10*9, "(a*90)")
|
||||
@@ -377,9 +377,9 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_big_mod(self):
|
||||
self.helper_test_variable(Variable("a", -20, 20)%10, -9, 9, "(a%10)")
|
||||
self.helper_test_variable(Variable("a", -20, 0)%10, -9, 0, "(((a*-1)%10)*-1)")
|
||||
self.helper_test_variable(Variable("a", -20, 1)%10, -9, 9, "(a%10)") # TODO: tighter max
|
||||
self.helper_test_variable(Variable("a", -20, 1)%10, -9, 1, "(a%10)")
|
||||
self.helper_test_variable(Variable("a", 0, 20)%10, 0, 9, "(a%10)")
|
||||
self.helper_test_variable(Variable("a", -1, 20)%10, -9, 9, "(a%10)") # TODO: tighter min
|
||||
self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)")
|
||||
|
||||
def test_ge_remove(self):
|
||||
self.helper_test_variable(Variable("a", 0, 6) >= 25, 0, 0, "False")
|
||||
|
||||
@@ -768,6 +768,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.SHL and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] << t[2], t[1] << t[2]
|
||||
if self.op is Ops.SHR and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] >> t[2], t[1] >> t[2]
|
||||
if self.op is Ops.MOD:
|
||||
if (c:=s1_vmin) == s1_vmax > 0:
|
||||
return (0 if s0_vmin > 0 else s0_vmin if 0 >= s0_vmin > -c else -(s1_vmax-1), 0 if s0_vmax < 0 else s0_vmax if 0 <= s0_vmax < c else c-1)
|
||||
if s1_vmin > 0: return (0, s1_vmax-1) if s0_vmin >= 0 else (-(s1_vmax-1), 0) if s0_vmax <= 0 else (-(s1_vmax-1), s1_vmax-1)
|
||||
if s1_vmax < 0: return (0, -s1_vmin-1) if s0_vmin >= 0 else (-(-s1_vmin-1), 0) if s0_vmax <= 0 else (-(-s1_vmin-1), -s1_vmin-1)
|
||||
if self.op is Ops.IDIV:
|
||||
|
||||
Reference in New Issue
Block a user