From e3105675fbfe1ae3f080fcec4ef2e281312c4d31 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 16 Nov 2024 09:44:17 -0500 Subject: [PATCH] cond.where(True, False) is cond (#7733) --- test/unit/test_uop_symbolic.py | 7 +++++++ tinygrad/ops.py | 1 + 2 files changed, 8 insertions(+) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 6b88f15216..fba3f2f9d6 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -455,6 +455,13 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((a+b+c*2).lt(1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)")) self.helper_test_variable((a+b*2+c*4).lt(1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)")) + def test_where_removal(self): + cond = Variable("a", 0, 3).lt(2) + u1, u0 = cond.ufix(1), cond.ufix(0) + self.helper_test_variable(cond, 0, 1, "(a<2)") + self.helper_test_variable(cond.where(u1, u0), 0, 1, "(a<2)") + self.helper_test_variable(cond.where(u1, u0).where(u1, u0), 0, 1, "(a<2)") + class TestSymbolicNumeric(unittest.TestCase): def helper_test_numeric(self, f): MIN, MAX = 0, 10 diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 33efeebdc8..41d2911a8e 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1034,6 +1034,7 @@ symbolic_simple = PatternMatcher([ ((UPat.var("x") & UPat.var("x")), lambda x: x), ((UPat.var("x") | UPat.var("x")), lambda x: x), (UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x), + (UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x), # ** zero folding ** (UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False (UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints),