diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index f588a471b..bdce88ba8 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -764,8 +764,7 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129, ins "unsigned":$nonKDim, ArrayRefParameter<"unsigned">:$warpsPerCTA, - "bool":$isTransposed, - "CTALayoutAttr":$CTALayout + "bool":$isTransposed ); let hasCustomAssemblyFormat = 1; diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp index 104401e64..a395c2180 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp @@ -772,7 +772,7 @@ private: srcType.getEncoding().dyn_cast(); auto newMfmaEnc = triton::gpu::MfmaEncodingAttr::get( mod.getContext(), srcMfma.getNonKDim(), {warpsPerCtaX, warpsPerCtaY}, - srcMfma.getIsTransposed(), srcMfma.getCTALayout()); + srcMfma.getIsTransposed()); auto newDstType = RankedTensorType::get( dstType.getShape(), dstType.getElementType(), dstType.getEncoding()); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index f284e2f97..bd0e6e36a 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -543,7 +543,7 @@ SmallVector getCTAsPerCGA(Attribute layout) { ref = mmaLayout.getCTALayout().getCTAsPerCGA(); #ifdef USE_ROCM else if (auto mfmaLayout = layout.dyn_cast()) - ref = mfmaLayout.getCTALayout().getCTAsPerCGA(); + return {1, 1}; #endif else if (auto dotLayout = layout.dyn_cast()) return getCTAsPerCGA(dotLayout.getParent()); @@ -568,8 +568,8 @@ SmallVector getCTASplitNum(Attribute layout) { mmaLayout.getCTALayout().getCTASplitNum().end()); #ifdef USE_ROCM } else if (auto mfmaLayout = layout.dyn_cast()) { - res.assign(mfmaLayout.getCTALayout().getCTASplitNum().begin(), - mfmaLayout.getCTALayout().getCTASplitNum().end()); + res.resize(2); + res[0] = res[1] = 1; #endif } else if (auto dotLayout = layout.dyn_cast()) { res = getCTASplitNum(dotLayout.getParent()); @@ -598,7 +598,7 @@ SmallVector getCTAOrder(Attribute layout) { ref = mmaLayout.getCTALayout().getCTAOrder(); #ifdef USE_ROCM } else if (auto mfmaLayout = layout.dyn_cast()) { - ref = mfmaLayout.getCTALayout().getCTAOrder(); + return {0, 1}; #endif } else if (auto dotLayout = layout.dyn_cast()) { return getCTAOrder(dotLayout.getParent()); @@ -673,8 +673,9 @@ unsigned getNumCTAs(Attribute layout) { else if (auto mmaLayout = layout.dyn_cast()) CTAsPerCGA = mmaLayout.getCTALayout().getCTAsPerCGA(); #ifdef USE_ROCM - else if (auto mfmaLayout = layout.dyn_cast()) - CTAsPerCGA = mfmaLayout.getCTALayout().getCTAsPerCGA(); + else if (auto mfmaLayout = layout.dyn_cast()) { + return 1; + } #endif else if (auto dotLayout = layout.dyn_cast()) return getNumCTAs(dotLayout.getParent()); @@ -1316,9 +1317,6 @@ Attribute MfmaEncodingAttr::parse(AsmParser &parser, Type type) { unsigned nonKDim = 0; SmallVector warpsPerCTA; bool isTransposed; - SmallVector CTAsPerCGA; - SmallVector CTASplitNum; - SmallVector CTAOrder; for (const NamedAttribute &attr : dict) { if (attr.getName() == "nonKDim") { @@ -1332,35 +1330,17 @@ Attribute MfmaEncodingAttr::parse(AsmParser &parser, Type type) { if (parseBool(parser, attr, isTransposed, "isTransposed").failed()) return {}; } - if (attr.getName() == "CTAsPerCGA") { - if (parseIntArrayAttr(parser, attr, CTAsPerCGA, "CTAsPerCGA").failed()) - return {}; - } - if (attr.getName() == "CTASplitNum") { - if (parseIntArrayAttr(parser, attr, CTASplitNum, "CTASplitNum").failed()) - return {}; - } - if (attr.getName() == "CTAOrder") { - if (parseIntArrayAttr(parser, attr, CTAOrder, "CTAOrder").failed()) - return {}; - } } - auto CTALayout = CTALayoutAttr::get(parser.getContext(), CTAsPerCGA, - CTASplitNum, CTAOrder); - - return parser.getChecked( - parser.getContext(), nonKDim, warpsPerCTA, isTransposed, CTALayout); + return parser.getChecked(parser.getContext(), nonKDim, + warpsPerCTA, isTransposed); } void MfmaEncodingAttr::print(AsmPrinter &printer) const { printer << "<{" << "nonKDim = " << getNonKDim() << ", " << "warpsPerCTA = [" << getWarpsPerCTA() << "], " - << "isTransposed = " << getIsTransposed() << ", " - << "CTAsPerCGA = [" << getCTALayout().getCTAsPerCGA() << "], " - << "CTASplitNum = [" << getCTALayout().getCTASplitNum() << "], " - << "CTAOrder = [" << getCTALayout().getCTAOrder() << "]}>"; + << "isTransposed = " << getIsTransposed() << "}>"; } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp index 91280d608..b4554a73a 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp @@ -243,6 +243,9 @@ public: return failure(); auto CTALayout = ttg::getCTALayout(oldRetType.getEncoding()); + assert(CTALayout.getCTAsPerCGA().size() == 2); + assert(CTALayout.getCTAsPerCGA()[0] == 1); + assert(CTALayout.getCTAsPerCGA()[1] == 1); // get MFMA encoding for the given number of warps auto retShape = oldRetType.getShape(); @@ -264,7 +267,7 @@ public: bool isTransposed = isChainDot(dotOp); mfmaEnc = ttg::MfmaEncodingAttr::get(oldRetType.getContext(), nonKDim, - warpsPerTile, isTransposed, CTALayout); + warpsPerTile, isTransposed); auto newRetType = RankedTensorType::get(retShape, oldRetType.getElementType(), mfmaEnc); diff --git a/python/test/unit/language/test_core_amd.py b/python/test/unit/language/test_core_amd.py index 1bae88783..314021c9a 100644 --- a/python/test/unit/language/test_core_amd.py +++ b/python/test/unit/language/test_core_amd.py @@ -2700,16 +2700,13 @@ class MmaLayout: class MfmaLayout: - def __init__(self, non_k_dim, warps_per_cta, is_transposed, ctas_per_cga, cta_split_num, cta_order): + def __init__(self, non_k_dim, warps_per_cta, is_transposed): self.non_k_dim = str(non_k_dim) self.warps_per_cta = str(warps_per_cta) self.is_transposed = str(is_transposed).lower() - self.ctas_per_cga = str(ctas_per_cga) - self.cta_split_num = str(cta_split_num) - self.cta_order = str(cta_order) def __str__(self): - return f"#{GPU_DIALECT}.mfma<{{nonKDim = {self.non_k_dim}, warpsPerCTA = {self.warps_per_cta}, isTransposed = {self.is_transposed}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + return f"#{GPU_DIALECT}.mfma<{{nonKDim = {self.non_k_dim}, warpsPerCTA = {self.warps_per_cta}, isTransposed = {self.is_transposed}}}>" class BlockedLayout: @@ -2776,7 +2773,7 @@ def test_dot_mfma_vector_load(vec_size, swizzle, transposeA, transposeB): blocked = BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[8, 8], warps_per_cta=[4, 1], order=[1, 0], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1]) shared_a = SharedLayout(vec=vec_size, per_phase=1, max_phase=max_phase, order=[0, 1] if transposeA else [1, 0], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1]) shared_b = SharedLayout(vec=vec_size, per_phase=1, max_phase=max_phase, order=[0, 1] if transposeB else [1, 0], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1]) - mfma = MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=False, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1]) + mfma = MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=False) ir = f""" #blocked = {blocked} @@ -2938,8 +2935,8 @@ module attributes {"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32, if torch.version.hip is not None and _get_warp_size() == 64: layouts = [ - MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0]), - MfmaLayout(non_k_dim=32, warps_per_cta=[2, 2], is_transposed=False, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0]), + MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True), + MfmaLayout(non_k_dim=32, warps_per_cta=[2, 2], is_transposed=False), ] shapes = [[128, 32], [128, 128], [32, 128], [64, 64]] else: @@ -3097,8 +3094,8 @@ def test_scan_layouts(M, N, src_layout, axis, device): @pytest.mark.parametrize("shape", [(64, 64)]) @pytest.mark.parametrize("dtype", ['float16']) -@pytest.mark.parametrize("src_layout", [MfmaLayout(non_k_dim=32, warps_per_cta=[2, 1], is_transposed=False, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0]), - MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0])]) +@pytest.mark.parametrize("src_layout", [MfmaLayout(non_k_dim=32, warps_per_cta=[2, 1], is_transposed=False), + MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True)]) @pytest.mark.parametrize("dst_layout", [BlockedLayout([1, 4], [4, 16], [1, 1], [1, 0], [1, 1], [1, 1], [0, 1])]) def test_make_range(dtype, shape, src_layout, dst_layout, device='cuda'): ir = f""" diff --git a/test/Conversion/AMDGPU/mfma_variants.mlir b/test/Conversion/AMDGPU/mfma_variants.mlir index a9362a9e5..d5891b494 100644 --- a/test/Conversion/AMDGPU/mfma_variants.mlir +++ b/test/Conversion/AMDGPU/mfma_variants.mlir @@ -5,7 +5,7 @@ !c_ty = f32 #k_width = 8 #non_k_dim = 32 -#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -25,7 +25,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : !c_ty = f32 #k_width = 8 #non_k_dim = 32 -#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -45,7 +45,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : !c_ty = f32 #k_width = 8 #non_k_dim = 32 -#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -65,7 +65,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : !c_ty = f32 #k_width = 8 #non_k_dim = 32 -#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -85,7 +85,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : !c_ty = f32 #k_width = 4 #non_k_dim = 32 -#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -105,7 +105,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : !c_ty = f32 #k_width = 2 #non_k_dim = 32 -#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -125,7 +125,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : !c_ty = f32 #k_width = 4 #non_k_dim = 32 -#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -145,7 +145,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : !c_ty = f32 #k_width = 1 #non_k_dim = 32 -#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -165,7 +165,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : !c_ty = i32 #k_width = 4 #non_k_dim = 32 -#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -185,7 +185,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : !c_ty = i32 #k_width = 8 #non_k_dim = 32 -#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -205,7 +205,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : !c_ty = f32 #k_width = 8 #non_k_dim = 16 -#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -225,7 +225,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : !c_ty = f32 #k_width = 8 #non_k_dim = 16 -#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -245,7 +245,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : !c_ty = f32 #k_width = 8 #non_k_dim = 16 -#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -265,7 +265,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : !c_ty = f32 #k_width = 8 #non_k_dim = 16 -#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -285,7 +285,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : !c_ty = f32 #k_width = 4 #non_k_dim = 16 -#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -305,7 +305,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : !c_ty = f32 #k_width = 2 #non_k_dim = 16 -#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -325,7 +325,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : !c_ty = f32 #k_width = 4 #non_k_dim = 16 -#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -345,7 +345,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : !c_ty = f32 #k_width = 1 #non_k_dim = 16 -#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -365,7 +365,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : !c_ty = i32 #k_width = 4 #non_k_dim = 16 -#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -385,7 +385,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : !c_ty = i32 #k_width = 8 #non_k_dim = 16 -#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -405,7 +405,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : !c_ty = f32 #k_width = 4 #non_k_dim = 4 -#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -425,7 +425,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : !c_ty = f32 #k_width = 2 #non_k_dim = 4 -#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -445,7 +445,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : !c_ty = f32 #k_width = 4 #non_k_dim = 4 -#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -465,7 +465,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : !c_ty = f32 #k_width = 1 #non_k_dim = 4 -#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -485,7 +485,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : !c_ty = i32 #k_width = 4 #non_k_dim = 4 -#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { diff --git a/test/Conversion/minimize_alloc.mlir b/test/Conversion/minimize_alloc.mlir index b08c474fb..8ec4bd6c4 100644 --- a/test/Conversion/minimize_alloc.mlir +++ b/test/Conversion/minimize_alloc.mlir @@ -5,7 +5,7 @@ #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed=false}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { tt.func public @matmul_kernel_0d1d2d3d4d5d6d7c8d9c10d11c(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mfma> diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 8e2fd68c2..df23a4f61 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1355,7 +1355,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #shared0 = #triton_gpu.shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#mfma0 = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma0 = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA=[1,1], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma0, kWidth = 4}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma0, kWidth = 4}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -1379,7 +1379,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTranspose=false}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> // CHECK-LABEL: convert_layout_mfma_block @@ -1529,7 +1529,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = 4}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = 4}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {