[OPTIMIZER] cleaned, renamed and simplified some optimization passes (#1232)

This shouldn't actually change the behavior of Triton -- only clean things up.
This commit is contained in:
Philippe Tillet
2023-02-22 13:54:55 -08:00
committed by GitHub
parent ba0198326e
commit 0ec277efc5
18 changed files with 599 additions and 652 deletions

View File

@@ -25,6 +25,8 @@ def TTG_ConvertLayoutOp : TTG_Op<"convert_layout",
let results = (outs TT_Tensor:$result);
let hasCanonicalizeMethod = 1;
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}

View File

@@ -6,7 +6,9 @@
namespace mlir {
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 2);
// TODO(Keren): prefetch pass not working yet
std::unique_ptr<Pass>
createTritonGPUAccelerateMatmulPass(int computeCapability = 80);
std::unique_ptr<Pass> createTritonGPUPrefetchPass();
std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
@@ -17,10 +19,12 @@ std::unique_ptr<Pass> createTritonGPUReorderInstructionsPass();
std::unique_ptr<Pass> createTritonGPUDecomposeConversionsPass();
std::unique_ptr<Pass> createTritonGPUCombineOpsPass(int computeCapability = 80);
std::unique_ptr<Pass> createTritonGPURemoveLayoutConversionsPass();
std::unique_ptr<Pass> createTritonGPUVerifier();
std::unique_ptr<Pass> createTritonGPUFuseTranspositionsPass();
std::unique_ptr<Pass> createTritonGPUUpdateMmaForVoltaPass();
/// Generate the code for registering passes.

View File

