mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[MFMA] Remove CTA related code from layout (#429)
This PR removes CTALayout attribute from MFMA layout, because it is NV specific.
This commit is contained in:
@@ -764,8 +764,7 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
|
||||
ins
|
||||
"unsigned":$nonKDim,
|
||||
ArrayRefParameter<"unsigned">:$warpsPerCTA,
|
||||
"bool":$isTransposed,
|
||||
"CTALayoutAttr":$CTALayout
|
||||
"bool":$isTransposed
|
||||
);
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
@@ -772,7 +772,7 @@ private:
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::MfmaEncodingAttr>();
|
||||
auto newMfmaEnc = triton::gpu::MfmaEncodingAttr::get(
|
||||
mod.getContext(), srcMfma.getNonKDim(), {warpsPerCtaX, warpsPerCtaY},
|
||||
srcMfma.getIsTransposed(), srcMfma.getCTALayout());
|
||||
srcMfma.getIsTransposed());
|
||||
|
||||
auto newDstType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(), dstType.getEncoding());
|
||||
|
||||
@@ -543,7 +543,7 @@ SmallVector<unsigned> getCTAsPerCGA(Attribute layout) {
|
||||
ref = mmaLayout.getCTALayout().getCTAsPerCGA();
|
||||
#ifdef USE_ROCM
|
||||
else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>())
|
||||
ref = mfmaLayout.getCTALayout().getCTAsPerCGA();
|
||||
return {1, 1};
|
||||
#endif
|
||||
else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>())
|
||||
return getCTAsPerCGA(dotLayout.getParent());
|
||||
@@ -568,8 +568,8 @@ SmallVector<unsigned> getCTASplitNum(Attribute layout) {
|
||||
mmaLayout.getCTALayout().getCTASplitNum().end());
|
||||
#ifdef USE_ROCM
|
||||
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
res.assign(mfmaLayout.getCTALayout().getCTASplitNum().begin(),
|
||||
mfmaLayout.getCTALayout().getCTASplitNum().end());
|
||||
res.resize(2);
|
||||
res[0] = res[1] = 1;
|
||||
#endif
|
||||
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
res = getCTASplitNum(dotLayout.getParent());
|
||||
@@ -598,7 +598,7 @@ SmallVector<unsigned> getCTAOrder(Attribute layout) {
|
||||
ref = mmaLayout.getCTALayout().getCTAOrder();
|
||||
#ifdef USE_ROCM
|
||||
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
ref = mfmaLayout.getCTALayout().getCTAOrder();
|
||||
return {0, 1};
|
||||
#endif
|
||||
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
return getCTAOrder(dotLayout.getParent());
|
||||
@@ -673,8 +673,9 @@ unsigned getNumCTAs(Attribute layout) {
|
||||
else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>())
|
||||
CTAsPerCGA = mmaLayout.getCTALayout().getCTAsPerCGA();
|
||||
#ifdef USE_ROCM
|
||||
else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>())
|
||||
CTAsPerCGA = mfmaLayout.getCTALayout().getCTAsPerCGA();
|
||||
else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
return 1;
|
||||
}
|
||||
#endif
|
||||
else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>())
|
||||
return getNumCTAs(dotLayout.getParent());
|
||||
@@ -1316,9 +1317,6 @@ Attribute MfmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
unsigned nonKDim = 0;
|
||||
SmallVector<unsigned> warpsPerCTA;
|
||||
bool isTransposed;
|
||||
SmallVector<unsigned> CTAsPerCGA;
|
||||
SmallVector<unsigned> CTASplitNum;
|
||||
SmallVector<unsigned> CTAOrder;
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "nonKDim") {
|
||||
@@ -1332,35 +1330,17 @@ Attribute MfmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parseBool(parser, attr, isTransposed, "isTransposed").failed())
|
||||
return {};
|
||||
}
|
||||
if (attr.getName() == "CTAsPerCGA") {
|
||||
if (parseIntArrayAttr(parser, attr, CTAsPerCGA, "CTAsPerCGA").failed())
|
||||
return {};
|
||||
}
|
||||
if (attr.getName() == "CTASplitNum") {
|
||||
if (parseIntArrayAttr(parser, attr, CTASplitNum, "CTASplitNum").failed())
|
||||
return {};
|
||||
}
|
||||
if (attr.getName() == "CTAOrder") {
|
||||
if (parseIntArrayAttr(parser, attr, CTAOrder, "CTAOrder").failed())
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
auto CTALayout = CTALayoutAttr::get(parser.getContext(), CTAsPerCGA,
|
||||
CTASplitNum, CTAOrder);
|
||||
|
||||
return parser.getChecked<MfmaEncodingAttr>(
|
||||
parser.getContext(), nonKDim, warpsPerCTA, isTransposed, CTALayout);
|
||||
return parser.getChecked<MfmaEncodingAttr>(parser.getContext(), nonKDim,
|
||||
warpsPerCTA, isTransposed);
|
||||
}
|
||||
|
||||
void MfmaEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "nonKDim = " << getNonKDim() << ", "
|
||||
<< "warpsPerCTA = [" << getWarpsPerCTA() << "], "
|
||||
<< "isTransposed = " << getIsTransposed() << ", "
|
||||
<< "CTAsPerCGA = [" << getCTALayout().getCTAsPerCGA() << "], "
|
||||
<< "CTASplitNum = [" << getCTALayout().getCTASplitNum() << "], "
|
||||
<< "CTAOrder = [" << getCTALayout().getCTAOrder() << "]}>";
|
||||
<< "isTransposed = " << getIsTransposed() << "}>";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -243,6 +243,9 @@ public:
|
||||
return failure();
|
||||
|
||||
auto CTALayout = ttg::getCTALayout(oldRetType.getEncoding());
|
||||
assert(CTALayout.getCTAsPerCGA().size() == 2);
|
||||
assert(CTALayout.getCTAsPerCGA()[0] == 1);
|
||||
assert(CTALayout.getCTAsPerCGA()[1] == 1);
|
||||
|
||||
// get MFMA encoding for the given number of warps
|
||||
auto retShape = oldRetType.getShape();
|
||||
@@ -264,7 +267,7 @@ public:
|
||||
|
||||
bool isTransposed = isChainDot(dotOp);
|
||||
mfmaEnc = ttg::MfmaEncodingAttr::get(oldRetType.getContext(), nonKDim,
|
||||
warpsPerTile, isTransposed, CTALayout);
|
||||
warpsPerTile, isTransposed);
|
||||
|
||||
auto newRetType =
|
||||
RankedTensorType::get(retShape, oldRetType.getElementType(), mfmaEnc);
|
||||
|
||||
@@ -2700,16 +2700,13 @@ class MmaLayout:
|
||||
|
||||
|
||||
class MfmaLayout:
|
||||
def __init__(self, non_k_dim, warps_per_cta, is_transposed, ctas_per_cga, cta_split_num, cta_order):
|
||||
def __init__(self, non_k_dim, warps_per_cta, is_transposed):
|
||||
self.non_k_dim = str(non_k_dim)
|
||||
self.warps_per_cta = str(warps_per_cta)
|
||||
self.is_transposed = str(is_transposed).lower()
|
||||
self.ctas_per_cga = str(ctas_per_cga)
|
||||
self.cta_split_num = str(cta_split_num)
|
||||
self.cta_order = str(cta_order)
|
||||
|
||||
def __str__(self):
|
||||
return f"#{GPU_DIALECT}.mfma<{{nonKDim = {self.non_k_dim}, warpsPerCTA = {self.warps_per_cta}, isTransposed = {self.is_transposed}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
|
||||
return f"#{GPU_DIALECT}.mfma<{{nonKDim = {self.non_k_dim}, warpsPerCTA = {self.warps_per_cta}, isTransposed = {self.is_transposed}}}>"
|
||||
|
||||
|
||||
class BlockedLayout:
|
||||
@@ -2776,7 +2773,7 @@ def test_dot_mfma_vector_load(vec_size, swizzle, transposeA, transposeB):
|
||||
blocked = BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[8, 8], warps_per_cta=[4, 1], order=[1, 0], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1])
|
||||
shared_a = SharedLayout(vec=vec_size, per_phase=1, max_phase=max_phase, order=[0, 1] if transposeA else [1, 0], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1])
|
||||
shared_b = SharedLayout(vec=vec_size, per_phase=1, max_phase=max_phase, order=[0, 1] if transposeB else [1, 0], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1])
|
||||
mfma = MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=False, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1])
|
||||
mfma = MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=False)
|
||||
|
||||
ir = f"""
|
||||
#blocked = {blocked}
|
||||
@@ -2938,8 +2935,8 @@ module attributes {"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32,
|
||||
|
||||
if torch.version.hip is not None and _get_warp_size() == 64:
|
||||
layouts = [
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0]),
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[2, 2], is_transposed=False, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0]),
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True),
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[2, 2], is_transposed=False),
|
||||
]
|
||||
shapes = [[128, 32], [128, 128], [32, 128], [64, 64]]
|
||||
else:
|
||||
@@ -3097,8 +3094,8 @@ def test_scan_layouts(M, N, src_layout, axis, device):
|
||||
|
||||
@pytest.mark.parametrize("shape", [(64, 64)])
|
||||
@pytest.mark.parametrize("dtype", ['float16'])
|
||||
@pytest.mark.parametrize("src_layout", [MfmaLayout(non_k_dim=32, warps_per_cta=[2, 1], is_transposed=False, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0]),
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0])])
|
||||
@pytest.mark.parametrize("src_layout", [MfmaLayout(non_k_dim=32, warps_per_cta=[2, 1], is_transposed=False),
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True)])
|
||||
@pytest.mark.parametrize("dst_layout", [BlockedLayout([1, 4], [4, 16], [1, 1], [1, 0], [1, 1], [1, 1], [0, 1])])
|
||||
def test_make_range(dtype, shape, src_layout, dst_layout, device='cuda'):
|
||||
ir = f"""
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
!c_ty = f32
|
||||
#k_width = 8
|
||||
#non_k_dim = 32
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -25,7 +25,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
!c_ty = f32
|
||||
#k_width = 8
|
||||
#non_k_dim = 32
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -45,7 +45,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
!c_ty = f32
|
||||
#k_width = 8
|
||||
#non_k_dim = 32
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -65,7 +65,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
!c_ty = f32
|
||||
#k_width = 8
|
||||
#non_k_dim = 32
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -85,7 +85,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
!c_ty = f32
|
||||
#k_width = 4
|
||||
#non_k_dim = 32
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -105,7 +105,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
!c_ty = f32
|
||||
#k_width = 2
|
||||
#non_k_dim = 32
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -125,7 +125,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
!c_ty = f32
|
||||
#k_width = 4
|
||||
#non_k_dim = 32
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -145,7 +145,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
!c_ty = f32
|
||||
#k_width = 1
|
||||
#non_k_dim = 32
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -165,7 +165,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
!c_ty = i32
|
||||
#k_width = 4
|
||||
#non_k_dim = 32
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -185,7 +185,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
!c_ty = i32
|
||||
#k_width = 8
|
||||
#non_k_dim = 32
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -205,7 +205,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
!c_ty = f32
|
||||
#k_width = 8
|
||||
#non_k_dim = 16
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -225,7 +225,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
!c_ty = f32
|
||||
#k_width = 8
|
||||
#non_k_dim = 16
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -245,7 +245,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
!c_ty = f32
|
||||
#k_width = 8
|
||||
#non_k_dim = 16
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -265,7 +265,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
!c_ty = f32
|
||||
#k_width = 8
|
||||
#non_k_dim = 16
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -285,7 +285,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
!c_ty = f32
|
||||
#k_width = 4
|
||||
#non_k_dim = 16
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -305,7 +305,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
!c_ty = f32
|
||||
#k_width = 2
|
||||
#non_k_dim = 16
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -325,7 +325,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
!c_ty = f32
|
||||
#k_width = 4
|
||||
#non_k_dim = 16
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -345,7 +345,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
!c_ty = f32
|
||||
#k_width = 1
|
||||
#non_k_dim = 16
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -365,7 +365,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
!c_ty = i32
|
||||
#k_width = 4
|
||||
#non_k_dim = 16
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -385,7 +385,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
!c_ty = i32
|
||||
#k_width = 8
|
||||
#non_k_dim = 16
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -405,7 +405,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
!c_ty = f32
|
||||
#k_width = 4
|
||||
#non_k_dim = 4
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -425,7 +425,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
!c_ty = f32
|
||||
#k_width = 2
|
||||
#non_k_dim = 4
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -445,7 +445,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
!c_ty = f32
|
||||
#k_width = 4
|
||||
#non_k_dim = 4
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -465,7 +465,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
!c_ty = f32
|
||||
#k_width = 1
|
||||
#non_k_dim = 4
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -485,7 +485,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
!c_ty = i32
|
||||
#k_width = 4
|
||||
#non_k_dim = 4
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = #non_k_dim, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed=false}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
|
||||
tt.func public @matmul_kernel_0d1d2d3d4d5d6d7c8d9c10d11c(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mfma>
|
||||
|
||||
@@ -1355,7 +1355,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma0 = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA=[1,1], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma0 = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA=[1,1], isTranspose=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma0, kWidth = 4}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma0, kWidth = 4}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -1379,7 +1379,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTranspose=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTranspose=false}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// CHECK-LABEL: convert_layout_mfma_block
|
||||
@@ -1529,7 +1529,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
|
||||
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed=false, CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [2, 2], isTransposed=false}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = 4}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = 4}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
Reference in New Issue
Block a user