Dot slicing pass (#440)

* First commit

* Implement DotSlicing pass.

* small fixes

* Support chained dot in DotSlicingPass (second GEMM in FA)

* Add lit test for FA dot slicing

---------

Co-authored-by: Ognjen Plavsic <ognjen.plavsic@luxoft.com>
Co-authored-by: Ognjen <oplavsic@luxoft.com>
This commit is contained in:
oplavsic
2024-01-16 21:25:10 +01:00
committed by GitHub
parent a819e48435
commit 760ac8441a
10 changed files with 725 additions and 11 deletions

View File

@@ -12,6 +12,7 @@ std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 3,
int computeCapability = 80);
std::unique_ptr<Pass> createTritonGPUStreamPipelinePass();
std::unique_ptr<Pass> createTritonAMDGPUDotSlicingPass(int sliceKTile = 0);
std::unique_ptr<Pass>
createTritonGPUAccelerateMatmulPass(int computeCapability = 80);

View File

@@ -49,6 +49,26 @@ def TritonGPUStreamPipeline : Pass<"tritongpu-stream-pipeline", "mlir::ModuleOp"
"mlir::arith::ArithDialect"];
}
def TritonAMDGPUDotSlicing: Pass<"tritonamdgpu-dot-slicing", "mlir::ModuleOp"> {
let summary = "'DotOp' instruction slicing";
let description = [{
Slice 'DotOp' instruction into multiple smaller 'DotOp' instructions
in order to improve scheduling and latency hiding.
}];
let constructor = "mlir::createTritonAMDGPUDotSlicingPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::arith::ArithDialect"];
let options = [
Option<"sliceKTile", "slice-k-tile",
"int32_t", /*default*/"0",
"slice size in k dimension">
];
}
def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
let summary = "prefetch";

View File

