cond.where(True, False) is cond (#7733)

This commit is contained in:
chenyu
2024-11-16 09:44:17 -05:00
committed by GitHub
parent 40ae0e9115
commit e3105675fb
2 changed files with 8 additions and 0 deletions

View File

@@ -455,6 +455,13 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((a+b+c*2).lt(1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)"))
self.helper_test_variable((a+b*2+c*4).lt(1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)"))
def test_where_removal(self):
cond = Variable("a", 0, 3).lt(2)
u1, u0 = cond.ufix(1), cond.ufix(0)
self.helper_test_variable(cond, 0, 1, "(a<2)")
self.helper_test_variable(cond.where(u1, u0), 0, 1, "(a<2)")
self.helper_test_variable(cond.where(u1, u0).where(u1, u0), 0, 1, "(a<2)")
class TestSymbolicNumeric(unittest.TestCase):
def helper_test_numeric(self, f):
MIN, MAX = 0, 10

View File

@@ -1034,6 +1034,7 @@ symbolic_simple = PatternMatcher([
((UPat.var("x") & UPat.var("x")), lambda x: x),
((UPat.var("x") | UPat.var("x")), lambda x: x),
(UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
# ** zero folding **
(UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False
(UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints),