remove outdated index masking in lowerer [pr] (#10953)

* add assert to check idx is never replaced with const 0

* remove outdated index masking
This commit is contained in:
Ignacio Sica
2025-06-24 11:53:30 -03:00
committed by GitHub
parent cc32394b32
commit f15247d2d2

View File

@@ -122,8 +122,6 @@ def lower_load_store(ctx: IndexContext, x: UOp, buf: UOp):
reduce_input = x.src[1].src[0]
store_back = reduce_input.op is Ops.LOAD and cast(PtrDType, reduce_input.src[0].dtype).local
else: store_back = False
# NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes
if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if u in x.src[1].src else u for u in ctx.idxs])
if (not cast(PtrDType, buf.dtype).local) or store_back:
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
if oidx is not ridx: valid = valid * oidx.eq(0)