mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Dot slicing pass (#440)
* First commit * Implement DotSlicing pass. * small fixes * Support chained dot in DotSlicingPass (second GEMM in FA) * Add lit test for FA dot slicing --------- Co-authored-by: Ognjen Plavsic <ognjen.plavsic@luxoft.com> Co-authored-by: Ognjen <oplavsic@luxoft.com>
This commit is contained in:
@@ -12,6 +12,7 @@ std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 3,
|
||||
int computeCapability = 80);
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUStreamPipelinePass();
|
||||
std::unique_ptr<Pass> createTritonAMDGPUDotSlicingPass(int sliceKTile = 0);
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
createTritonGPUAccelerateMatmulPass(int computeCapability = 80);
|
||||
|
||||
@@ -49,6 +49,26 @@ def TritonGPUStreamPipeline : Pass<"tritongpu-stream-pipeline", "mlir::ModuleOp"
|
||||
"mlir::arith::ArithDialect"];
|
||||
}
|
||||
|
||||
def TritonAMDGPUDotSlicing: Pass<"tritonamdgpu-dot-slicing", "mlir::ModuleOp"> {
|
||||
let summary = "'DotOp' instruction slicing";
|
||||
|
||||
let description = [{
|
||||
Slice 'DotOp' instruction into multiple smaller 'DotOp' instructions
|
||||
in order to improve scheduling and latency hiding.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonAMDGPUDotSlicingPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::arith::ArithDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"sliceKTile", "slice-k-tile",
|
||||
"int32_t", /*default*/"0",
|
||||
"slice size in k dimension">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
|
||||
let summary = "prefetch";
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ add_mlir_dialect_library(TritonGPUTransforms
|
||||
RemoveLayoutConversions.cpp
|
||||
ReorderInstructions.cpp
|
||||
StreamPipeline.cpp
|
||||
DotSlicing.cpp
|
||||
TritonGPUConversion.cpp
|
||||
Utility.cpp
|
||||
|
||||
|
||||
464
lib/Dialect/TritonGPU/Transforms/DotSlicing.cpp
Normal file
464
lib/Dialect/TritonGPU/Transforms/DotSlicing.cpp
Normal file
@@ -0,0 +1,464 @@
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// This file implements the dot slicing pass. It slices along the k dimension
|
||||
// in a dot operation (for dot A, B with shape(A) = (m,k), shape(B) = (k,n)).
|
||||
// The slice size along the k dimension is set by 'sliceKTile'.
|
||||
// The algorithm includes three main steps:
|
||||
//
|
||||
// 1) Modify load instruction layouts for dot operands.
|
||||
// - This ensures tensor slices have the same layout as the original tensor,
|
||||
// allowing slicing without data exchange between threads.
|
||||
//
|
||||
// 2) Slice the dot operands.
|
||||
// - Here, the original dot operands are sliced, starting from the load
|
||||
// instruction. This helps in hiding latency in global memory transactions.
|
||||
// Currently, the algorithm assumes the original tensor undergoes only
|
||||
// elementwise operations before being used as a dot operand. In other
|
||||
// words, the algorithm expects only elementwise operations and convert
|
||||
// layout operations to occur between the load and dot instructions.
|
||||
// However, It can be extended to handle other operations if necessary.
|
||||
//
|
||||
// 3) Slice the dot operation along the k dimension.
|
||||
// - This involves calculating the number of slices from 'sliceKTile', and
|
||||
// then creating a new dot operation for each slice using the operands
|
||||
// sliced in the previous step. Sliced dots are concatenated so the result
|
||||
// of each dot is used in the next, leveraging the slicing along the k dim.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
using namespace mlir;
|
||||
namespace tt = triton;
|
||||
namespace ttg = triton::gpu;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
namespace {
|
||||
|
||||
bool isElementwiseOp(Operation *op) {
|
||||
if (llvm::isa<arith::AddFOp, arith::AddIOp, arith::AndIOp, arith::CeilDivSIOp,
|
||||
arith::CeilDivUIOp, arith::DivFOp, arith::DivSIOp,
|
||||
arith::DivUIOp, arith::ExtFOp, arith::ExtSIOp, arith::ExtUIOp,
|
||||
arith::FloorDivSIOp, arith::FPToSIOp, arith::FPToUIOp,
|
||||
arith::MaximumFOp, arith::MaxSIOp, arith::MaxUIOp,
|
||||
arith::MinimumFOp, arith::MinSIOp, arith::MinUIOp,
|
||||
arith::MulFOp, arith::MulIOp, arith::NegFOp, arith::OrIOp,
|
||||
arith::RemFOp, arith::RemSIOp, arith::RemUIOp, arith::ShLIOp,
|
||||
arith::ShRSIOp, arith::ShRUIOp, arith::SIToFPOp, arith::SubFOp,
|
||||
arith::SubIOp, arith::TruncFOp, arith::TruncIOp,
|
||||
arith::UIToFPOp, arith::XOrIOp>(op))
|
||||
return true;
|
||||
if (llvm::isa<math::AbsFOp, math::AbsIOp, math::AtanOp, math::Atan2Op,
|
||||
math::CeilOp, math::CopySignOp, math::CosOp, math::SinOp,
|
||||
math::CountLeadingZerosOp, math::CountTrailingZerosOp,
|
||||
math::CtPopOp, math::ErfOp, math::ExpOp, math::Exp2Op,
|
||||
math::ExpM1Op, math::FloorOp, math::FmaOp, math::LogOp,
|
||||
math::Log10Op, math::Log1pOp, math::Log2Op, math::PowFOp,
|
||||
math::RsqrtOp, math::SqrtOp, math::TanhOp>(op))
|
||||
return true;
|
||||
if (llvm::isa<tt::IntToPtrOp, tt::PtrToIntOp, tt::BitcastOp, tt::FpToFpOp,
|
||||
tt::AddPtrOp>(op))
|
||||
return true;
|
||||
if (auto externElementwiseOp = dyn_cast<tt::ExternElementwiseOp>(op))
|
||||
return externElementwiseOp.getPure();
|
||||
if (llvm::isa<arith::CmpIOp, arith::CmpFOp, arith::SelectOp>(op))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
struct TritonAMDGPUDotSlicingPass
|
||||
: public TritonAMDGPUDotSlicingBase<TritonAMDGPUDotSlicingPass> {
|
||||
TritonAMDGPUDotSlicingPass() = default;
|
||||
|
||||
TritonAMDGPUDotSlicingPass(int sliceKTile) { this->sliceKTile = sliceKTile; }
|
||||
|
||||
// Find user of the currOp that affects dotOperand calculation.
|
||||
// We assume here that there is only one such user.
|
||||
Operation *getUserThatAffectsDotOperand(Operation *currOp,
|
||||
Operation *dotOperand) {
|
||||
SetVector<Operation *> forwardSlices;
|
||||
SmallVector<Operation *> usersThatAffectDot;
|
||||
for (auto *user : currOp->getUsers()) {
|
||||
forwardSlices.clear();
|
||||
getForwardSlice(user, &forwardSlices);
|
||||
|
||||
if (user == dotOperand) {
|
||||
usersThatAffectDot.push_back(user);
|
||||
continue;
|
||||
}
|
||||
if (std::find(forwardSlices.begin(), forwardSlices.end(), dotOperand) !=
|
||||
forwardSlices.end()) {
|
||||
usersThatAffectDot.push_back(user);
|
||||
}
|
||||
}
|
||||
assert(usersThatAffectDot.size() == 1);
|
||||
return usersThatAffectDot[0];
|
||||
}
|
||||
|
||||
Value getSlicedDotOperand(Operation *firstOpToSlice, tt::DotOp dotOp,
|
||||
int operandIdx, int loopIter, int sliceSizeK,
|
||||
OpBuilder builder,
|
||||
SmallVector<Operation *> &eraseOps) {
|
||||
auto ptrTensor = firstOpToSlice->getOperand(0);
|
||||
auto ptrTensorType = ptrTensor.getType().cast<RankedTensorType>();
|
||||
auto dotOperand = dotOp.getOperand(operandIdx);
|
||||
auto dotOperandTy = dotOp.getType().cast<RankedTensorType>();
|
||||
|
||||
SmallVector<int64_t> sliceSizes;
|
||||
SmallVector<int64_t> sliceOffsets;
|
||||
SmallVector<int64_t> sliceStrides{1, 1};
|
||||
if (operandIdx == 0) {
|
||||
sliceSizes.push_back(dotOperandTy.getShape()[0]);
|
||||
sliceSizes.push_back(sliceSizeK);
|
||||
sliceOffsets.push_back(0);
|
||||
sliceOffsets.push_back(loopIter * sliceSizeK);
|
||||
} else {
|
||||
assert(operandIdx == 1);
|
||||
sliceSizes.push_back(sliceSizeK);
|
||||
sliceSizes.push_back(dotOperandTy.getShape()[1]);
|
||||
sliceOffsets.push_back(loopIter * sliceSizeK);
|
||||
sliceOffsets.push_back(0);
|
||||
}
|
||||
|
||||
auto viewPtr = builder.create<ttg::ViewSliceOp>(
|
||||
dotOp.getLoc(),
|
||||
RankedTensorType::get(sliceSizes, ptrTensorType.getElementType(),
|
||||
ptrTensorType.getEncoding()),
|
||||
ptrTensor, ValueRange({}), ValueRange({}), ValueRange({}), sliceOffsets,
|
||||
sliceSizes, sliceStrides);
|
||||
|
||||
// Begin with the load instruction and proceed to slice the operations
|
||||
// along the execution path of the dotOperand.
|
||||
IRMapping mapping;
|
||||
mapping.map(ptrTensor, viewPtr);
|
||||
|
||||
Operation *currOp = firstOpToSlice;
|
||||
Operation *slicedOp = nullptr;
|
||||
while (true) {
|
||||
if (loopIter == 0) {
|
||||
eraseOps.push_back(currOp);
|
||||
}
|
||||
slicedOp = builder.clone(*currOp, mapping);
|
||||
|
||||
// The 'load', 'convert_layout', and 'elementwise' operations each have
|
||||
// one result. This limitation can be removed if necessary.
|
||||
assert(currOp->getNumResults() == 1);
|
||||
// Convert the operation's results to sliced types.
|
||||
for (auto [currRes, slicedRes] :
|
||||
llvm::zip(currOp->getResults(), slicedOp->getResults())) {
|
||||
auto slicedType = RankedTensorType::get(
|
||||
viewPtr.getType().cast<RankedTensorType>().getShape(),
|
||||
currRes.getType().cast<RankedTensorType>().getElementType(),
|
||||
currRes.getType().cast<RankedTensorType>().getEncoding());
|
||||
slicedRes.setType(slicedType);
|
||||
}
|
||||
|
||||
mapping.map(currOp, slicedOp);
|
||||
if (currOp == dotOperand.getDefiningOp()) {
|
||||
break;
|
||||
}
|
||||
assert(llvm::isa<tt::LoadOp>(currOp) ||
|
||||
llvm::isa<ttg::ConvertLayoutOp>(currOp) ||
|
||||
isElementwiseOp(currOp));
|
||||
|
||||
// If currOp has more then one user, proceed with the one that is "on a
|
||||
// path" of dot operand calculation. We expect there is only one such
|
||||
// user.
|
||||
auto currOpUser =
|
||||
getUserThatAffectsDotOperand(currOp, dotOperand.getDefiningOp());
|
||||
|
||||
// The currOpUser operation can have multiple operands, such as in any
|
||||
// binary elementwise op. In such cases, we slice all of the operands
|
||||
// using the same sliceOffsets, sliceSizes, and sliceStrides. This
|
||||
// approach is valid only under the assumption that currOpUser is an
|
||||
// elementwise operation. For non-elementwise operations with multiple
|
||||
// operands, slicing should potentially be handled differently.
|
||||
for (auto operandVal : currOpUser->getOperands()) {
|
||||
auto nonSlicedOperand = operandVal.getDefiningOp();
|
||||
if (nonSlicedOperand == currOp) {
|
||||
continue;
|
||||
}
|
||||
auto nonSlicedOperandTy = nonSlicedOperand->getResults()[0]
|
||||
.getType()
|
||||
.cast<RankedTensorType>();
|
||||
|
||||
auto slicedTy = RankedTensorType::get(
|
||||
sliceSizes, nonSlicedOperandTy.getElementType(),
|
||||
nonSlicedOperandTy.getEncoding());
|
||||
|
||||
auto slicedOperand = builder.create<ttg::ViewSliceOp>(
|
||||
nonSlicedOperand->getLoc(), slicedTy,
|
||||
nonSlicedOperand->getResults()[0], ValueRange({}), ValueRange({}),
|
||||
ValueRange({}), sliceOffsets, sliceSizes, sliceStrides);
|
||||
mapping.map(nonSlicedOperand->getResults()[0], slicedOperand);
|
||||
}
|
||||
|
||||
currOp = currOpUser;
|
||||
}
|
||||
|
||||
assert(llvm::isa<ttg::ConvertLayoutOp>(slicedOp));
|
||||
return slicedOp->getResults()[0];
|
||||
}
|
||||
|
||||
static Type getNewType(Type type, Attribute encoding) {
|
||||
RankedTensorType tensorType = type.cast<RankedTensorType>();
|
||||
return RankedTensorType::get(tensorType.getShape(),
|
||||
tensorType.getElementType(), encoding);
|
||||
}
|
||||
|
||||
// Same as coalesceOp function in Coalesce.cpp.
|
||||
void convertLayout(Attribute encoding, Operation *op) {
|
||||
OpBuilder builder(op);
|
||||
// Convert operands
|
||||
// For load/store with tensor pointers, we don't have to change the
|
||||
// operands' type, we do this by changing the outputs' type of
|
||||
// `make_tensor_ptr`
|
||||
SmallVector<Value, 4> newArgs;
|
||||
for (auto operand : op->getOperands()) {
|
||||
auto tensorType = operand.getType().dyn_cast<RankedTensorType>();
|
||||
if (tensorType &&
|
||||
!tensorType.getEncoding().isa<ttg::SharedEncodingAttr>()) {
|
||||
Type newType = getNewType(tensorType, encoding);
|
||||
newArgs.push_back(builder.create<ttg::ConvertLayoutOp>(
|
||||
op->getLoc(), newType, operand));
|
||||
} else {
|
||||
newArgs.push_back(operand);
|
||||
}
|
||||
}
|
||||
|
||||
// Convert output types
|
||||
SmallVector<Type, 4> newTypes;
|
||||
for (auto t : op->getResultTypes()) {
|
||||
bool isAsync = isa<ttg::InsertSliceAsyncOp>(op);
|
||||
newTypes.push_back(isAsync ? t : getNewType(t, encoding));
|
||||
}
|
||||
|
||||
// Construct new op with the new encoding
|
||||
Operation *newOp =
|
||||
builder.create(op->getLoc(), op->getName().getIdentifier(), newArgs,
|
||||
newTypes, op->getAttrs());
|
||||
|
||||
// Cast the results back to the original layout
|
||||
for (size_t i = 0; i < op->getNumResults(); i++) {
|
||||
Value newResult = newOp->getResult(i);
|
||||
if (newTypes[i] != op->getResultTypes()[i]) {
|
||||
newResult = builder.create<ttg::ConvertLayoutOp>(
|
||||
op->getLoc(), op->getResult(i).getType(), newResult);
|
||||
}
|
||||
op->getResult(i).replaceAllUsesWith(newResult);
|
||||
}
|
||||
op->erase();
|
||||
}
|
||||
|
||||
// Return true if layout was changed, else return false.
|
||||
bool setBlockedLayout(Operation *firstOpToSlice, ArrayRef<long> shape,
|
||||
ttg::BlockedEncodingAttr blockedEncoding,
|
||||
int operandIdx) {
|
||||
auto shapePerCTA = ttg::getShapePerCTATile(blockedEncoding, shape);
|
||||
auto sizePerThread = blockedEncoding.getSizePerThread();
|
||||
auto threadsPerWarp = blockedEncoding.getThreadsPerWarp();
|
||||
auto warpsPerCTA = blockedEncoding.getWarpsPerCTA();
|
||||
ModuleOp mod = getOperation();
|
||||
|
||||
// clang-format off
|
||||
//
|
||||
// Current layout can be used for slicing as is.
|
||||
// Example: sliceKTile = 32, slicing along dim 1 (A operand)
|
||||
// Layout: #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]>
|
||||
//
|
||||
// clang-format on
|
||||
if (this->sliceKTile % shapePerCTA[1 - operandIdx] == 0) {
|
||||
return false;
|
||||
// clang-format off
|
||||
//
|
||||
// Current layout can be used for slicing only by setting warpsPerCTA to 1
|
||||
// along slicing dim.
|
||||
// Example: sliceKTile = 32, slicing along y dim (A operand)
|
||||
// Layout: #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [2, 2], order = [1, 0]>
|
||||
// NewLayout: #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]>
|
||||
//
|
||||
// clang-format on
|
||||
} else if (this->sliceKTile % (shapePerCTA[1 - operandIdx] /
|
||||
warpsPerCTA[1 - operandIdx]) ==
|
||||
0) {
|
||||
SmallVector<unsigned> newWarpsPerCTA(2, warpsPerCTA[0] * warpsPerCTA[1]);
|
||||
newWarpsPerCTA[1 - operandIdx] = 1;
|
||||
auto newBlockedEncoding = ttg::BlockedEncodingAttr::get(
|
||||
mod.getContext(), sizePerThread, threadsPerWarp, newWarpsPerCTA,
|
||||
blockedEncoding.getOrder(), blockedEncoding.getCTALayout());
|
||||
convertLayout(newBlockedEncoding, firstOpToSlice);
|
||||
// clang-format off
|
||||
//
|
||||
// Current layout can be used for slicing by setting warpsPerCTA to 1
|
||||
// along slicing dim and changing ThreadsPerWarp parameter.
|
||||
// Example: sliceKTile = 32, slicing along y dim (A operand)
|
||||
// Layout: #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]>
|
||||
// NewLayout: #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]>
|
||||
//
|
||||
// clang-format on
|
||||
} else if (this->sliceKTile % sizePerThread[operandIdx] == 0) {
|
||||
SmallVector<unsigned> newWarpsPerCTA(2, warpsPerCTA[0] * warpsPerCTA[1]);
|
||||
newWarpsPerCTA[1 - operandIdx] = 1;
|
||||
SmallVector<unsigned> newThreadsPerWarp(2, 1);
|
||||
newThreadsPerWarp[operandIdx] =
|
||||
(threadsPerWarp[0] * threadsPerWarp[1]) /
|
||||
(this->sliceKTile / sizePerThread[1 - operandIdx]);
|
||||
newThreadsPerWarp[1 - operandIdx] =
|
||||
this->sliceKTile / sizePerThread[1 - operandIdx];
|
||||
|
||||
auto newBlockedEncoding = ttg::BlockedEncodingAttr::get(
|
||||
mod.getContext(), sizePerThread, newThreadsPerWarp, newWarpsPerCTA,
|
||||
blockedEncoding.getOrder(), blockedEncoding.getCTALayout());
|
||||
convertLayout(newBlockedEncoding, firstOpToSlice);
|
||||
// In other cases, the sizePerThread parameter would need to be changed,
|
||||
// which can affect coalescing and thus potentially decrease performance.
|
||||
} else {
|
||||
assert(false && "Unexpected layout in DotSlicing pass.");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Return true if layout was changed, else return false.
|
||||
bool setLayoutForSlicing(Operation *firstOpToSlice, int operandIdx) {
|
||||
auto firstOpToSliceTy =
|
||||
firstOpToSlice->getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto srcShape = firstOpToSliceTy.getShape();
|
||||
auto encoding = firstOpToSliceTy.getEncoding();
|
||||
|
||||
if (auto blockedEncoding = dyn_cast<ttg::BlockedEncodingAttr>(encoding)) {
|
||||
return setBlockedLayout(firstOpToSlice, srcShape, blockedEncoding,
|
||||
operandIdx);
|
||||
} else if (auto mfmaEncoding = dyn_cast<ttg::MfmaEncodingAttr>(encoding)) {
|
||||
auto shapePerCTA = ttg::getShapePerCTATile(mfmaEncoding, srcShape);
|
||||
// TODO: Implement changing of mfma layout in case it is not suitable for
|
||||
// slicing (similar as in setBlockedLayout).
|
||||
assert(this->sliceKTile % shapePerCTA[1] == 0);
|
||||
} else {
|
||||
assert(false && "Unsupported layout in setLayoutForSlicing.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
tt::LoadOp getLoadInst(tt::DotOp dotOp, int operandIdx) {
|
||||
auto dotOperand = dotOp.getOperand(operandIdx);
|
||||
SmallVector<tt::LoadOp> loadOpsVec;
|
||||
|
||||
getOperation()->walk([&](tt::LoadOp loadOp) {
|
||||
SetVector<Operation *> forwardSlices;
|
||||
getForwardSlice((Operation *)loadOp, &forwardSlices);
|
||||
if (std::find(forwardSlices.begin(), forwardSlices.end(),
|
||||
dotOperand.getDefiningOp()) != forwardSlices.end()) {
|
||||
loadOpsVec.push_back(loadOp);
|
||||
}
|
||||
});
|
||||
|
||||
// Currently, we expect the dot operand to depend only on one tensor
|
||||
// from global memory (applicable for dot ops that don't depend on other dot
|
||||
// ops). This condition can be lifted if necessary.
|
||||
assert(loadOpsVec.size() == 1);
|
||||
return loadOpsVec[0];
|
||||
}
|
||||
|
||||
bool dependsOnPreviousDot(tt::DotOp dotOp, int operandIdx) {
|
||||
SetVector<Operation *> bwdSlices;
|
||||
SmallVector<Operation *> filteredSlices;
|
||||
Operation *operand = dotOp.getOperand(operandIdx).getDefiningOp();
|
||||
// Seems like getBackwardSlice(dotOp, bwdSlices, filter) doesn't work
|
||||
// properly. Do it manually.
|
||||
getBackwardSlice(operand, &bwdSlices);
|
||||
std::copy_if(bwdSlices.begin(), bwdSlices.end(),
|
||||
std::back_inserter(filteredSlices),
|
||||
[](Operation *op) { return isa<tt::DotOp>(op); });
|
||||
|
||||
if (filteredSlices.empty()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool shouldSliceDot(tt::DotOp dotOp) {
|
||||
auto dotOperand = dotOp.getOperand(0);
|
||||
auto dotATy = dotOperand.getType().cast<RankedTensorType>();
|
||||
auto kDim = dotATy.getShape()[1];
|
||||
|
||||
if (this->sliceKTile == 0 || this->sliceKTile == kDim) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void dotSlicingDCE(ArrayRef<Operation *> eraseOps) {
|
||||
for (Operation *opToErase : llvm::reverse(eraseOps)) {
|
||||
assert(opToErase);
|
||||
bool hasUses = false;
|
||||
for (auto result : opToErase->getResults()) {
|
||||
if (!result.use_empty()) {
|
||||
hasUses = true;
|
||||
}
|
||||
}
|
||||
if (hasUses) {
|
||||
continue;
|
||||
}
|
||||
opToErase->erase();
|
||||
}
|
||||
}
|
||||
|
||||
Operation *getFirstOpToSlice(tt::DotOp dotOp, int operandIdx) {
|
||||
if (dependsOnPreviousDot(dotOp, operandIdx)) {
|
||||
return dotOp.getOperand(operandIdx).getDefiningOp();
|
||||
}
|
||||
return getLoadInst(dotOp, operandIdx);
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
getOperation()->walk([&](tt::DotOp dotOp) {
|
||||
if (!shouldSliceDot(dotOp)) {
|
||||
return;
|
||||
}
|
||||
|
||||
OpBuilder builder(dotOp);
|
||||
SmallVector<Operation *> eraseOps;
|
||||
|
||||
auto dotResTy = dotOp.getType().cast<RankedTensorType>();
|
||||
auto dotOperand = dotOp.getOperand(0);
|
||||
auto dotATy = dotOperand.getType().cast<RankedTensorType>();
|
||||
auto dotAShape = dotATy.getShape();
|
||||
int64_t numSlices = dotAShape[1] / this->sliceKTile;
|
||||
Value slicedAcc = dotOp.getOperand(2);
|
||||
|
||||
auto firstOpToSliceA = getFirstOpToSlice(dotOp, 0);
|
||||
auto firstOpToSliceB = getFirstOpToSlice(dotOp, 1);
|
||||
|
||||
if (setLayoutForSlicing(firstOpToSliceA, /*operandIdx*/ 0)) {
|
||||
firstOpToSliceA = getFirstOpToSlice(dotOp, /*operandIdx*/ 0);
|
||||
}
|
||||
|
||||
if (setLayoutForSlicing(firstOpToSliceB, /*operandIdx*/ 1)) {
|
||||
firstOpToSliceB = getFirstOpToSlice(dotOp, /*operandIdx*/ 1);
|
||||
}
|
||||
|
||||
for (int i = 0; i < numSlices; i++) {
|
||||
auto slicedOperandA = getSlicedDotOperand(
|
||||
firstOpToSliceA, dotOp, 0, i, this->sliceKTile, builder, eraseOps);
|
||||
auto slicedOperandB = getSlicedDotOperand(
|
||||
firstOpToSliceB, dotOp, 1, i, this->sliceKTile, builder, eraseOps);
|
||||
|
||||
auto slicedDot = builder.create<tt::DotOp>(
|
||||
dotOp.getLoc(), dotResTy, slicedOperandA, slicedOperandB, slicedAcc,
|
||||
dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc());
|
||||
slicedAcc = slicedDot;
|
||||
}
|
||||
|
||||
eraseOps.push_back((Operation *)dotOp);
|
||||
dotOp.replaceAllUsesWith(slicedAcc);
|
||||
dotSlicingDCE(eraseOps);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonAMDGPUDotSlicingPass(int sliceKTile) {
|
||||
return std::make_unique<TritonAMDGPUDotSlicingPass>(sliceKTile);
|
||||
}
|
||||
@@ -1861,6 +1861,10 @@ void init_triton_ir(py::module &&m) {
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUStreamPipelinePass());
|
||||
})
|
||||
.def("add_tritonamdgpu_dot_slicing_pass",
|
||||
[](mlir::PassManager &self, int slice_k_tile) {
|
||||
self.addPass(mlir::createTritonAMDGPUDotSlicingPass(slice_k_tile));
|
||||
})
|
||||
.def("add_tritongpu_prefetch_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUPrefetchPass());
|
||||
|
||||
@@ -100,7 +100,7 @@ def ttir_to_ttgir(mod, num_warps, warpsize, num_ctas, target):
|
||||
|
||||
|
||||
def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization,
|
||||
enable_persistent, optimize_epilogue, matrix_inst_type):
|
||||
enable_persistent, optimize_epilogue, matrix_inst_type, slice_k_tile):
|
||||
is_cuda = _is_cuda(target)
|
||||
if is_cuda:
|
||||
capability = target.capability
|
||||
@@ -123,6 +123,7 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, e
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
if optimize_epilogue:
|
||||
pm.add_tritongpu_optimize_epilogue_pass()
|
||||
pm.add_tritonamdgpu_dot_slicing_pass(slice_k_tile)
|
||||
pm.add_tritongpu_optimize_dot_operands_pass()
|
||||
if num_stages == 0 and is_hip() and target["matrix_core_version"] != 0:
|
||||
pm.add_tritongpu_stream_pipeline_pass()
|
||||
@@ -273,6 +274,7 @@ def make_hash(fn, target, env_vars, device_backend, **kwargs):
|
||||
num_ctas = kwargs.get("num_ctas", 1)
|
||||
num_stages = kwargs.get("num_stages", 3)
|
||||
waves_per_eu = kwargs.get("waves_per_eu", 0)
|
||||
slice_k_tile = kwargs.get("slice_k_tile", 0)
|
||||
matrix_instr_nonkdim = kwargs.get("matrix_instr_nonkdim", 0);
|
||||
enable_warp_specialization = kwargs.get("enable_warp_specialization", False)
|
||||
enable_persistent = kwargs.get("enable_persistent", False)
|
||||
@@ -282,7 +284,7 @@ def make_hash(fn, target, env_vars, device_backend, **kwargs):
|
||||
sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8))
|
||||
configs_key = [get_conf_key(conf) for conf in configs]
|
||||
env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())]
|
||||
key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{matrix_instr_nonkdim}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}"
|
||||
key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{slice_k_tile}-{matrix_instr_nonkdim}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}"
|
||||
return hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
assert isinstance(fn, str)
|
||||
ignore_version = kwargs.get('ignore_version', False)
|
||||
@@ -414,6 +416,7 @@ def compile(fn, **kwargs):
|
||||
num_ctas = kwargs.get("num_ctas", 1)
|
||||
num_stages = kwargs.get("num_stages", get_arch_default_num_stages(device_type, capability=capability))
|
||||
waves_per_eu = kwargs.get("waves_per_eu", 0)
|
||||
slice_k_tile = kwargs.get("slice_k_tile", 0)
|
||||
matrix_instr_nonkdim = kwargs.get("matrix_instr_nonkdim", 0)
|
||||
enable_fp_fusion = kwargs.get("enable_fp_fusion", True)
|
||||
# TODO[shuhaoj]: Default should be to enable warp specialization once possible
|
||||
@@ -453,7 +456,7 @@ def compile(fn, **kwargs):
|
||||
if is_cuda:
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttgir(
|
||||
ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info,
|
||||
enable_warp_specialization, enable_persistent, optimize_epilogue))
|
||||
enable_warp_specialization, enable_persistent, optimize_epilogue, slice_k_tile))
|
||||
stages["llir"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, target, tma_infos))
|
||||
add_cuda_stages(target, extern_libs, stages)
|
||||
@@ -472,12 +475,13 @@ def compile(fn, **kwargs):
|
||||
other["optimize_epilogue"] = optimize_epilogue
|
||||
other["tma_infos"] = tma_infos
|
||||
other["waves_per_eu"] = waves_per_eu
|
||||
other["slice_k_tile"] = slice_k_tile
|
||||
other["matrix_instr_nonkdim"] = matrix_instr_nonkdim
|
||||
|
||||
_device_backend.add_stages(target, extern_libs, stages, other)
|
||||
elif device_type == "xpu":
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, slice_k_tile))
|
||||
stages["llir"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos))
|
||||
_device_backend.add_stages(arch, extern_libs, stages)
|
||||
@@ -556,6 +560,7 @@ def compile(fn, **kwargs):
|
||||
"num_ctas": num_ctas,
|
||||
"num_stages": num_stages,
|
||||
"waves_per_eu": waves_per_eu,
|
||||
"slice_k_tile": slice_k_tile,
|
||||
"matrix_instr_nonkdim": matrix_instr_nonkdim,
|
||||
"enable_warp_specialization": enable_warp_specialization,
|
||||
"enable_persistent": enable_persistent,
|
||||
@@ -689,6 +694,7 @@ class CompiledKernel:
|
||||
self.num_ctas = metadata["num_ctas"]
|
||||
self.num_stages = metadata["num_stages"]
|
||||
self.waves_per_eu = metadata["waves_per_eu"]
|
||||
self.slice_k_tile = metadata["slice_k_tile"]
|
||||
self.clusterDims = metadata["clusterDims"]
|
||||
if "tensormaps_info" in metadata:
|
||||
self.tensormaps_info = metadata["tensormaps_info"]
|
||||
|
||||
@@ -351,6 +351,7 @@ class JITFunction(KernelInterface[T]):
|
||||
num_ctas,
|
||||
num_stages,
|
||||
waves_per_eu,
|
||||
slice_k_tile,
|
||||
matrix_instr_nonkdim,
|
||||
enable_warp_specialization,
|
||||
enable_fp_fusion,
|
||||
@@ -363,7 +364,7 @@ class JITFunction(KernelInterface[T]):
|
||||
name = self.fn.__name__
|
||||
module = self.fn.__module__
|
||||
arg_reprs = ', '.join([f'{param.name}: {ty}' for param, ty in zip(self.params, key[1])])
|
||||
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, matrix_instr_nonkdim={matrix_instr_nonkdim}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs}), enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
|
||||
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, slice_k_tile={slice_k_tile}, matrix_instr_nonkdim={matrix_instr_nonkdim}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs}), enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
|
||||
key = str(key)
|
||||
|
||||
class LegacyCompiler:
|
||||
@@ -381,6 +382,7 @@ class JITFunction(KernelInterface[T]):
|
||||
num_ctas=num_ctas,
|
||||
num_stages=num_stages,
|
||||
waves_per_eu=waves_per_eu,
|
||||
slice_k_tile=slice_k_tile,
|
||||
enable_warp_specialization=enable_warp_specialization,
|
||||
enable_fp_fusion=enable_fp_fusion,
|
||||
extern_libs=extern_libs,
|
||||
@@ -427,6 +429,7 @@ class JITFunction(KernelInterface[T]):
|
||||
num_ctas = get_special_arg("num_ctas", 1)
|
||||
num_stages = get_special_arg("num_stages")
|
||||
waves_per_eu = get_special_arg("waves_per_eu", 0)
|
||||
slice_k_tile = get_special_arg("slice_k_tile", 0)
|
||||
matrix_instr_nonkdim = get_special_arg("matrix_instr_nonkdim", 0)
|
||||
enable_warp_specialization = get_special_arg("enable_warp_specialization", False)
|
||||
enable_fp_fusion = get_special_arg("enable_fp_fusion", True)
|
||||
@@ -503,6 +506,7 @@ class JITFunction(KernelInterface[T]):
|
||||
num_ctas,
|
||||
num_stages,
|
||||
waves_per_eu,
|
||||
slice_k_tile,
|
||||
matrix_instr_nonkdim,
|
||||
enable_warp_specialization,
|
||||
enable_fp_fusion,
|
||||
@@ -539,6 +543,7 @@ class JITFunction(KernelInterface[T]):
|
||||
num_ctas,
|
||||
num_stages,
|
||||
waves_per_eu,
|
||||
slice_k_tile,
|
||||
matrix_instr_nonkdim,
|
||||
enable_warp_specialization,
|
||||
enable_fp_fusion,
|
||||
@@ -556,6 +561,7 @@ class JITFunction(KernelInterface[T]):
|
||||
num_ctas=num_ctas,
|
||||
num_stages=num_stages,
|
||||
waves_per_eu=waves_per_eu,
|
||||
slice_k_tile=slice_k_tile,
|
||||
matrix_instr_nonkdim=matrix_instr_nonkdim,
|
||||
enable_warp_specialization=enable_warp_specialization,
|
||||
enable_fp_fusion=enable_fp_fusion,
|
||||
|
||||
3
python/triton/third_party/hip/hip_backend.py
vendored
3
python/triton/third_party/hip/hip_backend.py
vendored
@@ -449,10 +449,11 @@ class HIPBackend(BaseBackend):
|
||||
optimize_epilogue = other["optimize_epilogue"]
|
||||
tma_infos = other["tma_infos"]
|
||||
waves_per_eu = other["waves_per_eu"]
|
||||
slice_k_tile = other["slice_k_tile"]
|
||||
matrix_instr_nonkdim = other["matrix_instr_nonkdim"]
|
||||
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_instr_nonkdim))
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_instr_nonkdim, slice_k_tile))
|
||||
stages["llir"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos, waves_per_eu))
|
||||
|
||||
|
||||
@@ -83,11 +83,13 @@ def _attn_fwd_inner(acc, l_i, m_i, q,
|
||||
# re-tuning.
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 0, 'pre_load_v': True}, num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=4),
|
||||
],
|
||||
key=['Z', 'H', 'N_CTX', 'STAGE', 'BLOCK_DMODEL'],
|
||||
)
|
||||
|
||||
209
test/TritonGPU/dot-slicing.mlir
Normal file
209
test/TritonGPU/dot-slicing.mlir
Normal file
@@ -0,0 +1,209 @@
|
||||
// RUN: triton-opt %s -split-input-file --tritonamdgpu-dot-slicing=slice-k-tile=32 | FileCheck %s
|
||||
|
||||
// CHECK: #[[SLICE_V_LAYOUT:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
// CHECK: #blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
// CHECK: #blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
|
||||
// CHECK: #[[SLICE_Q_LAYOUT:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
// CHECK: #[[SLICE_K_LAYOUT:.+]] = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
|
||||
// CHECK: %[[Q_VIEW_SLICE_1:.+]] = triton_gpu.view_slice %[[Q_PTR:.+]][0, 0] [128, 32] [1, 1] : tensor<128x128x!tt.ptr<f16, 1>, #[[SLICE_Q_LAYOUT]]> to tensor<128x32x!tt.ptr<f16, 1>, #[[SLICE_Q_LAYOUT]]>
|
||||
// CHECK: %[[LOAD_Q_1:.+]] = tt.load %[[Q_VIEW_SLICE_1]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #[[SLICE_Q_LAYOUT]]>
|
||||
// CHECK: %[[K_VIEW_SLICE_1:.+]] = triton_gpu.view_slice %[[K_PTR:.+]][0, 0] [32, 128] [1, 1] : tensor<128x128x!tt.ptr<f16, 1>, #[[SLICE_K_LAYOUT]]> to tensor<32x128x!tt.ptr<f16, 1>, #[[SLICE_K_LAYOUT]]>
|
||||
// CHECK: %[[LOAD_K_1:.+]] = tt.load %[[K_VIEW_SLICE_1]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #[[SLICE_K_LAYOUT]]>
|
||||
// CHECK: %[[QK_DOT_1:.+]] = tt.dot %[[QK_DOT_ARG_1:.+]], %[[QK_DOT_ARG_2:.+]], %[[QK_DOT_ARG_3:.+]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<128x128xf32, #mfma>
|
||||
|
||||
// CHECK: %[[Q_VIEW_SLICE_2:.+]] = triton_gpu.view_slice %[[Q_PTR:.+]][0, 32] [128, 32] [1, 1] : tensor<128x128x!tt.ptr<f16, 1>, #[[SLICE_Q_LAYOUT]]> to tensor<128x32x!tt.ptr<f16, 1>, #[[SLICE_Q_LAYOUT]]>
|
||||
// CHECK: %[[LOAD_Q_2:.+]] = tt.load %[[Q_VIEW_SLICE_2]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #[[SLICE_Q_LAYOUT]]>
|
||||
// CHECK: %[[K_VIEW_SLICE_2:.+]] = triton_gpu.view_slice %[[K_PTR:.+]][32, 0] [32, 128] [1, 1] : tensor<128x128x!tt.ptr<f16, 1>, #[[SLICE_K_LAYOUT]]> to tensor<32x128x!tt.ptr<f16, 1>, #[[SLICE_K_LAYOUT]]>
|
||||
// CHECK: %[[LOAD_K_2:.+]] = tt.load %[[K_VIEW_SLICE_2]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #[[SLICE_K_LAYOUT]]>
|
||||
// CHECK: %[[QK_DOT_2:.+]] = tt.dot %[[QK_DOT_ARG_2:.+]], %[[QK_DOT_ARG_2:.+]], %[[QK_DOT_1]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<128x128xf32, #mfma>
|
||||
|
||||
// CHECK: %[[Q_VIEW_SLICE_3:.+]] = triton_gpu.view_slice %[[Q_PTR:.+]][0, 64] [128, 32] [1, 1] : tensor<128x128x!tt.ptr<f16, 1>, #[[SLICE_Q_LAYOUT]]> to tensor<128x32x!tt.ptr<f16, 1>, #[[SLICE_Q_LAYOUT]]>
|
||||
// CHECK: %[[LOAD_Q_3:.+]] = tt.load %[[Q_VIEW_SLICE_3]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #[[SLICE_Q_LAYOUT]]>
|
||||
// CHECK: %[[K_VIEW_SLICE_3:.+]] = triton_gpu.view_slice %[[K_PTR:.+]][64, 0] [32, 128] [1, 1] : tensor<128x128x!tt.ptr<f16, 1>, #[[SLICE_K_LAYOUT]]> to tensor<32x128x!tt.ptr<f16, 1>, #[[SLICE_K_LAYOUT]]>
|
||||
// CHECK: %[[LOAD_K_3:.+]] = tt.load %[[K_VIEW_SLICE_3]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #[[SLICE_K_LAYOUT]]>
|
||||
// CHECK: %[[QK_DOT_3:.+]] = tt.dot %[[QK_DOT_ARG_3:.+]], %[[QK_DOT_ARG_3:.+]], %[[QK_DOT_2]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<128x128xf32, #mfma>
|
||||
|
||||
// CHECK: %[[Q_VIEW_SLICE_4:.+]] = triton_gpu.view_slice %[[Q_PTR:.+]][0, 96] [128, 32] [1, 1] : tensor<128x128x!tt.ptr<f16, 1>, #[[SLICE_Q_LAYOUT]]> to tensor<128x32x!tt.ptr<f16, 1>, #[[SLICE_Q_LAYOUT]]>
|
||||
// CHECK: %[[LOAD_Q_4:.+]] = tt.load %[[Q_VIEW_SLICE_4]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #[[SLICE_Q_LAYOUT]]>
|
||||
// CHECK: %[[K_VIEW_SLICE_4:.+]] = triton_gpu.view_slice %[[K_PTR:.+]][96, 0] [32, 128] [1, 1] : tensor<128x128x!tt.ptr<f16, 1>, #[[SLICE_K_LAYOUT]]> to tensor<32x128x!tt.ptr<f16, 1>, #[[SLICE_K_LAYOUT]]>
|
||||
// CHECK: %[[LOAD_K_4:.+]] = tt.load %[[K_VIEW_SLICE_4]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #[[SLICE_K_LAYOUT]]>
|
||||
// CHECK: %[[QK_DOT_4:.+]] = tt.dot %[[QK_DOT_ARG_4:.+]], %[[QK_DOT_ARG_4:.+]], %[[QK_DOT_3]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<128x128xf32, #mfma>
|
||||
|
||||
// CHECK: %[[QK_VIEW_SLICE_1:.+]] = triton_gpu.view_slice %[[QK_TENSOR:.+]][0, 0] [128, 32] [1, 1] : tensor<128x128xf16, #mfma> to tensor<128x32xf16, #mfma>
|
||||
// CHECK: %[[QK_DOT_OP_1:.+]] = triton_gpu.convert_layout %[[QK_VIEW_SLICE_1]] : (tensor<128x32xf16, #mfma>) -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
|
||||
// CHECK: %[[V_VIEW_SLICE_1:.+]] = triton_gpu.view_slice %[[V_TENSOR_PTR:.+]][0, 0] [32, 128] [1, 1] : tensor<128x128x!tt.ptr<f16, 1>, #[[SLICE_V_LAYOUT]]> to tensor<32x128x!tt.ptr<f16, 1>, #[[SLICE_V_LAYOUT]]>
|
||||
// CHECK: %[[LOAD_V_1:.+]] = tt.load %[[V_VIEW_SLICE_1]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #[[SLICE_V_LAYOUT]]>
|
||||
// CHECK: %[[V_DOT_OP_1:.+]] = triton_gpu.convert_layout %[[LOAD_V_1]] : (tensor<32x128xf16, #[[SLICE_V_LAYOUT]]>) -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>
|
||||
// CHECK: %[[QKV_DOT_1:.+]] = tt.dot %[[QK_DOT_OP_1]], %[[V_DOT_OP_1]], %[[QKV_DOT_ACC_1:.+]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<128x128xf32, #mfma>
|
||||
|
||||
// CHECK: %[[QK_VIEW_SLICE_2:.+]] = triton_gpu.view_slice %[[QK_TENSOR:.+]][0, 32] [128, 32] [1, 1] : tensor<128x128xf16, #mfma> to tensor<128x32xf16, #mfma>
|
||||
// CHECK: %[[QK_DOT_OP_2:.+]] = triton_gpu.convert_layout %[[QK_VIEW_SLICE_2]] : (tensor<128x32xf16, #mfma>) -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
|
||||
// CHECK: %[[V_VIEW_SLICE_2:.+]] = triton_gpu.view_slice %[[V_TENSOR_PTR:.+]][32, 0] [32, 128] [1, 1] : tensor<128x128x!tt.ptr<f16, 1>, #[[SLICE_V_LAYOUT]]> to tensor<32x128x!tt.ptr<f16, 1>, #[[SLICE_V_LAYOUT]]>
|
||||
// CHECK: %[[LOAD_V_2:.+]] = tt.load %[[V_VIEW_SLICE_2]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #[[SLICE_V_LAYOUT]]>
|
||||
// CHECK: %[[V_DOT_OP_2:.+]] = triton_gpu.convert_layout %[[LOAD_V_2]] : (tensor<32x128xf16, #[[SLICE_V_LAYOUT]]>) -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>
|
||||
// CHECK: %[[QKV_DOT_2:.+]] = tt.dot %[[QK_DOT_OP_2]], %[[V_DOT_OP_2]], %[[QKV_DOT_1]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<128x128xf32, #mfma>
|
||||
|
||||
// CHECK: %[[QK_VIEW_SLICE_3:.+]] = triton_gpu.view_slice %[[QK_TENSOR:.+]][0, 64] [128, 32] [1, 1] : tensor<128x128xf16, #mfma> to tensor<128x32xf16, #mfma>
|
||||
// CHECK: %[[QK_DOT_OP_3:.+]] = triton_gpu.convert_layout %[[QK_VIEW_SLICE_3]] : (tensor<128x32xf16, #mfma>) -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
|
||||
// CHECK: %[[V_VIEW_SLICE_3:.+]] = triton_gpu.view_slice %[[V_TENSOR_PTR:.+]][64, 0] [32, 128] [1, 1] : tensor<128x128x!tt.ptr<f16, 1>, #[[SLICE_V_LAYOUT]]> to tensor<32x128x!tt.ptr<f16, 1>, #[[SLICE_V_LAYOUT]]>
|
||||
// CHECK: %[[LOAD_V_3:.+]] = tt.load %[[V_VIEW_SLICE_3]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #[[SLICE_V_LAYOUT]]>
|
||||
// CHECK: %[[V_DOT_OP_3:.+]] = triton_gpu.convert_layout %[[LOAD_V_3]] : (tensor<32x128xf16, #[[SLICE_V_LAYOUT]]>) -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>
|
||||
// CHECK: %[[QKV_DOT_3:.+]] = tt.dot %[[QK_DOT_OP_3]], %[[V_DOT_OP_3]], %[[QKV_DOT_2]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<128x128xf32, #mfma>
|
||||
|
||||
// CHECK: %[[QK_VIEW_SLICE_4:.+]] = triton_gpu.view_slice %[[QK_TENSOR:.+]][0, 96] [128, 32] [1, 1] : tensor<128x128xf16, #mfma> to tensor<128x32xf16, #mfma>
|
||||
// CHECK: %[[QK_DOT_OP_4:.+]] = triton_gpu.convert_layout %[[QK_VIEW_SLICE_4]] : (tensor<128x32xf16, #mfma>) -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
|
||||
// CHECK: %[[V_VIEW_SLICE_4:.+]] = triton_gpu.view_slice %[[V_TENSOR_PTR:.+]][96, 0] [32, 128] [1, 1] : tensor<128x128x!tt.ptr<f16, 1>, #[[SLICE_V_LAYOUT]]> to tensor<32x128x!tt.ptr<f16, 1>, #[[SLICE_V_LAYOUT]]>
|
||||
// CHECK: %[[LOAD_V_4:.+]] = tt.load %[[V_VIEW_SLICE_4]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #[[SLICE_V_LAYOUT]]>
|
||||
// CHECK: %[[V_DOT_OP_4:.+]] = triton_gpu.convert_layout %[[LOAD_V_4]] : (tensor<32x128xf16, #[[SLICE_V_LAYOUT]]>) -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>
|
||||
// CHECK: %[[QKV_DOT_4:.+]] = tt.dot %[[QK_DOT_OP_4]], %[[V_DOT_OP_4]], %[[QKV_DOT_3]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<128x128xf32, #mfma>
|
||||
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
|
||||
#mfma = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = true}>
|
||||
module attributes {"triton_gpu.compute-capability" = 0 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
|
||||
tt.func public @_attn_fwd_0d1d2d34d5d6de7de8de9c10de11de12de13c14de15de16de17c18de19de20de21c2223de24de(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg10: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg11: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg12: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg13: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg14: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg15: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg16: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg17: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg18: i32, %arg19: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg20: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
|
||||
%cst = arith.constant dense<1.000000e+00> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>
|
||||
%cst_0 = arith.constant dense<0xFF800000> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>
|
||||
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mfma>
|
||||
%c0_i64 = arith.constant 0 : i64
|
||||
%c128_i64 = arith.constant 128 : i64
|
||||
%cst_2 = arith.constant 1.44269502 : f32
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%c128_i32 = arith.constant 128 : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = tt.get_program_id y : i32
|
||||
%2 = arith.muli %1, %arg7 : i32
|
||||
%3 = tt.addptr %arg0, %2 : !tt.ptr<f16, 1>, i32
|
||||
%4 = arith.muli %0, %c128_i32 : i32
|
||||
%5 = arith.extsi %arg8 : i32 to i64
|
||||
%6 = arith.extsi %4 : i32 to i64
|
||||
%7 = tt.addptr %arg2, %2 : !tt.ptr<f16, 1>, i32
|
||||
%8 = arith.extsi %arg14 : i32 to i64
|
||||
%9 = tt.addptr %arg1, %2 : !tt.ptr<f16, 1>, i32
|
||||
%10 = arith.extsi %arg11 : i32 to i64
|
||||
%11 = tt.addptr %arg5, %2 : !tt.ptr<f16, 1>, i32
|
||||
%12 = arith.extsi %arg17 : i32 to i64
|
||||
%13 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
%14 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
%15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
|
||||
%16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
|
||||
%17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
|
||||
%18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%19 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
%21 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked2>
|
||||
%22 = tt.splat %4 : (i32) -> tensor<128xi32, #blocked2>
|
||||
%23 = arith.addi %22, %21 : tensor<128xi32, #blocked2>
|
||||
%24 = arith.mulf %arg3, %cst_2 : f32
|
||||
%25 = tt.splat %6 : (i64) -> tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
%26 = tt.splat %6 : (i64) -> tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
%27 = arith.extsi %13 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
%28 = arith.extsi %14 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
%29 = arith.extsi %15 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> to tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
|
||||
%30 = arith.extsi %16 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> to tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
|
||||
%31 = arith.extsi %17 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> to tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
|
||||
%32 = arith.extsi %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> to tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%33 = arith.extsi %19 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> to tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%34 = arith.extsi %20 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
%35 = arith.addi %25, %27 : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
%36 = arith.addi %26, %28 : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
%37 = tt.expand_dims %35 {axis = 1 : i32} : (tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<128x1xi64, #blocked>
|
||||
%38 = tt.expand_dims %36 {axis = 1 : i32} : (tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<128x1xi64, #blocked>
|
||||
%39 = tt.splat %5 : (i64) -> tensor<128x1xi64, #blocked>
|
||||
%40 = arith.muli %37, %39 : tensor<128x1xi64, #blocked>
|
||||
%41 = tt.splat %3 : (!tt.ptr<f16, 1>) -> tensor<128x1x!tt.ptr<f16, 1>, #blocked>
|
||||
%42 = tt.addptr %41, %40 : tensor<128x1x!tt.ptr<f16, 1>, #blocked>, tensor<128x1xi64, #blocked>
|
||||
%43 = tt.broadcast %42 : (tensor<128x1x!tt.ptr<f16, 1>, #blocked>) -> tensor<128x128x!tt.ptr<f16, 1>, #blocked>
|
||||
%44 = tt.expand_dims %29 {axis = 0 : i32} : (tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi64, #blocked>
|
||||
%45 = tt.expand_dims %30 {axis = 0 : i32} : (tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi64, #blocked>
|
||||
%46 = tt.expand_dims %31 {axis = 0 : i32} : (tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi64, #blocked>
|
||||
%47 = tt.broadcast %44 : (tensor<1x128xi64, #blocked>) -> tensor<128x128xi64, #blocked>
|
||||
%48 = tt.broadcast %45 : (tensor<1x128xi64, #blocked>) -> tensor<128x128xi64, #blocked>
|
||||
%49 = tt.broadcast %46 : (tensor<1x128xi64, #blocked>) -> tensor<128x128xi64, #blocked>
|
||||
%50 = tt.addptr %43, %47 : tensor<128x128x!tt.ptr<f16, 1>, #blocked>, tensor<128x128xi64, #blocked>
|
||||
%51 = tt.load %50 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #blocked>
|
||||
%52 = tt.splat %24 : (f32) -> tensor<128x128xf32, #blocked>
|
||||
%53 = arith.extf %51 : tensor<128x128xf16, #blocked> to tensor<128x128xf32, #blocked>
|
||||
%54 = arith.mulf %53, %52 : tensor<128x128xf32, #blocked>
|
||||
%55 = arith.truncf %54 : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked>
|
||||
%56 = tt.expand_dims %32 {axis = 1 : i32} : (tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi64, #blocked1>
|
||||
%57 = tt.splat %9 : (!tt.ptr<f16, 1>) -> tensor<128x1x!tt.ptr<f16, 1>, #blocked1>
|
||||
%58 = tt.addptr %57, %56 : tensor<128x1x!tt.ptr<f16, 1>, #blocked1>, tensor<128x1xi64, #blocked1>
|
||||
%59 = tt.broadcast %58 : (tensor<128x1x!tt.ptr<f16, 1>, #blocked1>) -> tensor<128x128x!tt.ptr<f16, 1>, #blocked1>
|
||||
%60 = tt.splat %10 : (i64) -> tensor<1x128xi64, #blocked1>
|
||||
%61 = tt.splat %8 : (i64) -> tensor<128x1xi64, #blocked>
|
||||
%62 = tt.splat %7 : (!tt.ptr<f16, 1>) -> tensor<128x1x!tt.ptr<f16, 1>, #blocked>
|
||||
%63:5 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg22 = %cst_1, %arg23 = %cst, %arg24 = %cst_0, %arg25 = %c0_i64, %arg26 = %c0_i64) -> (tensor<128x128xf32, #mfma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>, i64, i64) : i32 {
|
||||
%82 = tt.splat %arg26 : (i64) -> tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%83 = arith.addi %82, %33 : tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%84 = tt.expand_dims %83 {axis = 0 : i32} : (tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x128xi64, #blocked1>
|
||||
%85 = arith.muli %84, %60 : tensor<1x128xi64, #blocked1>
|
||||
%86 = tt.broadcast %85 : (tensor<1x128xi64, #blocked1>) -> tensor<128x128xi64, #blocked1>
|
||||
%87 = tt.addptr %59, %86 : tensor<128x128x!tt.ptr<f16, 1>, #blocked1>, tensor<128x128xi64, #blocked1>
|
||||
%88 = tt.load %87 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #blocked1>
|
||||
%89 = triton_gpu.convert_layout %55 : (tensor<128x128xf16, #blocked>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
|
||||
%90 = triton_gpu.convert_layout %88 : (tensor<128x128xf16, #blocked1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>
|
||||
%91 = tt.dot %89, %90, %cst_1 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<128x128xf32, #mfma>
|
||||
%92 = "tt.reduce"(%91) <{axis = 1 : i32}> ({
|
||||
^bb0(%arg27: f32, %arg28: f32):
|
||||
%120 = arith.maximumf %arg27, %arg28 : f32
|
||||
tt.reduce.return %120 : f32
|
||||
}) : (tensor<128x128xf32, #mfma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>
|
||||
%93 = arith.maximumf %arg24, %92 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>
|
||||
%94 = tt.expand_dims %93 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>) -> tensor<128x1xf32, #mfma>
|
||||
%95 = tt.broadcast %94 : (tensor<128x1xf32, #mfma>) -> tensor<128x128xf32, #mfma>
|
||||
%96 = arith.subf %91, %95 : tensor<128x128xf32, #mfma>
|
||||
%97 = tt.extern_elementwise %96 {libname = "libdevice", libpath = "/triton/python/triton/language/../third_party/hip/lib/bitcode/cuda2gcn.bc", pure = true, symbol = "__nv_exp2f"} : (tensor<128x128xf32, #mfma>) -> tensor<128x128xf32, #mfma>
|
||||
%98 = arith.subf %arg24, %93 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>
|
||||
%99 = tt.extern_elementwise %98 {libname = "libdevice", libpath = "/triton/python/triton/language/../third_party/hip/lib/bitcode/cuda2gcn.bc", pure = true, symbol = "__nv_exp2f"} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>
|
||||
%100 = tt.expand_dims %99 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>) -> tensor<128x1xf32, #mfma>
|
||||
%101 = tt.broadcast %100 : (tensor<128x1xf32, #mfma>) -> tensor<128x128xf32, #mfma>
|
||||
%102 = arith.mulf %arg22, %101 : tensor<128x128xf32, #mfma>
|
||||
%103 = tt.splat %arg25 : (i64) -> tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
%104 = arith.addi %103, %34 : tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
%105 = tt.expand_dims %104 {axis = 1 : i32} : (tensor<128xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<128x1xi64, #blocked>
|
||||
%106 = arith.muli %105, %61 : tensor<128x1xi64, #blocked>
|
||||
%107 = tt.addptr %62, %106 : tensor<128x1x!tt.ptr<f16, 1>, #blocked>, tensor<128x1xi64, #blocked>
|
||||
%108 = tt.broadcast %107 : (tensor<128x1x!tt.ptr<f16, 1>, #blocked>) -> tensor<128x128x!tt.ptr<f16, 1>, #blocked>
|
||||
%109 = tt.addptr %108, %48 : tensor<128x128x!tt.ptr<f16, 1>, #blocked>, tensor<128x128xi64, #blocked>
|
||||
%110 = tt.load %109 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #blocked>
|
||||
%111 = arith.truncf %97 : tensor<128x128xf32, #mfma> to tensor<128x128xf16, #mfma>
|
||||
%112 = triton_gpu.convert_layout %111 : (tensor<128x128xf16, #mfma>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
|
||||
%113 = triton_gpu.convert_layout %110 : (tensor<128x128xf16, #blocked>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>
|
||||
%114 = tt.dot %112, %113, %102 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<128x128xf32, #mfma>
|
||||
%115 = "tt.reduce"(%97) <{axis = 1 : i32}> ({
|
||||
^bb0(%arg27: f32, %arg28: f32):
|
||||
%120 = arith.addf %arg27, %arg28 : f32
|
||||
tt.reduce.return %120 : f32
|
||||
}) : (tensor<128x128xf32, #mfma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>
|
||||
%116 = arith.mulf %arg23, %99 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>
|
||||
%117 = arith.addf %116, %115 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>
|
||||
%118 = arith.addi %arg25, %c128_i64 : i64
|
||||
%119 = arith.addi %arg26, %c128_i64 : i64
|
||||
scf.yield %114, %117, %93, %118, %119 : tensor<128x128xf32, #mfma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>, i64, i64
|
||||
}
|
||||
%64 = tt.expand_dims %63#1 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>) -> tensor<128x1xf32, #mfma>
|
||||
%65 = tt.broadcast %64 : (tensor<128x1xf32, #mfma>) -> tensor<128x128xf32, #mfma>
|
||||
%66 = arith.divf %63#0, %65 : tensor<128x128xf32, #mfma>
|
||||
%67 = arith.muli %1, %arg20 : i32
|
||||
%68 = tt.addptr %arg4, %67 : !tt.ptr<f32, 1>, i32
|
||||
%69 = tt.splat %68 : (!tt.ptr<f32, 1>) -> tensor<128x!tt.ptr<f32, 1>, #blocked2>
|
||||
%70 = tt.addptr %69, %23 : tensor<128x!tt.ptr<f32, 1>, #blocked2>, tensor<128xi32, #blocked2>
|
||||
%71 = tt.extern_elementwise %63#1 {libname = "libdevice", libpath = "/triton/python/triton/language/../third_party/hip/lib/bitcode/cuda2gcn.bc", pure = true, symbol = "__nv_log2f"} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>
|
||||
%72 = arith.addf %63#2, %71 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>
|
||||
%73 = triton_gpu.convert_layout %72 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mfma}>>) -> tensor<128xf32, #blocked2>
|
||||
tt.store %70, %73 {cache = 1 : i32, evict = 1 : i32} : tensor<128xf32, #blocked2>
|
||||
%74 = arith.truncf %66 : tensor<128x128xf32, #mfma> to tensor<128x128xf16, #mfma>
|
||||
%75 = tt.splat %12 : (i64) -> tensor<128x1xi64, #blocked>
|
||||
%76 = arith.muli %38, %75 : tensor<128x1xi64, #blocked>
|
||||
%77 = tt.splat %11 : (!tt.ptr<f16, 1>) -> tensor<128x1x!tt.ptr<f16, 1>, #blocked>
|
||||
%78 = tt.addptr %77, %76 : tensor<128x1x!tt.ptr<f16, 1>, #blocked>, tensor<128x1xi64, #blocked>
|
||||
%79 = tt.broadcast %78 : (tensor<128x1x!tt.ptr<f16, 1>, #blocked>) -> tensor<128x128x!tt.ptr<f16, 1>, #blocked>
|
||||
%80 = tt.addptr %79, %49 : tensor<128x128x!tt.ptr<f16, 1>, #blocked>, tensor<128x128xi64, #blocked>
|
||||
%81 = triton_gpu.convert_layout %80 : (tensor<128x128x!tt.ptr<f16, 1>, #blocked>) -> tensor<128x128x!tt.ptr<f16, 1>, #mfma>
|
||||
tt.store %81, %74 {cache = 1 : i32, evict = 1 : i32} : tensor<128x128xf16, #mfma>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user