@@ -13,6 +13,7 @@ add_mlir_dialect_library(TritonGPUTransforms
RemoveLayoutConversions.cpp
ReorderInstructions.cpp
StreamPipeline.cpp
DotSlicing.cpp
TritonGPUConversion.cpp
Utility.cpp

View File

@@ -0,0 +1,464 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
//===----------------------------------------------------------------------===//
// This file implements the dot slicing pass. It slices along the k dimension
// in a dot operation (for dot A, B with shape(A) = (m,k), shape(B) = (k,n)).
// The slice size along the k dimension is set by 'sliceKTile'.
// The algorithm includes three main steps:
//
// 1) Modify load instruction layouts for dot operands.
// - This ensures tensor slices have the same layout as the original tensor,
// allowing slicing without data exchange between threads.
//
// 2) Slice the dot operands.
// - Here, the original dot operands are sliced, starting from the load
// instruction. This helps in hiding latency in global memory transactions.
// Currently, the algorithm assumes the original tensor undergoes only
// elementwise operations before being used as a dot operand. In other
// words, the algorithm expects only elementwise operations and convert
// layout operations to occur between the load and dot instructions.
// However, It can be extended to handle other operations if necessary.
//
// 3) Slice the dot operation along the k dimension.
// - This involves calculating the number of slices from 'sliceKTile', and
// then creating a new dot operation for each slice using the operands
// sliced in the previous step. Sliced dots are concatenated so the result
// of each dot is used in the next, leveraging the slicing along the k dim.
//===----------------------------------------------------------------------===//
using namespace mlir;
namespace tt = triton;
namespace ttg = triton::gpu;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
namespace {
bool isElementwiseOp(Operation *op) {
if (llvm::isa<arith::AddFOp, arith::AddIOp, arith::AndIOp, arith::CeilDivSIOp,
arith::CeilDivUIOp, arith::DivFOp, arith::DivSIOp,
arith::DivUIOp, arith::ExtFOp, arith::ExtSIOp, arith::ExtUIOp,
arith::FloorDivSIOp, arith::FPToSIOp, arith::FPToUIOp,
arith::MaximumFOp, arith::MaxSIOp, arith::MaxUIOp,
arith::MinimumFOp, arith::MinSIOp, arith::MinUIOp,
arith::MulFOp, arith::MulIOp, arith::NegFOp, arith::OrIOp,
arith::RemFOp, arith::RemSIOp, arith::RemUIOp, arith::ShLIOp,
arith::ShRSIOp, arith::ShRUIOp, arith::SIToFPOp, arith::SubFOp,
arith::SubIOp, arith::TruncFOp, arith::TruncIOp,
arith::UIToFPOp, arith::XOrIOp>(op))
return true;
if (llvm::isa<math::AbsFOp, math::AbsIOp, math::AtanOp, math::Atan2Op,
math::CeilOp, math::CopySignOp, math::CosOp, math::SinOp,
math::CountLeadingZerosOp, math::CountTrailingZerosOp,
math::CtPopOp, math::ErfOp, math::ExpOp, math::Exp2Op,
math::ExpM1Op, math::FloorOp, math::FmaOp, math::LogOp,
math::Log10Op, math::Log1pOp, math::Log2Op, math::PowFOp,
math::RsqrtOp, math::SqrtOp, math::TanhOp>(op))
return true;
if (llvm::isa<tt::IntToPtrOp, tt::PtrToIntOp, tt::BitcastOp, tt::FpToFpOp,
tt::AddPtrOp>(op))
return true;
if (auto externElementwiseOp = dyn_cast<tt::ExternElementwiseOp>(op))
return externElementwiseOp.getPure();
if (llvm::isa<arith::CmpIOp, arith::CmpFOp, arith::SelectOp>(op))
return true;
return false;
}
} // anonymous namespace
struct TritonAMDGPUDotSlicingPass
: public TritonAMDGPUDotSlicingBase<TritonAMDGPUDotSlicingPass> {
TritonAMDGPUDotSlicingPass() = default;
TritonAMDGPUDotSlicingPass(int sliceKTile) { this->sliceKTile = sliceKTile; }
// Find user of the currOp that affects dotOperand calculation.
// We assume here that there is only one such user.
Operation *getUserThatAffectsDotOperand(Operation *currOp,
Operation *dotOperand) {
SetVector<Operation *> forwardSlices;
SmallVector<Operation *> usersThatAffectDot;
for (auto *user : currOp->getUsers()) {
forwardSlices.clear();
getForwardSlice(user, &forwardSlices);
if (user == dotOperand) {
usersThatAffectDot.push_back(user);
continue;
}
if (std::find(forwardSlices.begin(), forwardSlices.end(), dotOperand) !=
forwardSlices.end()) {
usersThatAffectDot.push_back(user);
}
}
assert(usersThatAffectDot.size() == 1);
return usersThatAffectDot[0];
}
Value getSlicedDotOperand(Operation *firstOpToSlice, tt::DotOp dotOp,
int operandIdx, int loopIter, int sliceSizeK,
OpBuilder builder,
SmallVector<Operation *> &eraseOps) {
auto ptrTensor = firstOpToSlice->getOperand(0);
auto ptrTensorType = ptrTensor.getType().cast<RankedTensorType>();
auto dotOperand = dotOp.getOperand(operandIdx);
auto dotOperandTy = dotOp.getType().cast<RankedTensorType>();
SmallVector<int64_t> sliceSizes;
SmallVector<int64_t> sliceOffsets;
SmallVector<int64_t> sliceStrides{1, 1};
if (operandIdx == 0) {
sliceSizes.push_back(dotOperandTy.getShape()[0]);
sliceSizes.push_back(sliceSizeK);
sliceOffsets.push_back(0);
sliceOffsets.push_back(loopIter * sliceSizeK);
} else {
assert(operandIdx == 1);
sliceSizes.push_back(sliceSizeK);
sliceSizes.push_back(dotOperandTy.getShape()[1]);
sliceOffsets.push_back(loopIter * sliceSizeK);
sliceOffsets.push_back(0);
}
auto viewPtr = builder.create<ttg::ViewSliceOp>(
dotOp.getLoc(),
RankedTensorType::get(sliceSizes, ptrTensorType.getElementType(),
ptrTensorType.getEncoding()),
ptrTensor, ValueRange({}), ValueRange({}), ValueRange({}), sliceOffsets,
sliceSizes, sliceStrides);
// Begin with the load instruction and proceed to slice the operations
// along the execution path of the dotOperand.
IRMapping mapping;
mapping.map(ptrTensor, viewPtr);
Operation *currOp = firstOpToSlice;
Operation *slicedOp = nullptr;
while (true) {
if (loopIter == 0) {
eraseOps.push_back(currOp);
}
slicedOp = builder.clone(*currOp, mapping);
// The 'load', 'convert_layout', and 'elementwise' operations each have
// one result. This limitation can be removed if necessary.
assert(currOp->getNumResults() == 1);
// Convert the operation's results to sliced types.
for (auto [currRes, slicedRes] :
llvm::zip(currOp->getResults(), slicedOp->getResults())) {
auto slicedType = RankedTensorType::get(
viewPtr.getType().cast<RankedTensorType>().getShape(),
currRes.getType().cast<RankedTensorType>().getElementType(),
currRes.getType().cast<RankedTensorType>().getEncoding());
slicedRes.setType(slicedType);
}
mapping.map(currOp, slicedOp);
if (currOp == dotOperand.getDefiningOp()) {
break;
}
assert(llvm::isa<tt::LoadOp>(currOp) ||
llvm::isa<ttg::ConvertLayoutOp>(currOp) ||
isElementwiseOp(currOp));
// If currOp has more then one user, proceed with the one that is "on a
// path" of dot operand calculation. We expect there is only one such
// user.
auto currOpUser =
getUserThatAffectsDotOperand(currOp, dotOperand.getDefiningOp());
// The currOpUser operation can have multiple operands, such as in any
// binary elementwise op. In such cases, we slice all of the operands
// using the same sliceOffsets, sliceSizes, and sliceStrides. This
// approach is valid only under the assumption that currOpUser is an
// elementwise operation. For non-elementwise operations with multiple
// operands, slicing should potentially be handled differently.
for (auto operandVal : currOpUser->getOperands()) {
auto nonSlicedOperand = operandVal.getDefiningOp();
if (nonSlicedOperand == currOp) {
continue;
}
auto nonSlicedOperandTy = nonSlicedOperand->getResults()[0]
.getType()
.cast<RankedTensorType>();
auto slicedTy = RankedTensorType::get(
sliceSizes, nonSlicedOperandTy.getElementType(),
nonSlicedOperandTy.getEncoding());
auto slicedOperand = builder.create<ttg::ViewSliceOp>(
nonSlicedOperand->getLoc(), slicedTy,
nonSlicedOperand->getResults()[0], ValueRange({}), ValueRange({}),
ValueRange({}), sliceOffsets, sliceSizes, sliceStrides);
mapping.map(nonSlicedOperand->getResults()[0], slicedOperand);
}
currOp = currOpUser;
}
assert(llvm::isa<ttg::ConvertLayoutOp>(slicedOp));
return slicedOp->getResults()[0];
}
static Type getNewType(Type type, Attribute encoding) {
RankedTensorType tensorType = type.cast<RankedTensorType>();
return RankedTensorType::get(tensorType.getShape(),
tensorType.getElementType(), encoding);
}
// Same as coalesceOp function in Coalesce.cpp.
void convertLayout(Attribute encoding, Operation *op) {
OpBuilder builder(op);
// Convert operands
// For load/store with tensor pointers, we don't have to change the
// operands' type, we do this by changing the outputs' type of
// `make_tensor_ptr`
SmallVector<Value, 4> newArgs;
for (auto operand : op->getOperands()) {
auto tensorType = operand.getType().dyn_cast<RankedTensorType>();
if (tensorType &&
!tensorType.getEncoding().isa<ttg::SharedEncodingAttr>()) {
Type newType = getNewType(tensorType, encoding);
newArgs.push_back(builder.create<ttg::ConvertLayoutOp>(
op->getLoc(), newType, operand));
} else {
newArgs.push_back(operand);
}
}
// Convert output types
SmallVector<Type, 4> newTypes;
for (auto t : op->getResultTypes()) {
bool isAsync = isa<ttg::InsertSliceAsyncOp>(op);
newTypes.push_back(isAsync ? t : getNewType(t, encoding));
}
// Construct new op with the new encoding
Operation *newOp =
builder.create(op->getLoc(), op->getName().getIdentifier(), newArgs,
newTypes, op->getAttrs());
// Cast the results back to the original layout
for (size_t i = 0; i < op->getNumResults(); i++) {
Value newResult = newOp->getResult(i);
if (newTypes[i] != op->getResultTypes()[i]) {
newResult = builder.create<ttg::ConvertLayoutOp>(
op->getLoc(), op->getResult(i).getType(), newResult);
}
op->getResult(i).replaceAllUsesWith(newResult);
}
op->erase();
}
// Return true if layout was changed, else return false.
bool setBlockedLayout(Operation *firstOpToSlice, ArrayRef<long> shape,
ttg::BlockedEncodingAttr blockedEncoding,
int operandIdx) {
auto shapePerCTA = ttg::getShapePerCTATile(blockedEncoding, shape);
auto sizePerThread = blockedEncoding.getSizePerThread();
auto threadsPerWarp = blockedEncoding.getThreadsPerWarp();
auto warpsPerCTA = blockedEncoding.getWarpsPerCTA();
ModuleOp mod = getOperation();
// clang-format off
//
// Current layout can be used for slicing as is.
// Example: sliceKTile = 32, slicing along dim 1 (A operand)
// Layout: #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]>
//
// clang-format on
if (this->sliceKTile % shapePerCTA[1 - operandIdx] == 0) {
return false;
// clang-format off
//
// Current layout can be used for slicing only by setting warpsPerCTA to 1
// along slicing dim.
// Example: sliceKTile = 32, slicing along y dim (A operand)
// Layout: #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [2, 2], order = [1, 0]>
// NewLayout: #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]>
//
// clang-format on
} else if (this->sliceKTile % (shapePerCTA[1 - operandIdx] /
warpsPerCTA[1 - operandIdx]) ==
0) {
SmallVector<unsigned> newWarpsPerCTA(2, warpsPerCTA[0] * warpsPerCTA[1]);
newWarpsPerCTA[1 - operandIdx] = 1;
auto newBlockedEncoding = ttg::BlockedEncodingAttr::get(
mod.getContext(), sizePerThread, threadsPerWarp, newWarpsPerCTA,
blockedEncoding.getOrder(), blockedEncoding.getCTALayout());
convertLayout(newBlockedEncoding, firstOpToSlice);
// clang-format off
//
// Current layout can be used for slicing by setting warpsPerCTA to 1
// along slicing dim and changing ThreadsPerWarp parameter.
// Example: sliceKTile = 32, slicing along y dim (A operand)
// Layout: #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]>
// NewLayout: #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]>
//
// clang-format on
} else if (this->sliceKTile % sizePerThread[operandIdx] == 0) {
SmallVector<unsigned> newWarpsPerCTA(2, warpsPerCTA[0] * warpsPerCTA[1]);
newWarpsPerCTA[1 - operandIdx] = 1;
SmallVector<unsigned> newThreadsPerWarp(2, 1);
newThreadsPerWarp[operandIdx] =
(threadsPerWarp[0] * threadsPerWarp[1]) /
(this->sliceKTile / sizePerThread[1 - operandIdx]);
newThreadsPerWarp[1 - operandIdx] =
this->sliceKTile / sizePerThread[1 - operandIdx];
auto newBlockedEncoding = ttg::BlockedEncodingAttr::get(
mod.getContext(), sizePerThread, newThreadsPerWarp, newWarpsPerCTA,
blockedEncoding.getOrder(), blockedEncoding.getCTALayout());
convertLayout(newBlockedEncoding, firstOpToSlice);
// In other cases, the sizePerThread parameter would need to be changed,
// which can affect coalescing and thus potentially decrease performance.
} else {
assert(false && "Unexpected layout in DotSlicing pass.");
}
return true;
}
// Return true if layout was changed, else return false.
bool setLayoutForSlicing(Operation *firstOpToSlice, int operandIdx) {
auto firstOpToSliceTy =
firstOpToSlice->getOperand(0).getType().cast<RankedTensorType>();
auto srcShape = firstOpToSliceTy.getShape();
auto encoding = firstOpToSliceTy.getEncoding();
if (auto blockedEncoding = dyn_cast<ttg::BlockedEncodingAttr>(encoding)) {
return setBlockedLayout(firstOpToSlice, srcShape, blockedEncoding,
operandIdx);
} else if (auto mfmaEncoding = dyn_cast<ttg::MfmaEncodingAttr>(encoding)) {
auto shapePerCTA = ttg::getShapePerCTATile(mfmaEncoding, srcShape);
// TODO: Implement changing of mfma layout in case it is not suitable for
// slicing (similar as in setBlockedLayout).
assert(this->sliceKTile % shapePerCTA[1] == 0);
} else {
assert(false && "Unsupported layout in setLayoutForSlicing.");
}
return false;
}
tt::LoadOp getLoadInst(tt::DotOp dotOp, int operandIdx) {
auto dotOperand = dotOp.getOperand(operandIdx);
SmallVector<tt::LoadOp> loadOpsVec;
getOperation()->walk([&](tt::LoadOp loadOp) {
SetVector<Operation *> forwardSlices;
getForwardSlice((Operation *)loadOp, &forwardSlices);
if (std::find(forwardSlices.begin(), forwardSlices.end(),
dotOperand.getDefiningOp()) != forwardSlices.end()) {
loadOpsVec.push_back(loadOp);
}
});
// Currently, we expect the dot operand to depend only on one tensor
// from global memory (applicable for dot ops that don't depend on other dot
// ops). This condition can be lifted if necessary.
assert(loadOpsVec.size() == 1);
return loadOpsVec[0];
}
bool dependsOnPreviousDot(tt::DotOp dotOp, int operandIdx) {
SetVector<Operation *> bwdSlices;
SmallVector<Operation *> filteredSlices;
Operation *operand = dotOp.getOperand(operandIdx).getDefiningOp();
// Seems like getBackwardSlice(dotOp, bwdSlices, filter) doesn't work
// properly. Do it manually.
getBackwardSlice(operand, &bwdSlices);
std::copy_if(bwdSlices.begin(), bwdSlices.end(),
std::back_inserter(filteredSlices),
[](Operation *op) { return isa<tt::DotOp>(op); });
if (filteredSlices.empty()) {
return false;
}
return true;
}
bool shouldSliceDot(tt::DotOp dotOp) {
auto dotOperand = dotOp.getOperand(0);
auto dotATy = dotOperand.getType().cast<RankedTensorType>();
auto kDim = dotATy.getShape()[1];
if (this->sliceKTile == 0 || this->sliceKTile == kDim) {
return false;
}
return true;
}
void dotSlicingDCE(ArrayRef<Operation *> eraseOps) {
for (Operation *opToErase : llvm::reverse(eraseOps)) {
assert(opToErase);
bool hasUses = false;
for (auto result : opToErase->getResults()) {
if (!result.use_empty()) {
hasUses = true;
}
}
if (hasUses) {
continue;
}
opToErase->erase();
}
}
Operation *getFirstOpToSlice(tt::DotOp dotOp, int operandIdx) {
if (dependsOnPreviousDot(dotOp, operandIdx)) {
return dotOp.getOperand(operandIdx).getDefiningOp();
}
return getLoadInst(dotOp, operandIdx);
}
void runOnOperation() override {
getOperation()->walk([&](tt::DotOp dotOp) {
if (!shouldSliceDot(dotOp)) {
return;
}
OpBuilder builder(dotOp);
SmallVector<Operation *> eraseOps;
auto dotResTy = dotOp.getType().cast<RankedTensorType>();
auto dotOperand = dotOp.getOperand(0);
auto dotATy = dotOperand.getType().cast<RankedTensorType>();
auto dotAShape = dotATy.getShape();
int64_t numSlices = dotAShape[1] / this->sliceKTile;
Value slicedAcc = dotOp.getOperand(2);
auto firstOpToSliceA = getFirstOpToSlice(dotOp, 0);
auto firstOpToSliceB = getFirstOpToSlice(dotOp, 1);
if (setLayoutForSlicing(firstOpToSliceA, /*operandIdx*/ 0)) {
firstOpToSliceA = getFirstOpToSlice(dotOp, /*operandIdx*/ 0);
}
if (setLayoutForSlicing(firstOpToSliceB, /*operandIdx*/ 1)) {
firstOpToSliceB = getFirstOpToSlice(dotOp, /*operandIdx*/ 1);
}
for (int i = 0; i < numSlices; i++) {
auto slicedOperandA = getSlicedDotOperand(
firstOpToSliceA, dotOp, 0, i, this->sliceKTile, builder, eraseOps);
auto slicedOperandB = getSlicedDotOperand(
firstOpToSliceB, dotOp, 1, i, this->sliceKTile, builder, eraseOps);
auto slicedDot = builder.create<tt::DotOp>(
dotOp.getLoc(), dotResTy, slicedOperandA, slicedOperandB, slicedAcc,
dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc());
slicedAcc = slicedDot;
}
eraseOps.push_back((Operation *)dotOp);
dotOp.replaceAllUsesWith(slicedAcc);
dotSlicingDCE(eraseOps);
});
}
};
std::unique_ptr<Pass> mlir::createTritonAMDGPUDotSlicingPass(int sliceKTile) {
return std::make_unique<TritonAMDGPUDotSlicingPass>(sliceKTile);
}

View File

@@ -1861,6 +1861,10 @@ void init_triton_ir(py::module &&m) {
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUStreamPipelinePass());
})
.def("add_tritonamdgpu_dot_slicing_pass",
[](mlir::PassManager &self, int slice_k_tile) {
self.addPass(mlir::createTritonAMDGPUDotSlicingPass(slice_k_tile));
})
.def("add_tritongpu_prefetch_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUPrefetchPass());

View File

@@ -100,7 +100,7 @@ def ttir_to_ttgir(mod, num_warps, warpsize, num_ctas, target):
def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization,
enable_persistent, optimize_epilogue, matrix_inst_type):
enable_persistent, optimize_epilogue, matrix_inst_type, slice_k_tile):
is_cuda = _is_cuda(target)
if is_cuda:
capability = target.capability
@@ -123,6 +123,7 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, e
pm.add_tritongpu_remove_layout_conversions_pass()
if optimize_epilogue:
pm.add_tritongpu_optimize_epilogue_pass()
pm.add_tritonamdgpu_dot_slicing_pass(slice_k_tile)
pm.add_tritongpu_optimize_dot_operands_pass()
if num_stages == 0 and is_hip() and target["matrix_core_version"] != 0:
pm.add_tritongpu_stream_pipeline_pass()
@@ -273,6 +274,7 @@ def make_hash(fn, target, env_vars, device_backend, **kwargs):
num_ctas = kwargs.get("num_ctas", 1)
num_stages = kwargs.get("num_stages", 3)
waves_per_eu = kwargs.get("waves_per_eu", 0)
slice_k_tile = kwargs.get("slice_k_tile", 0)
matrix_instr_nonkdim = kwargs.get("matrix_instr_nonkdim", 0);
enable_warp_specialization = kwargs.get("enable_warp_specialization", False)
enable_persistent = kwargs.get("enable_persistent", False)
@@ -282,7 +284,7 @@ def make_hash(fn, target, env_vars, device_backend, **kwargs):
sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8))
configs_key = [get_conf_key(conf) for conf in configs]
env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())]
key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{matrix_instr_nonkdim}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}"
key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{slice_k_tile}-{matrix_instr_nonkdim}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}"
return hashlib.md5(key.encode("utf-8")).hexdigest()
assert isinstance(fn, str)
ignore_version = kwargs.get('ignore_version', False)
@@ -414,6 +416,7 @@ def compile(fn, **kwargs):
num_ctas = kwargs.get("num_ctas", 1)
num_stages = kwargs.get("num_stages", get_arch_default_num_stages(device_type, capability=capability))
waves_per_eu = kwargs.get("waves_per_eu", 0)
slice_k_tile = kwargs.get("slice_k_tile", 0)
matrix_instr_nonkdim = kwargs.get("matrix_instr_nonkdim", 0)
enable_fp_fusion = kwargs.get("enable_fp_fusion", True)
# TODO[shuhaoj]: Default should be to enable warp specialization once possible
@@ -453,7 +456,7 @@ def compile(fn, **kwargs):
if is_cuda:
stages["ttgir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttgir(
ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info,
enable_warp_specialization, enable_persistent, optimize_epilogue))
enable_warp_specialization, enable_persistent, optimize_epilogue, slice_k_tile))
stages["llir"] = (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, target, tma_infos))
add_cuda_stages(target, extern_libs, stages)
@@ -472,12 +475,13 @@ def compile(fn, **kwargs):
other["optimize_epilogue"] = optimize_epilogue
other["tma_infos"] = tma_infos
other["waves_per_eu"] = waves_per_eu
other["slice_k_tile"] = slice_k_tile
other["matrix_instr_nonkdim"] = matrix_instr_nonkdim
_device_backend.add_stages(target, extern_libs, stages, other)
elif device_type == "xpu":
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, slice_k_tile))
stages["llir"] = (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos))
_device_backend.add_stages(arch, extern_libs, stages)
@@ -556,6 +560,7 @@ def compile(fn, **kwargs):
"num_ctas": num_ctas,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu,
"slice_k_tile": slice_k_tile,
"matrix_instr_nonkdim": matrix_instr_nonkdim,
"enable_warp_specialization": enable_warp_specialization,
"enable_persistent": enable_persistent,
@@ -689,6 +694,7 @@ class CompiledKernel:
self.num_ctas = metadata["num_ctas"]
self.num_stages = metadata["num_stages"]
self.waves_per_eu = metadata["waves_per_eu"]
self.slice_k_tile = metadata["slice_k_tile"]
self.clusterDims = metadata["clusterDims"]
if "tensormaps_info" in metadata:
self.tensormaps_info = metadata["tensormaps_info"]

