mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge branch 'triton-mlir' into ifu-231117
This commit is contained in:
@@ -740,6 +740,24 @@ The data will be distributed between threads as follows:
|
||||
[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ]
|
||||
[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ]
|
||||
[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ]
|
||||
|
||||
Example 3:
|
||||
Suppose we have a tensor with a shape of [8, 8], warpsPerCTA set to [2, 2] and nonKDim set to 4.
|
||||
The data will be distributed between threads as follows(note that each element is duploicated in 16 threads):
|
||||
|
||||
M N -> wave 0 wave 2
|
||||
| --------------------------/\-------------------------- ------------------------------/\------------------------------
|
||||
V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ]
|
||||
[ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ]
|
||||
[ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ]
|
||||
[ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ]
|
||||
|
||||
wave 1 wave 3
|
||||
--------------------------/\-------------------------- ------------------------------/\------------------------------
|
||||
[ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ]
|
||||
[ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ]
|
||||
[ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ]
|
||||
[ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ]
|
||||
}];
|
||||
|
||||
let parameters = (
|
||||
|
||||
@@ -426,7 +426,7 @@ bool supportMMA(triton::DotOp op, int version) {
|
||||
#ifdef USE_ROCM
|
||||
static bool supportMFMAGranularity(int m, int n, int k) {
|
||||
// these limitations are dtype dependent, in future we may relax them
|
||||
const static std::pair<int, int> mfmaTypes[2] = {{32, 8}, {16, 16}};
|
||||
const static std::pair<int, int> mfmaTypes[] = {{32, 8}, {16, 16}, {4, 64}};
|
||||
for (const auto &mfmaType : mfmaTypes) {
|
||||
auto [granularityMN, granularityK] = mfmaType;
|
||||
if (m % granularityMN != 0 || n % granularityMN != 0)
|
||||
|
||||
@@ -441,7 +441,7 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
|
||||
const SharedMemoryObject &smemObj) {
|
||||
auto mfmaLayout = encoding.getParent().cast<MfmaEncodingAttr>();
|
||||
auto nonKDim = mfmaLayout.getNonKDim();
|
||||
assert(nonKDim == 32 || nonKDim == 16);
|
||||
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
|
||||
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
|
||||
|
||||
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
@@ -587,7 +587,7 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
|
||||
const SharedMemoryObject &smemObj) {
|
||||
auto mfmaLayout = encoding.getParent().cast<MfmaEncodingAttr>();
|
||||
auto nonKDim = mfmaLayout.getNonKDim();
|
||||
assert(nonKDim == 32 || nonKDim == 16);
|
||||
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
|
||||
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
|
||||
|
||||
auto bTensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
|
||||
@@ -80,115 +80,86 @@ struct DotOpMFMAConversionHelper {
|
||||
return rewriter.create<arith::TruncIOp>(loc, i32_ty, tid);
|
||||
}
|
||||
|
||||
Value generateMFMAOp(MFMAInstrDescr mfmaDescr, Value valA, Value valB,
|
||||
Value valC) const {
|
||||
Value generateMFMA32Op(MatrixCoreType coreType, Value valA, Value valB,
|
||||
Value valC) const {
|
||||
auto resType = valC.getType();
|
||||
Value zeroFlag = i32_val(0);
|
||||
switch (mfmaDescr.coreType) {
|
||||
switch (coreType) {
|
||||
case MatrixCoreType::FP32_FP8_FP8_FP32:
|
||||
if (mfmaDescr.size == 16) {
|
||||
return rewriter.create<ROCDL::mfma_f32_16x16x32_fp8_fp8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
} else {
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x16_fp8_fp8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
}
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x16_fp8_fp8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_FP8_BF8_FP32:
|
||||
if (mfmaDescr.size == 16) {
|
||||
return rewriter.create<ROCDL::mfma_f32_16x16x32_fp8_bf8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
} else {
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x16_fp8_bf8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
}
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x16_fp8_bf8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_BF8_FP8_FP32:
|
||||
if (mfmaDescr.size == 16) {
|
||||
return rewriter.create<ROCDL::mfma_f32_16x16x32_bf8_fp8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
} else {
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x16_bf8_fp8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
}
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x16_bf8_fp8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_BF8_BF8_FP32:
|
||||
if (mfmaDescr.size == 16) {
|
||||
return rewriter.create<ROCDL::mfma_f32_16x16x32_bf8_bf8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
} else {
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x16_bf8_bf8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
}
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x16_bf8_bf8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_FP16_FP16_FP32:
|
||||
if (mfmaDescr.size == 16) {
|
||||
return rewriter.create<ROCDL::mfma_f32_16x16x16f16>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
} else {
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x8f16>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
}
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x8f16>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_BF16_BF16_FP32:
|
||||
if (mfmaDescr.size == 16) {
|
||||
return rewriter.create<ROCDL::mfma_f32_16x16x8bf16>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
} else {
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x4bf16>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
}
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x4bf16>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_BF16_BF16_FP32_1K:
|
||||
if (mfmaDescr.size == 16) {
|
||||
return rewriter.create<ROCDL::mfma_f32_16x16x16bf16_1k>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
} else {
|
||||
assert(mfmaDescr.size == 32);
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x8bf16_1k>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
}
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x8bf16_1k>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_FP32_FP32_FP32:
|
||||
if (mfmaDescr.size == 16) {
|
||||
return rewriter.create<ROCDL::mfma_f32_16x16x4f32>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
} else {
|
||||
assert(mfmaDescr.size == 32);
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x2f32>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
}
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x2f32>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::INT32_INT8_INT8_INT32:
|
||||
if (mfmaDescr.size == 16) {
|
||||
return rewriter.create<ROCDL::mfma_i32_16x16x16i8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
} else {
|
||||
return rewriter.create<ROCDL::mfma_i32_32x32x8i8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
}
|
||||
return rewriter.create<ROCDL::mfma_i32_32x32x8i8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::INT32_INT8_INT8_INT32_CDNA3:
|
||||
if (mfmaDescr.size == 16) {
|
||||
return rewriter.create<ROCDL::mfma_i32_16x16x32_i8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
} else {
|
||||
return rewriter.create<ROCDL::mfma_i32_32x32x16_i8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
}
|
||||
return rewriter.create<ROCDL::mfma_i32_32x32x16_i8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP64_FP64_FP64_FP64:
|
||||
return rewriter.create<ROCDL::mfma_f64_16x16x4f64>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
default:
|
||||
llvm::report_fatal_error("MFMA 32x32 data type not supported");
|
||||
}
|
||||
}
|
||||
|
||||
Value generateMFMA16Op(MatrixCoreType coreType, Value valA, Value valB,
|
||||
Value valC) const {
|
||||
auto resType = valC.getType();
|
||||
Value zeroFlag = i32_val(0);
|
||||
switch (coreType) {
|
||||
case MatrixCoreType::FP32_FP16_FP16_FP32:
|
||||
return rewriter.create<ROCDL::mfma_f32_16x16x16f16>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_BF16_BF16_FP32:
|
||||
return rewriter.create<ROCDL::mfma_f32_16x16x8bf16>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_BF16_BF16_FP32_1K:
|
||||
return rewriter.create<ROCDL::mfma_f32_16x16x16bf16_1k>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_FP32_FP32_FP32:
|
||||
return rewriter.create<ROCDL::mfma_f32_16x16x4f32>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::INT32_INT8_INT8_INT32:
|
||||
return rewriter.create<ROCDL::mfma_i32_16x16x16i8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP64_FP64_FP64_FP64:
|
||||
assert(mfmaDescr.size == 16);
|
||||
return rewriter.create<ROCDL::mfma_f64_16x16x4f64>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
@@ -197,6 +168,72 @@ struct DotOpMFMAConversionHelper {
|
||||
}
|
||||
}
|
||||
|
||||
Value generateMFMA4Op(MatrixCoreType coreType, Value valA, Value valB,
|
||||
Value valC) const {
|
||||
auto resType = valC.getType();
|
||||
Value zeroFlag = i32_val(0);
|
||||
switch (coreType) {
|
||||
case MatrixCoreType::FP32_FP16_FP16_FP32:
|
||||
return rewriter.create<ROCDL::mfma_f32_4x4x4f16>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_BF16_BF16_FP32:
|
||||
return rewriter.create<ROCDL::mfma_f32_4x4x2bf16>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_BF16_BF16_FP32_1K:
|
||||
return rewriter.create<ROCDL::mfma_f32_4x4x4bf16_1k>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_FP32_FP32_FP32:
|
||||
return rewriter.create<ROCDL::mfma_f32_4x4x1f32>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::INT32_INT8_INT8_INT32:
|
||||
return rewriter.create<ROCDL::mfma_i32_4x4x4i8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
default:
|
||||
llvm::report_fatal_error("MFMA4 data type not supported");
|
||||
}
|
||||
}
|
||||
|
||||
Value generateMFMAOp(MFMAInstrDescr mfmaDescr, Value valA, Value valB,
|
||||
Value valC) const {
|
||||
switch (mfmaDescr.size) {
|
||||
case 32:
|
||||
return generateMFMA32Op(mfmaDescr.coreType, valA, valB, valC);
|
||||
break;
|
||||
case 16:
|
||||
return generateMFMA16Op(mfmaDescr.coreType, valA, valB, valC);
|
||||
break;
|
||||
case 4:
|
||||
return generateMFMA4Op(mfmaDescr.coreType, valA, valB, valC);
|
||||
default:
|
||||
llvm::report_fatal_error("MFMA nonkDim size is not supported");
|
||||
}
|
||||
return Value();
|
||||
}
|
||||
|
||||
int getNumSubmatrices(Type elementType, int nonKDim) const {
|
||||
switch (nonKDim) {
|
||||
case 32:
|
||||
case 16:
|
||||
return 1;
|
||||
break;
|
||||
case 4:
|
||||
assert(elementType.getIntOrFloatBitWidth() <= 32 &&
|
||||
"fp64 is not supported yet");
|
||||
assert(elementType.getIntOrFloatBitWidth() != 8 ||
|
||||
elementType.isInteger(8) && "fp8 is not supported yet");
|
||||
return 16;
|
||||
break;
|
||||
default:
|
||||
llvm::report_fatal_error("unsupported nonKDim in MFMA dot");
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
// TODO unify this function with Utility.cpp:supportMFMATypes
|
||||
static MatrixCoreType getMatrixCoreTypeFromDot(DotOp op) {
|
||||
auto aOperandTy = op.getA().getType();
|
||||
@@ -223,21 +260,21 @@ struct DotOpMFMAConversionHelper {
|
||||
if (aElemTy.isBF16()) {
|
||||
auto nonKDim = mfmaEncoding.getNonKDim();
|
||||
auto kWidth = dotOpEncoding.getKWidth();
|
||||
if ((nonKDim == 32 && kWidth == 4) || (nonKDim == 16 && kWidth == 4)) {
|
||||
if ((nonKDim == 32 || nonKDim == 16 || nonKDim == 4) && kWidth == 4) {
|
||||
return MatrixCoreType::FP32_BF16_BF16_FP32_1K;
|
||||
} else {
|
||||
assert((nonKDim == 32 && kWidth == 2) ||
|
||||
(nonKDim == 16 && kWidth == 2));
|
||||
(nonKDim == 16 && kWidth == 2) || (nonKDim == 4 && kWidth == 2));
|
||||
return MatrixCoreType::FP32_BF16_BF16_FP32;
|
||||
}
|
||||
}
|
||||
if (aElemTy.isInteger(8)) {
|
||||
auto nonKDim = mfmaEncoding.getNonKDim();
|
||||
auto kWidth = dotOpEncoding.getKWidth();
|
||||
if ((nonKDim == 32 && kWidth == 8) || (nonKDim == 16 && kWidth == 8)) {
|
||||
if ((nonKDim == 32 || nonKDim == 16 || nonKDim == 4) && kWidth == 8) {
|
||||
return MatrixCoreType::INT32_INT8_INT8_INT32_CDNA3;
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
assert((nonKDim == 32 || nonKDim == 16 || nonKDim == 4) && kWidth == 4);
|
||||
return MatrixCoreType::INT32_INT8_INT8_INT32;
|
||||
}
|
||||
}
|
||||
@@ -255,11 +292,81 @@ struct DotOpMFMAConversionHelper {
|
||||
return descr;
|
||||
}
|
||||
|
||||
Value processSubBlocks(int numSubBlocks, Value acc, bool reduceSubBlocks,
|
||||
bool zeroSubBlocks) const {
|
||||
assert((numSubBlocks & (numSubBlocks - 1)) == 0 &&
|
||||
"numSubBlocks in not pow 2!");
|
||||
if (numSubBlocks == 1)
|
||||
return acc;
|
||||
constexpr int waveSize = 64;
|
||||
int subBlockSize = waveSize / numSubBlocks;
|
||||
Value laneId = getThreadId();
|
||||
laneId = and_(laneId, i32_val(waveSize - 1));
|
||||
auto vecTy = dyn_cast<VectorType>(acc.getType());
|
||||
auto elemType = vecTy.getElementType();
|
||||
assert(elemType.getIntOrFloatBitWidth() == 32);
|
||||
int numScalars = vecTy.getNumElements();
|
||||
std::vector<Value> accScalar(numScalars);
|
||||
for (int i = 0; i < numScalars; ++i)
|
||||
accScalar[i] = extract_element(elemType, acc, i32_val(i));
|
||||
|
||||
if (reduceSubBlocks) {
|
||||
while (subBlockSize < waveSize) {
|
||||
for (int i = 0; i < numScalars; ++i) {
|
||||
Value other_acc =
|
||||
mlir::LLVM::shflSync(loc, rewriter, accScalar[i], subBlockSize);
|
||||
if (elemType.isInteger(32))
|
||||
accScalar[i] = add(accScalar[i], other_acc);
|
||||
else
|
||||
accScalar[i] = fadd(accScalar[i], other_acc);
|
||||
}
|
||||
subBlockSize *= 2;
|
||||
}
|
||||
}
|
||||
if (zeroSubBlocks) {
|
||||
Value zero;
|
||||
if (elemType.isInteger(32))
|
||||
zero = i32_val(0);
|
||||
else
|
||||
zero = f32_val(0.0);
|
||||
auto cond = icmp_ult(laneId, i32_val(subBlockSize));
|
||||
for (int i = 0; i < numScalars; ++i)
|
||||
accScalar[i] = select(cond, accScalar[i], zero);
|
||||
}
|
||||
|
||||
Value reducedAcc = undef(vecTy);
|
||||
for (int i = 0; i < numScalars; ++i)
|
||||
reducedAcc = insert_element(vecTy, reducedAcc, accScalar[i], i32_val(i));
|
||||
return reducedAcc;
|
||||
}
|
||||
|
||||
/// @brief MFMA 4x4 is computes 16 matrix mupliplications, this functions adds
|
||||
/// these 16 matrices to get final 4x4 matrix
|
||||
/// @param numSubBlocks
|
||||
/// @param acc
|
||||
/// @return
|
||||
Value reduceSubBlocks(int numSubBlocks, Value acc) const {
|
||||
return processSubBlocks(numSubBlocks, acc, true, false);
|
||||
}
|
||||
|
||||
/// @brief Zeroes out redundant values in all sub-blocks except first one
|
||||
///
|
||||
/// Every wave in mfma 4x4 layout holds only 4 unique values(scalar or
|
||||
/// vectors) in blocks of 4 consecutive threads, There are 16 copies of these
|
||||
/// 4 values across all threads of the wave. Need to zero out 15 copies to use
|
||||
/// accumulator between dot operations.
|
||||
/// @param numSubBlocks
|
||||
/// @param acc
|
||||
/// @return
|
||||
Value zeroAuxiliarBlocks(int numSubBlocks, Value acc) const {
|
||||
return processSubBlocks(numSubBlocks, acc, false, true);
|
||||
}
|
||||
|
||||
// Conduct the Dot conversion.
|
||||
LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor) const {
|
||||
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
|
||||
auto nonKDim = mfmaLayout.getNonKDim();
|
||||
assert(nonKDim == 32 || nonKDim == 16);
|
||||
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
|
||||
auto mfmaInstrDescr = getMatrixInstrDescr(op);
|
||||
|
||||
Value a = op.getA();
|
||||
@@ -296,8 +403,10 @@ struct DotOpMFMAConversionHelper {
|
||||
|
||||
unsigned warpSize = triton::gpu::getWarpSize(mfmaLayout);
|
||||
// compute number of output elements that each thread holds for one MFMA
|
||||
// instruction
|
||||
auto elemsPerVec = nonKDim * nonKDim / warpSize;
|
||||
// instruction. subBlocks
|
||||
const int subBlocks =
|
||||
getNumSubmatrices(aTensorTy.getElementType(), nonKDim);
|
||||
auto elemsPerVec = nonKDim * nonKDim * subBlocks / warpSize;
|
||||
|
||||
auto vecTy = vec_ty(dstElemTy, elemsPerVec);
|
||||
for (int m = 0; m < numRepM; ++m) {
|
||||
@@ -308,13 +417,14 @@ struct DotOpMFMAConversionHelper {
|
||||
vecTy, acc, fc[m * numRepN * elemsPerVec + n * elemsPerVec + v],
|
||||
i32_val(v));
|
||||
}
|
||||
|
||||
acc = zeroAuxiliarBlocks(subBlocks, acc);
|
||||
for (size_t k = 0; k < numRepK; k++) {
|
||||
acc =
|
||||
mfmaLayout.getIsTransposed()
|
||||
? generateMFMAOp(mfmaInstrDescr, hb[{n, k}], ha[{m, k}], acc)
|
||||
: generateMFMAOp(mfmaInstrDescr, ha[{m, k}], hb[{n, k}], acc);
|
||||
}
|
||||
acc = reduceSubBlocks(subBlocks, acc);
|
||||
for (unsigned v = 0; v < elemsPerVec; ++v) {
|
||||
fc[m * numRepN * elemsPerVec + n * elemsPerVec + v] =
|
||||
extract_element(dstElemTy, acc, i32_val(v));
|
||||
|
||||
@@ -775,7 +775,9 @@ public:
|
||||
auto nonKDim = mfmaLayout.getNonKDim();
|
||||
// MFMA output tile consists of repeated "dot operand B" layout groups along
|
||||
// row axis. This variable defines number of these groups.
|
||||
const unsigned numGroups = (nonKDim == 32 ? 4 : 1);
|
||||
DenseMap<int, int> groups{{4, 1}, {16, 1}, {32, 4}};
|
||||
unsigned numGroups = groups.at(nonKDim);
|
||||
|
||||
const unsigned elemsPerThreadPerGroup = 4;
|
||||
auto warpSize = getWarpSize(mfmaLayout);
|
||||
assert(warpSize == 64);
|
||||
@@ -1193,7 +1195,12 @@ private:
|
||||
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value warpSize = i32_val(triton::gpu::getWarpSize(mfmaLayout));
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value effectiveWarpSize = warpSize;
|
||||
if (nonKDim == 4) {
|
||||
const int uniqueValuesPerWarp = 4;
|
||||
effectiveWarpSize = i32_val(uniqueValuesPerWarp);
|
||||
}
|
||||
Value laneId = urem(threadId, effectiveWarpSize);
|
||||
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
Value warpId0 =
|
||||
|
||||
@@ -2,6 +2,10 @@
|
||||
#include "TypeConverter.h"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "triton/Dialect/NVGPU/IR/Dialect.h"
|
||||
|
||||
#if USE_ROCM
|
||||
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
|
||||
#endif
|
||||
namespace mlir {
|
||||
|
||||
namespace LLVM {
|
||||
@@ -286,10 +290,20 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter,
|
||||
#ifdef USE_ROCM
|
||||
//On AMD, the ds_swizzle_b32 and ds_permute_b32 instructions work on 32bit/dwords
|
||||
//so we need promote to 32 here.
|
||||
if (bits == 8) {
|
||||
Value i32Val = sext(i32_ty, val);
|
||||
Value result = commonShflSync(loc, rewriter, i32Val, i, strideInt, mode, clamp, laneId);
|
||||
return trunc(i8_ty, result);
|
||||
auto valType = val.getType();
|
||||
if (!valType.isInteger(32) && bits <= 32) {
|
||||
if (!valType.isIntOrIndex())
|
||||
val = bitcast(val, int_ty(bits));
|
||||
if (bits < 32)
|
||||
val = sext(i32_ty, val);
|
||||
|
||||
val = commonShflSync(loc, rewriter, val, i, strideInt, mode, clamp, laneId);
|
||||
|
||||
if (bits < 32)
|
||||
val = trunc(int_ty(bits), val);
|
||||
if (!valType.isIntOrIndex())
|
||||
val = bitcast(val, valType);
|
||||
return val;
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -307,20 +321,12 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter,
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
GCNBuilder builder;
|
||||
|
||||
auto permute = [&](Value lane, StringRef permuteInstStr) {
|
||||
assert(permuteInstStr == "ds_permute_b32" ||
|
||||
permuteInstStr == "ds_bpermute_b32");
|
||||
auto bpermute = [&](Value lane) {
|
||||
// Multiple lineId by 4. (More on permute instruction semantics:
|
||||
// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf#page=180
|
||||
Value byteOffset = i32_val(2);
|
||||
Value permuteAddr = shl(lane, byteOffset);
|
||||
auto shfl = builder.create(permuteInstStr.str());
|
||||
auto dOpr = builder.newOperand("=v");
|
||||
auto addrOpr = builder.newOperand(permuteAddr, "v");
|
||||
auto aOpr = builder.newOperand(val, "v");
|
||||
(*shfl)(dOpr, addrOpr, aOpr);
|
||||
return rewriter.create<ROCDL::DsBpermuteOp>(loc, valType, permuteAddr, val);
|
||||
};
|
||||
|
||||
switch (mode) {
|
||||
@@ -334,39 +340,30 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter,
|
||||
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)})
|
||||
.getResult(0);
|
||||
Value stride = i32_val(32);
|
||||
Value lineId = add(threadId, stride);
|
||||
permute(lineId, "ds_permute_b32");
|
||||
Value lineId = xor_(threadId, stride);
|
||||
return bpermute(lineId);
|
||||
} else {
|
||||
// This map facilates the butterfly shuffle pattern for a stride less
|
||||
// than 16. The pattern stride is the key of the map.
|
||||
DenseMap<short, unsigned int> masks{
|
||||
{16, 0x401F}, {8, 0x201F}, {4, 0x101F}, {2, 0x081F}, {1, 0x041F}};
|
||||
auto shfl = builder.create("ds_swizzle_b32");
|
||||
auto dOpr = builder.newOperand("=v");
|
||||
auto aOpr = builder.newOperand(val, "v");
|
||||
auto maskOpr =
|
||||
builder.newConstantOperand("offset:" + std::to_string(masks[strideInt]));
|
||||
(*shfl)(dOpr, aOpr, maskOpr);
|
||||
Value offset = i32_val(masks[strideInt]);
|
||||
return rewriter.create<ROCDL::DsSwizzleOp>(loc, valType, val, offset);
|
||||
}
|
||||
break;
|
||||
case NVVM::ShflKind::up: {
|
||||
Value mask = icmp_slt(laneId, i);
|
||||
Value delta = sub(laneId, i);
|
||||
Value index = select(mask, laneId, delta);
|
||||
permute(index, "ds_bpermute_b32");
|
||||
break;
|
||||
return bpermute(index);
|
||||
}
|
||||
case NVVM::ShflKind::idx:
|
||||
permute(i, "ds_bpermute_b32");
|
||||
break;
|
||||
return bpermute(i);
|
||||
default:
|
||||
assert(false && "Unsupported ShflKind");
|
||||
break;
|
||||
}
|
||||
|
||||
auto swait = builder.create("s_waitcnt lgkmcnt(0)");
|
||||
(*swait)();
|
||||
return builder.launch(rewriter, loc, val.getType(), true);
|
||||
return Value();
|
||||
#else
|
||||
Type type = val.getType();
|
||||
if (type != i32_ty) {
|
||||
|
||||
@@ -240,15 +240,22 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
||||
}
|
||||
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
unsigned rows, cols;
|
||||
if (mfmaLayout.getNonKDim() == 32) {
|
||||
switch (mfmaLayout.getNonKDim()) {
|
||||
case 32:
|
||||
rows = 16;
|
||||
cols = 1;
|
||||
} else if (mfmaLayout.getNonKDim() == 16) {
|
||||
break;
|
||||
case 16:
|
||||
rows = 4;
|
||||
cols = 1;
|
||||
} else
|
||||
break;
|
||||
case 4:
|
||||
rows = 4;
|
||||
cols = 1;
|
||||
break;
|
||||
default:
|
||||
llvm_unreachable("Unexpected mfma non-k dim");
|
||||
|
||||
}
|
||||
if (mfmaLayout.getIsTransposed()) {
|
||||
return {cols, rows};
|
||||
} else {
|
||||
@@ -993,9 +1000,11 @@ SmallVector<int64_t>
|
||||
DotOperandEncodingAttr::getMFMAElemsPerInstr() const {
|
||||
auto mfmaEncoding = getParent().cast<MfmaEncodingAttr>();
|
||||
int64_t nonKDim = mfmaEncoding.getNonKDim();
|
||||
assert(nonKDim == 32 || nonKDim == 16);
|
||||
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
|
||||
int64_t kWidth = getKWidth();
|
||||
int64_t kDim = kWidth * (nonKDim == 32 ? 2 : 4);
|
||||
constexpr int waveSize = 64; // MFMA is used on wave64 architectures only
|
||||
int kGroups = waveSize / nonKDim;
|
||||
int64_t kDim = kWidth * kGroups;
|
||||
if (getOpIdx() == 0)
|
||||
return {nonKDim, kDim};
|
||||
else
|
||||
|
||||
@@ -105,9 +105,18 @@ public:
|
||||
if (enforcedNonKDim != 0) {
|
||||
nonKDim = enforcedNonKDim;
|
||||
} else {
|
||||
nonKDim = (resShape[0] < 32 || resShape[1] < 32) ? 16 : 32;
|
||||
nonKDim = -1;
|
||||
int minSize = std::min(resShape[0], resShape[1]);
|
||||
if (minSize >= 32)
|
||||
nonKDim = 32;
|
||||
if (minSize >= 16 && minSize < 32)
|
||||
nonKDim = 16;
|
||||
if (minSize < 16)
|
||||
nonKDim = 4;
|
||||
assert(nonKDim != -1);
|
||||
}
|
||||
if (nonKDim == 32) {
|
||||
switch (nonKDim) {
|
||||
case 32:
|
||||
if (elemType.isF32())
|
||||
kDim = 2;
|
||||
if (elemType.isF16())
|
||||
@@ -130,7 +139,8 @@ public:
|
||||
kDim = 8;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
case 16:
|
||||
if (elemType.isF32())
|
||||
kDim = 4;
|
||||
if (elemType.isF16())
|
||||
@@ -152,7 +162,25 @@ public:
|
||||
else {
|
||||
kDim = 16;
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
case 4:
|
||||
if (elemType.isF32())
|
||||
kDim = 16;
|
||||
if (elemType.isF16())
|
||||
kDim = 64;
|
||||
if (elemType.isBF16()) {
|
||||
if (mfmaVersion == 1)
|
||||
kDim = 32;
|
||||
if (mfmaVersion >= 2)
|
||||
kDim = 64;
|
||||
}
|
||||
if (elemType.isInteger(8)) {
|
||||
kDim = 64;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
llvm::report_fatal_error("unsupported nonKDim size in MFMA dot");
|
||||
}
|
||||
assert(kDim != -1);
|
||||
assert(nonKDim != -1);
|
||||
@@ -221,11 +249,19 @@ public:
|
||||
auto kWidth = kDim;
|
||||
// in mfma 32x32 case argument matrix groups elements in 2 groups
|
||||
// in mfma 16x16 case argument matrix groups elements in 4 groups
|
||||
if (nonKDim == 32) {
|
||||
// in mfma 4x4 case arguemnt matrix groups in 16 groups
|
||||
switch (nonKDim) {
|
||||
case 32:
|
||||
kWidth /= 2;
|
||||
} else {
|
||||
assert(nonKDim == 16);
|
||||
break;
|
||||
case 16:
|
||||
kWidth /= 4;
|
||||
break;
|
||||
case 4:
|
||||
kWidth /= 16;
|
||||
break;
|
||||
default:
|
||||
llvm::report_fatal_error("unsupported kDim in mfma dot");
|
||||
}
|
||||
auto newAType = RankedTensorType::get(
|
||||
oldAType.getShape(), oldAType.getElementType(),
|
||||
|
||||
790
python/perf-kernels/flash-attention-seqlen-padded.py
Normal file
790
python/perf-kernels/flash-attention-seqlen-padded.py
Normal file
@@ -0,0 +1,790 @@
|
||||
"""
|
||||
Fused Attention
|
||||
===============
|
||||
|
||||
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
|
||||
Credits: OpenAI kernel team
|
||||
|
||||
This kernel supports arbitrarily sized sequence lengths.
|
||||
|
||||
Extra Credits:
|
||||
- Original flash attention paper (https://arxiv.org/abs/2205.14135)
|
||||
- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf)
|
||||
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
torch_dtype:tl.constexpr = torch.float16
|
||||
TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz')
|
||||
if TORCH_HAS_FP8E5:
|
||||
torch_dtype:tl.constexpr = torch.float8_e5m2fnuz
|
||||
|
||||
@triton.jit
|
||||
def _attn_fwd_inner(
|
||||
acc, l_i, m_i, q,
|
||||
K_block_ptr, V_block_ptr,
|
||||
start_m,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
STAGE: tl.constexpr,
|
||||
offs_m: tl.constexpr,
|
||||
offs_n: tl.constexpr,
|
||||
N_CTX,
|
||||
pre_load_v: tl.constexpr,
|
||||
padded_block: tl.constexpr,
|
||||
total_tokens: tl.constexpr,
|
||||
):
|
||||
# range of values handled by this stage
|
||||
if STAGE == 1:
|
||||
lo, hi = 0, start_m * BLOCK_M
|
||||
elif STAGE == 2:
|
||||
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
|
||||
lo = tl.multiple_of(lo, BLOCK_M)
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
|
||||
# N_CTX is the seqlen to the nearest block (round down).
|
||||
# So here, we are computing the elements for that last irregular block.
|
||||
# In the loop, we will mask the elements of BLOCK_N that do not exist.
|
||||
elif padded_block:
|
||||
lo, hi = N_CTX, N_CTX + BLOCK_N
|
||||
lo = tl.multiple_of(lo, BLOCK_N)
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
|
||||
# causal = False
|
||||
else:
|
||||
lo, hi = 0, N_CTX
|
||||
# loop over k, v and update accumulator
|
||||
for start_n in range(lo, hi, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# For padded blocks, we will overrun the tensor size if
|
||||
# we load all BLOCK_N. For others, the blocks are all within range.
|
||||
if padded_block:
|
||||
k = tl.load(K_block_ptr, boundary_check=(1,), padding_option="zero")
|
||||
else:
|
||||
k = tl.load(K_block_ptr)
|
||||
if pre_load_v:
|
||||
if padded_block:
|
||||
v = tl.load(V_block_ptr, boundary_check=(0,), padding_option="zero")
|
||||
else:
|
||||
v = tl.load(V_block_ptr)
|
||||
# -- compute qk ----
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
if STAGE == 2:
|
||||
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
|
||||
qk = tl.where(mask, qk, float("-inf"))
|
||||
if padded_block:
|
||||
boundary = tl.full([BLOCK_M], total_tokens, dtype=tl.float32)
|
||||
mask = (start_n + offs_n[None,:]) < boundary[:,None]
|
||||
qk = tl.where(mask, qk, float("-inf"))
|
||||
qk += tl.dot(q, k)
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||
qk = qk - m_ij[:, None]
|
||||
p = tl.math.exp2(qk)
|
||||
# -- update output accumulator --
|
||||
alpha = tl.math.exp2(m_i - m_ij)
|
||||
acc = acc * alpha[:, None]
|
||||
if not pre_load_v:
|
||||
if padded_block:
|
||||
v = tl.load(V_block_ptr, boundary_check=(0,), padding_option="zero")
|
||||
else:
|
||||
v = tl.load(V_block_ptr)
|
||||
acc += tl.dot(p.to(v.dtype), v)
|
||||
# -- update m_i and l_i
|
||||
l_ij = tl.sum(p, 1)
|
||||
l_i = l_i * alpha + l_ij
|
||||
# update m_i and l_i
|
||||
m_i = m_ij
|
||||
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||
return acc, l_i, m_i
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_fwd(
|
||||
Q, K, V, sm_scale, M, Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H,
|
||||
N_CTX,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
STAGE: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
pre_load_v: tl.constexpr,
|
||||
need_padding: tl.constexpr,
|
||||
extra_tokens_n: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
qvk_offset = off_hz * stride_qh
|
||||
|
||||
# block pointers
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + qvk_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + qvk_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_vk, stride_vn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + qvk_offset,
|
||||
shape=(BLOCK_DMODEL, N_CTX),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1),
|
||||
)
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out + qvk_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# scale sm_scale by log_2(e) and use
|
||||
# 2^x instead of exp in the loop because CSE and LICM
|
||||
# don't work as expected with `exp` in the loop
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
# load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs
|
||||
q = tl.load(Q_block_ptr, boundary_check=(0,1), padding_option="zero")
|
||||
q = (q * qk_scale).to(q.dtype)
|
||||
# stage 1: off-band
|
||||
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
|
||||
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
|
||||
if STAGE & 1:
|
||||
# We don't currently support causal masking and padding.
|
||||
tl.static_assert((STAGE != 3) or not need_padding)
|
||||
# equal to N_CTX if N_CTX is already a multiple of block_M
|
||||
seqlen_aligned = N_CTX - extra_tokens_n
|
||||
if N_CTX >= BLOCK_N:
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
|
||||
start_m,
|
||||
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
|
||||
4 - STAGE, offs_m, offs_n,
|
||||
seqlen_aligned, pre_load_v,
|
||||
False, seqlen_aligned
|
||||
)
|
||||
tl.debug_barrier()
|
||||
if need_padding:
|
||||
if N_CTX < BLOCK_N:
|
||||
seqlen_aligned = 0
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
|
||||
start_m,
|
||||
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
|
||||
4 - STAGE, offs_m, offs_n,
|
||||
seqlen_aligned, pre_load_v,
|
||||
True, N_CTX,
|
||||
)
|
||||
# stage 2: on-band
|
||||
if STAGE & 2:
|
||||
# barrier makes it easier for compielr to schedule the
|
||||
# two loops independently
|
||||
tl.debug_barrier()
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
|
||||
start_m,
|
||||
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
|
||||
2, offs_m, offs_n,
|
||||
N_CTX, pre_load_v,
|
||||
)
|
||||
# epilogue
|
||||
# write back m
|
||||
acc = acc / l_i[:, None]
|
||||
m_ptrs = M + off_hz * N_CTX + offs_m
|
||||
# Check for last block_M
|
||||
overflow_size = (start_m * BLOCK_M) - N_CTX
|
||||
if overflow_size > 0:
|
||||
boundary = tl.full((BLOCK_M,), overflow_size, dtype=tl.float32)
|
||||
# This is a > check because mask being 0 blocks the store.
|
||||
m_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
|
||||
tl.store(m_ptrs, m_i + tl.math.log2(l_i))
|
||||
else:
|
||||
tl.store(m_ptrs, m_i + tl.math.log2(l_i))
|
||||
tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_bwd_preprocess(O, DO, #
|
||||
NewDO, Delta, #
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, #
|
||||
):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, D_HEAD)
|
||||
# load
|
||||
o = tl.load(O + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
delta = tl.sum(o * do, axis=1)
|
||||
# write-back
|
||||
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
|
||||
tl.store(Delta + off_m, delta)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_kernel_dk_dv(
|
||||
Q, K, V, sm_scale, Out, DO,
|
||||
DK, DV,
|
||||
L,
|
||||
D,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
Z, H, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
# Q is consumed depending on block ID. Every block uses
|
||||
# previous block offset by BLOCK_M x D_HEAD.
|
||||
qvk_offset = off_hz * stride_qh
|
||||
qdo_offset = qvk_offset + start_m * BLOCK_M * stride_qm
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
# Initialize pointers to Q, K, V
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + qdo_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + qvk_offset,
|
||||
shape=(BLOCK_DMODEL, N_CTX),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, start_m * BLOCK_M),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1)
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + qvk_offset,
|
||||
shape=(BLOCK_DMODEL, N_CTX),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(0, start_m * BLOCK_M),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1)
|
||||
)
|
||||
DO_block_ptr = tl.make_block_ptr(
|
||||
base=DO + qdo_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
# pointer to row-wise quantities in value-like data
|
||||
D_ptrs = D + off_hz * N_CTX
|
||||
l_ptrs = L + off_hz * N_CTX
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
# load k and v: they will stay in SRAM throughout
|
||||
k = tl.load(K_block_ptr)
|
||||
k = (k * qk_scale).to(k.dtype)
|
||||
v = tl.load(V_block_ptr)
|
||||
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# This lower loop bound is because of the causal mask. We create a lower triangular
|
||||
# result. The upper triangular is -inf (becomes 0 when we do e^x). As such, it can
|
||||
# be ignored in the GEMM.
|
||||
lo = start_m * BLOCK_M
|
||||
hi = N_CTX
|
||||
# loop over q, do
|
||||
for start_n in range(lo, hi, BLOCK_N):
|
||||
offs_m_curr = offs_n[:, None] + start_n
|
||||
# -- load q, do --
|
||||
q = tl.load(Q_block_ptr)
|
||||
do = tl.load(DO_block_ptr)
|
||||
# -- compute qk ----
|
||||
qk = tl.dot(q, k)
|
||||
qk = tl.where(offs_m_curr >= offs_m[None, :], qk, float("-inf"))
|
||||
l_i = tl.load(l_ptrs + offs_m_curr)
|
||||
p = tl.math.exp2(qk - l_i)
|
||||
# -- compute dv ----
|
||||
dv += tl.dot(tl.trans(p.to(do.dtype)), do)
|
||||
# compute dp = dot(v, do)
|
||||
Di = tl.load(D_ptrs + offs_m_curr)
|
||||
dp = tl.zeros([BLOCK_N, BLOCK_M], dtype=tl.float32) - Di
|
||||
dp += tl.dot(do, v)
|
||||
# compute ds = p * (dp - delta[:, None])
|
||||
ds = p * dp
|
||||
# compute dk
|
||||
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
|
||||
# update pointers
|
||||
Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_N, 0))
|
||||
DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_N, 0))
|
||||
# initialize pointers to output
|
||||
DK_block_ptr = tl.make_block_ptr(
|
||||
base=DK + qvk_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_kn, stride_kk),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
DV_block_ptr = tl.make_block_ptr(
|
||||
base=DV + qvk_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_vk, stride_vn),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
tl.store(DK_block_ptr, (dk * sm_scale).to(DK.dtype.element_ty))
|
||||
tl.store(DV_block_ptr, dv.to(tl.float16))
|
||||
|
||||
@triton.jit
|
||||
def _bwd_kernel_dq(
|
||||
Q, K, V, sm_scale, Out, DO,
|
||||
DQ,
|
||||
L,
|
||||
D,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
Z, H, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
qvk_offset = off_hz * stride_qh
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
# Initialize pointers to Q, K, V
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + qvk_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + qvk_offset,
|
||||
shape=(BLOCK_DMODEL, N_CTX),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1)
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + qvk_offset,
|
||||
shape=(BLOCK_DMODEL, N_CTX),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1)
|
||||
)
|
||||
DO_block_ptr = tl.make_block_ptr(
|
||||
base=DO + qvk_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
# pointer to row-wise quantities in value-like data
|
||||
D_ptrs = D + off_hz * N_CTX
|
||||
l_ptrs = L + off_hz * N_CTX
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
# load q and do: they will stay in SRAM throughout
|
||||
q = tl.load(Q_block_ptr)
|
||||
q = (q * qk_scale).to(q.dtype)
|
||||
do = tl.load(DO_block_ptr)
|
||||
Di = tl.load(D_ptrs + offs_m)
|
||||
l_i = tl.load(l_ptrs + offs_m)
|
||||
dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# loop over k, v
|
||||
lo = 0
|
||||
hi = (start_m + 1) * BLOCK_M
|
||||
for start_n in range(lo, hi, BLOCK_N):
|
||||
# -- load k, v --
|
||||
k = tl.load(K_block_ptr)
|
||||
v = tl.load(V_block_ptr)
|
||||
# -- compute qk ----
|
||||
qk = tl.dot(q, k)
|
||||
qk = tl.where(offs_m[:, None] >= (offs_n[None, :] + start_n), qk, float("-inf"))
|
||||
p = tl.math.exp2(qk - l_i[:, None])
|
||||
# compute dp = dot(v, do)
|
||||
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
||||
dp += tl.dot(do, v)
|
||||
# compute ds = p * (dp - delta[:, None])
|
||||
ds = p * dp
|
||||
# compute dq. Unfortunately we cannot avoid transpose here as this loop
|
||||
# uses k both normal and transpose.
|
||||
dq += tl.dot(ds.to(Q.dtype.element_ty), tl.trans(k))
|
||||
# update pointers
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (0, BLOCK_N))
|
||||
# initialize pointers to output
|
||||
DQ_block_ptr = tl.make_block_ptr(
|
||||
base=DQ + qvk_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
tl.store(DQ_block_ptr, (dq * sm_scale).to(tl.float16))
|
||||
|
||||
empty = torch.empty(128, device="cuda")
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False):
|
||||
# shape constraints
|
||||
_, _, seqlen, Lq = q.shape
|
||||
Lk, Lv = k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
# For now we assume K and V seqlen = Q seqlen
|
||||
assert seqlen == k.shape[-2] and seqlen == v.shape[-2]
|
||||
|
||||
# We've derived these previously from tuning the kernel
|
||||
BLOCK_M = 256 if Lq == 128 else 128
|
||||
BLOCK_N = 64 #128 if Lq == 128 else 64
|
||||
waves_per_eu = 2 if Lq == 128 else 3
|
||||
num_warps = 8 if Lq == 128 else 4
|
||||
pre_load_v = False if Lq == 128 else True
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
|
||||
stage = 3 if causal else 1
|
||||
|
||||
# Compute if we need padding and how much
|
||||
seqlen_k = k.shape[-2]
|
||||
|
||||
if seqlen_k < BLOCK_N:
|
||||
need_padding = True
|
||||
extra_tokens_n = BLOCK_N - seqlen_k
|
||||
# This effectively ensures we do not slice across Q.
|
||||
assert(grid[0] == 1)
|
||||
elif seqlen_k % BLOCK_N:
|
||||
need_padding = True
|
||||
extra_tokens_n = seqlen_k % BLOCK_N
|
||||
else:
|
||||
# We don't care if BLOCK_M isn't aligned, as we
|
||||
# always boundary_check on Q and O
|
||||
need_padding = False
|
||||
extra_tokens_n = 0
|
||||
|
||||
o = torch.empty_like(q, dtype=v.dtype)
|
||||
M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
|
||||
_attn_fwd[grid](
|
||||
q, k, v, sm_scale, M, o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1],
|
||||
N_CTX=q.shape[2],
|
||||
BLOCK_DMODEL=Lk,
|
||||
STAGE=stage,
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
|
||||
waves_per_eu=waves_per_eu, pre_load_v=pre_load_v,
|
||||
need_padding=need_padding, extra_tokens_n=extra_tokens_n,
|
||||
num_stages=1, num_warps=num_warps
|
||||
)
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, M)
|
||||
ctx.grid = grid
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.BLOCK_DMODEL = Lk
|
||||
ctx.causal = causal
|
||||
ctx.split_kernel = split_kernel
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
# configuration is not supported
|
||||
assert(not (ctx.split_kernel and not ctx.causal))
|
||||
if torch.version.hip is not None:
|
||||
BLOCK = 64
|
||||
else:
|
||||
BLOCK = 128
|
||||
q, k, v, o, L = ctx.saved_tensors
|
||||
assert do.is_contiguous()
|
||||
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
|
||||
do = do.contiguous()
|
||||
dq = torch.zeros_like(q)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
BATCH, N_HEAD, N_CTX = q.shape[:3]
|
||||
delta = torch.empty_like(L)
|
||||
do_scaled = torch.empty_like(do)
|
||||
# Figure out what BLOCK size fwd used and adjust num_blocks accordingly.
|
||||
# If the two are the same, we don't need this but the bwd pass block size
|
||||
# is smaller than the fwd so we need this scaling to ensure we loop over all
|
||||
# values and don't skip some blocks.
|
||||
# Alternatively we could compute a new grid but this keeps it consistent
|
||||
# with fwd and easier to reason about.
|
||||
block_scale = (q.shape[2] // ctx.grid[0]) // BLOCK
|
||||
_attn_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
||||
o, do, #
|
||||
do_scaled, delta, #
|
||||
BLOCK_M=block_scale * BLOCK, D_HEAD=ctx.BLOCK_DMODEL, #
|
||||
)
|
||||
if not ctx.split_kernel:
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dq, dk, dv,
|
||||
L, delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
block_scale * ctx.grid[0],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4,
|
||||
CAUSAL=ctx.causal,
|
||||
num_stages=1,
|
||||
)
|
||||
else :
|
||||
dq = torch.zeros_like(q)
|
||||
_bwd_kernel_dk_dv[(block_scale * ctx.grid[0], ctx.grid[1])](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dk, dv,
|
||||
L, delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4,
|
||||
num_stages=1,
|
||||
)
|
||||
_bwd_kernel_dq[ctx.grid](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dq,
|
||||
L, delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=2*BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4, waves_per_eu=1,
|
||||
num_stages=1,
|
||||
)
|
||||
# print(h.asm["ttgir"])
|
||||
return dq, dk, dv, None, None, None
|
||||
|
||||
attention = _attention.apply
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
|
||||
[(1, 40, 19, 128),
|
||||
(4, 48, 1024, 64),
|
||||
(4, 48, 997, 64),
|
||||
(4, 48, 2048, 64),
|
||||
(4, 48, 4096, 64),
|
||||
(4, 48, 3989, 64),
|
||||
(4, 48, 1024, 128),
|
||||
(4, 48, 1021, 128),
|
||||
(4, 48, 2048, 128),
|
||||
(4, 48, 4096, 128),
|
||||
(4, 16, 8192, 64),
|
||||
(4, 16, 8080, 64),
|
||||
#(4, 48, 16384, 64)
|
||||
])
|
||||
@pytest.mark.parametrize('causal', [False])
|
||||
def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
q = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
k = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
v = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
if TORCH_HAS_FP8E5:
|
||||
q = q.to(torch_dtype)
|
||||
k = k.to(torch_dtype)
|
||||
sm_scale = 0.5
|
||||
dout = torch.randn_like(q, dtype=torch.float16)
|
||||
# reference implementation
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
|
||||
p = torch.matmul(q.half(), k.transpose(2, 3).half()) * sm_scale
|
||||
if causal:
|
||||
p[:, :, M == 0] = float("-inf")
|
||||
p = torch.softmax(p.float(), dim=-1).half()
|
||||
ref_out = torch.matmul(p, v)
|
||||
# triton implementation
|
||||
tri_out = attention(q, k, v, causal, sm_scale)
|
||||
# compare
|
||||
torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
|
||||
[(4, 48, 1024, 64),
|
||||
(4, 48, 2048, 64),
|
||||
(4, 48, 4096, 64),
|
||||
(1, 16, 8192, 64),
|
||||
])
|
||||
def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
causal = True
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
|
||||
sm_scale = 0.5
|
||||
split_kernel = True
|
||||
dout = torch.randn_like(q)
|
||||
# reference implementation
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
|
||||
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
|
||||
if causal:
|
||||
p[:, :, M == 0] = float("-inf")
|
||||
p = torch.softmax(p.float(), dim=-1).half()
|
||||
ref_out = torch.matmul(p, v)
|
||||
ref_out.backward(dout)
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
# # triton implementation
|
||||
tri_out = attention(q, k, v, causal, sm_scale, split_kernel)
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0)
|
||||
if torch.version.hip is None:
|
||||
torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=0)
|
||||
# The current block size for MI200 series is 64x64. This results in
|
||||
# larger differences in float results due to rounding.
|
||||
else:
|
||||
torch.testing.assert_close(ref_dv, tri_dv, atol=5e-2, rtol=0)
|
||||
torch.testing.assert_close(ref_dk, tri_dk, atol=5e-2, rtol=1e-2)
|
||||
torch.testing.assert_close(ref_dq, tri_dq, atol=5e-2, rtol=1e-2)
|
||||
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import \
|
||||
flash_attn_qkvpacked_func as flash_attn_func
|
||||
HAS_FLASH = True
|
||||
except BaseException:
|
||||
HAS_FLASH = False
|
||||
|
||||
# vary seq length for fixed head and batch=4
|
||||
configs = []
|
||||
for mode in ['fwd']:
|
||||
for D_HEAD in [128]:
|
||||
if mode == 'bwd' and D_HEAD == 128:
|
||||
continue
|
||||
for causal in [False]:
|
||||
if mode == 'bwd' and causal == False:
|
||||
continue
|
||||
configs.append(triton.testing.Benchmark(
|
||||
x_names=['BATCH', 'H','N_CTX'],
|
||||
x_vals=[(16, 16, 1024),
|
||||
(8, 16, 2048),
|
||||
(4, 16, 4096),
|
||||
(2, 16, 8192),
|
||||
(1, 16, 16384),
|
||||
(2, 48, 1024),
|
||||
(2, 48, 2048),
|
||||
(2, 48, 4096),
|
||||
(2, 48, 8192),
|
||||
(2, 48, 16384),
|
||||
(8, 16, 1989),
|
||||
(4, 16, 4097),
|
||||
(2, 16, 8122),
|
||||
(1, 16, 16281),
|
||||
(2, 48, 1021),
|
||||
(2, 48, 2001),
|
||||
(2, 48, 3996),
|
||||
(2, 48, 8181),
|
||||
],
|
||||
line_arg='provider',
|
||||
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
||||
line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []),
|
||||
styles=[('red', '-'), ('blue', '-')],
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-{mode}-d{D_HEAD}-causal={causal}',
|
||||
args={
|
||||
'D_HEAD': D_HEAD,
|
||||
'dtype': torch.float16,
|
||||
'mode': mode,
|
||||
'causal': causal})
|
||||
)
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def bench_flash_attention(
|
||||
BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"
|
||||
):
|
||||
assert mode in ["fwd", "bwd"]
|
||||
warmup = 25
|
||||
rep = 100
|
||||
split_kernel = False
|
||||
# Bwd pass only supports causal=True right now
|
||||
if mode == 'bwd':
|
||||
causal = True
|
||||
split_kernel = True
|
||||
if provider == "triton":
|
||||
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
if mode == "fwd":
|
||||
q = q.to(torch_dtype)
|
||||
k = k.to(torch_dtype)
|
||||
sm_scale = 1.3
|
||||
fn = lambda: attention(q, k, v, causal, sm_scale, split_kernel)
|
||||
if mode == 'bwd':
|
||||
o = fn()
|
||||
do = torch.randn_like(o)
|
||||
fn = lambda: o.backward(do, retain_graph=True)
|
||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||
if provider == "flash":
|
||||
qkv = torch.randn(
|
||||
(BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True
|
||||
)
|
||||
fn = lambda: flash_attn_func(qkv, causal=causal)
|
||||
if mode == "bwd":
|
||||
o = fn()
|
||||
do = torch.randn_like(o)
|
||||
fn = lambda: o.backward(do, retain_graph=True)
|
||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
|
||||
total_flops = 2 * flops_per_matmul
|
||||
if causal:
|
||||
total_flops *= 0.5
|
||||
if mode == "bwd":
|
||||
total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
|
||||
return total_flops / ms * 1e-9
|
||||
|
||||
|
||||
# only works on post-Ampere GPUs right now
|
||||
bench_flash_attention.run(save_path=".", print_data=True)
|
||||
@@ -1644,7 +1644,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
('float8e4m3fnuz', 'float32'),
|
||||
('float16', 'float32'),
|
||||
('float32', 'float32')]
|
||||
for non_k_dim in [0, 16, 32]
|
||||
for non_k_dim in [0, 4, 16, 32]
|
||||
if not (allow_tf32 and (in_dtype in ['float16']))] +
|
||||
|
||||
[(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype, non_k_dim)
|
||||
@@ -1670,13 +1670,18 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
[64, 32, 32, 2],
|
||||
[256, 32, 32, 2],
|
||||
[256, 32, 32, 4],
|
||||
[32, 8, 128, 4],
|
||||
[8, 32, 128, 2],
|
||||
[4, 32, 64, 4],
|
||||
[32, 4, 64, 2],
|
||||
[16, 4, 64, 8]
|
||||
]
|
||||
for allow_tf32 in [False, True]
|
||||
for col_a in [True, False]
|
||||
for col_b in [True, False]
|
||||
for in_dtype in ['int8', 'bfloat16', 'float16', 'float32']
|
||||
for out_dtype in [None]
|
||||
for non_k_dim in [0, 16, 32]])
|
||||
for non_k_dim in [0, 4, 16, 32]])
|
||||
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, non_k_dim, device='cuda'):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
|
||||
@@ -1697,6 +1702,12 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
out_dtype = "int32"
|
||||
if non_k_dim == 32 and (M < 32 or N < 32):
|
||||
pytest.skip("incompatible non_k_dim == 32 with MN sizes")
|
||||
if non_k_dim == 16 and (M < 16 or N < 16):
|
||||
pytest.skip("incompatible non_k_dim == 16 with MN sizes")
|
||||
if non_k_dim == 4 and (K < 64):
|
||||
pytest.skip("incompatible non_k_dim == 4 with K size")
|
||||
if non_k_dim == 4 and (M > 16 or N > 16):
|
||||
pytest.skip("skipping lage matrices for non_k_dim == 4 to speedup testing")
|
||||
|
||||
if capability[0] < 7:
|
||||
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||
@@ -2003,6 +2014,7 @@ def get_variant_golden(a, b):
|
||||
|
||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,NUM_STAGES', [
|
||||
[64, 32, 128, 4, 64, 32, 64, 0],
|
||||
[4, 16, 128, 4, 4, 16, 64, 1],
|
||||
[64, 32, 128, 4, 64, 32, 64, 2]
|
||||
])
|
||||
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, NUM_STAGES):
|
||||
|
||||
@@ -1190,7 +1190,7 @@ def is_hip():
|
||||
|
||||
def mfma_supported_granularity(m, n, k) -> bool:
|
||||
# todo make this gran_type matrix element type sensitive
|
||||
for gran_type in [(32, 8), (16, 16)]:
|
||||
for gran_type in [(32, 8), (16, 16), (4, 64)]:
|
||||
granularity_mn, granularity_k = gran_type
|
||||
|
||||
if m % granularity_mn != 0 or n % granularity_mn != 0:
|
||||
@@ -1261,11 +1261,15 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_nu
|
||||
|
||||
assert len(lhs.shape) == 2, f"First input shape ({lhs.shape}) is not two dimensional!"
|
||||
assert len(rhs.shape) == 2, f"Second input shape ({rhs.shape}) is not two dimensional!"
|
||||
assert lhs.shape[1].value == rhs.shape[
|
||||
0].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[1].value}) must be equal to first index of second shape ({rhs.shape[0].value})"
|
||||
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
|
||||
and rhs.shape[1].value >= 16, \
|
||||
f"All values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!"
|
||||
assert lhs.shape[1].value == rhs.shape[0].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[1].value}) must be equal to first index of second shape ({rhs.shape[0].value})"
|
||||
if _is_cuda(builder.target):
|
||||
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
|
||||
and rhs.shape[1].value >= 16, \
|
||||
f"All values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!"
|
||||
if is_hip():
|
||||
assert lhs.shape[0].value >= 4 and lhs.shape[1].value >= 16 \
|
||||
and rhs.shape[1].value >= 4, \
|
||||
f"All values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 4!"
|
||||
|
||||
# hip for now converts fp8 to fp16 for mixed input
|
||||
if is_hip():
|
||||
|
||||
@@ -69,3 +69,26 @@ On some node, I saw the following runtime error
|
||||
```
|
||||
It's hard to reproduce the error. **Needs further investigation**
|
||||
- https://github.com/ROCmSoftwarePlatform/frameworks-internal/issues/6011
|
||||
|
||||
# One config running script
|
||||
|
||||
`one_config.py` is a script that runs one given matmul config.
|
||||
It is an interface to `tune_gemm.py` functionality and could be used for triton debugging.
|
||||
|
||||
### Usage
|
||||
|
||||
This script supports two methods to specify configuration parameters.
|
||||
|
||||
Variant 1: Separate command line attributes.
|
||||
|
||||
```bash
|
||||
python one_config.py -m 256 -n 256 -k 256 --block_m 64 --block_n 64 --block_k 64 --group_m 1 --split_k 2 --num_warps 2 --num_stages 0 --waves_per_eu 0
|
||||
```
|
||||
|
||||
Variant 2: one-line config description.
|
||||
This is how configs are printed by `tune_gemm.py` script
|
||||
|
||||
```bash
|
||||
python one_config.py --config_str M16_N8_K128_BM64_BN64_BK64_GM1_SK2_nW2_nS0_EU0
|
||||
```
|
||||
|
||||
|
||||
78
scripts/amd/gemm/one_config.py
Normal file
78
scripts/amd/gemm/one_config.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
Script for running one Matrix Multiplication kernel config at a time
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import sys
|
||||
import tune_gemm
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="check corectness of particular config for tuning gemm script",
|
||||
allow_abbrev=False,
|
||||
)
|
||||
|
||||
parser.add_argument("-m", type=int, default=0)
|
||||
parser.add_argument("-n", type=int, default=0)
|
||||
parser.add_argument("-k", type=int, default=0)
|
||||
parser.add_argument("--block_m", type=int, default=0)
|
||||
parser.add_argument("--block_n", type=int, default=0)
|
||||
parser.add_argument("--block_k", type=int, default=0)
|
||||
parser.add_argument("--group_m", type=int, default=0)
|
||||
parser.add_argument("--split_k", type=int, default=0)
|
||||
parser.add_argument("--num_warps", type=int, default=0)
|
||||
parser.add_argument("--num_stages", type=int, default=0)
|
||||
parser.add_argument("--waves_per_eu", type=int, default=0)
|
||||
|
||||
parser.add_argument("--config_str", type=str, default="", help="can take from gemm_tune.py script output, looks like M16_N8_K128_BM64_BN64_BK64_GM1_SK2_nW2_nS0_EU0")
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def parse_config(cfg_str):
|
||||
values = cfg_str.split("_")
|
||||
config_name = {"M": "M",
|
||||
"N": "N",
|
||||
"K": "K",
|
||||
"BM": "BLOCK_SIZE_M",
|
||||
"BN": "BLOCK_SIZE_N",
|
||||
"BK": "BLOCK_SIZE_K",
|
||||
"GM": "GROUP_SIZE_M",
|
||||
"SK": "SPLIT_K",
|
||||
"nW": "num_warps",
|
||||
"nS": "num_stages",
|
||||
"EU": "waves_per_eu",
|
||||
}
|
||||
config = {}
|
||||
for val in values:
|
||||
match = re.search("([a-zA-Z]*)([0-9]*)", val)
|
||||
if match:
|
||||
cfg_field_name = config_name[match.group(1)]
|
||||
config[cfg_field_name] = int(match.group(2))
|
||||
return config
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
if args.config_str:
|
||||
config = parse_config(args.config_str)
|
||||
else:
|
||||
config = {"M": args.m,
|
||||
"N": args.n,
|
||||
"K": args.k,
|
||||
"BLOCK_SIZE_M": args.block_m,
|
||||
"BLOCK_SIZE_N": args.block_n,
|
||||
"BLOCK_SIZE_K": args.block_k,
|
||||
"GROUP_SIZE_M": args.group_m,
|
||||
"SPLIT_K": args.split_k,
|
||||
"num_warps": args.num_warps,
|
||||
"num_stages": args.num_stages,
|
||||
"waves_per_eu": args.waves_per_eu,
|
||||
}
|
||||
tune_gemm.test_correctness(config["M"], config["N"], config["K"], config, verbose=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -54,6 +54,12 @@ def prune_configs(M, N, K, configs):
|
||||
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
|
||||
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
|
||||
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
|
||||
if mfma == 4 and BLOCK_SIZE_K < 64:
|
||||
continue
|
||||
# some layouts could not work properly in case
|
||||
# number elemens per thread is less 1
|
||||
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
|
||||
continue
|
||||
SPLIT_K = config.get("SPLIT_K")
|
||||
GROUP_M = config.get("GROUP_SIZE_M")
|
||||
if BLOCK_SIZE_M < mfma or BLOCK_SIZE_N < mfma:
|
||||
@@ -87,9 +93,12 @@ def need_split_k(SIZE_M, SIZE_N, SIZE_K):
|
||||
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024
|
||||
|
||||
|
||||
def run_bash_command(commandstring):
|
||||
proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash', stdout = subprocess.PIPE)
|
||||
return proc.stdout.splitlines()
|
||||
def run_bash_command(commandstring, capture=True):
|
||||
if capture:
|
||||
proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash', stdout = subprocess.PIPE)
|
||||
return proc.stdout.splitlines()
|
||||
proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash')
|
||||
return None
|
||||
|
||||
|
||||
def read_config(config):
|
||||
@@ -113,7 +122,7 @@ def matmul_{configStr}(a, b, c, M, N, K, am, ak, bk, bn, cm, cn, warmup=False):
|
||||
#M, K = a.shape
|
||||
#K, N = b.shape
|
||||
grid = triton.cdiv(M, {block_m}) * triton.cdiv(N, {block_n}), {split_k}
|
||||
print(f'config: matmul_kernel_{configStr}')
|
||||
print(f'config: matmul_kernel_{configStr}', flush=True)
|
||||
if warmup:
|
||||
matmul_kernel_{configStr}.warmup(
|
||||
torch.float16, torch.float16, torch.float16,
|
||||
@@ -129,6 +138,7 @@ def matmul_{configStr}(a, b, c, M, N, K, am, ak, bk, bn, cm, cn, warmup=False):
|
||||
waves_per_eu = {waves_per_eu},
|
||||
grid=(1,)
|
||||
)
|
||||
return None
|
||||
else:
|
||||
matmul_kernel_{configStr}[grid](
|
||||
a, b, c,
|
||||
@@ -143,7 +153,7 @@ def matmul_{configStr}(a, b, c, M, N, K, am, ak, bk, bn, cm, cn, warmup=False):
|
||||
num_stages = {num_stages},
|
||||
waves_per_eu = {waves_per_eu}
|
||||
)
|
||||
return c
|
||||
return c
|
||||
|
||||
def try_config_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, dtype):
|
||||
#a = torch.randn((M, K), device='cuda', dtype=dtype)
|
||||
@@ -151,13 +161,20 @@ def try_config_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, dtype):
|
||||
#c = torch.zeros((M, N), device=a.device, dtype=a.dtype)
|
||||
try:
|
||||
matmul_{configStr}(None, None, None, M, N, K, am, ak, bk, bn, cm, cn, True)
|
||||
except Exception:
|
||||
print(f'invalid config {configStr}')
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f'invalid config(compilation): {configStr}: ', e, flush=True)
|
||||
return False
|
||||
"""
|
||||
return configStr, matmul_def_str
|
||||
|
||||
## Open {ngpus} files
|
||||
## generated_kernelMNK-0.py, generated_kernelMNK-1.py, ..., generated_kernelMNK-{ngpus-1}.py
|
||||
|
||||
def generated_kernel_name(M, N, K, gpu_id):
|
||||
return f"generated_kernel{M}-{N}-{K}-{gpu_id}.py"
|
||||
|
||||
|
||||
## Open {len(gpus)} files
|
||||
## generated_kernelM-N-K-{gpus[0]}.py, generated_kernelM-N-K-{gpus[1]}.py, ..., generated_kernelM-N-K-{gpus[-1]}.py
|
||||
## and generate
|
||||
## 1. matmul kernels of all configs
|
||||
## 2. wrapper function matmul to invoke all the generated kernels
|
||||
@@ -165,10 +182,11 @@ def try_config_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, dtype):
|
||||
## 4. test_gemm to invoke
|
||||
## 4.1 run try_config in parallel
|
||||
## 4.2 matmul in a loop of 10 iterations
|
||||
def generate_kernel(M, N, K, configs, ngpus):
|
||||
def generate_kernel(M, N, K, configs, gpus):
|
||||
filenames = []
|
||||
for fi in range(ngpus):
|
||||
filenames.append(f"generated_kernel{M}{N}{K}-{fi}.py")
|
||||
ngpus = len(gpus)
|
||||
for gpu_id in gpus:
|
||||
filenames.append(generated_kernel_name(M, N, K, gpu_id))
|
||||
f_kernel = [open(path, 'w') for path in filenames]
|
||||
|
||||
### write imports
|
||||
@@ -185,7 +203,7 @@ import multiprocessing
|
||||
### write definitions of matmul_kernel_xxx
|
||||
### and matmul_xxx and try_config
|
||||
with open("matmul_kernel.py") as file:
|
||||
matmul_kernel_code = file.read();
|
||||
matmul_kernel_code = file.read()
|
||||
idx = 0
|
||||
for config in configs:
|
||||
file_idx = idx % ngpus
|
||||
@@ -211,6 +229,8 @@ import multiprocessing
|
||||
c.stride(0), c.stride(1), dtype)
|
||||
|
||||
if num_threads > 1:
|
||||
results = []
|
||||
config_names = []
|
||||
"""
|
||||
for fi in range(ngpus):
|
||||
f_kernel[fi].write(test_gemm_pre_str + "\n")
|
||||
@@ -219,23 +239,40 @@ import multiprocessing
|
||||
idx = 0
|
||||
for config in configs:
|
||||
configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config)
|
||||
task_str = f" thread_pool.apply_async(try_config_{configStr}, args=task_args)\n"
|
||||
task_str = f" results += [thread_pool.apply_async(try_config_{configStr}, args=task_args)]\n" + \
|
||||
f" config_names += ['{configStr}']\n"
|
||||
f_kernel[idx % ngpus].write(task_str)
|
||||
idx += 1
|
||||
|
||||
threadpool_str = """
|
||||
for fi in range(ngpus):
|
||||
threadpool_str = """
|
||||
failed_configs = []
|
||||
for i in range(len(results)):
|
||||
results[i].wait()
|
||||
res = results[i].get()
|
||||
if not res:
|
||||
failed_configs += [config_names[i]]
|
||||
thread_pool.close()
|
||||
thread_pool.join()
|
||||
else:"""
|
||||
for fi in range(ngpus):
|
||||
with open("{filename}.failed_configs", "w") as f:
|
||||
for cfg in failed_configs:
|
||||
f.write(cfg + "\\n")
|
||||
else:
|
||||
try:
|
||||
with open("{filename}.failed_configs", "r") as f:
|
||||
failed_configs = [cfg.strip() for cfg in f.readlines()]
|
||||
except Exception:
|
||||
failed_configs = []
|
||||
""".format(filename = filenames[fi])
|
||||
f_kernel[fi].write(threadpool_str)
|
||||
# call all matmul_xxx functions
|
||||
idx = 0
|
||||
for config in configs:
|
||||
configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, config)
|
||||
matmul_call_str = f"""
|
||||
for i in range(10):
|
||||
d = matmul_{configStr}(a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1))"""
|
||||
if '{configStr}' not in failed_configs:
|
||||
for i in range(10):
|
||||
d = matmul_{configStr}(a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1))"""
|
||||
f_kernel[idx % ngpus].write(matmul_call_str + "\n")
|
||||
idx += 1
|
||||
# post string
|
||||
@@ -267,30 +304,29 @@ def extract_kernel_time(M, N, K, config, gpuid):
|
||||
return config, parsed_outputs
|
||||
|
||||
|
||||
def profile_batch_kernels(M, N, K, gpuid):
|
||||
def profile_batch_kernels(M, N, K, gpuid, verbose):
|
||||
os.environ['ROCR_VISIBLE_DEVICES'] = str(gpuid)
|
||||
run_bash_command(f"rocprof --stats -o results-{gpuid}.csv python generated_kernel{M}{N}{K}-{gpuid}.py")
|
||||
run_bash_command(f"rocprof --stats -o results-{gpuid}.csv python {generated_kernel_name(M, N, K, gpuid)}", capture=(verbose < 2))
|
||||
|
||||
|
||||
def tune_gemm_config(M, N, K, configs, verbose=False, num_threads=16, ngpus = 1):
|
||||
def tune_gemm_config(M, N, K, configs, verbose=0, num_threads=16, gpus = [0]):
|
||||
## Generate kernel out of all configs
|
||||
generate_kernel(M, N, K, configs, ngpus)
|
||||
generate_kernel(M, N, K, configs, gpus)
|
||||
|
||||
## remove any compiled kernel in the cache
|
||||
run_bash_command("rm -rf ~/.triton/cache")
|
||||
|
||||
## precompile the kernels in parallel
|
||||
## TODO: parameterize numThreads at this level
|
||||
start_time = datetime.now()
|
||||
for fi in range(ngpus):
|
||||
run_bash_command(f"python generated_kernel{M}{N}{K}-{fi}.py -n 32")
|
||||
for gpu_id in gpus:
|
||||
run_bash_command(f"python {generated_kernel_name(M, N, K, gpu_id)} -n {num_threads}", capture=(verbose < 2))
|
||||
compile_end = datetime.now()
|
||||
compile_time = compile_end - start_time
|
||||
if verbose:
|
||||
print(f"compile time: {compile_time}")
|
||||
print(f"compile time: {compile_time}", flush=True)
|
||||
|
||||
## profile generated kernels
|
||||
running = [multiprocessing.Process(target=profile_batch_kernels, args=(M,N,K,fi)) for fi in range(ngpus)]
|
||||
running = [multiprocessing.Process(target=profile_batch_kernels, args=(M,N,K,gpu_id,verbose)) for gpu_id in gpus]
|
||||
for p in running:
|
||||
p.start()
|
||||
for p in running:
|
||||
@@ -299,7 +335,7 @@ def tune_gemm_config(M, N, K, configs, verbose=False, num_threads=16, ngpus = 1)
|
||||
profile_end = datetime.now()
|
||||
profile_time = profile_end - compile_end
|
||||
if verbose:
|
||||
print(f"profile time: {profile_time}")
|
||||
print(f"profile time: {profile_time}", flush=True)
|
||||
|
||||
## post process results.csv to get the best config and minTime
|
||||
## TODO: process the file in parallel
|
||||
@@ -308,8 +344,9 @@ def tune_gemm_config(M, N, K, configs, verbose=False, num_threads=16, ngpus = 1)
|
||||
tasks = []
|
||||
idx = 0
|
||||
for config in configs:
|
||||
file_idx = idx % ngpus
|
||||
tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, file_idx))]
|
||||
file_idx = idx % len(gpus)
|
||||
gpu_id = gpus[file_idx]
|
||||
tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, gpu_id))]
|
||||
idx += 1
|
||||
thread_pool.close()
|
||||
thread_pool.join()
|
||||
@@ -323,11 +360,11 @@ def tune_gemm_config(M, N, K, configs, verbose=False, num_threads=16, ngpus = 1)
|
||||
bestConfig = config
|
||||
else:
|
||||
min_us = -1
|
||||
print(f"invalid config: SIZE {M} {N} {K}: {config}")
|
||||
print(f"invalid config(post processing): SIZE {M} {N} {K}: {config}", flush=True)
|
||||
post_end = datetime.now()
|
||||
post_time = post_end - profile_end
|
||||
if verbose:
|
||||
print(f"post procesing time: {post_time}")
|
||||
print(f"post procesing time: {post_time}", flush=True)
|
||||
return minTime, bestConfig, compile_time, profile_time, post_time
|
||||
|
||||
|
||||
@@ -379,9 +416,9 @@ def test_correctness(M, N, K, config, verbose, datatype = torch.float16):
|
||||
if verbose:
|
||||
size_str = f'SIZE M: {M}, N: {N}, K: {K} '
|
||||
if torch.allclose(triton_output, torch_output, atol=1e-1, rtol=rtol):
|
||||
print(f'{size_str}✅')
|
||||
print(f'{size_str} Correct✅')
|
||||
else:
|
||||
print(f'{size_str}❌')
|
||||
print(f'{size_str} Incorrect❌')
|
||||
|
||||
|
||||
def get_default_tuning_result_filename():
|
||||
@@ -404,13 +441,16 @@ def parse_args():
|
||||
parser.add_argument("-m", type=int, default=0)
|
||||
parser.add_argument("-n", type=int, default=0)
|
||||
parser.add_argument("-k", type=int, default=0)
|
||||
parser.add_argument("--ngpus", type=int, default=1, help='number of GPUs used in the profiling step')
|
||||
parser.add_argument("--ngpus", type=int, default=0, help='number of GPUs used in the profiling step')
|
||||
parser.add_argument("--gpu_ids", type=lambda s: [int(id) for id in s.split(',')], default=[], help='list of gpu ids to use for tuning')
|
||||
parser.add_argument("--gemm_size_file", type=str, default="", help='yaml file to indicate matrix size')
|
||||
parser.add_argument("--tuning_results_file", type=str, default=get_default_tuning_result_filename(), help='yaml file to store tuning results')
|
||||
parser.add_argument("--keep", action='store_true', default=False, help='keep generated files')
|
||||
parser.add_argument("--compare", action='store_true', default=False, help="Whether check result correctness")
|
||||
parser.add_argument("--compare_wo_tuning", action='store_true', default=False, help="Whether check result correctness")
|
||||
parser.add_argument("--time_breakdown", action='store_true', default=False, help="Show detailed time breakdown of each step during the tuning")
|
||||
parser.add_argument("--verbose", action='store_true', default=False, help="enables time_breakdown and additional logging messages")
|
||||
parser.add_argument("--num_threads", type=int, default=16, help="number of threads to use for kernel compilation and post processing")
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
@@ -422,6 +462,16 @@ def main():
|
||||
tuning_output_file = args.tuning_results_file
|
||||
keepTmp = args.keep
|
||||
ngpus = args.ngpus
|
||||
gpu_ids = args.gpu_ids
|
||||
if ngpus != 0 and gpu_ids:
|
||||
print("--ngpus and --gpu_ids are mutually exclusive options")
|
||||
return os.EX_USAGE
|
||||
if ngpus == 0 and not gpu_ids:
|
||||
ngpus = 1
|
||||
if ngpus != 0:
|
||||
gpus = range(ngpus)
|
||||
if gpu_ids:
|
||||
gpus = gpu_ids
|
||||
|
||||
mnks = []
|
||||
## TODO: make it more robust to get user input
|
||||
@@ -454,7 +504,7 @@ def main():
|
||||
configs_full = get_full_tuning_space()
|
||||
|
||||
start_time = datetime.now()
|
||||
print(f"Tuning starts at: {start_time}")
|
||||
print(f"Tuning starts at: {start_time}", flush=True)
|
||||
|
||||
f_results = open(tuning_output_file, 'w')
|
||||
for (M, N, K) in mnks:
|
||||
@@ -466,7 +516,12 @@ def main():
|
||||
print(f"{size_str} nConfigs: {len(pruned_configs)}", end=" ", flush=True)
|
||||
|
||||
## The main tuning funtion for one gemm size
|
||||
minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config(M, N, K, pruned_configs, ngpus = ngpus, verbose=args.time_breakdown)
|
||||
verbose_level = 0
|
||||
if args.time_breakdown:
|
||||
verbose_level = 1
|
||||
if args.verbose:
|
||||
verbose_level = 2
|
||||
minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config(M, N, K, pruned_configs, num_threads=args.num_threads, gpus = gpus, verbose=verbose_level)
|
||||
|
||||
## post processing the numbers
|
||||
perf_tflops = lambda us: 2 * M * N * K * 1e-12 / (us * 1e-6)
|
||||
@@ -475,10 +530,10 @@ def main():
|
||||
formatted_tflops = "{:.3e}".format(tri_tflops)
|
||||
else:
|
||||
formatted_tflops = "{:.2f}".format(tri_tflops)
|
||||
print(f'TFLOPS: {formatted_tflops} time(us): {minTime}', end=" ")
|
||||
print(f'TFLOPS: {formatted_tflops} time(us): {minTime}', end=" ", flush=True)
|
||||
|
||||
bestConfig_compact_str, _ = gen_kernel_and_configStr_from_config(M, N, K, bestConfig)
|
||||
print(f'best_config: {bestConfig_compact_str}', end=" ")
|
||||
print(f'best_config: {bestConfig_compact_str}', end=" ", flush=True)
|
||||
|
||||
## write best config to tuning_results.yaml
|
||||
sizeDict = {'M': M, 'N': N, 'K': K}
|
||||
@@ -488,20 +543,22 @@ def main():
|
||||
|
||||
## remove generated files if asked to
|
||||
if not keepTmp:
|
||||
for fi in range(ngpus):
|
||||
os.remove(f"generated_kernel{M}{N}{K}-{fi}.py")
|
||||
for f in glob.glob(f"results-{fi}.*"):
|
||||
for gpu_id in gpus:
|
||||
generated_script = generated_kernel_name(M, N, K, gpu_id)
|
||||
os.remove(generated_script)
|
||||
os.remove(generated_script + ".failed_configs")
|
||||
for f in glob.glob(f"results-{gpu_id}.*"):
|
||||
os.remove(f)
|
||||
|
||||
## Check correctness if asked to
|
||||
if args.compare:
|
||||
print("correctness: ", end=" ")
|
||||
print("correctness: ", end=" ", flush=True)
|
||||
test_correctness(M, N, K, bestConfig, False)
|
||||
else:
|
||||
print("")
|
||||
print("", flush=True)
|
||||
|
||||
end_local_time = datetime.now()
|
||||
print(f">>> Elapsed time: {end_local_time - start_local_time} = {compile_time} (compile) + {profile_time} (profile) + {post_time} (post processing)")
|
||||
print(f">>> Elapsed time: {end_local_time - start_local_time} = {compile_time} (compile) + {profile_time} (profile) + {post_time} (post processing)", flush=True)
|
||||
|
||||
f_results.close()
|
||||
|
||||
|
||||
350
scripts/amd/plot_layout.py
Executable file
350
scripts/amd/plot_layout.py
Executable file
@@ -0,0 +1,350 @@
|
||||
import argparse
|
||||
import sys
|
||||
import yaml
|
||||
import os
|
||||
import glob
|
||||
import subprocess
|
||||
|
||||
|
||||
def draw_preamble_cmd():
|
||||
return '''\\documentclass[tikz, border=1mm, dvipsnames]{standalone}
|
||||
\\usepackage{ifthen}
|
||||
\\usepackage{tikz}
|
||||
\\usetikzlibrary{arrows.meta,arrows}
|
||||
\\usetikzlibrary{intersections}
|
||||
\\usetikzlibrary{calc, quotes}
|
||||
\\usetikzlibrary{patterns}
|
||||
\\usepackage{xparse}
|
||||
|
||||
\\ExplSyntaxOn
|
||||
\\NewExpandableDocumentCommand{\\bitwiseXor}{mm}
|
||||
{
|
||||
\\recuenco_bitwise_xor:nn { #1 } { #2 }
|
||||
}
|
||||
|
||||
\\cs_new:Nn \\recuenco_bitwise_xor:nn
|
||||
{
|
||||
\\int_from_bin:e
|
||||
{
|
||||
\\__recuenco_bitwise_xor:ee { \\int_to_bin:n { #1 } } { \\int_to_bin:n { #2 } }
|
||||
}
|
||||
}
|
||||
\\cs_generate_variant:Nn \\int_from_bin:n { e }
|
||||
|
||||
\\cs_new:Nn \\__recuenco_bitwise_xor:nn
|
||||
{
|
||||
\\__recuenco_bitwise_xor_binary:ee
|
||||
{
|
||||
\\prg_replicate:nn
|
||||
{
|
||||
\\int_max:nn { \\tl_count:n { #1 } } { \\tl_count:n { #2 } } - \\tl_count:n { #1 }
|
||||
}
|
||||
{ 0 }
|
||||
#1
|
||||
}
|
||||
{
|
||||
\\prg_replicate:nn
|
||||
{
|
||||
\\int_max:nn { \\tl_count:n { #1 } } { \\tl_count:n { #2 } } - \\tl_count:n { #2 }
|
||||
}
|
||||
{ 0 }
|
||||
#2
|
||||
}
|
||||
}
|
||||
\\cs_generate_variant:Nn \\__recuenco_bitwise_xor:nn { ee }
|
||||
|
||||
\\cs_new:Nn \\__recuenco_bitwise_xor_binary:nn
|
||||
{
|
||||
\\__recuenco_bitwise_xor_binary:w #1;#2;
|
||||
}
|
||||
\\cs_generate_variant:Nn \\__recuenco_bitwise_xor_binary:nn { ee }
|
||||
|
||||
\\cs_new:Npn \\__recuenco_bitwise_xor_binary:w #1#2;#3#4;
|
||||
{
|
||||
\\int_abs:n { #1-#3 }
|
||||
\\tl_if_empty:nF { #2 } { \\__recuenco_bitwise_xor_binary:w #2;#4; }
|
||||
}
|
||||
|
||||
\\ExplSyntaxOff'''
|
||||
|
||||
|
||||
def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kpack):
|
||||
return f'''\\begin{{document}}
|
||||
\\begin{{tikzpicture}}
|
||||
\\def\\scale{{1}}
|
||||
\\def\\elem{{0.04}}
|
||||
\\coordinate (C TL) at (0,0);
|
||||
\\def\\opColorAL{{magenta}}
|
||||
\\def\\opColorAR{{cyan}}
|
||||
\\def\\opColorBL{{Maroon}}
|
||||
\\def\\opColorBR{{BlueGreen}}
|
||||
\\drawDot{{{M}}}{{{N}}}{{{K}}}{{{mfmaNonKDim}}}{{{warpsPerCTA[0]}}}{{{warpsPerCTA[1]}}}{{{trans}}}{{{kpack}}}
|
||||
|
||||
\\coordinate (C TL) at ($(C TL)+({N}*\elem+32*\elem, 0)$);
|
||||
\\def\\mfmaTrans{{{trans}}}
|
||||
\\ifthenelse{{\\mfmaTrans=0}}{{
|
||||
\\def\\opColorAL{{magenta}}
|
||||
\\def\\opColorAR{{cyan}}
|
||||
\\def\\opColorBL{{Maroon}}
|
||||
\\def\\opColorBR{{BlueGreen}}
|
||||
}}{{
|
||||
\\def\\opColorBL{{magenta}}
|
||||
\\def\\opColorBR{{cyan}}
|
||||
\\def\\opColorAL{{Maroon}}
|
||||
\\def\\opColorAR{{BlueGreen}}
|
||||
}}
|
||||
%% Draw zoomed in view of mfma
|
||||
\\def\\elem{{.16}}
|
||||
\\pgfmathsetmacro{{\\gap}}{{\\elem*5}}
|
||||
\\pgfmathsetmacro{{\\nonTrans}}{{1-\\mfmaTrans}}
|
||||
\\coordinate (C TL) at ($(C TL)+(.5*\\gap+1.2*\\nonTrans*\\gap+2*{kpack}*\\elem, 0)$);
|
||||
\\drawMFMAInstr{{{mfmaNonKDim}}}{{{kpack}}}{{\\mfmaTrans}}
|
||||
|
||||
\\end{{tikzpicture}}
|
||||
\\end{{document}}'''
|
||||
|
||||
|
||||
def draw_blocked_layout_cmd(M, K, sizePerThread, threadsPerWarp, warpsPerCTA,
|
||||
order):
|
||||
return f'''\\begin{{document}}
|
||||
\\begin{{tikzpicture}}
|
||||
\\def\\scale{{1}}
|
||||
\\def\\elem{{0.06}}
|
||||
\\coordinate (TL) at (0,0);
|
||||
\\drawBlockedTensor{{{M}}}{{{K}}}{{{sizePerThread[0]}}}{{{sizePerThread[1]}}}{{{threadsPerWarp[0]}}}{{{warpsPerCTA[0]}}}{{{warpsPerCTA[1]}}}{{{order[0]}}}
|
||||
\\end{{tikzpicture}}
|
||||
\\end{{document}}'''
|
||||
|
||||
|
||||
def draw_lds_access_cmd(M, K, kpack, ldsLayout, ldsAccess, sizePerThread,
|
||||
threadsPerWarp):
|
||||
if ldsLayout == 'swizzle':
|
||||
hasSwizzle = 1
|
||||
elif ldsLayout == 'padding':
|
||||
hasSwizzle = 2
|
||||
else:
|
||||
hasSwizzle = 0
|
||||
|
||||
if ldsAccess == 'read':
|
||||
accessMode = 1
|
||||
elif ldsAccess == 'write':
|
||||
accessMode = 2
|
||||
else:
|
||||
accessMode = 0
|
||||
|
||||
return f'''\\begin{{document}}
|
||||
\\begin{{tikzpicture}}
|
||||
\\def\\scale{{1}}
|
||||
\\def\\M{{{M}}}
|
||||
\\def\\K{{{K}}}
|
||||
\\def\\vec{{{kpack}}}
|
||||
\\def\\hasSwizzle{{{hasSwizzle}}}
|
||||
\\def\\accessMode{{{accessMode}}}
|
||||
|
||||
\\def\\sizePerThreadK{{{sizePerThread[1]}}}
|
||||
\\def\\sizePerThreadM{{{sizePerThread[0]}}}
|
||||
\\def\\threadsPerWarpK{{{threadsPerWarp[1]}}}
|
||||
|
||||
\\def\\elem{{0.18}}
|
||||
\\coordinate (TL) at (0,0);
|
||||
\\drawTensorLayoutGlobalMem
|
||||
\\coordinate (TL) at ($(TL)+(0, -24*\\elem-10*\\elem)$);
|
||||
\\drawLDSLayoutTritonSwizzling{{\\hasSwizzle}}{{\\accessMode}}
|
||||
\\end{{tikzpicture}}
|
||||
\\end{{document}}'''
|
||||
|
||||
|
||||
def draw_wmma_instr_cmd(waveSize):
|
||||
wmma_mode = 0 if waveSize == 32 else 1
|
||||
return f'''\\begin{{document}}
|
||||
\\begin{{tikzpicture}}
|
||||
\\def\\scale{{1}}
|
||||
\\coordinate (C TL) at (0,0);
|
||||
\\def\\elem{{0.25}}
|
||||
\\drawWMMAInstr{{{wmma_mode}}}{{1}}
|
||||
\\end{{tikzpicture}}
|
||||
\\end{{document}}'''
|
||||
|
||||
|
||||
def run_bash_command(commandstring):
|
||||
proc = subprocess.run(commandstring,
|
||||
shell=True,
|
||||
check=True,
|
||||
executable='/bin/bash',
|
||||
stdout=subprocess.PIPE)
|
||||
return proc.stdout.splitlines()
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="Draw triton layouts",
|
||||
allow_abbrev=False,
|
||||
)
|
||||
## tensor shapes
|
||||
parser.add_argument("-shape",
|
||||
type=int,
|
||||
nargs=3,
|
||||
default=(32, 128, 64),
|
||||
help='Tensor shape in the form of M,N,K')
|
||||
parser.add_argument("-plot",
|
||||
type=str,
|
||||
default="blocked",
|
||||
choices=['blocked', 'dot', 'wmma', 'lds'],
|
||||
help='choose plot mode')
|
||||
parser.add_argument(
|
||||
"-nonKDim",
|
||||
type=int,
|
||||
default=32,
|
||||
choices=[32],
|
||||
help='mfma instruction dim, only 32 is supported for now')
|
||||
## blocked layout parameters
|
||||
parser.add_argument("-sizePerThread", type=int, nargs=2, default=(1, 4))
|
||||
parser.add_argument("-threadsPerWarp", type=int, nargs=2, default=(16, 4))
|
||||
parser.add_argument("-warpsPerCTA", type=int, nargs=2, default=(1, 4))
|
||||
parser.add_argument("-order", type=int, nargs=2, default=(1, 0))
|
||||
## LDS access parameters
|
||||
parser.add_argument("-kpack",
|
||||
type=int,
|
||||
default=4,
|
||||
choices=[4, 8],
|
||||
help='vector length during LDS load, same as vec')
|
||||
parser.add_argument("-lds_layout",
|
||||
type=str,
|
||||
default="none",
|
||||
choices=['swizzle', 'padding', 'none'],
|
||||
help='choose the LDS data layout')
|
||||
parser.add_argument("-lds_access",
|
||||
type=str,
|
||||
default="none",
|
||||
choices=['read', 'write', 'none'],
|
||||
help='choose LDS access mode')
|
||||
## wmma instruction layout parameter
|
||||
parser.add_argument("-wave_size",
|
||||
type=int,
|
||||
default=32,
|
||||
choices=[32, 64],
|
||||
help='choose the wmma instruction mode')
|
||||
|
||||
parser.add_argument("-o",
|
||||
type=str,
|
||||
default="myplot",
|
||||
help='output pdf file name (without surfix)')
|
||||
parser.add_argument("-mfmaTrans",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='If set, then use mfma.trans layout')
|
||||
parser.add_argument("--keep",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='If set, keep the generated .tex file')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
shape = args.shape
|
||||
M = shape[0]
|
||||
N = shape[1]
|
||||
K = shape[2]
|
||||
plot_mode = args.plot
|
||||
mfmaNonKDim = args.nonKDim
|
||||
kpack = args.kpack
|
||||
trans = 1 if args.mfmaTrans else 0
|
||||
ofilename = args.o
|
||||
keepSrc = args.keep
|
||||
|
||||
ldsLayout = args.lds_layout
|
||||
ldsAccess = args.lds_access
|
||||
|
||||
waveSize = args.wave_size
|
||||
|
||||
sizePerThread = args.sizePerThread
|
||||
threadsPerWarp = args.threadsPerWarp
|
||||
warpsPerCTA = args.warpsPerCTA
|
||||
order = args.order
|
||||
|
||||
CTAShape = []
|
||||
if plot_mode == 'blocked':
|
||||
print(f"Plotting tensor M={M},K={K} with blocked layout:")
|
||||
print(f"sizePerThread={sizePerThread}", end=" ")
|
||||
print(f"threadsPerWarp={threadsPerWarp}", end=" ")
|
||||
print(f"warpsPerCTA={warpsPerCTA}", end=" ")
|
||||
print(f"order={order}", end=" ")
|
||||
CTAShape.append(sizePerThread[0] * threadsPerWarp[0] * warpsPerCTA[0])
|
||||
CTAShape.append(sizePerThread[1] * threadsPerWarp[1] * warpsPerCTA[1])
|
||||
|
||||
if plot_mode == 'dot':
|
||||
mfma_inst_str = "mfma_32x32x8f16" if mfmaNonKDim == 32 else "mfma_16x16x16f16"
|
||||
mfma_trans_str = ".trans" if trans else ""
|
||||
print(f"Plotting dot operation with shapes M={M},N={N},K={K}")
|
||||
print("MFMA: " + mfma_inst_str + mfma_trans_str, end=" ")
|
||||
print(f"warpsPerCTA={warpsPerCTA}", end=" ")
|
||||
CTAShape.append(32 * warpsPerCTA[0])
|
||||
CTAShape.append(32 * warpsPerCTA[1])
|
||||
|
||||
if plot_mode == 'blocked' or plot_mode == 'dot':
|
||||
print(f"CTAShape={CTAShape}")
|
||||
assert M != 0 and CTAShape[
|
||||
0] <= M and M % CTAShape[0] == 0, "bad tensor dimension M"
|
||||
|
||||
if plot_mode == 'blocked':
|
||||
assert K != 0 and CTAShape[
|
||||
1] <= K and K % CTAShape[1] == 0, "bad tensor dimension K"
|
||||
|
||||
if plot_mode == 'dot':
|
||||
assert N != 0 and CTAShape[
|
||||
1] <= N and N % CTAShape[1] == 0, "bad tensor dimension N"
|
||||
assert K != 0 and K % (2 * kpack) == 0, "bad tensor dimension K"
|
||||
|
||||
if plot_mode == 'lds':
|
||||
print(f"Plotting LDS access for tensor M={M},K={K} with vec={kpack}")
|
||||
if ldsAccess == 'write':
|
||||
print(
|
||||
f"sizePerThread={sizePerThread}, threadsPerWarp={threadsPerWarp}"
|
||||
)
|
||||
|
||||
with open("myplot.tex", 'w') as f_plot:
|
||||
with open("tikzplot.tex") as file:
|
||||
tikz_code = file.read()
|
||||
|
||||
preamble_str = draw_preamble_cmd()
|
||||
|
||||
draw_blockedLayout_str = draw_blocked_layout_cmd(
|
||||
M, K, sizePerThread, threadsPerWarp, warpsPerCTA, order)
|
||||
|
||||
draw_dotLayout_str = draw_dot_layout_cmd(M, N, K, mfmaNonKDim,
|
||||
warpsPerCTA, trans, kpack)
|
||||
|
||||
draw_lds_str = draw_lds_access_cmd(M, K, kpack, ldsLayout, ldsAccess,
|
||||
sizePerThread, threadsPerWarp)
|
||||
|
||||
draw_wmma_str = draw_wmma_instr_cmd(waveSize)
|
||||
|
||||
f_plot.write(preamble_str + "\n")
|
||||
f_plot.write(tikz_code)
|
||||
if plot_mode == 'blocked':
|
||||
f_plot.write(draw_blockedLayout_str)
|
||||
elif plot_mode == 'dot':
|
||||
f_plot.write(draw_dotLayout_str)
|
||||
elif plot_mode == 'lds':
|
||||
f_plot.write(draw_lds_str)
|
||||
elif plot_mode == 'wmma':
|
||||
f_plot.write(draw_wmma_str)
|
||||
|
||||
run_bash_command(f"pdflatex -jobname {ofilename} myplot.tex")
|
||||
print(f"plot saved in {ofilename}.pdf")
|
||||
|
||||
## Remove au files
|
||||
os.remove(f"{ofilename}.aux")
|
||||
os.remove(f"{ofilename}.log")
|
||||
if not keepSrc:
|
||||
os.remove("myplot.tex")
|
||||
run_bash_command("rm -rf ./auto")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
874
scripts/amd/tikzplot.tex
Executable file
874
scripts/amd/tikzplot.tex
Executable file
@@ -0,0 +1,874 @@
|
||||
\newcommand{\drawBlockedWave}[5]{
|
||||
%%
|
||||
%% Draw a wave coverage with blocked layout
|
||||
%%
|
||||
%% Wave TL: pre defined top-left coordinate of the wave
|
||||
%% \elem: pre defined variable
|
||||
%%
|
||||
%% #1: sizePerThread[0] --> sizePerThreadM
|
||||
%% #2: sizePerThread[1] --> sizePerThreadN
|
||||
%% #3: threadsPerWarp[0] --> threadsPerWarpM
|
||||
%% #4: threadsPerWarp[1] --> threadsPerWarpN
|
||||
%% #5: fastest changing dim --> order
|
||||
|
||||
\pgfmathsetmacro{\sizePerThreadM}{#1}
|
||||
\pgfmathsetmacro{\sizePerThreadN}{#2}
|
||||
\pgfmathsetmacro{\threadsPerWarpM}{#3}
|
||||
\pgfmathsetmacro{\threadsPerWarpN}{#4}
|
||||
\pgfmathsetmacro{\order}{#5}
|
||||
|
||||
\pgfmathsetmacro{\waveSizeM}{\sizePerThreadM*\threadsPerWarpM}
|
||||
\pgfmathsetmacro{\waveSizeN}{\sizePerThreadN*\threadsPerWarpN}
|
||||
|
||||
\foreach \tid in {0,...,63}{
|
||||
\pgfmathsetmacro{\tidM}{int(\tid/\threadsPerWarpN)}
|
||||
\pgfmathsetmacro{\tidN}{mod(\tid,\threadsPerWarpN)}
|
||||
\coordinate (Thread TL) at ($(Wave TL)+(\tidN*\sizePerThreadN*\elem, -\tidM*\sizePerThreadM*\elem)$);
|
||||
\pgfmathsetmacro{\ratio}{\tidM*10}
|
||||
|
||||
\ifthenelse{\tid = 0}{
|
||||
\draw [line width = 0.01mm, fill=red] (Thread TL)
|
||||
rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem);
|
||||
}{
|
||||
\draw [line width = 0.01mm, fill=blue!\ratio!white] (Thread TL)
|
||||
rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem);
|
||||
}
|
||||
}
|
||||
\draw (Wave TL) rectangle ++(\waveSizeN*\elem, -\waveSizeM*\elem);
|
||||
}
|
||||
|
||||
\newcommand{\drawBlockedCTA}[7]{
|
||||
%%
|
||||
%% Draw a CTA coverage with blocked layout
|
||||
%%
|
||||
%% CTA TL: pre defined top-left coordinate of the CTA
|
||||
%% \elem: pre defined variable
|
||||
%%
|
||||
%% #1: sizePerThread[0] --> sizePerThreadM
|
||||
%% #2: sizePerThread[1] --> sizePerThreadN
|
||||
%% #3: threadsPerWarp[0] --> threadsPerWarpM
|
||||
%% #4: threadsPerWarp[1] --> threadsPerWarpN
|
||||
%% #5: warpsPerCTA[0] --> warpsPerCTAM
|
||||
%% #6: warpsPerCTA[1] --> warpsPerCTAN
|
||||
%% #7: fastest changing dim --> order
|
||||
|
||||
\pgfmathsetmacro{\sizePerThreadM}{#1}
|
||||
\pgfmathsetmacro{\sizePerThreadN}{#2}
|
||||
\pgfmathsetmacro{\threadsPerWarpM}{#3}
|
||||
\pgfmathsetmacro{\threadsPerWarpN}{#4}
|
||||
\pgfmathsetmacro{\warpsPerCTAM}{#5}
|
||||
\pgfmathsetmacro{\warpsPerCTAN}{#6}
|
||||
\pgfmathsetmacro{\order}{#7}
|
||||
|
||||
\pgfmathsetmacro{\CTASizeM}{\sizePerThreadM*\threadsPerWarpM*\warpsPerCTAM}
|
||||
\pgfmathsetmacro{\CTASizeN}{\sizePerThreadN*\threadsPerWarpN*\warpsPerCTAN}
|
||||
\pgfmathsetmacro{\waveSizeM}{\sizePerThreadM*\threadsPerWarpM}
|
||||
\pgfmathsetmacro{\waveSizeN}{\sizePerThreadN*\threadsPerWarpN}
|
||||
|
||||
\pgfmathsetmacro{\maxWaveId}{\warpsPerCTAM*\warpsPerCTAN-1}
|
||||
|
||||
\coordinate (Wave TL) at (CTA TL);
|
||||
\drawBlockedWave{\sizePerThreadM}{\sizePerThreadN}{\threadsPerWarpM}{\threadsPerWarpN}{\order}
|
||||
\foreach \waveId in {0,...,\maxWaveId}{
|
||||
\ifthenelse{\order=1}
|
||||
{
|
||||
\pgfmathsetmacro{\waveCoordM}{int(\waveId/\warpsPerCTAN)}
|
||||
\pgfmathsetmacro{\waveCoordN}{mod(\waveId,\warpsPerCTAN)}
|
||||
\pgfmathsetmacro{\rot}{0}
|
||||
}{
|
||||
\pgfmathsetmacro{\waveCoordM}{mod(\waveId,\warpsPerCTAM)}
|
||||
\pgfmathsetmacro{\waveCoordN}{int(\waveId/\warpsPerCTAM)}
|
||||
\pgfmathsetmacro{\rot}{90}
|
||||
}
|
||||
|
||||
\coordinate (Wave TL) at ($(CTA TL)+(\waveCoordN*\waveSizeN*\elem, -\waveCoordM*\waveSizeM*\elem)$);
|
||||
\draw [ultra thin] (Wave TL) rectangle ++(\waveSizeN*\elem, -\waveSizeM*\elem)
|
||||
node [pos=.5, scale=.6*\scale, inner sep=0, fill=white, rotate=\rot] {wave\waveId};
|
||||
}
|
||||
|
||||
\draw [thick] (CTA TL) rectangle ++(\CTASizeN*\elem, -\CTASizeM*\elem);
|
||||
}
|
||||
|
||||
\newcommand{\drawBlockedTensor}[8]{
|
||||
%%
|
||||
%% Draw a tensor with blocked layout of the following parameters
|
||||
%% sizePerThread[2]
|
||||
%% threadsPerWarp[2]
|
||||
%% warpsPerCTA[2]
|
||||
%% order[2]
|
||||
%%
|
||||
%% TL: pre defined top-left coordinate of the tensor
|
||||
%% \elem: pre defined variable
|
||||
%%
|
||||
%% #1: tensorShape[0] --> M
|
||||
%% #2: tensorShape[1] --> N
|
||||
%% #3: sizePerThread[0] --> sizePerThreadM
|
||||
%% #4: sizePerThread[1] --> sizePerThreadN
|
||||
%% #5: threadsPerWarp[0] --> threadsPerWarpM
|
||||
%% Note that threadsPerWarp[1] is calculated by 64/threadsPerWarp[0]
|
||||
%% #6: warpsPerCTA[0] --> warpsPerCTAM
|
||||
%% #7: warpsPerCTA[1] --> warpsPerCTAN
|
||||
%% #8: fastest changing dim --> order
|
||||
|
||||
\pgfmathsetmacro{\M}{#1}
|
||||
\pgfmathsetmacro{\N}{#2}
|
||||
\pgfmathsetmacro{\sizePerThreadM}{#3}
|
||||
\pgfmathsetmacro{\sizePerThreadN}{#4}
|
||||
\pgfmathsetmacro{\threadsPerWarpM}{#5}
|
||||
\pgfmathsetmacro{\warpsPerCTAM}{#6}
|
||||
\pgfmathsetmacro{\warpsPerCTAN}{#7}
|
||||
\pgfmathsetmacro{\order}{#8}
|
||||
|
||||
\pgfmathsetmacro{\threadsPerWarpN}{64/\threadsPerWarpM}
|
||||
\pgfmathsetmacro{\CTASizeM}{\sizePerThreadM*\threadsPerWarpM*\warpsPerCTAM}
|
||||
\pgfmathsetmacro{\CTASizeN}{\sizePerThreadN*\threadsPerWarpN*\warpsPerCTAN}
|
||||
\pgfmathsetmacro{\CTARepM}{\M/\CTASizeM}
|
||||
\pgfmathsetmacro{\CTARepN}{\N/\CTASizeN}
|
||||
\pgfmathsetmacro{\maxCTAId}{\CTARepM*\CTARepN-1}
|
||||
|
||||
\foreach \ctaId in {0,...,\maxCTAId}{
|
||||
\pgfmathsetmacro{\ctaCoordM}{int(\ctaId/\CTARepN)}
|
||||
\pgfmathsetmacro{\ctaCoordN}{mod(\ctaId,\CTARepN)}
|
||||
\coordinate (CTA TL) at ($(TL)+(\ctaCoordN*\CTASizeN*\elem, -\ctaCoordM*\CTASizeM*\elem)$);
|
||||
\drawBlockedCTA{\sizePerThreadM}{\sizePerThreadN}{\threadsPerWarpM}{\threadsPerWarpN}{\warpsPerCTAM}{\warpsPerCTAN}{\order}
|
||||
}
|
||||
|
||||
\node [scale=.7*\scale, above, rotate=90] at ($(TL)+(0, -.5*\M*\elem)$) {M=\M};
|
||||
\node [scale=.7*\scale, above] at ($(TL)+(.5*\N*\elem, 0)$) {K=\N};
|
||||
|
||||
\def\zoomR{1.5}
|
||||
\coordinate (zoomin BL) at ($(TL)+(0, .3)$);
|
||||
|
||||
\foreach \hl in {0,...,\sizePerThreadM}{
|
||||
\draw ($(zoomin BL)+(0, \hl*\elem*\zoomR)$) -- ++(\sizePerThreadN*\elem*\zoomR,0);
|
||||
}
|
||||
\foreach \vl in {0,...,\sizePerThreadN}{
|
||||
\draw ($(zoomin BL)+(\vl*\elem*\zoomR, 0)$) -- ++(0, \sizePerThreadM*\elem*\zoomR);
|
||||
}
|
||||
|
||||
\node [scale=.6*\scale, left] at ($(zoomin BL)+(0, .5*\sizePerThreadM*\elem*\zoomR)$) {$t_0$};
|
||||
\node [scale=.6*\scale, right] at ($(zoomin BL)+(\sizePerThreadN*\elem*\zoomR, .5*\sizePerThreadM*\elem*\zoomR)$) {\sizePerThreadM$\times$\sizePerThreadN};
|
||||
|
||||
\draw [densely dotted] (TL) -- (zoomin BL);
|
||||
\draw [densely dotted] ($(TL)+(\sizePerThreadN*\elem, 0)$) -- ($(zoomin BL)+(\sizePerThreadN*\elem*\zoomR, 0)$);
|
||||
\draw [fill=red] (TL) rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem);
|
||||
}
|
||||
|
||||
\newcommand{\drawBlockMFMALayoutLarge}[2]{
|
||||
%%
|
||||
%% Draw a single block of MFMA_32x32x8xf16
|
||||
%%
|
||||
%% block TL: pre-defined top-left coordinate of the block
|
||||
%% \elem: pre defined variable
|
||||
%%
|
||||
%% #1: 1 for mfma.trans, 0 for normal mfma
|
||||
%% #2: verbose. 1 means draw tid in each vec; 0 means draw nothing
|
||||
|
||||
\pgfmathsetmacro{\trans}{#1}
|
||||
\pgfmathsetmacro{\nonTrans}{1-#1}
|
||||
\pgfmathsetmacro{\verbose}{#2}
|
||||
\foreach \iVec in {0,1,2,3} {
|
||||
\coordinate (wave TL) at ($(block TL)+(\trans*\iVec*2*4*\elem, -\nonTrans*\iVec*2*4*\elem)$);
|
||||
\foreach \col/\tg in {blue/0,orange/1}{
|
||||
\foreach \tid in {0,...,31} {
|
||||
\pgfmathsetmacro{\ratio}{\tid*2.5+15}
|
||||
\ifthenelse{\verbose=0}{
|
||||
\draw [line width=0.005mm, fill=\col!\ratio!white]
|
||||
($(wave TL)+(\nonTrans*\tid*\elem+\tg*\trans*4*\elem, -\trans*\tid*\elem-\tg*\nonTrans*4*\elem)$)
|
||||
rectangle ++(\nonTrans*\elem+\trans*4*\elem, -\nonTrans*4*\elem-\trans*\elem);
|
||||
}{
|
||||
\pgfmathsetmacro{\drawTid}{int(\tid+\tg*32)}
|
||||
\draw [line width=0.005mm, fill=\col!\ratio!white]
|
||||
($(wave TL)+(\nonTrans*\tid*\elem+\tg*\trans*4*\elem, -\trans*\tid*\elem-\tg*\nonTrans*4*\elem)$)
|
||||
rectangle ++(\nonTrans*\elem+\trans*4*\elem, -\nonTrans*4*\elem-\trans*\elem)
|
||||
node [pos=.5, scale=.35*\scale, rotate=90*\nonTrans] {t\drawTid};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
\draw [thick] (block TL) rectangle ++(32*\elem, -32*\elem);
|
||||
}
|
||||
|
||||
|
||||
\newcommand{\drawTensorMFMALayout}[6]{
|
||||
%%
|
||||
%% Draw a tensor with mfma layout.
|
||||
%%
|
||||
%% C TL: pre defined top-left coordinates of the tensor
|
||||
%%
|
||||
%% #1: M
|
||||
%% #2: N
|
||||
%% #3: MFMA nonKDim
|
||||
%% #4: warpsPerCTA[0]
|
||||
%% #5: warpsPerCTA[1]
|
||||
%% #6: 1 for mfma.trans, 0 for normal mfma
|
||||
|
||||
\pgfmathsetmacro{\tensorShapeH}{#1}
|
||||
\pgfmathsetmacro{\tensorShapeW}{#2}
|
||||
\pgfmathsetmacro{\mfmaNonKDim}{#3}
|
||||
\pgfmathsetmacro{\warpsPerCTAH}{#4}
|
||||
\pgfmathsetmacro{\warpsPerCTAW}{#5}
|
||||
\pgfmathsetmacro{\mfmaTrans}{#6}
|
||||
|
||||
\coordinate (old TL) at (TL);
|
||||
\coordinate (TL) at (C TL);
|
||||
|
||||
|
||||
\pgfmathsetmacro{\CTARepH}{\tensorShapeH/\mfmaNonKDim/\warpsPerCTAH}
|
||||
\pgfmathsetmacro{\CTARepW}{\tensorShapeW/\mfmaNonKDim/\warpsPerCTAW}
|
||||
\pgfmathsetmacro{\maxCTAId}{\CTARepH*\CTARepW-1}
|
||||
\pgfmathsetmacro{\maxWaveId}{\warpsPerCTAH*\warpsPerCTAW-1}
|
||||
\pgfmathsetmacro{\CTASizeH}{\warpsPerCTAH*\mfmaNonKDim}
|
||||
\pgfmathsetmacro{\CTASizeW}{\warpsPerCTAW*\mfmaNonKDim}
|
||||
|
||||
|
||||
\foreach \ctaId in {0,...,\maxCTAId}{
|
||||
\pgfmathsetmacro{\ctaCoordH}{int(\ctaId/\CTARepW)}
|
||||
\pgfmathsetmacro{\ctaCoordW}{mod(\ctaId,\CTARepW)}
|
||||
\coordinate (CTA TL) at ($(TL)+(\ctaCoordW*\CTASizeW*\elem, -\ctaCoordH*\CTASizeH*\elem)$);
|
||||
%% Draw a detailed view of wave0 in each CTA
|
||||
\coordinate (block TL) at (CTA TL);
|
||||
\drawBlockMFMALayoutLarge{\mfmaTrans}{0}
|
||||
|
||||
\foreach \waveId in {0,...,\maxWaveId}{
|
||||
\pgfmathsetmacro{\waveCoordH}{int(\waveId/\warpsPerCTAW)}
|
||||
\pgfmathsetmacro{\waveCoordW}{mod(\waveId,\warpsPerCTAW)}
|
||||
\coordinate (block TL) at ($(CTA TL)+(\waveCoordW*\mfmaNonKDim*\elem, -\waveCoordH*\mfmaNonKDim*\elem)$);
|
||||
%% Inside the loop, only draw a rectangle
|
||||
\draw [ultra thin] (block TL) rectangle ++(\mfmaNonKDim*\elem, -\mfmaNonKDim*\elem)
|
||||
node [scale=.7*\scale, pos=.5, fill=white, inner sep=0] {wave\waveId};
|
||||
}
|
||||
|
||||
%% Draw the outline of each CTA rep
|
||||
\draw [ultra thick] (CTA TL) rectangle ++(\CTASizeW*\elem, -\CTASizeH*\elem);
|
||||
}
|
||||
|
||||
\coordinate (TL) at (old TL);
|
||||
}
|
||||
|
||||
\newcommand{\drawMFMAOperand}[4]{
|
||||
%%
|
||||
%% Draw one mfma operand
|
||||
%%
|
||||
%% mfma op TL: pre defined coordinates of the top-left
|
||||
%% \elem: pre defined variable
|
||||
%%
|
||||
%% #1: mfmNonKDim
|
||||
%% #2: kpack
|
||||
%% #3: 0 for opA and 1 for opB
|
||||
%% #4: verbose. 1 means draw tid in each vec; 0 means draw nothing
|
||||
|
||||
\pgfmathsetmacro{\nonKDim}{#1}
|
||||
\pgfmathsetmacro{\kpack}{#2}
|
||||
\pgfmathsetmacro{\opIdxA}{#3}
|
||||
\pgfmathsetmacro{\opIdxB}{1-\opIdxA}
|
||||
\pgfmathsetmacro{\verbose}{#4}
|
||||
|
||||
\ifthenelse{\opIdxA = 0}{
|
||||
\def\opColorL{\opColorAL}
|
||||
\def\opColorR{\opColorAR}
|
||||
}{
|
||||
\def\opColorL{\opColorBL}
|
||||
\def\opColorR{\opColorBR}
|
||||
}
|
||||
|
||||
\foreach \col/\tg in {\opColorL/0,\opColorR/1}{
|
||||
\foreach \tid in {0,...,31} {
|
||||
% \pgfmathsetmacro{\ratio}{\tid*2.5+15}
|
||||
\ifthenelse{\verbose=0}{
|
||||
\draw [line width=0.005mm, fill=\col]
|
||||
($(mfma op TL)+(\tg*\kpack*\elem*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kpack*\elem*\opIdxA)$)
|
||||
rectangle ++(\kpack*\elem*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kpack*\elem*\opIdxA);
|
||||
}{
|
||||
\pgfmathsetmacro{\drawTid}{int(\tid+\tg*32)}
|
||||
\draw [line width=0.005mm, fill=\col]
|
||||
($(mfma op TL)+(\tg*\kpack*\elem*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kpack*\elem*\opIdxA)$)
|
||||
rectangle ++(\kpack*\elem*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kpack*\elem*\opIdxA)
|
||||
node [pos=.5, scale=.35*\scale, rotate=90*\opIdxA] {t\drawTid};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
\newcommand{\drawWaveOperand}[4]{
|
||||
%%
|
||||
%% Draw the part of the tensor that is one operand of the wave
|
||||
%%
|
||||
%% Op TL: pre defined coordinates of the top-left of the operand
|
||||
%% \elem: pre defined variable
|
||||
%%
|
||||
%% #1: K
|
||||
%% #2: mfmNonKDim
|
||||
%% #3: kpack
|
||||
%% #4: 0 for opA and 1 for opB
|
||||
|
||||
\pgfmathsetmacro{\K}{#1}
|
||||
\pgfmathsetmacro{\nonKDim}{#2}
|
||||
\pgfmathsetmacro{\kpack}{#3}
|
||||
\pgfmathsetmacro{\opIdx}{#4}
|
||||
\pgfmathsetmacro{\opIdxOther}{1-\opIdx}
|
||||
|
||||
\coordinate (TL) at (Op TL);
|
||||
|
||||
\pgfmathsetmacro{\numKRep}{\K/\kpack/2}
|
||||
\pgfmathsetmacro{\maxKRepId}{\numKRep-1}
|
||||
|
||||
\foreach \repId in {0,...,\maxKRepId}{
|
||||
\coordinate (mfma op TL) at ($(TL)+(\repId*2*\kpack*\elem*\opIdxOther, -\repId*2*\kpack*\elem*\opIdx)$);
|
||||
\drawMFMAOperand{\nonKDim}{\kpack}{\opIdx}{0}
|
||||
\draw [thick] (mfma op TL) rectangle
|
||||
++(2*\kpack*\elem*\opIdxOther+\nonKDim*\opIdx*\elem, -\nonKDim*\opIdxOther*\elem-2*\kpack*\elem*\opIdx);
|
||||
}
|
||||
}
|
||||
|
||||
\newcommand{\drawDotOperands}[7]{
|
||||
%%
|
||||
%% Draw operand tensors of dot
|
||||
%%
|
||||
%% A TL and B TL: pre defined top-left coordinates of A and B tensor
|
||||
%% \elem: pre defined variable
|
||||
%%
|
||||
%% #1: M
|
||||
%% #2: N
|
||||
%% #3: K
|
||||
%% #4: MFMA nonKDim
|
||||
%% #5: warpsPerCTA[0]
|
||||
%% #6: warpsPerCTA[1]
|
||||
%% #7: kpack
|
||||
|
||||
\pgfmathsetmacro{\M}{#1}
|
||||
\pgfmathsetmacro{\N}{#2}
|
||||
\pgfmathsetmacro{\K}{#3}
|
||||
\pgfmathsetmacro{\mfmaNonKDim}{#4}
|
||||
\pgfmathsetmacro{\warpsPerCTAM}{#5}
|
||||
\pgfmathsetmacro{\warpsPerCTAN}{#6}
|
||||
\pgfmathsetmacro{\kpack}{#7}
|
||||
|
||||
%% operand A
|
||||
\pgfmathsetmacro{\CTARepM}{\M/\warpsPerCTAM/32}
|
||||
\pgfmathsetmacro{\maxCTAIdM}{\CTARepM-1}
|
||||
\pgfmathsetmacro{\maxWaveId}{\warpsPerCTAM-1}
|
||||
\foreach \ctaId in {0,...,\maxCTAIdM}{
|
||||
\coordinate (CTA TL) at ($(A TL)+(0, -\ctaId*\warpsPerCTAM*32*\elem)$);
|
||||
\foreach \waveId in {0,...,\maxWaveId}{
|
||||
\coordinate (wave TL) at ($(CTA TL)+(0, -\waveId*32*\elem)$);
|
||||
\draw [ultra thin] (wave TL) rectangle ++(\K*\elem, -32*\elem);
|
||||
}
|
||||
%% Only draw the detailed view of the first wave in CTA
|
||||
\coordinate (Op TL) at (CTA TL);
|
||||
\drawWaveOperand{\K}{\mfmaNonKDim}{\kpack}{0}
|
||||
|
||||
%% Draw the outline of each CTA rep
|
||||
\draw [ultra thick] (CTA TL) rectangle ++(\K*\elem, -\warpsPerCTAM*32*\elem);
|
||||
}
|
||||
\draw [ultra thin] (A TL) rectangle ++(\K*\elem, -\M*\elem);
|
||||
|
||||
|
||||
%% operand B
|
||||
\pgfmathsetmacro{\CTARepN}{\N/\warpsPerCTAN/32}
|
||||
\pgfmathsetmacro{\maxCTAIdN}{\CTARepN-1}
|
||||
\pgfmathsetmacro{\maxWaveId}{\warpsPerCTAN-1}
|
||||
\foreach \ctaId in {0,...,\maxCTAIdN}{
|
||||
\coordinate (CTA TL) at ($(B TL)+(\ctaId*\warpsPerCTAN*32*\elem, 0)$);
|
||||
\foreach \waveId in {0,...,\maxWaveId}{
|
||||
\coordinate (wave TL) at ($(CTA TL)+(\waveId*32*\elem ,0)$);
|
||||
\draw [ultra thin] (wave TL) rectangle ++(32*\elem, -\K*\elem);
|
||||
}
|
||||
%% Only draw the detailed view of the first wave in CTA
|
||||
\coordinate (Op TL) at (CTA TL);
|
||||
\drawWaveOperand{\K}{\mfmaNonKDim}{\kpack}{1}
|
||||
|
||||
%% Draw the outline of each CTA rep
|
||||
\draw [ultra thick] (CTA TL) rectangle ++(\warpsPerCTAN*32*\elem, -\K*\elem);
|
||||
}
|
||||
\draw [ultra thin] (B TL) rectangle ++(\N*\elem, -\K*\elem);
|
||||
}
|
||||
|
||||
|
||||
\newcommand{\drawDot}[8]{
|
||||
%%
|
||||
%% Draw C = dot A, B
|
||||
%%
|
||||
%% C TL: pre defined top-left coordinates of the result tensor
|
||||
%% \elem: pre defined variable
|
||||
%%
|
||||
%% #1: M
|
||||
%% #2: N
|
||||
%% #3: K
|
||||
%% #4: MFMA nonKDim
|
||||
%% #5: warpsPerCTA[0]
|
||||
%% #6: warpsPerCTA[1]
|
||||
%% #7: 1 for mfma.trans, 0 for normal mfma
|
||||
%% #8: kpack
|
||||
|
||||
\pgfmathsetmacro{\M}{#1}
|
||||
\pgfmathsetmacro{\N}{#2}
|
||||
\pgfmathsetmacro{\K}{#3}
|
||||
\pgfmathsetmacro{\mfmaNonKDim}{#4}
|
||||
\pgfmathsetmacro{\warpsPerCTAM}{#5}
|
||||
\pgfmathsetmacro{\warpsPerCTAN}{#6}
|
||||
\pgfmathsetmacro{\mfmaTrans}{#7}
|
||||
\pgfmathsetmacro{\kpack}{#8}
|
||||
\pgfmathsetmacro{\kdim}{int(2*\kpack)}
|
||||
|
||||
\pgfmathsetmacro{\gap}{\elem*20}
|
||||
\coordinate (A TL) at ($(C TL)+(-\gap-\K*\elem, 0)$);
|
||||
\coordinate (B TL) at ($(C TL)+(0, \gap+\K*\elem)$);
|
||||
|
||||
\drawDotOperands{\M}{\N}{\K}{\mfmaNonKDim}{\warpsPerCTAM}{\warpsPerCTAN}{\kpack}
|
||||
|
||||
\drawTensorMFMALayout{\M}{\N}{\mfmaNonKDim}{\warpsPerCTAM}{\warpsPerCTAN}{\mfmaTrans}
|
||||
|
||||
%% Draw labels
|
||||
\node [scale=\scale, above] at ($(A TL)+(.5*\K*\elem, 0)$) {K=\K};
|
||||
\node [scale=\scale, above, rotate=90] at ($(A TL)+(0, -.5*\M*\elem)$) {M=\M};
|
||||
|
||||
\node [scale=\scale, above, rotate=90] at ($(B TL)+(0, -.5*\K*\elem)$) {K=\K};
|
||||
\node [scale=\scale, above] at ($(B TL)+(.5*\N*\elem, 0)$) {N=\N};
|
||||
|
||||
\node [scale=\scale, above left] at (A TL) {A};
|
||||
\node [scale=\scale, above left] at (B TL) {B};
|
||||
\node [scale=\scale, above left] at (C TL) {C};
|
||||
|
||||
%% label nonKDim
|
||||
\node [scale=.8*\scale, left] at ($(A TL)+(0, -.5*\mfmaNonKDim*\elem)$) {\mfmaNonKDim};
|
||||
\node [scale=.8*\scale, above] at ($(B TL)+(.5*\mfmaNonKDim*\elem, 0)$) {\mfmaNonKDim};
|
||||
%% label kpack
|
||||
\node [scale=.8*\scale, above] at ($(A TL)+(\kpack*\elem, 0)$) {\kdim};
|
||||
\node [scale=.8*\scale, left] at ($(B TL)+(0, -\kpack*\elem)$) {\kdim};
|
||||
}
|
||||
|
||||
\newcommand{\Colors}{{
|
||||
"red",
|
||||
"YellowGreen",
|
||||
"blue",
|
||||
"Maroon",
|
||||
"orange",
|
||||
"cyan",
|
||||
"magenta",
|
||||
"brown",
|
||||
"teal",
|
||||
"purple",
|
||||
"gray",
|
||||
"Green",
|
||||
"BlueGreen",
|
||||
"violet",
|
||||
"olive",
|
||||
"darkgray",
|
||||
}}
|
||||
|
||||
\newcommand{\drawTensorLayoutGlobalMem}{
|
||||
%%
|
||||
%% Draw tensor layout in global memory without any swizzling
|
||||
%%
|
||||
%% TL: pre defined top-left coordinates of the tensor in global memory
|
||||
%% \elem: per defined variable
|
||||
%% \Colors: a pre defined array of 16 colors
|
||||
%%
|
||||
%% The following arguments are also expected to be pre defined
|
||||
%% #1: M
|
||||
%% #2: K
|
||||
%% #3: vec: number of elements in a group
|
||||
|
||||
\pgfmathsetmacro{\numVecK}{\K/\vec}
|
||||
\pgfmathsetmacro{\maxVecId}{16*\numVecK-1}
|
||||
\pgfmathsetmacro{\drawM}{20}
|
||||
|
||||
%% Draw the tensor, but only draw 32 rows
|
||||
\draw (TL) rectangle ++(\K*\elem, -\drawM*\elem);
|
||||
%% Draw detailed vec view of the tensor
|
||||
\foreach \vecId in {0,...,\maxVecId}{
|
||||
|
||||
\pgfmathsetmacro{\vecCoordM}{int(\vecId/\numVecK)}
|
||||
\pgfmathsetmacro{\vecCoordK}{mod(\vecId,\numVecK)}
|
||||
\coordinate (vec TL) at ($(TL)+(\vecCoordK*\vec*\elem, -\vecCoordM*\elem)$);
|
||||
|
||||
\pgfmathsetmacro{\colorIdxK}{int(mod(\vecCoordK,16))}
|
||||
\pgfmathsetmacro{\colorIdxM}{mod(\vecCoordM,16)}
|
||||
\pgfmathsetmacro{\vecColor}{\Colors[\colorIdxK]}
|
||||
\pgfmathsetmacro{\ratio}{100-floor(\vecCoordK/16)*40}
|
||||
|
||||
\draw [ultra thin, fill=\vecColor!\ratio!white] (vec TL) rectangle ++(\vec*\elem, -\elem)
|
||||
node [pos=.5, scale=.6*\scale, white] {m\vecCoordM};
|
||||
|
||||
}
|
||||
%% M and K dim
|
||||
\node [scale=\scale, rotate=90, above] at ($(TL)+(0, -.5*\drawM*\elem-8*\elem)$) {M=\M};
|
||||
\node [scale=.8*\scale, left] at ($(TL)+(0, -.5*16*\elem)$) {16};
|
||||
\node [scale=\scale, above] at ($(TL)+(.5*\K*\elem, 0)$) {K=\K};
|
||||
%% label for vecSize
|
||||
\def\vecR{1.5}
|
||||
\coordinate (vec TL) at ($(TL)+(-.25*\vec*\elem, 3*\elem*\vecR)$);
|
||||
\pgfmathsetmacro{\maxVec}{\vec-1}
|
||||
\foreach \vecId in {0,...,\maxVec}{
|
||||
\draw ($(vec TL)+(\vecId*\elem*\vecR, 0)$) rectangle ++(\elem*\vecR, -\elem*\vecR);
|
||||
}
|
||||
\draw [densely dotted] (TL) -- ($(vec TL)+(0, -\elem*\vecR)$);
|
||||
\draw [densely dotted] ($(TL)+(\vec*\elem, 0)$) -- ($(vec TL)+(\vec*\elem*\vecR, -\elem*\vecR)$);
|
||||
\node [scale=.8*\scale, above] at ($(vec TL)+(.5*\vec*\elem*\vecR, 0)$) {vec=\vec};
|
||||
}
|
||||
|
||||
|
||||
|
||||
\newcommand{\drawLDSLayoutTritonSwizzling}[2]{
|
||||
%%
|
||||
%% Draw tensor layout in LDS with swizzling
|
||||
%%
|
||||
%% TL: pre defined top-left coordinates of the tensor in global memory
|
||||
%% \elem: per defined variable
|
||||
%% \Colors: a pre defined array of 16 colors
|
||||
%%
|
||||
%% The following three arguments are expected to be pre defined
|
||||
%% #1: M
|
||||
%% #2: K
|
||||
%% #3: vec: number of elements in a group
|
||||
%%
|
||||
%% #1: hasSwizzle, 0 means no swizzling and no padding,
|
||||
%% 1 means optimal swizzling
|
||||
%% 2 means padding
|
||||
%% #2: access mode, 0 means draw nothing, 1 means ds_read, 2 means ds_write
|
||||
%% For ds_write access, the following variables are assumed to be pre defined
|
||||
%% \sizePerThreadK
|
||||
%% \sizePerThreadM
|
||||
%% \threadsPerWarpK
|
||||
|
||||
\pgfmathsetmacro{\hasSwizzle}{#1}
|
||||
\pgfmathsetmacro{\accessMode}{#2}
|
||||
\pgfmathsetmacro{\numVecK}{\K/\vec}
|
||||
|
||||
%% Assuming fp16 data type
|
||||
\pgfmathsetmacro{\LDSK}{64}
|
||||
\pgfmathsetmacro{\numLDSVec}{\LDSK/\vec}
|
||||
\pgfmathsetmacro{\swizzleK}{max(\LDSK, \K)}
|
||||
\pgfmathsetmacro{\LDSM}{int(\M/\LDSK*\K)}
|
||||
|
||||
\ifthenelse{\accessMode = 2}{
|
||||
%% \accessMode == 2, draw 8 rows
|
||||
\pgfmathsetmacro{\maxVecId}{8*\numVecK-1}
|
||||
\pgfmathsetmacro{\drawM}{8*\K/\LDSK+4}
|
||||
}{
|
||||
%% \accessMode == 0 or 1, draw 16 rows
|
||||
\pgfmathsetmacro{\maxVecId}{16*\numVecK-1}
|
||||
\pgfmathsetmacro{\drawM}{16*\K/\LDSK+4}
|
||||
}
|
||||
|
||||
%% Parameters used for swizzling
|
||||
\pgfmathsetmacro{\numVecSwizzleK}{\swizzleK/\vec}
|
||||
%% perPhase = ceil(LDSK / K)
|
||||
%% The number of the rows of the tensor that can share the same swizzling pattern
|
||||
\pgfmathsetmacro{\perPhase}{ceil(\LDSK/\K)}
|
||||
%% maxPhase: the total number of different swizzling patterns
|
||||
\ifthenelse{\hasSwizzle=0}{
|
||||
%% When swizzling is disabled
|
||||
\pgfmathsetmacro{\maxPhase}{1}
|
||||
}{
|
||||
%% When vec is small enough, we want 16/perPhase different swizzling patterns
|
||||
%% When vec is large, we can only have 64 / \vec different swizzling pattern at most
|
||||
\pgfmathsetmacro{\maxPhase}{min(16/\perPhase,64/\vec)}
|
||||
}
|
||||
|
||||
%% Draw the LDS
|
||||
\draw (TL) rectangle ++(\LDSK*\elem, -\drawM*\elem);
|
||||
|
||||
%% Draw detailed vec view of LDS
|
||||
\foreach \vecId in {0,...,\maxVecId}{
|
||||
\pgfmathsetmacro{\vecCoordM}{int(\vecId/\numVecK)}
|
||||
\pgfmathsetmacro{\vecCoordK}{int(mod(\vecId,\numVecK))}
|
||||
\pgfmathsetmacro{\rawPhase}{floor(\vecId/\numVecSwizzleK)}
|
||||
%% vec color
|
||||
\pgfmathsetmacro{\colorIdxK}{int(mod(\vecCoordK,16))}
|
||||
\pgfmathsetmacro{\colorIdxM}{mod(\vecCoordM,16)}
|
||||
\pgfmathsetmacro{\ratio}{100-floor(\vecCoordK/16)*40}
|
||||
\pgfmathsetmacro{\vecColor}{\Colors[\colorIdxK]}
|
||||
|
||||
%% old vec coordinates
|
||||
\coordinate (vec TL) at ($(TL)+(\vecCoordK*\vec*\elem, -\vecCoordM*\elem)$);
|
||||
|
||||
%% new vec coordinates in LDS by swizzling
|
||||
%% The following two conditions correspond to the relation between \LDSK and \K
|
||||
\ifthenelse{\LDSK < \K}{
|
||||
\pgfmathsetmacro{\vecLDSM}{\vecCoordM*\K/\LDSK+floor(\vecCoordK*\vec/\LDSK)}
|
||||
\pgfmathsetmacro{\vecLDSK}{int(mod(\vecCoordK, \LDSK/\vec))}
|
||||
}{
|
||||
\pgfmathsetmacro{\vecLDSM}{floor(\vecCoordM/\perPhase)}
|
||||
\pgfmathsetmacro{\vecLDSK}{int(\vecCoordK+mod(\vecCoordM,\perPhase)*\numVecK)}
|
||||
}
|
||||
%%
|
||||
\pgfmathsetmacro{\phase}{int(mod(\rawPhase, \maxPhase))}
|
||||
%% Compute the swizzled col id
|
||||
\pgfmathsetmacro{\vecLDSKSwizzled}{\bitwiseXor{\vecLDSK}{\phase}}
|
||||
|
||||
%% new vec coordinates in LDS by padding
|
||||
\pgfmathsetmacro{\numPads}{floor(\vecId/\numLDSVec)}
|
||||
\pgfmathsetmacro{\bankId}{\vec/2*\vecId+\numPads}
|
||||
\pgfmathsetmacro{\vecPadM}{int(\bankId/32)}
|
||||
\pgfmathsetmacro{\vecPadK}{int(mod(\bankId,32))}
|
||||
|
||||
\ifthenelse{\hasSwizzle = 2}{
|
||||
%% vec coordinates by padding
|
||||
\coordinate (new vec TL) at ($(TL)+(\vecPadK*2*\elem, -\vecPadM*\elem)$);
|
||||
\pgfmathsetmacro{\tailBankId}{int(\vecPadK+\vec/2-1)}
|
||||
}{
|
||||
%% vec coordinates by swizzling
|
||||
\coordinate (new vec TL) at ($(TL)+(\vecLDSKSwizzled*\vec*\elem, -\vecLDSM*\elem)$);
|
||||
\pgfmathsetmacro{\tailBankId}{0}
|
||||
}
|
||||
|
||||
\ifthenelse{\hasSwizzle = 2 \AND \tailBankId > 31}{
|
||||
\pgfmathsetmacro{\nextBanks}{\tailBankId-31}
|
||||
\pgfmathsetmacro{\leftBanks}{\vec/2 - \nextBanks}
|
||||
\draw [ultra thin, fill=\vecColor!\ratio!white] (new vec TL) rectangle ++(\leftBanks*2*\elem, -\elem)
|
||||
node [pos=.5, scale=.6*\scale, white] {m\vecCoordM};
|
||||
\draw [ultra thin, fill=\vecColor!\ratio!white] ($(TL)+(0, -\vecPadM*\elem-\elem)$)
|
||||
rectangle ++(\nextBanks*2*\elem, -\elem) node [pos=.5, scale=.6*\scale, white] {m\vecCoordM};
|
||||
}{
|
||||
\draw [ultra thin, fill=\vecColor!\ratio!white] (new vec TL) rectangle ++(\vec*\elem, -\elem)
|
||||
node [pos=.5, scale=.6*\scale, white] {m\vecCoordM};
|
||||
}
|
||||
|
||||
%% ds_read
|
||||
%% Highlight the elements the first 16 threads access in the first cycle
|
||||
%% This is used to visualize bank conflicts
|
||||
\ifthenelse{\accessMode = 1}{
|
||||
\ifthenelse{\vecCoordK = 0}{
|
||||
\draw [fill=white] (new vec TL) rectangle ++(\elem, -\elem);
|
||||
\draw (new vec TL) -- ++(\elem, -\elem);
|
||||
\draw ($(new vec TL)+(0, -\elem)$) -- ++(\elem, \elem);
|
||||
}{}
|
||||
}{}
|
||||
|
||||
%% Draw ds_write pattern
|
||||
\ifthenelse{\accessMode = 2}{
|
||||
%% First compute the coverage of the first 16 threads
|
||||
\pgfmathsetmacro{\covK}{min(16, \threadsPerWarpK)*\sizePerThreadK/\vec}
|
||||
\pgfmathsetmacro{\covM}{ceil(16/\threadsPerWarpK)*\sizePerThreadM}
|
||||
%% Check conditions for the first 16 threads
|
||||
\pgfmathsetmacro{\vecInThread}{int(mod(\vecCoordK, \sizePerThreadK/\vec))}
|
||||
\ifthenelse{\vecInThread=0}{
|
||||
\ifthenelse{\vecCoordK<\covK \AND \vecCoordM<\covM}{
|
||||
\draw [fill=white] (new vec TL) rectangle ++(\elem, -\elem);
|
||||
\draw (new vec TL) -- ++(\elem, -\elem);
|
||||
\draw ($(new vec TL)+(0, -\elem)$) -- ++(\elem, \elem);
|
||||
}{}
|
||||
}{}
|
||||
}{}
|
||||
|
||||
%% Label the phase of each line if swizzling is used
|
||||
\ifthenelse{\hasSwizzle = 2}{}{
|
||||
\pgfmathsetmacro{\lastVecId}{int(64/\vec)-1}
|
||||
\ifthenelse{\vecLDSKSwizzled = \lastVecId}{
|
||||
\draw [ultra thin] ($(new vec TL)+(\vec*\elem, -.5*\elem)$) -- ++(\elem, 0)
|
||||
node [scale=.6*\scale, right] {\phase};
|
||||
}{}
|
||||
}
|
||||
}
|
||||
|
||||
%% Draw boundary of 32 banks
|
||||
%% Assume fp16 data type
|
||||
\foreach \bank in {0,...,31}{
|
||||
\draw [ultra thin, gray] ($(TL)+(\bank*2*\elem, 0)$) -- ++(0, 2*\elem)
|
||||
node [scale=.6*\scale, right, black] {\bank};
|
||||
}
|
||||
\draw [ultra thin, gray] ($(TL)+(32*2*\elem, 0)$) -- ++(0, 2*\elem);
|
||||
\node [scale=.6*\scale, left, black] at ($(TL)+(0, 2*\elem)$) {bank id};
|
||||
|
||||
\node [scale=\scale, above] at ($(TL)+(.5*\LDSK*\elem, 3*\elem)$) {LDS 32 banks};
|
||||
\node [scale=\scale, rotate=90, above] at ($(TL)+(0, -.5*\drawM*\elem)$) {LDSM=\LDSM};
|
||||
|
||||
%% label phase if swizzling is used
|
||||
\ifthenelse{\hasSwizzle = 2}{}{
|
||||
\node [scale=.6*\scale, above right] at($(TL)+(32*2*\elem, 0)$) {phase};
|
||||
}
|
||||
}
|
||||
|
||||
\newcommand{\drawMFMAInstr}[3]{
|
||||
%%
|
||||
%% Draw layout of mfma instructions with tid labeled
|
||||
%%
|
||||
%% C TL: pre defined top-left coordinates of the output matrix
|
||||
%% \elem: pre defined variable
|
||||
%%
|
||||
%% #1: mfmaNonKDim
|
||||
%% #2: kpack
|
||||
%% #3: mfmaTrans
|
||||
\pgfmathsetmacro{\mfmaNonKDim}{#1}
|
||||
\pgfmathsetmacro{\kpack}{#2}
|
||||
\pgfmathsetmacro{\mfmaTrans}{#3}
|
||||
\pgfmathsetmacro{\nonTrans}{1-#3}
|
||||
|
||||
\pgfmathsetmacro{\gap}{\elem*5}
|
||||
\coordinate (mfma opA TL) at ($(C TL)+(-.5*\gap-1.2*\nonTrans*\gap-2*\kpack*\elem, 0)$);
|
||||
\coordinate (mfma op TL) at (mfma opA TL);
|
||||
\drawMFMAOperand{\mfmaNonKDim}{\kpack}{0}{1}
|
||||
\coordinate (mfma op TL) at ($(C TL)+(0, 1.5*\gap+.5*\mfmaTrans*\gap+2*\kpack*\elem)$);
|
||||
\drawMFMAOperand{\mfmaNonKDim}{\kpack}{1}{1}
|
||||
|
||||
\coordinate (block TL) at (C TL);
|
||||
\drawBlockMFMALayoutLarge{\mfmaTrans}{1}
|
||||
|
||||
%% Draw labels
|
||||
\def\vecR{1.5}
|
||||
\coordinate (vec TL) at ($(mfma opA TL)+(-.25*\kpack*\elem, 3*\elem*\vecR)$);
|
||||
\pgfmathsetmacro{\maxVec}{\kpack-1}
|
||||
\foreach \vecId in {0,...,\maxVec}{
|
||||
\draw ($(vec TL)+(\vecId*\elem*\vecR, 0)$) rectangle ++(\elem*\vecR, -\elem*\vecR);
|
||||
}
|
||||
\draw [densely dotted] (mfma opA TL) -- ($(vec TL)+(0, -\elem*\vecR)$);
|
||||
\draw [densely dotted] ($(mfma opA TL)+(\kpack*\elem, 0)$) -- ($(vec TL)+(\kpack*\elem*\vecR, -\elem*\vecR)$);
|
||||
\node [scale=.8*\scale, above] at ($(vec TL)+(.5*\kpack*\elem*\vecR, 0)$) {vec=\kpack};
|
||||
|
||||
\coordinate (vec TL) at ($(mfma op TL)+(-3*\elem*\vecR, .25*\kpack*\elem)$);
|
||||
\foreach \vecId in {0,...,\maxVec}{
|
||||
\draw ($(vec TL)+(0, -\vecId*\elem*\vecR)$) rectangle ++(\elem*\vecR, -\elem*\vecR);
|
||||
}
|
||||
\draw [densely dotted] (mfma op TL) -- ($(vec TL)+(\elem*\vecR,0)$);
|
||||
\draw [densely dotted] ($(mfma op TL)+(0, -\kpack*\elem)$) -- ($(vec TL)+(\elem*\vecR, -\kpack*\elem*\vecR)$);
|
||||
\node [scale=.8*\scale, above, rotate=90] at ($(vec TL)+(0, -.5*\kpack*\elem*\vecR)$) {vec=\kpack};
|
||||
|
||||
\node [scale=\scale, below] at ($(block TL)+(.5*\mfmaNonKDim*\elem,-\mfmaNonKDim*\elem)$) {outC};
|
||||
\ifthenelse{\mfmaTrans=0}{
|
||||
\node [scale=\scale, below] at ($(mfma opA TL)+(\kpack*\elem, -\mfmaNonKDim*\elem)$) {opA};
|
||||
\node [scale=\scale, above] at (mfma op TL) {opB};
|
||||
\coordinate (vec TL) at ($(block TL)+(-3*\elem-\elem*\vecR, .25*4*\elem)$);
|
||||
\foreach \vecId in {0,1,2,3}{
|
||||
\draw ($(vec TL)+(0, -\vecId*\elem*\vecR)$) rectangle ++(\elem*\vecR, -\elem*\vecR);
|
||||
}
|
||||
\draw [densely dotted] (block TL) -- ++(-3*\elem, .25*4*\elem);
|
||||
\draw [densely dotted] ($(block TL)+(0, -4*\elem)$) -- ++(-3*\elem, -.25*4*\elem);
|
||||
\node [scale=.8*\scale, above, rotate=90] at ($(vec TL)+(0, -.5*4*\elem*\vecR)$) {vec=4};
|
||||
\node [scale=.8*\scale, above, align=center] at ($(block TL)+(16*\elem, 0)$) {mfmaLayout\\trans=False};
|
||||
}{
|
||||
\node [scale=\scale, below] at ($(mfma opA TL)+(\kpack*\elem, -\mfmaNonKDim*\elem)$) {opB};
|
||||
\node [scale=\scale, above] at (mfma op TL) {opA};
|
||||
\coordinate (vec TL) at ($(block TL)+(-.25*4*\elem, 3*\elem+\elem*\vecR)$);
|
||||
\foreach \vecId in {0,1,2,3}{
|
||||
\draw ($(vec TL)+(\vecId*\elem*\vecR, 0)$) rectangle ++(\elem*\vecR, -\elem*\vecR);
|
||||
}
|
||||
\draw [densely dotted] (block TL) -- ++(-.25*4*\elem, 3*\elem);
|
||||
\draw [densely dotted] ($(block TL)+(4*\elem, 0)$) -- ++(.25*4*\elem, 3*\elem);
|
||||
\node [scale=.8*\scale, above] at ($(vec TL)+(.5*4*\elem*\vecR, 0)$) {vec=4};
|
||||
\node [scale=.8*\scale, above, align=center] at ($(block TL)+(16*\elem, 0)$) {mfmaLayout\\trans=True};
|
||||
}
|
||||
}
|
||||
|
||||
\newcommand{\drawWMMAOperand}[3]{
|
||||
%%
|
||||
%% Draw the layout of one operand of WMMA instruction
|
||||
%%
|
||||
%% #1: opIdx. 0 for opA, 1 for opB
|
||||
%% #2: verbose. 1 means draw tid in each vec; 0 means draw nothing
|
||||
%% #3: mode. 0 for w32, 1 for w64
|
||||
%%
|
||||
%% wmma op TL: pre defined top-left coordinates of the operand matrix
|
||||
|
||||
\pgfmathsetmacro{\isOpB}{#1}
|
||||
\pgfmathsetmacro{\isOpA}{1-\isOpB}
|
||||
\pgfmathsetmacro{\verbose}{#2}
|
||||
\pgfmathsetmacro{\isWLarge}{#3}
|
||||
|
||||
\foreach \row in {0,...,15}{
|
||||
\pgfmathsetmacro{\ratio}{\row*5+15}
|
||||
\coordinate (vec TL) at ($(wmma op TL)+(\row*\isOpB*\elem, -\row*\elem*\isOpA)$);
|
||||
\ifthenelse{\isWLarge=1}{
|
||||
\pgfmathsetmacro{\tidone}{int(\row+16)}
|
||||
\pgfmathsetmacro{\tidtwo}{int(\row+32)}
|
||||
\pgfmathsetmacro{\tidthree}{int(\row+48)}
|
||||
\draw [line width=0.005mm, fill=brown!\ratio!white] (vec TL)
|
||||
rectangle ++(16*\elem*\isOpA+\elem*\isOpB, -\elem*\isOpA-16*\elem*\isOpB)
|
||||
node [scale=0.4*\scale, pos=.5, rotate=90*\isOpB] {t\row, t\tidone, t\tidtwo, t\tidthree};
|
||||
}{
|
||||
\pgfmathsetmacro{\tidone}{int(\row+16)}
|
||||
\draw [line width=0.005mm, fill=brown!\ratio!white] (vec TL)
|
||||
rectangle ++(16*\elem*\isOpA+\elem*\isOpB, -\elem*\isOpA-16*\elem*\isOpB)
|
||||
node [scale=0.4*\scale, pos=.5, rotate=90*\isOpB] {t\row, t\tidone};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
\newcommand{\drawWMMAResult}[2]{
|
||||
%%
|
||||
%% Draw layout of WMMA result tensor
|
||||
%%
|
||||
%% #1: verbose. 1 means draw tid in each vec; 0 means draw nothing
|
||||
%% #2: mode. 0 for w32, 1 for w64
|
||||
|
||||
\pgfmathsetmacro{\verbose}{#1}
|
||||
\pgfmathsetmacro{\isWLarge}{#2}
|
||||
|
||||
\pgfmathsetmacro{\numElem}{256}
|
||||
\pgfmathsetmacro{\maxElemId}{\numElem-1}
|
||||
|
||||
\foreach \elemId in {0,...,\maxElemId}{
|
||||
%% figure out the rowID
|
||||
\pgfmathsetmacro{\rowId}{floor(\elemId/16)}
|
||||
%% figure out the colID
|
||||
\pgfmathsetmacro{\colId}{mod(\elemId,16)}
|
||||
%% figure out the tid and color
|
||||
\ifthenelse{\isWLarge=1}{
|
||||
\pgfmathsetmacro{\tid}{int(mod(\elemId,64))}
|
||||
\pgfmathsetmacro{\laneId}{mod(\elemId,64)}
|
||||
}{
|
||||
\pgfmathsetmacro{\tid}{int(mod(\elemId,32))}
|
||||
\pgfmathsetmacro{\laneId}{mod(\elemId,32)}
|
||||
}
|
||||
%% figure out the color
|
||||
\pgfmathsetmacro{\colorId}{floor(\laneId/16)}
|
||||
\pgfmathsetmacro{\vecColor}{\Colors[\colorId]}
|
||||
%% Coordinate
|
||||
\coordinate (vec TL) at ($(C TL)+(\colId*\elem, -\rowId*\elem)$);
|
||||
\draw [line width=0.005mm, fill=\vecColor!60!white] (vec TL) rectangle ++(\elem, -\elem)
|
||||
node [scale=.4*\scale, pos=.5] {t\tid};
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
\newcommand{\drawWMMAInstr}[2]{
|
||||
%%
|
||||
%% Draw wmma instruction layouts 16x16x16
|
||||
%%
|
||||
%% #1: mode. 0 for w32, 1 for w64
|
||||
%% #2: verbose. 1 means draw tid in each vec; 0 means draw nothing
|
||||
%%
|
||||
%% C TL: pre defined top-left coordinates of output matrix
|
||||
%% \elem: pre defined element size
|
||||
|
||||
|
||||
\pgfmathsetmacro{\isWLarge}{#1}
|
||||
\pgfmathsetmacro{\verbose}{#2}
|
||||
|
||||
\pgfmathsetmacro{\gap}{\elem*2}
|
||||
\coordinate (wmma op TL) at ($(C TL)+(-\gap-16*\elem, 0)$);
|
||||
\coordinate (wmma opA TL) at (wmma op TL);
|
||||
\drawWMMAOperand{0}{\verbose}{\isWLarge}
|
||||
\coordinate (wmma op TL) at ($(C TL)+(0, \gap+16*\elem)$);
|
||||
\drawWMMAOperand{1}{\verbose}{\isWLarge}
|
||||
|
||||
\drawWMMAResult{1}{\isWLarge}
|
||||
|
||||
%% labels
|
||||
\pgfmathsetmacro{\gap}{\elem}
|
||||
\node [above left, scale=\scale] at (wmma opA TL) {A};
|
||||
\node [above left, scale=\scale] at (wmma op TL) {B};
|
||||
\node [above right, scale=\scale] at ($(C TL)+(16*\elem, 0)$) {C};
|
||||
|
||||
%% A k dim
|
||||
\node [scale=.8*\scale] (k dim A) at ($(wmma opA TL)+(8*\elem,\gap)$) {16};
|
||||
\draw [->, >=stealth] (k dim A.west) -- ($(wmma opA TL)+(0, \gap)$);
|
||||
\draw [->, >=stealth] (k dim A.east) -- ($(wmma opA TL)+(16*\elem, \gap)$);
|
||||
|
||||
%% B K dim
|
||||
\node [scale=.8*\scale, rotate=90] (k dim B) at ($(wmma op TL)+(-\gap, -8*\elem)$) {16};
|
||||
\draw [->, >=stealth] (k dim B.east) -- ($(wmma op TL)+(-\gap, 0)$);
|
||||
\draw [->, >=stealth] (k dim B.west) -- ($(wmma op TL)+(-\gap, -16*\elem)$);
|
||||
|
||||
%% C M dim
|
||||
\node [scale=.8*\scale] (m dim) at ($(C TL)+(8*\elem,-16*\elem-\gap)$) {16};
|
||||
\draw [->, >=stealth] (m dim.west) -- ($(C TL)+(0, -16*\elem-\gap)$);
|
||||
\draw [->, >=stealth] (m dim.east) -- ($(C TL)+(16*\elem, -16*\elem-\gap)$);
|
||||
|
||||
%% C N dim
|
||||
\node [scale=.8*\scale, rotate=-90] (n dim) at ($(C TL)+(16*\elem+\gap, -8*\elem)$) {16};
|
||||
\draw [->, >=stealth] (n dim.west) -- ($(C TL)+(16*\elem+\gap, 0)$);
|
||||
\draw [->, >=stealth] (n dim.east) -- ($(C TL)+(16*\elem+\gap, -16*\elem)$);
|
||||
}
|
||||
@@ -1978,11 +1978,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
// PTX: nvvm.shfl.sync bfly
|
||||
// PTX: nvvm.barrier0
|
||||
|
||||
// GCN-COUNT-4: ds_swizzle_b32
|
||||
// GCN-COUNT-4: rocdl.ds_swizzle %{{.*}} : (i32, i32) -> i32
|
||||
// GCN: llvm.store
|
||||
// GCN: rocdl.barrier
|
||||
// GCN: llvm.load
|
||||
// GCN-COUNT-2: ds_swizzle_b32
|
||||
// GCN-COUNT-2: rocdl.ds_swizzle %{{.*}} : (i32, i32) -> i32
|
||||
// GCN: llvm.store
|
||||
// GCN: rocdl.barrier
|
||||
// GCN: llvm.load
|
||||
|
||||
Reference in New Issue
Block a user