[FA OPTIMIZATION] Keep results of FA dot operations in registers (#247)

* [WIP][FA OPTIMIZATION] Optimize chain dot

This commit optimizes chain dot operation by keeping
results of the first dot operation in registers.

* [FA OPTIMIZATION] Enable lowering pipeline for keeping result of chain dot in registers

* Move operand swapping in ttgir -> llir lowering phase

* Refactor emitMfmaOffsetForCTA function to be more readable

* Fix accidental change in 06-fused-attention.py

* Address review comments

* Fix rebase errors
This commit is contained in:
oplavsic
2023-07-12 22:25:55 +02:00
committed by GitHub
parent 4d0deef45f
commit d6e51fd221
17 changed files with 299 additions and 100 deletions

View File

@@ -126,6 +126,10 @@ template <typename T> T nextPowOf2(T n) {
return n + 1;
}
#ifdef USE_ROCM
bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
#endif
/// Multi-root DAG topological sort.
/// Performs a topological sort of the Operation in the `toSort` SetVector.
/// Returns a topologically sorted SetVector.

View File

@@ -562,7 +562,8 @@ The data will be distributed between threads as follows:
let parameters = (
ins
"unsigned":$nonKDim,
ArrayRefParameter<"unsigned">:$warpsPerCTA
ArrayRefParameter<"unsigned">:$warpsPerCTA,
"bool":$isTransposed
);
let hasCustomAssemblyFormat = 1;

View File

@@ -17,6 +17,7 @@ using ::mlir::triton::gpu::getContigPerThread;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::MfmaEncodingAttr;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
@@ -63,6 +64,14 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
if (isMmaToDotShortcut(srcTy, dstTy))
return {};
#ifdef USE_ROCM
if (srcLayout.isa<MfmaEncodingAttr>() &&
srcLayout.dyn_cast<MfmaEncodingAttr>().getIsTransposed() &&
dstLayout.isa<DotOperandEncodingAttr>())
if (isMfmaToDotShortcut(srcTy, dstTy))
return {};
#endif
assert(srcLayout && dstLayout &&
"Unexpected layout in getScratchConfigForCvtLayout()");
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);

View File

