revert fold to find test case

This commit is contained in:
Mesozoic Egg
2025-01-04 02:09:15 +08:00
parent abdbe09d2f
commit 225bb6e801

View File

@@ -129,7 +129,7 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp):
def lower_load_store(ctx: IndexContext, x: UOp):
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs)
idx, valid = folded_upcast(idx), folded_upcast(valid)
idx, valid = upcast(idx), upcast(valid)
buf = x.src[0]
if x.op is Ops.LOAD:
barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is Ops.DEFINE_LOCAL else ()
@@ -142,7 +142,7 @@ def lower_load_store(ctx: IndexContext, x: UOp):
# NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes
if store_back:
idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if u in x.src[2].src else u for u in ctx.idxs])
idx = folded_upcast(idx)
idx = upcast(idx)
if (not cast(PtrDType, x.src[0].dtype).local) or store_back:
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
if oidx is not ridx: valid = valid * oidx.eq(0)
@@ -150,7 +150,7 @@ def lower_load_store(ctx: IndexContext, x: UOp):
pm_lowerer = PatternMatcher([
(UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
(UPat(Ops.VALID, src=(UPat(Ops.VIEW),), name="x"), lambda ctx,x: folded_upcast(x.st_arg.to_indexed_uops(ctx.idxs)[1])),
(UPat(Ops.VALID, src=(UPat(Ops.VIEW),), name="x"), lambda ctx,x: upcast(x.st_arg.to_indexed_uops(ctx.idxs)[1])),
# rewrite LOAD/STORE VIEW to LOAD/STORE with indexed
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),