diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index da920e2da8..bad5eaa797 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -407,14 +407,10 @@ def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp: expr, is_upper, c = res bounds[expr][int(is_upper)] = c - # don't simplify any other gates, can lead to OOB, we substitute them back later - uop = uop.substitute((load_subs:={u: UOp(Ops.NOOP, dtype=u.dtype, arg=u) for u in uop.toposort() if u.op is Ops.INDEX})) - # simplify uop given that valid is True all_candidates = [] for i,(expr,v) in enumerate(bounds.items()): v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1]) - expr = expr.substitute(load_subs) # make sure expr appears in same form in the uop # try checking the whole clause all_candidates.append((expr, UOp.variable(f"fake{i}", v0, v1, expr.dtype))) @@ -438,8 +434,6 @@ def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp: # try all the valids together (but only the whole expressions) if (s_uop:=uop.substitute(sub_dict:=dict(all_candidates))) is not uop: uop = s_uop.simplify().substitute({newX:X for X,newX in sub_dict.items()}).simplify(full_symbolic=False) - # put the loads back in - uop = uop.substitute({v:k for k,v in load_subs.items()}) return uop def _valid_priority(v: UOp, valids:list[UOp]):