@@ -195,6 +195,22 @@ bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
!srcTy.getElementType().isF32();
}
#ifdef USE_ROCM
bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
auto srcLayout = srcTy.getEncoding();
auto dstLayout = dstTy.getEncoding();
auto mfmaLayout = srcLayout.cast<triton::gpu::MfmaEncodingAttr>();
auto dotOperandLayout = dstLayout.cast<triton::gpu::DotOperandEncodingAttr>();
// TODO: Remove the restriction on the warpsPerCTA once chain dot testing is
// improved. In addition, we can enable this shortcut for regular MFMA
// layout when opIdx == 1.
return mfmaLayout.getWarpsPerCTA()[1] == 1 &&
dotOperandLayout.getOpIdx() == 0 &&
dotOperandLayout.getParent() == mfmaLayout &&
mfmaLayout.getIsTransposed() && srcTy.getElementType().isF16();
}
#endif
bool isSingleValue(Value value) {
// Don't consider load as expensive if it is loading a scalar.
if (auto tensorTy = value.getType().dyn_cast<RankedTensorType>())

View File

@@ -87,6 +87,12 @@ public:
dstLayout.isa<DotOperandEncodingAttr>()) {
return lowerMmaToDotOperand(op, adaptor, rewriter);
}
#ifdef USE_ROCM
if (srcLayout.isa<MfmaEncodingAttr>() &&
dstLayout.isa<DotOperandEncodingAttr>()) {
return lowerMfmaToDotOperand(op, adaptor, rewriter);
}
#endif
if (srcLayout.isa<SharedEncodingAttr>() &&
isaDistributedLayout(dstLayout)) {
return lowerSharedToDistributed(op, adaptor, rewriter);
@@ -205,48 +211,13 @@ private:
}
#ifdef USE_ROCM
if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
SmallVector<Value> mfmaColIdx(4);
SmallVector<Value> mfmaRowIdx(16);
Value threadId = getThreadId(rewriter, loc);
unsigned iWaveSize = triton::gpu::getWarpSize(layout);
Value warpSize = i32_val(iWaveSize);
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
// TODO: fix the bug in MMAEncodingAttr document
SmallVector<Value> multiDimWarpId(2);
multiDimWarpId[0] = urem(warpId, i32_val(mfmaLayout.getWarpsPerCTA()[0]));
multiDimWarpId[1] = udiv(warpId, i32_val(mfmaLayout.getWarpsPerCTA()[0]));
Value _0 = i32_val(0);
Value _1 = i32_val(1);
Value _4 = i32_val(4);
Value _8 = i32_val(8);
Value _32 = i32_val(32);
multiDimWarpId[0] = urem(multiDimWarpId[0], i32_val(shape[0] / 32));
multiDimWarpId[1] = urem(multiDimWarpId[1], i32_val(shape[1] / 32));
Value halfOffset = select(icmp_uge(laneId, _32), _4, _0);
Value mfmaGroup32 = urem(laneId, _32);
Value rowWarpOffset = mul(multiDimWarpId[0], _32);
for (unsigned block = 0; block < 4; ++block) {
mfmaRowIdx[4 * block] = block == 0
? add(halfOffset, rowWarpOffset)
: add(mfmaRowIdx[4 * (block - 1)], _8);
for (int r = 1; r < 4; ++r) {
mfmaRowIdx[4 * block + r] = add(mfmaRowIdx[4 * block + r - 1], _1);
}
}
Value colWarpOffset = mul(multiDimWarpId[1], _32);
mfmaColIdx[0] = add(mfmaGroup32, colWarpOffset);
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, type);
SmallVector<SmallVector<unsigned>> offsets;
assert(rank == 2);
SmallVector<Value> multiDimOffset(rank);
multiDimOffset[0] = mfmaRowIdx[elemId % 16];
multiDimOffset[1] = mfmaColIdx[0];
multiDimOffset[0] = add(multiDimOffset[0],
i32_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
multiDimOffset[1] = add(multiDimOffset[1],
i32_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
emitMfmaOffsetForCTA(mfmaLayout, offsets, multiDimCTAInRepId[0], multiDimCTAInRepId[1]);
multiDimOffset[0] = add(multiDimBase[0], i32_val(offsets[elemId][0]));
multiDimOffset[1] = add(multiDimBase[1], i32_val(offsets[elemId][1]));
return multiDimOffset;
}
#endif
@@ -676,6 +647,45 @@ private:
return success();
}
#ifdef USE_ROCM
LogicalResult
lowerMfmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
auto dstTy = op.getResult().getType().cast<RankedTensorType>();
if (isMfmaToDotShortcut(srcTy, dstTy)) {
// get source values
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
rewriter, srcTy);
unsigned elems = getTotalElemsPerThread(srcTy);
Type elemTy =
this->getTypeConverter()->convertType(srcTy.getElementType());
// for the destination type, we need to pack values together
// so they can be consumed by tensor core operations
SmallVector<Value> vecVals;
SmallVector<Type> types;
auto elemSize = elemTy.getIntOrFloatBitWidth();
// TODO: Support types other than float16.
assert(type::isFloat(elemTy) && elemSize == 16);
unsigned vecSize = 4;
Type vecTy = vec_ty(elemTy, vecSize);
types = SmallVector<Type>(elems / vecSize, vecTy);
for (unsigned i = 0; i < elems; i += vecSize) {
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
for (unsigned j = 0; j < vecSize; j++)
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
vecVals.push_back(packed);
}
Value view =
getTypeConverter()->packLLElements(loc, vecVals, rewriter, dstTy);
rewriter.replaceOp(op, view);
return success();
}
return failure();
}
#endif
// mma -> dot_operand
LogicalResult
lowerMmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,