@@ -7,7 +7,8 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
let summary = "pipeline";
let description = [{
Unroll loops to hide global memory -> shared memory latency.
Replace `LoadOp` in loops by `InsertSliceAsyncOp` instructions that asynchronously construct the data
needed at the next iteration
}];
let constructor = "mlir::createTritonGPUPipelinePass()";
@@ -27,7 +28,8 @@ def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
let summary = "prefetch";
let description = [{
Prefetch operands (a and b) of tt.dot into shared memory to hide shared memory -> register latency.
Decompose `DotOp` instructions in loops into several finer-grained `DotOp`
that may have their operands constructed at the end of the previous iteration
}];
let constructor = "mlir::createTritonGPUPrefetchPass()";
@@ -37,6 +39,41 @@ def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
"mlir::arith::ArithDialect"];
}
def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::ModuleOp"> {
let summary = "accelerate matmul";
let description = [{
Optimize the input/output layout of `dot` instruction to make them compatible hardware accelerators
(e.g., Nvidia tensor cores)
}];
let constructor = "mlir::createTritonGPUAccelerateMatmulPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
];
}
def TritonGPUFuseTranspositions : Pass<"tritongpu-fuse-transposition", "mlir::ModuleOp"> {
let summary = "fuse transpositions";
let description = [{
Re-arranged layouts of tensors used as matrix multiplication operands so as to promote the use of
hardware-accelerated transpositions.
}];
let constructor = "mlir::createTritonGPUFuseTranspositionsPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
}
def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> {
let summary = "coalesce";
@@ -49,26 +86,16 @@ def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> {
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
}
def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
let summary = "combine triton gpu ops";
def TritonGPURemoveLayoutConversions : Pass<"tritongpu-remove-layout-conversions", "mlir::ModuleOp"> {
let summary = "remove superfluous layout conversions";
let description = [{
convert_layout(convert_layout(%src, #LAYOUT_0), #LAYOUT_1) =>
convert_layout(%src, #LAYOUT_1)
convert_layout(%src, #LAYOUT) => %src if %src.layout() == #LAYOUT
}];
let constructor = "mlir::createTritonGPUCombineOpsPass()";
let constructor = "mlir::createTritonGPURemoveLayoutConversionsPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
];
}
def TritonGPUReorderInstructions: Pass<"tritongpu-reorder-instructions", "mlir::ModuleOp"> {
@@ -95,19 +122,6 @@ def TritonGPUDecomposeConversions: Pass<"tritongpu-decompose-conversions", "mlir
"mlir::triton::TritonDialect"];
}
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {
let summary = "canonicalize scf.ForOp ops";
let description = [{
This implements some optimizations that are missing in the standard scf.ForOp
canonicalizer.
}];
let constructor = "mlir::createTritonGPUCanonicalizeLoopsPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
}
def UpdateMmaForVolta : Pass<"tritongpu-update-mma-for-volta", "mlir::ModuleOp"> {
let summary = "Update mma encodings for Volta";

View File

@@ -790,6 +790,141 @@ struct TritonGPUInferLayoutInterface
}
};
//===----------------------------------------------------------------------===//
// Canonicalizer
//===----------------------------------------------------------------------===//
LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
PatternRewriter &rewriter) {
// we don't handle conversions to DotOperandEncodingAttr
// this is a heuristics to accommodate fused attention
auto srcType = op.getOperand().getType().cast<RankedTensorType>();
auto dstType = op.getType().cast<RankedTensorType>();
if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>() &&
srcType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
return mlir::failure();
// convert to the same layout -- we can delete
if (op->getResultTypes() == op->getOperandTypes()) {
rewriter.replaceOp(op, op->getOperands());
return mlir::success();
}
Operation *arg = op->getOperand(0).getDefiningOp();
// block argument
if (!arg)
return mlir::failure();
// cvt(view) -> view
if (auto view = dyn_cast<triton::ViewOp>(arg)) {
rewriter.replaceOpWithNewOp<triton::ViewOp>(op, op->getResult(0).getType(),
view.getResult());
return mlir::success();
}
// cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2)
auto alloc_tensor = dyn_cast<triton::gpu::AllocTensorOp>(arg);
if (alloc_tensor) {
if (!isSharedEncoding(op->getResult(0))) {
return mlir::failure();
}
rewriter.replaceOpWithNewOp<triton::gpu::AllocTensorOp>(
op, op->getResult(0).getType());
return mlir::success();
}
// cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2))
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
if (insert_slice) {
if (!isSharedEncoding(op->getResult(0))) {
return mlir::failure();
}
auto newType = op->getResult(0).getType().cast<RankedTensorType>();
// Ensure that the new insert_slice op is placed in the same place as the
// old insert_slice op. Otherwise, the new insert_slice op may be placed
// after the async_wait op, which is not allowed.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(insert_slice);
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newType, insert_slice.getDst());
rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>(
op, newType, insert_slice.getSrc(), newArg.getResult(),
insert_slice.getIndex(), insert_slice.getMask(),
insert_slice.getOther(), insert_slice.getCache(),
insert_slice.getEvict(), insert_slice.getIsVolatile(),
insert_slice.getAxis());
return mlir::success();
}
// cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2))
auto extract_slice = dyn_cast<tensor::ExtractSliceOp>(arg);
if (extract_slice) {
if (!isSharedEncoding(op->getResult(0))) {
return mlir::failure();
}
auto origType =
extract_slice.getSource().getType().cast<RankedTensorType>();
auto newType = RankedTensorType::get(
origType.getShape(), origType.getElementType(),
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
auto origResType = op->getResult(0).getType().cast<RankedTensorType>();
auto resType = RankedTensorType::get(
origResType.getShape(), origResType.getElementType(),
extract_slice.getType().cast<RankedTensorType>().getEncoding());
// Ensure that the new extract_slice op is placed in the same place as the
// old extract_slice op. Otherwise, the new extract_slice op may be placed
// after the async_wait op, which is not allowed.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(extract_slice);
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newType, extract_slice.getSource());
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
op, resType, newArg.getResult(), extract_slice.offsets(),
extract_slice.sizes(), extract_slice.strides(),
extract_slice.static_offsets(), extract_slice.static_sizes(),
extract_slice.static_strides());
return mlir::success();
}
// cvt(cvt(x, type1), type2) -> cvt(x, type2)
if (llvm::isa<triton::gpu::ConvertLayoutOp>(arg)) {
if (arg->getOperand(0).getDefiningOp() &&
!isSharedEncoding(arg->getOperand(0)) &&
isSharedEncoding(op.getOperand()) &&
!isSharedEncoding(op.getResult())) {
return mlir::failure();
}
if (isSharedEncoding(op.getOperand()) && isSharedEncoding(op.getResult())) {
return mlir::failure();
}
auto srcType = op.getOperand().getType().cast<RankedTensorType>();
auto srcShared =
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
if (srcShared && srcShared.getVec() > 1)
return mlir::failure();
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
op, op->getResultTypes().front(), arg->getOperand(0));
return mlir::success();
}
// cvt(type1, splat(type2, x)) -> splat(type1, x)
if (auto splat = llvm::dyn_cast<triton::SplatOp>(arg)) {
rewriter.replaceOpWithNewOp<triton::SplatOp>(op, op->getResultTypes(),
splat.getSrc());
return mlir::success();
}
// cvt(type1, make_range(type2, x)) -> make_range(type1, x)
if (auto range = llvm::dyn_cast<triton::MakeRangeOp>(arg)) {
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
op, op->getResultTypes(), range.getStart(), range.getEnd());
return mlir::success();
}
// cvt(type, constant) -> constant
if (auto cst = llvm::dyn_cast<arith::ConstantOp>(arg))
if (auto ret = cst.getValue().dyn_cast<SplatElementsAttr>()) {
auto newRet = SplatElementsAttr::get(op->getResultTypes().front(),
ret.getSplatValue<Attribute>());
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newRet);
return mlir::success();
}
return mlir::failure();
}
//===----------------------------------------------------------------------===//
void TritonGPUDialect::initialize() {
addAttributes<
#define GET_ATTRDEF_LIST

View File

@@ -0,0 +1,212 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include <memory>
using namespace mlir;
namespace {
using triton::DotOp;
using triton::gpu::ConvertLayoutOp;
using triton::gpu::DotOperandEncodingAttr;
using triton::gpu::MmaEncodingAttr;
using triton::gpu::SliceEncodingAttr;
int computeCapabilityToMMAVersion(int computeCapability) {
if (computeCapability < 70) {
return 0;
} else if (computeCapability < 80) {
return 1;
} else if (computeCapability < 90) {
return 2;
} else if (computeCapability < 100) {
// FIXME: temporarily add this to pass unis tests
return 2;
} else {
assert(false && "computeCapability > 100 not supported");
return 3;
}
}
SmallVector<int64_t, 2> mmaVersionToShapePerWarp(int version) {
if (version == 1)
return {16, 16};
else if (version == 2)
return {16, 8};
else {
assert(false && "version not supported");
return {0, 0};
}
}
SmallVector<unsigned, 2> warpsPerTileV1(const ArrayRef<int64_t> shape,
int numWarps) {
// Set a default value that ensures product of wpt equals numWarps
return {static_cast<unsigned>(numWarps), 1};
}
SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
const ArrayRef<int64_t> shape,
int numWarps) {
SetVector<Operation *> slices;
mlir::getForwardSlice(dotOp.getResult(), &slices);
if (llvm::find_if(slices, [](Operation *op) {
return isa<triton::DotOp>(op);
}) != slices.end())
return {(unsigned)numWarps, 1};
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
bool changed = false;
// TODO (@daadaada): double-check.
// original logic in
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
// seems buggy for shape = [32, 16] ?
do {
changed = false;
if (ret[0] * ret[1] >= numWarps)
break;
if (shape[0] / shapePerWarp[0] / ret[0] >=
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
if (ret[0] < shape[0] / shapePerWarp[0]) {
ret[0] *= 2;
} else
ret[1] *= 2;
} else {
ret[1] *= 2;
}
} while (true);
return ret;
}
class BlockedToMMA : public mlir::RewritePattern {
int computeCapability;
mutable int mmaV1Counter{}; // used to generate ID for MMAv1 encoding
public:
BlockedToMMA(mlir::MLIRContext *context, int computeCapability)
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
computeCapability(computeCapability) {}
static SmallVector<unsigned, 2> getWarpsPerTile(triton::DotOp dotOp,
const ArrayRef<int64_t> shape,
int version, int numWarps) {
switch (version) {
case 1:
return warpsPerTileV1(shape, numWarps);
case 2:
return warpsPerTileV2(dotOp, shape, numWarps);
default:
llvm_unreachable("unsupported MMA version");
}
}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto dotOp = cast<triton::DotOp>(op);
// TODO: Check data-types and SM compatibility
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
if (!oldRetType.getEncoding() ||
oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
return failure();
// for FMA, should retain the blocked layout.
int versionMajor = computeCapabilityToMMAVersion(computeCapability);
if (!supportMMA(dotOp, versionMajor))
return failure();
// get MMA encoding for the given number of warps
auto retShape = oldRetType.getShape();
auto mod = op->getParentOfType<mlir::ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
auto warpsPerTile =
getWarpsPerTile(dotOp, retShape, versionMajor, numWarps);
triton::gpu::MmaEncodingAttr mmaEnc;
if (versionMajor == 1) {
mmaEnc = triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), versionMajor, numWarps, mmaV1Counter++);
} else if (versionMajor == 2) {
mmaEnc = triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), versionMajor, 0 /*versionMinor*/,
warpsPerTile);
} else {
llvm_unreachable("Mma layout only supports versionMajor in {1, 2}");
}
auto newRetType =
RankedTensorType::get(retShape, oldRetType.getElementType(), mmaEnc);
// convert accumulator
auto oldAcc = dotOp.getOperand(2);
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
oldAcc.getLoc(), newRetType, oldAcc);
Value a = dotOp.getA();
Value b = dotOp.getB();
auto oldAType = a.getType().cast<RankedTensorType>();
auto oldBType = b.getType().cast<RankedTensorType>();
auto oldAOrder = oldAType.getEncoding()
.cast<triton::gpu::DotOperandEncodingAttr>()
.getParent()
.cast<triton::gpu::BlockedEncodingAttr>()
.getOrder();
auto oldBOrder = oldBType.getEncoding()
.cast<triton::gpu::DotOperandEncodingAttr>()
.getParent()
.cast<triton::gpu::BlockedEncodingAttr>()
.getOrder();
Attribute isMMAv1RowA;
Attribute isMMAv1RowB;
if (versionMajor == 1) {
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
}
auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(
oldAType.getContext(), 0, newRetType.getEncoding(), isMMAv1RowA));
auto newBType = RankedTensorType::get(
oldBType.getShape(), oldBType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(
oldBType.getContext(), 1, newRetType.getEncoding(), isMMAv1RowB));
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
auto newDot = rewriter.create<triton::DotOp>(
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.getAllowTF32());
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
op, oldRetType, newDot.getResult());
return success();
}
};
} // namespace
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
class TritonGPUAccelerateMatmulPass
: public TritonGPUAccelerateMatmulBase<TritonGPUAccelerateMatmulPass> {
public:
TritonGPUAccelerateMatmulPass() = default;
TritonGPUAccelerateMatmulPass(int computeCapability) {
this->computeCapability = computeCapability;
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
mlir::RewritePatternSet patterns(context);
patterns.add<::BlockedToMMA>(context, computeCapability);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
signalPassFailure();
}
}
};
std::unique_ptr<Pass>
mlir::createTritonGPUAccelerateMatmulPass(int computeCapability) {
return std::make_unique<TritonGPUAccelerateMatmulPass>(computeCapability);
}

