mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FA OPTIMIZATION] Keep results of FA dot operations in registers (#247)
* [WIP][FA OPTIMIZATION] Optimize chain dot This commit optimizes chain dot operation by keeping results of the first dot operation in registers. * [FA OPTIMIZATION] Enable lowering pipeline for keeping result of chain dot in registers * Move operand swapping in ttgir -> llir lowering phase * Refactor emitMfmaOffsetForCTA function to be more readable * Fix accidental change in 06-fused-attention.py * Address review comments * Fix rebase errors
This commit is contained in:
@@ -126,6 +126,10 @@ template <typename T> T nextPowOf2(T n) {
|
||||
return n + 1;
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
|
||||
#endif
|
||||
|
||||
/// Multi-root DAG topological sort.
|
||||
/// Performs a topological sort of the Operation in the `toSort` SetVector.
|
||||
/// Returns a topologically sorted SetVector.
|
||||
|
||||
@@ -562,7 +562,8 @@ The data will be distributed between threads as follows:
|
||||
let parameters = (
|
||||
ins
|
||||
"unsigned":$nonKDim,
|
||||
ArrayRefParameter<"unsigned">:$warpsPerCTA
|
||||
ArrayRefParameter<"unsigned">:$warpsPerCTA,
|
||||
"bool":$isTransposed
|
||||
);
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
@@ -17,6 +17,7 @@ using ::mlir::triton::gpu::getContigPerThread;
|
||||
using ::mlir::triton::gpu::getOrder;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::getSizePerThread;
|
||||
using ::mlir::triton::gpu::MfmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||
@@ -63,6 +64,14 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
if (isMmaToDotShortcut(srcTy, dstTy))
|
||||
return {};
|
||||
|
||||
#ifdef USE_ROCM
|
||||
if (srcLayout.isa<MfmaEncodingAttr>() &&
|
||||
srcLayout.dyn_cast<MfmaEncodingAttr>().getIsTransposed() &&
|
||||
dstLayout.isa<DotOperandEncodingAttr>())
|
||||
if (isMfmaToDotShortcut(srcTy, dstTy))
|
||||
return {};
|
||||
#endif
|
||||
|
||||
assert(srcLayout && dstLayout &&
|
||||
"Unexpected layout in getScratchConfigForCvtLayout()");
|
||||
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
|
||||
|
||||
@@ -195,6 +195,22 @@ bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
!srcTy.getElementType().isF32();
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto dstLayout = dstTy.getEncoding();
|
||||
auto mfmaLayout = srcLayout.cast<triton::gpu::MfmaEncodingAttr>();
|
||||
auto dotOperandLayout = dstLayout.cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
// TODO: Remove the restriction on the warpsPerCTA once chain dot testing is
|
||||
// improved. In addition, we can enable this shortcut for regular MFMA
|
||||
// layout when opIdx == 1.
|
||||
return mfmaLayout.getWarpsPerCTA()[1] == 1 &&
|
||||
dotOperandLayout.getOpIdx() == 0 &&
|
||||
dotOperandLayout.getParent() == mfmaLayout &&
|
||||
mfmaLayout.getIsTransposed() && srcTy.getElementType().isF16();
|
||||
}
|
||||
#endif
|
||||
|
||||
bool isSingleValue(Value value) {
|
||||
// Don't consider load as expensive if it is loading a scalar.
|
||||
if (auto tensorTy = value.getType().dyn_cast<RankedTensorType>())
|
||||
|
||||
@@ -87,6 +87,12 @@ public:
|
||||
dstLayout.isa<DotOperandEncodingAttr>()) {
|
||||
return lowerMmaToDotOperand(op, adaptor, rewriter);
|
||||
}
|
||||
#ifdef USE_ROCM
|
||||
if (srcLayout.isa<MfmaEncodingAttr>() &&
|
||||
dstLayout.isa<DotOperandEncodingAttr>()) {
|
||||
return lowerMfmaToDotOperand(op, adaptor, rewriter);
|
||||
}
|
||||
#endif
|
||||
if (srcLayout.isa<SharedEncodingAttr>() &&
|
||||
isaDistributedLayout(dstLayout)) {
|
||||
return lowerSharedToDistributed(op, adaptor, rewriter);
|
||||
@@ -205,48 +211,13 @@ private:
|
||||
}
|
||||
#ifdef USE_ROCM
|
||||
if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
SmallVector<Value> mfmaColIdx(4);
|
||||
SmallVector<Value> mfmaRowIdx(16);
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
unsigned iWaveSize = triton::gpu::getWarpSize(layout);
|
||||
Value warpSize = i32_val(iWaveSize);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
// TODO: fix the bug in MMAEncodingAttr document
|
||||
SmallVector<Value> multiDimWarpId(2);
|
||||
multiDimWarpId[0] = urem(warpId, i32_val(mfmaLayout.getWarpsPerCTA()[0]));
|
||||
multiDimWarpId[1] = udiv(warpId, i32_val(mfmaLayout.getWarpsPerCTA()[0]));
|
||||
Value _0 = i32_val(0);
|
||||
Value _1 = i32_val(1);
|
||||
Value _4 = i32_val(4);
|
||||
Value _8 = i32_val(8);
|
||||
Value _32 = i32_val(32);
|
||||
multiDimWarpId[0] = urem(multiDimWarpId[0], i32_val(shape[0] / 32));
|
||||
multiDimWarpId[1] = urem(multiDimWarpId[1], i32_val(shape[1] / 32));
|
||||
Value halfOffset = select(icmp_uge(laneId, _32), _4, _0);
|
||||
Value mfmaGroup32 = urem(laneId, _32);
|
||||
Value rowWarpOffset = mul(multiDimWarpId[0], _32);
|
||||
for (unsigned block = 0; block < 4; ++block) {
|
||||
mfmaRowIdx[4 * block] = block == 0
|
||||
? add(halfOffset, rowWarpOffset)
|
||||
: add(mfmaRowIdx[4 * (block - 1)], _8);
|
||||
for (int r = 1; r < 4; ++r) {
|
||||
mfmaRowIdx[4 * block + r] = add(mfmaRowIdx[4 * block + r - 1], _1);
|
||||
}
|
||||
}
|
||||
Value colWarpOffset = mul(multiDimWarpId[1], _32);
|
||||
mfmaColIdx[0] = add(mfmaGroup32, colWarpOffset);
|
||||
|
||||
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, type);
|
||||
SmallVector<SmallVector<unsigned>> offsets;
|
||||
assert(rank == 2);
|
||||
SmallVector<Value> multiDimOffset(rank);
|
||||
|
||||
multiDimOffset[0] = mfmaRowIdx[elemId % 16];
|
||||
|
||||
multiDimOffset[1] = mfmaColIdx[0];
|
||||
multiDimOffset[0] = add(multiDimOffset[0],
|
||||
i32_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
|
||||
multiDimOffset[1] = add(multiDimOffset[1],
|
||||
i32_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
|
||||
emitMfmaOffsetForCTA(mfmaLayout, offsets, multiDimCTAInRepId[0], multiDimCTAInRepId[1]);
|
||||
multiDimOffset[0] = add(multiDimBase[0], i32_val(offsets[elemId][0]));
|
||||
multiDimOffset[1] = add(multiDimBase[1], i32_val(offsets[elemId][1]));
|
||||
return multiDimOffset;
|
||||
}
|
||||
#endif
|
||||
@@ -676,6 +647,45 @@ private:
|
||||
return success();
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
LogicalResult
|
||||
lowerMfmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
|
||||
auto dstTy = op.getResult().getType().cast<RankedTensorType>();
|
||||
if (isMfmaToDotShortcut(srcTy, dstTy)) {
|
||||
// get source values
|
||||
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
|
||||
rewriter, srcTy);
|
||||
unsigned elems = getTotalElemsPerThread(srcTy);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(srcTy.getElementType());
|
||||
// for the destination type, we need to pack values together
|
||||
// so they can be consumed by tensor core operations
|
||||
SmallVector<Value> vecVals;
|
||||
SmallVector<Type> types;
|
||||
auto elemSize = elemTy.getIntOrFloatBitWidth();
|
||||
// TODO: Support types other than float16.
|
||||
assert(type::isFloat(elemTy) && elemSize == 16);
|
||||
unsigned vecSize = 4;
|
||||
Type vecTy = vec_ty(elemTy, vecSize);
|
||||
types = SmallVector<Type>(elems / vecSize, vecTy);
|
||||
for (unsigned i = 0; i < elems; i += vecSize) {
|
||||
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||
for (unsigned j = 0; j < vecSize; j++)
|
||||
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
|
||||
vecVals.push_back(packed);
|
||||
}
|
||||
Value view =
|
||||
getTypeConverter()->packLLElements(loc, vecVals, rewriter, dstTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
#endif
|
||||
|
||||
// mma -> dot_operand
|
||||
LogicalResult
|
||||
lowerMmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
|
||||
@@ -142,7 +142,9 @@ struct DotOpMFMAConversionHelper {
|
||||
}
|
||||
|
||||
for (size_t k = 0; k < numRepK; k++) {
|
||||
acc = generateMFMAOp(mfmaTy, ha[{m, k}], hb[{n, k}], acc);
|
||||
acc = mfmaLayout.getIsTransposed()
|
||||
? generateMFMAOp(mfmaTy, hb[{n, k}], ha[{m, k}], acc)
|
||||
: generateMFMAOp(mfmaTy, ha[{m, k}], hb[{n, k}], acc);
|
||||
}
|
||||
for (unsigned v = 0; v < 16; ++v) {
|
||||
fc[m * numRepN * 16 + n * 16 + v] =
|
||||
|
||||
@@ -130,6 +130,7 @@ private:
|
||||
writeIdx[axis] = udiv(index[axis], axisSizePerThread);
|
||||
}
|
||||
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
// TODO: Support MFMA transposed layout.
|
||||
if (axis == 0) {
|
||||
// Because warpTileSize = [32, 32] and threadsPerWarp = [2, 32], each 2
|
||||
// rows in smem would correspond to a warp. The mapping
|
||||
@@ -424,8 +425,17 @@ private:
|
||||
// Reduce within warps
|
||||
for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) {
|
||||
SmallVector<Value> shfl(op.getNumOperands());
|
||||
unsigned shuffleIdx = N;
|
||||
#ifdef USE_ROCM
|
||||
if (inMfma && inMfma.getIsTransposed()) {
|
||||
assert(sizeIntraWarps == 2);
|
||||
// Adjecant threads in y dimension in transposed MFMA layout are 32
|
||||
// apart: [[0 0 0 0 32 32 32 32 ...] [1 1 1 1 33 33 33 33 ...] ...].
|
||||
shuffleIdx = 32;
|
||||
}
|
||||
#endif
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
shfl[i] = shflSync(loc, rewriter, acc[i], N);
|
||||
shfl[i] = shflSync(loc, rewriter, acc[i], shuffleIdx);
|
||||
}
|
||||
accumulate(rewriter, *combineOp, acc, shfl, false);
|
||||
}
|
||||
|
||||
@@ -655,6 +655,34 @@ public:
|
||||
llvm_unreachable("unsupported emitOffsetForLayout");
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
void emitMfmaOffsetForCTA(const MfmaEncodingAttr &mfmaLayout,
|
||||
SmallVector<SmallVector<unsigned>> &offsets,
|
||||
unsigned ctaOffsetX, unsigned ctaOffsetY) const {
|
||||
// MFMA output tile consists of repeated "dot operand B" layout groups along
|
||||
// row axis. This variable defines number of these groups.
|
||||
const unsigned numGroups = 4;
|
||||
const unsigned elemsPerThreadPerGroup = 4;
|
||||
auto warpSize = getWarpSize(mfmaLayout);
|
||||
assert(warpSize == 64);
|
||||
auto shapePerCta = getShapePerCTA(mfmaLayout);
|
||||
for (unsigned block = 0; block < numGroups; block++) {
|
||||
unsigned rowOrColOffset = block * elemsPerThreadPerGroup * warpSize / 32;
|
||||
for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) {
|
||||
if (mfmaLayout.getIsTransposed()) {
|
||||
offsets.push_back(
|
||||
{ctaOffsetX * shapePerCta[0],
|
||||
ctaOffsetY * shapePerCta[1] + elem + rowOrColOffset});
|
||||
} else {
|
||||
offsets.push_back(
|
||||
{ctaOffsetX * shapePerCta[0] + elem + rowOrColOffset,
|
||||
ctaOffsetY * shapePerCta[1]});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Emit indices
|
||||
// -----------------------------------------------------------------------
|
||||
@@ -979,16 +1007,16 @@ private:
|
||||
|
||||
SmallVector<Value>
|
||||
emitBaseIndexForMfmaLayout(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const MfmaEncodingAttr &mmaLayout,
|
||||
const MfmaEncodingAttr &mfmaLayout,
|
||||
RankedTensorType type) const {
|
||||
auto shape = type.getShape();
|
||||
auto _warpsPerCTA = mmaLayout.getWarpsPerCTA();
|
||||
auto _warpsPerCTA = mfmaLayout.getWarpsPerCTA();
|
||||
assert(_warpsPerCTA.size() == 2);
|
||||
SmallVector<Value> warpsPerCTA = {i32_val(_warpsPerCTA[0]),
|
||||
i32_val(_warpsPerCTA[1])};
|
||||
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value warpSize = i32_val(triton::gpu::getWarpSize(mmaLayout));
|
||||
Value warpSize = i32_val(triton::gpu::getWarpSize(mfmaLayout));
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
@@ -1000,29 +1028,35 @@ private:
|
||||
Value offWarp1 = mul(warpId1, i32_val(32));
|
||||
|
||||
SmallVector<Value> multiDimBase(2);
|
||||
multiDimBase[0] = add(mul(i32_val(4), udiv(laneId, i32_val(32))), offWarp0);
|
||||
multiDimBase[1] = add(urem(laneId, i32_val(32)), offWarp1);
|
||||
if (mfmaLayout.getIsTransposed()) {
|
||||
multiDimBase[1] =
|
||||
add(mul(i32_val(4), udiv(laneId, i32_val(32))), offWarp1);
|
||||
multiDimBase[0] = add(urem(laneId, i32_val(32)), offWarp0);
|
||||
} else {
|
||||
multiDimBase[0] =
|
||||
add(mul(i32_val(4), udiv(laneId, i32_val(32))), offWarp0);
|
||||
multiDimBase[1] = add(urem(laneId, i32_val(32)), offWarp1);
|
||||
}
|
||||
return multiDimBase;
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>>
|
||||
emitOffsetForMfmaLayout(const MfmaEncodingAttr &mmaLayout,
|
||||
emitOffsetForMfmaLayout(const MfmaEncodingAttr &mfmaLayout,
|
||||
RankedTensorType type) const {
|
||||
|
||||
auto tensorShape = type.getShape();
|
||||
SmallVector<SmallVector<unsigned>> offsets;
|
||||
auto shapePerCta = getShapePerCTA(mmaLayout);
|
||||
const unsigned iterationCount = 4;
|
||||
auto shapePerCta = getShapePerCTA(mfmaLayout);
|
||||
|
||||
for (unsigned i = 0; i < tensorShape[0]; i += shapePerCta[0]) {
|
||||
for (unsigned j = 0; j < tensorShape[1]; j += shapePerCta[1]) {
|
||||
unsigned rowOffset = 0;
|
||||
for (unsigned k = 0; k < iterationCount; k++) {
|
||||
for (unsigned l = 0; l < iterationCount; l++) {
|
||||
offsets.push_back({i + l + rowOffset, j});
|
||||
}
|
||||
rowOffset += iterationCount * 2;
|
||||
}
|
||||
SmallVector<unsigned> numCTAPerDim(2);
|
||||
for (unsigned d = 0; d < 2; ++d) {
|
||||
unsigned inPerCTA = std::min<unsigned>(tensorShape[d], shapePerCta[d]);
|
||||
numCTAPerDim[d] = ceil<unsigned>(tensorShape[d], inPerCTA);
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < numCTAPerDim[0]; ++i) {
|
||||
for (unsigned j = 0; j < numCTAPerDim[1]; ++j) {
|
||||
emitMfmaOffsetForCTA(mfmaLayout, offsets, i, j);
|
||||
}
|
||||
}
|
||||
return offsets;
|
||||
|
||||
@@ -310,6 +310,9 @@ public:
|
||||
|
||||
// Preprocess
|
||||
decomposeMmaToDotOperand(mod, numWarps, threadsPerWarp);
|
||||
#ifdef USE_ROCM
|
||||
decomposeMfmaToDotOperand(mod, numWarps, threadsPerWarp);
|
||||
#endif
|
||||
decomposeBlockedToDotOperand(mod);
|
||||
decomposeInsertSliceAsyncOp(mod);
|
||||
|
||||
@@ -464,6 +467,36 @@ private:
|
||||
});
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
void decomposeMfmaToDotOperand(ModuleOp mod, int numWarps,
|
||||
int threadsPerWarp) const {
|
||||
// Replace `mfma -> dot_op` with `mfma -> blocked -> dot_op`
|
||||
// unless certain conditions are met
|
||||
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
||||
OpBuilder builder(cvtOp);
|
||||
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
||||
auto srcMfma =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::MfmaEncodingAttr>();
|
||||
auto dstDotOp =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (srcMfma && dstDotOp && !isMfmaToDotShortcut(srcType, dstType)) {
|
||||
auto tmpType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu::BlockedEncodingAttr::get(
|
||||
mod.getContext(), srcType.getShape(), getSizePerThread(srcMfma),
|
||||
getOrder(srcMfma), numWarps, threadsPerWarp));
|
||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), dstType, tmp);
|
||||
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
||||
cvtOp.erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
||||
void decomposeBlockedToDotOperand(ModuleOp mod) const {
|
||||
// Replace `blocked -> dot_op` with `blocked -> shared -> dot_op`
|
||||
// because the codegen doesn't handle `blocked -> dot_op` directly
|
||||
|
||||
@@ -87,7 +87,11 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
|
||||
return {8, 4};
|
||||
}
|
||||
if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
return {2, 32};
|
||||
if (mfmaLayout.getIsTransposed()) {
|
||||
return {32, 2};
|
||||
} else {
|
||||
return {2, 32};
|
||||
}
|
||||
}
|
||||
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
auto parent = sliceLayout.getParent();
|
||||
@@ -200,7 +204,11 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
||||
llvm_unreachable("Unexpected mma version");
|
||||
}
|
||||
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
return {16, 1};
|
||||
if (mfmaLayout.getIsTransposed()) {
|
||||
return {1, 16};
|
||||
} else {
|
||||
return {16, 1};
|
||||
}
|
||||
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
auto parentLayout = dotLayout.getParent();
|
||||
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
|
||||
@@ -457,6 +465,17 @@ static LogicalResult parseIntAttrValue(AsmParser &parser, Attribute attr,
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult parseBoolAttrValue(AsmParser &parser, Attribute attr,
|
||||
bool &value, StringRef desc) {
|
||||
auto boolAttr = attr.dyn_cast<BoolAttr>();
|
||||
if (!boolAttr) {
|
||||
parser.emitError(parser.getNameLoc(), "expected bool type in ") << desc;
|
||||
return failure();
|
||||
}
|
||||
value = boolAttr.getValue();
|
||||
return success();
|
||||
}
|
||||
|
||||
// parse an array of integers
|
||||
static LogicalResult parseIntArrayAttr(AsmParser &parser,
|
||||
const NamedAttribute &attr,
|
||||
@@ -481,6 +500,11 @@ static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr,
|
||||
return parseIntAttrValue(parser, attr.getValue(), value, desc);
|
||||
};
|
||||
|
||||
static LogicalResult parseBool(AsmParser &parser, const NamedAttribute &attr,
|
||||
bool &value, StringRef desc) {
|
||||
return parseBoolAttrValue(parser, attr.getValue(), value, desc);
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Attribute methods
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -551,10 +575,17 @@ MfmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
|
||||
assert(rank == 2 && "Unexpected rank of mma layout");
|
||||
|
||||
SmallVector<unsigned> elemsPerThread(rank);
|
||||
unsigned elemsCol = ceil<unsigned>(shape[1], 32 * getWarpsPerCTA()[1]);
|
||||
unsigned elemsRow = ceil<unsigned>(shape[0], 32 * getWarpsPerCTA()[0]) * 16;
|
||||
elemsPerThread[0] = elemsRow;
|
||||
elemsPerThread[1] = elemsCol;
|
||||
if (getIsTransposed()) {
|
||||
unsigned elemsCol = ceil<unsigned>(shape[1], 32 * getWarpsPerCTA()[1]) * 16;
|
||||
unsigned elemsRow = ceil<unsigned>(shape[0], 32 * getWarpsPerCTA()[0]);
|
||||
elemsPerThread[0] = elemsRow;
|
||||
elemsPerThread[1] = elemsCol;
|
||||
} else {
|
||||
unsigned elemsCol = ceil<unsigned>(shape[1], 32 * getWarpsPerCTA()[1]);
|
||||
unsigned elemsRow = ceil<unsigned>(shape[0], 32 * getWarpsPerCTA()[0]) * 16;
|
||||
elemsPerThread[0] = elemsRow;
|
||||
elemsPerThread[1] = elemsCol;
|
||||
}
|
||||
return elemsPerThread;
|
||||
}
|
||||
|
||||
@@ -908,6 +939,7 @@ Attribute MfmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
|
||||
unsigned nonKDim = 0;
|
||||
SmallVector<unsigned, 2> warpsPerCTA;
|
||||
bool isTransposed;
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "nonKDim") {
|
||||
@@ -917,18 +949,21 @@ Attribute MfmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (attr.getName() == "warpsPerCTA") {
|
||||
if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "isTransposed") {
|
||||
if (parseBool(parser, attr, isTransposed, "isTransposed").failed())
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<MfmaEncodingAttr>(parser.getContext(), nonKDim,
|
||||
warpsPerCTA);
|
||||
warpsPerCTA, isTransposed);
|
||||
}
|
||||
|
||||
void MfmaEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "nonKDim = " << getNonKDim() << ", "
|
||||
<< "warpsPerCTA = [" << getWarpsPerCTA() << "]"
|
||||
<< "}>";
|
||||
<< ", isTransposed = " << getIsTransposed() << "}>";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -126,6 +126,18 @@ public:
|
||||
BlockedToMFMA(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context) {}
|
||||
|
||||
bool isChainDot(triton::DotOp &dotOp) const {
|
||||
auto filter = [&dotOp](Operation *op) {
|
||||
return op->getParentRegion() == dotOp->getParentRegion();
|
||||
};
|
||||
auto slices = mlir::getSlice(dotOp, filter);
|
||||
for (Operation *op : slices) {
|
||||
if (isa<triton::DotOp>(op) && (op != dotOp))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
@@ -157,8 +169,9 @@ public:
|
||||
|
||||
auto warpsPerTile = warpsPerTileMI200(dotOp, retShape, numWarps);
|
||||
|
||||
mfmaEnc = triton::gpu::MfmaEncodingAttr::get(oldRetType.getContext(),
|
||||
nonKDim, warpsPerTile);
|
||||
bool isTransposed = isChainDot(dotOp);
|
||||
mfmaEnc = triton::gpu::MfmaEncodingAttr::get(
|
||||
oldRetType.getContext(), nonKDim, warpsPerTile, isTransposed);
|
||||
|
||||
auto newRetType =
|
||||
RankedTensorType::get(retShape, oldRetType.getElementType(), mfmaEnc);
|
||||
@@ -194,7 +207,6 @@ public:
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
class BlockedToMMA : public mlir::RewritePattern {
|
||||
|
||||
@@ -48,6 +48,16 @@ public:
|
||||
dstDotOp.getParent() == srcMmaEncoding))
|
||||
return;
|
||||
}
|
||||
#ifdef USE_ROCM
|
||||
if (auto srcMfmaEncoding =
|
||||
srcEncoding.dyn_cast<triton::gpu::MfmaEncodingAttr>()) {
|
||||
|
||||
if (srcMfmaEncoding.getWarpsPerCTA()[1] == 1 &&
|
||||
srcMfmaEncoding.getIsTransposed() &&
|
||||
dstDotOp.getParent() == srcMfmaEncoding)
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
auto tmpType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu::SharedEncodingAttr::get(
|
||||
|
||||
@@ -1202,7 +1202,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype",
|
||||
[(*shape, 2, False, False, epilogue, allow_tf32, dtype)
|
||||
for shape in [(64, 64, 64), (32, 32, 32)]
|
||||
for epilogue in ['none', 'trans', 'add-matrix']
|
||||
for epilogue in ['none', 'trans', 'add-matrix', 'chain-dot', 'softmax']
|
||||
for allow_tf32 in [True, False]
|
||||
for dtype in ['float16', 'float32']
|
||||
if not (allow_tf32 and (dtype in ['float16']))] +
|
||||
@@ -2128,12 +2128,13 @@ class MmaLayout:
|
||||
|
||||
|
||||
class MfmaLayout:
|
||||
def __init__(self, non_k_dim, warps_per_cta):
|
||||
def __init__(self, non_k_dim, warps_per_cta, isTransposed):
|
||||
self.non_k_dim = str(non_k_dim)
|
||||
self.warps_per_cta = str(warps_per_cta)
|
||||
self.isTransposed = str(isTransposed).lower()
|
||||
|
||||
def __str__(self):
|
||||
return f"#triton_gpu.mfma<{{nonKDim = {self.non_k_dim}, warpsPerCTA = {self.warps_per_cta}}}>"
|
||||
return f"#triton_gpu.mfma<{{nonKDim = {self.non_k_dim}, warpsPerCTA = {self.warps_per_cta}, isTransposed = {self.isTransposed}}}>"
|
||||
|
||||
|
||||
class BlockedLayout:
|
||||
@@ -2238,8 +2239,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
|
||||
|
||||
if _get_warp_size() == 64:
|
||||
layouts = [
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1]),
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[2, 2]),
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], isTransposed=True),
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[2, 2], isTransposed=False),
|
||||
]
|
||||
shapes = [[128, 32], [128, 128], [32, 128], [64, 64]]
|
||||
else:
|
||||
@@ -2255,6 +2256,9 @@ else:
|
||||
@pytest.mark.parametrize("src_layout", layouts)
|
||||
@pytest.mark.parametrize("axis", [0, 1])
|
||||
def test_reduce_layouts(M, N, src_layout, axis, device='cuda'):
|
||||
if torch.version.hip is not None:
|
||||
if src_layout.isTransposed and axis == 0:
|
||||
pytest.skip("Reduce along axis 0 is not supported in transposed mfma layout")
|
||||
rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1"
|
||||
rdims_1d = f"{N}" if axis == 0 else f"{M}"
|
||||
store_range = "%7" if axis == 0 else "%1"
|
||||
@@ -2318,7 +2322,7 @@ def test_reduce_layouts(M, N, src_layout, axis, device='cuda'):
|
||||
|
||||
@pytest.mark.parametrize("shape", [(64, 64)])
|
||||
@pytest.mark.parametrize("dtype", ['float16'])
|
||||
@pytest.mark.parametrize("src_layout", [MfmaLayout(non_k_dim=32, warps_per_cta=[2, 1]), MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1])])
|
||||
@pytest.mark.parametrize("src_layout", [MfmaLayout(non_k_dim=32, warps_per_cta=[2, 1], isTransposed=False), MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], isTransposed = True)])
|
||||
@pytest.mark.parametrize("dst_layout", [BlockedLayout([1, 4], [4, 16], [1, 1], [1, 0])])
|
||||
def test_make_range(dtype, shape, src_layout, dst_layout, device='cuda'):
|
||||
ir = f"""
|
||||
|
||||
@@ -12,6 +12,14 @@ import triton.ops
|
||||
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
|
||||
def test_op(Z, H, N_CTX, D_HEAD, dtype):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if torch.version.hip is not None:
|
||||
if dtype != torch.float16:
|
||||
pytest.skip("Currently flash attention on AMD gpu is only supported in fp16.")
|
||||
if D_HEAD < 32:
|
||||
pytest.skip("D_HEAD < 32 is not supported. It will be enabled once smaller tile size is supported in MFMA pipeline.")
|
||||
if D_HEAD > 64:
|
||||
pytest.skip("D_HEAD > 64 is not supported. Currently it causes shared memory out of resource error.")
|
||||
|
||||
if capability[0] < 8:
|
||||
pytest.skip("Flash attention only supported for compute capability < 80")
|
||||
torch.manual_seed(20)
|
||||
@@ -29,21 +37,26 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype):
|
||||
p = torch.softmax(p.float(), dim=-1).to(dtype)
|
||||
# p = torch.exp(p)
|
||||
ref_out = torch.matmul(p, v)
|
||||
ref_out.backward(dout)
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
|
||||
if torch.version.hip is None:
|
||||
ref_out.backward(dout)
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
|
||||
# # triton implementation
|
||||
tri_out = triton.ops.attention(q, k, v, sm_scale)
|
||||
# print(ref_out)
|
||||
# print(tri_out)
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
if torch.version.hip is None:
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
atol = 1e-1 if dtype == torch.bfloat16 else 1e-2
|
||||
torch.testing.assert_allclose(ref_out, tri_out, atol=atol, rtol=0)
|
||||
torch.testing.assert_allclose(ref_dv, tri_dv, atol=atol, rtol=0)
|
||||
torch.testing.assert_allclose(ref_dk, tri_dk, atol=atol, rtol=0)
|
||||
torch.testing.assert_allclose(ref_dq, tri_dq, atol=atol, rtol=0)
|
||||
if torch.version.hip is None:
|
||||
torch.testing.assert_allclose(ref_dv, tri_dv, atol=atol, rtol=0)
|
||||
torch.testing.assert_allclose(ref_dk, tri_dk, atol=atol, rtol=0)
|
||||
torch.testing.assert_allclose(ref_dq, tri_dq, atol=atol, rtol=0)
|
||||
|
||||
@@ -202,7 +202,10 @@ class _attention(torch.autograd.Function):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8:
|
||||
raise RuntimeError("Flash attention currently only supported for compute capability >= 80")
|
||||
BLOCK = 128
|
||||
if torch.version.hip is not None:
|
||||
BLOCK = 64
|
||||
else:
|
||||
BLOCK = 128
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
|
||||
@@ -292,18 +292,21 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
p = torch.softmax(p.float(), dim=-1).half()
|
||||
# p = torch.exp(p)
|
||||
ref_out = torch.matmul(p, v)
|
||||
ref_out.backward(dout)
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
|
||||
if torch.version.hip is None:
|
||||
ref_out.backward(dout)
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
# # triton implementation
|
||||
tri_out = attention(q, k, v, sm_scale)
|
||||
# print(ref_out)
|
||||
# print(tri_out)
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
if torch.version.hip is None:
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
|
||||
if torch.version.hip is None:
|
||||
|
||||
@@ -1186,7 +1186,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0]}>
|
||||
#mfma0 = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [1, 1]}>
|
||||
#mfma0 = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma0}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma0}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -1209,7 +1209,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTranspose=false}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// CHECK-LABEL: convert_layout_mfma_block
|
||||
@@ -1359,7 +1359,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
Reference in New Issue
Block a user