mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FA OPTIMIZATION] Keep results of FA dot operations in registers (#247)
* [WIP][FA OPTIMIZATION] Optimize chain dot This commit optimizes chain dot operation by keeping results of the first dot operation in registers. * [FA OPTIMIZATION] Enable lowering pipeline for keeping result of chain dot in registers * Move operand swapping in ttgir -> llir lowering phase * Refactor emitMfmaOffsetForCTA function to be more readable * Fix accidental change in 06-fused-attention.py * Address review comments * Fix rebase errors
This commit is contained in:
@@ -87,6 +87,12 @@ public:
|
||||
dstLayout.isa<DotOperandEncodingAttr>()) {
|
||||
return lowerMmaToDotOperand(op, adaptor, rewriter);
|
||||
}
|
||||
#ifdef USE_ROCM
|
||||
if (srcLayout.isa<MfmaEncodingAttr>() &&
|
||||
dstLayout.isa<DotOperandEncodingAttr>()) {
|
||||
return lowerMfmaToDotOperand(op, adaptor, rewriter);
|
||||
}
|
||||
#endif
|
||||
if (srcLayout.isa<SharedEncodingAttr>() &&
|
||||
isaDistributedLayout(dstLayout)) {
|
||||
return lowerSharedToDistributed(op, adaptor, rewriter);
|
||||
@@ -205,48 +211,13 @@ private:
|
||||
}
|
||||
#ifdef USE_ROCM
|
||||
if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
SmallVector<Value> mfmaColIdx(4);
|
||||
SmallVector<Value> mfmaRowIdx(16);
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
unsigned iWaveSize = triton::gpu::getWarpSize(layout);
|
||||
Value warpSize = i32_val(iWaveSize);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
// TODO: fix the bug in MMAEncodingAttr document
|
||||
SmallVector<Value> multiDimWarpId(2);
|
||||
multiDimWarpId[0] = urem(warpId, i32_val(mfmaLayout.getWarpsPerCTA()[0]));
|
||||
multiDimWarpId[1] = udiv(warpId, i32_val(mfmaLayout.getWarpsPerCTA()[0]));
|
||||
Value _0 = i32_val(0);
|
||||
Value _1 = i32_val(1);
|
||||
Value _4 = i32_val(4);
|
||||
Value _8 = i32_val(8);
|
||||
Value _32 = i32_val(32);
|
||||
multiDimWarpId[0] = urem(multiDimWarpId[0], i32_val(shape[0] / 32));
|
||||
multiDimWarpId[1] = urem(multiDimWarpId[1], i32_val(shape[1] / 32));
|
||||
Value halfOffset = select(icmp_uge(laneId, _32), _4, _0);
|
||||
Value mfmaGroup32 = urem(laneId, _32);
|
||||
Value rowWarpOffset = mul(multiDimWarpId[0], _32);
|
||||
for (unsigned block = 0; block < 4; ++block) {
|
||||
mfmaRowIdx[4 * block] = block == 0
|
||||
? add(halfOffset, rowWarpOffset)
|
||||
: add(mfmaRowIdx[4 * (block - 1)], _8);
|
||||
for (int r = 1; r < 4; ++r) {
|
||||
mfmaRowIdx[4 * block + r] = add(mfmaRowIdx[4 * block + r - 1], _1);
|
||||
}
|
||||
}
|
||||
Value colWarpOffset = mul(multiDimWarpId[1], _32);
|
||||
mfmaColIdx[0] = add(mfmaGroup32, colWarpOffset);
|
||||
|
||||
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, type);
|
||||
SmallVector<SmallVector<unsigned>> offsets;
|
||||
assert(rank == 2);
|
||||
SmallVector<Value> multiDimOffset(rank);
|
||||
|
||||
multiDimOffset[0] = mfmaRowIdx[elemId % 16];
|
||||
|
||||
multiDimOffset[1] = mfmaColIdx[0];
|
||||
multiDimOffset[0] = add(multiDimOffset[0],
|
||||
i32_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
|
||||
multiDimOffset[1] = add(multiDimOffset[1],
|
||||
i32_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
|
||||
emitMfmaOffsetForCTA(mfmaLayout, offsets, multiDimCTAInRepId[0], multiDimCTAInRepId[1]);
|
||||
multiDimOffset[0] = add(multiDimBase[0], i32_val(offsets[elemId][0]));
|
||||
multiDimOffset[1] = add(multiDimBase[1], i32_val(offsets[elemId][1]));
|
||||
return multiDimOffset;
|
||||
}
|
||||
#endif
|
||||
@@ -676,6 +647,45 @@ private:
|
||||
return success();
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
LogicalResult
|
||||
lowerMfmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
|
||||
auto dstTy = op.getResult().getType().cast<RankedTensorType>();
|
||||
if (isMfmaToDotShortcut(srcTy, dstTy)) {
|
||||
// get source values
|
||||
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
|
||||
rewriter, srcTy);
|
||||
unsigned elems = getTotalElemsPerThread(srcTy);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(srcTy.getElementType());
|
||||
// for the destination type, we need to pack values together
|
||||
// so they can be consumed by tensor core operations
|
||||
SmallVector<Value> vecVals;
|
||||
SmallVector<Type> types;
|
||||
auto elemSize = elemTy.getIntOrFloatBitWidth();
|
||||
// TODO: Support types other than float16.
|
||||
assert(type::isFloat(elemTy) && elemSize == 16);
|
||||
unsigned vecSize = 4;
|
||||
Type vecTy = vec_ty(elemTy, vecSize);
|
||||
types = SmallVector<Type>(elems / vecSize, vecTy);
|
||||
for (unsigned i = 0; i < elems; i += vecSize) {
|
||||
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||
for (unsigned j = 0; j < vecSize; j++)
|
||||
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
|
||||
vecVals.push_back(packed);
|
||||
}
|
||||
Value view =
|
||||
getTypeConverter()->packLLElements(loc, vecVals, rewriter, dstTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
#endif
|
||||
|
||||
// mma -> dot_operand
|
||||
LogicalResult
|
||||
lowerMmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
|
||||
Reference in New Issue
Block a user