View File

@@ -142,7 +142,9 @@ struct DotOpMFMAConversionHelper {
}
for (size_t k = 0; k < numRepK; k++) {
acc = generateMFMAOp(mfmaTy, ha[{m, k}], hb[{n, k}], acc);
acc = mfmaLayout.getIsTransposed()
? generateMFMAOp(mfmaTy, hb[{n, k}], ha[{m, k}], acc)
: generateMFMAOp(mfmaTy, ha[{m, k}], hb[{n, k}], acc);
}
for (unsigned v = 0; v < 16; ++v) {
fc[m * numRepN * 16 + n * 16 + v] =

View File

@@ -130,6 +130,7 @@ private:
writeIdx[axis] = udiv(index[axis], axisSizePerThread);
}
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
// TODO: Support MFMA transposed layout.
if (axis == 0) {
// Because warpTileSize = [32, 32] and threadsPerWarp = [2, 32], each 2
// rows in smem would correspond to a warp. The mapping
@@ -424,8 +425,17 @@ private:
// Reduce within warps
for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) {
SmallVector<Value> shfl(op.getNumOperands());
unsigned shuffleIdx = N;
#ifdef USE_ROCM
if (inMfma && inMfma.getIsTransposed()) {
assert(sizeIntraWarps == 2);
// Adjecant threads in y dimension in transposed MFMA layout are 32
// apart: [[0 0 0 0 32 32 32 32 ...] [1 1 1 1 33 33 33 33 ...] ...].
shuffleIdx = 32;
}
#endif
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
shfl[i] = shflSync(loc, rewriter, acc[i], N);
shfl[i] = shflSync(loc, rewriter, acc[i], shuffleIdx);
}
accumulate(rewriter, *combineOp, acc, shfl, false);
}

View File