View File

@@ -1,22 +1,18 @@
set(LLVM_TARGET_DEFINITIONS Combine.td)
mlir_tablegen(TritonGPUCombine.inc -gen-rewriters)
add_public_tablegen_target(TritonGPUCombineIncGen)
add_mlir_dialect_library(TritonGPUTransforms
AccelerateMatmul.cpp
Coalesce.cpp
CanonicalizeLoops.cpp
Combine.cpp
DecomposeConversions.cpp
FuseTranspositions.cpp
Pipeline.cpp
Prefetch.cpp
RemoveLayoutConversions.cpp
ReorderInstructions.cpp
DecomposeConversions.cpp
TritonGPUConversion.cpp
UpdateMmaForVolta.cpp
Utility.cpp
DEPENDS
TritonGPUTransformsIncGen
TritonGPUCombineIncGen
LINK_LIBS PUBLIC
TritonIR

View File

@@ -1,55 +0,0 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::triton;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
namespace {
struct CanonicalizePass
: public TritonGPUCanonicalizeLoopsBase<CanonicalizePass> {
CanonicalizePass() = default;
void runOnOperation() override {
// Canonicalize pass may have created dead code that
// standard scf.for canonicalization cannot handle
// as of LLVM 14. For example, the iteration arguments
// for the pointer of the synchronous loads that are
// discarded.
// The following piece of code is a workaround to
// very crudely remove dead code, by making an iteration
// argument yield itself if it is not used to create
// side effects anywhere.
getOperation()->walk([&](scf::ForOp forOp) -> void {
for (size_t i = 0; i < forOp.getNumResults(); ++i) {
// condition 1: no other iter arguments depend on it
SetVector<Operation *> fwdSlice;
mlir::getForwardSlice(forOp.getRegionIterArgs()[i], &fwdSlice);
Operation *yieldOp = forOp.getBody()->getTerminator();
bool noOtherDependency = std::all_of(
yieldOp->operand_begin(), yieldOp->operand_end(), [&](Value arg) {
return arg == yieldOp->getOperand(i) ||
!fwdSlice.contains(arg.getDefiningOp());
});
// condition 2: final value is not used after the loop
auto retVal = forOp.getResult(i);
bool noUserAfterLoop = retVal.getUsers().empty();
// yielding the region iter arg will cause loop canonicalization
// to clean up the dead code
if (noOtherDependency && noUserAfterLoop) {
yieldOp->setOperand(i, forOp.getRegionIterArgs()[i]);
}
}
});
}
};
} // anonymous namespace
std::unique_ptr<Pass> mlir::createTritonGPUCanonicalizeLoopsPass() {
return std::make_unique<CanonicalizePass>();
}

View File

@@ -1,8 +0,0 @@
#ifndef TRITONGPU_PATTERNS
#define TRITONGPU_PATTERNS
include "triton/Dialect/TritonGPU/IR/TritonGPUOps.td"
include "triton/Dialect/Triton/IR/TritonOps.td"
include "mlir/IR/PatternBase.td"
#endif

View File

@@ -0,0 +1,153 @@
#include "Utility.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include <memory>
using namespace mlir;
namespace {
using triton::DotOp;
using triton::gpu::ConvertLayoutOp;
using triton::gpu::DotOperandEncodingAttr;
using triton::gpu::MmaEncodingAttr;
using triton::gpu::SliceEncodingAttr;
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
public:
explicit OptimizeConvertToDotOperand(mlir::MLIRContext *context)
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
// order
ArrayRef<unsigned> order;
if (auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
order = srcBlockedLayout.getOrder();
else if (auto srcSharedLayout =
srcType.getEncoding()
.dyn_cast<triton::gpu::SharedEncodingAttr>())
order = srcSharedLayout.getOrder();
else
return failure();
// dot operand output
auto dstDotOperandLayout =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (!dstDotOperandLayout)
return failure();
if (!dstDotOperandLayout.getIsMMAv1Row())
return failure();
bool isMMAv1Row =
dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row))
return failure();
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
op->getContext(), dstDotOperandLayout.getOpIdx(),
dstDotOperandLayout.getParent(), newIsRow);
auto newDstType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(), newDstEncoding);
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newDstType, cvt.getOperand());
rewriter.replaceOp(op, newCvt.getResult());
return success();
}
};
// convert(trans(convert(arg)))
// x = convert_layout arg: #distributed -> #shared_x
// y = trans x: #shared_x -> #shared_y
// z = convert_layout y: #shared_y -> #dot_operand
class ConvertTransConvert : public mlir::RewritePattern {
public:
ConvertTransConvert(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
1, context) {}
LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto dstOp = cast<triton::gpu::ConvertLayoutOp>(op);
auto tmpOp =
dyn_cast_or_null<triton::TransOp>(dstOp.getSrc().getDefiningOp());
if (!tmpOp)
return mlir::failure();
auto srcOp = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(
tmpOp.getSrc().getDefiningOp());
if (!srcOp)
return mlir::failure();
auto arg = srcOp.getSrc();
auto X = tmpOp.getSrc();
// types
auto argType = arg.getType().cast<RankedTensorType>();
auto XType = X.getType().cast<RankedTensorType>();
auto ZType = dstOp.getResult().getType().cast<RankedTensorType>();
// encodings
auto argEncoding = argType.getEncoding();
auto XEncoding =
XType.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
auto ZEncoding =
ZType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (!ZEncoding)
return mlir::failure();
// new X encoding
auto newXOrder = triton::gpu::getOrder(argEncoding);
auto newXEncoding = triton::gpu::SharedEncodingAttr::get(
getContext(), ZEncoding, XType.getShape(), newXOrder,
XType.getElementType());
auto newXType = RankedTensorType::get(XType.getShape(),
XType.getElementType(), newXEncoding);
if (XEncoding == newXEncoding)
return mlir::failure();
auto newX = rewriter.create<triton::gpu::ConvertLayoutOp>(srcOp.getLoc(),
newXType, arg);
auto newY = rewriter.create<triton::TransOp>(tmpOp.getLoc(), newX);
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(dstOp, ZType,
newY);
return mlir::success();
}
};
} // namespace
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
class TritonGPUFuseTranspositionsPass
: public TritonGPUFuseTranspositionsBase<TritonGPUFuseTranspositionsPass> {
public:
TritonGPUFuseTranspositionsPass() = default;
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
mlir::PassManager pm(m.getContext());
pm.addPass(mlir::createCanonicalizerPass());
auto ret = pm.run(m);
mlir::RewritePatternSet patterns(context);
patterns.add<OptimizeConvertToDotOperand>(context);
patterns.add<ConvertTransConvert>(context);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
signalPassFailure();
if (fixupLoops(m).failed())
signalPassFailure();
}
};
std::unique_ptr<Pass> mlir::createTritonGPUFuseTranspositionsPass() {
return std::make_unique<TritonGPUFuseTranspositionsPass>();
}

