mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
[MFMA] Change kWidth parameter semantics
This PR changes kWidth semantics "from elements per instruction" to "elements per thread per instruction" along k axis.
This commit is contained in:
committed by
Lixun Zhang
parent
10795d8fd3
commit
d80cd2d374
@@ -389,7 +389,7 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
// layout when opIdx == 1.
|
||||
return mfmaLayout.getWarpsPerCTA()[1] == 1 &&
|
||||
dotOperandLayout.getOpIdx() == 0 &&
|
||||
dotOperandLayout.getKWidth() == 8 &&
|
||||
dotOperandLayout.getKWidth() == 4 &&
|
||||
dotOperandLayout.getParent() == mfmaLayout &&
|
||||
mfmaLayout.getIsTransposed() &&
|
||||
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
|
||||
|
||||
@@ -92,10 +92,12 @@ struct DotOpMFMAConversionHelper {
|
||||
return MatrixCoreType::FP32_FP32_FP32_FP32;
|
||||
if (elemTy.isBF16()) {
|
||||
auto dotOpEncoding = tensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
if (dotOpEncoding.getKWidth() == 8)
|
||||
if (dotOpEncoding.getKWidth() == 4) {
|
||||
return MatrixCoreType::FP32_BF16_BF16_FP32_1K;
|
||||
else
|
||||
} else {
|
||||
assert(dotOpEncoding.getKWidth() == 2);
|
||||
return MatrixCoreType::FP32_BF16_BF16_FP32;
|
||||
}
|
||||
}
|
||||
if (elemTy.isInteger(8))
|
||||
return MatrixCoreType::INT32_INT8_INT8_INT32;
|
||||
|
||||
@@ -111,7 +111,7 @@ Type TritonGPUToLLVMTypeConverter::getElementTypeForStruct(
|
||||
if (elemTy.isF32())
|
||||
return elemTy;
|
||||
if (elemTy.isInteger(16)) // aka BF16
|
||||
return vec_ty(elemTy, dotOpLayout.getKWidth() / 2);
|
||||
return vec_ty(elemTy, dotOpLayout.getKWidth());
|
||||
if (elemTy.isF16())
|
||||
return vec_ty(elemTy, 4);
|
||||
if (elemTy.isInteger(8))
|
||||
|
||||
@@ -679,9 +679,9 @@ DotOperandEncodingAttr::getMFMAElemsPerInstr() const {
|
||||
int64_t nonKDim = mfmaEncoding.getNonKDim();
|
||||
int64_t kDim = getKWidth();
|
||||
if (getOpIdx() == 0)
|
||||
return {nonKDim, kDim};
|
||||
return {nonKDim, kDim*2};
|
||||
else
|
||||
return {kDim, nonKDim};
|
||||
return {kDim*2, nonKDim};
|
||||
}
|
||||
|
||||
SmallVector<int64_t>
|
||||
|
||||
@@ -133,21 +133,22 @@ public:
|
||||
|
||||
std::pair<int64_t, int64_t> chooseMfmaDimensions(triton::DotOp dot, int mfmaVersion) const {
|
||||
int64_t nonKDim = 32;
|
||||
// number of matrix elements along k dim per thread per in one mfma instruction
|
||||
int64_t kDim = -1;
|
||||
auto opType = dot.getA().getType().cast<RankedTensorType>();
|
||||
auto elemType = opType.getElementType();
|
||||
if (elemType.isF32())
|
||||
kDim = 2;
|
||||
kDim = 1;
|
||||
if (elemType.isF16())
|
||||
kDim = 8;
|
||||
kDim = 4;
|
||||
if (elemType.isBF16()) {
|
||||
if (mfmaVersion == 1)
|
||||
kDim = 4;
|
||||
kDim = 2;
|
||||
if (mfmaVersion == 2)
|
||||
kDim = 8;
|
||||
kDim = 4;
|
||||
}
|
||||
if (elemType.isInteger(8))
|
||||
kDim = 8;
|
||||
kDim = 4;
|
||||
assert(kDim != -1);
|
||||
return {nonKDim, kDim};
|
||||
}
|
||||
|
||||
@@ -2398,11 +2398,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
|
||||
%17 = tt.addptr %16, %8 : tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xi32, #blocked>
|
||||
%18 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf16, #blocked>
|
||||
%19 = triton_gpu.convert_layout %18 : (tensor<32x32xf16, #blocked>) -> tensor<32x32xf16, #shared1>
|
||||
%20 = triton_gpu.convert_layout %19 : (tensor<32x32xf16, #shared1>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=8}>>
|
||||
%20 = triton_gpu.convert_layout %19 : (tensor<32x32xf16, #shared1>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=4}>>
|
||||
%21 = tt.load %13 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf16, #blocked>
|
||||
%22 = triton_gpu.convert_layout %21 : (tensor<32x32xf16, #blocked>) -> tensor<32x32xf16, #shared2>
|
||||
%23 = triton_gpu.convert_layout %22 : (tensor<32x32xf16, #shared2>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=8}>>
|
||||
%24 = tt.dot %20, %23, %cst {allowTF32 = false} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=8}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, kDim = 8, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=8}>> -> tensor<32x32xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>>
|
||||
%23 = triton_gpu.convert_layout %22 : (tensor<32x32xf16, #shared2>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=4}>>
|
||||
%24 = tt.dot %20, %23, %cst {allowTF32 = false} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=4}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, kDim = 8, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=4}>> -> tensor<32x32xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>>
|
||||
%25 = triton_gpu.convert_layout %24 : (tensor<32x32xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>>) -> tensor<32x32xf32, #blocked>
|
||||
%26 = arith.truncf %25 : tensor<32x32xf32, #blocked> to tensor<32x32xf16, #blocked>
|
||||
tt.store %17, %26 {cache = 1 : i32, evict = 1 : i32} : tensor<32x32xf16, #blocked>
|
||||
|
||||
@@ -1187,8 +1187,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0]}>
|
||||
#mfma0 = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma0, kWidth = 8}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma0, kWidth = 8}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma0, kWidth = 4}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma0, kWidth = 4}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: convert_dot_mfma
|
||||
tt.func @convert_dot_mfma(%A: tensor<32x32xf16, #blocked0>, %B: tensor<32x32xf16, #blocked0>) {
|
||||
@@ -1361,8 +1361,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = 8}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = 8}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = 4}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = 4}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: matmul_kernel_dot_operand_layout_gcn
|
||||
tt.func @matmul_kernel_dot_operand_layout_gcn(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
|
||||
Reference in New Issue
Block a user