uops const fold rules to prevent tautological compare warnings (#4041)

* uops const fold rules to prevent tautological compare warnings

`bool < false` is false, `true < bool` is false, `a == a` is true, `a != a` is false

* not true for nan

* and nan does not work with llvm

* full truth table test

* revert a==a

* comments and indents
This commit is contained in:
chenyu
2024-04-02 16:45:58 -04:00
committed by GitHub
parent e879e16c48
commit 85edc493b0
2 changed files with 41 additions and 1 deletions

View File

@@ -76,6 +76,10 @@ constant_folder = PatternMatcher([
# x+-y -> x-y
({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "x"}, {"__name__": "my", "uop": UOps.ALU, "arg": UnaryOps.NEG})},
lambda x, my: UOp(UOps.ALU, x.dtype, (x, my.vin[0]), BinaryOps.SUB)),
# bool < False is always false, True < bool is always false
({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({}, {"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.bool, "arg": False})}, lambda x: x),
({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.bool, "arg": True}, {})},
lambda x: UOp.const(dtypes.bool, False)),
# a conditional with the same results either way is a noop, also fold const conditionals
({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({}, {"__name__": "val"}, {"__name__": "val"})}, lambda val: val),
({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({"__name__": "gate", "uop": UOps.CONST}, {"__name__": "c0"}, {"__name__": "c1"})},