mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND][BACKEND] Add acquire/release semantics for atomics (#1739)
This commit is contained in:
@@ -406,11 +406,6 @@ struct AtomicCASOpConversion
|
||||
: valueTy;
|
||||
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
|
||||
Value mask = getMask(valueTy, rewriter, loc);
|
||||
PTXBuilder ptxBuilderMemfence;
|
||||
auto memfence = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
|
||||
memfence();
|
||||
auto ASMReturnTy = void_ty(ctx);
|
||||
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
|
||||
|
||||
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
|
||||
@@ -424,7 +419,10 @@ struct AtomicCASOpConversion
|
||||
auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, "r");
|
||||
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r");
|
||||
auto &atom = *ptxBuilderAtomicCAS.create<PTXInstr>("atom");
|
||||
atom.global().o("cas").o("b32");
|
||||
std::string semStr;
|
||||
llvm::raw_string_ostream os(semStr);
|
||||
os << op.getSem();
|
||||
atom.global().o(semStr).o("cas").o("b32");
|
||||
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask);
|
||||
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
|
||||
barrier();
|
||||
@@ -435,8 +433,8 @@ struct AtomicCASOpConversion
|
||||
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
|
||||
st.shared().o("b32");
|
||||
st(dstOprStore, valOprStore).predicate(mask);
|
||||
auto ASMReturnTy = void_ty(ctx);
|
||||
ptxBuilderStore.launch(rewriter, loc, ASMReturnTy);
|
||||
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
|
||||
barrier();
|
||||
Value ret = load(atomPtr);
|
||||
barrier();
|
||||
@@ -464,7 +462,7 @@ struct AtomicRMWOpConversion
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
|
||||
//
|
||||
auto atomicRmwAttr = op.getAtomicRmwOp();
|
||||
|
||||
Value val = op.getVal();
|
||||
@@ -565,7 +563,10 @@ struct AtomicRMWOpConversion
|
||||
default:
|
||||
return failure();
|
||||
}
|
||||
atom.o(rmwOp).o(sTy);
|
||||
std::string semStr;
|
||||
llvm::raw_string_ostream os(semStr);
|
||||
os << op.getSem();
|
||||
atom.o(semStr).o(rmwOp).o(sTy);
|
||||
if (tensorTy) {
|
||||
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
||||
auto retType = vec == 1 ? valueElemTy : vecTy;
|
||||
@@ -575,11 +576,7 @@ struct AtomicRMWOpConversion
|
||||
vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii));
|
||||
}
|
||||
} else {
|
||||
PTXBuilder ptxBuilderMemfence;
|
||||
auto memfenc = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
|
||||
memfenc();
|
||||
auto ASMReturnTy = void_ty(ctx);
|
||||
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
|
||||
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
||||
auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy);
|
||||
if (op->user_begin() == op->user_end()) {
|
||||
|
||||
Reference in New Issue
Block a user