[MFMA] Change kWidth parameter semantics

This PR changes kWidth semantics "from elements per instruction" to
"elements per thread per instruction" along k axis.
This commit is contained in:
Aleksandr Efimov
2023-09-11 18:01:57 +00:00
committed by Lixun Zhang
parent 10795d8fd3
commit d80cd2d374
7 changed files with 21 additions and 18 deletions

View File

@@ -389,7 +389,7 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
// layout when opIdx == 1.
return mfmaLayout.getWarpsPerCTA()[1] == 1 &&
dotOperandLayout.getOpIdx() == 0 &&
dotOperandLayout.getKWidth() == 8 &&
dotOperandLayout.getKWidth() == 4 &&
dotOperandLayout.getParent() == mfmaLayout &&
mfmaLayout.getIsTransposed() &&
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());

View File

@@ -92,10 +92,12 @@ struct DotOpMFMAConversionHelper {
return MatrixCoreType::FP32_FP32_FP32_FP32;
if (elemTy.isBF16()) {
auto dotOpEncoding = tensorTy.getEncoding().cast<DotOperandEncodingAttr>();
if (dotOpEncoding.getKWidth() == 8)
if (dotOpEncoding.getKWidth() == 4) {
return MatrixCoreType::FP32_BF16_BF16_FP32_1K;
else
} else {
assert(dotOpEncoding.getKWidth() == 2);
return MatrixCoreType::FP32_BF16_BF16_FP32;
}
}
if (elemTy.isInteger(8))
return MatrixCoreType::INT32_INT8_INT8_INT32;

View File

@@ -111,7 +111,7 @@ Type TritonGPUToLLVMTypeConverter::getElementTypeForStruct(
if (elemTy.isF32())
return elemTy;
if (elemTy.isInteger(16)) // aka BF16
return vec_ty(elemTy, dotOpLayout.getKWidth() / 2);
return vec_ty(elemTy, dotOpLayout.getKWidth());
if (elemTy.isF16())
return vec_ty(elemTy, 4);
if (elemTy.isInteger(8))

View File

@@ -679,9 +679,9 @@ DotOperandEncodingAttr::getMFMAElemsPerInstr() const {
int64_t nonKDim = mfmaEncoding.getNonKDim();
int64_t kDim = getKWidth();
if (getOpIdx() == 0)
return {nonKDim, kDim};
return {nonKDim, kDim*2};
else
return {kDim, nonKDim};
return {kDim*2, nonKDim};
}
SmallVector<int64_t>

View File

@@ -133,21 +133,22 @@ public:
std::pair<int64_t, int64_t> chooseMfmaDimensions(triton::DotOp dot, int mfmaVersion) const {
int64_t nonKDim = 32;
// number of matrix elements along k dim per thread per in one mfma instruction
int64_t kDim = -1;
auto opType = dot.getA().getType().cast<RankedTensorType>();
auto elemType = opType.getElementType();
if (elemType.isF32())
kDim = 2;
kDim = 1;
if (elemType.isF16())
kDim = 8;
kDim = 4;
if (elemType.isBF16()) {
if (mfmaVersion == 1)
kDim = 4;
kDim = 2;
if (mfmaVersion == 2)
kDim = 8;
kDim = 4;
}
if (elemType.isInteger(8))
kDim = 8;
kDim = 4;
assert(kDim != -1);
return {nonKDim, kDim};
}

View File

@@ -2398,11 +2398,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
%17 = tt.addptr %16, %8 : tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xi32, #blocked>
%18 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf16, #blocked>
%19 = triton_gpu.convert_layout %18 : (tensor<32x32xf16, #blocked>) -> tensor<32x32xf16, #shared1>
%20 = triton_gpu.convert_layout %19 : (tensor<32x32xf16, #shared1>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=8}>>
%20 = triton_gpu.convert_layout %19 : (tensor<32x32xf16, #shared1>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=4}>>
%21 = tt.load %13 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf16, #blocked>
%22 = triton_gpu.convert_layout %21 : (tensor<32x32xf16, #blocked>) -> tensor<32x32xf16, #shared2>
%23 = triton_gpu.convert_layout %22 : (tensor<32x32xf16, #shared2>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=8}>>
%24 = tt.dot %20, %23, %cst {allowTF32 = false} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=8}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, kDim = 8, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=8}>> -> tensor<32x32xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>>
%23 = triton_gpu.convert_layout %22 : (tensor<32x32xf16, #shared2>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=4}>>
%24 = tt.dot %20, %23, %cst {allowTF32 = false} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=4}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, kDim = 8, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=4}>> -> tensor<32x32xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>>
%25 = triton_gpu.convert_layout %24 : (tensor<32x32xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>>) -> tensor<32x32xf32, #blocked>
%26 = arith.truncf %25 : tensor<32x32xf32, #blocked> to tensor<32x32xf16, #blocked>
tt.store %17, %26 {cache = 1 : i32, evict = 1 : i32} : tensor<32x32xf16, #blocked>

View File

@@ -1187,8 +1187,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0]}>
#mfma0 = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma0, kWidth = 8}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma0, kWidth = 8}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma0, kWidth = 4}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma0, kWidth = 4}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma
tt.func @convert_dot_mfma(%A: tensor<32x32xf16, #blocked0>, %B: tensor<32x32xf16, #blocked0>) {
@@ -1361,8 +1361,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = 8}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = 8}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = 4}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = 4}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: matmul_kernel_dot_operand_layout_gcn
tt.func @matmul_kernel_dot_operand_layout_gcn(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},