construct a candidate set to attempt valid idx rewrite (#6706)

preparation for the brute force attempt for some valids
This commit is contained in:
chenyu
2024-09-24 04:12:21 -04:00
committed by GitHub
parent 2be0b26a1f
commit f2700ac58a

View File

@@ -216,22 +216,24 @@ def simplify_valid_image_load(load:UOp, buf:UOp):
if v[0] is not None and v[1] is not None and v[0] > v[1]:
return UOp(UOps.LOAD, load.dtype, (buf, idx, invalid_val, valid.const_like(False)))
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the idx into a same output, we rewrite idx
candidates = []
if uop.op is UOps.ALU and uop.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in _get_chain(uop, BinaryOps.ADD)):
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
to_check = [(Xi, UOp(UOps.DEFINE_VAR, Xi.dtype, (), ("fake", 1, Xi.vmax))) for Xi in _get_chain(uop, BinaryOps.ADD)]
else:
# try checking the whole clause
to_check = [(uop, UOp.define_var("fake", uop.dtype, uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1]))]
candidates.append([(Xi, UOp(UOps.DEFINE_VAR, Xi.dtype, (), ("fake", 1, Xi.vmax))) for Xi in _get_chain(uop, BinaryOps.ADD)])
# try checking the whole clause
candidates.append([(uop, UOp.define_var("fake", uop.dtype, uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1]))])
newidxs:List[List[UOp]] = [[], []]
for X,newX in to_check:
newidx = replace_uop(graph_rewrite(replace_uop(idx, X, newX), constant_folder), newX, X)
newidxs[0].append(newidx.src[0])
newidxs[1].append(newidx.src[1])
for candidate in candidates:
newidxs:List[List[UOp]] = [[], []]
for X,newX in candidate:
newidx = replace_uop(graph_rewrite(replace_uop(idx, X, newX), constant_folder), newX, X)
newidxs[0].append(newidx.src[0])
newidxs[1].append(newidx.src[1])
# if every branch in to_check gives the same simplified output, we can rewrite the idx
if len(newidxs[0])==1 or (len(newidxs[0]) > 1 and all_same([i.key for i in newidxs[0]])): idx = idx.replace(src=(newidxs[0][0], idx.src[1]))
if len(newidxs[1])==1 or (len(newidxs[1]) > 1 and all_same([i.key for i in newidxs[1]])): idx = idx.replace(src=(idx.src[0], newidxs[1][0]))
# if every branch in candidate gives the same simplified output, we can rewrite the idx
if len(newidxs[0])==1 or (len(newidxs[0]) > 1 and all_same([i.key for i in newidxs[0]])): idx = idx.replace(src=(newidxs[0][0], idx.src[1]))
if len(newidxs[1])==1 or (len(newidxs[1]) > 1 and all_same([i.key for i in newidxs[1]])): idx = idx.replace(src=(idx.src[0], newidxs[1][0]))
# can drop valid if idx is out of bound when valid is False
drop_stmt = []