[MFMA] Remove CTA related code from layout (#429)

This PR removes CTALayout attribute from MFMA layout, because it is NV specific.
This commit is contained in:
Alexander Efimov
2023-12-27 20:01:28 +03:00
committed by GitHub
parent 1e2fd0dd1a
commit 98589ac013
8 changed files with 52 additions and 73 deletions

View File

@@ -764,8 +764,7 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
ins
"unsigned":$nonKDim,
ArrayRefParameter<"unsigned">:$warpsPerCTA,
"bool":$isTransposed,
"CTALayoutAttr":$CTALayout
"bool":$isTransposed
);
let hasCustomAssemblyFormat = 1;

View File

@@ -772,7 +772,7 @@ private:
srcType.getEncoding().dyn_cast<triton::gpu::MfmaEncodingAttr>();
auto newMfmaEnc = triton::gpu::MfmaEncodingAttr::get(
mod.getContext(), srcMfma.getNonKDim(), {warpsPerCtaX, warpsPerCtaY},
srcMfma.getIsTransposed(), srcMfma.getCTALayout());
srcMfma.getIsTransposed());
auto newDstType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(), dstType.getEncoding());

View File

@@ -543,7 +543,7 @@ SmallVector<unsigned> getCTAsPerCGA(Attribute layout) {
ref = mmaLayout.getCTALayout().getCTAsPerCGA();
#ifdef USE_ROCM
else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>())
ref = mfmaLayout.getCTALayout().getCTAsPerCGA();
return {1, 1};
#endif
else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>())
return getCTAsPerCGA(dotLayout.getParent());
@@ -568,8 +568,8 @@ SmallVector<unsigned> getCTASplitNum(Attribute layout) {
mmaLayout.getCTALayout().getCTASplitNum().end());
#ifdef USE_ROCM
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
res.assign(mfmaLayout.getCTALayout().getCTASplitNum().begin(),
mfmaLayout.getCTALayout().getCTASplitNum().end());
res.resize(2);
res[0] = res[1] = 1;
#endif
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
res = getCTASplitNum(dotLayout.getParent());
@@ -598,7 +598,7 @@ SmallVector<unsigned> getCTAOrder(Attribute layout) {
ref = mmaLayout.getCTALayout().getCTAOrder();
#ifdef USE_ROCM
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
ref = mfmaLayout.getCTALayout().getCTAOrder();
return {0, 1};
#endif
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
return getCTAOrder(dotLayout.getParent());
@@ -673,8 +673,9 @@ unsigned getNumCTAs(Attribute layout) {
else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>())
CTAsPerCGA = mmaLayout.getCTALayout().getCTAsPerCGA();
#ifdef USE_ROCM
else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>())
CTAsPerCGA = mfmaLayout.getCTALayout().getCTAsPerCGA();
else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
return 1;
}
#endif
else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>())
return getNumCTAs(dotLayout.getParent());
@@ -1316,9 +1317,6 @@ Attribute MfmaEncodingAttr::parse(AsmParser &parser, Type type) {
unsigned nonKDim = 0;
SmallVector<unsigned> warpsPerCTA;
bool isTransposed;
SmallVector<unsigned> CTAsPerCGA;
SmallVector<unsigned> CTASplitNum;
SmallVector<unsigned> CTAOrder;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "nonKDim") {
@@ -1332,35 +1330,17 @@ Attribute MfmaEncodingAttr::parse(AsmParser &parser, Type type) {
if (parseBool(parser, attr, isTransposed, "isTransposed").failed())
return {};
}
if (attr.getName() == "CTAsPerCGA") {
if (parseIntArrayAttr(parser, attr, CTAsPerCGA, "CTAsPerCGA").failed())
return {};
}
if (attr.getName() == "CTASplitNum") {
if (parseIntArrayAttr(parser, attr, CTASplitNum, "CTASplitNum").failed())
return {};
}
if (attr.getName() == "CTAOrder") {
if (parseIntArrayAttr(parser, attr, CTAOrder, "CTAOrder").failed())
return {};
}
}
auto CTALayout = CTALayoutAttr::get(parser.getContext(), CTAsPerCGA,
CTASplitNum, CTAOrder);
return parser.getChecked<MfmaEncodingAttr>(
parser.getContext(), nonKDim, warpsPerCTA, isTransposed, CTALayout);
return parser.getChecked<MfmaEncodingAttr>(parser.getContext(), nonKDim,
warpsPerCTA, isTransposed);
}
void MfmaEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{"
<< "nonKDim = " << getNonKDim() << ", "
<< "warpsPerCTA = [" << getWarpsPerCTA() << "], "
<< "isTransposed = " << getIsTransposed() << ", "
<< "CTAsPerCGA = [" << getCTALayout().getCTAsPerCGA() << "], "
<< "CTASplitNum = [" << getCTALayout().getCTASplitNum() << "], "
<< "CTAOrder = [" << getCTALayout().getCTAOrder() << "]}>";
<< "isTransposed = " << getIsTransposed() << "}>";
}
//===----------------------------------------------------------------------===//

