[pr] move has_valid into pm_lowerer (#8308)

* [pr] move has_valid into pm_lowerer

* simpler
This commit is contained in:
leopf
2024-12-18 18:05:18 +01:00
committed by GitHub
parent 69eb55a529
commit c5ae66215a

View File

@@ -106,12 +106,10 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp):
def lower_load_store(ctx: IndexContext, x: UOp):
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs)
# TODO: check has_valid in UPat, not here
has_valid = valid.op is not Ops.CONST or valid.arg is not True
buf = x.src[0]
if x.op is Ops.LOAD:
barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is Ops.DEFINE_LOCAL else ()
return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid if has_valid else None),) + barrier)
return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid),) + barrier)
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
if cast(PtrDType, x.src[0].dtype).local and x.src[2].op is Ops.ASSIGN:
reduce_input = x.src[2].src[1].src[1] if x.src[2].src[1].src[1] is not x.src[2].src[0] else x.src[2].src[1].src[0]
@@ -122,14 +120,14 @@ def lower_load_store(ctx: IndexContext, x: UOp):
if (not cast(PtrDType, x.src[0].dtype).local) or store_back:
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
if oidx is not ridx: valid = valid * oidx.eq(0)
has_valid = valid.op is not Ops.CONST or valid.arg is not True
return UOp(Ops.STORE, dtypes.void, (buf.index(idx, valid if has_valid else None), x.src[2]))
return UOp(Ops.STORE, dtypes.void, (buf.index(idx, valid), x.src[2]))
pm_lowerer = PatternMatcher([
(UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
(UPat(Ops.VALID, src=(UPat(Ops.VIEW),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
# rewrite LOAD/STORE VIEW to LOAD/STORE with indexed
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),
])
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp: return graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))