improve min/max for AND (#14356)

This commit is contained in:
chenyu
2026-01-26 15:44:18 -05:00
committed by GitHub
parent f16372487a
commit d641e63189
2 changed files with 20 additions and 1 deletions

View File

@@ -49,6 +49,24 @@ class TestVminVmaxProperties(unittest.TestCase):
self.assertEqual(uop.vmin, 0)
self.assertEqual(uop.vmax, 20) # shoud be 0
def test_vmin_vmax_and_with_negative_variable(self):
# when mask doesn't have sign bit set, result is always non-negative
x = UOp.variable('x', -100, 100, dtypes.int32)
# 511 = 0x1FF, doesn't have sign bit set for int32
uop = x & 511
self.assertEqual(uop.vmin, 0)
self.assertEqual(uop.vmax, 511)
# 0x7FFFFFFF is max positive int32, doesn't have sign bit
uop = x & 0x7FFFFFFF
self.assertEqual(uop.vmin, 0)
self.assertEqual(uop.vmax, 0x7FFFFFFF)
# negative mask: x & -1 could be anything since -1 has all bits set
uop = x & -1
self.assertEqual(uop.vmin, dtypes.min(dtypes.int32))
self.assertEqual(uop.vmax, dtypes.max(dtypes.int32))
def test_vmin_vmax_multiplication_with_variable(self):
# vmin and vmax for multiplication with a variable
x = UOp.variable('x', -3, 4)

View File

@@ -749,7 +749,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
(s0_vmin, s0_vmax), (s1_vmin, s1_vmax) = self.src[0]._min_max, self.src[1]._min_max
if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax
if self.op is Ops.SUB: return s0_vmin-s1_vmax, s0_vmax-s1_vmin
if self.op is Ops.AND and dtypes.is_int(self.dtype) and s1_vmin == s1_vmax >= 0 and s0_vmin >= 0: return min(0, s0_vmin), min(s0_vmax, s1_vmax)
if self.op is Ops.AND and dtypes.is_int(self.dtype) and s1_vmin == s1_vmax >= 0:
return 0, s1_vmax if s0_vmin < 0 else min(s0_vmax, s1_vmax)
if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals)
# SHL/SHR on consts only
if self.op is Ops.SHL and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] << t[2], t[1] << t[2]