ROCM IFU: fix AtomicCASOpConversion segfault

This commit is contained in:
Michael Melesse
2023-12-12 17:40:31 -06:00
parent a42ac260aa
commit 6efc013e46
3 changed files with 112 additions and 46 deletions

View File

@@ -981,12 +981,16 @@ private:
namespace {
void createBarrier(ConversionPatternRewriter &rewriter, Location loc,
int numCTAs) {
#ifdef USE_ROCM
barrier();
#else
if (numCTAs == 1) {
barrier();
} else {
rewriter.create<triton::nvidia_gpu::ClusterArriveOp>(loc, false);
rewriter.create<triton::nvidia_gpu::ClusterWaitOp>(loc);
}
#endif
}
} // namespace
@@ -1008,6 +1012,7 @@ struct AtomicCASOpConversion
LogicalResult
matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// extract relevant info from Module
auto loc = op.getLoc();
MLIRContext *ctx = rewriter.getContext();
Value ptr = op.getPtr();
@@ -1016,6 +1021,7 @@ struct AtomicCASOpConversion
Value llCmp = adaptor.getCmp();
Value llVal = adaptor.getVal();
// prep data by unpacking to get data ready
auto ptrElements = getTypeConverter()->unpackLLElements(
loc, llPtr, rewriter, op.getPtr().getType());
auto cmpElements = getTypeConverter()->unpackLLElements(
@@ -1023,53 +1029,106 @@ struct AtomicCASOpConversion
auto valElements = getTypeConverter()->unpackLLElements(
loc, llVal, rewriter, op.getVal().getType());
auto TensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
// deal with tensor or scalar
auto valueTy = op.getResult().getType();
auto TensorTy = valueTy.dyn_cast<RankedTensorType>();
Type valueElemTy =
TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType())
: op.getResult().getType();
auto tid = tid_val();
Value pred = icmp_eq(tid, i32_val(0));
: valueTy;
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType());
// vec = 1 for scalar
auto vec = getVectorSize(op.getPtr());
// tensor
if (TensorTy) {
auto valTy = op.getVal().getType().cast<RankedTensorType>();
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
}
Value casPtr = ptrElements[0];
Value casCmp = cmpElements[0];
Value casVal = valElements[0];
Value mask = getMask(valueTy, rewriter, loc);
auto vecTy = vec_ty(valueElemTy, vec);
SmallVector<Value> resultVals(elemsPerThread);
// Build blocks to bypass the atomic instruction for ~rmwMask.
auto *curBlock = rewriter.getInsertionBlock();
auto *endBlock = curBlock->splitBlock(rewriter.getInsertionPoint());
auto *atomicBlock = rewriter.createBlock(
curBlock->getParent(), std::next(Region::iterator(curBlock)));
// atomic ops
for (size_t i = 0; i < elemsPerThread; i += vec) {
Value casVal = undef(vecTy);
for (int ii = 0; ii < vec; ++ii) {
Value iiVal = createIndexAttrConstant(
rewriter, loc, getTypeConverter()->getIndexType(), ii);
casVal = insert_element(vecTy, casVal, valElements[i + ii], iiVal);
}
// Fill entry block with global memory barrier and conditional branch.
rewriter.setInsertionPointToEnd(curBlock);
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
rewriter.create<LLVM::CondBrOp>(loc, pred, atomicBlock, endBlock);
Value casPtr = ptrElements[i];
Value casCmp = cmpElements[i];
casVal = valElements[i];
// Build main block with atomic_cmpxchg.
rewriter.setInsertionPointToEnd(atomicBlock);
// use op
if (TensorTy) { // for tensor
auto retType = vec == 1 ? valueElemTy : vecTy;
// TODO: USE ATOMIC CAS OP on Tensor
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
loc, casPtr, casCmp, casVal, successOrdering, failureOrdering,
StringRef("agent"));
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
loc, casPtr, casCmp, casVal, successOrdering,
failureOrdering, StringRef("agent"));
// Extract the new_loaded value from the pair.
Value newLoaded = extract_val(valueElemTy, cmpxchg, 0);
// Extract the new_loaded value from the pair.
Value ret = extract_val(valueElemTy, cmpxchg, i);
for (int ii = 0; ii < vec; ++ii) {
resultVals[i + ii] =
vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii));
}
} else { // for scalar
// Build blocks to bypass the atomic instruction for ~rmwMask.
auto *curBlock = rewriter.getInsertionBlock();
auto *endBlock = curBlock->splitBlock(rewriter.getInsertionPoint());
auto *atomicBlock = rewriter.createBlock(
curBlock->getParent(), std::next(Region::iterator(curBlock)));
store(newLoaded, atomPtr);
// Fill entry block with global memory barrier and conditional branch.
rewriter.setInsertionPointToEnd(curBlock);
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
auto tid = tid_val();
Value pred = icmp_eq(tid, i32_val(i));
rewriter.create<LLVM::CondBrOp>(loc, pred, atomicBlock, endBlock);
rewriter.create<LLVM::BrOp>(loc, ValueRange(), endBlock);
// Build main block with atomic_cmpxchg.
rewriter.setInsertionPointToEnd(atomicBlock);
// Build the last block: synced load from shared memory, exit.
rewriter.setInsertionPointToStart(endBlock);
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
loc, casPtr, casCmp, casVal, successOrdering, failureOrdering,
StringRef("agent"));
GCNBuilder BuilderMemfenceLDS;
BuilderMemfenceLDS.create<>("s_waitcnt lgkmcnt(0)")->operator()();
BuilderMemfenceLDS.launch(rewriter, loc, void_ty(ctx));
barrier();
Value ret = load(atomPtr);
rewriter.replaceOp(op, {ret});
// Extract the new_loaded value from the pair.
Value newLoaded = extract_val(valueElemTy, cmpxchg, 0);
store(newLoaded, atomPtr);
rewriter.create<LLVM::BrOp>(loc, ValueRange(), endBlock);
// Build the last block: synced load from shared memory, exit.
rewriter.setInsertionPointToStart(endBlock);
GCNBuilder BuilderMemfenceLDS;
BuilderMemfenceLDS.create<>("s_waitcnt lgkmcnt(0)")->operator()();
BuilderMemfenceLDS.launch(rewriter, loc, void_ty(ctx));
barrier();
Value ret = load(atomPtr);
rewriter.replaceOp(op, {ret});
}
}
// replace op
if (TensorTy) {
Type structTy = getTypeConverter()->convertType(TensorTy);
Value resultStruct = getTypeConverter()->packLLElements(
loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, {resultStruct});
}
return success();
}