diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index acdb2c558..17eaa4ec8 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -389,7 +389,7 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { // layout when opIdx == 1. return mfmaLayout.getWarpsPerCTA()[1] == 1 && dotOperandLayout.getOpIdx() == 0 && - dotOperandLayout.getKWidth() == 8 && + dotOperandLayout.getKWidth() == 4 && dotOperandLayout.getParent() == mfmaLayout && mfmaLayout.getIsTransposed() && (srcTy.getElementType().isF16() || srcTy.getElementType().isBF16()); diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp index 047e1c6c1..0e215bd13 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp @@ -92,10 +92,12 @@ struct DotOpMFMAConversionHelper { return MatrixCoreType::FP32_FP32_FP32_FP32; if (elemTy.isBF16()) { auto dotOpEncoding = tensorTy.getEncoding().cast(); - if (dotOpEncoding.getKWidth() == 8) + if (dotOpEncoding.getKWidth() == 4) { return MatrixCoreType::FP32_BF16_BF16_FP32_1K; - else + } else { + assert(dotOpEncoding.getKWidth() == 2); return MatrixCoreType::FP32_BF16_BF16_FP32; + } } if (elemTy.isInteger(8)) return MatrixCoreType::INT32_INT8_INT8_INT32; diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index a094f5d3f..3f20337e8 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -111,7 +111,7 @@ Type TritonGPUToLLVMTypeConverter::getElementTypeForStruct( if (elemTy.isF32()) return elemTy; if (elemTy.isInteger(16)) // aka BF16 - return vec_ty(elemTy, dotOpLayout.getKWidth() / 2); + return vec_ty(elemTy, dotOpLayout.getKWidth()); if (elemTy.isF16()) return vec_ty(elemTy, 4); if (elemTy.isInteger(8)) diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index bc55e1661..5632b0f6b 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -679,9 +679,9 @@ DotOperandEncodingAttr::getMFMAElemsPerInstr() const { int64_t nonKDim = mfmaEncoding.getNonKDim(); int64_t kDim = getKWidth(); if (getOpIdx() == 0) - return {nonKDim, kDim}; + return {nonKDim, kDim*2}; else - return {kDim, nonKDim}; + return {kDim*2, nonKDim}; } SmallVector diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 77aab6018..c485fdcd4 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -133,21 +133,22 @@ public: std::pair chooseMfmaDimensions(triton::DotOp dot, int mfmaVersion) const { int64_t nonKDim = 32; + // number of matrix elements along k dim per thread per in one mfma instruction int64_t kDim = -1; auto opType = dot.getA().getType().cast(); auto elemType = opType.getElementType(); if (elemType.isF32()) - kDim = 2; + kDim = 1; if (elemType.isF16()) - kDim = 8; + kDim = 4; if (elemType.isBF16()) { if (mfmaVersion == 1) - kDim = 4; + kDim = 2; if (mfmaVersion == 2) - kDim = 8; + kDim = 4; } if (elemType.isInteger(8)) - kDim = 8; + kDim = 4; assert(kDim != -1); return {nonKDim, kDim}; } diff --git a/python/test/unit/language/test_core_amd.py b/python/test/unit/language/test_core_amd.py index 9ef4e99bf..191e76f71 100644 --- a/python/test/unit/language/test_core_amd.py +++ b/python/test/unit/language/test_core_amd.py @@ -2398,11 +2398,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %17 = tt.addptr %16, %8 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> %18 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf16, #blocked> %19 = triton_gpu.convert_layout %18 : (tensor<32x32xf16, #blocked>) -> tensor<32x32xf16, #shared1> - %20 = triton_gpu.convert_layout %19 : (tensor<32x32xf16, #shared1>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=8}>> + %20 = triton_gpu.convert_layout %19 : (tensor<32x32xf16, #shared1>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=4}>> %21 = tt.load %13 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf16, #blocked> %22 = triton_gpu.convert_layout %21 : (tensor<32x32xf16, #blocked>) -> tensor<32x32xf16, #shared2> - %23 = triton_gpu.convert_layout %22 : (tensor<32x32xf16, #shared2>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=8}>> - %24 = tt.dot %20, %23, %cst {allowTF32 = false} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=8}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, kDim = 8, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=8}>> -> tensor<32x32xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>> + %23 = triton_gpu.convert_layout %22 : (tensor<32x32xf16, #shared2>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=4}>> + %24 = tt.dot %20, %23, %cst {allowTF32 = false} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=4}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, kDim = 8, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=4}>> -> tensor<32x32xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>> %25 = triton_gpu.convert_layout %24 : (tensor<32x32xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>>) -> tensor<32x32xf32, #blocked> %26 = arith.truncf %25 : tensor<32x32xf32, #blocked> to tensor<32x32xf16, #blocked> tt.store %17, %26 {cache = 1 : i32, evict = 1 : i32} : tensor<32x32xf16, #blocked> diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 7460b1b79..0a56bd3cf 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1187,8 +1187,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [1, 0]}> #shared0 = #triton_gpu.shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0]}> #mfma0 = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA=[1,1], isTranspose=false}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma0, kWidth = 8}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma0, kWidth = 8}> +#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma0, kWidth = 4}> +#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma0, kWidth = 4}> module attributes {"triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: convert_dot_mfma tt.func @convert_dot_mfma(%A: tensor<32x32xf16, #blocked0>, %B: tensor<32x32xf16, #blocked0>) { @@ -1361,8 +1361,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 32], warpsPerCTA = [1, 4], order = [1, 0]}> #shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed=false}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = 8}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = 8}> +#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = 4}> +#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = 4}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: matmul_kernel_dot_operand_layout_gcn tt.func @matmul_kernel_dot_operand_layout_gcn(%ptr:!tt.ptr {tt.divisibility = 16 : i32},