diff --git a/test/test_const_folding.py b/test/test_const_folding.py index fd1900106d..604a2cc084 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -1,4 +1,4 @@ -import unittest +import unittest, math from tinygrad import Tensor, Device from tinygrad.engine.schedule import create_schedule from tinygrad.features.multi import MultiLazyBuffer @@ -131,5 +131,41 @@ class TestMultiConstFolding(unittest.TestCase): _check_ast_count(0, t ** 1) _check_ast_count(0, 1 ** t) +class TestTautologicalCompare(unittest.TestCase): + # without const folding, these would have triggered -Wtautological-compare in clang + def test_lt_false(self): + # bool < False is always false + np.testing.assert_equal((Tensor([True, False]) < False).numpy(), [False, False]) + + def test_true_lt(self): + # True < bool is always false + np.testing.assert_equal((True < Tensor([True, False])).numpy(), [False, False]) + + def test_truth_table(self): + np.testing.assert_equal((Tensor(False) < Tensor(False)).numpy(), False) + np.testing.assert_equal((Tensor(False) < Tensor(True)).numpy(), True) + np.testing.assert_equal((Tensor(True) < Tensor(False)).numpy(), False) + np.testing.assert_equal((Tensor(True) < Tensor(True)).numpy(), False) + + @unittest.skip("not implemented yet") + def test_a_eq_a(self): + # self eq is always true for int or bool + a = Tensor([1, 2, 3]) + np.testing.assert_equal((a == a).numpy(), [True, True, True]) + + # not true for nan + a = Tensor([math.nan, 1.0, 2.0]) + np.testing.assert_equal((a == a).numpy(), [False, True, True]) + + @unittest.skip("not implemented yet") + def test_a_ne_a(self): + # self not eq is always false for int or bool + a = Tensor([1, 2, 3]) + np.testing.assert_equal((a != a).numpy(), [False, False, False]) + + # not true for nan + a = Tensor([math.nan, 1.0, 2.0]) + np.testing.assert_equal((a != a).numpy(), [True, False, False]) + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 285ea0eeb0..09156c7a87 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -76,6 +76,10 @@ constant_folder = PatternMatcher([ # x+-y -> x-y ({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "x"}, {"__name__": "my", "uop": UOps.ALU, "arg": UnaryOps.NEG})}, lambda x, my: UOp(UOps.ALU, x.dtype, (x, my.vin[0]), BinaryOps.SUB)), + # bool < False is always false, True < bool is always false + ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({}, {"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.bool, "arg": False})}, lambda x: x), + ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.bool, "arg": True}, {})}, + lambda x: UOp.const(dtypes.bool, False)), # a conditional with the same results either way is a noop, also fold const conditionals ({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({}, {"__name__": "val"}, {"__name__": "val"})}, lambda val: val), ({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({"__name__": "gate", "uop": UOps.CONST}, {"__name__": "c0"}, {"__name__": "c1"})},