[ROCM] Fixed tt.atomic_rmw for f16 type. (#222)

* [ROCM] Fixed `tt.atomic_rmw` for f16 type.

Fixed bug when `AtomicRMWOp` fails to process packed f16 operands.

* Address comment.
This commit is contained in:
Daniil Fukalov
2023-05-24 22:22:44 +02:00
committed by GitHub
parent 8a83986d0d
commit 1ee82e2a8e
2 changed files with 44 additions and 14 deletions

View File

@@ -658,25 +658,22 @@ struct AtomicRMWOpConversion
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
auto vecTy = vec_ty(valueElemTy, vec);
auto retType = vec == 1 ? valueElemTy : vecTy;
SmallVector<Value> resultVals(elemsPerThread);
const bool f16v2 = vec == 2 && valueElemTy.isF16();
for (size_t i = 0; i < elemsPerThread; i += vec) {
Value rmwVal = undef(vecTy);
for (int ii = 0; ii < vec; ++ii) {
Value iiVal = createIndexAttrConstant(
rewriter, loc, getTypeConverter()->getIndexType(), ii);
rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal);
}
Value rmwPtr = ptrElements[i];
// TODO: in case llMask is zero we can create only one branch for all
// elemsPerThread.
Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask;
Value undefVal = undef(valueElemTy);
Value undefVal = undef(retType);
// Build blocks to bypass the atomic instruction for ~rmwMask.
auto *curBlock = rewriter.getInsertionBlock();
auto *endBlock = curBlock->splitBlock(rewriter.getInsertionPoint());
auto *atomicBlock = rewriter.createBlock(
curBlock->getParent(), std::next(Region::iterator(curBlock)));
endBlock->addArgument({valueElemTy}, {loc});
endBlock->addArgument({retType}, {loc});
rewriter.setInsertionPointToEnd(curBlock);
rewriter.create<LLVM::CondBrOp>(loc, rmwMask, atomicBlock, endBlock,
@@ -684,12 +681,23 @@ struct AtomicRMWOpConversion
rewriter.setInsertionPointToEnd(atomicBlock);
auto maybeKind = matchAtomicOp(atomicRmwAttr);
// TODO: use amdgpu.raw_buffer_atomic_fadd for MI-* series of AMD GPU
// since it supports memref indexes (after moving triton to use memrefs).
auto atom = rewriter.create<LLVM::AtomicRMWOp>(
// TODO: use rocdl.raw.buffer.atomic from ROCDL dialect to use efficient
// atomics for MI-* series of AMD GPU.
Value atom = rewriter.create<LLVM::AtomicRMWOp>(
loc, *maybeKind, rmwPtr, valElements[i],
LLVM::AtomicOrdering::monotonic, StringRef("agent"));
rewriter.create<LLVM::BrOp>(loc, atom.getResult(), endBlock);
LLVM::AtomicOrdering::monotonic, StringRef("agent")).getResult();
// NV for the f16v2 case generates one packed instruction. We have to
// create two separate instructions since LLVM::AtomicRMWOp doesn't
// support this. Can be optimized out with rocdl.raw.buffer.atomic.
if (f16v2) {
Value atom2 = rewriter.create<LLVM::AtomicRMWOp>(
loc, *maybeKind, ptrElements[i+1], valElements[i + 1],
LLVM::AtomicOrdering::monotonic, StringRef("agent")).getResult();
auto tmp = insert_element(vecTy, undef(vecTy), atom, i32_val(0));
atom = insert_element(vecTy, tmp, atom2, i32_val(1)).getResult();
}
rewriter.create<LLVM::BrOp>(loc, atom, endBlock);
rewriter.setInsertionPointToStart(endBlock);
Value retVal = endBlock->getArgument(0);