apply more create_lt_node (#3597)

updated one in linearizer if condition, and various symbolic tests
This commit is contained in:
chenyu
2024-03-03 16:12:39 -05:00
committed by GitHub
parent bc562c4747
commit 968d109453
2 changed files with 31 additions and 31 deletions

View File

@@ -14,20 +14,20 @@ class TestSymbolic(unittest.TestCase):
self.assertEqual(v.max, m)
def test_ge(self):
self.helper_test_variable(Variable("a", 3, 8)>=77, 0, 0, "0")
self.helper_test_variable(Variable("a", 3, 8)>=9, 0, 0, "0")
self.helper_test_variable(Variable("a", 3, 8)>=8, 0, 1, "((a*-1)<-7)")
self.helper_test_variable(Variable("a", 3, 8)>=4, 0, 1, "((a*-1)<-3)")
self.helper_test_variable(Variable("a", 3, 8)>=3, 1, 1, "1")
self.helper_test_variable(Variable("a", 3, 8)>=2, 1, 1, "1")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 77), 0, 0, "0")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 9), 0, 0, "0")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 8), 0, 1, "((a*-1)<-7)")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 4), 0, 1, "((a*-1)<-3)")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 3), 1, 1, "1")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 2), 1, 1, "1")
def test_lt(self):
self.helper_test_variable(Variable("a", 3, 8)<77, 1, 1, "1")
self.helper_test_variable(Variable("a", 3, 8)<9, 1, 1, "1")
self.helper_test_variable(Variable("a", 3, 8)<8, 0, 1, "(a<8)")
self.helper_test_variable(Variable("a", 3, 8)<4, 0, 1, "(a<4)")
self.helper_test_variable(Variable("a", 3, 8)<3, 0, 0, "0")
self.helper_test_variable(Variable("a", 3, 8)<2, 0, 0, "0")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 77), 1, 1, "1")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 9), 1, 1, "1")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 8), 0, 1, "(a<8)")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 4), 0, 1, "(a<4)")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 3), 0, 0, "0")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 2), 0, 0, "0")
def test_ge_divides(self):
expr = create_lt_node(Variable("idx", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512)
@@ -42,7 +42,7 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(expr//4, 0, 0, "0")
def test_lt_factors(self):
expr = Node.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256)) < 512])
expr = create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256), 512)
self.helper_test_variable(expr, 0, 1, "(((idx1*4)+FLOAT4_INDEX)<512)")
def test_div_becomes_num(self):
@@ -200,13 +200,13 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(Variable("a", 0, 20)%10, 0, 9, "(a%10)")
#self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)")
def test_gt_remove(self):
self.helper_test_variable(Variable("a", 0, 6) >= 25, 0, 0, "0")
def test_ge_remove(self):
self.helper_test_variable(create_ge_node(Variable("a", 0, 6), 25), 0, 0, "0")
def test_lt_remove(self):
self.helper_test_variable(Variable("a", 0, 6) < -3, 0, 0, "0")
self.helper_test_variable(Variable("a", 0, 6) < 3, 0, 1, "(a<3)")
self.helper_test_variable(Variable("a", 0, 6) < 8, 1, 1, "1")
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), -3), 0, 0, "0")
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), 3), 0, 1, "(a<3)")
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), 8), 1, 1, "1")
def test_lt_sum_remove(self):
self.helper_test_variable(create_lt_node(Variable("a", 0, 6) + 2, 3), 0, 1, "(a<1)")
@@ -401,19 +401,19 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
b = Variable("b", 6, 9)
c = Variable("c", 1, 10)
d = Variable("d", 5, 10)
# if the value is always the same, it folds to num
assert (a < b) == NumNode(1)
assert (b < a) == NumNode(0)
assert (d < a) == NumNode(0)
assert (a < a) == NumNode(0)
assert (a > a) == NumNode(0)
# if the comparison output is always the same, it folds to num
assert create_lt_node(a, b) == NumNode(1)
assert create_lt_node(b, a) == NumNode(0)
assert create_lt_node(d, a) == NumNode(0)
assert create_lt_node(a, a) == NumNode(0)
assert create_lt_node(a, a) == NumNode(0)
# if it remains as a LtNode, bool is always true and (min, max) == (0, 1)
assert isinstance((a < c), LtNode) and (a < c).min == 0 and (a < c).max == 1
assert a < c
assert isinstance((a > c), LtNode) and (a > c).min == 0 and (a > c).max == 1
a_lt_c = create_lt_node(a, c)
assert isinstance(a_lt_c, LtNode) and a_lt_c.min == 0 and a_lt_c.max == 1
assert a_lt_c
# same when comparing with a constant
assert a < 3 and (a < 3).min == 0 and (a < 3).max == 1
assert a > 3 and (a > 3).min == 0 and (a > 3).max == 1
a_lt_3 = create_lt_node(a, 3)
assert a_lt_3 and a_lt_3.min == 0 and a_lt_3.max == 1
def test_sumnode_mulnode_lt(self):
a = Variable("a", 1, 2)