From e7e70a3c95acd9fc245d491eeb8bfb077c45b766 Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Tue, 3 Mar 2026 20:53:50 -0800 Subject: [PATCH] simplify idx before counting backward_slice (#15117) --- tinygrad/codegen/late/devectorizer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index 230d5e357c..45c7e73546 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -189,7 +189,10 @@ def _do_image_fixup(dt:ImageDType, idx:UOp) -> tuple[UOp, UOp, int, int]: h, w = dt.shape[0], dt.shape[1] if IMAGE == 1 and valid is not None: h, w = max(ImageDType.valid_dims(dt), key=lambda hw: - (len(_drop_valid_stmts(valid, idx:=uop_given_valid(valid, UOp.vectorize((x//4)%hw[1], x//(4*hw[1]))), *hw)), -len(idx.backward_slice))) + # maximize number of valids removed + (len(_drop_valid_stmts(valid, idx:=uop_given_valid(valid, UOp.vectorize((x//4)%hw[1], x//(4*hw[1]))), *hw)), + # and minimize idx complexity (number of nodes) + -len(idx.simplify().backward_slice))) buf = buf.replace(dtype=(dtypes.imageh if dt.itemsize == 2 else dtypes.imagef)((h, w, 4), w * 4 * dt.itemsize)) oidx = UOp(Ops.VECTORIZE, dtypes.index.vec(2), ((x // 4) % w, (x // (4*w)))) return x, idx.replace(src=(buf, oidx.valid(valid))), w, h