improve idiv _min_max (#8066)

for the cases that the we don't know the exact bounds, we might still know the sign. with this, can remove some resolve for symbolic shapetracker
This commit is contained in:
chenyu
2024-12-05 23:02:16 -05:00
committed by GitHub
parent 13b954f22c
commit e7d5fe4a32
2 changed files with 16 additions and 3 deletions

View File

@@ -527,6 +527,15 @@ class TestSymbolic(unittest.TestCase):
# not combining # TODO: can combine if one is identity element const
self.helper_test_variable(aa+ab, 0, 6, "((a if (x<2) else b)+(a if (x<2) else 0))")
def test_symbolic_div(self):
# from symbolic arange
a = Variable("a", 1, 10)
denominator = ((a*-2)+1)
numerator = (((((a*2)+-1)*2)+1)*a)
self.helper_test_variable(denominator, -19, -1, "((a*-2)+1)")
self.helper_test_variable(numerator, 3, 390, "(a*((a*4)+-1))")
self.helper_test_variable((numerator//denominator)<=0, 1, 1, "True")
class TestSymbolicNumeric(unittest.TestCase):
def helper_test_numeric(self, f):
MIN, MAX = 0, 10

View File

@@ -425,9 +425,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax
if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals)
if self.op is Ops.MOD and s1_vmin > 0: return 0, s1_vmax-1
if self.op is Ops.IDIV and s1_vmin == s1_vmax: # min/max are equal in a CONST
if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin
if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin)
if self.op is Ops.IDIV:
if s1_vmin == s1_vmax: # min/max are equal in a CONST
if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin
if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin)
# don't know exact bounds, but know the sign
if (s0_vmax <= 0 and s1_vmin < 0) or (s0_vmin >= 0 and s1_vmin > 0): return 0, dtypes.max(self.dtype)
if (s0_vmax <= 0 and s1_vmin > 0) or (s0_vmin >= 0 and s1_vmin < 0): return dtypes.min(self.dtype), 0
if self.op is Ops.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax)
if self.op is Ops.CMPLT: return (s0_vmax<s1_vmin, s0_vmin<s1_vmax)
if self.op is Ops.CMPNE: return ((s0_vmax < s1_vmin) or (s1_vmax < s0_vmin), not (s0_vmin == s0_vmax == s1_vmin == s1_vmax))