@@ -655,6 +655,34 @@ public:
llvm_unreachable("unsupported emitOffsetForLayout");
}
#ifdef USE_ROCM
void emitMfmaOffsetForCTA(const MfmaEncodingAttr &mfmaLayout,
SmallVector<SmallVector<unsigned>> &offsets,
unsigned ctaOffsetX, unsigned ctaOffsetY) const {
// MFMA output tile consists of repeated "dot operand B" layout groups along
// row axis. This variable defines number of these groups.
const unsigned numGroups = 4;
const unsigned elemsPerThreadPerGroup = 4;
auto warpSize = getWarpSize(mfmaLayout);
assert(warpSize == 64);
auto shapePerCta = getShapePerCTA(mfmaLayout);
for (unsigned block = 0; block < numGroups; block++) {
unsigned rowOrColOffset = block * elemsPerThreadPerGroup * warpSize / 32;
for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) {
if (mfmaLayout.getIsTransposed()) {
offsets.push_back(
{ctaOffsetX * shapePerCta[0],
ctaOffsetY * shapePerCta[1] + elem + rowOrColOffset});
} else {
offsets.push_back(
{ctaOffsetX * shapePerCta[0] + elem + rowOrColOffset,
ctaOffsetY * shapePerCta[1]});
}
}
}
}
#endif
// -----------------------------------------------------------------------
// Emit indices
// -----------------------------------------------------------------------
@@ -979,16 +1007,16 @@ private:
SmallVector<Value>
emitBaseIndexForMfmaLayout(Location loc, ConversionPatternRewriter &rewriter,
const MfmaEncodingAttr &mmaLayout,
const MfmaEncodingAttr &mfmaLayout,
RankedTensorType type) const {
auto shape = type.getShape();
auto _warpsPerCTA = mmaLayout.getWarpsPerCTA();
auto _warpsPerCTA = mfmaLayout.getWarpsPerCTA();
assert(_warpsPerCTA.size() == 2);
SmallVector<Value> warpsPerCTA = {i32_val(_warpsPerCTA[0]),
i32_val(_warpsPerCTA[1])};
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(triton::gpu::getWarpSize(mmaLayout));
Value warpSize = i32_val(triton::gpu::getWarpSize(mfmaLayout));
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
@@ -1000,29 +1028,35 @@ private:
Value offWarp1 = mul(warpId1, i32_val(32));
SmallVector<Value> multiDimBase(2);
multiDimBase[0] = add(mul(i32_val(4), udiv(laneId, i32_val(32))), offWarp0);
multiDimBase[1] = add(urem(laneId, i32_val(32)), offWarp1);
if (mfmaLayout.getIsTransposed()) {
multiDimBase[1] =
add(mul(i32_val(4), udiv(laneId, i32_val(32))), offWarp1);
multiDimBase[0] = add(urem(laneId, i32_val(32)), offWarp0);
} else {
multiDimBase[0] =
add(mul(i32_val(4), udiv(laneId, i32_val(32))), offWarp0);
multiDimBase[1] = add(urem(laneId, i32_val(32)), offWarp1);
}
return multiDimBase;
}
SmallVector<SmallVector<unsigned>>
emitOffsetForMfmaLayout(const MfmaEncodingAttr &mmaLayout,
emitOffsetForMfmaLayout(const MfmaEncodingAttr &mfmaLayout,
RankedTensorType type) const {
auto tensorShape = type.getShape();
SmallVector<SmallVector<unsigned>> offsets;
auto shapePerCta = getShapePerCTA(mmaLayout);
const unsigned iterationCount = 4;
auto shapePerCta = getShapePerCTA(mfmaLayout);
for (unsigned i = 0; i < tensorShape[0]; i += shapePerCta[0]) {
for (unsigned j = 0; j < tensorShape[1]; j += shapePerCta[1]) {
unsigned rowOffset = 0;
for (unsigned k = 0; k < iterationCount; k++) {
for (unsigned l = 0; l < iterationCount; l++) {
offsets.push_back({i + l + rowOffset, j});
}
rowOffset += iterationCount * 2;
}
SmallVector<unsigned> numCTAPerDim(2);
for (unsigned d = 0; d < 2; ++d) {
unsigned inPerCTA = std::min<unsigned>(tensorShape[d], shapePerCta[d]);
numCTAPerDim[d] = ceil<unsigned>(tensorShape[d], inPerCTA);
}
for (unsigned i = 0; i < numCTAPerDim[0]; ++i) {
for (unsigned j = 0; j < numCTAPerDim[1]; ++j) {
emitMfmaOffsetForCTA(mfmaLayout, offsets, i, j);
}
}
return offsets;

View File

@@ -310,6 +310,9 @@ public:
// Preprocess
decomposeMmaToDotOperand(mod, numWarps, threadsPerWarp);
#ifdef USE_ROCM
decomposeMfmaToDotOperand(mod, numWarps, threadsPerWarp);
#endif
decomposeBlockedToDotOperand(mod);
decomposeInsertSliceAsyncOp(mod);
@@ -464,6 +467,36 @@ private:
});
}
#ifdef USE_ROCM
void decomposeMfmaToDotOperand(ModuleOp mod, int numWarps,
int threadsPerWarp) const {
// Replace `mfma -> dot_op` with `mfma -> blocked -> dot_op`
// unless certain conditions are met
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
OpBuilder builder(cvtOp);
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvtOp.getType().cast<RankedTensorType>();
auto srcMfma =
srcType.getEncoding().dyn_cast<triton::gpu::MfmaEncodingAttr>();
auto dstDotOp =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (srcMfma && dstDotOp && !isMfmaToDotShortcut(srcType, dstType)) {
auto tmpType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(),
triton::gpu::BlockedEncodingAttr::get(
mod.getContext(), srcType.getShape(), getSizePerThread(srcMfma),
getOrder(srcMfma), numWarps, threadsPerWarp));
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), dstType, tmp);
cvtOp.replaceAllUsesWith(newConvert.getResult());
cvtOp.erase();
}
});
}
#endif
void decomposeBlockedToDotOperand(ModuleOp mod) const {
// Replace `blocked -> dot_op` with `blocked -> shared -> dot_op`
// because the codegen doesn't handle `blocked -> dot_op` directly

View File

@@ -87,7 +87,11 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
return {8, 4};
}
if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
return {2, 32};
if (mfmaLayout.getIsTransposed()) {
return {32, 2};
} else {
return {2, 32};
}
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parent = sliceLayout.getParent();
@@ -200,7 +204,11 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
llvm_unreachable("Unexpected mma version");
}
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
return {16, 1};
if (mfmaLayout.getIsTransposed()) {
return {1, 16};
} else {
return {16, 1};
}
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
auto parentLayout = dotLayout.getParent();
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
@@ -457,6 +465,17 @@ static LogicalResult parseIntAttrValue(AsmParser &parser, Attribute attr,
return success();
}
static LogicalResult parseBoolAttrValue(AsmParser &parser, Attribute attr,
bool &value, StringRef desc) {
auto boolAttr = attr.dyn_cast<BoolAttr>();
if (!boolAttr) {
parser.emitError(parser.getNameLoc(), "expected bool type in ") << desc;
return failure();
}
value = boolAttr.getValue();
return success();
}
// parse an array of integers
static LogicalResult parseIntArrayAttr(AsmParser &parser,
const NamedAttribute &attr,
@@ -481,6 +500,11 @@ static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr,
return parseIntAttrValue(parser, attr.getValue(), value, desc);
};
static LogicalResult parseBool(AsmParser &parser, const NamedAttribute &attr,
bool &value, StringRef desc) {
return parseBoolAttrValue(parser, attr.getValue(), value, desc);
};
//===----------------------------------------------------------------------===//
// Attribute methods
//===----------------------------------------------------------------------===//
@@ -551,10 +575,17 @@ MfmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
assert(rank == 2 && "Unexpected rank of mma layout");
SmallVector<unsigned> elemsPerThread(rank);
unsigned elemsCol = ceil<unsigned>(shape[1], 32 * getWarpsPerCTA()[1]);
unsigned elemsRow = ceil<unsigned>(shape[0], 32 * getWarpsPerCTA()[0]) * 16;
elemsPerThread[0] = elemsRow;
elemsPerThread[1] = elemsCol;
if (getIsTransposed()) {
unsigned elemsCol = ceil<unsigned>(shape[1], 32 * getWarpsPerCTA()[1]) * 16;
unsigned elemsRow = ceil<unsigned>(shape[0], 32 * getWarpsPerCTA()[0]);
elemsPerThread[0] = elemsRow;
elemsPerThread[1] = elemsCol;
} else {
unsigned elemsCol = ceil<unsigned>(shape[1], 32 * getWarpsPerCTA()[1]);
unsigned elemsRow = ceil<unsigned>(shape[0], 32 * getWarpsPerCTA()[0]) * 16;
elemsPerThread[0] = elemsRow;
elemsPerThread[1] = elemsCol;
}
return elemsPerThread;
}
@@ -908,6 +939,7 @@ Attribute MfmaEncodingAttr::parse(AsmParser &parser, Type type) {
unsigned nonKDim = 0;
SmallVector<unsigned, 2> warpsPerCTA;
bool isTransposed;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "nonKDim") {
@@ -917,18 +949,21 @@ Attribute MfmaEncodingAttr::parse(AsmParser &parser, Type type) {
if (attr.getName() == "warpsPerCTA") {
if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed())
return {};
} else if (attr.getName() == "isTransposed") {
if (parseBool(parser, attr, isTransposed, "isTransposed").failed())
return {};
}
}
return parser.getChecked<MfmaEncodingAttr>(parser.getContext(), nonKDim,
warpsPerCTA);
warpsPerCTA, isTransposed);
}
void MfmaEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{"
<< "nonKDim = " << getNonKDim() << ", "
<< "warpsPerCTA = [" << getWarpsPerCTA() << "]"
<< "}>";
<< ", isTransposed = " << getIsTransposed() << "}>";
}
//===----------------------------------------------------------------------===//

