mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
better bound for mod negative number (#10227)
This commit is contained in:
@@ -170,32 +170,37 @@ class TestVminVmaxDivMod(unittest.TestCase):
|
||||
|
||||
def test_vmin_vmax_mod_positive(self):
|
||||
# vmin and vmax for modulo of a variable by a positive constant
|
||||
x = UOp.variable('x', 10, 20)
|
||||
uop = x % 3
|
||||
positive = UOp.variable('positive', 10, 20)
|
||||
uop = positive % 3
|
||||
self.assertEqual(uop.vmin, 0)
|
||||
self.assertEqual(uop.vmax, 2)
|
||||
|
||||
@unittest.skip("broken")
|
||||
def test_vmin_vmax_mod_negative(self):
|
||||
# vmin and vmax for modulo of a variable by a negative constant
|
||||
x = UOp.variable('x', 10, 20)
|
||||
uop = x % -3
|
||||
negative = UOp.variable('negative', -20, -10)
|
||||
uop = negative % 3
|
||||
self.assertEqual(uop.vmin, -2)
|
||||
self.assertEqual(uop.vmax, 0)
|
||||
|
||||
def test_vmin_vmax_division_with_mixed_range(self):
|
||||
# vmin and vmax for division of a variable with a range crossing zero
|
||||
x = UOp.variable('x', -10, 10)
|
||||
uop = x // 3
|
||||
self.assertEqual(uop.vmin, -3) # -10//3 = -3 (in C)
|
||||
self.assertEqual(uop.vmax, 3) # 10//3 = 3
|
||||
mixed = UOp.variable('mixed', -20, 20)
|
||||
uop = mixed % 3
|
||||
self.assertEqual(uop.vmin, -2)
|
||||
self.assertEqual(uop.vmax, 2)
|
||||
|
||||
def test_vmin_vmax_mod_with_mixed_range(self):
|
||||
# vmin and vmax for modulo of a variable with a range crossing zero
|
||||
x = UOp.variable('x', -10, 10)
|
||||
uop = x % 4
|
||||
self.assertEqual(uop.vmin, -3)
|
||||
self.assertEqual(uop.vmax, 3)
|
||||
def test_vmin_vmax_mod_negative(self):
|
||||
# vmin and vmax for modulo of a variable by a negative constant
|
||||
positive = UOp.variable('positive', 10, 20)
|
||||
uop = positive % -3
|
||||
self.assertEqual(uop.vmin, 0)
|
||||
self.assertEqual(uop.vmax, 2)
|
||||
|
||||
negative = UOp.variable('negative', -20, -10)
|
||||
uop = negative % -3
|
||||
self.assertEqual(uop.vmin, -2)
|
||||
self.assertEqual(uop.vmax, 0)
|
||||
|
||||
mixed = UOp.variable('mixed', -20, 20)
|
||||
uop = mixed % -3
|
||||
self.assertEqual(uop.vmin, -2)
|
||||
self.assertEqual(uop.vmax, 2)
|
||||
|
||||
class TestVminVmaxVConst(unittest.TestCase):
|
||||
def test_vmin_vmax_vconst_single_element(self):
|
||||
|
||||
@@ -627,8 +627,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
# SHL/SHR on consts only
|
||||
if self.op is Ops.SHL and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] << t[2], t[1] << t[2]
|
||||
if self.op is Ops.SHR and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] >> t[2], t[1] >> t[2]
|
||||
if self.op is Ops.MOD and s1_vmin > 0:
|
||||
return (0, s1_vmax-1) if s0_vmin >= 0 else (-(s1_vmax-1), s1_vmax-1)
|
||||
if self.op is Ops.MOD:
|
||||
if s1_vmin > 0: return (0, s1_vmax-1) if s0_vmin >= 0 else (-(s1_vmax-1), 0) if s0_vmax <= 0 else (-(s1_vmax-1), s1_vmax-1)
|
||||
if s1_vmax < 0: return (0, -s1_vmax-1) if s0_vmin >= 0 else (-(-s1_vmax-1), 0) if s0_vmax <= 0 else (-(-s1_vmax-1), -s1_vmax-1)
|
||||
if self.op is Ops.IDIV:
|
||||
assert isinstance(s0_vmin, int) and isinstance(s0_vmax, int) and isinstance(s1_vmin, int) and isinstance(s1_vmax, int)
|
||||
if (c:=s1_vmin) == s1_vmax: # s1 is a const
|
||||
|
||||
Reference in New Issue
Block a user