simple UOp lt/ge folding (#5657)

works if lhs is a DEFINE_VAR.
folds trivial x < -math.inf now, need to change SPECIAL to use DEFINE_VAR to fold more
This commit is contained in:
chenyu
2024-07-23 14:11:05 -04:00
committed by GitHub
parent b0fc5a4c6f
commit 199b3bf02b
3 changed files with 14 additions and 6 deletions

View File

@@ -30,8 +30,7 @@ def Variable(expr, nmin, nmax):
# TODO: fix DEFINE_VAR to not need this
class TempVar:
def __init__(self, x): self.expr = x
#return UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax)), TempVar(expr))
return UOp(UOps.DEFINE_VAR, dtypes.int, tuple(), TempVar(expr))
return UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax)), TempVar(expr))
class Node:
@staticmethod
def sum(ops): return functools.reduce(lambda x,y: x+y, ops)
@@ -73,7 +72,6 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 3), 1, 1, "1")
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 2), 1, 1, "1")
@unittest.expectedFailure
def test_lt(self):
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 77), 1, 1, "1")
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 9), 1, 1, "1")
@@ -272,11 +270,9 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(Variable("a", 0, 20)%10, 0, 9, "(a%10)")
#self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)")
@unittest.expectedFailure
def test_ge_remove(self):
self.helper_test_variable(create_ge_node(Variable("a", 0, 6), 25), 0, 0, "0")
@unittest.expectedFailure
def test_lt_remove(self):
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), -3), 0, 0, "0")
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), 3), 0, 1, "(a<3)")