[FA OPTIMIZATION] Keep results of FA dot operations in registers (#247)

* [WIP][FA OPTIMIZATION] Optimize chain dot

This commit optimizes chain dot operation by keeping
results of the first dot operation in registers.

* [FA OPTIMIZATION] Enable lowering pipeline for keeping result of chain dot in registers

* Move operand swapping in ttgir -> llir lowering phase

* Refactor emitMfmaOffsetForCTA function to be more readable

* Fix accidental change in 06-fused-attention.py

* Address review comments

* Fix rebase errors
This commit is contained in:
oplavsic
2023-07-12 22:25:55 +02:00
committed by GitHub
parent 4d0deef45f
commit d6e51fd221
17 changed files with 299 additions and 100 deletions

View File

@@ -87,6 +87,12 @@ public:
dstLayout.isa<DotOperandEncodingAttr>()) {
return lowerMmaToDotOperand(op, adaptor, rewriter);
}
#ifdef USE_ROCM
if (srcLayout.isa<MfmaEncodingAttr>() &&
dstLayout.isa<DotOperandEncodingAttr>()) {
return lowerMfmaToDotOperand(op, adaptor, rewriter);
}
#endif
if (srcLayout.isa<SharedEncodingAttr>() &&
isaDistributedLayout(dstLayout)) {
return lowerSharedToDistributed(op, adaptor, rewriter);
@@ -205,48 +211,13 @@ private:
}
#ifdef USE_ROCM
if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
SmallVector<Value> mfmaColIdx(4);
SmallVector<Value> mfmaRowIdx(16);
Value threadId = getThreadId(rewriter, loc);
unsigned iWaveSize = triton::gpu::getWarpSize(layout);
Value warpSize = i32_val(iWaveSize);
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
// TODO: fix the bug in MMAEncodingAttr document
SmallVector<Value> multiDimWarpId(2);
multiDimWarpId[0] = urem(warpId, i32_val(mfmaLayout.getWarpsPerCTA()[0]));
multiDimWarpId[1] = udiv(warpId, i32_val(mfmaLayout.getWarpsPerCTA()[0]));
Value _0 = i32_val(0);
Value _1 = i32_val(1);
Value _4 = i32_val(4);
Value _8 = i32_val(8);
Value _32 = i32_val(32);
multiDimWarpId[0] = urem(multiDimWarpId[0], i32_val(shape[0] / 32));
multiDimWarpId[1] = urem(multiDimWarpId[1], i32_val(shape[1] / 32));
Value halfOffset = select(icmp_uge(laneId, _32), _4, _0);
Value mfmaGroup32 = urem(laneId, _32);
Value rowWarpOffset = mul(multiDimWarpId[0], _32);
for (unsigned block = 0; block < 4; ++block) {
mfmaRowIdx[4 * block] = block == 0
? add(halfOffset, rowWarpOffset)
: add(mfmaRowIdx[4 * (block - 1)], _8);
for (int r = 1; r < 4; ++r) {
mfmaRowIdx[4 * block + r] = add(mfmaRowIdx[4 * block + r - 1], _1);
}
}
Value colWarpOffset = mul(multiDimWarpId[1], _32);
mfmaColIdx[0] = add(mfmaGroup32, colWarpOffset);
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, type);
SmallVector<SmallVector<unsigned>> offsets;
assert(rank == 2);
SmallVector<Value> multiDimOffset(rank);
multiDimOffset[0] = mfmaRowIdx[elemId % 16];
multiDimOffset[1] = mfmaColIdx[0];
multiDimOffset[0] = add(multiDimOffset[0],
i32_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
multiDimOffset[1] = add(multiDimOffset[1],
i32_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
emitMfmaOffsetForCTA(mfmaLayout, offsets, multiDimCTAInRepId[0], multiDimCTAInRepId[1]);
multiDimOffset[0] = add(multiDimBase[0], i32_val(offsets[elemId][0]));
multiDimOffset[1] = add(multiDimBase[1], i32_val(offsets[elemId][1]));
return multiDimOffset;
}
#endif
@@ -676,6 +647,45 @@ private:
return success();
}
#ifdef USE_ROCM
LogicalResult
lowerMfmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
auto dstTy = op.getResult().getType().cast<RankedTensorType>();
if (isMfmaToDotShortcut(srcTy, dstTy)) {
// get source values
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
rewriter, srcTy);
unsigned elems = getTotalElemsPerThread(srcTy);
Type elemTy =
this->getTypeConverter()->convertType(srcTy.getElementType());
// for the destination type, we need to pack values together
// so they can be consumed by tensor core operations
SmallVector<Value> vecVals;
SmallVector<Type> types;
auto elemSize = elemTy.getIntOrFloatBitWidth();
// TODO: Support types other than float16.
assert(type::isFloat(elemTy) && elemSize == 16);
unsigned vecSize = 4;
Type vecTy = vec_ty(elemTy, vecSize);
types = SmallVector<Type>(elems / vecSize, vecTy);
for (unsigned i = 0; i < elems; i += vecSize) {
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
for (unsigned j = 0; j < vecSize; j++)
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
vecVals.push_back(packed);
}
Value view =
getTypeConverter()->packLLElements(loc, vecVals, rewriter, dstTy);
rewriter.replaceOp(op, view);
return success();
}
return failure();
}
#endif
// mma -> dot_operand
LogicalResult
lowerMmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,