View File

@@ -22,7 +22,6 @@
using namespace mlir;
namespace {
#include "TritonGPUCombine.inc"
using triton::DotOp;
using triton::gpu::ConvertLayoutOp;
using triton::gpu::DotOperandEncodingAttr;
@@ -139,132 +138,7 @@ public:
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op))
return mlir::failure();
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
// we don't handle conversions to DotOperandEncodingAttr
// this is a heuristics to accommodate fused attention
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
auto dstType = convert.getType().cast<RankedTensorType>();
if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>() &&
srcType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
return mlir::failure();
// convert to the same layout -- we can delete
if (op->getResultTypes() == op->getOperandTypes()) {
rewriter.replaceOp(op, op->getOperands());
return mlir::success();
}
Operation *arg = op->getOperand(0).getDefiningOp();
// block argument
if (!arg)
return mlir::failure();
// cvt(view) -> view
if (auto view = dyn_cast<triton::ViewOp>(arg)) {
rewriter.replaceOpWithNewOp<triton::ViewOp>(
op, op->getResult(0).getType(), view.getResult());
return mlir::success();
}
// cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2)
auto alloc_tensor = dyn_cast<triton::gpu::AllocTensorOp>(arg);
if (alloc_tensor) {
if (!isSharedEncoding(op->getResult(0))) {
return mlir::failure();
}
rewriter.replaceOpWithNewOp<triton::gpu::AllocTensorOp>(
op, op->getResult(0).getType());
return mlir::success();
}
// cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2))
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
if (insert_slice) {
if (!isSharedEncoding(op->getResult(0))) {
return mlir::failure();
}
auto newType = op->getResult(0).getType().cast<RankedTensorType>();
// Ensure that the new insert_slice op is placed in the same place as the
// old insert_slice op. Otherwise, the new insert_slice op may be placed
// after the async_wait op, which is not allowed.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(insert_slice);
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newType, insert_slice.getDst());
rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>(
op, newType, insert_slice.getSrc(), newArg.getResult(),
insert_slice.getIndex(), insert_slice.getMask(),
insert_slice.getOther(), insert_slice.getCache(),
insert_slice.getEvict(), insert_slice.getIsVolatile(),
insert_slice.getAxis());
return mlir::success();
}
// cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2))
auto extract_slice = dyn_cast<tensor::ExtractSliceOp>(arg);
if (extract_slice) {
if (!isSharedEncoding(op->getResult(0))) {
return mlir::failure();
}
auto origType =
extract_slice.getSource().getType().cast<RankedTensorType>();
auto newType = RankedTensorType::get(
origType.getShape(), origType.getElementType(),
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
auto origResType = op->getResult(0).getType().cast<RankedTensorType>();
auto resType = RankedTensorType::get(
origResType.getShape(), origResType.getElementType(),
extract_slice.getType().cast<RankedTensorType>().getEncoding());
// Ensure that the new extract_slice op is placed in the same place as the
// old extract_slice op. Otherwise, the new extract_slice op may be placed
// after the async_wait op, which is not allowed.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(extract_slice);
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newType, extract_slice.getSource());
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
op, resType, newArg.getResult(), extract_slice.offsets(),
extract_slice.sizes(), extract_slice.strides(),
extract_slice.static_offsets(), extract_slice.static_sizes(),
extract_slice.static_strides());
return mlir::success();
}
// cvt(cvt(x, type1), type2) -> cvt(x, type2)
if (llvm::isa<triton::gpu::ConvertLayoutOp>(arg)) {
if (arg->getOperand(0).getDefiningOp() &&
!isSharedEncoding(arg->getOperand(0)) &&
isSharedEncoding(convert.getOperand()) &&
!isSharedEncoding(convert.getResult())) {
return mlir::failure();
}
if (isSharedEncoding(convert.getOperand()) &&
isSharedEncoding(convert.getResult())) {
return mlir::failure();
}
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
auto srcShared =
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
if (srcShared && srcShared.getVec() > 1)
return mlir::failure();
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
op, op->getResultTypes().front(), arg->getOperand(0));
return mlir::success();
}
// cvt(type1, splat(type2, x)) -> splat(type1, x)
if (auto splat = llvm::dyn_cast<triton::SplatOp>(arg)) {
rewriter.replaceOpWithNewOp<triton::SplatOp>(op, op->getResultTypes(),
splat.getSrc());
return mlir::success();
}
// cvt(type1, make_range(type2, x)) -> make_range(type1, x)
if (auto range = llvm::dyn_cast<triton::MakeRangeOp>(arg)) {
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
op, op->getResultTypes(), range.getStart(), range.getEnd());
return mlir::success();
}
// cvt(type, constant) -> constant
if (auto cst = llvm::dyn_cast<arith::ConstantOp>(arg))
if (auto ret = cst.getValue().dyn_cast<SplatElementsAttr>()) {
auto newRet = SplatElementsAttr::get(op->getResultTypes().front(),
ret.getSplatValue<Attribute>());
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newRet);
return mlir::success();
}
return mlir::failure();
return ConvertLayoutOp::canonicalize(convert, rewriter);
}
};
@@ -568,9 +442,9 @@ public:
};
//
class FoldConvertAndReduce : public mlir::RewritePattern {
class RematerializeForward : public mlir::RewritePattern {
public:
explicit FoldConvertAndReduce(mlir::MLIRContext *context)
explicit RematerializeForward(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
1, context) {}
@@ -837,390 +711,6 @@ public:
}
};
// -----------------------------------------------------------------------------
//
// -----------------------------------------------------------------------------
class RematerializeForward : public mlir::RewritePattern {
public:
explicit RematerializeForward(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
2, context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *_cvtOp,
mlir::PatternRewriter &rewriter) const override {
auto cvt = cast<triton::gpu::ConvertLayoutOp>(_cvtOp);
auto forOp = dyn_cast<scf::ForOp>(cvt->getParentOp());
if (!forOp)
return mlir::failure();
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
SetVector<Operation *> cvtSlices;
auto filter = [&](Operation *op) {
return isInLoop(op) &&
!isa<triton::LoadOp, triton::StoreOp, triton::AtomicRMWOp,
triton::AtomicCASOp>(op) &&
!isa<triton::DotOp>(op) && !isa<scf::YieldOp>(op) &&
!isa<triton::gpu::ConvertLayoutOp>(op);
};
mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter);
if (cvtSlices.empty())
return failure();
for (Operation *op : cvtSlices) {
if (!isa<triton::ViewOp, triton::CatOp>(op) &&
!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
!op->hasTrait<mlir::OpTrait::Elementwise>() &&
!isa<triton::StoreOp>(op))
return failure();
for (Value arg : op->getOperands()) {
Operation *argOp = arg.getDefiningOp();
if (argOp && (argOp != cvt) &&
!isa<arith::ConstantOp, triton::SplatOp, triton::MakeRangeOp>(
argOp)) {
return failure();
}
}
}
// Otherwise, we push the conversion forward
// since we'll be able to move it out of
// the loop once it reaches the yield op
pushConversionForward(cvt, cvtSlices, rewriter);
return success();
}
};
// -----------------------------------------------------------------------------
//
// -----------------------------------------------------------------------------
namespace {
int computeCapabilityToMMAVersion(int computeCapability) {
if (computeCapability < 70) {
return 0;
} else if (computeCapability < 80) {
return 1;
} else if (computeCapability < 90) {
return 2;
} else if (computeCapability < 100) {
// FIXME: temporarily add this to pass unis tests
return 2;
} else {
assert(false && "computeCapability > 100 not supported");
return 3;
}
}
SmallVector<int64_t, 2> mmaVersionToShapePerWarp(int version) {
if (version == 1)
return {16, 16};
else if (version == 2)
return {16, 8};
else {
assert(false && "version not supported");
return {0, 0};
}
}
SmallVector<unsigned, 2> warpsPerTileV1(const ArrayRef<int64_t> shape,
int numWarps) {
// Set a default value and ensure product of wpt equals numWarps
return {static_cast<unsigned>(numWarps), 1};
}
SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
const ArrayRef<int64_t> shape,
int numWarps) {
SetVector<Operation *> slices;
mlir::getForwardSlice(dotOp.getResult(), &slices);
if (llvm::find_if(slices, [](Operation *op) {
return isa<triton::DotOp>(op);
}) != slices.end())
return {(unsigned)numWarps, 1};
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
bool changed = false;
// TODO (@daadaada): double-check.
// original logic in
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
// seems buggy for shape = [32, 16] ?
do {
changed = false;
if (ret[0] * ret[1] >= numWarps)
break;
if (shape[0] / shapePerWarp[0] / ret[0] >=
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
if (ret[0] < shape[0] / shapePerWarp[0]) {
ret[0] *= 2;
} else
ret[1] *= 2;
} else {
ret[1] *= 2;
}
} while (true);
return ret;
}
} // namespace
class OptimizeBlockedToShared : public mlir::RewritePattern {
public:
explicit OptimizeBlockedToShared(mlir::MLIRContext *context)
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
auto dstSharedLayout =
dstType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
if (!srcBlockedLayout || !dstSharedLayout)
return failure();
if (srcBlockedLayout.getOrder() == dstSharedLayout.getOrder())
return failure();
// For now only works if single use is transpose
// TODO: rematerialize #shared uses
auto users = op->getUsers();
if (std::distance(users.begin(), users.end()) != 1 ||
!isa<triton::TransOp>(*users.begin()))
return failure();
auto tmpShared = triton::gpu::SharedEncodingAttr::get(
op->getContext(), dstSharedLayout.getVec(),
dstSharedLayout.getPerPhase(), dstSharedLayout.getMaxPhase(),
srcBlockedLayout.getOrder());
auto tmpType = RankedTensorType::get(srcType.getShape(),
srcType.getElementType(), tmpShared);
auto tmpCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), tmpType, cvt.getOperand());
auto newDstType = RankedTensorType::get(
users.begin()->getResultTypes()[0].cast<RankedTensorType>().getShape(),
srcType.getElementType(), dstSharedLayout);
auto newTrans = rewriter.create<triton::TransOp>(op->getLoc(), newDstType,
tmpCvt.getResult());
rewriter.replaceOp(*users.begin(), newTrans.getResult());
return success();
}
};
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
public:
explicit OptimizeConvertToDotOperand(mlir::MLIRContext *context)
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
// order
ArrayRef<unsigned> order;
if (auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
order = srcBlockedLayout.getOrder();
else if (auto srcSharedLayout =
srcType.getEncoding()
.dyn_cast<triton::gpu::SharedEncodingAttr>())
order = srcSharedLayout.getOrder();
else
return failure();
// dot operand output
auto dstDotOperandLayout =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (!dstDotOperandLayout)
return failure();
if (!dstDotOperandLayout.getIsMMAv1Row())
return failure();
bool isMMAv1Row =
dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row))
return failure();
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
op->getContext(), dstDotOperandLayout.getOpIdx(),
dstDotOperandLayout.getParent(), newIsRow);
auto newDstType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(), newDstEncoding);
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newDstType, cvt.getOperand());
rewriter.replaceOp(op, newCvt.getResult());
return success();
}
};
class BlockedToMMA : public mlir::RewritePattern {
int computeCapability;
mutable int mmaV1Counter{}; // used to generate ID for MMAv1 encoding
public:
BlockedToMMA(mlir::MLIRContext *context, int computeCapability)
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
computeCapability(computeCapability) {}
static SmallVector<unsigned, 2> getWarpsPerTile(triton::DotOp dotOp,
const ArrayRef<int64_t> shape,
int version, int numWarps) {
switch (version) {
case 1:
return warpsPerTileV1(shape, numWarps);
case 2:
return warpsPerTileV2(dotOp, shape, numWarps);
default:
assert(false && "not supported version");
return {0, 0};
}
}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto dotOp = cast<triton::DotOp>(op);
// TODO: Check data-types and SM compatibility
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
if (!oldRetType.getEncoding() ||
oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
return failure();
// for FMA, should retain the blocked layout.
int versionMajor = computeCapabilityToMMAVersion(computeCapability);
if (!supportMMA(dotOp, versionMajor))
return failure();
// get MMA encoding for the given number of warps
auto retShape = oldRetType.getShape();
auto mod = op->getParentOfType<mlir::ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
auto warpsPerTile =
getWarpsPerTile(dotOp, retShape, versionMajor, numWarps);
triton::gpu::MmaEncodingAttr mmaEnc;
if (versionMajor == 1) {
mmaEnc = triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), versionMajor, numWarps, mmaV1Counter++);
} else if (versionMajor == 2) {
mmaEnc = triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), versionMajor, 0 /*versionMinor*/,
warpsPerTile);
} else {
assert(false && "Mma layout only support versionMajor of 1 or 2");
}
auto newRetType =
RankedTensorType::get(retShape, oldRetType.getElementType(), mmaEnc);
// convert accumulator
auto oldAcc = dotOp.getOperand(2);
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
oldAcc.getLoc(), newRetType, oldAcc);
Value a = dotOp.getA();
Value b = dotOp.getB();
auto oldAType = a.getType().cast<RankedTensorType>();
auto oldBType = b.getType().cast<RankedTensorType>();
auto oldAOrder = oldAType.getEncoding()
.cast<triton::gpu::DotOperandEncodingAttr>()
.getParent()
.cast<triton::gpu::BlockedEncodingAttr>()
.getOrder();
auto oldBOrder = oldBType.getEncoding()
.cast<triton::gpu::DotOperandEncodingAttr>()
.getParent()
.cast<triton::gpu::BlockedEncodingAttr>()
.getOrder();
Attribute isMMAv1RowA;
Attribute isMMAv1RowB;
if (versionMajor == 1) {
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
}
auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(
oldAType.getContext(), 0, newRetType.getEncoding(), isMMAv1RowA));
auto newBType = RankedTensorType::get(
oldBType.getShape(), oldBType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(
oldBType.getContext(), 1, newRetType.getEncoding(), isMMAv1RowB));
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
auto newDot = rewriter.create<triton::DotOp>(
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.getAllowTF32());
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
op, oldRetType, newDot.getResult());
return success();
}
};
// Convert + trans + convert
// x = convert_layout distributed -> #shared_x
// y = trans x -> #shared_y
// z = convert_layout y -> #dot_operand
class ConvertTransConvert : public mlir::RewritePattern {
public:
ConvertTransConvert(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
1, context) {}
LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto dstOp = cast<triton::gpu::ConvertLayoutOp>(op);
auto tmpOp =
dyn_cast_or_null<triton::TransOp>(dstOp.getSrc().getDefiningOp());
if (!tmpOp)
return mlir::failure();
auto srcOp = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(
tmpOp.getSrc().getDefiningOp());
if (!srcOp)
return mlir::failure();
auto arg = srcOp.getSrc();
auto X = tmpOp.getSrc();
// types
auto argType = arg.getType().cast<RankedTensorType>();
auto XType = X.getType().cast<RankedTensorType>();
auto ZType = dstOp.getResult().getType().cast<RankedTensorType>();
// encodings
auto argEncoding = argType.getEncoding();
auto XEncoding =
XType.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
auto ZEncoding =
ZType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (!ZEncoding)
return mlir::failure();
// new X encoding
auto newXOrder = triton::gpu::getOrder(argEncoding);
auto newXEncoding = triton::gpu::SharedEncodingAttr::get(
getContext(), ZEncoding, XType.getShape(), newXOrder,
XType.getElementType());
auto newXType = RankedTensorType::get(XType.getShape(),
XType.getElementType(), newXEncoding);
if (XEncoding == newXEncoding)
return mlir::failure();
auto newX = rewriter.create<triton::gpu::ConvertLayoutOp>(srcOp.getLoc(),
newXType, arg);
auto newY = rewriter.create<triton::TransOp>(tmpOp.getLoc(), newX);
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(dstOp, ZType,
newY);
return mlir::success();
}
};
//
class ConvertDotConvert : public mlir::RewritePattern {
public:
@@ -1272,31 +762,25 @@ public:
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
class TritonGPUCombineOpsPass
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
class TritonGPURemoveLayoutConversionsPass
: public TritonGPURemoveLayoutConversionsBase<
TritonGPURemoveLayoutConversionsPass> {
public:
TritonGPUCombineOpsPass() = default;
TritonGPUCombineOpsPass(int computeCapability) {
this->computeCapability = computeCapability;
}
TritonGPURemoveLayoutConversionsPass() = default;
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
mlir::RewritePatternSet patterns(context);
patterns.add<OptimizeBlockedToShared>(context);
patterns.add<OptimizeConvertToDotOperand>(context);
patterns.add<SimplifyConversion>(context);
patterns.add<SimplifyReduceCvt>(context);
patterns.add<FoldConvertAndReduce>(context);
patterns.add<DecomposeDotOperand>(context);
patterns.add<RematerializeBackward>(context);
patterns.add<RematerializeForward>(context);
patterns.add<MoveConvertOutOfLoop>(context);
patterns.add<MoveConvertOutOfIf>(context);
patterns.add<BlockedToMMA>(context, computeCapability);
patterns.add<ConvertTransConvert>(context);
patterns.add<DecomposeDotOperand>(context);
patterns.add<ConvertDotConvert>(context);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
@@ -1309,7 +793,6 @@ public:
}
};
std::unique_ptr<Pass>
mlir::createTritonGPUCombineOpsPass(int computeCapability) {
return std::make_unique<TritonGPUCombineOpsPass>(computeCapability);
std::unique_ptr<Pass> mlir::createTritonGPURemoveLayoutConversionsPass() {
return std::make_unique<TritonGPURemoveLayoutConversionsPass>();
}

View File

@@ -1375,7 +1375,7 @@ void init_triton_ir(py::module &&m) {
.def(
"add_sccp_pass",
[](mlir::PassManager &self) { self.addPass(mlir::createSCCPPass()); })
.def("add_coalesce_pass",
.def("add_tritongpu_coalesce_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUCoalescePass());
})
@@ -1414,10 +1414,18 @@ void init_triton_ir(py::module &&m) {
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUPrefetchPass());
})
.def("add_tritongpu_combine_pass",
.def("add_tritongpu_accelerate_matmul_pass",
[](mlir::PassManager &self, int computeCapability) {
self.addPass(
mlir::createTritonGPUCombineOpsPass(computeCapability));
mlir::createTritonGPUAccelerateMatmulPass(computeCapability));
})
.def("add_tritongpu_fuse_transpositions_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUFuseTranspositionsPass());
})
.def("add_tritongpu_remove_layout_conversions_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPURemoveLayoutConversionsPass());
})
.def("add_tritongpu_update_mma_for_volta_pass",
[](mlir::PassManager &self) {

View File

@@ -975,32 +975,32 @@ def ast_to_ttir(fn, signature, specialization, constants):
return optimize_triton_ir(mod)
def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability):
def ttir_to_ttgir(mod, num_warps):
pm = _triton.ir.pass_manager(mod.context)
pm.add_convert_triton_to_tritongpu_pass(num_warps)
pm.run(mod)
return mod
def optimize_ttgir(mod, num_stages, compute_capability):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_coalesce_pass()
# The combine pass converts blocked layout to mma layout
# for dot ops so that pipeline can get shared memory swizzled correctly.
pm.add_tritongpu_combine_pass(compute_capability)
pm.add_tritongpu_coalesce_pass()
pm.add_tritongpu_accelerate_matmul_pass(compute_capability)
pm.add_tritongpu_remove_layout_conversions_pass()
pm.add_tritongpu_fuse_transpositions_pass()
pm.add_tritongpu_pipeline_pass(num_stages)
# Prefetch must be done after pipeline pass because pipeline pass
# extracts slices from the original tensor.
pm.add_tritongpu_prefetch_pass()
pm.add_canonicalizer_pass()
pm.add_cse_pass()
pm.add_tritongpu_combine_pass(compute_capability)
pm.add_licm_pass()
pm.add_tritongpu_combine_pass(compute_capability)
pm.add_cse_pass()
pm.add_tritongpu_fuse_transpositions_pass()
pm.add_tritongpu_remove_layout_conversions_pass()
pm.add_tritongpu_decompose_conversions_pass()
if compute_capability // 10 == 7:
# The update_mma_for_volta pass helps to compute some information for MMA encoding specifically for MMAv1
# NOTE this pass should be placed after all the passes those modifies mma layout
pm.add_tritongpu_update_mma_for_volta_pass()
pm.add_tritongpu_reorder_instructions_pass()
pm.add_cse_pass()
pm.add_symbol_dce_pass()
pm.add_tritongpu_reorder_instructions_pass()
pm.run(mod)
return mod
@@ -1565,7 +1565,7 @@ def compile(fn, **kwargs):
"ttir": (lambda path: parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
"ttgir": (lambda path: parse_mlir_module(path, context),
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, capability)),
"llir": (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, capability)),
"ptx": (lambda path: Path(path).read_text(),

View File

@@ -42,7 +42,8 @@ if __name__ == '__main__':
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
# triton-ir -> triton-gpu-ir
module = triton.compiler.ttir_to_ttgir(module, num_warps=4, num_stages=3, compute_capability=args.sm)
module = triton.compiler.ttir_to_ttgir(module, num_warps=4)
module = triton.compiler.optimize_ttgir(module, num_stages=3, compute_capability=args.sm)
if args.target == 'triton-gpu-ir':
print(module.str())
sys.exit(0)

View File

@@ -223,6 +223,7 @@ class _attention(torch.autograd.Function):
BLOCK_DMODEL=Lk, num_warps=num_warps,
num_stages=2,
)
# print(h.asm["ttgir"])
ctx.save_for_backward(q, k, v, o, L, m)
ctx.grid = grid
@@ -260,6 +261,7 @@ class _attention(torch.autograd.Function):
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
num_stages=1,
)
# print(h.asm["ttgir"])
return dq, dk, dv, None

