mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
limit AND const min max to ints [pr] (#12918)
This commit is contained in:
@@ -40,15 +40,14 @@ class TestVminVmaxProperties(unittest.TestCase):
|
||||
self.assertEqual(uop.vmin, 0)
|
||||
self.assertEqual(uop.vmax, 5)
|
||||
|
||||
# this can be improved
|
||||
uop = x & 15
|
||||
self.assertEqual(uop.vmin, 0)
|
||||
self.assertEqual(uop.vmax, 15)
|
||||
|
||||
# this can be improved
|
||||
# TODO: this can be improved
|
||||
uop = x & 32
|
||||
self.assertEqual(uop.vmin, 0)
|
||||
self.assertEqual(uop.vmax, 20)
|
||||
self.assertEqual(uop.vmax, 20) # shoud be 0
|
||||
|
||||
def test_vmin_vmax_multiplication_with_variable(self):
|
||||
# vmin and vmax for multiplication with a variable
|
||||
|
||||
@@ -677,7 +677,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
(s0_vmin, s0_vmax), (s1_vmin, s1_vmax) = self.src[0]._min_max, self.src[1]._min_max
|
||||
if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax
|
||||
if self.op is Ops.SUB: return s0_vmin-s1_vmax, s0_vmax-s1_vmin
|
||||
if self.op is Ops.AND and s1_vmin == s1_vmax and s0_vmin >= 0 and s1_vmin >= 0: return min(0, s0_vmin), min(s0_vmax, s1_vmax)
|
||||
if self.op is Ops.AND and dtypes.is_int(self.dtype) and s1_vmin == s1_vmax >= 0 and s0_vmin >= 0: return min(0, s0_vmin), min(s0_vmax, s1_vmax)
|
||||
if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals)
|
||||
# SHL/SHR on consts only
|
||||
if self.op is Ops.SHL and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] << t[2], t[1] << t[2]
|
||||
@@ -692,9 +692,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax)
|
||||
if self.op is Ops.CMPLT: return (s0_vmax<s1_vmin, s0_vmin<s1_vmax)
|
||||
if self.op is Ops.CMPNE: return ((s0_vmax < s1_vmin) or (s1_vmax < s0_vmin), not (s0_vmin == s0_vmax == s1_vmin == s1_vmax))
|
||||
if self.dtype == dtypes.bool:
|
||||
if self.op is Ops.OR: return s0_vmin or s1_vmin, s0_vmax or s1_vmax
|
||||
if self.op is Ops.AND: return s0_vmin and s1_vmin, s0_vmax and s1_vmax
|
||||
if self.op is Ops.OR and self.dtype == dtypes.bool: return s0_vmin or s1_vmin, s0_vmax or s1_vmax
|
||||
if self.op is Ops.AND and self.dtype == dtypes.bool: return s0_vmin and s1_vmin, s0_vmax and s1_vmax
|
||||
# float has NAN issue and we use explicit NAN in transcendental
|
||||
if self.op is Ops.WHERE and dtypes.is_int(self.dtype): return min(self.src[1].vmin, self.src[2].vmin), max(self.src[1].vmax, self.src[2].vmax)
|
||||
# NOTE: returned UOp is assumed to be CONST
|
||||
|
||||
Reference in New Issue
Block a user