View File

@@ -243,6 +243,9 @@ public:
return failure();
auto CTALayout = ttg::getCTALayout(oldRetType.getEncoding());
assert(CTALayout.getCTAsPerCGA().size() == 2);
assert(CTALayout.getCTAsPerCGA()[0] == 1);
assert(CTALayout.getCTAsPerCGA()[1] == 1);
// get MFMA encoding for the given number of warps
auto retShape = oldRetType.getShape();
@@ -264,7 +267,7 @@ public:
bool isTransposed = isChainDot(dotOp);
mfmaEnc = ttg::MfmaEncodingAttr::get(oldRetType.getContext(), nonKDim,
warpsPerTile, isTransposed, CTALayout);
warpsPerTile, isTransposed);
auto newRetType =
RankedTensorType::get(retShape, oldRetType.getElementType(), mfmaEnc);

View File

@@ -2700,16 +2700,13 @@ class MmaLayout:
class MfmaLayout:
def __init__(self, non_k_dim, warps_per_cta, is_transposed, ctas_per_cga, cta_split_num, cta_order):
def __init__(self, non_k_dim, warps_per_cta, is_transposed):
self.non_k_dim = str(non_k_dim)
self.warps_per_cta = str(warps_per_cta)
self.is_transposed = str(is_transposed).lower()
self.ctas_per_cga = str(ctas_per_cga)
self.cta_split_num = str(cta_split_num)
self.cta_order = str(cta_order)
def __str__(self):
return f"#{GPU_DIALECT}.mfma<{{nonKDim = {self.non_k_dim}, warpsPerCTA = {self.warps_per_cta}, isTransposed = {self.is_transposed}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
return f"#{GPU_DIALECT}.mfma<{{nonKDim = {self.non_k_dim}, warpsPerCTA = {self.warps_per_cta}, isTransposed = {self.is_transposed}}}>"
class BlockedLayout:
@@ -2776,7 +2773,7 @@ def test_dot_mfma_vector_load(vec_size, swizzle, transposeA, transposeB):
blocked = BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[8, 8], warps_per_cta=[4, 1], order=[1, 0], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1])
shared_a = SharedLayout(vec=vec_size, per_phase=1, max_phase=max_phase, order=[0, 1] if transposeA else [1, 0], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1])
shared_b = SharedLayout(vec=vec_size, per_phase=1, max_phase=max_phase, order=[0, 1] if transposeB else [1, 0], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1])
mfma = MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=False, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1])
mfma = MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=False)
ir = f"""
#blocked = {blocked}
@@ -2938,8 +2935,8 @@ module attributes {"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32,
if torch.version.hip is not None and _get_warp_size() == 64:
layouts = [
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0]),
MfmaLayout(non_k_dim=32, warps_per_cta=[2, 2], is_transposed=False, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0]),
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True),
MfmaLayout(non_k_dim=32, warps_per_cta=[2, 2], is_transposed=False),
]
shapes = [[128, 32], [128, 128], [32, 128], [64, 64]]
else:
@@ -3097,8 +3094,8 @@ def test_scan_layouts(M, N, src_layout, axis, device):
@pytest.mark.parametrize("shape", [(64, 64)])
@pytest.mark.parametrize("dtype", ['float16'])
@pytest.mark.parametrize("src_layout", [MfmaLayout(non_k_dim=32, warps_per_cta=[2, 1], is_transposed=False, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0]),
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0])])
@pytest.mark.parametrize("src_layout", [MfmaLayout(non_k_dim=32, warps_per_cta=[2, 1], is_transposed=False),
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True)])
@pytest.mark.parametrize("dst_layout", [BlockedLayout([1, 4], [4, 16], [1, 1], [1, 0], [1, 1], [1, 1], [0, 1])])
def test_make_range(dtype, shape, src_layout, dst_layout, device='cuda'):
ir = f"""