View File

@@ -351,6 +351,7 @@ class JITFunction(KernelInterface[T]):
num_ctas,
num_stages,
waves_per_eu,
slice_k_tile,
matrix_instr_nonkdim,
enable_warp_specialization,
enable_fp_fusion,
@@ -363,7 +364,7 @@ class JITFunction(KernelInterface[T]):
name = self.fn.__name__
module = self.fn.__module__
arg_reprs = ', '.join([f'{param.name}: {ty}' for param, ty in zip(self.params, key[1])])
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, matrix_instr_nonkdim={matrix_instr_nonkdim}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs}), enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, slice_k_tile={slice_k_tile}, matrix_instr_nonkdim={matrix_instr_nonkdim}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs}), enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
key = str(key)
class LegacyCompiler:
@@ -381,6 +382,7 @@ class JITFunction(KernelInterface[T]):
num_ctas=num_ctas,
num_stages=num_stages,
waves_per_eu=waves_per_eu,
slice_k_tile=slice_k_tile,
enable_warp_specialization=enable_warp_specialization,
enable_fp_fusion=enable_fp_fusion,
extern_libs=extern_libs,
@@ -427,6 +429,7 @@ class JITFunction(KernelInterface[T]):
num_ctas = get_special_arg("num_ctas", 1)
num_stages = get_special_arg("num_stages")
waves_per_eu = get_special_arg("waves_per_eu", 0)
slice_k_tile = get_special_arg("slice_k_tile", 0)
matrix_instr_nonkdim = get_special_arg("matrix_instr_nonkdim", 0)
enable_warp_specialization = get_special_arg("enable_warp_specialization", False)
enable_fp_fusion = get_special_arg("enable_fp_fusion", True)
@@ -503,6 +506,7 @@ class JITFunction(KernelInterface[T]):
num_ctas,
num_stages,
waves_per_eu,
slice_k_tile,
matrix_instr_nonkdim,
enable_warp_specialization,
enable_fp_fusion,
@@ -539,6 +543,7 @@ class JITFunction(KernelInterface[T]):
num_ctas,
num_stages,
waves_per_eu,
slice_k_tile,
matrix_instr_nonkdim,
enable_warp_specialization,
enable_fp_fusion,
@@ -556,6 +561,7 @@ class JITFunction(KernelInterface[T]):
num_ctas=num_ctas,
num_stages=num_stages,
waves_per_eu=waves_per_eu,
slice_k_tile=slice_k_tile,
matrix_instr_nonkdim=matrix_instr_nonkdim,
enable_warp_specialization=enable_warp_specialization,
enable_fp_fusion=enable_fp_fusion,

