minor update to uop_given_valid [pr] (#13243)

split from #13241
This commit is contained in:
chenyu
2025-11-12 16:03:18 -08:00
committed by GitHub
parent fe2876a6d8
commit f9851a852f

View File

@@ -430,10 +430,10 @@ def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp:
newuops = [uop.substitute({X:newX}) for X,newX in candidate]
if any(u is uop for u in newuops): continue # if any branch doesnt appear in uop, skip
newuops = [u.simplify().substitute({newX:X}).simplify(full_symbolic=False) for (X,newX),u in zip(candidate,newuops)]
if uop.op is Ops.VECTORIZE and len(uop.src) == 2:
if all_same(newuops): uop = newuops[0]
elif uop.op is Ops.VECTORIZE and len(uop.src) == 2:
if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
elif all_same(newuops): uop = newuops[0]
# try all the valids together (but only the whole expressions)
if (s_uop:=uop.substitute(sub_dict:=dict(all_candidates))) is not uop: