Revert "revert some UOp IDIV bound (#5863)" (#5871)

This reverts commit 0c8d202348.
This commit is contained in:
chenyu
2024-08-01 21:38:31 -04:00
committed by GitHub
parent df138bc558
commit f27f949a5d
3 changed files with 3 additions and 11 deletions

View File

@@ -223,9 +223,6 @@ jobs:
- if: ${{ matrix.task == 'onnx' }}
name: Test THREEFRY
run: PYTHONPATH=. THREEFRY=1 GPU=1 python3 -m pytest test/test_randomness.py test/test_jit.py --durations=20
- if: ${{ matrix.task == 'onnx' }}
name: Test UOP_IS_SYMBOLIC
run: PYTHONPATH=. UOP_IS_SYMBOLIC=1 GPU=1 python3 -m pytest -n=auto test/test_ops.py -k "conv and not (test_padded_conv3d or test_conv2d_bs_4_cin_3 or test_conv2d or test_conv2d_bs_4_cin_1 or test_strided_conv_transpose2d)" --durations=20
- if: ${{ matrix.task == 'onnx' }}
name: Run handcode_opt
run: PYTHONPATH=. MODEL=resnet GPU=1 DEBUG=1 BS=4 HALF=0 python3 examples/handcode_opt.py

View File

@@ -177,7 +177,6 @@ class TestSymbolic(unittest.TestCase):
def test_div_remove(self):
self.helper_test_variable(Variable("a", 0, 7) // 20, 0, 0, "0")
@unittest.expectedFailure
def test_div_min_max(self):
self.helper_test_variable(Variable("a", 0, 7) // 2, 0, 3, "(a//2)")
@@ -189,7 +188,6 @@ class TestSymbolic(unittest.TestCase):
def test_sum_div_remove(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 20, 0, 0, "0")
@unittest.expectedFailure
def test_sum_div_min_max(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)")
@@ -206,7 +204,6 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
self.helper_test_variable(Node.sum([NumNode(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)")
@unittest.expectedFailure
def test_sum_div_no_factor(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)")
@@ -220,7 +217,6 @@ class TestSymbolic(unittest.TestCase):
def test_sum_div_const(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 4, 0, 7, "a")
@unittest.expectedFailure
def test_sum_div_const_big(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 16, 0, 1, "(a//4)")
@@ -252,7 +248,6 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,12), 0, 1, {"((a*-1)<-2)", "(!(a<3))"})
self.helper_test_variable(create_ge_node(Variable("a", 0, 5)*4,13), 0, 1, {"((a*-1)<-3)", "(!(a<4))"})
@unittest.expectedFailure
def test_div_div(self):
self.helper_test_variable((Variable("a", 0, 1800)//10)//9, 0, 20, "(a//90)")
@@ -318,7 +313,6 @@ class TestSymbolic(unittest.TestCase):
def test_mul_div_factor_mul(self):
self.helper_test_variable((Variable("a", 0, 10)*8)//4, 0, 20, "(a*2)")
@unittest.expectedFailure
def test_mul_div_factor_div(self):
self.helper_test_variable((Variable("a", 0, 10)*4)//8, 0, 5, "(a//2)")
@@ -338,6 +332,7 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((-Variable("idx", 0, 100)+200)//-4 + 50, 0, 25, "((((-idx)+200)//(-4))+50)")
self.helper_test_variable((-Variable("idx", 0, 100)+201)//-4 + 50, 0, 25, "((((-idx)+201)//(-4))+50)")
# NOTE: tests are not correct in symbolic
# TODO: simplify the expression
def test_div_neg_all_range(self):
gidx = Variable("gidx", 0, 124)
@@ -347,6 +342,7 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((-gidx*8-lidx+1001)//-4 + 250, 0, 250, "(((((-gidx)*8)+(-lidx)+1001)//(-4))+250)")
self.helper_test_variable((-gidx*8-lidx+1002)//-4 + 250, 0, 250, "(((((-gidx)*8)+(-lidx)+1002)//(-4))+250)")
# NOTE: tests are not correct in symbolic
def test_div_neg_then_neg(self):
# taken from arange opts
lidx0 = Variable("lidx0", 0, 7)

View File

@@ -119,8 +119,7 @@ class UOp:
return self.const(Lmin*Rmin), self.const(Lmax*Rmax)
if self.arg is BinaryOps.MOD and s1.op is UOps.CONST and s1.arg > 0: return self.const(0), self.const(s1.arg-1)
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
# TODO: this breaks UOP_IS_SYMBOLIC=1 test/test_ops.py -k test_conv2d_bs_1_cin_1
# if s1.arg > 0: return self.const(s0.vmin.arg//s1.arg), self.const(s0.vmax.arg//s1.arg)
if s1.arg > 0: return self.const(s0.vmin.arg//s1.arg), self.const(s0.vmax.arg//s1.arg)
if s1.arg < 0: return self.const(-(s0.vmax.arg//-s1.arg)), self.const(-(s0.vmin.arg//-s1.arg))
if self.arg is BinaryOps.MAX: return self.const(max(s0.vmin.arg, s1.vmin.arg)), self.const(max(s0.vmax.arg, s1.vmax.arg))
if self.arg is BinaryOps.CMPLT: return (UOp.const(dtypes.bool, True), UOp.const(dtypes.bool, True)) if s0.vmax.arg < s1.vmin.arg else \