no load substitute in uop_given_valid [pr] (#13333)

This commit is contained in:
chenyu
2025-11-18 11:47:58 -05:00
committed by GitHub
parent 05294bc648
commit 805de27e07

View File

@@ -407,14 +407,10 @@ def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp:
expr, is_upper, c = res
bounds[expr][int(is_upper)] = c
# don't simplify any other gates, can lead to OOB, we substitute them back later
uop = uop.substitute((load_subs:={u: UOp(Ops.NOOP, dtype=u.dtype, arg=u) for u in uop.toposort() if u.op is Ops.INDEX}))
# simplify uop given that valid is True
all_candidates = []
for i,(expr,v) in enumerate(bounds.items()):
v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1])
expr = expr.substitute(load_subs) # make sure expr appears in same form in the uop
# try checking the whole clause
all_candidates.append((expr, UOp.variable(f"fake{i}", v0, v1, expr.dtype)))
@@ -438,8 +434,6 @@ def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp:
# try all the valids together (but only the whole expressions)
if (s_uop:=uop.substitute(sub_dict:=dict(all_candidates))) is not uop:
uop = s_uop.simplify().substitute({newX:X for X,newX in sub_dict.items()}).simplify(full_symbolic=False)
# put the loads back in
uop = uop.substitute({v:k for k,v in load_subs.items()})
return uop
def _valid_priority(v: UOp, valids:list[UOp]):