View File

@@ -126,6 +126,18 @@ public:
BlockedToMFMA(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context) {}
bool isChainDot(triton::DotOp &dotOp) const {
auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
};
auto slices = mlir::getSlice(dotOp, filter);
for (Operation *op : slices) {
if (isa<triton::DotOp>(op) && (op != dotOp))
return true;
}
return false;
}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
@@ -157,8 +169,9 @@ public:
auto warpsPerTile = warpsPerTileMI200(dotOp, retShape, numWarps);
mfmaEnc = triton::gpu::MfmaEncodingAttr::get(oldRetType.getContext(),
nonKDim, warpsPerTile);
bool isTransposed = isChainDot(dotOp);
mfmaEnc = triton::gpu::MfmaEncodingAttr::get(
oldRetType.getContext(), nonKDim, warpsPerTile, isTransposed);
auto newRetType =
RankedTensorType::get(retShape, oldRetType.getElementType(), mfmaEnc);
@@ -194,7 +207,6 @@ public:
return success();
}
};
#endif
class BlockedToMMA : public mlir::RewritePattern {

View File

@@ -48,6 +48,16 @@ public:
dstDotOp.getParent() == srcMmaEncoding))
return;
}
#ifdef USE_ROCM
if (auto srcMfmaEncoding =
srcEncoding.dyn_cast<triton::gpu::MfmaEncodingAttr>()) {
if (srcMfmaEncoding.getWarpsPerCTA()[1] == 1 &&
srcMfmaEncoding.getIsTransposed() &&
dstDotOp.getParent() == srcMfmaEncoding)
return;
}
#endif
auto tmpType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(),
triton::gpu::SharedEncodingAttr::get(

View File

@@ -1202,7 +1202,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype",
[(*shape, 2, False, False, epilogue, allow_tf32, dtype)
for shape in [(64, 64, 64), (32, 32, 32)]
for epilogue in ['none', 'trans', 'add-matrix']
for epilogue in ['none', 'trans', 'add-matrix', 'chain-dot', 'softmax']
for allow_tf32 in [True, False]
for dtype in ['float16', 'float32']
if not (allow_tf32 and (dtype in ['float16']))] +
@@ -2128,12 +2128,13 @@ class MmaLayout:
class MfmaLayout:
def __init__(self, non_k_dim, warps_per_cta):
def __init__(self, non_k_dim, warps_per_cta, isTransposed):
self.non_k_dim = str(non_k_dim)
self.warps_per_cta = str(warps_per_cta)
self.isTransposed = str(isTransposed).lower()
def __str__(self):
return f"#triton_gpu.mfma<{{nonKDim = {self.non_k_dim}, warpsPerCTA = {self.warps_per_cta}}}>"
return f"#triton_gpu.mfma<{{nonKDim = {self.non_k_dim}, warpsPerCTA = {self.warps_per_cta}, isTransposed = {self.isTransposed}}}>"
class BlockedLayout:
@@ -2238,8 +2239,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
if _get_warp_size() == 64:
layouts = [
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1]),
MfmaLayout(non_k_dim=32, warps_per_cta=[2, 2]),
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], isTransposed=True),
MfmaLayout(non_k_dim=32, warps_per_cta=[2, 2], isTransposed=False),
]
shapes = [[128, 32], [128, 128], [32, 128], [64, 64]]
else:
@@ -2255,6 +2256,9 @@ else:
@pytest.mark.parametrize("src_layout", layouts)
@pytest.mark.parametrize("axis", [0, 1])
def test_reduce_layouts(M, N, src_layout, axis, device='cuda'):
if torch.version.hip is not None:
if src_layout.isTransposed and axis == 0:
pytest.skip("Reduce along axis 0 is not supported in transposed mfma layout")
rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1"
rdims_1d = f"{N}" if axis == 0 else f"{M}"
store_range = "%7" if axis == 0 else "%1"
@@ -2318,7 +2322,7 @@ def test_reduce_layouts(M, N, src_layout, axis, device='cuda'):
@pytest.mark.parametrize("shape", [(64, 64)])
@pytest.mark.parametrize("dtype", ['float16'])
@pytest.mark.parametrize("src_layout", [MfmaLayout(non_k_dim=32, warps_per_cta=[2, 1]), MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1])])
@pytest.mark.parametrize("src_layout", [MfmaLayout(non_k_dim=32, warps_per_cta=[2, 1], isTransposed=False), MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], isTransposed = True)])
@pytest.mark.parametrize("dst_layout", [BlockedLayout([1, 4], [4, 16], [1, 1], [1, 0])])
def test_make_range(dtype, shape, src_layout, dst_layout, device='cuda'):
ir = f"""

View File

@@ -12,6 +12,14 @@ import triton.ops
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
def test_op(Z, H, N_CTX, D_HEAD, dtype):
capability = torch.cuda.get_device_capability()
if torch.version.hip is not None:
if dtype != torch.float16:
pytest.skip("Currently flash attention on AMD gpu is only supported in fp16.")
if D_HEAD < 32:
pytest.skip("D_HEAD < 32 is not supported. It will be enabled once smaller tile size is supported in MFMA pipeline.")
if D_HEAD > 64:
pytest.skip("D_HEAD > 64 is not supported. Currently it causes shared memory out of resource error.")
if capability[0] < 8:
pytest.skip("Flash attention only supported for compute capability < 80")
torch.manual_seed(20)
@@ -29,21 +37,26 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype):
p = torch.softmax(p.float(), dim=-1).to(dtype)
# p = torch.exp(p)
ref_out = torch.matmul(p, v)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
if torch.version.hip is None:
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# # triton implementation
tri_out = triton.ops.attention(q, k, v, sm_scale)
# print(ref_out)
# print(tri_out)
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
if torch.version.hip is None:
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# compare
atol = 1e-1 if dtype == torch.bfloat16 else 1e-2
torch.testing.assert_allclose(ref_out, tri_out, atol=atol, rtol=0)
torch.testing.assert_allclose(ref_dv, tri_dv, atol=atol, rtol=0)
torch.testing.assert_allclose(ref_dk, tri_dk, atol=atol, rtol=0)
torch.testing.assert_allclose(ref_dq, tri_dq, atol=atol, rtol=0)
if torch.version.hip is None:
torch.testing.assert_allclose(ref_dv, tri_dv, atol=atol, rtol=0)
torch.testing.assert_allclose(ref_dk, tri_dk, atol=atol, rtol=0)
torch.testing.assert_allclose(ref_dq, tri_dq, atol=atol, rtol=0)

View File

@@ -202,7 +202,10 @@ class _attention(torch.autograd.Function):
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
raise RuntimeError("Flash attention currently only supported for compute capability >= 80")
BLOCK = 128
if torch.version.hip is not None:
BLOCK = 64
else:
BLOCK = 128
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv

View File

@@ -292,18 +292,21 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
p = torch.softmax(p.float(), dim=-1).half()
# p = torch.exp(p)
ref_out = torch.matmul(p, v)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
if torch.version.hip is None:
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# # triton implementation
tri_out = attention(q, k, v, sm_scale)
# print(ref_out)
# print(tri_out)
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
if torch.version.hip is None:
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# compare
assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
if torch.version.hip is None:

View File

@@ -1186,7 +1186,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0]}>
#mfma0 = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [1, 1]}>
#mfma0 = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma0}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma0}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
@@ -1209,7 +1209,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2]}>
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTranspose=false}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_mfma_block
@@ -1359,7 +1359,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2]}>
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {