[FRONTEND][BACKEND] Add acquire/release semantics for atomics (#1739)

This commit is contained in:
Philippe Tillet
2023-06-05 19:09:13 -07:00
committed by GitHub
parent 9c8d7c18b3
commit c52a91231a
9 changed files with 141 additions and 73 deletions

View File

@@ -406,11 +406,6 @@ struct AtomicCASOpConversion
: valueTy;
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
Value mask = getMask(valueTy, rewriter, loc);
PTXBuilder ptxBuilderMemfence;
auto memfence = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
memfence();
auto ASMReturnTy = void_ty(ctx);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
@@ -424,7 +419,10 @@ struct AtomicCASOpConversion
auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, "r");
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r");
auto &atom = *ptxBuilderAtomicCAS.create<PTXInstr>("atom");
atom.global().o("cas").o("b32");
std::string semStr;
llvm::raw_string_ostream os(semStr);
os << op.getSem();
atom.global().o(semStr).o("cas").o("b32");
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask);
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
barrier();
@@ -435,8 +433,8 @@ struct AtomicCASOpConversion
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
st.shared().o("b32");
st(dstOprStore, valOprStore).predicate(mask);
auto ASMReturnTy = void_ty(ctx);
ptxBuilderStore.launch(rewriter, loc, ASMReturnTy);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
barrier();
Value ret = load(atomPtr);
barrier();
@@ -464,7 +462,7 @@ struct AtomicRMWOpConversion
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
MLIRContext *ctx = rewriter.getContext();
//
auto atomicRmwAttr = op.getAtomicRmwOp();
Value val = op.getVal();
@@ -565,7 +563,10 @@ struct AtomicRMWOpConversion
default:
return failure();
}
atom.o(rmwOp).o(sTy);
std::string semStr;
llvm::raw_string_ostream os(semStr);
os << op.getSem();
atom.o(semStr).o(rmwOp).o(sTy);
if (tensorTy) {
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
auto retType = vec == 1 ? valueElemTy : vecTy;
@@ -575,11 +576,7 @@ struct AtomicRMWOpConversion
vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii));
}
} else {
PTXBuilder ptxBuilderMemfence;
auto memfenc = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
memfenc();
auto ASMReturnTy = void_ty(ctx);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy);
if (op->user_begin() == op->user_end()) {