diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py
index b958b62a45..3b3f185b24 100644
--- a/test/unit/test_uop_symbolic.py
+++ b/test/unit/test_uop_symbolic.py
@@ -15,7 +15,7 @@ def check_uop_against_string(self, v:UOp, s:str):
if isinstance(s_eval, int) and v.dtype==dtypes.index: s_eval = UOp.const(dtypes.index, s_eval)
elif isinstance(s_eval, (bool, int, float)): s_eval = UOp.const(dtypes.from_py(s_eval), s_eval)
s_eval = graph_rewrite(s_eval, commutative, name="cannonicalize eval")
- self.assertIs(s_eval, v, f"eval did not match simplified: {s_eval} != {v} for {s}")
+ self.assertIs(s_eval, v, f"eval did not match simplified: {s_eval} != {v.render()} for {s}")
def Variable(name: str, min_val: ConstType, max_val: ConstType, dtype: DType=dtypes.index): return UOp.variable(name,min_val,max_val,dtype)
def uconst(val): return UOp.const(dtypes.index, val)
@@ -679,6 +679,10 @@ class TestSymbolic(unittest.TestCase):
b = Variable("b", 0, 3)
c = Variable("c", 0, 3)
d = Variable("d", -3, 3)
+ self.helper_test_variable((a<2), 0, 1, "(a<2)")
+ self.helper_test_variable((a<=2), 0, 1, "((21), 0, 1, "(1=1), 0, 1, "((a<1)!=True)")
self.helper_test_variable((a<1).ne(True), 0, 1, "((a<1)!=True)")
self.helper_test_variable((a+b<1).ne(True), 0, 1, "(((a+b)<1)!=True)")
self.helper_test_variable((a*3+b*4<1).ne(True), 0, 1, "(((a+b)<1)!=True)")
diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py
index 771dfa0467..fc8077147f 100644
--- a/tinygrad/uop/symbolic.py
+++ b/tinygrad/uop/symbolic.py
@@ -237,7 +237,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
(UPat.var("x", dtypes.index) 0
- # not x < 1 -> X > 0
+ # not x < 1 means x > 0
((UPat.var("x", dtypes.index)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
# a range mod its own upper bound is just the range
(UPat(Ops.RANGE, src=UPat.var("end"), name="r")%UPat.var("end"), lambda r,end: r),
@@ -262,15 +262,17 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
# ******** we take a small aside to "simplify_valid" to rewrite valids ********
-def parse_valid(valid:UOp) -> tuple[UOp, bool, int]|None:
+def parse_valid(v:UOp) -> tuple[UOp, bool, int]|None:
# if it's X <= c, returns X, True, c
# if it's X >= c, returns X, False, c
- # (X < c).ne(True) -> X >= c
- if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
- (s0:=valid.src[0]).op is Ops.CMPLT and dtypes.is_int(s0.src[0].dtype): return s0.src[0], False, int(s0.src[1].vmin)
- # X < c -> X <= c-1
- if valid.op is Ops.CMPLT and dtypes.is_int(valid.src[0].dtype): return valid.src[0], True, int((valid.src[1]).vmax)-1
+ if v.op is Ops.CMPNE and v.src[1].op is Ops.CONST and v.src[1].arg == 1 and (s0:=v.src[0]).op is Ops.CMPLT and dtypes.is_int(s0.src[0].dtype):
+ # (X < c).ne(True) -> X >= c
+ return s0.src[0], False, int(s0.src[1].vmin)
+ if v.op is Ops.CMPLT and dtypes.is_int(v.src[0].dtype):
+ # X < c -> X <= c-1
+ return v.src[0], True, int((v.src[1]).vmax)-1
+ # NOTE: v.src[1].op can be Ops.VCONST
return None
def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp: