From d641e631891044972abb500a9687f08267982dfa Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 26 Jan 2026 15:44:18 -0500 Subject: [PATCH] improve min/max for AND (#14356) --- test/unit/test_uop_vmin_vmax.py | 18 ++++++++++++++++++ tinygrad/uop/ops.py | 3 ++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/test/unit/test_uop_vmin_vmax.py b/test/unit/test_uop_vmin_vmax.py index 2935d971f6..daf36d4ac9 100644 --- a/test/unit/test_uop_vmin_vmax.py +++ b/test/unit/test_uop_vmin_vmax.py @@ -49,6 +49,24 @@ class TestVminVmaxProperties(unittest.TestCase): self.assertEqual(uop.vmin, 0) self.assertEqual(uop.vmax, 20) # shoud be 0 + def test_vmin_vmax_and_with_negative_variable(self): + # when mask doesn't have sign bit set, result is always non-negative + x = UOp.variable('x', -100, 100, dtypes.int32) + # 511 = 0x1FF, doesn't have sign bit set for int32 + uop = x & 511 + self.assertEqual(uop.vmin, 0) + self.assertEqual(uop.vmax, 511) + + # 0x7FFFFFFF is max positive int32, doesn't have sign bit + uop = x & 0x7FFFFFFF + self.assertEqual(uop.vmin, 0) + self.assertEqual(uop.vmax, 0x7FFFFFFF) + + # negative mask: x & -1 could be anything since -1 has all bits set + uop = x & -1 + self.assertEqual(uop.vmin, dtypes.min(dtypes.int32)) + self.assertEqual(uop.vmax, dtypes.max(dtypes.int32)) + def test_vmin_vmax_multiplication_with_variable(self): # vmin and vmax for multiplication with a variable x = UOp.variable('x', -3, 4) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 6ca898f4da..769de66b9d 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -749,7 +749,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass): (s0_vmin, s0_vmax), (s1_vmin, s1_vmax) = self.src[0]._min_max, self.src[1]._min_max if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax if self.op is Ops.SUB: return s0_vmin-s1_vmax, s0_vmax-s1_vmin - if self.op is Ops.AND and dtypes.is_int(self.dtype) and s1_vmin == s1_vmax >= 0 and s0_vmin >= 0: return min(0, s0_vmin), min(s0_vmax, s1_vmax) + if self.op is Ops.AND and dtypes.is_int(self.dtype) and s1_vmin == s1_vmax >= 0: + return 0, s1_vmax if s0_vmin < 0 else min(s0_vmax, s1_vmax) if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals) # SHL/SHR on consts only if self.op is Ops.SHL and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] << t[2], t[1] << t[2]