UOp bound for div negative number (#5808)

This commit is contained in:
chenyu
2024-07-31 02:10:23 -04:00
committed by GitHub
parent bcbd925001
commit 2e087ca8e4
3 changed files with 26 additions and 18 deletions

View File

@@ -265,14 +265,6 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((-Variable("idx", 0, 100)+200)//-4 + 50, 0, 25, "(idx//4)")
self.helper_test_variable((-Variable("idx", 0, 100)+201)//-4 + 50, -1, 24, "(((3+idx)//4)+-1)")
def test_div_neg_all_range(self):
gidx = Variable("gidx", 0, 124)
lidx = Variable("lidx", 0, 7)
self.helper_test_variable((-gidx*8-lidx+999)//-4 + 250, 0, 250, "(((1+lidx)//4)+(gidx*2))")
self.helper_test_variable((-gidx*8-lidx+1000)//-4 + 250, 0, 249, "((gidx*2)+(lidx//4))")
self.helper_test_variable((-gidx*8-lidx+1001)//-4 + 250, -1, 249, "(((3+lidx)//4)+(gidx*2)+-1)")
self.helper_test_variable((-gidx*8-lidx+1002)//-4 + 250, -1, 249, "(((2+lidx)//4)+(gidx*2)+-1)")
class TestSymbolicNumeric(unittest.TestCase):
def helper_test_numeric(self, f):
# TODO: why are the negative tests broken? (even if we did support negative variables)

View File

@@ -325,20 +325,34 @@ class TestSymbolic(unittest.TestCase):
def test_div_into_mod(self):
self.helper_test_variable((Variable("idx", 0, 16)*4)%8//4, 0, 1, "(idx%2)")
@unittest.expectedFailure
# TODO: simplify the expression
def test_div_neg_cancel(self):
self.helper_test_variable((-Variable("idx", 0, 100)+199)//-4 + 50, 0, 25, "((1+idx)//4)")
self.helper_test_variable((-Variable("idx", 0, 100)+200)//-4 + 50, 0, 25, "(idx//4)")
self.helper_test_variable((-Variable("idx", 0, 100)+201)//-4 + 50, -1, 24, "(((3+idx)//4)+-1)")
self.helper_test_variable((-Variable("idx", 0, 100)+199)//-4 + 50, 1, 26, "((((-idx)+199)//(-4))+50)")
self.helper_test_variable((-Variable("idx", 0, 100)+200)//-4 + 50, 0, 25, "((((-idx)+200)//(-4))+50)")
self.helper_test_variable((-Variable("idx", 0, 100)+201)//-4 + 50, 0, 25, "((((-idx)+201)//(-4))+50)")
@unittest.expectedFailure
# NOTE: tests are not correct in symbolic
# TODO: simplify the expression
def test_div_neg_all_range(self):
gidx = Variable("gidx", 0, 124)
lidx = Variable("lidx", 0, 7)
self.helper_test_variable((-gidx*8-lidx+999)//-4 + 250, 0, 250, "(((1+lidx)//4)+(gidx*2))")
self.helper_test_variable((-gidx*8-lidx+1000)//-4 + 250, 0, 249, "((gidx*2)+(lidx//4))")
self.helper_test_variable((-gidx*8-lidx+1001)//-4 + 250, -1, 249, "(((3+lidx)//4)+(gidx*2)+-1)")
self.helper_test_variable((-gidx*8-lidx+1002)//-4 + 250, -1, 249, "(((2+lidx)//4)+(gidx*2)+-1)")
self.helper_test_variable((-gidx*8-lidx+999)//-4 + 250, 1, 250, "(((((-gidx)*8)+(-lidx)+999)//(-4))+250)")
self.helper_test_variable((-gidx*8-lidx+1000)//-4 + 250, 0, 250, "(((((-gidx)*8)+(-lidx)+1000)//(-4))+250)")
self.helper_test_variable((-gidx*8-lidx+1001)//-4 + 250, 0, 250, "(((((-gidx)*8)+(-lidx)+1001)//(-4))+250)")
self.helper_test_variable((-gidx*8-lidx+1002)//-4 + 250, 0, 250, "(((((-gidx)*8)+(-lidx)+1002)//(-4))+250)")
# NOTE: tests are not correct in symbolic
def test_div_neg_then_neg(self):
# taken from arange opts
lidx0 = Variable("lidx0", 0, 7)
lidx1 = Variable("lidx1", 0, 7)
alu2 = -lidx0-lidx1
self.helper_test_variable((((alu2+14)//(-32))+4), 4, 4, "4")
self.helper_test_variable(-(((alu2+14)//(-32))+4), -4, -4, "(-4)")
self.helper_test_variable((((alu2+134)//(-32))+4), 0, 1, "((((-lidx0)+(-lidx1)+134)//(-32))+4)")
self.helper_test_variable((((alu2+142)//(-32))+4), 0, 0, "0")
self.helper_test_variable((((alu2+150)//(-32))+4), 0, 0, "0")
self.helper_test_variable((((alu2+158)//(-32))+4), 0, 0, "0")
@unittest.skip("not supported on uops yet")
class TestSymbolicNumeric(unittest.TestCase):

View File

@@ -117,7 +117,9 @@ class UOp:
Rmin, Rmax = (s1.vmin.arg, s1.vmax.arg) if s0.vmin.arg >= 0 else (s1.vmax.arg, s1.vmin.arg)
return self.const(Lmin*Rmin), self.const(Lmax*Rmax)
if self.arg is BinaryOps.MOD and s1.op is UOps.CONST and s1.arg > 0: return self.const(0), self.const(s1.arg-1)
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST and s1.arg > 0: return self.const(s0.vmin.arg//s1.arg), self.const(s0.vmax.arg//s1.arg)
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
if s1.arg > 0: return self.const(s0.vmin.arg//s1.arg), self.const(s0.vmax.arg//s1.arg)
if s1.arg < 0: return self.const(-(s0.vmax.arg//-s1.arg)), self.const(-(s0.vmin.arg//-s1.arg))
if self.arg is BinaryOps.MAX: return self.const(max(s0.vmin.arg, s1.vmin.arg)), self.const(max(s0.vmax.arg, s1.vmax.arg))
if self.arg is BinaryOps.CMPLT: return (UOp.const(dtypes.bool, True), UOp.const(dtypes.bool, True)) if s0.vmax.arg < s1.vmin.arg else \
(UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, False)) if s0.vmin.arg >= s1.vmax.arg else (None, None)