mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
ROCM IFU: fix AtomicCASOpConversion segfault
This commit is contained in:
@@ -981,12 +981,16 @@ private:
|
||||
namespace {
|
||||
void createBarrier(ConversionPatternRewriter &rewriter, Location loc,
|
||||
int numCTAs) {
|
||||
#ifdef USE_ROCM
|
||||
barrier();
|
||||
#else
|
||||
if (numCTAs == 1) {
|
||||
barrier();
|
||||
} else {
|
||||
rewriter.create<triton::nvidia_gpu::ClusterArriveOp>(loc, false);
|
||||
rewriter.create<triton::nvidia_gpu::ClusterWaitOp>(loc);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
} // namespace
|
||||
|
||||
@@ -1008,6 +1012,7 @@ struct AtomicCASOpConversion
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// extract relevant info from Module
|
||||
auto loc = op.getLoc();
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
Value ptr = op.getPtr();
|
||||
@@ -1016,6 +1021,7 @@ struct AtomicCASOpConversion
|
||||
Value llCmp = adaptor.getCmp();
|
||||
Value llVal = adaptor.getVal();
|
||||
|
||||
// prep data by unpacking to get data ready
|
||||
auto ptrElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llPtr, rewriter, op.getPtr().getType());
|
||||
auto cmpElements = getTypeConverter()->unpackLLElements(
|
||||
@@ -1023,53 +1029,106 @@ struct AtomicCASOpConversion
|
||||
auto valElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llVal, rewriter, op.getVal().getType());
|
||||
|
||||
auto TensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
// deal with tensor or scalar
|
||||
auto valueTy = op.getResult().getType();
|
||||
auto TensorTy = valueTy.dyn_cast<RankedTensorType>();
|
||||
Type valueElemTy =
|
||||
TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType())
|
||||
: op.getResult().getType();
|
||||
auto tid = tid_val();
|
||||
Value pred = icmp_eq(tid, i32_val(0));
|
||||
: valueTy;
|
||||
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
|
||||
auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType());
|
||||
// vec = 1 for scalar
|
||||
auto vec = getVectorSize(op.getPtr());
|
||||
// tensor
|
||||
if (TensorTy) {
|
||||
auto valTy = op.getVal().getType().cast<RankedTensorType>();
|
||||
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
|
||||
}
|
||||
|
||||
Value casPtr = ptrElements[0];
|
||||
Value casCmp = cmpElements[0];
|
||||
Value casVal = valElements[0];
|
||||
Value mask = getMask(valueTy, rewriter, loc);
|
||||
auto vecTy = vec_ty(valueElemTy, vec);
|
||||
SmallVector<Value> resultVals(elemsPerThread);
|
||||
|
||||
// Build blocks to bypass the atomic instruction for ~rmwMask.
|
||||
auto *curBlock = rewriter.getInsertionBlock();
|
||||
auto *endBlock = curBlock->splitBlock(rewriter.getInsertionPoint());
|
||||
auto *atomicBlock = rewriter.createBlock(
|
||||
curBlock->getParent(), std::next(Region::iterator(curBlock)));
|
||||
// atomic ops
|
||||
for (size_t i = 0; i < elemsPerThread; i += vec) {
|
||||
Value casVal = undef(vecTy);
|
||||
for (int ii = 0; ii < vec; ++ii) {
|
||||
Value iiVal = createIndexAttrConstant(
|
||||
rewriter, loc, getTypeConverter()->getIndexType(), ii);
|
||||
casVal = insert_element(vecTy, casVal, valElements[i + ii], iiVal);
|
||||
}
|
||||
|
||||
// Fill entry block with global memory barrier and conditional branch.
|
||||
rewriter.setInsertionPointToEnd(curBlock);
|
||||
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
|
||||
rewriter.create<LLVM::CondBrOp>(loc, pred, atomicBlock, endBlock);
|
||||
Value casPtr = ptrElements[i];
|
||||
Value casCmp = cmpElements[i];
|
||||
casVal = valElements[i];
|
||||
|
||||
// Build main block with atomic_cmpxchg.
|
||||
rewriter.setInsertionPointToEnd(atomicBlock);
|
||||
// use op
|
||||
if (TensorTy) { // for tensor
|
||||
auto retType = vec == 1 ? valueElemTy : vecTy;
|
||||
// TODO: USE ATOMIC CAS OP on Tensor
|
||||
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
|
||||
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
|
||||
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
|
||||
loc, casPtr, casCmp, casVal, successOrdering, failureOrdering,
|
||||
StringRef("agent"));
|
||||
|
||||
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
|
||||
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
|
||||
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
|
||||
loc, casPtr, casCmp, casVal, successOrdering,
|
||||
failureOrdering, StringRef("agent"));
|
||||
// Extract the new_loaded value from the pair.
|
||||
Value newLoaded = extract_val(valueElemTy, cmpxchg, 0);
|
||||
// Extract the new_loaded value from the pair.
|
||||
Value ret = extract_val(valueElemTy, cmpxchg, i);
|
||||
|
||||
for (int ii = 0; ii < vec; ++ii) {
|
||||
resultVals[i + ii] =
|
||||
vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii));
|
||||
}
|
||||
} else { // for scalar
|
||||
// Build blocks to bypass the atomic instruction for ~rmwMask.
|
||||
auto *curBlock = rewriter.getInsertionBlock();
|
||||
auto *endBlock = curBlock->splitBlock(rewriter.getInsertionPoint());
|
||||
auto *atomicBlock = rewriter.createBlock(
|
||||
curBlock->getParent(), std::next(Region::iterator(curBlock)));
|
||||
|
||||
store(newLoaded, atomPtr);
|
||||
// Fill entry block with global memory barrier and conditional branch.
|
||||
rewriter.setInsertionPointToEnd(curBlock);
|
||||
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
|
||||
auto tid = tid_val();
|
||||
Value pred = icmp_eq(tid, i32_val(i));
|
||||
rewriter.create<LLVM::CondBrOp>(loc, pred, atomicBlock, endBlock);
|
||||
|
||||
rewriter.create<LLVM::BrOp>(loc, ValueRange(), endBlock);
|
||||
// Build main block with atomic_cmpxchg.
|
||||
rewriter.setInsertionPointToEnd(atomicBlock);
|
||||
|
||||
// Build the last block: synced load from shared memory, exit.
|
||||
rewriter.setInsertionPointToStart(endBlock);
|
||||
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
|
||||
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
|
||||
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
|
||||
loc, casPtr, casCmp, casVal, successOrdering, failureOrdering,
|
||||
StringRef("agent"));
|
||||
|
||||
GCNBuilder BuilderMemfenceLDS;
|
||||
BuilderMemfenceLDS.create<>("s_waitcnt lgkmcnt(0)")->operator()();
|
||||
BuilderMemfenceLDS.launch(rewriter, loc, void_ty(ctx));
|
||||
barrier();
|
||||
Value ret = load(atomPtr);
|
||||
rewriter.replaceOp(op, {ret});
|
||||
// Extract the new_loaded value from the pair.
|
||||
Value newLoaded = extract_val(valueElemTy, cmpxchg, 0);
|
||||
|
||||
store(newLoaded, atomPtr);
|
||||
|
||||
rewriter.create<LLVM::BrOp>(loc, ValueRange(), endBlock);
|
||||
|
||||
// Build the last block: synced load from shared memory, exit.
|
||||
rewriter.setInsertionPointToStart(endBlock);
|
||||
|
||||
GCNBuilder BuilderMemfenceLDS;
|
||||
BuilderMemfenceLDS.create<>("s_waitcnt lgkmcnt(0)")->operator()();
|
||||
BuilderMemfenceLDS.launch(rewriter, loc, void_ty(ctx));
|
||||
barrier();
|
||||
Value ret = load(atomPtr);
|
||||
rewriter.replaceOp(op, {ret});
|
||||
}
|
||||
}
|
||||
|
||||
// replace op
|
||||
if (TensorTy) {
|
||||
Type structTy = getTypeConverter()->convertType(TensorTy);
|
||||
Value resultStruct = getTypeConverter()->packLLElements(
|
||||
loc, resultVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, {resultStruct});
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user