better bound for mod negative number (#10227)

This commit is contained in:
chenyu
2025-05-09 01:19:47 -04:00
committed by GitHub
parent 99f6d89dfb
commit 56def6c319
2 changed files with 27 additions and 21 deletions

View File

@@ -170,32 +170,37 @@ class TestVminVmaxDivMod(unittest.TestCase):
def test_vmin_vmax_mod_positive(self):
# vmin and vmax for modulo of a variable by a positive constant
x = UOp.variable('x', 10, 20)
uop = x % 3
positive = UOp.variable('positive', 10, 20)
uop = positive % 3
self.assertEqual(uop.vmin, 0)
self.assertEqual(uop.vmax, 2)
@unittest.skip("broken")
def test_vmin_vmax_mod_negative(self):
# vmin and vmax for modulo of a variable by a negative constant
x = UOp.variable('x', 10, 20)
uop = x % -3
negative = UOp.variable('negative', -20, -10)
uop = negative % 3
self.assertEqual(uop.vmin, -2)
self.assertEqual(uop.vmax, 0)
def test_vmin_vmax_division_with_mixed_range(self):
# vmin and vmax for division of a variable with a range crossing zero
x = UOp.variable('x', -10, 10)
uop = x // 3
self.assertEqual(uop.vmin, -3) # -10//3 = -3 (in C)
self.assertEqual(uop.vmax, 3) # 10//3 = 3
mixed = UOp.variable('mixed', -20, 20)
uop = mixed % 3
self.assertEqual(uop.vmin, -2)
self.assertEqual(uop.vmax, 2)
def test_vmin_vmax_mod_with_mixed_range(self):
# vmin and vmax for modulo of a variable with a range crossing zero
x = UOp.variable('x', -10, 10)
uop = x % 4
self.assertEqual(uop.vmin, -3)
self.assertEqual(uop.vmax, 3)
def test_vmin_vmax_mod_negative(self):
# vmin and vmax for modulo of a variable by a negative constant
positive = UOp.variable('positive', 10, 20)
uop = positive % -3
self.assertEqual(uop.vmin, 0)
self.assertEqual(uop.vmax, 2)
negative = UOp.variable('negative', -20, -10)
uop = negative % -3
self.assertEqual(uop.vmin, -2)
self.assertEqual(uop.vmax, 0)
mixed = UOp.variable('mixed', -20, 20)
uop = mixed % -3
self.assertEqual(uop.vmin, -2)
self.assertEqual(uop.vmax, 2)
class TestVminVmaxVConst(unittest.TestCase):
def test_vmin_vmax_vconst_single_element(self):

View File

@@ -627,8 +627,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
# SHL/SHR on consts only
if self.op is Ops.SHL and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] << t[2], t[1] << t[2]
if self.op is Ops.SHR and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] >> t[2], t[1] >> t[2]
if self.op is Ops.MOD and s1_vmin > 0:
return (0, s1_vmax-1) if s0_vmin >= 0 else (-(s1_vmax-1), s1_vmax-1)
if self.op is Ops.MOD:
if s1_vmin > 0: return (0, s1_vmax-1) if s0_vmin >= 0 else (-(s1_vmax-1), 0) if s0_vmax <= 0 else (-(s1_vmax-1), s1_vmax-1)
if s1_vmax < 0: return (0, -s1_vmax-1) if s0_vmin >= 0 else (-(-s1_vmax-1), 0) if s0_vmax <= 0 else (-(-s1_vmax-1), -s1_vmax-1)
if self.op is Ops.IDIV:
assert isinstance(s0_vmin, int) and isinstance(s0_vmax, int) and isinstance(s1_vmin, int) and isinstance(s1_vmax, int)
if (c:=s1_vmin) == s1_vmax: # s1 is a const