mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
cond.where(True, False) is cond (#7733)
This commit is contained in:
@@ -455,6 +455,13 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable((a+b+c*2).lt(1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)"))
|
||||
self.helper_test_variable((a+b*2+c*4).lt(1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)"))
|
||||
|
||||
def test_where_removal(self):
|
||||
cond = Variable("a", 0, 3).lt(2)
|
||||
u1, u0 = cond.ufix(1), cond.ufix(0)
|
||||
self.helper_test_variable(cond, 0, 1, "(a<2)")
|
||||
self.helper_test_variable(cond.where(u1, u0), 0, 1, "(a<2)")
|
||||
self.helper_test_variable(cond.where(u1, u0).where(u1, u0), 0, 1, "(a<2)")
|
||||
|
||||
class TestSymbolicNumeric(unittest.TestCase):
|
||||
def helper_test_numeric(self, f):
|
||||
MIN, MAX = 0, 10
|
||||
|
||||
Reference in New Issue
Block a user