diff --git a/test/test_linearizer.py b/test/test_linearizer.py index db11c260b9..c15ad8c2e2 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -96,7 +96,6 @@ class TestLinearizer(unittest.TestCase): ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE] assert len(ranges) == 1 # NOTE: it collapses now - @unittest.expectedFailure # TODO: investigate def test_two_nested_range_alt_indexing(self): a = Tensor([2, 2]).realize() out = a.reshape(2, 1).pad(((1, 1), (1, 1)), value=2).sum() diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 212813eae5..bdefa8e069 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -3,7 +3,7 @@ import math, operator, struct, functools from collections import defaultdict from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu from tinygrad.dtype import ConstType, dtypes, PtrDType, can_safe_cast, Invalid -from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, unwrap, dedup +from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, unwrap, IMAGE, dedup from tinygrad.uop.decompositions import xpow from tinygrad.uop.divandmod import div_and_mod_symbolic @@ -363,11 +363,15 @@ pm_move_where_on_load = PatternMatcher([ (UPat.var("c1").where(0, UPat.var("buf").index(UPat.var("x"))), lambda c1,buf,x: where_on_load(c1.logical_not(),buf,x)), ]) +def gated_given_valid(cond:UOp, x:UOp, i:UOp) -> UOp|None: + # Skip if x contains DIV/MOD AND IMAGE mode is enabled -> image index e.g. openpilot + if IMAGE.value > 0 and x.op_in_backward_slice_with_self(Ops.IDIV, Ops.MOD): return None + return cond.where(uop_given_valid(cond, x, try_simplex=False), i) + pm_simplify_valid = PatternMatcher([ # simplify valid (UPat(Ops.AND, name="valid"), simplify_valid), - # TODO: this regressed openpilot, not having this regressed cifar - # (invalid_gate, lambda cond,x,i: cond.where(uop_given_valid(cond, x, try_simplex=False), i)), + (invalid_gate, gated_given_valid), ]) # this is symbolic 2.0