[MFMA] Reenable removed CDNA3 int and fp8 support (#424)

MFMA4x4 PR accidentailly removed support of `int8xint8 -> int32` and `fp8xfp8 -> fp32` dot on CDNA.
This PR reenables it back.
This commit is contained in:
Alexander Efimov
2023-12-14 13:06:28 +01:00
committed by GitHub
parent f2afd65e8c
commit 40e1dcaa53
2 changed files with 520 additions and 0 deletions

View File

@@ -139,6 +139,22 @@ struct DotOpMFMAConversionHelper {
auto resType = valC.getType();
Value zeroFlag = i32_val(0);
switch (coreType) {
case MatrixCoreType::FP32_FP8_FP8_FP32:
return rewriter.create<ROCDL::mfma_f32_16x16x32_fp8_fp8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_FP8_BF8_FP32:
return rewriter.create<ROCDL::mfma_f32_16x16x32_fp8_bf8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF8_FP8_FP32:
return rewriter.create<ROCDL::mfma_f32_16x16x32_bf8_fp8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF8_BF8_FP32:
return rewriter.create<ROCDL::mfma_f32_16x16x32_bf8_bf8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_FP16_FP16_FP32:
return rewriter.create<ROCDL::mfma_f32_16x16x16f16>(
loc, TypeRange{resType},
@@ -159,6 +175,10 @@ struct DotOpMFMAConversionHelper {
return rewriter.create<ROCDL::mfma_i32_16x16x16i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::INT32_INT8_INT8_INT32_CDNA3:
return rewriter.create<ROCDL::mfma_i32_16x16x32_i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP64_FP64_FP64_FP64:
return rewriter.create<ROCDL::mfma_f64_16x16x4f64>(
loc, TypeRange{resType},

View File

@@ -0,0 +1,500 @@
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm="target=rocdl" 2>/dev/null | FileCheck --check-prefixes=CHECK,GCN %s
!a_ty = f8E4M3FNUZ
!b_ty = f8E4M3FNUZ
!c_ty = f32
#k_width = 8
#non_k_dim = 32
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma_f32_32x32x16_fp8_fp8
tt.func @convert_dot_mfma_f32_32x32x16_fp8_fp8(%a: tensor<128x256x!a_ty, #dot_operand_a>, %b: tensor<256x32x!b_ty, #dot_operand_b>) {
%cst_c = arith.constant dense<0.000000e+00> : tensor<128x32x!c_ty, #mfma>
// GCN-COUNT-64: rocdl.mfma.f32.32x32x16.fp8.fp8
%D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x256x!a_ty, #dot_operand_a> * tensor<256x32x!b_ty, #dot_operand_b> -> tensor<128x32x!c_ty, #mfma>
tt.return
}
}
// -----
!a_ty = f8E4M3FNUZ
!b_ty = f8E5M2FNUZ
!c_ty = f32
#k_width = 8
#non_k_dim = 32
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma_f32_32x32x16_fp8_bf8
tt.func @convert_dot_mfma_f32_32x32x16_fp8_bf8(%a: tensor<128x256x!a_ty, #dot_operand_a>, %b: tensor<256x32x!b_ty, #dot_operand_b>) {
%cst_c = arith.constant dense<0.000000e+00> : tensor<128x32x!c_ty, #mfma>
// GCN-COUNT-64: rocdl.mfma.f32.32x32x16.fp8.bf8
%D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x256x!a_ty, #dot_operand_a> * tensor<256x32x!b_ty, #dot_operand_b> -> tensor<128x32x!c_ty, #mfma>
tt.return
}
}
// -----
!a_ty = f8E5M2FNUZ
!b_ty = f8E4M3FNUZ
!c_ty = f32
#k_width = 8
#non_k_dim = 32
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma_f32_32x32x16_bf8_fp8
tt.func @convert_dot_mfma_f32_32x32x16_bf8_fp8(%a: tensor<128x256x!a_ty, #dot_operand_a>, %b: tensor<256x32x!b_ty, #dot_operand_b>) {
%cst_c = arith.constant dense<0.000000e+00> : tensor<128x32x!c_ty, #mfma>
// GCN-COUNT-64: rocdl.mfma.f32.32x32x16.bf8.fp8
%D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x256x!a_ty, #dot_operand_a> * tensor<256x32x!b_ty, #dot_operand_b> -> tensor<128x32x!c_ty, #mfma>
tt.return
}
}
// -----
!a_ty = f8E5M2FNUZ
!b_ty = f8E5M2FNUZ
!c_ty = f32
#k_width = 8
#non_k_dim = 32
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma_f32_32x32x16_bf8_bf8
tt.func @convert_dot_mfma_f32_32x32x16_bf8_bf8(%a: tensor<128x256x!a_ty, #dot_operand_a>, %b: tensor<256x32x!b_ty, #dot_operand_b>) {
%cst_c = arith.constant dense<0.000000e+00> : tensor<128x32x!c_ty, #mfma>
// GCN-COUNT-64: rocdl.mfma.f32.32x32x16.bf8.bf8
%D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x256x!a_ty, #dot_operand_a> * tensor<256x32x!b_ty, #dot_operand_b> -> tensor<128x32x!c_ty, #mfma>
tt.return
}
}
// -----
!a_ty = f16
!b_ty = f16
!c_ty = f32
#k_width = 4
#non_k_dim = 32
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma_f32_32x32x8f16
tt.func @convert_dot_mfma_f32_32x32x8f16(%a: tensor<128x256x!a_ty, #dot_operand_a>, %b: tensor<256x32x!b_ty, #dot_operand_b>) {
%cst_c = arith.constant dense<0.000000e+00> : tensor<128x32x!c_ty, #mfma>
// GCN-COUNT-128: rocdl.mfma.f32.32x32x8f16
%D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x256x!a_ty, #dot_operand_a> * tensor<256x32x!b_ty, #dot_operand_b> -> tensor<128x32x!c_ty, #mfma>
tt.return
}
}
// -----
!a_ty = bf16
!b_ty = bf16
!c_ty = f32
#k_width = 2
#non_k_dim = 32
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma_f32_32x32x4bf16
tt.func @convert_dot_mfma_f32_32x32x4bf16(%a: tensor<128x256x!a_ty, #dot_operand_a>, %b: tensor<256x32x!b_ty, #dot_operand_b>) {
%cst_c = arith.constant dense<0.000000e+00> : tensor<128x32x!c_ty, #mfma>
// GCN-COUNT-256: rocdl.mfma.f32.32x32x4bf16
%D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x256x!a_ty, #dot_operand_a> * tensor<256x32x!b_ty, #dot_operand_b> -> tensor<128x32x!c_ty, #mfma>
tt.return
}
}
// -----
!a_ty = bf16
!b_ty = bf16
!c_ty = f32
#k_width = 4
#non_k_dim = 32
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma_f32_32x32x8bf16_1k
tt.func @convert_dot_mfma_f32_32x32x8bf16_1k(%a: tensor<128x256x!a_ty, #dot_operand_a>, %b: tensor<256x32x!b_ty, #dot_operand_b>) {
%cst_c = arith.constant dense<0.000000e+00> : tensor<128x32x!c_ty, #mfma>
// GCN-COUNT-128: rocdl.mfma.f32.32x32x8bf16.1k
%D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x256x!a_ty, #dot_operand_a> * tensor<256x32x!b_ty, #dot_operand_b> -> tensor<128x32x!c_ty, #mfma>
tt.return
}
}
// -----
!a_ty = f32
!b_ty = f32
!c_ty = f32
#k_width = 1
#non_k_dim = 32
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma_f32_32x32x2f32
tt.func @convert_dot_mfma_f32_32x32x2f32(%a: tensor<128x256x!a_ty, #dot_operand_a>, %b: tensor<256x32x!b_ty, #dot_operand_b>) {
%cst_c = arith.constant dense<0.000000e+00> : tensor<128x32x!c_ty, #mfma>
// GCN-COUNT-512: rocdl.mfma.f32.32x32x2f32
%D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x256x!a_ty, #dot_operand_a> * tensor<256x32x!b_ty, #dot_operand_b> -> tensor<128x32x!c_ty, #mfma>
tt.return
}
}
// -----
!a_ty = i8
!b_ty = i8
!c_ty = i32
#k_width = 4
#non_k_dim = 32
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma_i32_32x32x8i8
tt.func @convert_dot_mfma_i32_32x32x8i8(%a: tensor<128x256x!a_ty, #dot_operand_a>, %b: tensor<256x32x!b_ty, #dot_operand_b>) {
%cst_c = arith.constant dense<0> : tensor<128x32x!c_ty, #mfma>
// GCN-COUNT-128: rocdl.mfma.i32.32x32x8i8
%D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x256x!a_ty, #dot_operand_a> * tensor<256x32x!b_ty, #dot_operand_b> -> tensor<128x32x!c_ty, #mfma>
tt.return
}
}
// -----
!a_ty = i8
!b_ty = i8
!c_ty = i32
#k_width = 8
#non_k_dim = 32
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma_i32_32x32x16_i8
tt.func @convert_dot_mfma_i32_32x32x16_i8(%a: tensor<128x256x!a_ty, #dot_operand_a>, %b: tensor<256x32x!b_ty, #dot_operand_b>) {
%cst_c = arith.constant dense<0> : tensor<128x32x!c_ty, #mfma>
// GCN-COUNT-64: rocdl.mfma.i32.32x32x16.i8
%D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x256x!a_ty, #dot_operand_a> * tensor<256x32x!b_ty, #dot_operand_b> -> tensor<128x32x!c_ty, #mfma>
tt.return
}
}
// -----
!a_ty = f8E4M3FNUZ
!b_ty = f8E4M3FNUZ
!c_ty = f32
#k_width = 8
#non_k_dim = 16
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma_f32_16x16x32_fp8_fp8
tt.func @convert_dot_mfma_f32_16x16x32_fp8_fp8(%a: tensor<128x256x!a_ty, #dot_operand_a>, %b: tensor<256x32x!b_ty, #dot_operand_b>) {
%cst_c = arith.constant dense<0.000000e+00> : tensor<128x32x!c_ty, #mfma>
// GCN-COUNT-128: rocdl.mfma.f32.16x16x32.fp8.fp8
%D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x256x!a_ty, #dot_operand_a> * tensor<256x32x!b_ty, #dot_operand_b> -> tensor<128x32x!c_ty, #mfma>
tt.return
}
}
// -----
!a_ty = f8E4M3FNUZ
!b_ty = f8E5M2FNUZ
!c_ty = f32
#k_width = 8
#non_k_dim = 16
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma_f32_16x16x32_fp8_bf8
tt.func @convert_dot_mfma_f32_16x16x32_fp8_bf8(%a: tensor<128x256x!a_ty, #dot_operand_a>, %b: tensor<256x32x!b_ty, #dot_operand_b>) {
%cst_c = arith.constant dense<0.000000e+00> : tensor<128x32x!c_ty, #mfma>
// GCN-COUNT-128: rocdl.mfma.f32.16x16x32.fp8.bf8
%D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x256x!a_ty, #dot_operand_a> * tensor<256x32x!b_ty, #dot_operand_b> -> tensor<128x32x!c_ty, #mfma>
tt.return
}
}
// -----
!a_ty = f8E5M2FNUZ
!b_ty = f8E4M3FNUZ
!c_ty = f32
#k_width = 8
#non_k_dim = 16
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma_f32_16x16x32_bf8_fp8
tt.func @convert_dot_mfma_f32_16x16x32_bf8_fp8(%a: tensor<128x256x!a_ty, #dot_operand_a>, %b: tensor<256x32x!b_ty, #dot_operand_b>) {
%cst_c = arith.constant dense<0.000000e+00> : tensor<128x32x!c_ty, #mfma>
// GCN-COUNT-128: rocdl.mfma.f32.16x16x32.bf8.fp8
%D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x256x!a_ty, #dot_operand_a> * tensor<256x32x!b_ty, #dot_operand_b> -> tensor<128x32x!c_ty, #mfma>
tt.return
}
}
// -----
!a_ty = f8E5M2FNUZ
!b_ty = f8E5M2FNUZ
!c_ty = f32
#k_width = 8
#non_k_dim = 16
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma_f32_16x16x32_bf8_bf8
tt.func @convert_dot_mfma_f32_16x16x32_bf8_bf8(%a: tensor<128x256x!a_ty, #dot_operand_a>, %b: tensor<256x32x!b_ty, #dot_operand_b>) {
%cst_c = arith.constant dense<0.000000e+00> : tensor<128x32x!c_ty, #mfma>
// GCN-COUNT-128: rocdl.mfma.f32.16x16x32.bf8.bf8
%D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x256x!a_ty, #dot_operand_a> * tensor<256x32x!b_ty, #dot_operand_b> -> tensor<128x32x!c_ty, #mfma>
tt.return
}
}
// -----
!a_ty = f16
!b_ty = f16
!c_ty = f32
#k_width = 4
#non_k_dim = 16
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma_f32_16x16x16f16
tt.func @convert_dot_mfma_f32_16x16x16f16(%a: tensor<128x256x!a_ty, #dot_operand_a>, %b: tensor<256x32x!b_ty, #dot_operand_b>) {
%cst_c = arith.constant dense<0.000000e+00> : tensor<128x32x!c_ty, #mfma>
// GCN-COUNT-256: rocdl.mfma.f32.16x16x16f16
%D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x256x!a_ty, #dot_operand_a> * tensor<256x32x!b_ty, #dot_operand_b> -> tensor<128x32x!c_ty, #mfma>
tt.return
}
}
// -----
!a_ty = bf16
!b_ty = bf16
!c_ty = f32
#k_width = 2
#non_k_dim = 16
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma_f32_16x16x8bf16
tt.func @convert_dot_mfma_f32_16x16x8bf16(%a: tensor<128x256x!a_ty, #dot_operand_a>, %b: tensor<256x32x!b_ty, #dot_operand_b>) {
%cst_c = arith.constant dense<0.000000e+00> : tensor<128x32x!c_ty, #mfma>
// GCN-COUNT-512: rocdl.mfma.f32.16x16x8bf16
%D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x256x!a_ty, #dot_operand_a> * tensor<256x32x!b_ty, #dot_operand_b> -> tensor<128x32x!c_ty, #mfma>
tt.return
}
}
// -----
!a_ty = bf16
!b_ty = bf16
!c_ty = f32
#k_width = 4
#non_k_dim = 16
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma_f32_16x16x16bf16_1k
tt.func @convert_dot_mfma_f32_16x16x16bf16_1k(%a: tensor<128x256x!a_ty, #dot_operand_a>, %b: tensor<256x32x!b_ty, #dot_operand_b>) {
%cst_c = arith.constant dense<0.000000e+00> : tensor<128x32x!c_ty, #mfma>
// GCN-COUNT-256: rocdl.mfma.f32.16x16x16bf16.1k
%D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x256x!a_ty, #dot_operand_a> * tensor<256x32x!b_ty, #dot_operand_b> -> tensor<128x32x!c_ty, #mfma>
tt.return
}
}
// -----
!a_ty = f32
!b_ty = f32
!c_ty = f32
#k_width = 1
#non_k_dim = 16
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma_f32_16x16x4f32
tt.func @convert_dot_mfma_f32_16x16x4f32(%a: tensor<128x256x!a_ty, #dot_operand_a>, %b: tensor<256x32x!b_ty, #dot_operand_b>) {
%cst_c = arith.constant dense<0.000000e+00> : tensor<128x32x!c_ty, #mfma>
// GCN-COUNT-1024: rocdl.mfma.f32.16x16x4f32
%D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x256x!a_ty, #dot_operand_a> * tensor<256x32x!b_ty, #dot_operand_b> -> tensor<128x32x!c_ty, #mfma>
tt.return
}
}
// -----
!a_ty = i8
!b_ty = i8
!c_ty = i32
#k_width = 4
#non_k_dim = 16
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma_i32_16x16x16i8
tt.func @convert_dot_mfma_i32_16x16x16i8(%a: tensor<128x256x!a_ty, #dot_operand_a>, %b: tensor<256x32x!b_ty, #dot_operand_b>) {
%cst_c = arith.constant dense<0> : tensor<128x32x!c_ty, #mfma>
// GCN-COUNT-256: rocdl.mfma.i32.16x16x16i8
%D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x256x!a_ty, #dot_operand_a> * tensor<256x32x!b_ty, #dot_operand_b> -> tensor<128x32x!c_ty, #mfma>
tt.return
}
}
// -----
!a_ty = i8
!b_ty = i8
!c_ty = i32
#k_width = 8
#non_k_dim = 16
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma_i32_16x16x32_i8
tt.func @convert_dot_mfma_i32_16x16x32_i8(%a: tensor<128x256x!a_ty, #dot_operand_a>, %b: tensor<256x32x!b_ty, #dot_operand_b>) {
%cst_c = arith.constant dense<0> : tensor<128x32x!c_ty, #mfma>
// GCN-COUNT-128: rocdl.mfma.i32.16x16x32.i8
%D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x256x!a_ty, #dot_operand_a> * tensor<256x32x!b_ty, #dot_operand_b> -> tensor<128x32x!c_ty, #mfma>
tt.return
}
}
// -----
!a_ty = f16
!b_ty = f16
!c_ty = f32
#k_width = 4
#non_k_dim = 4
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma_f32_4x4x4f16
tt.func @convert_dot_mfma_f32_4x4x4f16(%a: tensor<128x256x!a_ty, #dot_operand_a>, %b: tensor<256x32x!b_ty, #dot_operand_b>) {
%cst_c = arith.constant dense<0.000000e+00> : tensor<128x32x!c_ty, #mfma>
// GCN-COUNT-1024: rocdl.mfma.f32.4x4x4f16
%D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x256x!a_ty, #dot_operand_a> * tensor<256x32x!b_ty, #dot_operand_b> -> tensor<128x32x!c_ty, #mfma>
tt.return
}
}
// -----
!a_ty = bf16
!b_ty = bf16
!c_ty = f32
#k_width = 2
#non_k_dim = 4
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma_f32_4x4x2bf16
tt.func @convert_dot_mfma_f32_4x4x2bf16(%a: tensor<128x256x!a_ty, #dot_operand_a>, %b: tensor<256x32x!b_ty, #dot_operand_b>) {
%cst_c = arith.constant dense<0.000000e+00> : tensor<128x32x!c_ty, #mfma>
// GCN-COUNT-2048: rocdl.mfma.f32.4x4x2bf16
%D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x256x!a_ty, #dot_operand_a> * tensor<256x32x!b_ty, #dot_operand_b> -> tensor<128x32x!c_ty, #mfma>
tt.return
}
}
// -----
!a_ty = bf16
!b_ty = bf16
!c_ty = f32
#k_width = 4
#non_k_dim = 4
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma_f32_4x4x4bf16_1k
tt.func @convert_dot_mfma_f32_4x4x4bf16_1k(%a: tensor<128x256x!a_ty, #dot_operand_a>, %b: tensor<256x32x!b_ty, #dot_operand_b>) {
%cst_c = arith.constant dense<0.000000e+00> : tensor<128x32x!c_ty, #mfma>
// GCN-COUNT-1024: rocdl.mfma.f32.4x4x4bf16.1k
%D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x256x!a_ty, #dot_operand_a> * tensor<256x32x!b_ty, #dot_operand_b> -> tensor<128x32x!c_ty, #mfma>
tt.return
}
}
// -----
!a_ty = f32
!b_ty = f32
!c_ty = f32
#k_width = 1
#non_k_dim = 4
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma_f32_4x4x1f32
tt.func @convert_dot_mfma_f32_4x4x1f32(%a: tensor<128x256x!a_ty, #dot_operand_a>, %b: tensor<256x32x!b_ty, #dot_operand_b>) {
%cst_c = arith.constant dense<0.000000e+00> : tensor<128x32x!c_ty, #mfma>
// GCN-COUNT-4096: rocdl.mfma.f32.4x4x1f32
%D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x256x!a_ty, #dot_operand_a> * tensor<256x32x!b_ty, #dot_operand_b> -> tensor<128x32x!c_ty, #mfma>
tt.return
}
}
// -----
!a_ty = i8
!b_ty = i8
!c_ty = i32
#k_width = 4
#non_k_dim = 4
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mfma_i32_4x4x4i8
tt.func @convert_dot_mfma_i32_4x4x4i8(%a: tensor<128x256x!a_ty, #dot_operand_a>, %b: tensor<256x32x!b_ty, #dot_operand_b>) {
%cst_c = arith.constant dense<0> : tensor<128x32x!c_ty, #mfma>
// GCN-COUNT-1024: rocdl.mfma.i32.4x4x4i8
%D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x256x!a_ty, #dot_operand_a> * tensor<256x32x!b_ty, #dot_operand_b> -> tensor<128x32x!c_ty, #mfma>
tt.return
}
}