mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Updated predicate for atomic ops (#1619)
This commit is contained in:
@@ -399,13 +399,13 @@ struct AtomicCASOpConversion
|
||||
auto valElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llVal, rewriter, op.getVal().getType());
|
||||
|
||||
auto TensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
auto valueTy = op.getResult().getType();
|
||||
auto TensorTy = valueTy.dyn_cast<RankedTensorType>();
|
||||
Type valueElemTy =
|
||||
TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType())
|
||||
: op.getResult().getType();
|
||||
: valueTy;
|
||||
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
|
||||
auto tid = tid_val();
|
||||
Value pred = icmp_eq(tid, i32_val(0));
|
||||
Value mask = getMask(valueTy, rewriter, loc);
|
||||
PTXBuilder ptxBuilderMemfence;
|
||||
auto memfence = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
|
||||
memfence();
|
||||
@@ -425,7 +425,7 @@ struct AtomicCASOpConversion
|
||||
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r");
|
||||
auto &atom = *ptxBuilderAtomicCAS.create<PTXInstr>("atom");
|
||||
atom.global().o("cas").o("b32");
|
||||
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(pred);
|
||||
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask);
|
||||
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
|
||||
barrier();
|
||||
|
||||
@@ -434,7 +434,7 @@ struct AtomicCASOpConversion
|
||||
auto *valOprStore = ptxBuilderStore.newOperand(old, "r");
|
||||
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
|
||||
st.shared().o("b32");
|
||||
st(dstOprStore, valOprStore).predicate(pred);
|
||||
st(dstOprStore, valOprStore).predicate(mask);
|
||||
ptxBuilderStore.launch(rewriter, loc, ASMReturnTy);
|
||||
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
|
||||
barrier();
|
||||
@@ -483,10 +483,11 @@ struct AtomicRMWOpConversion
|
||||
maskElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llMask, rewriter, op.getMask().getType());
|
||||
|
||||
auto tensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
auto valueTy = op.getResult().getType();
|
||||
auto tensorTy = valueTy.dyn_cast<RankedTensorType>();
|
||||
Type valueElemTy =
|
||||
tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType())
|
||||
: op.getResult().getType();
|
||||
: valueTy;
|
||||
const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
|
||||
auto elemsPerThread = getTotalElemsPerThread(val.getType());
|
||||
// vec = 1, numElements = 1 for scalar
|
||||
@@ -499,10 +500,7 @@ struct AtomicRMWOpConversion
|
||||
// mask
|
||||
numElems = tensorTy.getNumElements();
|
||||
}
|
||||
Value mask = int_val(1, 1);
|
||||
auto tid = tid_val();
|
||||
mask = and_(mask,
|
||||
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
|
||||
Value mask = getMask(valueTy, rewriter, loc);
|
||||
|
||||
auto vecTy = vec_ty(valueElemTy, vec);
|
||||
SmallVector<Value> resultVals(elemsPerThread);
|
||||
@@ -582,7 +580,6 @@ struct AtomicRMWOpConversion
|
||||
memfenc();
|
||||
auto ASMReturnTy = void_ty(ctx);
|
||||
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
|
||||
rmwMask = and_(rmwMask, icmp_eq(tid, i32_val(0)));
|
||||
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
||||
auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy);
|
||||
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
|
||||
Reference in New Issue
Block a user