View File

@@ -449,10 +449,11 @@ class HIPBackend(BaseBackend):
optimize_epilogue = other["optimize_epilogue"]
tma_infos = other["tma_infos"]
waves_per_eu = other["waves_per_eu"]
slice_k_tile = other["slice_k_tile"]
matrix_instr_nonkdim = other["matrix_instr_nonkdim"]
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_instr_nonkdim))
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_instr_nonkdim, slice_k_tile))
stages["llir"] = (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos, waves_per_eu))

View File

@@ -83,11 +83,13 @@ def _attn_fwd_inner(acc, l_i, m_i, q,
# re-tuning.
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 0, 'pre_load_v': True}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=4),
],
key=['Z', 'H', 'N_CTX', 'STAGE', 'BLOCK_DMODEL'],
)

View File

@@ -0,0 +1,209 @@
// RUN: triton-opt %s -split-input-file --tritonamdgpu-dot-slicing=slice-k-tile=32 | FileCheck %s
// CHECK: #[[SLICE_V_LAYOUT:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
// CHECK: #blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
// CHECK: #blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
// CHECK: #[[SLICE_Q_LAYOUT:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
// CHECK: #[[SLICE_K_LAYOUT:.+]] = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
// CHECK: %[[Q_VIEW_SLICE_1:.+]] = triton_gpu.view_slice %[[Q_PTR:.+]][0, 0] [128, 32] [1, 1] : tensor<128x128x!tt.ptr<f16, 1>, #[[SLICE_Q_LAYOUT]]> to tensor<128x32x!tt.ptr<f16, 1>, #[[SLICE_Q_LAYOUT]]>
// CHECK: %[[LOAD_Q_1:.+]] = tt.load %[[Q_VIEW_SLICE_1]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #[[SLICE_Q_LAYOUT]]>
// CHECK: %[[K_VIEW_SLICE_1:.+]] = triton_gpu.view_slice %[[K_PTR:.+]][0, 0] [32, 128] [1, 1] : tensor<128x128x!tt.ptr<f16, 1>, #[[SLICE_K_LAYOUT]]> to tensor<32x128x!tt.ptr<f16, 1>, #[[SLICE_K_LAYOUT]]>
// CHECK: %[[LOAD_K_1:.+]] = tt.load %[[K_VIEW_SLICE_1]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #[[SLICE_K_LAYOUT]]>
// CHECK: %[[QK_DOT_1:.+]] = tt.dot %[[QK_DOT_ARG_1:.+]], %[[QK_DOT_ARG_2:.+]], %[[QK_DOT_ARG_3:.+]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<128x128xf32, #mfma>
// CHECK: %[[Q_VIEW_SLICE_2:.+]] = triton_gpu.view_slice %[[Q_PTR:.+]][0, 32] [128, 32] [1, 1] : tensor<128x128x!tt.ptr<f16, 1>, #[[SLICE_Q_LAYOUT]]> to tensor<128x32x!tt.ptr<f16, 1>, #[[SLICE_Q_LAYOUT]]>
// CHECK: %[[LOAD_Q_2:.+]] = tt.load %[[Q_VIEW_SLICE_2]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #[[SLICE_Q_LAYOUT]]>
// CHECK: %[[K_VIEW_SLICE_2:.+]] = triton_gpu.view_slice %[[K_PTR:.+]][32, 0] [32, 128] [1, 1] : tensor<128x128x!tt.ptr<f16, 1>, #[[SLICE_K_LAYOUT]]> to tensor<32x128x!tt.ptr<f16, 1>, #[[SLICE_K_LAYOUT]]>
// CHECK: %[[LOAD_K_2:.+]] = tt.load %[[K_VIEW_SLICE_2]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #[[SLICE_K_LAYOUT]]>
// CHECK: %[[QK_DOT_2:.+]] = tt.dot %[[QK_DOT_ARG_2:.+]], %[[QK_DOT_ARG_2:.+]], %[[QK_DOT_1]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<128x128xf32, #mfma>
// CHECK: %[[Q_VIEW_SLICE_3:.+]] = triton_gpu.view_slice %[[Q_PTR:.+]][0, 64] [128, 32] [1, 1] : tensor<128x128x!tt.ptr<f16, 1>, #[[SLICE_Q_LAYOUT]]> to tensor<128x32x!tt.ptr<f16, 1>, #[[SLICE_Q_LAYOUT]]>
// CHECK: %[[LOAD_Q_3:.+]] = tt.load %[[Q_VIEW_SLICE_3]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #[[SLICE_Q_LAYOUT]]>
// CHECK: %[[K_VIEW_SLICE_3:.+]] = triton_gpu.view_slice %[[K_PTR:.+]][64, 0] [32, 128] [1, 1] : tensor<128x128x!tt.ptr<f16, 1>, #[[SLICE_K_LAYOUT]]> to tensor<32x128x!tt.ptr<f16, 1>, #[[SLICE_K_LAYOUT]]>
// CHECK: %[[LOAD_K_3:.+]] = tt.load %[[K_VIEW_SLICE_3]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #[[SLICE_K_LAYOUT]]>
// CHECK: %[[QK_DOT_3:.+]] = tt.dot %[[QK_DOT_ARG_3:.+]], %[[QK_DOT_ARG_3:.+]], %[[QK_DOT_2]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<128x128xf32, #mfma>
// CHECK: %[[Q_VIEW_SLICE_4:.+]] = triton_gpu.view_slice %[[Q_PTR:.+]][0, 96] [128, 32] [1, 1] : tensor<128x128x!tt.ptr<f16, 1>, #[[SLICE_Q_LAYOUT]]> to tensor<128x32x!tt.ptr<f16, 1>, #[[SLICE_Q_LAYOUT]]>
// CHECK: %[[LOAD_Q_4:.+]] = tt.load %[[Q_VIEW_SLICE_4]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #[[SLICE_Q_LAYOUT]]>
// CHECK: %[[K_VIEW_SLICE_4:.+]] = triton_gpu.view_slice %[[K_PTR:.+]][96, 0] [32, 128] [1, 1] : tensor<128x128x!tt.ptr<f16, 1>, #[[SLICE_K_LAYOUT]]> to tensor<32x128x!tt.ptr<f16, 1>, #[[SLICE_K_LAYOUT]]>
// CHECK: %[[LOAD_K_4:.+]] = tt.load %[[K_VIEW_SLICE_4]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #[[SLICE_K_LAYOUT]]>
// CHECK: %[[QK_DOT_4:.+]] = tt.dot %[[QK_DOT_ARG_4:.+]], %[[QK_DOT_ARG_4:.+]], %[[QK_DOT_3]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<128x128xf32, #mfma>
// CHECK: %[[QK_VIEW_SLICE_1:.+]] = triton_gpu.view_slice %[[QK_TENSOR:.+]][0, 0] [128, 32] [1, 1] : tensor<128x128xf16, #mfma> to tensor<128x32xf16, #mfma>
// CHECK: %[[QK_DOT_OP_1:.+]] = triton_gpu.convert_layout %[[QK_VIEW_SLICE_1]] : (tensor<128x32xf16, #mfma>) -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
// CHECK: %[[V_VIEW_SLICE_1:.+]] = triton_gpu.view_slice %[[V_TENSOR_PTR:.+]][0, 0] [32, 128] [1, 1] : tensor<128x128x!tt.ptr<f16, 1>, #[[SLICE_V_LAYOUT]]> to tensor<32x128x!tt.ptr<f16, 1>, #[[SLICE_V_LAYOUT]]>
// CHECK: %[[LOAD_V_1:.+]] = tt.load %[[V_VIEW_SLICE_1]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #[[SLICE_V_LAYOUT]]>
// CHECK: %[[V_DOT_OP_1:.+]] = triton_gpu.convert_layout %[[LOAD_V_1]] : (tensor<32x128xf16, #[[SLICE_V_LAYOUT]]>) -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>
// CHECK: %[[QKV_DOT_1:.+]] = tt.dot %[[QK_DOT_OP_1]], %[[V_DOT_OP_1]], %[[QKV_DOT_ACC_1:.+]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<128x128xf32, #mfma>
// CHECK: %[[QK_VIEW_SLICE_2:.+]] = triton_gpu.view_slice %[[QK_TENSOR:.+]][0, 32] [128, 32] [1, 1] : tensor<128x128xf16, #mfma> to tensor<128x32xf16, #mfma>
// CHECK: %[[QK_DOT_OP_2:.+]] = triton_gpu.convert_layout %[[QK_VIEW_SLICE_2]] : (tensor<128x32xf16, #mfma>) -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
// CHECK: %[[V_VIEW_SLICE_2:.+]] = triton_gpu.view_slice %[[V_TENSOR_PTR:.+]][32, 0] [32, 128] [1, 1] : tensor<128x128x!tt.ptr<f16, 1>, #[[SLICE_V_LAYOUT]]> to tensor<32x128x!tt.ptr<f16, 1>, #[[SLICE_V_LAYOUT]]>
// CHECK: %[[LOAD_V_2:.+]] = tt.load %[[V_VIEW_SLICE_2]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #[[SLICE_V_LAYOUT]]>
// CHECK: %[[V_DOT_OP_2:.+]] = triton_gpu.convert_layout %[[LOAD_V_2]] : (tensor<32x128xf16, #[[SLICE_V_LAYOUT]]>) -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>
// CHECK: %[[QKV_DOT_2:.+]] = tt.dot %[[QK_DOT_OP_2]], %[[V_DOT_OP_2]], %[[QKV_DOT_1]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<128x128xf32, #mfma>
// CHECK: %[[QK_VIEW_SLICE_3:.+]] = triton_gpu.view_slice %[[QK_TENSOR:.+]][0, 64] [128, 32] [1, 1] : tensor<128x128xf16, #mfma> to tensor<128x32xf16, #mfma>
// CHECK: %[[QK_DOT_OP_3:.+]] = triton_gpu.convert_layout %[[QK_VIEW_SLICE_3]] : (tensor<128x32xf16, #mfma>) -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
// CHECK: %[[V_VIEW_SLICE_3:.+]] = triton_gpu.view_slice %[[V_TENSOR_PTR:.+]][64, 0] [32, 128] [1, 1] : tensor<128x128x!tt.ptr<f16, 1>, #[[SLICE_V_LAYOUT]]> to tensor<32x128x!tt.ptr<f16, 1>, #[[SLICE_V_LAYOUT]]>
// CHECK: %[[LOAD_V_3:.+]] = tt.load %[[V_VIEW_SLICE_3]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #[[SLICE_V_LAYOUT]]>
// CHECK: %[[V_DOT_OP_3:.+]] = triton_gpu.convert_layout %[[LOAD_V_3]] : (tensor<32x128xf16, #[[SLICE_V_LAYOUT]]>) -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>
// CHECK: %[[QKV_DOT_3:.+]] = tt.dot %[[QK_DOT_OP_3]], %[[V_DOT_OP_3]], %[[QKV_DOT_2]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<128x128xf32, #mfma>
// CHECK: %[[QK_VIEW_SLICE_4:.+]] = triton_gpu.view_slice %[[QK_TENSOR:.+]][0, 96] [128, 32] [1, 1] : tensor<128x128xf16, #mfma> to tensor<128x32xf16, #mfma>
// CHECK: %[[QK_DOT_OP_4:.+]] = triton_gpu.convert_layout %[[QK_VIEW_SLICE_4]] : (tensor<128x32xf16, #mfma>) -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
// CHECK: %[[V_VIEW_SLICE_4:.+]] = triton_gpu.view_slice %[[V_TENSOR_PTR:.+]][96, 0] [32, 128] [1, 1] : tensor<128x128x!tt.ptr<f16, 1>, #[[SLICE_V_LAYOUT]]> to tensor<32x128x!tt.ptr<f16, 1>, #[[SLICE_V_LAYOUT]]>
// CHECK: %[[LOAD_V_4:.+]] = tt.load %[[V_VIEW_SLICE_4]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #[[SLICE_V_LAYOUT]]>
// CHECK: %[[V_DOT_OP_4:.+]] = triton_gpu.convert_layout %[[LOAD_V_4]] : (tensor<32x128xf16, #[[SLICE_V_LAYOUT]]>) -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>
// CHECK: %[[QKV_DOT_4:.+]] = tt.dot %[[QK_DOT_OP_4]], %[[V_DOT_OP_4]], %[[QKV_DOT_3]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<128x128xf32, #mfma>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = true}>
module attributes {"triton_gpu.compute-capability" = 0 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @_attn_fwd_0d1d2d34d5d6de7de8de9c10de11de12de13c14de15de16de17c18de19de20de21c2223de24de(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg10: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg11: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg12: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg13: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg14: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg15: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg16: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg17: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg18: i32, %arg19: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg20: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<1.000000e+00> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>
%cst_0 = arith.constant dense<0xFF800000> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mfma>
%c0_i64 = arith.constant 0 : i64
%c128_i64 = arith.constant 128 : i64
%cst_2 = arith.constant 1.44269502 : f32
%c0_i32 = arith.constant 0 : i32
%c128_i32 = arith.constant 128 : i32
%0 = tt.get_program_id x : i32
%1 = tt.get_program_id y : i32
%2 = arith.muli %1, %arg7 : i32
%3 = tt.addptr %arg0, %2 : !tt.ptr<f16, 1>, i32
%4 = arith.muli %0, %c128_i32 : i32
%5 = arith.extsi %arg8 : i32 to i64
%6 = arith.extsi %4 : i32 to i64
%7 = tt.addptr %arg2, %2 : !tt.ptr<f16, 1>, i32
%8 = arith.extsi %arg14 : i32 to i64
%9 = tt.addptr %arg1, %2 : !tt.ptr<f16, 1>, i32
%10 = arith.extsi %arg11 : i32 to i64
%11 = tt.addptr %arg5, %2 : !tt.ptr<f16, 1>, i32
%12 = arith.extsi %arg17 : i32 to i64
%13 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%14 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%19 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%21 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked2>
%22 = tt.splat %4 : (i32) -> tensor<128xi32, #blocked2>
%23 = arith.addi %22, %21 : tensor<128xi32, #blocked2>
%24 = arith.mulf %arg3, %cst_2 : f32
%25 = tt.splat %6 : (i64) -> tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%26 = tt.splat %6 : (i64) -> tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%27 = arith.extsi %13 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%28 = arith.extsi %14 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%29 = arith.extsi %15 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> to tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%30 = arith.extsi %16 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> to tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%31 = arith.extsi %17 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> to tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%32 = arith.extsi %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> to tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%33 = arith.extsi %19 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> to tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%34 = arith.extsi %20 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%35 = arith.addi %25, %27 : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%36 = arith.addi %26, %28 : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%37 = tt.expand_dims %35 {axis = 1 : i32} : (tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<128x1xi64, #blocked>
%38 = tt.expand_dims %36 {axis = 1 : i32} : (tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<128x1xi64, #blocked>
%39 = tt.splat %5 : (i64) -> tensor<128x1xi64, #blocked>
%40 = arith.muli %37, %39 : tensor<128x1xi64, #blocked>
%41 = tt.splat %3 : (!tt.ptr<f16, 1>) -> tensor<128x1x!tt.ptr<f16, 1>, #blocked>
%42 = tt.addptr %41, %40 : tensor<128x1x!tt.ptr<f16, 1>, #blocked>, tensor<128x1xi64, #blocked>
%43 = tt.broadcast %42 : (tensor<128x1x!tt.ptr<f16, 1>, #blocked>) -> tensor<128x128x!tt.ptr<f16, 1>, #blocked>
%44 = tt.expand_dims %29 {axis = 0 : i32} : (tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi64, #blocked>
%45 = tt.expand_dims %30 {axis = 0 : i32} : (tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi64, #blocked>
%46 = tt.expand_dims %31 {axis = 0 : i32} : (tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi64, #blocked>
%47 = tt.broadcast %44 : (tensor<1x128xi64, #blocked>) -> tensor<128x128xi64, #blocked>
%48 = tt.broadcast %45 : (tensor<1x128xi64, #blocked>) -> tensor<128x128xi64, #blocked>
%49 = tt.broadcast %46 : (tensor<1x128xi64, #blocked>) -> tensor<128x128xi64, #blocked>
%50 = tt.addptr %43, %47 : tensor<128x128x!tt.ptr<f16, 1>, #blocked>, tensor<128x128xi64, #blocked>
%51 = tt.load %50 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #blocked>
%52 = tt.splat %24 : (f32) -> tensor<128x128xf32, #blocked>
%53 = arith.extf %51 : tensor<128x128xf16, #blocked> to tensor<128x128xf32, #blocked>
%54 = arith.mulf %53, %52 : tensor<128x128xf32, #blocked>
%55 = arith.truncf %54 : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked>
%56 = tt.expand_dims %32 {axis = 1 : i32} : (tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi64, #blocked1>
%57 = tt.splat %9 : (!tt.ptr<f16, 1>) -> tensor<128x1x!tt.ptr<f16, 1>, #blocked1>
%58 = tt.addptr %57, %56 : tensor<128x1x!tt.ptr<f16, 1>, #blocked1>, tensor<128x1xi64, #blocked1>
%59 = tt.broadcast %58 : (tensor<128x1x!tt.ptr<f16, 1>, #blocked1>) -> tensor<128x128x!tt.ptr<f16, 1>, #blocked1>
%60 = tt.splat %10 : (i64) -> tensor<1x128xi64, #blocked1>
%61 = tt.splat %8 : (i64) -> tensor<128x1xi64, #blocked>
%62 = tt.splat %7 : (!tt.ptr<f16, 1>) -> tensor<128x1x!tt.ptr<f16, 1>, #blocked>
%63:5 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg22 = %cst_1, %arg23 = %cst, %arg24 = %cst_0, %arg25 = %c0_i64, %arg26 = %c0_i64) -> (tensor<128x128xf32, #mfma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>, i64, i64) : i32 {
%82 = tt.splat %arg26 : (i64) -> tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%83 = arith.addi %82, %33 : tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%84 = tt.expand_dims %83 {axis = 0 : i32} : (tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x128xi64, #blocked1>
%85 = arith.muli %84, %60 : tensor<1x128xi64, #blocked1>
%86 = tt.broadcast %85 : (tensor<1x128xi64, #blocked1>) -> tensor<128x128xi64, #blocked1>
%87 = tt.addptr %59, %86 : tensor<128x128x!tt.ptr<f16, 1>, #blocked1>, tensor<128x128xi64, #blocked1>
%88 = tt.load %87 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #blocked1>
%89 = triton_gpu.convert_layout %55 : (tensor<128x128xf16, #blocked>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
%90 = triton_gpu.convert_layout %88 : (tensor<128x128xf16, #blocked1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>
%91 = tt.dot %89, %90, %cst_1 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<128x128xf32, #mfma>
%92 = "tt.reduce"(%91) <{axis = 1 : i32}> ({
^bb0(%arg27: f32, %arg28: f32):
%120 = arith.maximumf %arg27, %arg28 : f32
tt.reduce.return %120 : f32
}) : (tensor<128x128xf32, #mfma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>
%93 = arith.maximumf %arg24, %92 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>
%94 = tt.expand_dims %93 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>) -> tensor<128x1xf32, #mfma>
%95 = tt.broadcast %94 : (tensor<128x1xf32, #mfma>) -> tensor<128x128xf32, #mfma>
%96 = arith.subf %91, %95 : tensor<128x128xf32, #mfma>
%97 = tt.extern_elementwise %96 {libname = "libdevice", libpath = "/triton/python/triton/language/../third_party/hip/lib/bitcode/cuda2gcn.bc", pure = true, symbol = "__nv_exp2f"} : (tensor<128x128xf32, #mfma>) -> tensor<128x128xf32, #mfma>
%98 = arith.subf %arg24, %93 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>
%99 = tt.extern_elementwise %98 {libname = "libdevice", libpath = "/triton/python/triton/language/../third_party/hip/lib/bitcode/cuda2gcn.bc", pure = true, symbol = "__nv_exp2f"} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>
%100 = tt.expand_dims %99 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>) -> tensor<128x1xf32, #mfma>
%101 = tt.broadcast %100 : (tensor<128x1xf32, #mfma>) -> tensor<128x128xf32, #mfma>
%102 = arith.mulf %arg22, %101 : tensor<128x128xf32, #mfma>
%103 = tt.splat %arg25 : (i64) -> tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%104 = arith.addi %103, %34 : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%105 = tt.expand_dims %104 {axis = 1 : i32} : (tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<128x1xi64, #blocked>
%106 = arith.muli %105, %61 : tensor<128x1xi64, #blocked>
%107 = tt.addptr %62, %106 : tensor<128x1x!tt.ptr<f16, 1>, #blocked>, tensor<128x1xi64, #blocked>
%108 = tt.broadcast %107 : (tensor<128x1x!tt.ptr<f16, 1>, #blocked>) -> tensor<128x128x!tt.ptr<f16, 1>, #blocked>
%109 = tt.addptr %108, %48 : tensor<128x128x!tt.ptr<f16, 1>, #blocked>, tensor<128x128xi64, #blocked>
%110 = tt.load %109 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #blocked>
%111 = arith.truncf %97 : tensor<128x128xf32, #mfma> to tensor<128x128xf16, #mfma>
%112 = triton_gpu.convert_layout %111 : (tensor<128x128xf16, #mfma>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
%113 = triton_gpu.convert_layout %110 : (tensor<128x128xf16, #blocked>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>
%114 = tt.dot %112, %113, %102 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<128x128xf32, #mfma>
%115 = "tt.reduce"(%97) <{axis = 1 : i32}> ({
^bb0(%arg27: f32, %arg28: f32):
%120 = arith.addf %arg27, %arg28 : f32
tt.reduce.return %120 : f32
}) : (tensor<128x128xf32, #mfma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>
%116 = arith.mulf %arg23, %99 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>
%117 = arith.addf %116, %115 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>
%118 = arith.addi %arg25, %c128_i64 : i64
%119 = arith.addi %arg26, %c128_i64 : i64
scf.yield %114, %117, %93, %118, %119 : tensor<128x128xf32, #mfma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>, i64, i64
}
%64 = tt.expand_dims %63#1 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>) -> tensor<128x1xf32, #mfma>
%65 = tt.broadcast %64 : (tensor<128x1xf32, #mfma>) -> tensor<128x128xf32, #mfma>
%66 = arith.divf %63#0, %65 : tensor<128x128xf32, #mfma>
%67 = arith.muli %1, %arg20 : i32
%68 = tt.addptr %arg4, %67 : !tt.ptr<f32, 1>, i32
%69 = tt.splat %68 : (!tt.ptr<f32, 1>) -> tensor<128x!tt.ptr<f32, 1>, #blocked2>
%70 = tt.addptr %69, %23 : tensor<128x!tt.ptr<f32, 1>, #blocked2>, tensor<128xi32, #blocked2>
%71 = tt.extern_elementwise %63#1 {libname = "libdevice", libpath = "/triton/python/triton/language/../third_party/hip/lib/bitcode/cuda2gcn.bc", pure = true, symbol = "__nv_log2f"} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>
%72 = arith.addf %63#2, %71 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>
%73 = triton_gpu.convert_layout %72 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>) -> tensor<128xf32, #blocked2>
tt.store %70, %73 {cache = 1 : i32, evict = 1 : i32} : tensor<128xf32, #blocked2>
%74 = arith.truncf %66 : tensor<128x128xf32, #mfma> to tensor<128x128xf16, #mfma>
%75 = tt.splat %12 : (i64) -> tensor<128x1xi64, #blocked>
%76 = arith.muli %38, %75 : tensor<128x1xi64, #blocked>
%77 = tt.splat %11 : (!tt.ptr<f16, 1>) -> tensor<128x1x!tt.ptr<f16, 1>, #blocked>
%78 = tt.addptr %77, %76 : tensor<128x1x!tt.ptr<f16, 1>, #blocked>, tensor<128x1xi64, #blocked>
%79 = tt.broadcast %78 : (tensor<128x1x!tt.ptr<f16, 1>, #blocked>) -> tensor<128x128x!tt.ptr<f16, 1>, #blocked>
%80 = tt.addptr %79, %49 : tensor<128x128x!tt.ptr<f16, 1>, #blocked>, tensor<128x128xi64, #blocked>
%81 = triton_gpu.convert_layout %80 : (tensor<128x128x!tt.ptr<f16, 1>, #blocked>) -> tensor<128x128x!tt.ptr<f16, 1>, #mfma>
tt.store %81, %74 {cache = 1 : i32, evict = 1 : i32} : tensor<128x128xf16, #mfma>
tt.return
}
}