View File

@@ -5,7 +5,7 @@
!c_ty = f32
#k_width = 8
#non_k_dim = 32
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
@@ -25,7 +25,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
!c_ty = f32
#k_width = 8
#non_k_dim = 32
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
@@ -45,7 +45,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
!c_ty = f32
#k_width = 8
#non_k_dim = 32
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
@@ -65,7 +65,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
!c_ty = f32
#k_width = 8
#non_k_dim = 32
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
@@ -85,7 +85,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
!c_ty = f32
#k_width = 4
#non_k_dim = 32
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
@@ -105,7 +105,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
!c_ty = f32
#k_width = 2
#non_k_dim = 32
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
@@ -125,7 +125,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
!c_ty = f32
#k_width = 4
#non_k_dim = 32
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
@@ -145,7 +145,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
!c_ty = f32
#k_width = 1
#non_k_dim = 32
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
@@ -165,7 +165,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
!c_ty = i32
#k_width = 4
#non_k_dim = 32
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
@@ -185,7 +185,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
!c_ty = i32
#k_width = 8
#non_k_dim = 32
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
@@ -205,7 +205,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
!c_ty = f32
#k_width = 8
#non_k_dim = 16
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
@@ -225,7 +225,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
!c_ty = f32
#k_width = 8
#non_k_dim = 16
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
@@ -245,7 +245,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
!c_ty = f32
#k_width = 8
#non_k_dim = 16
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
@@ -265,7 +265,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
!c_ty = f32
#k_width = 8
#non_k_dim = 16
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
@@ -285,7 +285,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
!c_ty = f32
#k_width = 4
#non_k_dim = 16
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
@@ -305,7 +305,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
!c_ty = f32
#k_width = 2
#non_k_dim = 16
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
@@ -325,7 +325,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
!c_ty = f32
#k_width = 4
#non_k_dim = 16
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
@@ -345,7 +345,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
!c_ty = f32
#k_width = 1
#non_k_dim = 16
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
@@ -365,7 +365,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
!c_ty = i32
#k_width = 4
#non_k_dim = 16
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
@@ -385,7 +385,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
!c_ty = i32
#k_width = 8
#non_k_dim = 16
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
@@ -405,7 +405,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
!c_ty = f32
#k_width = 4
#non_k_dim = 4
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
@@ -425,7 +425,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
!c_ty = f32
#k_width = 2
#non_k_dim = 4
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
@@ -445,7 +445,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
!c_ty = f32
#k_width = 4
#non_k_dim = 4
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
@@ -465,7 +465,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
!c_ty = f32
#k_width = 1
#non_k_dim = 4
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
@@ -485,7 +485,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
!c_ty = i32
#k_width = 4
#non_k_dim = 4
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {

View File

@@ -5,7 +5,7 @@
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed=false}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @matmul_kernel_0d1d2d3d4d5d6d7c8d9c10d11c(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mfma>

View File

@@ -1355,7 +1355,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma0 = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma0 = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA=[1,1], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma0, kWidth = 4}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma0, kWidth = 4}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
@@ -1379,7 +1379,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTranspose=false}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_mfma_block
@@ -1529,7 +1529,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = 4}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = 4}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {