From 0d326a48b81def467bf4c7bd163d05afbee40890 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 20 Feb 2024 09:06:55 -0500 Subject: [PATCH] fix LtNode simplification when lhs and rhs contain same variables (#3451) * fix LtNode simplification when lhs and rhs contain same variables `(Variable("a", 1, 5) < Variable("a", 1, 5))` should eval to `NumNode(0)` * fix with less perf impact --- test/unit/test_symbolic.py | 8 +++++--- tinygrad/shape/symbolic.py | 1 + 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/test/unit/test_symbolic.py b/test/unit/test_symbolic.py index ce5d26619a..9c54a52c6b 100644 --- a/test/unit/test_symbolic.py +++ b/test/unit/test_symbolic.py @@ -413,9 +413,11 @@ class TestSymbolicSymbolicOps(unittest.TestCase): c = Variable("c", 1, 10) d = Variable("d", 5, 10) # if the value is always the same, it folds to num - assert (a < b) == 1 - assert (b < a) == 0 - assert (d < a) == 0 + assert (a < b) == NumNode(1) + assert (b < a) == NumNode(0) + assert (d < a) == NumNode(0) + assert (a < a) == NumNode(0) + assert (a > a) == NumNode(0) # if it remains as a LtNode, bool is always true and (min, max) == (0, 1) assert isinstance((a < c), LtNode) and (a < c).min == 0 and (a < c).max == 1 assert a < c diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 044772e192..92f1b53677 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -168,6 +168,7 @@ class OpNode(Node): class LtNode(OpNode): def get_bounds(self) -> Tuple[int, int]: + if self.a == self.b: return (0, 0) if isinstance(self.b, int): return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1) return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min >= self.b.max else (0, 1) def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: