remove unmasked valid after swizzles (#9377)

This commit is contained in:
qazal
2025-03-07 17:43:16 +02:00
committed by GitHub
parent 088d86691b
commit dc89dae994

View File

@@ -359,6 +359,8 @@ fix_kernel_ops = PatternMatcher([
# remove CONTIGUOUS/DEVICE
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())),
# remove unmasked valid
(UPat.where(UPat(Ops.VALID, name="valid"), UPat.cvar("x"), UPat()), lambda valid,x: x if all(v.mask is None for v in valid.st.views) else None),
# no ImageDType after load
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly