Merge branch 'triton-mlir' into ifu-231117

This commit is contained in:
jayfurmanek
2023-12-12 14:24:11 -06:00
committed by GitHub
17 changed files with 2574 additions and 209 deletions

View File

@@ -426,7 +426,7 @@ bool supportMMA(triton::DotOp op, int version) {
#ifdef USE_ROCM
static bool supportMFMAGranularity(int m, int n, int k) {
// these limitations are dtype dependent, in future we may relax them
const static std::pair<int, int> mfmaTypes[2] = {{32, 8}, {16, 16}};
const static std::pair<int, int> mfmaTypes[] = {{32, 8}, {16, 16}, {4, 64}};
for (const auto &mfmaType : mfmaTypes) {
auto [granularityMN, granularityK] = mfmaType;
if (m % granularityMN != 0 || n % granularityMN != 0)

View File

@@ -441,7 +441,7 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
const SharedMemoryObject &smemObj) {
auto mfmaLayout = encoding.getParent().cast<MfmaEncodingAttr>();
auto nonKDim = mfmaLayout.getNonKDim();
assert(nonKDim == 32 || nonKDim == 16);
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
@@ -587,7 +587,7 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
const SharedMemoryObject &smemObj) {
auto mfmaLayout = encoding.getParent().cast<MfmaEncodingAttr>();
auto nonKDim = mfmaLayout.getNonKDim();
assert(nonKDim == 32 || nonKDim == 16);
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
auto bTensorTy = tensor.getType().cast<RankedTensorType>();

View File

@@ -80,115 +80,86 @@ struct DotOpMFMAConversionHelper {
return rewriter.create<arith::TruncIOp>(loc, i32_ty, tid);
}
Value generateMFMAOp(MFMAInstrDescr mfmaDescr, Value valA, Value valB,
Value valC) const {
Value generateMFMA32Op(MatrixCoreType coreType, Value valA, Value valB,
Value valC) const {
auto resType = valC.getType();
Value zeroFlag = i32_val(0);
switch (mfmaDescr.coreType) {
switch (coreType) {
case MatrixCoreType::FP32_FP8_FP8_FP32:
if (mfmaDescr.size == 16) {
return rewriter.create<ROCDL::mfma_f32_16x16x32_fp8_fp8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
} else {
return rewriter.create<ROCDL::mfma_f32_32x32x16_fp8_fp8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
}
return rewriter.create<ROCDL::mfma_f32_32x32x16_fp8_fp8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_FP8_BF8_FP32:
if (mfmaDescr.size == 16) {
return rewriter.create<ROCDL::mfma_f32_16x16x32_fp8_bf8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
} else {
return rewriter.create<ROCDL::mfma_f32_32x32x16_fp8_bf8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
}
return rewriter.create<ROCDL::mfma_f32_32x32x16_fp8_bf8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF8_FP8_FP32:
if (mfmaDescr.size == 16) {
return rewriter.create<ROCDL::mfma_f32_16x16x32_bf8_fp8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
} else {
return rewriter.create<ROCDL::mfma_f32_32x32x16_bf8_fp8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
}
return rewriter.create<ROCDL::mfma_f32_32x32x16_bf8_fp8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF8_BF8_FP32:
if (mfmaDescr.size == 16) {
return rewriter.create<ROCDL::mfma_f32_16x16x32_bf8_bf8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
} else {
return rewriter.create<ROCDL::mfma_f32_32x32x16_bf8_bf8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
}
return rewriter.create<ROCDL::mfma_f32_32x32x16_bf8_bf8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_FP16_FP16_FP32:
if (mfmaDescr.size == 16) {
return rewriter.create<ROCDL::mfma_f32_16x16x16f16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
} else {
return rewriter.create<ROCDL::mfma_f32_32x32x8f16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
}
return rewriter.create<ROCDL::mfma_f32_32x32x8f16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF16_BF16_FP32:
if (mfmaDescr.size == 16) {
return rewriter.create<ROCDL::mfma_f32_16x16x8bf16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
} else {
return rewriter.create<ROCDL::mfma_f32_32x32x4bf16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
}
return rewriter.create<ROCDL::mfma_f32_32x32x4bf16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF16_BF16_FP32_1K:
if (mfmaDescr.size == 16) {
return rewriter.create<ROCDL::mfma_f32_16x16x16bf16_1k>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
} else {
assert(mfmaDescr.size == 32);
return rewriter.create<ROCDL::mfma_f32_32x32x8bf16_1k>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
}
return rewriter.create<ROCDL::mfma_f32_32x32x8bf16_1k>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_FP32_FP32_FP32:
if (mfmaDescr.size == 16) {
return rewriter.create<ROCDL::mfma_f32_16x16x4f32>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
} else {
assert(mfmaDescr.size == 32);
return rewriter.create<ROCDL::mfma_f32_32x32x2f32>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
}
return rewriter.create<ROCDL::mfma_f32_32x32x2f32>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::INT32_INT8_INT8_INT32:
if (mfmaDescr.size == 16) {
return rewriter.create<ROCDL::mfma_i32_16x16x16i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
} else {
return rewriter.create<ROCDL::mfma_i32_32x32x8i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
}
return rewriter.create<ROCDL::mfma_i32_32x32x8i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::INT32_INT8_INT8_INT32_CDNA3:
if (mfmaDescr.size == 16) {
return rewriter.create<ROCDL::mfma_i32_16x16x32_i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
} else {
return rewriter.create<ROCDL::mfma_i32_32x32x16_i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
}
return rewriter.create<ROCDL::mfma_i32_32x32x16_i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP64_FP64_FP64_FP64:
return rewriter.create<ROCDL::mfma_f64_16x16x4f64>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
default:
llvm::report_fatal_error("MFMA 32x32 data type not supported");
}
}
Value generateMFMA16Op(MatrixCoreType coreType, Value valA, Value valB,
Value valC) const {
auto resType = valC.getType();
Value zeroFlag = i32_val(0);
switch (coreType) {
case MatrixCoreType::FP32_FP16_FP16_FP32:
return rewriter.create<ROCDL::mfma_f32_16x16x16f16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF16_BF16_FP32:
return rewriter.create<ROCDL::mfma_f32_16x16x8bf16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF16_BF16_FP32_1K:
return rewriter.create<ROCDL::mfma_f32_16x16x16bf16_1k>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_FP32_FP32_FP32:
return rewriter.create<ROCDL::mfma_f32_16x16x4f32>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::INT32_INT8_INT8_INT32:
return rewriter.create<ROCDL::mfma_i32_16x16x16i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP64_FP64_FP64_FP64:
assert(mfmaDescr.size == 16);
return rewriter.create<ROCDL::mfma_f64_16x16x4f64>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
@@ -197,6 +168,72 @@ struct DotOpMFMAConversionHelper {
}
}
Value generateMFMA4Op(MatrixCoreType coreType, Value valA, Value valB,
Value valC) const {
auto resType = valC.getType();
Value zeroFlag = i32_val(0);
switch (coreType) {
case MatrixCoreType::FP32_FP16_FP16_FP32:
return rewriter.create<ROCDL::mfma_f32_4x4x4f16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF16_BF16_FP32:
return rewriter.create<ROCDL::mfma_f32_4x4x2bf16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF16_BF16_FP32_1K:
return rewriter.create<ROCDL::mfma_f32_4x4x4bf16_1k>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_FP32_FP32_FP32:
return rewriter.create<ROCDL::mfma_f32_4x4x1f32>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::INT32_INT8_INT8_INT32:
return rewriter.create<ROCDL::mfma_i32_4x4x4i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
default:
llvm::report_fatal_error("MFMA4 data type not supported");
}
}
Value generateMFMAOp(MFMAInstrDescr mfmaDescr, Value valA, Value valB,
Value valC) const {
switch (mfmaDescr.size) {
case 32:
return generateMFMA32Op(mfmaDescr.coreType, valA, valB, valC);
break;
case 16:
return generateMFMA16Op(mfmaDescr.coreType, valA, valB, valC);
break;
case 4:
return generateMFMA4Op(mfmaDescr.coreType, valA, valB, valC);
default:
llvm::report_fatal_error("MFMA nonkDim size is not supported");
}
return Value();
}
int getNumSubmatrices(Type elementType, int nonKDim) const {
switch (nonKDim) {
case 32:
case 16:
return 1;
break;
case 4:
assert(elementType.getIntOrFloatBitWidth() <= 32 &&
"fp64 is not supported yet");
assert(elementType.getIntOrFloatBitWidth() != 8 ||
elementType.isInteger(8) && "fp8 is not supported yet");
return 16;
break;
default:
llvm::report_fatal_error("unsupported nonKDim in MFMA dot");
}
return -1;
}
// TODO unify this function with Utility.cpp:supportMFMATypes
static MatrixCoreType getMatrixCoreTypeFromDot(DotOp op) {
auto aOperandTy = op.getA().getType();
@@ -223,21 +260,21 @@ struct DotOpMFMAConversionHelper {
if (aElemTy.isBF16()) {
auto nonKDim = mfmaEncoding.getNonKDim();
auto kWidth = dotOpEncoding.getKWidth();
if ((nonKDim == 32 && kWidth == 4) || (nonKDim == 16 && kWidth == 4)) {
if ((nonKDim == 32 || nonKDim == 16 || nonKDim == 4) && kWidth == 4) {
return MatrixCoreType::FP32_BF16_BF16_FP32_1K;
} else {
assert((nonKDim == 32 && kWidth == 2) ||
(nonKDim == 16 && kWidth == 2));
(nonKDim == 16 && kWidth == 2) || (nonKDim == 4 && kWidth == 2));
return MatrixCoreType::FP32_BF16_BF16_FP32;
}
}
if (aElemTy.isInteger(8)) {
auto nonKDim = mfmaEncoding.getNonKDim();
auto kWidth = dotOpEncoding.getKWidth();
if ((nonKDim == 32 && kWidth == 8) || (nonKDim == 16 && kWidth == 8)) {
if ((nonKDim == 32 || nonKDim == 16 || nonKDim == 4) && kWidth == 8) {
return MatrixCoreType::INT32_INT8_INT8_INT32_CDNA3;
}
else {
} else {
assert((nonKDim == 32 || nonKDim == 16 || nonKDim == 4) && kWidth == 4);
return MatrixCoreType::INT32_INT8_INT8_INT32;
}
}
@@ -255,11 +292,81 @@ struct DotOpMFMAConversionHelper {
return descr;
}
Value processSubBlocks(int numSubBlocks, Value acc, bool reduceSubBlocks,
bool zeroSubBlocks) const {
assert((numSubBlocks & (numSubBlocks - 1)) == 0 &&
"numSubBlocks in not pow 2!");
if (numSubBlocks == 1)
return acc;
constexpr int waveSize = 64;
int subBlockSize = waveSize / numSubBlocks;
Value laneId = getThreadId();
laneId = and_(laneId, i32_val(waveSize - 1));
auto vecTy = dyn_cast<VectorType>(acc.getType());
auto elemType = vecTy.getElementType();
assert(elemType.getIntOrFloatBitWidth() == 32);
int numScalars = vecTy.getNumElements();
std::vector<Value> accScalar(numScalars);
for (int i = 0; i < numScalars; ++i)
accScalar[i] = extract_element(elemType, acc, i32_val(i));
if (reduceSubBlocks) {
while (subBlockSize < waveSize) {
for (int i = 0; i < numScalars; ++i) {
Value other_acc =
mlir::LLVM::shflSync(loc, rewriter, accScalar[i], subBlockSize);
if (elemType.isInteger(32))
accScalar[i] = add(accScalar[i], other_acc);
else
accScalar[i] = fadd(accScalar[i], other_acc);
}
subBlockSize *= 2;
}
}
if (zeroSubBlocks) {
Value zero;
if (elemType.isInteger(32))
zero = i32_val(0);
else
zero = f32_val(0.0);
auto cond = icmp_ult(laneId, i32_val(subBlockSize));
for (int i = 0; i < numScalars; ++i)
accScalar[i] = select(cond, accScalar[i], zero);
}
Value reducedAcc = undef(vecTy);
for (int i = 0; i < numScalars; ++i)
reducedAcc = insert_element(vecTy, reducedAcc, accScalar[i], i32_val(i));
return reducedAcc;
}
/// @brief MFMA 4x4 is computes 16 matrix mupliplications, this functions adds
/// these 16 matrices to get final 4x4 matrix
/// @param numSubBlocks
/// @param acc
/// @return
Value reduceSubBlocks(int numSubBlocks, Value acc) const {
return processSubBlocks(numSubBlocks, acc, true, false);
}
/// @brief Zeroes out redundant values in all sub-blocks except first one
///
/// Every wave in mfma 4x4 layout holds only 4 unique values(scalar or
/// vectors) in blocks of 4 consecutive threads, There are 16 copies of these
/// 4 values across all threads of the wave. Need to zero out 15 copies to use
/// accumulator between dot operations.
/// @param numSubBlocks
/// @param acc
/// @return
Value zeroAuxiliarBlocks(int numSubBlocks, Value acc) const {
return processSubBlocks(numSubBlocks, acc, false, true);
}
// Conduct the Dot conversion.
LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor) const {
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
auto nonKDim = mfmaLayout.getNonKDim();
assert(nonKDim == 32 || nonKDim == 16);
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
auto mfmaInstrDescr = getMatrixInstrDescr(op);
Value a = op.getA();
@@ -296,8 +403,10 @@ struct DotOpMFMAConversionHelper {
unsigned warpSize = triton::gpu::getWarpSize(mfmaLayout);
// compute number of output elements that each thread holds for one MFMA
// instruction
auto elemsPerVec = nonKDim * nonKDim / warpSize;
// instruction. subBlocks
const int subBlocks =
getNumSubmatrices(aTensorTy.getElementType(), nonKDim);
auto elemsPerVec = nonKDim * nonKDim * subBlocks / warpSize;
auto vecTy = vec_ty(dstElemTy, elemsPerVec);
for (int m = 0; m < numRepM; ++m) {
@@ -308,13 +417,14 @@ struct DotOpMFMAConversionHelper {
vecTy, acc, fc[m * numRepN * elemsPerVec + n * elemsPerVec + v],
i32_val(v));
}
acc = zeroAuxiliarBlocks(subBlocks, acc);
for (size_t k = 0; k < numRepK; k++) {
acc =
mfmaLayout.getIsTransposed()
? generateMFMAOp(mfmaInstrDescr, hb[{n, k}], ha[{m, k}], acc)
: generateMFMAOp(mfmaInstrDescr, ha[{m, k}], hb[{n, k}], acc);
}
acc = reduceSubBlocks(subBlocks, acc);
for (unsigned v = 0; v < elemsPerVec; ++v) {
fc[m * numRepN * elemsPerVec + n * elemsPerVec + v] =
extract_element(dstElemTy, acc, i32_val(v));

View File

@@ -775,7 +775,9 @@ public:
auto nonKDim = mfmaLayout.getNonKDim();
// MFMA output tile consists of repeated "dot operand B" layout groups along
// row axis. This variable defines number of these groups.
const unsigned numGroups = (nonKDim == 32 ? 4 : 1);
DenseMap<int, int> groups{{4, 1}, {16, 1}, {32, 4}};
unsigned numGroups = groups.at(nonKDim);
const unsigned elemsPerThreadPerGroup = 4;
auto warpSize = getWarpSize(mfmaLayout);
assert(warpSize == 64);
@@ -1193,7 +1195,12 @@ private:
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(triton::gpu::getWarpSize(mfmaLayout));
Value laneId = urem(threadId, warpSize);
Value effectiveWarpSize = warpSize;
if (nonKDim == 4) {
const int uniqueValuesPerWarp = 4;
effectiveWarpSize = i32_val(uniqueValuesPerWarp);
}
Value laneId = urem(threadId, effectiveWarpSize);
Value warpId = udiv(threadId, warpSize);
Value warpId0 =

View File

@@ -2,6 +2,10 @@
#include "TypeConverter.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "triton/Dialect/NVGPU/IR/Dialect.h"
#if USE_ROCM
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#endif
namespace mlir {
namespace LLVM {
@@ -286,10 +290,20 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter,
#ifdef USE_ROCM
//On AMD, the ds_swizzle_b32 and ds_permute_b32 instructions work on 32bit/dwords
//so we need promote to 32 here.
if (bits == 8) {
Value i32Val = sext(i32_ty, val);
Value result = commonShflSync(loc, rewriter, i32Val, i, strideInt, mode, clamp, laneId);
return trunc(i8_ty, result);
auto valType = val.getType();
if (!valType.isInteger(32) && bits <= 32) {
if (!valType.isIntOrIndex())
val = bitcast(val, int_ty(bits));
if (bits < 32)
val = sext(i32_ty, val);
val = commonShflSync(loc, rewriter, val, i, strideInt, mode, clamp, laneId);
if (bits < 32)
val = trunc(int_ty(bits), val);
if (!valType.isIntOrIndex())
val = bitcast(val, valType);
return val;
}
#endif
@@ -307,20 +321,12 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter,
}
#ifdef USE_ROCM
GCNBuilder builder;
auto permute = [&](Value lane, StringRef permuteInstStr) {
assert(permuteInstStr == "ds_permute_b32" ||
permuteInstStr == "ds_bpermute_b32");
auto bpermute = [&](Value lane) {
// Multiple lineId by 4. (More on permute instruction semantics:
// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf#page=180
Value byteOffset = i32_val(2);
Value permuteAddr = shl(lane, byteOffset);
auto shfl = builder.create(permuteInstStr.str());
auto dOpr = builder.newOperand("=v");
auto addrOpr = builder.newOperand(permuteAddr, "v");
auto aOpr = builder.newOperand(val, "v");
(*shfl)(dOpr, addrOpr, aOpr);
return rewriter.create<ROCDL::DsBpermuteOp>(loc, valType, permuteAddr, val);
};
switch (mode) {
@@ -334,39 +340,30 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter,
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)})
.getResult(0);
Value stride = i32_val(32);
Value lineId = add(threadId, stride);
permute(lineId, "ds_permute_b32");
Value lineId = xor_(threadId, stride);
return bpermute(lineId);
} else {
// This map facilates the butterfly shuffle pattern for a stride less
// than 16. The pattern stride is the key of the map.
DenseMap<short, unsigned int> masks{
{16, 0x401F}, {8, 0x201F}, {4, 0x101F}, {2, 0x081F}, {1, 0x041F}};
auto shfl = builder.create("ds_swizzle_b32");
auto dOpr = builder.newOperand("=v");
auto aOpr = builder.newOperand(val, "v");
auto maskOpr =
builder.newConstantOperand("offset:" + std::to_string(masks[strideInt]));
(*shfl)(dOpr, aOpr, maskOpr);
Value offset = i32_val(masks[strideInt]);
return rewriter.create<ROCDL::DsSwizzleOp>(loc, valType, val, offset);
}
break;
case NVVM::ShflKind::up: {
Value mask = icmp_slt(laneId, i);
Value delta = sub(laneId, i);
Value index = select(mask, laneId, delta);
permute(index, "ds_bpermute_b32");
break;
return bpermute(index);
}
case NVVM::ShflKind::idx:
permute(i, "ds_bpermute_b32");
break;
return bpermute(i);
default:
assert(false && "Unsupported ShflKind");
break;
}
auto swait = builder.create("s_waitcnt lgkmcnt(0)");
(*swait)();
return builder.launch(rewriter, loc, val.getType(), true);
return Value();
#else
Type type = val.getType();
if (type != i32_ty) {

View File

@@ -240,15 +240,22 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
}
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
unsigned rows, cols;
if (mfmaLayout.getNonKDim() == 32) {
switch (mfmaLayout.getNonKDim()) {
case 32:
rows = 16;
cols = 1;
} else if (mfmaLayout.getNonKDim() == 16) {
break;
case 16:
rows = 4;
cols = 1;
} else
break;
case 4:
rows = 4;
cols = 1;
break;
default:
llvm_unreachable("Unexpected mfma non-k dim");
}
if (mfmaLayout.getIsTransposed()) {
return {cols, rows};
} else {
@@ -993,9 +1000,11 @@ SmallVector<int64_t>
DotOperandEncodingAttr::getMFMAElemsPerInstr() const {
auto mfmaEncoding = getParent().cast<MfmaEncodingAttr>();
int64_t nonKDim = mfmaEncoding.getNonKDim();
assert(nonKDim == 32 || nonKDim == 16);
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
int64_t kWidth = getKWidth();
int64_t kDim = kWidth * (nonKDim == 32 ? 2 : 4);
constexpr int waveSize = 64; // MFMA is used on wave64 architectures only
int kGroups = waveSize / nonKDim;
int64_t kDim = kWidth * kGroups;
if (getOpIdx() == 0)
return {nonKDim, kDim};
else

View File

@@ -105,9 +105,18 @@ public:
if (enforcedNonKDim != 0) {
nonKDim = enforcedNonKDim;
} else {
nonKDim = (resShape[0] < 32 || resShape[1] < 32) ? 16 : 32;
nonKDim = -1;
int minSize = std::min(resShape[0], resShape[1]);
if (minSize >= 32)
nonKDim = 32;
if (minSize >= 16 && minSize < 32)
nonKDim = 16;
if (minSize < 16)
nonKDim = 4;
assert(nonKDim != -1);
}
if (nonKDim == 32) {
switch (nonKDim) {
case 32:
if (elemType.isF32())
kDim = 2;
if (elemType.isF16())
@@ -130,7 +139,8 @@ public:
kDim = 8;
}
}
} else {
break;
case 16:
if (elemType.isF32())
kDim = 4;
if (elemType.isF16())
@@ -152,7 +162,25 @@ public:
else {
kDim = 16;
}
}
}
break;
case 4:
if (elemType.isF32())
kDim = 16;
if (elemType.isF16())
kDim = 64;
if (elemType.isBF16()) {
if (mfmaVersion == 1)
kDim = 32;
if (mfmaVersion >= 2)
kDim = 64;
}
if (elemType.isInteger(8)) {
kDim = 64;
}
break;
default:
llvm::report_fatal_error("unsupported nonKDim size in MFMA dot");
}
assert(kDim != -1);
assert(nonKDim != -1);
@@ -221,11 +249,19 @@ public:
auto kWidth = kDim;
// in mfma 32x32 case argument matrix groups elements in 2 groups
// in mfma 16x16 case argument matrix groups elements in 4 groups
if (nonKDim == 32) {
// in mfma 4x4 case arguemnt matrix groups in 16 groups
switch (nonKDim) {
case 32:
kWidth /= 2;
} else {
assert(nonKDim == 16);
break;
case 16:
kWidth /= 4;
break;
case 4:
kWidth /= 16;
break;
default:
llvm::report_fatal_error("unsupported kDim in mfma dot");
}
auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(),