mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
cond.where(True, False) is cond (#7733)
This commit is contained in:
@@ -455,6 +455,13 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable((a+b+c*2).lt(1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)"))
|
||||
self.helper_test_variable((a+b*2+c*4).lt(1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)"))
|
||||
|
||||
def test_where_removal(self):
|
||||
cond = Variable("a", 0, 3).lt(2)
|
||||
u1, u0 = cond.ufix(1), cond.ufix(0)
|
||||
self.helper_test_variable(cond, 0, 1, "(a<2)")
|
||||
self.helper_test_variable(cond.where(u1, u0), 0, 1, "(a<2)")
|
||||
self.helper_test_variable(cond.where(u1, u0).where(u1, u0), 0, 1, "(a<2)")
|
||||
|
||||
class TestSymbolicNumeric(unittest.TestCase):
|
||||
def helper_test_numeric(self, f):
|
||||
MIN, MAX = 0, 10
|
||||
|
||||
@@ -1034,6 +1034,7 @@ symbolic_simple = PatternMatcher([
|
||||
((UPat.var("x") & UPat.var("x")), lambda x: x),
|
||||
((UPat.var("x") | UPat.var("x")), lambda x: x),
|
||||
(UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
|
||||
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
|
||||
# ** zero folding **
|
||||
(UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False
|
||||
(UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints),
|
||||
|
||||
Reference in New Issue
Block a user