View File

@@ -1,4 +1,4 @@
// RUN: triton-opt %s -split-input-file -tritongpu-coalesce -canonicalize | FileCheck %s
// RUN: triton-opt %s -split-input-file -tritongpu-coalesce | FileCheck %s
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>

View File

@@ -1,4 +1,4 @@
// RUN: triton-opt %s -split-input-file -tritongpu-combine 2>&1 | FileCheck %s
// RUN: triton-opt %s -split-input-file -tritongpu-remove-layout-conversions 2>&1 | FileCheck %s
#layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>

View File

@@ -1,4 +1,4 @@
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu -tritongpu-combine -tritongpu-pipeline=num-stages=3 -tritongpu-combine -test-print-allocation 2>&1 | FileCheck %s
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu -tritongpu-remove-layout-conversions -tritongpu-pipeline=num-stages=3 -test-print-allocation 2>&1 | FileCheck %s
// CHECK: offset = 0, size = 49152
// CHECK: offset = 49152, size = 49152

View File

@@ -1,4 +1,4 @@
// RUN: triton-opt %s -split-input-file -tritongpu-combine -tritongpu-update-mma-for-volta 2>&1 | FileCheck %s
// RUN: triton-opt %s -split-input-file -tritongpu-fuse-transposition -tritongpu-update-mma-for-volta 2>&1 | FileCheck %s
// -----