diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index 7d37aa180d..93d65da486 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -10,6 +10,7 @@ from tinygrad.renderer import Renderer # ***** image load valid simplification ***** +@functools.cache def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]: # can drop valid if idx is out of bound when valid is False drop_stmt = [] @@ -187,12 +188,15 @@ def _do_image_fixup(dt:ImageDType, idx:UOp) -> tuple[UOp, UOp, int, int]: buf = idx.src[0] x, valid = idx.src[1].get_idx(), idx.src[1].get_valid() h, w = dt.shape[0], dt.shape[1] - if IMAGE == 1 and valid is not None: - h, w = max(ImageDType.valid_dims(dt), key=lambda hw: - # maximize number of valids removed - (len(_drop_valid_stmts(valid, idx:=uop_given_valid(valid, UOp.vectorize((x//4)%hw[1], x//(4*hw[1])).simplify()), *hw)), - # and minimize idx complexity (number of nodes) - -len(idx.gep(1).backward_slice))) + if IMAGE == 1: + # search for dims that drop the most valid statements + best_drop, cands = -1, [] + for ch, cw in ImageDType.valid_dims(dt): + if (dropped:=len(_drop_valid_stmts(valid, cidx:=uop_given_valid(valid, UOp.vectorize((x//4)%cw, x//(4*cw))), ch, cw))) > best_drop: + best_drop, cands = dropped, [(ch, cw, cidx)] + elif dropped == best_drop: cands.append((ch, cw, cidx)) + # and tiebreak with indexing complexity (ie. number of nodes) + h, w, _ = cands[0] if len(cands) == 1 else min(cands, key=lambda cand: len(cand[2].simplify().gep(1).backward_slice)) buf = buf.replace(dtype=(dtypes.imageh if dt.itemsize == 2 else dtypes.imagef)((h, w, 4), w * 4 * dt.itemsize)) oidx = UOp(Ops.VECTORIZE, dtypes.index.vec(2), ((x // 4) % w, (x // (4*w)))) return x, idx.replace(src=(buf, oidx.valid(valid))), w, h