mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[ROCM] Fixed tt.atomic_rmw for f16 type. (#222)
* [ROCM] Fixed `tt.atomic_rmw` for f16 type. Fixed bug when `AtomicRMWOp` fails to process packed f16 operands. * Address comment.
This commit is contained in:
@@ -658,25 +658,22 @@ struct AtomicRMWOpConversion
|
||||
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
|
||||
|
||||
auto vecTy = vec_ty(valueElemTy, vec);
|
||||
auto retType = vec == 1 ? valueElemTy : vecTy;
|
||||
SmallVector<Value> resultVals(elemsPerThread);
|
||||
const bool f16v2 = vec == 2 && valueElemTy.isF16();
|
||||
for (size_t i = 0; i < elemsPerThread; i += vec) {
|
||||
Value rmwVal = undef(vecTy);
|
||||
for (int ii = 0; ii < vec; ++ii) {
|
||||
Value iiVal = createIndexAttrConstant(
|
||||
rewriter, loc, getTypeConverter()->getIndexType(), ii);
|
||||
rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal);
|
||||
}
|
||||
|
||||
Value rmwPtr = ptrElements[i];
|
||||
// TODO: in case llMask is zero we can create only one branch for all
|
||||
// elemsPerThread.
|
||||
Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask;
|
||||
|
||||
Value undefVal = undef(valueElemTy);
|
||||
Value undefVal = undef(retType);
|
||||
// Build blocks to bypass the atomic instruction for ~rmwMask.
|
||||
auto *curBlock = rewriter.getInsertionBlock();
|
||||
auto *endBlock = curBlock->splitBlock(rewriter.getInsertionPoint());
|
||||
auto *atomicBlock = rewriter.createBlock(
|
||||
curBlock->getParent(), std::next(Region::iterator(curBlock)));
|
||||
endBlock->addArgument({valueElemTy}, {loc});
|
||||
endBlock->addArgument({retType}, {loc});
|
||||
|
||||
rewriter.setInsertionPointToEnd(curBlock);
|
||||
rewriter.create<LLVM::CondBrOp>(loc, rmwMask, atomicBlock, endBlock,
|
||||
@@ -684,12 +681,23 @@ struct AtomicRMWOpConversion
|
||||
|
||||
rewriter.setInsertionPointToEnd(atomicBlock);
|
||||
auto maybeKind = matchAtomicOp(atomicRmwAttr);
|
||||
// TODO: use amdgpu.raw_buffer_atomic_fadd for MI-* series of AMD GPU
|
||||
// since it supports memref indexes (after moving triton to use memrefs).
|
||||
auto atom = rewriter.create<LLVM::AtomicRMWOp>(
|
||||
// TODO: use rocdl.raw.buffer.atomic from ROCDL dialect to use efficient
|
||||
// atomics for MI-* series of AMD GPU.
|
||||
Value atom = rewriter.create<LLVM::AtomicRMWOp>(
|
||||
loc, *maybeKind, rmwPtr, valElements[i],
|
||||
LLVM::AtomicOrdering::monotonic, StringRef("agent"));
|
||||
rewriter.create<LLVM::BrOp>(loc, atom.getResult(), endBlock);
|
||||
LLVM::AtomicOrdering::monotonic, StringRef("agent")).getResult();
|
||||
|
||||
// NV for the f16v2 case generates one packed instruction. We have to
|
||||
// create two separate instructions since LLVM::AtomicRMWOp doesn't
|
||||
// support this. Can be optimized out with rocdl.raw.buffer.atomic.
|
||||
if (f16v2) {
|
||||
Value atom2 = rewriter.create<LLVM::AtomicRMWOp>(
|
||||
loc, *maybeKind, ptrElements[i+1], valElements[i + 1],
|
||||
LLVM::AtomicOrdering::monotonic, StringRef("agent")).getResult();
|
||||
auto tmp = insert_element(vecTy, undef(vecTy), atom, i32_val(0));
|
||||
atom = insert_element(vecTy, tmp, atom2, i32_val(1)).getResult();
|
||||
}
|
||||
rewriter.create<LLVM::BrOp>(loc, atom, endBlock);
|
||||
|
||||
rewriter.setInsertionPointToStart(endBlock);
|
||||
Value retVal = endBlock->getArgument(0);
|
||||
|
||||
Reference in New Issue
Block a user