diff --git a/test/unit/test_simplify_valid_idx.py b/test/unit/test_simplify_valid_idx.py index 7a27cc83ca..02af8567c5 100644 --- a/test/unit/test_simplify_valid_idx.py +++ b/test/unit/test_simplify_valid_idx.py @@ -104,7 +104,9 @@ class TestValidIdxSimplification(unittest.TestCase): def test_simplify_valid_from_div(self): x = Variable("x", -100, 100) valid = ((x<0)&((100%x).cast(dtypes.bool))) - self.assertIsNone(simplify_valid(valid)) + # NOTE: this simplifies the (100%x) part somehow, still has two clauses + self.assertIsNotNone(simplify_valid(valid)) + self.assertEqual(len(list(valid.split_uop(Ops.AND))), 2) @unittest.expectedFailure # TODO: fix def test_from_merge_views(self): diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 7da4767712..11e04b2dce 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -459,8 +459,6 @@ def simplify_valid(valid:UOp) -> UOp|None: something_changed = False valids = list(valid.split_uop(Ops.AND)) for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)): - # TODO: root cause this and test_simplify_valid_from_div - if stmt.op is Ops.CAST: return None ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt) if ret[-1] is not stmt: something_changed = True return functools.reduce(operator.and_, ret) if something_changed else None