From f2700ac58a090b7dcc89cb0276460d811abbc38a Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 24 Sep 2024 04:12:21 -0400 Subject: [PATCH] construct a candidate set to attempt valid idx rewrite (#6706) preparation for the brute force attempt for some valids --- tinygrad/codegen/uopgraph.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 0f119ede31..fde15df847 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -216,22 +216,24 @@ def simplify_valid_image_load(load:UOp, buf:UOp): if v[0] is not None and v[1] is not None and v[0] > v[1]: return UOp(UOps.LOAD, load.dtype, (buf, idx, invalid_val, valid.const_like(False))) + # every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the idx into a same output, we rewrite idx + candidates = [] if uop.op is UOps.ALU and uop.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in _get_chain(uop, BinaryOps.ADD)): # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output - to_check = [(Xi, UOp(UOps.DEFINE_VAR, Xi.dtype, (), ("fake", 1, Xi.vmax))) for Xi in _get_chain(uop, BinaryOps.ADD)] - else: - # try checking the whole clause - to_check = [(uop, UOp.define_var("fake", uop.dtype, uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1]))] + candidates.append([(Xi, UOp(UOps.DEFINE_VAR, Xi.dtype, (), ("fake", 1, Xi.vmax))) for Xi in _get_chain(uop, BinaryOps.ADD)]) + # try checking the whole clause + candidates.append([(uop, UOp.define_var("fake", uop.dtype, uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1]))]) - newidxs:List[List[UOp]] = [[], []] - for X,newX in to_check: - newidx = replace_uop(graph_rewrite(replace_uop(idx, X, newX), constant_folder), newX, X) - newidxs[0].append(newidx.src[0]) - newidxs[1].append(newidx.src[1]) + for candidate in candidates: + newidxs:List[List[UOp]] = [[], []] + for X,newX in candidate: + newidx = replace_uop(graph_rewrite(replace_uop(idx, X, newX), constant_folder), newX, X) + newidxs[0].append(newidx.src[0]) + newidxs[1].append(newidx.src[1]) - # if every branch in to_check gives the same simplified output, we can rewrite the idx - if len(newidxs[0])==1 or (len(newidxs[0]) > 1 and all_same([i.key for i in newidxs[0]])): idx = idx.replace(src=(newidxs[0][0], idx.src[1])) - if len(newidxs[1])==1 or (len(newidxs[1]) > 1 and all_same([i.key for i in newidxs[1]])): idx = idx.replace(src=(idx.src[0], newidxs[1][0])) + # if every branch in candidate gives the same simplified output, we can rewrite the idx + if len(newidxs[0])==1 or (len(newidxs[0]) > 1 and all_same([i.key for i in newidxs[0]])): idx = idx.replace(src=(newidxs[0][0], idx.src[1])) + if len(newidxs[1])==1 or (len(newidxs[1]) > 1 and all_same([i.key for i in newidxs[1]])): idx = idx.replace(src=(idx.src[0], newidxs[1][0])) # can drop valid if idx is out of bound when valid is False drop_stmt = []