limit AND const min max to ints [pr] (#12918)

This commit is contained in:
chenyu
2025-10-25 16:07:52 -04:00
committed by GitHub
parent 92324172be
commit e18922f111
2 changed files with 5 additions and 7 deletions

View File

@@ -40,15 +40,14 @@ class TestVminVmaxProperties(unittest.TestCase):
self.assertEqual(uop.vmin, 0)
self.assertEqual(uop.vmax, 5)
# this can be improved
uop = x & 15
self.assertEqual(uop.vmin, 0)
self.assertEqual(uop.vmax, 15)
# this can be improved
# TODO: this can be improved
uop = x & 32
self.assertEqual(uop.vmin, 0)
self.assertEqual(uop.vmax, 20)
self.assertEqual(uop.vmax, 20) # shoud be 0
def test_vmin_vmax_multiplication_with_variable(self):
# vmin and vmax for multiplication with a variable

View File

@@ -677,7 +677,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
(s0_vmin, s0_vmax), (s1_vmin, s1_vmax) = self.src[0]._min_max, self.src[1]._min_max
if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax
if self.op is Ops.SUB: return s0_vmin-s1_vmax, s0_vmax-s1_vmin
if self.op is Ops.AND and s1_vmin == s1_vmax and s0_vmin >= 0 and s1_vmin >= 0: return min(0, s0_vmin), min(s0_vmax, s1_vmax)
if self.op is Ops.AND and dtypes.is_int(self.dtype) and s1_vmin == s1_vmax >= 0 and s0_vmin >= 0: return min(0, s0_vmin), min(s0_vmax, s1_vmax)
if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals)
# SHL/SHR on consts only
if self.op is Ops.SHL and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] << t[2], t[1] << t[2]
@@ -692,9 +692,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.op is Ops.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax)
if self.op is Ops.CMPLT: return (s0_vmax<s1_vmin, s0_vmin<s1_vmax)
if self.op is Ops.CMPNE: return ((s0_vmax < s1_vmin) or (s1_vmax < s0_vmin), not (s0_vmin == s0_vmax == s1_vmin == s1_vmax))
if self.dtype == dtypes.bool:
if self.op is Ops.OR: return s0_vmin or s1_vmin, s0_vmax or s1_vmax
if self.op is Ops.AND: return s0_vmin and s1_vmin, s0_vmax and s1_vmax
if self.op is Ops.OR and self.dtype == dtypes.bool: return s0_vmin or s1_vmin, s0_vmax or s1_vmax
if self.op is Ops.AND and self.dtype == dtypes.bool: return s0_vmin and s1_vmin, s0_vmax and s1_vmax
# float has NAN issue and we use explicit NAN in transcendental
if self.op is Ops.WHERE and dtypes.is_int(self.dtype): return min(self.src[1].vmin, self.src[2].vmin), max(self.src[1].vmax, self.src[2].vmax)
# NOTE: returned UOp is assumed to be CONST