Merge branch 'triton-mlir' into ifu-231117

This commit is contained in:
jayfurmanek
2023-12-12 14:24:11 -06:00
committed by GitHub
17 changed files with 2574 additions and 209 deletions

View File

@@ -740,6 +740,24 @@ The data will be distributed between threads as follows:
[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ]
[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ]
[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ]
Example 3:
Suppose we have a tensor with a shape of [8, 8], warpsPerCTA set to [2, 2] and nonKDim set to 4.
The data will be distributed between threads as follows(note that each element is duploicated in 16 threads):
M N -> wave 0 wave 2
| --------------------------/\-------------------------- ------------------------------/\------------------------------
V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ]
[ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ]
[ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ]
[ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ]
wave 1 wave 3
--------------------------/\-------------------------- ------------------------------/\------------------------------
[ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ]
[ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ]
[ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ]
[ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ]
}];
let parameters = (

View File

@@ -426,7 +426,7 @@ bool supportMMA(triton::DotOp op, int version) {
#ifdef USE_ROCM
static bool supportMFMAGranularity(int m, int n, int k) {
// these limitations are dtype dependent, in future we may relax them
const static std::pair<int, int> mfmaTypes[2] = {{32, 8}, {16, 16}};
const static std::pair<int, int> mfmaTypes[] = {{32, 8}, {16, 16}, {4, 64}};
for (const auto &mfmaType : mfmaTypes) {
auto [granularityMN, granularityK] = mfmaType;
if (m % granularityMN != 0 || n % granularityMN != 0)

View File

@@ -441,7 +441,7 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
const SharedMemoryObject &smemObj) {
auto mfmaLayout = encoding.getParent().cast<MfmaEncodingAttr>();
auto nonKDim = mfmaLayout.getNonKDim();
assert(nonKDim == 32 || nonKDim == 16);
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
@@ -587,7 +587,7 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
const SharedMemoryObject &smemObj) {
auto mfmaLayout = encoding.getParent().cast<MfmaEncodingAttr>();
auto nonKDim = mfmaLayout.getNonKDim();
assert(nonKDim == 32 || nonKDim == 16);
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
auto bTensorTy = tensor.getType().cast<RankedTensorType>();

View File

@@ -80,115 +80,86 @@ struct DotOpMFMAConversionHelper {
return rewriter.create<arith::TruncIOp>(loc, i32_ty, tid);
}
Value generateMFMAOp(MFMAInstrDescr mfmaDescr, Value valA, Value valB,
Value valC) const {
Value generateMFMA32Op(MatrixCoreType coreType, Value valA, Value valB,
Value valC) const {
auto resType = valC.getType();
Value zeroFlag = i32_val(0);
switch (mfmaDescr.coreType) {
switch (coreType) {
case MatrixCoreType::FP32_FP8_FP8_FP32:
if (mfmaDescr.size == 16) {
return rewriter.create<ROCDL::mfma_f32_16x16x32_fp8_fp8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
} else {
return rewriter.create<ROCDL::mfma_f32_32x32x16_fp8_fp8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
}
return rewriter.create<ROCDL::mfma_f32_32x32x16_fp8_fp8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_FP8_BF8_FP32:
if (mfmaDescr.size == 16) {
return rewriter.create<ROCDL::mfma_f32_16x16x32_fp8_bf8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
} else {
return rewriter.create<ROCDL::mfma_f32_32x32x16_fp8_bf8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
}
return rewriter.create<ROCDL::mfma_f32_32x32x16_fp8_bf8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF8_FP8_FP32:
if (mfmaDescr.size == 16) {
return rewriter.create<ROCDL::mfma_f32_16x16x32_bf8_fp8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
} else {
return rewriter.create<ROCDL::mfma_f32_32x32x16_bf8_fp8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
}
return rewriter.create<ROCDL::mfma_f32_32x32x16_bf8_fp8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF8_BF8_FP32:
if (mfmaDescr.size == 16) {
return rewriter.create<ROCDL::mfma_f32_16x16x32_bf8_bf8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
} else {
return rewriter.create<ROCDL::mfma_f32_32x32x16_bf8_bf8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
}
return rewriter.create<ROCDL::mfma_f32_32x32x16_bf8_bf8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_FP16_FP16_FP32:
if (mfmaDescr.size == 16) {
return rewriter.create<ROCDL::mfma_f32_16x16x16f16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
} else {
return rewriter.create<ROCDL::mfma_f32_32x32x8f16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
}
return rewriter.create<ROCDL::mfma_f32_32x32x8f16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF16_BF16_FP32:
if (mfmaDescr.size == 16) {
return rewriter.create<ROCDL::mfma_f32_16x16x8bf16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
} else {
return rewriter.create<ROCDL::mfma_f32_32x32x4bf16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
}
return rewriter.create<ROCDL::mfma_f32_32x32x4bf16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF16_BF16_FP32_1K:
if (mfmaDescr.size == 16) {
return rewriter.create<ROCDL::mfma_f32_16x16x16bf16_1k>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
} else {
assert(mfmaDescr.size == 32);
return rewriter.create<ROCDL::mfma_f32_32x32x8bf16_1k>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
}
return rewriter.create<ROCDL::mfma_f32_32x32x8bf16_1k>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_FP32_FP32_FP32:
if (mfmaDescr.size == 16) {
return rewriter.create<ROCDL::mfma_f32_16x16x4f32>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
} else {
assert(mfmaDescr.size == 32);
return rewriter.create<ROCDL::mfma_f32_32x32x2f32>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
}
return rewriter.create<ROCDL::mfma_f32_32x32x2f32>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::INT32_INT8_INT8_INT32:
if (mfmaDescr.size == 16) {
return rewriter.create<ROCDL::mfma_i32_16x16x16i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
} else {
return rewriter.create<ROCDL::mfma_i32_32x32x8i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
}
return rewriter.create<ROCDL::mfma_i32_32x32x8i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::INT32_INT8_INT8_INT32_CDNA3:
if (mfmaDescr.size == 16) {
return rewriter.create<ROCDL::mfma_i32_16x16x32_i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
} else {
return rewriter.create<ROCDL::mfma_i32_32x32x16_i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
}
return rewriter.create<ROCDL::mfma_i32_32x32x16_i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP64_FP64_FP64_FP64:
return rewriter.create<ROCDL::mfma_f64_16x16x4f64>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
default:
llvm::report_fatal_error("MFMA 32x32 data type not supported");
}
}
Value generateMFMA16Op(MatrixCoreType coreType, Value valA, Value valB,
Value valC) const {
auto resType = valC.getType();
Value zeroFlag = i32_val(0);
switch (coreType) {
case MatrixCoreType::FP32_FP16_FP16_FP32:
return rewriter.create<ROCDL::mfma_f32_16x16x16f16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF16_BF16_FP32:
return rewriter.create<ROCDL::mfma_f32_16x16x8bf16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF16_BF16_FP32_1K:
return rewriter.create<ROCDL::mfma_f32_16x16x16bf16_1k>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_FP32_FP32_FP32:
return rewriter.create<ROCDL::mfma_f32_16x16x4f32>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::INT32_INT8_INT8_INT32:
return rewriter.create<ROCDL::mfma_i32_16x16x16i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP64_FP64_FP64_FP64:
assert(mfmaDescr.size == 16);
return rewriter.create<ROCDL::mfma_f64_16x16x4f64>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
@@ -197,6 +168,72 @@ struct DotOpMFMAConversionHelper {
}
}
Value generateMFMA4Op(MatrixCoreType coreType, Value valA, Value valB,
Value valC) const {
auto resType = valC.getType();
Value zeroFlag = i32_val(0);
switch (coreType) {
case MatrixCoreType::FP32_FP16_FP16_FP32:
return rewriter.create<ROCDL::mfma_f32_4x4x4f16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF16_BF16_FP32:
return rewriter.create<ROCDL::mfma_f32_4x4x2bf16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF16_BF16_FP32_1K:
return rewriter.create<ROCDL::mfma_f32_4x4x4bf16_1k>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_FP32_FP32_FP32:
return rewriter.create<ROCDL::mfma_f32_4x4x1f32>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::INT32_INT8_INT8_INT32:
return rewriter.create<ROCDL::mfma_i32_4x4x4i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
default:
llvm::report_fatal_error("MFMA4 data type not supported");
}
}
Value generateMFMAOp(MFMAInstrDescr mfmaDescr, Value valA, Value valB,
Value valC) const {
switch (mfmaDescr.size) {
case 32:
return generateMFMA32Op(mfmaDescr.coreType, valA, valB, valC);
break;
case 16:
return generateMFMA16Op(mfmaDescr.coreType, valA, valB, valC);
break;
case 4:
return generateMFMA4Op(mfmaDescr.coreType, valA, valB, valC);
default:
llvm::report_fatal_error("MFMA nonkDim size is not supported");
}
return Value();
}
int getNumSubmatrices(Type elementType, int nonKDim) const {
switch (nonKDim) {
case 32:
case 16:
return 1;
break;
case 4:
assert(elementType.getIntOrFloatBitWidth() <= 32 &&
"fp64 is not supported yet");
assert(elementType.getIntOrFloatBitWidth() != 8 ||
elementType.isInteger(8) && "fp8 is not supported yet");
return 16;
break;
default:
llvm::report_fatal_error("unsupported nonKDim in MFMA dot");
}
return -1;
}
// TODO unify this function with Utility.cpp:supportMFMATypes
static MatrixCoreType getMatrixCoreTypeFromDot(DotOp op) {
auto aOperandTy = op.getA().getType();
@@ -223,21 +260,21 @@ struct DotOpMFMAConversionHelper {
if (aElemTy.isBF16()) {
auto nonKDim = mfmaEncoding.getNonKDim();
auto kWidth = dotOpEncoding.getKWidth();
if ((nonKDim == 32 && kWidth == 4) || (nonKDim == 16 && kWidth == 4)) {
if ((nonKDim == 32 || nonKDim == 16 || nonKDim == 4) && kWidth == 4) {
return MatrixCoreType::FP32_BF16_BF16_FP32_1K;
} else {
assert((nonKDim == 32 && kWidth == 2) ||
(nonKDim == 16 && kWidth == 2));
(nonKDim == 16 && kWidth == 2) || (nonKDim == 4 && kWidth == 2));
return MatrixCoreType::FP32_BF16_BF16_FP32;
}
}
if (aElemTy.isInteger(8)) {
auto nonKDim = mfmaEncoding.getNonKDim();
auto kWidth = dotOpEncoding.getKWidth();
if ((nonKDim == 32 && kWidth == 8) || (nonKDim == 16 && kWidth == 8)) {
if ((nonKDim == 32 || nonKDim == 16 || nonKDim == 4) && kWidth == 8) {
return MatrixCoreType::INT32_INT8_INT8_INT32_CDNA3;
}
else {
} else {
assert((nonKDim == 32 || nonKDim == 16 || nonKDim == 4) && kWidth == 4);
return MatrixCoreType::INT32_INT8_INT8_INT32;
}
}
@@ -255,11 +292,81 @@ struct DotOpMFMAConversionHelper {
return descr;
}
Value processSubBlocks(int numSubBlocks, Value acc, bool reduceSubBlocks,
bool zeroSubBlocks) const {
assert((numSubBlocks & (numSubBlocks - 1)) == 0 &&
"numSubBlocks in not pow 2!");
if (numSubBlocks == 1)
return acc;
constexpr int waveSize = 64;
int subBlockSize = waveSize / numSubBlocks;
Value laneId = getThreadId();
laneId = and_(laneId, i32_val(waveSize - 1));
auto vecTy = dyn_cast<VectorType>(acc.getType());
auto elemType = vecTy.getElementType();
assert(elemType.getIntOrFloatBitWidth() == 32);
int numScalars = vecTy.getNumElements();
std::vector<Value> accScalar(numScalars);
for (int i = 0; i < numScalars; ++i)
accScalar[i] = extract_element(elemType, acc, i32_val(i));
if (reduceSubBlocks) {
while (subBlockSize < waveSize) {
for (int i = 0; i < numScalars; ++i) {
Value other_acc =
mlir::LLVM::shflSync(loc, rewriter, accScalar[i], subBlockSize);
if (elemType.isInteger(32))
accScalar[i] = add(accScalar[i], other_acc);
else
accScalar[i] = fadd(accScalar[i], other_acc);
}
subBlockSize *= 2;
}
}
if (zeroSubBlocks) {
Value zero;
if (elemType.isInteger(32))
zero = i32_val(0);
else
zero = f32_val(0.0);
auto cond = icmp_ult(laneId, i32_val(subBlockSize));
for (int i = 0; i < numScalars; ++i)
accScalar[i] = select(cond, accScalar[i], zero);
}
Value reducedAcc = undef(vecTy);
for (int i = 0; i < numScalars; ++i)
reducedAcc = insert_element(vecTy, reducedAcc, accScalar[i], i32_val(i));
return reducedAcc;
}
/// @brief MFMA 4x4 is computes 16 matrix mupliplications, this functions adds
/// these 16 matrices to get final 4x4 matrix
/// @param numSubBlocks
/// @param acc
/// @return
Value reduceSubBlocks(int numSubBlocks, Value acc) const {
return processSubBlocks(numSubBlocks, acc, true, false);
}
/// @brief Zeroes out redundant values in all sub-blocks except first one
///
/// Every wave in mfma 4x4 layout holds only 4 unique values(scalar or
/// vectors) in blocks of 4 consecutive threads, There are 16 copies of these
/// 4 values across all threads of the wave. Need to zero out 15 copies to use
/// accumulator between dot operations.
/// @param numSubBlocks
/// @param acc
/// @return
Value zeroAuxiliarBlocks(int numSubBlocks, Value acc) const {
return processSubBlocks(numSubBlocks, acc, false, true);
}
// Conduct the Dot conversion.
LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor) const {
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
auto nonKDim = mfmaLayout.getNonKDim();
assert(nonKDim == 32 || nonKDim == 16);
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
auto mfmaInstrDescr = getMatrixInstrDescr(op);
Value a = op.getA();
@@ -296,8 +403,10 @@ struct DotOpMFMAConversionHelper {
unsigned warpSize = triton::gpu::getWarpSize(mfmaLayout);
// compute number of output elements that each thread holds for one MFMA
// instruction
auto elemsPerVec = nonKDim * nonKDim / warpSize;
// instruction. subBlocks
const int subBlocks =
getNumSubmatrices(aTensorTy.getElementType(), nonKDim);
auto elemsPerVec = nonKDim * nonKDim * subBlocks / warpSize;
auto vecTy = vec_ty(dstElemTy, elemsPerVec);
for (int m = 0; m < numRepM; ++m) {
@@ -308,13 +417,14 @@ struct DotOpMFMAConversionHelper {
vecTy, acc, fc[m * numRepN * elemsPerVec + n * elemsPerVec + v],
i32_val(v));
}
acc = zeroAuxiliarBlocks(subBlocks, acc);
for (size_t k = 0; k < numRepK; k++) {
acc =
mfmaLayout.getIsTransposed()
? generateMFMAOp(mfmaInstrDescr, hb[{n, k}], ha[{m, k}], acc)
: generateMFMAOp(mfmaInstrDescr, ha[{m, k}], hb[{n, k}], acc);
}
acc = reduceSubBlocks(subBlocks, acc);
for (unsigned v = 0; v < elemsPerVec; ++v) {
fc[m * numRepN * elemsPerVec + n * elemsPerVec + v] =
extract_element(dstElemTy, acc, i32_val(v));

View File

@@ -775,7 +775,9 @@ public:
auto nonKDim = mfmaLayout.getNonKDim();
// MFMA output tile consists of repeated "dot operand B" layout groups along
// row axis. This variable defines number of these groups.
const unsigned numGroups = (nonKDim == 32 ? 4 : 1);
DenseMap<int, int> groups{{4, 1}, {16, 1}, {32, 4}};
unsigned numGroups = groups.at(nonKDim);
const unsigned elemsPerThreadPerGroup = 4;
auto warpSize = getWarpSize(mfmaLayout);
assert(warpSize == 64);
@@ -1193,7 +1195,12 @@ private:
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(triton::gpu::getWarpSize(mfmaLayout));
Value laneId = urem(threadId, warpSize);
Value effectiveWarpSize = warpSize;
if (nonKDim == 4) {
const int uniqueValuesPerWarp = 4;
effectiveWarpSize = i32_val(uniqueValuesPerWarp);
}
Value laneId = urem(threadId, effectiveWarpSize);
Value warpId = udiv(threadId, warpSize);
Value warpId0 =

View File

@@ -2,6 +2,10 @@
#include "TypeConverter.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "triton/Dialect/NVGPU/IR/Dialect.h"
#if USE_ROCM
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#endif
namespace mlir {
namespace LLVM {
@@ -286,10 +290,20 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter,
#ifdef USE_ROCM
//On AMD, the ds_swizzle_b32 and ds_permute_b32 instructions work on 32bit/dwords
//so we need promote to 32 here.
if (bits == 8) {
Value i32Val = sext(i32_ty, val);
Value result = commonShflSync(loc, rewriter, i32Val, i, strideInt, mode, clamp, laneId);
return trunc(i8_ty, result);
auto valType = val.getType();
if (!valType.isInteger(32) && bits <= 32) {
if (!valType.isIntOrIndex())
val = bitcast(val, int_ty(bits));
if (bits < 32)
val = sext(i32_ty, val);
val = commonShflSync(loc, rewriter, val, i, strideInt, mode, clamp, laneId);
if (bits < 32)
val = trunc(int_ty(bits), val);
if (!valType.isIntOrIndex())
val = bitcast(val, valType);
return val;
}
#endif
@@ -307,20 +321,12 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter,
}
#ifdef USE_ROCM
GCNBuilder builder;
auto permute = [&](Value lane, StringRef permuteInstStr) {
assert(permuteInstStr == "ds_permute_b32" ||
permuteInstStr == "ds_bpermute_b32");
auto bpermute = [&](Value lane) {
// Multiple lineId by 4. (More on permute instruction semantics:
// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf#page=180
Value byteOffset = i32_val(2);
Value permuteAddr = shl(lane, byteOffset);
auto shfl = builder.create(permuteInstStr.str());
auto dOpr = builder.newOperand("=v");
auto addrOpr = builder.newOperand(permuteAddr, "v");
auto aOpr = builder.newOperand(val, "v");
(*shfl)(dOpr, addrOpr, aOpr);
return rewriter.create<ROCDL::DsBpermuteOp>(loc, valType, permuteAddr, val);
};
switch (mode) {
@@ -334,39 +340,30 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter,
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)})
.getResult(0);
Value stride = i32_val(32);
Value lineId = add(threadId, stride);
permute(lineId, "ds_permute_b32");
Value lineId = xor_(threadId, stride);
return bpermute(lineId);
} else {
// This map facilates the butterfly shuffle pattern for a stride less
// than 16. The pattern stride is the key of the map.
DenseMap<short, unsigned int> masks{
{16, 0x401F}, {8, 0x201F}, {4, 0x101F}, {2, 0x081F}, {1, 0x041F}};
auto shfl = builder.create("ds_swizzle_b32");
auto dOpr = builder.newOperand("=v");
auto aOpr = builder.newOperand(val, "v");
auto maskOpr =
builder.newConstantOperand("offset:" + std::to_string(masks[strideInt]));
(*shfl)(dOpr, aOpr, maskOpr);
Value offset = i32_val(masks[strideInt]);
return rewriter.create<ROCDL::DsSwizzleOp>(loc, valType, val, offset);
}
break;
case NVVM::ShflKind::up: {
Value mask = icmp_slt(laneId, i);
Value delta = sub(laneId, i);
Value index = select(mask, laneId, delta);
permute(index, "ds_bpermute_b32");
break;
return bpermute(index);
}
case NVVM::ShflKind::idx:
permute(i, "ds_bpermute_b32");
break;
return bpermute(i);
default:
assert(false && "Unsupported ShflKind");
break;
}
auto swait = builder.create("s_waitcnt lgkmcnt(0)");
(*swait)();
return builder.launch(rewriter, loc, val.getType(), true);
return Value();
#else
Type type = val.getType();
if (type != i32_ty) {

View File

@@ -240,15 +240,22 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
}
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
unsigned rows, cols;
if (mfmaLayout.getNonKDim() == 32) {
switch (mfmaLayout.getNonKDim()) {
case 32:
rows = 16;
cols = 1;
} else if (mfmaLayout.getNonKDim() == 16) {
break;
case 16:
rows = 4;
cols = 1;
} else
break;
case 4:
rows = 4;
cols = 1;
break;
default:
llvm_unreachable("Unexpected mfma non-k dim");
}
if (mfmaLayout.getIsTransposed()) {
return {cols, rows};
} else {
@@ -993,9 +1000,11 @@ SmallVector<int64_t>
DotOperandEncodingAttr::getMFMAElemsPerInstr() const {
auto mfmaEncoding = getParent().cast<MfmaEncodingAttr>();
int64_t nonKDim = mfmaEncoding.getNonKDim();
assert(nonKDim == 32 || nonKDim == 16);
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
int64_t kWidth = getKWidth();
int64_t kDim = kWidth * (nonKDim == 32 ? 2 : 4);
constexpr int waveSize = 64; // MFMA is used on wave64 architectures only
int kGroups = waveSize / nonKDim;
int64_t kDim = kWidth * kGroups;
if (getOpIdx() == 0)
return {nonKDim, kDim};
else

View File

@@ -105,9 +105,18 @@ public:
if (enforcedNonKDim != 0) {
nonKDim = enforcedNonKDim;
} else {
nonKDim = (resShape[0] < 32 || resShape[1] < 32) ? 16 : 32;
nonKDim = -1;
int minSize = std::min(resShape[0], resShape[1]);
if (minSize >= 32)
nonKDim = 32;
if (minSize >= 16 && minSize < 32)
nonKDim = 16;
if (minSize < 16)
nonKDim = 4;
assert(nonKDim != -1);
}
if (nonKDim == 32) {
switch (nonKDim) {
case 32:
if (elemType.isF32())
kDim = 2;
if (elemType.isF16())
@@ -130,7 +139,8 @@ public:
kDim = 8;
}
}
} else {
break;
case 16:
if (elemType.isF32())
kDim = 4;
if (elemType.isF16())
@@ -152,7 +162,25 @@ public:
else {
kDim = 16;
}
}
}
break;
case 4:
if (elemType.isF32())
kDim = 16;
if (elemType.isF16())
kDim = 64;
if (elemType.isBF16()) {
if (mfmaVersion == 1)
kDim = 32;
if (mfmaVersion >= 2)
kDim = 64;
}
if (elemType.isInteger(8)) {
kDim = 64;
}
break;
default:
llvm::report_fatal_error("unsupported nonKDim size in MFMA dot");
}
assert(kDim != -1);
assert(nonKDim != -1);
@@ -221,11 +249,19 @@ public:
auto kWidth = kDim;
// in mfma 32x32 case argument matrix groups elements in 2 groups
// in mfma 16x16 case argument matrix groups elements in 4 groups
if (nonKDim == 32) {
// in mfma 4x4 case arguemnt matrix groups in 16 groups
switch (nonKDim) {
case 32:
kWidth /= 2;
} else {
assert(nonKDim == 16);
break;
case 16:
kWidth /= 4;
break;
case 4:
kWidth /= 16;
break;
default:
llvm::report_fatal_error("unsupported kDim in mfma dot");
}
auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(),

View File

@@ -0,0 +1,790 @@
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team
This kernel supports arbitrarily sized sequence lengths.
Extra Credits:
- Original flash attention paper (https://arxiv.org/abs/2205.14135)
- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf)
"""
import pytest
import torch
import triton
import triton.language as tl
torch_dtype:tl.constexpr = torch.float16
TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz')
if TORCH_HAS_FP8E5:
torch_dtype:tl.constexpr = torch.float8_e5m2fnuz
@triton.jit
def _attn_fwd_inner(
acc, l_i, m_i, q,
K_block_ptr, V_block_ptr,
start_m,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
STAGE: tl.constexpr,
offs_m: tl.constexpr,
offs_n: tl.constexpr,
N_CTX,
pre_load_v: tl.constexpr,
padded_block: tl.constexpr,
total_tokens: tl.constexpr,
):
# range of values handled by this stage
if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2:
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
# N_CTX is the seqlen to the nearest block (round down).
# So here, we are computing the elements for that last irregular block.
# In the loop, we will mask the elements of BLOCK_N that do not exist.
elif padded_block:
lo, hi = N_CTX, N_CTX + BLOCK_N
lo = tl.multiple_of(lo, BLOCK_N)
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
# causal = False
else:
lo, hi = 0, N_CTX
# loop over k, v and update accumulator
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# For padded blocks, we will overrun the tensor size if
# we load all BLOCK_N. For others, the blocks are all within range.
if padded_block:
k = tl.load(K_block_ptr, boundary_check=(1,), padding_option="zero")
else:
k = tl.load(K_block_ptr)
if pre_load_v:
if padded_block:
v = tl.load(V_block_ptr, boundary_check=(0,), padding_option="zero")
else:
v = tl.load(V_block_ptr)
# -- compute qk ----
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = tl.where(mask, qk, float("-inf"))
if padded_block:
boundary = tl.full([BLOCK_M], total_tokens, dtype=tl.float32)
mask = (start_n + offs_n[None,:]) < boundary[:,None]
qk = tl.where(mask, qk, float("-inf"))
qk += tl.dot(q, k)
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk = qk - m_ij[:, None]
p = tl.math.exp2(qk)
# -- update output accumulator --
alpha = tl.math.exp2(m_i - m_ij)
acc = acc * alpha[:, None]
if not pre_load_v:
if padded_block:
v = tl.load(V_block_ptr, boundary_check=(0,), padding_option="zero")
else:
v = tl.load(V_block_ptr)
acc += tl.dot(p.to(v.dtype), v)
# -- update m_i and l_i
l_ij = tl.sum(p, 1)
l_i = l_i * alpha + l_ij
# update m_i and l_i
m_i = m_ij
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
return acc, l_i, m_i
@triton.jit
def _attn_fwd(
Q, K, V, sm_scale, M, Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, H,
N_CTX,
BLOCK_DMODEL: tl.constexpr,
STAGE: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
pre_load_v: tl.constexpr,
need_padding: tl.constexpr,
extra_tokens_n: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
qvk_offset = off_hz * stride_qh
# block pointers
Q_block_ptr = tl.make_block_ptr(
base=Q + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0),
)
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1),
)
O_block_ptr = tl.make_block_ptr(
base=Out + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
qk_scale = sm_scale * 1.44269504
# load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs
q = tl.load(Q_block_ptr, boundary_check=(0,1), padding_option="zero")
q = (q * qk_scale).to(q.dtype)
# stage 1: off-band
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
if STAGE & 1:
# We don't currently support causal masking and padding.
tl.static_assert((STAGE != 3) or not need_padding)
# equal to N_CTX if N_CTX is already a multiple of block_M
seqlen_aligned = N_CTX - extra_tokens_n
if N_CTX >= BLOCK_N:
acc, l_i, m_i = _attn_fwd_inner(
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
start_m,
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
4 - STAGE, offs_m, offs_n,
seqlen_aligned, pre_load_v,
False, seqlen_aligned
)
tl.debug_barrier()
if need_padding:
if N_CTX < BLOCK_N:
seqlen_aligned = 0
acc, l_i, m_i = _attn_fwd_inner(
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
start_m,
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
4 - STAGE, offs_m, offs_n,
seqlen_aligned, pre_load_v,
True, N_CTX,
)
# stage 2: on-band
if STAGE & 2:
# barrier makes it easier for compielr to schedule the
# two loops independently
tl.debug_barrier()
acc, l_i, m_i = _attn_fwd_inner(
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
start_m,
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
2, offs_m, offs_n,
N_CTX, pre_load_v,
)
# epilogue
# write back m
acc = acc / l_i[:, None]
m_ptrs = M + off_hz * N_CTX + offs_m
# Check for last block_M
overflow_size = (start_m * BLOCK_M) - N_CTX
if overflow_size > 0:
boundary = tl.full((BLOCK_M,), overflow_size, dtype=tl.float32)
# This is a > check because mask being 0 blocks the store.
m_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
tl.store(m_ptrs, m_i + tl.math.log2(l_i))
else:
tl.store(m_ptrs, m_i + tl.math.log2(l_i))
tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,1))
@triton.jit
def _attn_bwd_preprocess(O, DO, #
NewDO, Delta, #
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, #
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
# load
o = tl.load(O + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
tl.store(Delta + off_m, delta)
@triton.jit
def _bwd_kernel_dk_dv(
Q, K, V, sm_scale, Out, DO,
DK, DV,
L,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# Q is consumed depending on block ID. Every block uses
# previous block offset by BLOCK_M x D_HEAD.
qvk_offset = off_hz * stride_qh
qdo_offset = qvk_offset + start_m * BLOCK_M * stride_qm
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
# Initialize pointers to Q, K, V
Q_block_ptr = tl.make_block_ptr(
base=Q + qdo_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, start_m * BLOCK_M),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_vn, stride_vk),
offsets=(0, start_m * BLOCK_M),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
DO_block_ptr = tl.make_block_ptr(
base=DO + qdo_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
l_ptrs = L + off_hz * N_CTX
qk_scale = sm_scale * 1.44269504
# load k and v: they will stay in SRAM throughout
k = tl.load(K_block_ptr)
k = (k * qk_scale).to(k.dtype)
v = tl.load(V_block_ptr)
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# This lower loop bound is because of the causal mask. We create a lower triangular
# result. The upper triangular is -inf (becomes 0 when we do e^x). As such, it can
# be ignored in the GEMM.
lo = start_m * BLOCK_M
hi = N_CTX
# loop over q, do
for start_n in range(lo, hi, BLOCK_N):
offs_m_curr = offs_n[:, None] + start_n
# -- load q, do --
q = tl.load(Q_block_ptr)
do = tl.load(DO_block_ptr)
# -- compute qk ----
qk = tl.dot(q, k)
qk = tl.where(offs_m_curr >= offs_m[None, :], qk, float("-inf"))
l_i = tl.load(l_ptrs + offs_m_curr)
p = tl.math.exp2(qk - l_i)
# -- compute dv ----
dv += tl.dot(tl.trans(p.to(do.dtype)), do)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_N, BLOCK_M], dtype=tl.float32) - Di
dp += tl.dot(do, v)
# compute ds = p * (dp - delta[:, None])
ds = p * dp
# compute dk
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
# update pointers
Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_N, 0))
DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_N, 0))
# initialize pointers to output
DK_block_ptr = tl.make_block_ptr(
base=DK + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_kn, stride_kk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
DV_block_ptr = tl.make_block_ptr(
base=DV + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
tl.store(DK_block_ptr, (dk * sm_scale).to(DK.dtype.element_ty))
tl.store(DV_block_ptr, dv.to(tl.float16))
@triton.jit
def _bwd_kernel_dq(
Q, K, V, sm_scale, Out, DO,
DQ,
L,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
qvk_offset = off_hz * stride_qh
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
# Initialize pointers to Q, K, V
Q_block_ptr = tl.make_block_ptr(
base=Q + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_vn, stride_vk),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
DO_block_ptr = tl.make_block_ptr(
base=DO + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
l_ptrs = L + off_hz * N_CTX
qk_scale = sm_scale * 1.44269504
# load q and do: they will stay in SRAM throughout
q = tl.load(Q_block_ptr)
q = (q * qk_scale).to(q.dtype)
do = tl.load(DO_block_ptr)
Di = tl.load(D_ptrs + offs_m)
l_i = tl.load(l_ptrs + offs_m)
dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# loop over k, v
lo = 0
hi = (start_m + 1) * BLOCK_M
for start_n in range(lo, hi, BLOCK_N):
# -- load k, v --
k = tl.load(K_block_ptr)
v = tl.load(V_block_ptr)
# -- compute qk ----
qk = tl.dot(q, k)
qk = tl.where(offs_m[:, None] >= (offs_n[None, :] + start_n), qk, float("-inf"))
p = tl.math.exp2(qk - l_i[:, None])
# compute dp = dot(v, do)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, v)
# compute ds = p * (dp - delta[:, None])
ds = p * dp
# compute dq. Unfortunately we cannot avoid transpose here as this loop
# uses k both normal and transpose.
dq += tl.dot(ds.to(Q.dtype.element_ty), tl.trans(k))
# update pointers
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (0, BLOCK_N))
# initialize pointers to output
DQ_block_ptr = tl.make_block_ptr(
base=DQ + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
tl.store(DQ_block_ptr, (dq * sm_scale).to(tl.float16))
empty = torch.empty(128, device="cuda")
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False):
# shape constraints
_, _, seqlen, Lq = q.shape
Lk, Lv = k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
# For now we assume K and V seqlen = Q seqlen
assert seqlen == k.shape[-2] and seqlen == v.shape[-2]
# We've derived these previously from tuning the kernel
BLOCK_M = 256 if Lq == 128 else 128
BLOCK_N = 64 #128 if Lq == 128 else 64
waves_per_eu = 2 if Lq == 128 else 3
num_warps = 8 if Lq == 128 else 4
pre_load_v = False if Lq == 128 else True
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
stage = 3 if causal else 1
# Compute if we need padding and how much
seqlen_k = k.shape[-2]
if seqlen_k < BLOCK_N:
need_padding = True
extra_tokens_n = BLOCK_N - seqlen_k
# This effectively ensures we do not slice across Q.
assert(grid[0] == 1)
elif seqlen_k % BLOCK_N:
need_padding = True
extra_tokens_n = seqlen_k % BLOCK_N
else:
# We don't care if BLOCK_M isn't aligned, as we
# always boundary_check on Q and O
need_padding = False
extra_tokens_n = 0
o = torch.empty_like(q, dtype=v.dtype)
M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
_attn_fwd[grid](
q, k, v, sm_scale, M, o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1],
N_CTX=q.shape[2],
BLOCK_DMODEL=Lk,
STAGE=stage,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
waves_per_eu=waves_per_eu, pre_load_v=pre_load_v,
need_padding=need_padding, extra_tokens_n=extra_tokens_n,
num_stages=1, num_warps=num_warps
)
ctx.save_for_backward(q, k, v, o, M)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk
ctx.causal = causal
ctx.split_kernel = split_kernel
return o
@staticmethod
def backward(ctx, do):
# configuration is not supported
assert(not (ctx.split_kernel and not ctx.causal))
if torch.version.hip is not None:
BLOCK = 64
else:
BLOCK = 128
q, k, v, o, L = ctx.saved_tensors
assert do.is_contiguous()
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
do = do.contiguous()
dq = torch.zeros_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
BATCH, N_HEAD, N_CTX = q.shape[:3]
delta = torch.empty_like(L)
do_scaled = torch.empty_like(do)
# Figure out what BLOCK size fwd used and adjust num_blocks accordingly.
# If the two are the same, we don't need this but the bwd pass block size
# is smaller than the fwd so we need this scaling to ensure we loop over all
# values and don't skip some blocks.
# Alternatively we could compute a new grid but this keeps it consistent
# with fwd and easier to reason about.
block_scale = (q.shape[2] // ctx.grid[0]) // BLOCK
_attn_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do, #
do_scaled, delta, #
BLOCK_M=block_scale * BLOCK, D_HEAD=ctx.BLOCK_DMODEL, #
)
if not ctx.split_kernel:
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
L, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
block_scale * ctx.grid[0],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4,
CAUSAL=ctx.causal,
num_stages=1,
)
else :
dq = torch.zeros_like(q)
_bwd_kernel_dk_dv[(block_scale * ctx.grid[0], ctx.grid[1])](
q, k, v, ctx.sm_scale,
o, do_scaled,
dk, dv,
L, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4,
num_stages=1,
)
_bwd_kernel_dq[ctx.grid](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq,
L, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=2*BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4, waves_per_eu=1,
num_stages=1,
)
# print(h.asm["ttgir"])
return dq, dk, dv, None, None, None
attention = _attention.apply
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
[(1, 40, 19, 128),
(4, 48, 1024, 64),
(4, 48, 997, 64),
(4, 48, 2048, 64),
(4, 48, 4096, 64),
(4, 48, 3989, 64),
(4, 48, 1024, 128),
(4, 48, 1021, 128),
(4, 48, 2048, 128),
(4, 48, 4096, 128),
(4, 16, 8192, 64),
(4, 16, 8080, 64),
#(4, 48, 16384, 64)
])
@pytest.mark.parametrize('causal', [False])
def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
torch.manual_seed(20)
q = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
if TORCH_HAS_FP8E5:
q = q.to(torch_dtype)
k = k.to(torch_dtype)
sm_scale = 0.5
dout = torch.randn_like(q, dtype=torch.float16)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p = torch.matmul(q.half(), k.transpose(2, 3).half()) * sm_scale
if causal:
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
ref_out = torch.matmul(p, v)
# triton implementation
tri_out = attention(q, k, v, causal, sm_scale)
# compare
torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2)
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
[(4, 48, 1024, 64),
(4, 48, 2048, 64),
(4, 48, 4096, 64),
(1, 16, 8192, 64),
])
def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
causal = True
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
sm_scale = 0.5
split_kernel = True
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
if causal:
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
ref_out = torch.matmul(p, v)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# # triton implementation
tri_out = attention(q, k, v, causal, sm_scale, split_kernel)
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# compare
torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0)
if torch.version.hip is None:
torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=0)
# The current block size for MI200 series is 64x64. This results in
# larger differences in float results due to rounding.
else:
torch.testing.assert_close(ref_dv, tri_dv, atol=5e-2, rtol=0)
torch.testing.assert_close(ref_dk, tri_dk, atol=5e-2, rtol=1e-2)
torch.testing.assert_close(ref_dq, tri_dq, atol=5e-2, rtol=1e-2)
try:
from flash_attn.flash_attn_interface import \
flash_attn_qkvpacked_func as flash_attn_func
HAS_FLASH = True
except BaseException:
HAS_FLASH = False
# vary seq length for fixed head and batch=4
configs = []
for mode in ['fwd']:
for D_HEAD in [128]:
if mode == 'bwd' and D_HEAD == 128:
continue
for causal in [False]:
if mode == 'bwd' and causal == False:
continue
configs.append(triton.testing.Benchmark(
x_names=['BATCH', 'H','N_CTX'],
x_vals=[(16, 16, 1024),
(8, 16, 2048),
(4, 16, 4096),
(2, 16, 8192),
(1, 16, 16384),
(2, 48, 1024),
(2, 48, 2048),
(2, 48, 4096),
(2, 48, 8192),
(2, 48, 16384),
(8, 16, 1989),
(4, 16, 4097),
(2, 16, 8122),
(1, 16, 16281),
(2, 48, 1021),
(2, 48, 2001),
(2, 48, 3996),
(2, 48, 8181),
],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-{mode}-d{D_HEAD}-causal={causal}',
args={
'D_HEAD': D_HEAD,
'dtype': torch.float16,
'mode': mode,
'causal': causal})
)
@triton.testing.perf_report(configs)
def bench_flash_attention(
BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"
):
assert mode in ["fwd", "bwd"]
warmup = 25
rep = 100
split_kernel = False
# Bwd pass only supports causal=True right now
if mode == 'bwd':
causal = True
split_kernel = True
if provider == "triton":
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
if mode == "fwd":
q = q.to(torch_dtype)
k = k.to(torch_dtype)
sm_scale = 1.3
fn = lambda: attention(q, k, v, causal, sm_scale, split_kernel)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
if provider == "flash":
qkv = torch.randn(
(BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True
)
fn = lambda: flash_attn_func(qkv, causal=causal)
if mode == "bwd":
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
total_flops = 2 * flops_per_matmul
if causal:
total_flops *= 0.5
if mode == "bwd":
total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
return total_flops / ms * 1e-9
# only works on post-Ampere GPUs right now
bench_flash_attention.run(save_path=".", print_data=True)

View File

@@ -1644,7 +1644,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
('float8e4m3fnuz', 'float32'),
('float16', 'float32'),
('float32', 'float32')]
for non_k_dim in [0, 16, 32]
for non_k_dim in [0, 4, 16, 32]
if not (allow_tf32 and (in_dtype in ['float16']))] +
[(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype, non_k_dim)
@@ -1670,13 +1670,18 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
[64, 32, 32, 2],
[256, 32, 32, 2],
[256, 32, 32, 4],
[32, 8, 128, 4],
[8, 32, 128, 2],
[4, 32, 64, 4],
[32, 4, 64, 2],
[16, 4, 64, 8]
]
for allow_tf32 in [False, True]
for col_a in [True, False]
for col_b in [True, False]
for in_dtype in ['int8', 'bfloat16', 'float16', 'float32']
for out_dtype in [None]
for non_k_dim in [0, 16, 32]])
for non_k_dim in [0, 4, 16, 32]])
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, non_k_dim, device='cuda'):
capability = torch.cuda.get_device_capability()
@@ -1697,6 +1702,12 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
out_dtype = "int32"
if non_k_dim == 32 and (M < 32 or N < 32):
pytest.skip("incompatible non_k_dim == 32 with MN sizes")
if non_k_dim == 16 and (M < 16 or N < 16):
pytest.skip("incompatible non_k_dim == 16 with MN sizes")
if non_k_dim == 4 and (K < 64):
pytest.skip("incompatible non_k_dim == 4 with K size")
if non_k_dim == 4 and (M > 16 or N > 16):
pytest.skip("skipping lage matrices for non_k_dim == 4 to speedup testing")
if capability[0] < 7:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
@@ -2003,6 +2014,7 @@ def get_variant_golden(a, b):
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,NUM_STAGES', [
[64, 32, 128, 4, 64, 32, 64, 0],
[4, 16, 128, 4, 4, 16, 64, 1],
[64, 32, 128, 4, 64, 32, 64, 2]
])
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, NUM_STAGES):

View File

@@ -1190,7 +1190,7 @@ def is_hip():
def mfma_supported_granularity(m, n, k) -> bool:
# todo make this gran_type matrix element type sensitive
for gran_type in [(32, 8), (16, 16)]:
for gran_type in [(32, 8), (16, 16), (4, 64)]:
granularity_mn, granularity_k = gran_type
if m % granularity_mn != 0 or n % granularity_mn != 0:
@@ -1261,11 +1261,15 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_nu
assert len(lhs.shape) == 2, f"First input shape ({lhs.shape}) is not two dimensional!"
assert len(rhs.shape) == 2, f"Second input shape ({rhs.shape}) is not two dimensional!"
assert lhs.shape[1].value == rhs.shape[
0].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[1].value}) must be equal to first index of second shape ({rhs.shape[0].value})"
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
and rhs.shape[1].value >= 16, \
f"All values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!"
assert lhs.shape[1].value == rhs.shape[0].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[1].value}) must be equal to first index of second shape ({rhs.shape[0].value})"
if _is_cuda(builder.target):
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
and rhs.shape[1].value >= 16, \
f"All values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!"
if is_hip():
assert lhs.shape[0].value >= 4 and lhs.shape[1].value >= 16 \
and rhs.shape[1].value >= 4, \
f"All values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 4!"
# hip for now converts fp8 to fp16 for mixed input
if is_hip():

View File

@@ -69,3 +69,26 @@ On some node, I saw the following runtime error
```
It's hard to reproduce the error. **Needs further investigation**
- https://github.com/ROCmSoftwarePlatform/frameworks-internal/issues/6011
# One config running script
`one_config.py` is a script that runs one given matmul config.
It is an interface to `tune_gemm.py` functionality and could be used for triton debugging.
### Usage
This script supports two methods to specify configuration parameters.
Variant 1: Separate command line attributes.
```bash
python one_config.py -m 256 -n 256 -k 256 --block_m 64 --block_n 64 --block_k 64 --group_m 1 --split_k 2 --num_warps 2 --num_stages 0 --waves_per_eu 0
```
Variant 2: one-line config description.
This is how configs are printed by `tune_gemm.py` script
```bash
python one_config.py --config_str M16_N8_K128_BM64_BN64_BK64_GM1_SK2_nW2_nS0_EU0
```

View File

@@ -0,0 +1,78 @@
"""
Script for running one Matrix Multiplication kernel config at a time
"""
import argparse
import re
import sys
import tune_gemm
def parse_args():
parser = argparse.ArgumentParser(
prog="check corectness of particular config for tuning gemm script",
allow_abbrev=False,
)
parser.add_argument("-m", type=int, default=0)
parser.add_argument("-n", type=int, default=0)
parser.add_argument("-k", type=int, default=0)
parser.add_argument("--block_m", type=int, default=0)
parser.add_argument("--block_n", type=int, default=0)
parser.add_argument("--block_k", type=int, default=0)
parser.add_argument("--group_m", type=int, default=0)
parser.add_argument("--split_k", type=int, default=0)
parser.add_argument("--num_warps", type=int, default=0)
parser.add_argument("--num_stages", type=int, default=0)
parser.add_argument("--waves_per_eu", type=int, default=0)
parser.add_argument("--config_str", type=str, default="", help="can take from gemm_tune.py script output, looks like M16_N8_K128_BM64_BN64_BK64_GM1_SK2_nW2_nS0_EU0")
args = parser.parse_args()
return args
def parse_config(cfg_str):
values = cfg_str.split("_")
config_name = {"M": "M",
"N": "N",
"K": "K",
"BM": "BLOCK_SIZE_M",
"BN": "BLOCK_SIZE_N",
"BK": "BLOCK_SIZE_K",
"GM": "GROUP_SIZE_M",
"SK": "SPLIT_K",
"nW": "num_warps",
"nS": "num_stages",
"EU": "waves_per_eu",
}
config = {}
for val in values:
match = re.search("([a-zA-Z]*)([0-9]*)", val)
if match:
cfg_field_name = config_name[match.group(1)]
config[cfg_field_name] = int(match.group(2))
return config
def main():
args = parse_args()
if args.config_str:
config = parse_config(args.config_str)
else:
config = {"M": args.m,
"N": args.n,
"K": args.k,
"BLOCK_SIZE_M": args.block_m,
"BLOCK_SIZE_N": args.block_n,
"BLOCK_SIZE_K": args.block_k,
"GROUP_SIZE_M": args.group_m,
"SPLIT_K": args.split_k,
"num_warps": args.num_warps,
"num_stages": args.num_stages,
"waves_per_eu": args.waves_per_eu,
}
tune_gemm.test_correctness(config["M"], config["N"], config["K"], config, verbose=True)
if __name__ == "__main__":
sys.exit(main())

View File

@@ -54,6 +54,12 @@ def prune_configs(M, N, K, configs):
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
if mfma == 4 and BLOCK_SIZE_K < 64:
continue
# some layouts could not work properly in case
# number elemens per thread is less 1
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
continue
SPLIT_K = config.get("SPLIT_K")
GROUP_M = config.get("GROUP_SIZE_M")
if BLOCK_SIZE_M < mfma or BLOCK_SIZE_N < mfma:
@@ -87,9 +93,12 @@ def need_split_k(SIZE_M, SIZE_N, SIZE_K):
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024
def run_bash_command(commandstring):
proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash', stdout = subprocess.PIPE)
return proc.stdout.splitlines()
def run_bash_command(commandstring, capture=True):
if capture:
proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash', stdout = subprocess.PIPE)
return proc.stdout.splitlines()
proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash')
return None
def read_config(config):
@@ -113,7 +122,7 @@ def matmul_{configStr}(a, b, c, M, N, K, am, ak, bk, bn, cm, cn, warmup=False):
#M, K = a.shape
#K, N = b.shape
grid = triton.cdiv(M, {block_m}) * triton.cdiv(N, {block_n}), {split_k}
print(f'config: matmul_kernel_{configStr}')
print(f'config: matmul_kernel_{configStr}', flush=True)
if warmup:
matmul_kernel_{configStr}.warmup(
torch.float16, torch.float16, torch.float16,
@@ -129,6 +138,7 @@ def matmul_{configStr}(a, b, c, M, N, K, am, ak, bk, bn, cm, cn, warmup=False):
waves_per_eu = {waves_per_eu},
grid=(1,)
)
return None
else:
matmul_kernel_{configStr}[grid](
a, b, c,
@@ -143,7 +153,7 @@ def matmul_{configStr}(a, b, c, M, N, K, am, ak, bk, bn, cm, cn, warmup=False):
num_stages = {num_stages},
waves_per_eu = {waves_per_eu}
)
return c
return c
def try_config_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, dtype):
#a = torch.randn((M, K), device='cuda', dtype=dtype)
@@ -151,13 +161,20 @@ def try_config_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, dtype):
#c = torch.zeros((M, N), device=a.device, dtype=a.dtype)
try:
matmul_{configStr}(None, None, None, M, N, K, am, ak, bk, bn, cm, cn, True)
except Exception:
print(f'invalid config {configStr}')
return True
except Exception as e:
print(f'invalid config(compilation): {configStr}: ', e, flush=True)
return False
"""
return configStr, matmul_def_str
## Open {ngpus} files
## generated_kernelMNK-0.py, generated_kernelMNK-1.py, ..., generated_kernelMNK-{ngpus-1}.py
def generated_kernel_name(M, N, K, gpu_id):
return f"generated_kernel{M}-{N}-{K}-{gpu_id}.py"
## Open {len(gpus)} files
## generated_kernelM-N-K-{gpus[0]}.py, generated_kernelM-N-K-{gpus[1]}.py, ..., generated_kernelM-N-K-{gpus[-1]}.py
## and generate
## 1. matmul kernels of all configs
## 2. wrapper function matmul to invoke all the generated kernels
@@ -165,10 +182,11 @@ def try_config_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, dtype):
## 4. test_gemm to invoke
## 4.1 run try_config in parallel
## 4.2 matmul in a loop of 10 iterations
def generate_kernel(M, N, K, configs, ngpus):
def generate_kernel(M, N, K, configs, gpus):
filenames = []
for fi in range(ngpus):
filenames.append(f"generated_kernel{M}{N}{K}-{fi}.py")
ngpus = len(gpus)
for gpu_id in gpus:
filenames.append(generated_kernel_name(M, N, K, gpu_id))
f_kernel = [open(path, 'w') for path in filenames]
### write imports
@@ -185,7 +203,7 @@ import multiprocessing
### write definitions of matmul_kernel_xxx
### and matmul_xxx and try_config
with open("matmul_kernel.py") as file:
matmul_kernel_code = file.read();
matmul_kernel_code = file.read()
idx = 0
for config in configs:
file_idx = idx % ngpus
@@ -211,6 +229,8 @@ import multiprocessing
c.stride(0), c.stride(1), dtype)
if num_threads > 1:
results = []
config_names = []
"""
for fi in range(ngpus):
f_kernel[fi].write(test_gemm_pre_str + "\n")
@@ -219,23 +239,40 @@ import multiprocessing
idx = 0
for config in configs:
configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config)
task_str = f" thread_pool.apply_async(try_config_{configStr}, args=task_args)\n"
task_str = f" results += [thread_pool.apply_async(try_config_{configStr}, args=task_args)]\n" + \
f" config_names += ['{configStr}']\n"
f_kernel[idx % ngpus].write(task_str)
idx += 1
threadpool_str = """
for fi in range(ngpus):
threadpool_str = """
failed_configs = []
for i in range(len(results)):
results[i].wait()
res = results[i].get()
if not res:
failed_configs += [config_names[i]]
thread_pool.close()
thread_pool.join()
else:"""
for fi in range(ngpus):
with open("{filename}.failed_configs", "w") as f:
for cfg in failed_configs:
f.write(cfg + "\\n")
else:
try:
with open("{filename}.failed_configs", "r") as f:
failed_configs = [cfg.strip() for cfg in f.readlines()]
except Exception:
failed_configs = []
""".format(filename = filenames[fi])
f_kernel[fi].write(threadpool_str)
# call all matmul_xxx functions
idx = 0
for config in configs:
configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config)
matmul_call_str = f"""
for i in range(10):
d = matmul_{configStr}(a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1))"""
if '{configStr}' not in failed_configs:
for i in range(10):
d = matmul_{configStr}(a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1))"""
f_kernel[idx % ngpus].write(matmul_call_str + "\n")
idx += 1
# post string
@@ -267,30 +304,29 @@ def extract_kernel_time(M, N, K, config, gpuid):
return config, parsed_outputs
def profile_batch_kernels(M, N, K, gpuid):
def profile_batch_kernels(M, N, K, gpuid, verbose):
os.environ['ROCR_VISIBLE_DEVICES'] = str(gpuid)
run_bash_command(f"rocprof --stats -o results-{gpuid}.csv python generated_kernel{M}{N}{K}-{gpuid}.py")
run_bash_command(f"rocprof --stats -o results-{gpuid}.csv python {generated_kernel_name(M, N, K, gpuid)}", capture=(verbose < 2))
def tune_gemm_config(M, N, K, configs, verbose=False, num_threads=16, ngpus = 1):
def tune_gemm_config(M, N, K, configs, verbose=0, num_threads=16, gpus = [0]):
## Generate kernel out of all configs
generate_kernel(M, N, K, configs, ngpus)
generate_kernel(M, N, K, configs, gpus)
## remove any compiled kernel in the cache
run_bash_command("rm -rf ~/.triton/cache")
## precompile the kernels in parallel
## TODO: parameterize numThreads at this level
start_time = datetime.now()
for fi in range(ngpus):
run_bash_command(f"python generated_kernel{M}{N}{K}-{fi}.py -n 32")
for gpu_id in gpus:
run_bash_command(f"python {generated_kernel_name(M, N, K, gpu_id)} -n {num_threads}", capture=(verbose < 2))
compile_end = datetime.now()
compile_time = compile_end - start_time
if verbose:
print(f"compile time: {compile_time}")
print(f"compile time: {compile_time}", flush=True)
## profile generated kernels
running = [multiprocessing.Process(target=profile_batch_kernels, args=(M,N,K,fi)) for fi in range(ngpus)]
running = [multiprocessing.Process(target=profile_batch_kernels, args=(M,N,K,gpu_id,verbose)) for gpu_id in gpus]
for p in running:
p.start()
for p in running:
@@ -299,7 +335,7 @@ def tune_gemm_config(M, N, K, configs, verbose=False, num_threads=16, ngpus = 1)
profile_end = datetime.now()
profile_time = profile_end - compile_end
if verbose:
print(f"profile time: {profile_time}")
print(f"profile time: {profile_time}", flush=True)
## post process results.csv to get the best config and minTime
## TODO: process the file in parallel
@@ -308,8 +344,9 @@ def tune_gemm_config(M, N, K, configs, verbose=False, num_threads=16, ngpus = 1)
tasks = []
idx = 0
for config in configs:
file_idx = idx % ngpus
tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, file_idx))]
file_idx = idx % len(gpus)
gpu_id = gpus[file_idx]
tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, gpu_id))]
idx += 1
thread_pool.close()
thread_pool.join()
@@ -323,11 +360,11 @@ def tune_gemm_config(M, N, K, configs, verbose=False, num_threads=16, ngpus = 1)
bestConfig = config
else:
min_us = -1
print(f"invalid config: SIZE {M} {N} {K}: {config}")
print(f"invalid config(post processing): SIZE {M} {N} {K}: {config}", flush=True)
post_end = datetime.now()
post_time = post_end - profile_end
if verbose:
print(f"post procesing time: {post_time}")
print(f"post procesing time: {post_time}", flush=True)
return minTime, bestConfig, compile_time, profile_time, post_time
@@ -379,9 +416,9 @@ def test_correctness(M, N, K, config, verbose, datatype = torch.float16):
if verbose:
size_str = f'SIZE M: {M}, N: {N}, K: {K} '
if torch.allclose(triton_output, torch_output, atol=1e-1, rtol=rtol):
print(f'{size_str}')
print(f'{size_str} Correct')
else:
print(f'{size_str}')
print(f'{size_str} Incorrect')
def get_default_tuning_result_filename():
@@ -404,13 +441,16 @@ def parse_args():
parser.add_argument("-m", type=int, default=0)
parser.add_argument("-n", type=int, default=0)
parser.add_argument("-k", type=int, default=0)
parser.add_argument("--ngpus", type=int, default=1, help='number of GPUs used in the profiling step')
parser.add_argument("--ngpus", type=int, default=0, help='number of GPUs used in the profiling step')
parser.add_argument("--gpu_ids", type=lambda s: [int(id) for id in s.split(',')], default=[], help='list of gpu ids to use for tuning')
parser.add_argument("--gemm_size_file", type=str, default="", help='yaml file to indicate matrix size')
parser.add_argument("--tuning_results_file", type=str, default=get_default_tuning_result_filename(), help='yaml file to store tuning results')
parser.add_argument("--keep", action='store_true', default=False, help='keep generated files')
parser.add_argument("--compare", action='store_true', default=False, help="Whether check result correctness")
parser.add_argument("--compare_wo_tuning", action='store_true', default=False, help="Whether check result correctness")
parser.add_argument("--time_breakdown", action='store_true', default=False, help="Show detailed time breakdown of each step during the tuning")
parser.add_argument("--verbose", action='store_true', default=False, help="enables time_breakdown and additional logging messages")
parser.add_argument("--num_threads", type=int, default=16, help="number of threads to use for kernel compilation and post processing")
args = parser.parse_args()
return args
@@ -422,6 +462,16 @@ def main():
tuning_output_file = args.tuning_results_file
keepTmp = args.keep
ngpus = args.ngpus
gpu_ids = args.gpu_ids
if ngpus != 0 and gpu_ids:
print("--ngpus and --gpu_ids are mutually exclusive options")
return os.EX_USAGE
if ngpus == 0 and not gpu_ids:
ngpus = 1
if ngpus != 0:
gpus = range(ngpus)
if gpu_ids:
gpus = gpu_ids
mnks = []
## TODO: make it more robust to get user input
@@ -454,7 +504,7 @@ def main():
configs_full = get_full_tuning_space()
start_time = datetime.now()
print(f"Tuning starts at: {start_time}")
print(f"Tuning starts at: {start_time}", flush=True)
f_results = open(tuning_output_file, 'w')
for (M, N, K) in mnks:
@@ -466,7 +516,12 @@ def main():
print(f"{size_str} nConfigs: {len(pruned_configs)}", end=" ", flush=True)
## The main tuning funtion for one gemm size
minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config(M, N, K, pruned_configs, ngpus = ngpus, verbose=args.time_breakdown)
verbose_level = 0
if args.time_breakdown:
verbose_level = 1
if args.verbose:
verbose_level = 2
minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config(M, N, K, pruned_configs, num_threads=args.num_threads, gpus = gpus, verbose=verbose_level)
## post processing the numbers
perf_tflops = lambda us: 2 * M * N * K * 1e-12 / (us * 1e-6)
@@ -475,10 +530,10 @@ def main():
formatted_tflops = "{:.3e}".format(tri_tflops)
else:
formatted_tflops = "{:.2f}".format(tri_tflops)
print(f'TFLOPS: {formatted_tflops} time(us): {minTime}', end=" ")
print(f'TFLOPS: {formatted_tflops} time(us): {minTime}', end=" ", flush=True)
bestConfig_compact_str, _ = gen_kernel_and_configStr_from_config(M, N, K, bestConfig)
print(f'best_config: {bestConfig_compact_str}', end=" ")
print(f'best_config: {bestConfig_compact_str}', end=" ", flush=True)
## write best config to tuning_results.yaml
sizeDict = {'M': M, 'N': N, 'K': K}
@@ -488,20 +543,22 @@ def main():
## remove generated files if asked to
if not keepTmp:
for fi in range(ngpus):
os.remove(f"generated_kernel{M}{N}{K}-{fi}.py")
for f in glob.glob(f"results-{fi}.*"):
for gpu_id in gpus:
generated_script = generated_kernel_name(M, N, K, gpu_id)
os.remove(generated_script)
os.remove(generated_script + ".failed_configs")
for f in glob.glob(f"results-{gpu_id}.*"):
os.remove(f)
## Check correctness if asked to
if args.compare:
print("correctness: ", end=" ")
print("correctness: ", end=" ", flush=True)
test_correctness(M, N, K, bestConfig, False)
else:
print("")
print("", flush=True)
end_local_time = datetime.now()
print(f">>> Elapsed time: {end_local_time - start_local_time} = {compile_time} (compile) + {profile_time} (profile) + {post_time} (post processing)")
print(f">>> Elapsed time: {end_local_time - start_local_time} = {compile_time} (compile) + {profile_time} (profile) + {post_time} (post processing)", flush=True)
f_results.close()

350
scripts/amd/plot_layout.py Executable file
View File

@@ -0,0 +1,350 @@
import argparse
import sys
import yaml
import os
import glob
import subprocess
def draw_preamble_cmd():
return '''\\documentclass[tikz, border=1mm, dvipsnames]{standalone}
\\usepackage{ifthen}
\\usepackage{tikz}
\\usetikzlibrary{arrows.meta,arrows}
\\usetikzlibrary{intersections}
\\usetikzlibrary{calc, quotes}
\\usetikzlibrary{patterns}
\\usepackage{xparse}
\\ExplSyntaxOn
\\NewExpandableDocumentCommand{\\bitwiseXor}{mm}
{
\\recuenco_bitwise_xor:nn { #1 } { #2 }
}
\\cs_new:Nn \\recuenco_bitwise_xor:nn
{
\\int_from_bin:e
{
\\__recuenco_bitwise_xor:ee { \\int_to_bin:n { #1 } } { \\int_to_bin:n { #2 } }
}
}
\\cs_generate_variant:Nn \\int_from_bin:n { e }
\\cs_new:Nn \\__recuenco_bitwise_xor:nn
{
\\__recuenco_bitwise_xor_binary:ee
{
\\prg_replicate:nn
{
\\int_max:nn { \\tl_count:n { #1 } } { \\tl_count:n { #2 } } - \\tl_count:n { #1 }
}
{ 0 }
#1
}
{
\\prg_replicate:nn
{
\\int_max:nn { \\tl_count:n { #1 } } { \\tl_count:n { #2 } } - \\tl_count:n { #2 }
}
{ 0 }
#2
}
}
\\cs_generate_variant:Nn \\__recuenco_bitwise_xor:nn { ee }
\\cs_new:Nn \\__recuenco_bitwise_xor_binary:nn
{
\\__recuenco_bitwise_xor_binary:w #1;#2;
}
\\cs_generate_variant:Nn \\__recuenco_bitwise_xor_binary:nn { ee }
\\cs_new:Npn \\__recuenco_bitwise_xor_binary:w #1#2;#3#4;
{
\\int_abs:n { #1-#3 }
\\tl_if_empty:nF { #2 } { \\__recuenco_bitwise_xor_binary:w #2;#4; }
}
\\ExplSyntaxOff'''
def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kpack):
return f'''\\begin{{document}}
\\begin{{tikzpicture}}
\\def\\scale{{1}}
\\def\\elem{{0.04}}
\\coordinate (C TL) at (0,0);
\\def\\opColorAL{{magenta}}
\\def\\opColorAR{{cyan}}
\\def\\opColorBL{{Maroon}}
\\def\\opColorBR{{BlueGreen}}
\\drawDot{{{M}}}{{{N}}}{{{K}}}{{{mfmaNonKDim}}}{{{warpsPerCTA[0]}}}{{{warpsPerCTA[1]}}}{{{trans}}}{{{kpack}}}
\\coordinate (C TL) at ($(C TL)+({N}*\elem+32*\elem, 0)$);
\\def\\mfmaTrans{{{trans}}}
\\ifthenelse{{\\mfmaTrans=0}}{{
\\def\\opColorAL{{magenta}}
\\def\\opColorAR{{cyan}}
\\def\\opColorBL{{Maroon}}
\\def\\opColorBR{{BlueGreen}}
}}{{
\\def\\opColorBL{{magenta}}
\\def\\opColorBR{{cyan}}
\\def\\opColorAL{{Maroon}}
\\def\\opColorAR{{BlueGreen}}
}}
%% Draw zoomed in view of mfma
\\def\\elem{{.16}}
\\pgfmathsetmacro{{\\gap}}{{\\elem*5}}
\\pgfmathsetmacro{{\\nonTrans}}{{1-\\mfmaTrans}}
\\coordinate (C TL) at ($(C TL)+(.5*\\gap+1.2*\\nonTrans*\\gap+2*{kpack}*\\elem, 0)$);
\\drawMFMAInstr{{{mfmaNonKDim}}}{{{kpack}}}{{\\mfmaTrans}}
\\end{{tikzpicture}}
\\end{{document}}'''
def draw_blocked_layout_cmd(M, K, sizePerThread, threadsPerWarp, warpsPerCTA,
order):
return f'''\\begin{{document}}
\\begin{{tikzpicture}}
\\def\\scale{{1}}
\\def\\elem{{0.06}}
\\coordinate (TL) at (0,0);
\\drawBlockedTensor{{{M}}}{{{K}}}{{{sizePerThread[0]}}}{{{sizePerThread[1]}}}{{{threadsPerWarp[0]}}}{{{warpsPerCTA[0]}}}{{{warpsPerCTA[1]}}}{{{order[0]}}}
\\end{{tikzpicture}}
\\end{{document}}'''
def draw_lds_access_cmd(M, K, kpack, ldsLayout, ldsAccess, sizePerThread,
threadsPerWarp):
if ldsLayout == 'swizzle':
hasSwizzle = 1
elif ldsLayout == 'padding':
hasSwizzle = 2
else:
hasSwizzle = 0
if ldsAccess == 'read':
accessMode = 1
elif ldsAccess == 'write':
accessMode = 2
else:
accessMode = 0
return f'''\\begin{{document}}
\\begin{{tikzpicture}}
\\def\\scale{{1}}
\\def\\M{{{M}}}
\\def\\K{{{K}}}
\\def\\vec{{{kpack}}}
\\def\\hasSwizzle{{{hasSwizzle}}}
\\def\\accessMode{{{accessMode}}}
\\def\\sizePerThreadK{{{sizePerThread[1]}}}
\\def\\sizePerThreadM{{{sizePerThread[0]}}}
\\def\\threadsPerWarpK{{{threadsPerWarp[1]}}}
\\def\\elem{{0.18}}
\\coordinate (TL) at (0,0);
\\drawTensorLayoutGlobalMem
\\coordinate (TL) at ($(TL)+(0, -24*\\elem-10*\\elem)$);
\\drawLDSLayoutTritonSwizzling{{\\hasSwizzle}}{{\\accessMode}}
\\end{{tikzpicture}}
\\end{{document}}'''
def draw_wmma_instr_cmd(waveSize):
wmma_mode = 0 if waveSize == 32 else 1
return f'''\\begin{{document}}
\\begin{{tikzpicture}}
\\def\\scale{{1}}
\\coordinate (C TL) at (0,0);
\\def\\elem{{0.25}}
\\drawWMMAInstr{{{wmma_mode}}}{{1}}
\\end{{tikzpicture}}
\\end{{document}}'''
def run_bash_command(commandstring):
proc = subprocess.run(commandstring,
shell=True,
check=True,
executable='/bin/bash',
stdout=subprocess.PIPE)
return proc.stdout.splitlines()
def parse_args():
parser = argparse.ArgumentParser(
prog="Draw triton layouts",
allow_abbrev=False,
)
## tensor shapes
parser.add_argument("-shape",
type=int,
nargs=3,
default=(32, 128, 64),
help='Tensor shape in the form of M,N,K')
parser.add_argument("-plot",
type=str,
default="blocked",
choices=['blocked', 'dot', 'wmma', 'lds'],
help='choose plot mode')
parser.add_argument(
"-nonKDim",
type=int,
default=32,
choices=[32],
help='mfma instruction dim, only 32 is supported for now')
## blocked layout parameters
parser.add_argument("-sizePerThread", type=int, nargs=2, default=(1, 4))
parser.add_argument("-threadsPerWarp", type=int, nargs=2, default=(16, 4))
parser.add_argument("-warpsPerCTA", type=int, nargs=2, default=(1, 4))
parser.add_argument("-order", type=int, nargs=2, default=(1, 0))
## LDS access parameters
parser.add_argument("-kpack",
type=int,
default=4,
choices=[4, 8],
help='vector length during LDS load, same as vec')
parser.add_argument("-lds_layout",
type=str,
default="none",
choices=['swizzle', 'padding', 'none'],
help='choose the LDS data layout')
parser.add_argument("-lds_access",
type=str,
default="none",
choices=['read', 'write', 'none'],
help='choose LDS access mode')
## wmma instruction layout parameter
parser.add_argument("-wave_size",
type=int,
default=32,
choices=[32, 64],
help='choose the wmma instruction mode')
parser.add_argument("-o",
type=str,
default="myplot",
help='output pdf file name (without surfix)')
parser.add_argument("-mfmaTrans",
action='store_true',
default=False,
help='If set, then use mfma.trans layout')
parser.add_argument("--keep",
action='store_true',
default=False,
help='If set, keep the generated .tex file')
args = parser.parse_args()
return args
def main():
args = parse_args()
shape = args.shape
M = shape[0]
N = shape[1]
K = shape[2]
plot_mode = args.plot
mfmaNonKDim = args.nonKDim
kpack = args.kpack
trans = 1 if args.mfmaTrans else 0
ofilename = args.o
keepSrc = args.keep
ldsLayout = args.lds_layout
ldsAccess = args.lds_access
waveSize = args.wave_size
sizePerThread = args.sizePerThread
threadsPerWarp = args.threadsPerWarp
warpsPerCTA = args.warpsPerCTA
order = args.order
CTAShape = []
if plot_mode == 'blocked':
print(f"Plotting tensor M={M},K={K} with blocked layout:")
print(f"sizePerThread={sizePerThread}", end=" ")
print(f"threadsPerWarp={threadsPerWarp}", end=" ")
print(f"warpsPerCTA={warpsPerCTA}", end=" ")
print(f"order={order}", end=" ")
CTAShape.append(sizePerThread[0] * threadsPerWarp[0] * warpsPerCTA[0])
CTAShape.append(sizePerThread[1] * threadsPerWarp[1] * warpsPerCTA[1])
if plot_mode == 'dot':
mfma_inst_str = "mfma_32x32x8f16" if mfmaNonKDim == 32 else "mfma_16x16x16f16"
mfma_trans_str = ".trans" if trans else ""
print(f"Plotting dot operation with shapes M={M},N={N},K={K}")
print("MFMA: " + mfma_inst_str + mfma_trans_str, end=" ")
print(f"warpsPerCTA={warpsPerCTA}", end=" ")
CTAShape.append(32 * warpsPerCTA[0])
CTAShape.append(32 * warpsPerCTA[1])
if plot_mode == 'blocked' or plot_mode == 'dot':
print(f"CTAShape={CTAShape}")
assert M != 0 and CTAShape[
0] <= M and M % CTAShape[0] == 0, "bad tensor dimension M"
if plot_mode == 'blocked':
assert K != 0 and CTAShape[
1] <= K and K % CTAShape[1] == 0, "bad tensor dimension K"
if plot_mode == 'dot':
assert N != 0 and CTAShape[
1] <= N and N % CTAShape[1] == 0, "bad tensor dimension N"
assert K != 0 and K % (2 * kpack) == 0, "bad tensor dimension K"
if plot_mode == 'lds':
print(f"Plotting LDS access for tensor M={M},K={K} with vec={kpack}")
if ldsAccess == 'write':
print(
f"sizePerThread={sizePerThread}, threadsPerWarp={threadsPerWarp}"
)
with open("myplot.tex", 'w') as f_plot:
with open("tikzplot.tex") as file:
tikz_code = file.read()
preamble_str = draw_preamble_cmd()
draw_blockedLayout_str = draw_blocked_layout_cmd(
M, K, sizePerThread, threadsPerWarp, warpsPerCTA, order)
draw_dotLayout_str = draw_dot_layout_cmd(M, N, K, mfmaNonKDim,
warpsPerCTA, trans, kpack)
draw_lds_str = draw_lds_access_cmd(M, K, kpack, ldsLayout, ldsAccess,
sizePerThread, threadsPerWarp)
draw_wmma_str = draw_wmma_instr_cmd(waveSize)
f_plot.write(preamble_str + "\n")
f_plot.write(tikz_code)
if plot_mode == 'blocked':
f_plot.write(draw_blockedLayout_str)
elif plot_mode == 'dot':
f_plot.write(draw_dotLayout_str)
elif plot_mode == 'lds':
f_plot.write(draw_lds_str)
elif plot_mode == 'wmma':
f_plot.write(draw_wmma_str)
run_bash_command(f"pdflatex -jobname {ofilename} myplot.tex")
print(f"plot saved in {ofilename}.pdf")
## Remove au files
os.remove(f"{ofilename}.aux")
os.remove(f"{ofilename}.log")
if not keepSrc:
os.remove("myplot.tex")
run_bash_command("rm -rf ./auto")
if __name__ == '__main__':
sys.exit(main())

874
scripts/amd/tikzplot.tex Executable file
View File

@@ -0,0 +1,874 @@
\newcommand{\drawBlockedWave}[5]{
%%
%% Draw a wave coverage with blocked layout
%%
%% Wave TL: pre defined top-left coordinate of the wave
%% \elem: pre defined variable
%%
%% #1: sizePerThread[0] --> sizePerThreadM
%% #2: sizePerThread[1] --> sizePerThreadN
%% #3: threadsPerWarp[0] --> threadsPerWarpM
%% #4: threadsPerWarp[1] --> threadsPerWarpN
%% #5: fastest changing dim --> order
\pgfmathsetmacro{\sizePerThreadM}{#1}
\pgfmathsetmacro{\sizePerThreadN}{#2}
\pgfmathsetmacro{\threadsPerWarpM}{#3}
\pgfmathsetmacro{\threadsPerWarpN}{#4}
\pgfmathsetmacro{\order}{#5}
\pgfmathsetmacro{\waveSizeM}{\sizePerThreadM*\threadsPerWarpM}
\pgfmathsetmacro{\waveSizeN}{\sizePerThreadN*\threadsPerWarpN}
\foreach \tid in {0,...,63}{
\pgfmathsetmacro{\tidM}{int(\tid/\threadsPerWarpN)}
\pgfmathsetmacro{\tidN}{mod(\tid,\threadsPerWarpN)}
\coordinate (Thread TL) at ($(Wave TL)+(\tidN*\sizePerThreadN*\elem, -\tidM*\sizePerThreadM*\elem)$);
\pgfmathsetmacro{\ratio}{\tidM*10}
\ifthenelse{\tid = 0}{
\draw [line width = 0.01mm, fill=red] (Thread TL)
rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem);
}{
\draw [line width = 0.01mm, fill=blue!\ratio!white] (Thread TL)
rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem);
}
}
\draw (Wave TL) rectangle ++(\waveSizeN*\elem, -\waveSizeM*\elem);
}
\newcommand{\drawBlockedCTA}[7]{
%%
%% Draw a CTA coverage with blocked layout
%%
%% CTA TL: pre defined top-left coordinate of the CTA
%% \elem: pre defined variable
%%
%% #1: sizePerThread[0] --> sizePerThreadM
%% #2: sizePerThread[1] --> sizePerThreadN
%% #3: threadsPerWarp[0] --> threadsPerWarpM
%% #4: threadsPerWarp[1] --> threadsPerWarpN
%% #5: warpsPerCTA[0] --> warpsPerCTAM
%% #6: warpsPerCTA[1] --> warpsPerCTAN
%% #7: fastest changing dim --> order
\pgfmathsetmacro{\sizePerThreadM}{#1}
\pgfmathsetmacro{\sizePerThreadN}{#2}
\pgfmathsetmacro{\threadsPerWarpM}{#3}
\pgfmathsetmacro{\threadsPerWarpN}{#4}
\pgfmathsetmacro{\warpsPerCTAM}{#5}
\pgfmathsetmacro{\warpsPerCTAN}{#6}
\pgfmathsetmacro{\order}{#7}
\pgfmathsetmacro{\CTASizeM}{\sizePerThreadM*\threadsPerWarpM*\warpsPerCTAM}
\pgfmathsetmacro{\CTASizeN}{\sizePerThreadN*\threadsPerWarpN*\warpsPerCTAN}
\pgfmathsetmacro{\waveSizeM}{\sizePerThreadM*\threadsPerWarpM}
\pgfmathsetmacro{\waveSizeN}{\sizePerThreadN*\threadsPerWarpN}
\pgfmathsetmacro{\maxWaveId}{\warpsPerCTAM*\warpsPerCTAN-1}
\coordinate (Wave TL) at (CTA TL);
\drawBlockedWave{\sizePerThreadM}{\sizePerThreadN}{\threadsPerWarpM}{\threadsPerWarpN}{\order}
\foreach \waveId in {0,...,\maxWaveId}{
\ifthenelse{\order=1}
{
\pgfmathsetmacro{\waveCoordM}{int(\waveId/\warpsPerCTAN)}
\pgfmathsetmacro{\waveCoordN}{mod(\waveId,\warpsPerCTAN)}
\pgfmathsetmacro{\rot}{0}
}{
\pgfmathsetmacro{\waveCoordM}{mod(\waveId,\warpsPerCTAM)}
\pgfmathsetmacro{\waveCoordN}{int(\waveId/\warpsPerCTAM)}
\pgfmathsetmacro{\rot}{90}
}
\coordinate (Wave TL) at ($(CTA TL)+(\waveCoordN*\waveSizeN*\elem, -\waveCoordM*\waveSizeM*\elem)$);
\draw [ultra thin] (Wave TL) rectangle ++(\waveSizeN*\elem, -\waveSizeM*\elem)
node [pos=.5, scale=.6*\scale, inner sep=0, fill=white, rotate=\rot] {wave\waveId};
}
\draw [thick] (CTA TL) rectangle ++(\CTASizeN*\elem, -\CTASizeM*\elem);
}
\newcommand{\drawBlockedTensor}[8]{
%%
%% Draw a tensor with blocked layout of the following parameters
%% sizePerThread[2]
%% threadsPerWarp[2]
%% warpsPerCTA[2]
%% order[2]
%%
%% TL: pre defined top-left coordinate of the tensor
%% \elem: pre defined variable
%%
%% #1: tensorShape[0] --> M
%% #2: tensorShape[1] --> N
%% #3: sizePerThread[0] --> sizePerThreadM
%% #4: sizePerThread[1] --> sizePerThreadN
%% #5: threadsPerWarp[0] --> threadsPerWarpM
%% Note that threadsPerWarp[1] is calculated by 64/threadsPerWarp[0]
%% #6: warpsPerCTA[0] --> warpsPerCTAM
%% #7: warpsPerCTA[1] --> warpsPerCTAN
%% #8: fastest changing dim --> order
\pgfmathsetmacro{\M}{#1}
\pgfmathsetmacro{\N}{#2}
\pgfmathsetmacro{\sizePerThreadM}{#3}
\pgfmathsetmacro{\sizePerThreadN}{#4}
\pgfmathsetmacro{\threadsPerWarpM}{#5}
\pgfmathsetmacro{\warpsPerCTAM}{#6}
\pgfmathsetmacro{\warpsPerCTAN}{#7}
\pgfmathsetmacro{\order}{#8}
\pgfmathsetmacro{\threadsPerWarpN}{64/\threadsPerWarpM}
\pgfmathsetmacro{\CTASizeM}{\sizePerThreadM*\threadsPerWarpM*\warpsPerCTAM}
\pgfmathsetmacro{\CTASizeN}{\sizePerThreadN*\threadsPerWarpN*\warpsPerCTAN}
\pgfmathsetmacro{\CTARepM}{\M/\CTASizeM}
\pgfmathsetmacro{\CTARepN}{\N/\CTASizeN}
\pgfmathsetmacro{\maxCTAId}{\CTARepM*\CTARepN-1}
\foreach \ctaId in {0,...,\maxCTAId}{
\pgfmathsetmacro{\ctaCoordM}{int(\ctaId/\CTARepN)}
\pgfmathsetmacro{\ctaCoordN}{mod(\ctaId,\CTARepN)}
\coordinate (CTA TL) at ($(TL)+(\ctaCoordN*\CTASizeN*\elem, -\ctaCoordM*\CTASizeM*\elem)$);
\drawBlockedCTA{\sizePerThreadM}{\sizePerThreadN}{\threadsPerWarpM}{\threadsPerWarpN}{\warpsPerCTAM}{\warpsPerCTAN}{\order}
}
\node [scale=.7*\scale, above, rotate=90] at ($(TL)+(0, -.5*\M*\elem)$) {M=\M};
\node [scale=.7*\scale, above] at ($(TL)+(.5*\N*\elem, 0)$) {K=\N};
\def\zoomR{1.5}
\coordinate (zoomin BL) at ($(TL)+(0, .3)$);
\foreach \hl in {0,...,\sizePerThreadM}{
\draw ($(zoomin BL)+(0, \hl*\elem*\zoomR)$) -- ++(\sizePerThreadN*\elem*\zoomR,0);
}
\foreach \vl in {0,...,\sizePerThreadN}{
\draw ($(zoomin BL)+(\vl*\elem*\zoomR, 0)$) -- ++(0, \sizePerThreadM*\elem*\zoomR);
}
\node [scale=.6*\scale, left] at ($(zoomin BL)+(0, .5*\sizePerThreadM*\elem*\zoomR)$) {$t_0$};
\node [scale=.6*\scale, right] at ($(zoomin BL)+(\sizePerThreadN*\elem*\zoomR, .5*\sizePerThreadM*\elem*\zoomR)$) {\sizePerThreadM$\times$\sizePerThreadN};
\draw [densely dotted] (TL) -- (zoomin BL);
\draw [densely dotted] ($(TL)+(\sizePerThreadN*\elem, 0)$) -- ($(zoomin BL)+(\sizePerThreadN*\elem*\zoomR, 0)$);
\draw [fill=red] (TL) rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem);
}
\newcommand{\drawBlockMFMALayoutLarge}[2]{
%%
%% Draw a single block of MFMA_32x32x8xf16
%%
%% block TL: pre-defined top-left coordinate of the block
%% \elem: pre defined variable
%%
%% #1: 1 for mfma.trans, 0 for normal mfma
%% #2: verbose. 1 means draw tid in each vec; 0 means draw nothing
\pgfmathsetmacro{\trans}{#1}
\pgfmathsetmacro{\nonTrans}{1-#1}
\pgfmathsetmacro{\verbose}{#2}
\foreach \iVec in {0,1,2,3} {
\coordinate (wave TL) at ($(block TL)+(\trans*\iVec*2*4*\elem, -\nonTrans*\iVec*2*4*\elem)$);
\foreach \col/\tg in {blue/0,orange/1}{
\foreach \tid in {0,...,31} {
\pgfmathsetmacro{\ratio}{\tid*2.5+15}
\ifthenelse{\verbose=0}{
\draw [line width=0.005mm, fill=\col!\ratio!white]
($(wave TL)+(\nonTrans*\tid*\elem+\tg*\trans*4*\elem, -\trans*\tid*\elem-\tg*\nonTrans*4*\elem)$)
rectangle ++(\nonTrans*\elem+\trans*4*\elem, -\nonTrans*4*\elem-\trans*\elem);
}{
\pgfmathsetmacro{\drawTid}{int(\tid+\tg*32)}
\draw [line width=0.005mm, fill=\col!\ratio!white]
($(wave TL)+(\nonTrans*\tid*\elem+\tg*\trans*4*\elem, -\trans*\tid*\elem-\tg*\nonTrans*4*\elem)$)
rectangle ++(\nonTrans*\elem+\trans*4*\elem, -\nonTrans*4*\elem-\trans*\elem)
node [pos=.5, scale=.35*\scale, rotate=90*\nonTrans] {t\drawTid};
}
}
}
}
\draw [thick] (block TL) rectangle ++(32*\elem, -32*\elem);
}
\newcommand{\drawTensorMFMALayout}[6]{
%%
%% Draw a tensor with mfma layout.
%%
%% C TL: pre defined top-left coordinates of the tensor
%%
%% #1: M
%% #2: N
%% #3: MFMA nonKDim
%% #4: warpsPerCTA[0]
%% #5: warpsPerCTA[1]
%% #6: 1 for mfma.trans, 0 for normal mfma
\pgfmathsetmacro{\tensorShapeH}{#1}
\pgfmathsetmacro{\tensorShapeW}{#2}
\pgfmathsetmacro{\mfmaNonKDim}{#3}
\pgfmathsetmacro{\warpsPerCTAH}{#4}
\pgfmathsetmacro{\warpsPerCTAW}{#5}
\pgfmathsetmacro{\mfmaTrans}{#6}
\coordinate (old TL) at (TL);
\coordinate (TL) at (C TL);
\pgfmathsetmacro{\CTARepH}{\tensorShapeH/\mfmaNonKDim/\warpsPerCTAH}
\pgfmathsetmacro{\CTARepW}{\tensorShapeW/\mfmaNonKDim/\warpsPerCTAW}
\pgfmathsetmacro{\maxCTAId}{\CTARepH*\CTARepW-1}
\pgfmathsetmacro{\maxWaveId}{\warpsPerCTAH*\warpsPerCTAW-1}
\pgfmathsetmacro{\CTASizeH}{\warpsPerCTAH*\mfmaNonKDim}
\pgfmathsetmacro{\CTASizeW}{\warpsPerCTAW*\mfmaNonKDim}
\foreach \ctaId in {0,...,\maxCTAId}{
\pgfmathsetmacro{\ctaCoordH}{int(\ctaId/\CTARepW)}
\pgfmathsetmacro{\ctaCoordW}{mod(\ctaId,\CTARepW)}
\coordinate (CTA TL) at ($(TL)+(\ctaCoordW*\CTASizeW*\elem, -\ctaCoordH*\CTASizeH*\elem)$);
%% Draw a detailed view of wave0 in each CTA
\coordinate (block TL) at (CTA TL);
\drawBlockMFMALayoutLarge{\mfmaTrans}{0}
\foreach \waveId in {0,...,\maxWaveId}{
\pgfmathsetmacro{\waveCoordH}{int(\waveId/\warpsPerCTAW)}
\pgfmathsetmacro{\waveCoordW}{mod(\waveId,\warpsPerCTAW)}
\coordinate (block TL) at ($(CTA TL)+(\waveCoordW*\mfmaNonKDim*\elem, -\waveCoordH*\mfmaNonKDim*\elem)$);
%% Inside the loop, only draw a rectangle
\draw [ultra thin] (block TL) rectangle ++(\mfmaNonKDim*\elem, -\mfmaNonKDim*\elem)
node [scale=.7*\scale, pos=.5, fill=white, inner sep=0] {wave\waveId};
}
%% Draw the outline of each CTA rep
\draw [ultra thick] (CTA TL) rectangle ++(\CTASizeW*\elem, -\CTASizeH*\elem);
}
\coordinate (TL) at (old TL);
}
\newcommand{\drawMFMAOperand}[4]{
%%
%% Draw one mfma operand
%%
%% mfma op TL: pre defined coordinates of the top-left
%% \elem: pre defined variable
%%
%% #1: mfmNonKDim
%% #2: kpack
%% #3: 0 for opA and 1 for opB
%% #4: verbose. 1 means draw tid in each vec; 0 means draw nothing
\pgfmathsetmacro{\nonKDim}{#1}
\pgfmathsetmacro{\kpack}{#2}
\pgfmathsetmacro{\opIdxA}{#3}
\pgfmathsetmacro{\opIdxB}{1-\opIdxA}
\pgfmathsetmacro{\verbose}{#4}
\ifthenelse{\opIdxA = 0}{
\def\opColorL{\opColorAL}
\def\opColorR{\opColorAR}
}{
\def\opColorL{\opColorBL}
\def\opColorR{\opColorBR}
}
\foreach \col/\tg in {\opColorL/0,\opColorR/1}{
\foreach \tid in {0,...,31} {
% \pgfmathsetmacro{\ratio}{\tid*2.5+15}
\ifthenelse{\verbose=0}{
\draw [line width=0.005mm, fill=\col]
($(mfma op TL)+(\tg*\kpack*\elem*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kpack*\elem*\opIdxA)$)
rectangle ++(\kpack*\elem*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kpack*\elem*\opIdxA);
}{
\pgfmathsetmacro{\drawTid}{int(\tid+\tg*32)}
\draw [line width=0.005mm, fill=\col]
($(mfma op TL)+(\tg*\kpack*\elem*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kpack*\elem*\opIdxA)$)
rectangle ++(\kpack*\elem*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kpack*\elem*\opIdxA)
node [pos=.5, scale=.35*\scale, rotate=90*\opIdxA] {t\drawTid};
}
}
}
}
\newcommand{\drawWaveOperand}[4]{
%%
%% Draw the part of the tensor that is one operand of the wave
%%
%% Op TL: pre defined coordinates of the top-left of the operand
%% \elem: pre defined variable
%%
%% #1: K
%% #2: mfmNonKDim
%% #3: kpack
%% #4: 0 for opA and 1 for opB
\pgfmathsetmacro{\K}{#1}
\pgfmathsetmacro{\nonKDim}{#2}
\pgfmathsetmacro{\kpack}{#3}
\pgfmathsetmacro{\opIdx}{#4}
\pgfmathsetmacro{\opIdxOther}{1-\opIdx}
\coordinate (TL) at (Op TL);
\pgfmathsetmacro{\numKRep}{\K/\kpack/2}
\pgfmathsetmacro{\maxKRepId}{\numKRep-1}
\foreach \repId in {0,...,\maxKRepId}{
\coordinate (mfma op TL) at ($(TL)+(\repId*2*\kpack*\elem*\opIdxOther, -\repId*2*\kpack*\elem*\opIdx)$);
\drawMFMAOperand{\nonKDim}{\kpack}{\opIdx}{0}
\draw [thick] (mfma op TL) rectangle
++(2*\kpack*\elem*\opIdxOther+\nonKDim*\opIdx*\elem, -\nonKDim*\opIdxOther*\elem-2*\kpack*\elem*\opIdx);
}
}
\newcommand{\drawDotOperands}[7]{
%%
%% Draw operand tensors of dot
%%
%% A TL and B TL: pre defined top-left coordinates of A and B tensor
%% \elem: pre defined variable
%%
%% #1: M
%% #2: N
%% #3: K
%% #4: MFMA nonKDim
%% #5: warpsPerCTA[0]
%% #6: warpsPerCTA[1]
%% #7: kpack
\pgfmathsetmacro{\M}{#1}
\pgfmathsetmacro{\N}{#2}
\pgfmathsetmacro{\K}{#3}
\pgfmathsetmacro{\mfmaNonKDim}{#4}
\pgfmathsetmacro{\warpsPerCTAM}{#5}
\pgfmathsetmacro{\warpsPerCTAN}{#6}
\pgfmathsetmacro{\kpack}{#7}
%% operand A
\pgfmathsetmacro{\CTARepM}{\M/\warpsPerCTAM/32}
\pgfmathsetmacro{\maxCTAIdM}{\CTARepM-1}
\pgfmathsetmacro{\maxWaveId}{\warpsPerCTAM-1}
\foreach \ctaId in {0,...,\maxCTAIdM}{
\coordinate (CTA TL) at ($(A TL)+(0, -\ctaId*\warpsPerCTAM*32*\elem)$);
\foreach \waveId in {0,...,\maxWaveId}{
\coordinate (wave TL) at ($(CTA TL)+(0, -\waveId*32*\elem)$);
\draw [ultra thin] (wave TL) rectangle ++(\K*\elem, -32*\elem);
}
%% Only draw the detailed view of the first wave in CTA
\coordinate (Op TL) at (CTA TL);
\drawWaveOperand{\K}{\mfmaNonKDim}{\kpack}{0}
%% Draw the outline of each CTA rep
\draw [ultra thick] (CTA TL) rectangle ++(\K*\elem, -\warpsPerCTAM*32*\elem);
}
\draw [ultra thin] (A TL) rectangle ++(\K*\elem, -\M*\elem);
%% operand B
\pgfmathsetmacro{\CTARepN}{\N/\warpsPerCTAN/32}
\pgfmathsetmacro{\maxCTAIdN}{\CTARepN-1}
\pgfmathsetmacro{\maxWaveId}{\warpsPerCTAN-1}
\foreach \ctaId in {0,...,\maxCTAIdN}{
\coordinate (CTA TL) at ($(B TL)+(\ctaId*\warpsPerCTAN*32*\elem, 0)$);
\foreach \waveId in {0,...,\maxWaveId}{
\coordinate (wave TL) at ($(CTA TL)+(\waveId*32*\elem ,0)$);
\draw [ultra thin] (wave TL) rectangle ++(32*\elem, -\K*\elem);
}
%% Only draw the detailed view of the first wave in CTA
\coordinate (Op TL) at (CTA TL);
\drawWaveOperand{\K}{\mfmaNonKDim}{\kpack}{1}
%% Draw the outline of each CTA rep
\draw [ultra thick] (CTA TL) rectangle ++(\warpsPerCTAN*32*\elem, -\K*\elem);
}
\draw [ultra thin] (B TL) rectangle ++(\N*\elem, -\K*\elem);
}
\newcommand{\drawDot}[8]{
%%
%% Draw C = dot A, B
%%
%% C TL: pre defined top-left coordinates of the result tensor
%% \elem: pre defined variable
%%
%% #1: M
%% #2: N
%% #3: K
%% #4: MFMA nonKDim
%% #5: warpsPerCTA[0]
%% #6: warpsPerCTA[1]
%% #7: 1 for mfma.trans, 0 for normal mfma
%% #8: kpack
\pgfmathsetmacro{\M}{#1}
\pgfmathsetmacro{\N}{#2}
\pgfmathsetmacro{\K}{#3}
\pgfmathsetmacro{\mfmaNonKDim}{#4}
\pgfmathsetmacro{\warpsPerCTAM}{#5}
\pgfmathsetmacro{\warpsPerCTAN}{#6}
\pgfmathsetmacro{\mfmaTrans}{#7}
\pgfmathsetmacro{\kpack}{#8}
\pgfmathsetmacro{\kdim}{int(2*\kpack)}
\pgfmathsetmacro{\gap}{\elem*20}
\coordinate (A TL) at ($(C TL)+(-\gap-\K*\elem, 0)$);
\coordinate (B TL) at ($(C TL)+(0, \gap+\K*\elem)$);
\drawDotOperands{\M}{\N}{\K}{\mfmaNonKDim}{\warpsPerCTAM}{\warpsPerCTAN}{\kpack}
\drawTensorMFMALayout{\M}{\N}{\mfmaNonKDim}{\warpsPerCTAM}{\warpsPerCTAN}{\mfmaTrans}
%% Draw labels
\node [scale=\scale, above] at ($(A TL)+(.5*\K*\elem, 0)$) {K=\K};
\node [scale=\scale, above, rotate=90] at ($(A TL)+(0, -.5*\M*\elem)$) {M=\M};
\node [scale=\scale, above, rotate=90] at ($(B TL)+(0, -.5*\K*\elem)$) {K=\K};
\node [scale=\scale, above] at ($(B TL)+(.5*\N*\elem, 0)$) {N=\N};
\node [scale=\scale, above left] at (A TL) {A};
\node [scale=\scale, above left] at (B TL) {B};
\node [scale=\scale, above left] at (C TL) {C};
%% label nonKDim
\node [scale=.8*\scale, left] at ($(A TL)+(0, -.5*\mfmaNonKDim*\elem)$) {\mfmaNonKDim};
\node [scale=.8*\scale, above] at ($(B TL)+(.5*\mfmaNonKDim*\elem, 0)$) {\mfmaNonKDim};
%% label kpack
\node [scale=.8*\scale, above] at ($(A TL)+(\kpack*\elem, 0)$) {\kdim};
\node [scale=.8*\scale, left] at ($(B TL)+(0, -\kpack*\elem)$) {\kdim};
}
\newcommand{\Colors}{{
"red",
"YellowGreen",
"blue",
"Maroon",
"orange",
"cyan",
"magenta",
"brown",
"teal",
"purple",
"gray",
"Green",
"BlueGreen",
"violet",
"olive",
"darkgray",
}}
\newcommand{\drawTensorLayoutGlobalMem}{
%%
%% Draw tensor layout in global memory without any swizzling
%%
%% TL: pre defined top-left coordinates of the tensor in global memory
%% \elem: per defined variable
%% \Colors: a pre defined array of 16 colors
%%
%% The following arguments are also expected to be pre defined
%% #1: M
%% #2: K
%% #3: vec: number of elements in a group
\pgfmathsetmacro{\numVecK}{\K/\vec}
\pgfmathsetmacro{\maxVecId}{16*\numVecK-1}
\pgfmathsetmacro{\drawM}{20}
%% Draw the tensor, but only draw 32 rows
\draw (TL) rectangle ++(\K*\elem, -\drawM*\elem);
%% Draw detailed vec view of the tensor
\foreach \vecId in {0,...,\maxVecId}{
\pgfmathsetmacro{\vecCoordM}{int(\vecId/\numVecK)}
\pgfmathsetmacro{\vecCoordK}{mod(\vecId,\numVecK)}
\coordinate (vec TL) at ($(TL)+(\vecCoordK*\vec*\elem, -\vecCoordM*\elem)$);
\pgfmathsetmacro{\colorIdxK}{int(mod(\vecCoordK,16))}
\pgfmathsetmacro{\colorIdxM}{mod(\vecCoordM,16)}
\pgfmathsetmacro{\vecColor}{\Colors[\colorIdxK]}
\pgfmathsetmacro{\ratio}{100-floor(\vecCoordK/16)*40}
\draw [ultra thin, fill=\vecColor!\ratio!white] (vec TL) rectangle ++(\vec*\elem, -\elem)
node [pos=.5, scale=.6*\scale, white] {m\vecCoordM};
}
%% M and K dim
\node [scale=\scale, rotate=90, above] at ($(TL)+(0, -.5*\drawM*\elem-8*\elem)$) {M=\M};
\node [scale=.8*\scale, left] at ($(TL)+(0, -.5*16*\elem)$) {16};
\node [scale=\scale, above] at ($(TL)+(.5*\K*\elem, 0)$) {K=\K};
%% label for vecSize
\def\vecR{1.5}
\coordinate (vec TL) at ($(TL)+(-.25*\vec*\elem, 3*\elem*\vecR)$);
\pgfmathsetmacro{\maxVec}{\vec-1}
\foreach \vecId in {0,...,\maxVec}{
\draw ($(vec TL)+(\vecId*\elem*\vecR, 0)$) rectangle ++(\elem*\vecR, -\elem*\vecR);
}
\draw [densely dotted] (TL) -- ($(vec TL)+(0, -\elem*\vecR)$);
\draw [densely dotted] ($(TL)+(\vec*\elem, 0)$) -- ($(vec TL)+(\vec*\elem*\vecR, -\elem*\vecR)$);
\node [scale=.8*\scale, above] at ($(vec TL)+(.5*\vec*\elem*\vecR, 0)$) {vec=\vec};
}
\newcommand{\drawLDSLayoutTritonSwizzling}[2]{
%%
%% Draw tensor layout in LDS with swizzling
%%
%% TL: pre defined top-left coordinates of the tensor in global memory
%% \elem: per defined variable
%% \Colors: a pre defined array of 16 colors
%%
%% The following three arguments are expected to be pre defined
%% #1: M
%% #2: K
%% #3: vec: number of elements in a group
%%
%% #1: hasSwizzle, 0 means no swizzling and no padding,
%% 1 means optimal swizzling
%% 2 means padding
%% #2: access mode, 0 means draw nothing, 1 means ds_read, 2 means ds_write
%% For ds_write access, the following variables are assumed to be pre defined
%% \sizePerThreadK
%% \sizePerThreadM
%% \threadsPerWarpK
\pgfmathsetmacro{\hasSwizzle}{#1}
\pgfmathsetmacro{\accessMode}{#2}
\pgfmathsetmacro{\numVecK}{\K/\vec}
%% Assuming fp16 data type
\pgfmathsetmacro{\LDSK}{64}
\pgfmathsetmacro{\numLDSVec}{\LDSK/\vec}
\pgfmathsetmacro{\swizzleK}{max(\LDSK, \K)}
\pgfmathsetmacro{\LDSM}{int(\M/\LDSK*\K)}
\ifthenelse{\accessMode = 2}{
%% \accessMode == 2, draw 8 rows
\pgfmathsetmacro{\maxVecId}{8*\numVecK-1}
\pgfmathsetmacro{\drawM}{8*\K/\LDSK+4}
}{
%% \accessMode == 0 or 1, draw 16 rows
\pgfmathsetmacro{\maxVecId}{16*\numVecK-1}
\pgfmathsetmacro{\drawM}{16*\K/\LDSK+4}
}
%% Parameters used for swizzling
\pgfmathsetmacro{\numVecSwizzleK}{\swizzleK/\vec}
%% perPhase = ceil(LDSK / K)
%% The number of the rows of the tensor that can share the same swizzling pattern
\pgfmathsetmacro{\perPhase}{ceil(\LDSK/\K)}
%% maxPhase: the total number of different swizzling patterns
\ifthenelse{\hasSwizzle=0}{
%% When swizzling is disabled
\pgfmathsetmacro{\maxPhase}{1}
}{
%% When vec is small enough, we want 16/perPhase different swizzling patterns
%% When vec is large, we can only have 64 / \vec different swizzling pattern at most
\pgfmathsetmacro{\maxPhase}{min(16/\perPhase,64/\vec)}
}
%% Draw the LDS
\draw (TL) rectangle ++(\LDSK*\elem, -\drawM*\elem);
%% Draw detailed vec view of LDS
\foreach \vecId in {0,...,\maxVecId}{
\pgfmathsetmacro{\vecCoordM}{int(\vecId/\numVecK)}
\pgfmathsetmacro{\vecCoordK}{int(mod(\vecId,\numVecK))}
\pgfmathsetmacro{\rawPhase}{floor(\vecId/\numVecSwizzleK)}
%% vec color
\pgfmathsetmacro{\colorIdxK}{int(mod(\vecCoordK,16))}
\pgfmathsetmacro{\colorIdxM}{mod(\vecCoordM,16)}
\pgfmathsetmacro{\ratio}{100-floor(\vecCoordK/16)*40}
\pgfmathsetmacro{\vecColor}{\Colors[\colorIdxK]}
%% old vec coordinates
\coordinate (vec TL) at ($(TL)+(\vecCoordK*\vec*\elem, -\vecCoordM*\elem)$);
%% new vec coordinates in LDS by swizzling
%% The following two conditions correspond to the relation between \LDSK and \K
\ifthenelse{\LDSK < \K}{
\pgfmathsetmacro{\vecLDSM}{\vecCoordM*\K/\LDSK+floor(\vecCoordK*\vec/\LDSK)}
\pgfmathsetmacro{\vecLDSK}{int(mod(\vecCoordK, \LDSK/\vec))}
}{
\pgfmathsetmacro{\vecLDSM}{floor(\vecCoordM/\perPhase)}
\pgfmathsetmacro{\vecLDSK}{int(\vecCoordK+mod(\vecCoordM,\perPhase)*\numVecK)}
}
%%
\pgfmathsetmacro{\phase}{int(mod(\rawPhase, \maxPhase))}
%% Compute the swizzled col id
\pgfmathsetmacro{\vecLDSKSwizzled}{\bitwiseXor{\vecLDSK}{\phase}}
%% new vec coordinates in LDS by padding
\pgfmathsetmacro{\numPads}{floor(\vecId/\numLDSVec)}
\pgfmathsetmacro{\bankId}{\vec/2*\vecId+\numPads}
\pgfmathsetmacro{\vecPadM}{int(\bankId/32)}
\pgfmathsetmacro{\vecPadK}{int(mod(\bankId,32))}
\ifthenelse{\hasSwizzle = 2}{
%% vec coordinates by padding
\coordinate (new vec TL) at ($(TL)+(\vecPadK*2*\elem, -\vecPadM*\elem)$);
\pgfmathsetmacro{\tailBankId}{int(\vecPadK+\vec/2-1)}
}{
%% vec coordinates by swizzling
\coordinate (new vec TL) at ($(TL)+(\vecLDSKSwizzled*\vec*\elem, -\vecLDSM*\elem)$);
\pgfmathsetmacro{\tailBankId}{0}
}
\ifthenelse{\hasSwizzle = 2 \AND \tailBankId > 31}{
\pgfmathsetmacro{\nextBanks}{\tailBankId-31}
\pgfmathsetmacro{\leftBanks}{\vec/2 - \nextBanks}
\draw [ultra thin, fill=\vecColor!\ratio!white] (new vec TL) rectangle ++(\leftBanks*2*\elem, -\elem)
node [pos=.5, scale=.6*\scale, white] {m\vecCoordM};
\draw [ultra thin, fill=\vecColor!\ratio!white] ($(TL)+(0, -\vecPadM*\elem-\elem)$)
rectangle ++(\nextBanks*2*\elem, -\elem) node [pos=.5, scale=.6*\scale, white] {m\vecCoordM};
}{
\draw [ultra thin, fill=\vecColor!\ratio!white] (new vec TL) rectangle ++(\vec*\elem, -\elem)
node [pos=.5, scale=.6*\scale, white] {m\vecCoordM};
}
%% ds_read
%% Highlight the elements the first 16 threads access in the first cycle
%% This is used to visualize bank conflicts
\ifthenelse{\accessMode = 1}{
\ifthenelse{\vecCoordK = 0}{
\draw [fill=white] (new vec TL) rectangle ++(\elem, -\elem);
\draw (new vec TL) -- ++(\elem, -\elem);
\draw ($(new vec TL)+(0, -\elem)$) -- ++(\elem, \elem);
}{}
}{}
%% Draw ds_write pattern
\ifthenelse{\accessMode = 2}{
%% First compute the coverage of the first 16 threads
\pgfmathsetmacro{\covK}{min(16, \threadsPerWarpK)*\sizePerThreadK/\vec}
\pgfmathsetmacro{\covM}{ceil(16/\threadsPerWarpK)*\sizePerThreadM}
%% Check conditions for the first 16 threads
\pgfmathsetmacro{\vecInThread}{int(mod(\vecCoordK, \sizePerThreadK/\vec))}
\ifthenelse{\vecInThread=0}{
\ifthenelse{\vecCoordK<\covK \AND \vecCoordM<\covM}{
\draw [fill=white] (new vec TL) rectangle ++(\elem, -\elem);
\draw (new vec TL) -- ++(\elem, -\elem);
\draw ($(new vec TL)+(0, -\elem)$) -- ++(\elem, \elem);
}{}
}{}
}{}
%% Label the phase of each line if swizzling is used
\ifthenelse{\hasSwizzle = 2}{}{
\pgfmathsetmacro{\lastVecId}{int(64/\vec)-1}
\ifthenelse{\vecLDSKSwizzled = \lastVecId}{
\draw [ultra thin] ($(new vec TL)+(\vec*\elem, -.5*\elem)$) -- ++(\elem, 0)
node [scale=.6*\scale, right] {\phase};
}{}
}
}
%% Draw boundary of 32 banks
%% Assume fp16 data type
\foreach \bank in {0,...,31}{
\draw [ultra thin, gray] ($(TL)+(\bank*2*\elem, 0)$) -- ++(0, 2*\elem)
node [scale=.6*\scale, right, black] {\bank};
}
\draw [ultra thin, gray] ($(TL)+(32*2*\elem, 0)$) -- ++(0, 2*\elem);
\node [scale=.6*\scale, left, black] at ($(TL)+(0, 2*\elem)$) {bank id};
\node [scale=\scale, above] at ($(TL)+(.5*\LDSK*\elem, 3*\elem)$) {LDS 32 banks};
\node [scale=\scale, rotate=90, above] at ($(TL)+(0, -.5*\drawM*\elem)$) {LDSM=\LDSM};
%% label phase if swizzling is used
\ifthenelse{\hasSwizzle = 2}{}{
\node [scale=.6*\scale, above right] at($(TL)+(32*2*\elem, 0)$) {phase};
}
}
\newcommand{\drawMFMAInstr}[3]{
%%
%% Draw layout of mfma instructions with tid labeled
%%
%% C TL: pre defined top-left coordinates of the output matrix
%% \elem: pre defined variable
%%
%% #1: mfmaNonKDim
%% #2: kpack
%% #3: mfmaTrans
\pgfmathsetmacro{\mfmaNonKDim}{#1}
\pgfmathsetmacro{\kpack}{#2}
\pgfmathsetmacro{\mfmaTrans}{#3}
\pgfmathsetmacro{\nonTrans}{1-#3}
\pgfmathsetmacro{\gap}{\elem*5}
\coordinate (mfma opA TL) at ($(C TL)+(-.5*\gap-1.2*\nonTrans*\gap-2*\kpack*\elem, 0)$);
\coordinate (mfma op TL) at (mfma opA TL);
\drawMFMAOperand{\mfmaNonKDim}{\kpack}{0}{1}
\coordinate (mfma op TL) at ($(C TL)+(0, 1.5*\gap+.5*\mfmaTrans*\gap+2*\kpack*\elem)$);
\drawMFMAOperand{\mfmaNonKDim}{\kpack}{1}{1}
\coordinate (block TL) at (C TL);
\drawBlockMFMALayoutLarge{\mfmaTrans}{1}
%% Draw labels
\def\vecR{1.5}
\coordinate (vec TL) at ($(mfma opA TL)+(-.25*\kpack*\elem, 3*\elem*\vecR)$);
\pgfmathsetmacro{\maxVec}{\kpack-1}
\foreach \vecId in {0,...,\maxVec}{
\draw ($(vec TL)+(\vecId*\elem*\vecR, 0)$) rectangle ++(\elem*\vecR, -\elem*\vecR);
}
\draw [densely dotted] (mfma opA TL) -- ($(vec TL)+(0, -\elem*\vecR)$);
\draw [densely dotted] ($(mfma opA TL)+(\kpack*\elem, 0)$) -- ($(vec TL)+(\kpack*\elem*\vecR, -\elem*\vecR)$);
\node [scale=.8*\scale, above] at ($(vec TL)+(.5*\kpack*\elem*\vecR, 0)$) {vec=\kpack};
\coordinate (vec TL) at ($(mfma op TL)+(-3*\elem*\vecR, .25*\kpack*\elem)$);
\foreach \vecId in {0,...,\maxVec}{
\draw ($(vec TL)+(0, -\vecId*\elem*\vecR)$) rectangle ++(\elem*\vecR, -\elem*\vecR);
}
\draw [densely dotted] (mfma op TL) -- ($(vec TL)+(\elem*\vecR,0)$);
\draw [densely dotted] ($(mfma op TL)+(0, -\kpack*\elem)$) -- ($(vec TL)+(\elem*\vecR, -\kpack*\elem*\vecR)$);
\node [scale=.8*\scale, above, rotate=90] at ($(vec TL)+(0, -.5*\kpack*\elem*\vecR)$) {vec=\kpack};
\node [scale=\scale, below] at ($(block TL)+(.5*\mfmaNonKDim*\elem,-\mfmaNonKDim*\elem)$) {outC};
\ifthenelse{\mfmaTrans=0}{
\node [scale=\scale, below] at ($(mfma opA TL)+(\kpack*\elem, -\mfmaNonKDim*\elem)$) {opA};
\node [scale=\scale, above] at (mfma op TL) {opB};
\coordinate (vec TL) at ($(block TL)+(-3*\elem-\elem*\vecR, .25*4*\elem)$);
\foreach \vecId in {0,1,2,3}{
\draw ($(vec TL)+(0, -\vecId*\elem*\vecR)$) rectangle ++(\elem*\vecR, -\elem*\vecR);
}
\draw [densely dotted] (block TL) -- ++(-3*\elem, .25*4*\elem);
\draw [densely dotted] ($(block TL)+(0, -4*\elem)$) -- ++(-3*\elem, -.25*4*\elem);
\node [scale=.8*\scale, above, rotate=90] at ($(vec TL)+(0, -.5*4*\elem*\vecR)$) {vec=4};
\node [scale=.8*\scale, above, align=center] at ($(block TL)+(16*\elem, 0)$) {mfmaLayout\\trans=False};
}{
\node [scale=\scale, below] at ($(mfma opA TL)+(\kpack*\elem, -\mfmaNonKDim*\elem)$) {opB};
\node [scale=\scale, above] at (mfma op TL) {opA};
\coordinate (vec TL) at ($(block TL)+(-.25*4*\elem, 3*\elem+\elem*\vecR)$);
\foreach \vecId in {0,1,2,3}{
\draw ($(vec TL)+(\vecId*\elem*\vecR, 0)$) rectangle ++(\elem*\vecR, -\elem*\vecR);
}
\draw [densely dotted] (block TL) -- ++(-.25*4*\elem, 3*\elem);
\draw [densely dotted] ($(block TL)+(4*\elem, 0)$) -- ++(.25*4*\elem, 3*\elem);
\node [scale=.8*\scale, above] at ($(vec TL)+(.5*4*\elem*\vecR, 0)$) {vec=4};
\node [scale=.8*\scale, above, align=center] at ($(block TL)+(16*\elem, 0)$) {mfmaLayout\\trans=True};
}
}
\newcommand{\drawWMMAOperand}[3]{
%%
%% Draw the layout of one operand of WMMA instruction
%%
%% #1: opIdx. 0 for opA, 1 for opB
%% #2: verbose. 1 means draw tid in each vec; 0 means draw nothing
%% #3: mode. 0 for w32, 1 for w64
%%
%% wmma op TL: pre defined top-left coordinates of the operand matrix
\pgfmathsetmacro{\isOpB}{#1}
\pgfmathsetmacro{\isOpA}{1-\isOpB}
\pgfmathsetmacro{\verbose}{#2}
\pgfmathsetmacro{\isWLarge}{#3}
\foreach \row in {0,...,15}{
\pgfmathsetmacro{\ratio}{\row*5+15}
\coordinate (vec TL) at ($(wmma op TL)+(\row*\isOpB*\elem, -\row*\elem*\isOpA)$);
\ifthenelse{\isWLarge=1}{
\pgfmathsetmacro{\tidone}{int(\row+16)}
\pgfmathsetmacro{\tidtwo}{int(\row+32)}
\pgfmathsetmacro{\tidthree}{int(\row+48)}
\draw [line width=0.005mm, fill=brown!\ratio!white] (vec TL)
rectangle ++(16*\elem*\isOpA+\elem*\isOpB, -\elem*\isOpA-16*\elem*\isOpB)
node [scale=0.4*\scale, pos=.5, rotate=90*\isOpB] {t\row, t\tidone, t\tidtwo, t\tidthree};
}{
\pgfmathsetmacro{\tidone}{int(\row+16)}
\draw [line width=0.005mm, fill=brown!\ratio!white] (vec TL)
rectangle ++(16*\elem*\isOpA+\elem*\isOpB, -\elem*\isOpA-16*\elem*\isOpB)
node [scale=0.4*\scale, pos=.5, rotate=90*\isOpB] {t\row, t\tidone};
}
}
}
\newcommand{\drawWMMAResult}[2]{
%%
%% Draw layout of WMMA result tensor
%%
%% #1: verbose. 1 means draw tid in each vec; 0 means draw nothing
%% #2: mode. 0 for w32, 1 for w64
\pgfmathsetmacro{\verbose}{#1}
\pgfmathsetmacro{\isWLarge}{#2}
\pgfmathsetmacro{\numElem}{256}
\pgfmathsetmacro{\maxElemId}{\numElem-1}
\foreach \elemId in {0,...,\maxElemId}{
%% figure out the rowID
\pgfmathsetmacro{\rowId}{floor(\elemId/16)}
%% figure out the colID
\pgfmathsetmacro{\colId}{mod(\elemId,16)}
%% figure out the tid and color
\ifthenelse{\isWLarge=1}{
\pgfmathsetmacro{\tid}{int(mod(\elemId,64))}
\pgfmathsetmacro{\laneId}{mod(\elemId,64)}
}{
\pgfmathsetmacro{\tid}{int(mod(\elemId,32))}
\pgfmathsetmacro{\laneId}{mod(\elemId,32)}
}
%% figure out the color
\pgfmathsetmacro{\colorId}{floor(\laneId/16)}
\pgfmathsetmacro{\vecColor}{\Colors[\colorId]}
%% Coordinate
\coordinate (vec TL) at ($(C TL)+(\colId*\elem, -\rowId*\elem)$);
\draw [line width=0.005mm, fill=\vecColor!60!white] (vec TL) rectangle ++(\elem, -\elem)
node [scale=.4*\scale, pos=.5] {t\tid};
}
}
\newcommand{\drawWMMAInstr}[2]{
%%
%% Draw wmma instruction layouts 16x16x16
%%
%% #1: mode. 0 for w32, 1 for w64
%% #2: verbose. 1 means draw tid in each vec; 0 means draw nothing
%%
%% C TL: pre defined top-left coordinates of output matrix
%% \elem: pre defined element size
\pgfmathsetmacro{\isWLarge}{#1}
\pgfmathsetmacro{\verbose}{#2}
\pgfmathsetmacro{\gap}{\elem*2}
\coordinate (wmma op TL) at ($(C TL)+(-\gap-16*\elem, 0)$);
\coordinate (wmma opA TL) at (wmma op TL);
\drawWMMAOperand{0}{\verbose}{\isWLarge}
\coordinate (wmma op TL) at ($(C TL)+(0, \gap+16*\elem)$);
\drawWMMAOperand{1}{\verbose}{\isWLarge}
\drawWMMAResult{1}{\isWLarge}
%% labels
\pgfmathsetmacro{\gap}{\elem}
\node [above left, scale=\scale] at (wmma opA TL) {A};
\node [above left, scale=\scale] at (wmma op TL) {B};
\node [above right, scale=\scale] at ($(C TL)+(16*\elem, 0)$) {C};
%% A k dim
\node [scale=.8*\scale] (k dim A) at ($(wmma opA TL)+(8*\elem,\gap)$) {16};
\draw [->, >=stealth] (k dim A.west) -- ($(wmma opA TL)+(0, \gap)$);
\draw [->, >=stealth] (k dim A.east) -- ($(wmma opA TL)+(16*\elem, \gap)$);
%% B K dim
\node [scale=.8*\scale, rotate=90] (k dim B) at ($(wmma op TL)+(-\gap, -8*\elem)$) {16};
\draw [->, >=stealth] (k dim B.east) -- ($(wmma op TL)+(-\gap, 0)$);
\draw [->, >=stealth] (k dim B.west) -- ($(wmma op TL)+(-\gap, -16*\elem)$);
%% C M dim
\node [scale=.8*\scale] (m dim) at ($(C TL)+(8*\elem,-16*\elem-\gap)$) {16};
\draw [->, >=stealth] (m dim.west) -- ($(C TL)+(0, -16*\elem-\gap)$);
\draw [->, >=stealth] (m dim.east) -- ($(C TL)+(16*\elem, -16*\elem-\gap)$);
%% C N dim
\node [scale=.8*\scale, rotate=-90] (n dim) at ($(C TL)+(16*\elem+\gap, -8*\elem)$) {16};
\draw [->, >=stealth] (n dim.west) -- ($(C TL)+(16*\elem+\gap, 0)$);
\draw [->, >=stealth] (n dim.east) -- ($(C TL)+(16*\elem+\gap, -16*\elem)$);
}

View File

@@ -1978,11 +1978,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// PTX: nvvm.shfl.sync bfly
// PTX: nvvm.barrier0
// GCN-COUNT-4: ds_swizzle_b32
// GCN-COUNT-4: rocdl.ds_swizzle %{{.*}} : (i32, i32) -> i32
// GCN: llvm.store
// GCN: rocdl.barrier
// GCN: llvm.load
// GCN-COUNT-2: ds_swizzle_b32
// GCN-COUNT-2: rocdl.ds_swizzle %{{.*}} : (i32, i32) -> i32
// GCN: llvm.store
// GCN: rocdl.barrier
// GCN: llvm.load