mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
improve min/max for AND (#14356)
This commit is contained in:
@@ -49,6 +49,24 @@ class TestVminVmaxProperties(unittest.TestCase):
|
||||
self.assertEqual(uop.vmin, 0)
|
||||
self.assertEqual(uop.vmax, 20) # shoud be 0
|
||||
|
||||
def test_vmin_vmax_and_with_negative_variable(self):
|
||||
# when mask doesn't have sign bit set, result is always non-negative
|
||||
x = UOp.variable('x', -100, 100, dtypes.int32)
|
||||
# 511 = 0x1FF, doesn't have sign bit set for int32
|
||||
uop = x & 511
|
||||
self.assertEqual(uop.vmin, 0)
|
||||
self.assertEqual(uop.vmax, 511)
|
||||
|
||||
# 0x7FFFFFFF is max positive int32, doesn't have sign bit
|
||||
uop = x & 0x7FFFFFFF
|
||||
self.assertEqual(uop.vmin, 0)
|
||||
self.assertEqual(uop.vmax, 0x7FFFFFFF)
|
||||
|
||||
# negative mask: x & -1 could be anything since -1 has all bits set
|
||||
uop = x & -1
|
||||
self.assertEqual(uop.vmin, dtypes.min(dtypes.int32))
|
||||
self.assertEqual(uop.vmax, dtypes.max(dtypes.int32))
|
||||
|
||||
def test_vmin_vmax_multiplication_with_variable(self):
|
||||
# vmin and vmax for multiplication with a variable
|
||||
x = UOp.variable('x', -3, 4)
|
||||
|
||||
@@ -749,7 +749,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
(s0_vmin, s0_vmax), (s1_vmin, s1_vmax) = self.src[0]._min_max, self.src[1]._min_max
|
||||
if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax
|
||||
if self.op is Ops.SUB: return s0_vmin-s1_vmax, s0_vmax-s1_vmin
|
||||
if self.op is Ops.AND and dtypes.is_int(self.dtype) and s1_vmin == s1_vmax >= 0 and s0_vmin >= 0: return min(0, s0_vmin), min(s0_vmax, s1_vmax)
|
||||
if self.op is Ops.AND and dtypes.is_int(self.dtype) and s1_vmin == s1_vmax >= 0:
|
||||
return 0, s1_vmax if s0_vmin < 0 else min(s0_vmax, s1_vmax)
|
||||
if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals)
|
||||
# SHL/SHR on consts only
|
||||
if self.op is Ops.SHL and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] << t[2], t[1] << t[2]
|
||||
|
||||
Reference in New Issue
Block a user