[OPTIMIZATION] Fix performance for attention backward path with mma v3 (#2411)

Support having chain of mma with mixed size.
Serialize the different block calculation in backward attention to
workaround problem with ptxas and wgmma.
This commit is contained in:
Thomas Raoux
2023-09-28 10:29:08 -07:00
committed by GitHub
parent 1e093fbfff
commit 721bdebee1
11 changed files with 164 additions and 195 deletions

View File

@@ -141,16 +141,6 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
ArrayRef<unsigned> shape);
// Implement backward and forward slice that will go through scf blocks when
// yield or scf results are in the slice.
// Note that like exisiting forward and backard slice this may add operations to
// the slice that are not actually dependent on the root because when a region
// is added to the slice in the forward slice all the operations of the region
// are added. We could implement a more accurate slice method by tracking value
// usage across scf regions.
void getBackwardSliceSCFAware(Operation *, SetVector<Operation *> *slices);
void getForwardSliceSCFAware(Value root, SetVector<Operation *> *slices);
} // namespace mlir
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

View File

@@ -424,6 +424,23 @@ bool supportMMA(Value value, int version) {
(elemTy.isInteger(8) && version >= 2);
}
static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) {
auto src = srcEncoding.dyn_cast<triton::gpu::MmaEncodingAttr>();
auto dst = dstEncoding.dyn_cast<triton::gpu::MmaEncodingAttr>();
if (!src || !dst)
return false;
auto srcInstrShape = src.getInstrShape();
auto dstInstrShape = dst.getInstrShape();
// when #mma = MmaEncoding<version=3, warpsPerCTA=[..., 1]>
return src && dst && src.getVersionMajor() == 3 &&
src.getWarpsPerCTA()[1] == 1 && dst.getVersionMajor() == 3 &&
dst.getWarpsPerCTA()[1] == 1 && srcInstrShape[2] == dstInstrShape[2];
}
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
return isMmaToMmaShortcut(srcTy.getEncoding(), dstTy.getEncoding());
}
// For MMAV3 dotOperand layout matches mma operand for f16 case.
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
RankedTensorType dstTy) {
@@ -432,7 +449,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
auto mmaLayout = srcLayout.cast<triton::gpu::MmaEncodingAttr>();
auto dotOperandLayout = dstLayout.cast<triton::gpu::DotOperandEncodingAttr>();
return mmaLayout.getVersionMajor() == 3 && dotOperandLayout.getOpIdx() == 0 &&
dotOperandLayout.getParent() == mmaLayout &&
isMmaToMmaShortcut(dotOperandLayout.getParent(), srcLayout) &&
srcTy.getElementType().isF16();
}
@@ -452,17 +469,6 @@ bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
!srcTy.getElementType().isF32();
}
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
auto src = srcTy.getEncoding().cast<triton::gpu::MmaEncodingAttr>();
auto dst = dstTy.getEncoding().cast<triton::gpu::MmaEncodingAttr>();
auto srcElemsPerThread = triton::gpu::getTotalElemsPerThread(srcTy);
auto dstElemsPerThread = triton::gpu::getTotalElemsPerThread(dstTy);
// when #mma = MmaEncoding<version=3, warpsPerCTA=[..., 1]>
return src.getVersionMajor() == 3 && src.getWarpsPerCTA()[1] == 1 &&
dst.getVersionMajor() == 3 && dst.getWarpsPerCTA()[1] == 1 &&
srcElemsPerThread == dstElemsPerThread;
}
bool isSingleValue(Value value) {
// Don't consider load as expensive if it is loading a scalar.
if (auto tensorTy = value.getType().dyn_cast<RankedTensorType>())

View File

@@ -81,8 +81,7 @@ public:
// forwarding on mma->mma shortcut, lower distributed->distributed otherwise
if (srcLayout.isa<MmaEncodingAttr>() && dstLayout.isa<MmaEncodingAttr>()) {
if (isMmaToMmaShortcut(srcTy, dstTy)) {
rewriter.replaceOp(op, op.getSrc());
return success();
return lowerMmaToMma(op, adaptor, rewriter);
}
}
if (isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout)) {
@@ -963,6 +962,43 @@ private:
return failure();
}
// mma -> mma
LogicalResult lowerMmaToMma(triton::gpu::ConvertLayoutOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
auto dstTy = op.getResult().getType().cast<RankedTensorType>();
if (triton::gpu::getTotalElemsPerThread(srcTy) ==
triton::gpu::getTotalElemsPerThread(dstTy)) {
rewriter.replaceOp(op, op.getSrc());
return success();
}
// get source values
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
rewriter, srcTy);
SmallVector<Value> retVals;
SmallVector<unsigned> dstElementPerThread =
triton::gpu::getElemsPerThread(dstTy);
SmallVector<unsigned> srcElementPerThread =
triton::gpu::getElemsPerThread(srcTy);
for (unsigned j = 0; j < dstElementPerThread[0]; j++) {
for (unsigned i = 0; i < dstElementPerThread[1]; i++) {
if (i >= srcElementPerThread[1] || j >= srcElementPerThread[0]) {
retVals.push_back(undef(vals[0].getType()));
continue;
}
unsigned index = i + j * srcElementPerThread[1];
retVals.push_back(vals[index]);
}
}
assert(retVals.size() == triton::gpu::getTotalElemsPerThread(dstTy));
Value view =
getTypeConverter()->packLLElements(loc, retVals, rewriter, dstTy);
rewriter.replaceOp(op, view);
return success();
}
// shared -> dot_operand if the result layout is mma
Value lowerSharedToDotOperandMMA(
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,

View File

@@ -220,7 +220,10 @@ DotOpMmaV3SmemLoader loadB(TritonGPUToLLVMTypeConverter *typeConverter,
llvm::SmallVector<Value> loadReg(ConversionPatternRewriter &rewriter,
Location loc,
const SmallVector<Value> &elements,
int startIndex, int numElements) {
int startIndex, int numElements,
Operation *insertBefore) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(insertBefore);
if (!elements[0].getType().isF16()) {
llvm::SmallVector<Value> mmaOut(numElements);
for (int i = 0; i < numElements; ++i)
@@ -351,9 +354,12 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
auto func = op->getParentOfType<LLVM::LLVMFuncOp>();
int numTMADescs =
func->getAttr(kAttrNumTMALoadDescsName).cast<IntegerAttr>().getInt();
Operation *startSequence = nullptr;
if (numTMADescs == 0)
rewriter.create<triton::nvgpu::FenceAsyncSharedOp>(loc, 0);
rewriter.create<triton::nvgpu::WGMMAFenceOp>(loc);
startSequence = rewriter.create<triton::nvgpu::FenceAsyncSharedOp>(loc, 0);
Operation *fenceOp = rewriter.create<triton::nvgpu::WGMMAFenceOp>(loc);
if (startSequence == nullptr)
startSequence = fenceOp;
// WGMMA fp8 -> fp32 accumulates in lower precision than fp32.
bool needsPartialAccumulator = isFP8(eltTypeA) &&
eltTypeC == triton::nvgpu::WGMMAEltType::f32 &&
@@ -362,7 +368,8 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
for (int m = 0; m < numRepM; ++m) {
for (int n = 0; n < numRepN; ++n) {
llvm::SmallVector<Value> mmaOut =
loadReg(rewriter, loc, fc, (m * numRepN + n) * accSize, accSize);
loadReg(rewriter, loc, fc, (m * numRepN + n) * accSize, accSize,
startSequence);
llvm::SmallVector<Type> elemTypes;
for (Value accEl : mmaOut)
elemTypes.push_back(accEl.getType());
@@ -379,8 +386,9 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
a = aLoader.smemLoad(m, k, rewriter, loc);
} else {
unsigned regASize = (instrShape[0] * instrShape[2]) / 32;
llvm::SmallVector<Value> regA = loadReg(
rewriter, loc, structA, (m * numRepK + k) * regASize, regASize);
llvm::SmallVector<Value> regA =
loadReg(rewriter, loc, structA, (m * numRepK + k) * regASize,
regASize, startSequence);
auto regATy = LLVM::LLVMStructType::getLiteral(
rewriter.getContext(),
SmallVector<Type>(regA.size(), regA[0].getType()));

View File

@@ -1480,7 +1480,7 @@ struct TritonGPUInferLayoutInterface
auto dotOpEnc = operandEncoding.dyn_cast<DotOperandEncodingAttr>();
if (!operandEncoding.isa<SharedEncodingAttr>() &&
!(opIdx == 0 && dotOpEnc && dotOpEnc.getOpIdx() == 0 &&
dotOpEnc.getParent() == mmaRetEncoding)) {
dotOpEnc.getParent().isa<MmaEncodingAttr>())) {
return emitOptionalError(
location, "unexpected operand layout for MmaEncodingAttr v3");
}

View File

@@ -145,35 +145,6 @@ public:
}
}
unsigned getMmaV3InstrN(tt::DotOp dotOp, unsigned currN) const {
auto type = dotOp.getResult().getType().cast<RankedTensorType>();
if (type.getEncoding().isa<MmaEncodingAttr>())
return currN;
auto it = dotOpInstNs.find(dotOp.getOperation());
if (it != dotOpInstNs.end())
return it->second;
SetVector<Operation *> slices;
mlir::getForwardSliceSCFAware(dotOp.getResult(), &slices);
mlir::getBackwardSliceSCFAware(dotOp.getOperation(), &slices);
unsigned N = currN;
SmallVector<Operation *> dotOps;
for (Operation *iter : slices) {
if (auto nextDotOp = dyn_cast<tt::DotOp>(iter)) {
auto type = nextDotOp.getResult().getType().cast<RankedTensorType>();
auto AType = nextDotOp.getOperand(0).getType().cast<RankedTensorType>();
auto shapePerCTA = ttg::getShapePerCTA(type);
auto instrShape = mmaVersionToInstrShape(3, shapePerCTA, AType);
dotOps.push_back(iter);
if (instrShape[1] < N)
N = instrShape[1];
}
}
for (Operation *dotOp : dotOps)
dotOpInstNs[dotOp] = N;
return N;
}
static Value getMMAv3Operand(Value v, mlir::PatternRewriter &rewriter,
int opIdx) {
Value arg = v;
@@ -232,9 +203,6 @@ public:
auto instrShape =
mmaVersionToInstrShape(versionMajor, retShapePerCTA, AType);
if (versionMajor == 3)
instrShape[1] = getMmaV3InstrN(dotOp, instrShape[1]);
// operands
Value a = dotOp.getA();
Value b = dotOp.getB();

View File

@@ -247,8 +247,10 @@ struct MMAV3UseRegOperand : public OpRewritePattern<triton::DotOp> {
return failure();
auto srcEncoding =
getEncoding(convertLhs.getSrc()).dyn_cast<MmaEncodingAttr>();
if (!srcEncoding || srcEncoding.getVersionMajor() != 3 ||
srcEncoding != getEncoding(dotOp.getResult()))
auto dstEncoding =
getEncoding(dotOp.getResult()).dyn_cast<MmaEncodingAttr>();
if (!srcEncoding || srcEncoding.getVersionMajor() != 3 || !dstEncoding ||
dstEncoding.getVersionMajor() != 3)
return failure();
// We currently only support convert from f16 mma to f16 dot operand as the
// other types require shuffling data across threads.

View File

@@ -217,7 +217,8 @@ static bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
if (convertOp.getResult()
.getType()
.cast<RankedTensorType>()
.getEncoding() == encoding)
.getEncoding()
.isa<triton::gpu::MmaEncodingAttr>())
return true;
}
auto yield = dyn_cast<scf::YieldOp>(op);

View File

@@ -492,46 +492,6 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
return linear;
}
void getBackwardSliceSCFAware(Operation *op, SetVector<Operation *> *slices) {
SmallVector<Operation *> queue = {op};
while (!queue.empty()) {
Operation *currentOp = queue.back();
queue.pop_back();
SetVector<Operation *> temp;
auto filter = [slices](Operation *sliceOp) {
return slices->count(sliceOp) == 0;
};
mlir::getBackwardSlice(currentOp, &temp, filter);
for (Operation *sliceOp : temp) {
if (auto forOp = dyn_cast<scf::ForOp>(sliceOp)) {
queue.push_back(forOp.getBody()->getTerminator());
}
}
slices->insert(temp.begin(), temp.end());
}
}
void getForwardSliceSCFAware(Value root, SetVector<Operation *> *slices) {
SmallVector<Value> queue = {root};
while (!queue.empty()) {
Value currentValue = queue.back();
queue.pop_back();
SetVector<Operation *> temp;
auto filter = [slices](Operation *sliceOp) {
return slices->count(sliceOp) == 0;
};
mlir::getForwardSlice(currentValue, &temp, filter);
for (Operation *sliceOp : temp) {
if (auto yieldOp = dyn_cast<scf::YieldOp>(sliceOp)) {
auto forOp = yieldOp->getParentOfType<scf::ForOp>();
if (forOp)
queue.append(forOp->getResults().begin(), forOp->getResults().end());
}
}
slices->insert(temp.begin(), temp.end());
}
}
namespace {
/// Detect dead arguments in scf.for op by assuming all the values are dead and

View File

@@ -323,106 +323,102 @@ def _attn_bwd(
# load scales
offs_k = tl.arange(0, BLOCK_DMODEL)
if (tl.program_id(1) == 0):
# THIS BLOCK DOES DK/DV/DR:
# THIS BLOCK DOES DK/DV/DR:
start_n = pid * BLOCK_N1
start_m = start_n
start_n = pid * BLOCK_N1
start_m = start_n
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
offs_n = start_n + tl.arange(0, BLOCK_N1)
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
offs_n = start_n + tl.arange(0, BLOCK_N1)
dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
# load K and V: they stay in SRAM throughout the inner loop.
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
# load K and V: they stay in SRAM throughout the inner loop.
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
num_steps = BLOCK_N1 // MASK_BLOCK_M1
num_steps = BLOCK_N1 // MASK_BLOCK_M1
dk, dv = _attn_bwd_dkdv(dk, dv,
Q, k, v, sm_scale,
DO,
M, D,
stride_tok, stride_d,
H, N_CTX,
MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
start_n, start_m, num_steps,
MASK=True,
)
dk, dv = _attn_bwd_dkdv(dk, dv,
Q, k, v, sm_scale,
DO,
M, D,
stride_tok, stride_d,
H, N_CTX,
MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
start_n, start_m, num_steps,
MASK=True,
)
start_m += num_steps * MASK_BLOCK_M1
num_steps = (N_CTX - start_m) // BLOCK_M1
start_m += num_steps * MASK_BLOCK_M1
num_steps = (N_CTX - start_m) // BLOCK_M1
# Compute dK and dV for non-masked blocks.
dk, dv = _attn_bwd_dkdv(dk, dv,
Q, k, v, sm_scale,
DO,
M, D,
stride_tok, stride_d,
H, N_CTX,
BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
start_n, start_m, num_steps,
MASK=False,
)
# Compute dK and dV for non-masked blocks.
dk, dv = _attn_bwd_dkdv(dk, dv,
Q, k, v, sm_scale,
DO,
M, D,
stride_tok, stride_d,
H, N_CTX,
BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
start_n, start_m, num_steps,
MASK=False,
)
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
tl.store(dv_ptrs, dv)
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
tl.store(dv_ptrs, dv)
# Write back dK.
dk *= sm_scale
dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
tl.store(dk_ptrs, dk)
# Write back dK.
dk *= sm_scale
dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
tl.store(dk_ptrs, dk)
# THIS BLOCK DOES DQ:
start_m = pid * BLOCK_M2
end_n = start_m + BLOCK_M2
else:
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
offs_m = start_m + tl.arange(0, BLOCK_M2)
# THIS BLOCK DOES DQ:
start_m = pid * BLOCK_M2
end_n = start_m + BLOCK_M2
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
offs_m = start_m + tl.arange(0, BLOCK_M2)
m = tl.load(M + offs_m)
m = m[:, None]
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
m = tl.load(M + offs_m)
m = m[:, None]
# Compute dQ for masked (diagonal) blocks.
# NOTE: This code scans each row of QK^T backward (from right to left,
# but inside each call to _attn_bwd_dq, from left to right), but that's
# not due to anything important. I just wanted to reuse the loop
# structure for dK & dV above as much as possible.
num_steps = BLOCK_M2 // MASK_BLOCK_N2
dq = _attn_bwd_dq(
dq, q, K, V,
do, m, D,
stride_tok, stride_d,
H, N_CTX,
BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL,
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps,
MASK=True,
)
end_n -= num_steps * MASK_BLOCK_N2
# stage 2
num_steps = end_n // BLOCK_N2
dq = _attn_bwd_dq(
dq, q, K, V,
do, m, D,
stride_tok, stride_d,
H, N_CTX,
BLOCK_M2, BLOCK_N2, BLOCK_DMODEL,
start_m, end_n - num_steps * BLOCK_N2, num_steps,
MASK=False,
)
# Write back dQ.
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
dq *= LN2
tl.store(dq_ptrs, dq)
# Compute dQ for masked (diagonal) blocks.
# NOTE: This code scans each row of QK^T backward (from right to left,
# but inside each call to _attn_bwd_dq, from left to right), but that's
# not due to anything important. I just wanted to reuse the loop
# structure for dK & dV above as much as possible.
num_steps = BLOCK_M2 // MASK_BLOCK_N2
dq = _attn_bwd_dq(
dq, q, K, V,
do, m, D,
stride_tok, stride_d,
H, N_CTX,
BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL,
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps,
MASK=True,
)
end_n -= num_steps * MASK_BLOCK_N2
# stage 2
num_steps = end_n // BLOCK_N2
dq = _attn_bwd_dq(
dq, q, K, V,
do, m, D,
stride_tok, stride_d,
H, N_CTX,
BLOCK_M2, BLOCK_N2, BLOCK_DMODEL,
start_m, end_n - num_steps * BLOCK_N2, num_steps,
MASK=False,
)
# Write back dQ.
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
dq *= LN2
tl.store(dq_ptrs, dq)
empty = torch.empty(128, device="cuda")
@@ -491,7 +487,7 @@ class _attention(torch.autograd.Function):
BATCH, N_HEAD, N_CTX,
BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
grid = (N_CTX // BLOCK_N1, 2, BATCH * N_HEAD)
grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
_attn_bwd[grid](
q, arg_k, v, ctx.sm_scale, do, dq, dk, dv,
M, delta,

View File

@@ -1,6 +1,8 @@
// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=90 | FileCheck %s
// CHECK: #[[MMA:.+]] = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 16, 16]}>
// CHECK: #[[MMA1:.+]] = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
// CHECK: #[[MMA2:.+]] = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 32, 16]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
@@ -22,7 +24,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #blocked2>
// CHECK: scf.for
// CHECK: tt.dot {{.*}} -> tensor<128x16xf16, #[[MMA]]>
// CHECK: tt.dot {{.*}} -> tensor<128x64xf16, #[[MMA]]>
// CHECK: tt.dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]>
%115 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_0) -> (tensor<128x64xf16, #blocked1>) : i32 {
%172 = tt.dot %170, %171, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x16xf16, #blocked>
%178 = triton_gpu.convert_layout %172 : (tensor<128x16xf16, #blocked>) -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>>
@@ -30,8 +32,8 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
scf.yield %180 : tensor<128x64xf16, #blocked1>
}
// CHECK: scf.for
// CHECK: tt.dot {{.*}} -> tensor<128x32xf16, #[[MMA]]>
// CHECK: tt.dot {{.*}} -> tensor<128x64xf16, #[[MMA]]>
// CHECK: tt.dot {{.*}} -> tensor<128x32xf16, #[[MMA2]]>
// CHECK: tt.dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]>
%149 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %115) -> (tensor<128x64xf16, #blocked1>) : i32 {
%166 = tt.dot %164, %165, %cst_2 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x32xf16, #blocked2>
%172 = triton_gpu.convert_layout %166 : (tensor<128x32xf16, #blocked2>) -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>>