diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index b958b62a45..3b3f185b24 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -15,7 +15,7 @@ def check_uop_against_string(self, v:UOp, s:str): if isinstance(s_eval, int) and v.dtype==dtypes.index: s_eval = UOp.const(dtypes.index, s_eval) elif isinstance(s_eval, (bool, int, float)): s_eval = UOp.const(dtypes.from_py(s_eval), s_eval) s_eval = graph_rewrite(s_eval, commutative, name="cannonicalize eval") - self.assertIs(s_eval, v, f"eval did not match simplified: {s_eval} != {v} for {s}") + self.assertIs(s_eval, v, f"eval did not match simplified: {s_eval} != {v.render()} for {s}") def Variable(name: str, min_val: ConstType, max_val: ConstType, dtype: DType=dtypes.index): return UOp.variable(name,min_val,max_val,dtype) def uconst(val): return UOp.const(dtypes.index, val) @@ -679,6 +679,10 @@ class TestSymbolic(unittest.TestCase): b = Variable("b", 0, 3) c = Variable("c", 0, 3) d = Variable("d", -3, 3) + self.helper_test_variable((a<2), 0, 1, "(a<2)") + self.helper_test_variable((a<=2), 0, 1, "((21), 0, 1, "(1=1), 0, 1, "((a<1)!=True)") self.helper_test_variable((a<1).ne(True), 0, 1, "((a<1)!=True)") self.helper_test_variable((a+b<1).ne(True), 0, 1, "(((a+b)<1)!=True)") self.helper_test_variable((a*3+b*4<1).ne(True), 0, 1, "(((a+b)<1)!=True)") diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 771dfa0467..fc8077147f 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -237,7 +237,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ (UPat.var("x", dtypes.index) 0 - # not x < 1 -> X > 0 + # not x < 1 means x > 0 ((UPat.var("x", dtypes.index)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None), # a range mod its own upper bound is just the range (UPat(Ops.RANGE, src=UPat.var("end"), name="r")%UPat.var("end"), lambda r,end: r), @@ -262,15 +262,17 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ # ******** we take a small aside to "simplify_valid" to rewrite valids ******** -def parse_valid(valid:UOp) -> tuple[UOp, bool, int]|None: +def parse_valid(v:UOp) -> tuple[UOp, bool, int]|None: # if it's X <= c, returns X, True, c # if it's X >= c, returns X, False, c - # (X < c).ne(True) -> X >= c - if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \ - (s0:=valid.src[0]).op is Ops.CMPLT and dtypes.is_int(s0.src[0].dtype): return s0.src[0], False, int(s0.src[1].vmin) - # X < c -> X <= c-1 - if valid.op is Ops.CMPLT and dtypes.is_int(valid.src[0].dtype): return valid.src[0], True, int((valid.src[1]).vmax)-1 + if v.op is Ops.CMPNE and v.src[1].op is Ops.CONST and v.src[1].arg == 1 and (s0:=v.src[0]).op is Ops.CMPLT and dtypes.is_int(s0.src[0].dtype): + # (X < c).ne(True) -> X >= c + return s0.src[0], False, int(s0.src[1].vmin) + if v.op is Ops.CMPLT and dtypes.is_int(v.src[0].dtype): + # X < c -> X <= c-1 + return v.src[0], True, int((v.src[1]).vmax)-1 + # NOTE: v.src[1].op can be Ops.VCONST return None def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp: