mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZER] cleaned, renamed and simplified some optimization passes (#1232)
This shouldn't actually change the behavior of Triton -- only clean things up.
This commit is contained in:
@@ -25,6 +25,8 @@ def TTG_ConvertLayoutOp : TTG_Op<"convert_layout",
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let hasCanonicalizeMethod = 1;
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,9 @@
|
||||
namespace mlir {
|
||||
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 2);
|
||||
|
||||
// TODO(Keren): prefetch pass not working yet
|
||||
std::unique_ptr<Pass>
|
||||
createTritonGPUAccelerateMatmulPass(int computeCapability = 80);
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUPrefetchPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
|
||||
@@ -17,10 +19,12 @@ std::unique_ptr<Pass> createTritonGPUReorderInstructionsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUDecomposeConversionsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUCombineOpsPass(int computeCapability = 80);
|
||||
std::unique_ptr<Pass> createTritonGPURemoveLayoutConversionsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUVerifier();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUFuseTranspositionsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUUpdateMmaForVoltaPass();
|
||||
|
||||
/// Generate the code for registering passes.
|
||||
|
||||
@@ -7,7 +7,8 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
|
||||
let summary = "pipeline";
|
||||
|
||||
let description = [{
|
||||
Unroll loops to hide global memory -> shared memory latency.
|
||||
Replace `LoadOp` in loops by `InsertSliceAsyncOp` instructions that asynchronously construct the data
|
||||
needed at the next iteration
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUPipelinePass()";
|
||||
@@ -27,7 +28,8 @@ def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
|
||||
let summary = "prefetch";
|
||||
|
||||
let description = [{
|
||||
Prefetch operands (a and b) of tt.dot into shared memory to hide shared memory -> register latency.
|
||||
Decompose `DotOp` instructions in loops into several finer-grained `DotOp`
|
||||
that may have their operands constructed at the end of the previous iteration
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUPrefetchPass()";
|
||||
@@ -37,6 +39,41 @@ def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
|
||||
"mlir::arith::ArithDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::ModuleOp"> {
|
||||
let summary = "accelerate matmul";
|
||||
|
||||
let description = [{
|
||||
Optimize the input/output layout of `dot` instruction to make them compatible hardware accelerators
|
||||
(e.g., Nvidia tensor cores)
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUAccelerateMatmulPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">
|
||||
];
|
||||
|
||||
}
|
||||
|
||||
def TritonGPUFuseTranspositions : Pass<"tritongpu-fuse-transposition", "mlir::ModuleOp"> {
|
||||
let summary = "fuse transpositions";
|
||||
|
||||
let description = [{
|
||||
Re-arranged layouts of tensors used as matrix multiplication operands so as to promote the use of
|
||||
hardware-accelerated transpositions.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUFuseTranspositionsPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> {
|
||||
let summary = "coalesce";
|
||||
|
||||
@@ -49,26 +86,16 @@ def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> {
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
|
||||
let summary = "combine triton gpu ops";
|
||||
def TritonGPURemoveLayoutConversions : Pass<"tritongpu-remove-layout-conversions", "mlir::ModuleOp"> {
|
||||
let summary = "remove superfluous layout conversions";
|
||||
|
||||
let description = [{
|
||||
convert_layout(convert_layout(%src, #LAYOUT_0), #LAYOUT_1) =>
|
||||
convert_layout(%src, #LAYOUT_1)
|
||||
|
||||
convert_layout(%src, #LAYOUT) => %src if %src.layout() == #LAYOUT
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUCombineOpsPass()";
|
||||
let constructor = "mlir::createTritonGPURemoveLayoutConversionsPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUReorderInstructions: Pass<"tritongpu-reorder-instructions", "mlir::ModuleOp"> {
|
||||
@@ -95,19 +122,6 @@ def TritonGPUDecomposeConversions: Pass<"tritongpu-decompose-conversions", "mlir
|
||||
"mlir::triton::TritonDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {
|
||||
let summary = "canonicalize scf.ForOp ops";
|
||||
|
||||
let description = [{
|
||||
This implements some optimizations that are missing in the standard scf.ForOp
|
||||
canonicalizer.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUCanonicalizeLoopsPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
|
||||
}
|
||||
|
||||
def UpdateMmaForVolta : Pass<"tritongpu-update-mma-for-volta", "mlir::ModuleOp"> {
|
||||
let summary = "Update mma encodings for Volta";
|
||||
|
||||
|
||||
@@ -790,6 +790,141 @@ struct TritonGPUInferLayoutInterface
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Canonicalizer
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
// we don't handle conversions to DotOperandEncodingAttr
|
||||
// this is a heuristics to accommodate fused attention
|
||||
auto srcType = op.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = op.getType().cast<RankedTensorType>();
|
||||
if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>() &&
|
||||
srcType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
||||
return mlir::failure();
|
||||
// convert to the same layout -- we can delete
|
||||
if (op->getResultTypes() == op->getOperandTypes()) {
|
||||
rewriter.replaceOp(op, op->getOperands());
|
||||
return mlir::success();
|
||||
}
|
||||
Operation *arg = op->getOperand(0).getDefiningOp();
|
||||
// block argument
|
||||
if (!arg)
|
||||
return mlir::failure();
|
||||
// cvt(view) -> view
|
||||
if (auto view = dyn_cast<triton::ViewOp>(arg)) {
|
||||
rewriter.replaceOpWithNewOp<triton::ViewOp>(op, op->getResult(0).getType(),
|
||||
view.getResult());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2)
|
||||
auto alloc_tensor = dyn_cast<triton::gpu::AllocTensorOp>(arg);
|
||||
if (alloc_tensor) {
|
||||
if (!isSharedEncoding(op->getResult(0))) {
|
||||
return mlir::failure();
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::AllocTensorOp>(
|
||||
op, op->getResult(0).getType());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2))
|
||||
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
|
||||
if (insert_slice) {
|
||||
if (!isSharedEncoding(op->getResult(0))) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto newType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
// Ensure that the new insert_slice op is placed in the same place as the
|
||||
// old insert_slice op. Otherwise, the new insert_slice op may be placed
|
||||
// after the async_wait op, which is not allowed.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(insert_slice);
|
||||
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newType, insert_slice.getDst());
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>(
|
||||
op, newType, insert_slice.getSrc(), newArg.getResult(),
|
||||
insert_slice.getIndex(), insert_slice.getMask(),
|
||||
insert_slice.getOther(), insert_slice.getCache(),
|
||||
insert_slice.getEvict(), insert_slice.getIsVolatile(),
|
||||
insert_slice.getAxis());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2))
|
||||
auto extract_slice = dyn_cast<tensor::ExtractSliceOp>(arg);
|
||||
if (extract_slice) {
|
||||
if (!isSharedEncoding(op->getResult(0))) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto origType =
|
||||
extract_slice.getSource().getType().cast<RankedTensorType>();
|
||||
auto newType = RankedTensorType::get(
|
||||
origType.getShape(), origType.getElementType(),
|
||||
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
|
||||
auto origResType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
auto resType = RankedTensorType::get(
|
||||
origResType.getShape(), origResType.getElementType(),
|
||||
extract_slice.getType().cast<RankedTensorType>().getEncoding());
|
||||
// Ensure that the new extract_slice op is placed in the same place as the
|
||||
// old extract_slice op. Otherwise, the new extract_slice op may be placed
|
||||
// after the async_wait op, which is not allowed.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(extract_slice);
|
||||
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newType, extract_slice.getSource());
|
||||
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
|
||||
op, resType, newArg.getResult(), extract_slice.offsets(),
|
||||
extract_slice.sizes(), extract_slice.strides(),
|
||||
extract_slice.static_offsets(), extract_slice.static_sizes(),
|
||||
extract_slice.static_strides());
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// cvt(cvt(x, type1), type2) -> cvt(x, type2)
|
||||
if (llvm::isa<triton::gpu::ConvertLayoutOp>(arg)) {
|
||||
if (arg->getOperand(0).getDefiningOp() &&
|
||||
!isSharedEncoding(arg->getOperand(0)) &&
|
||||
isSharedEncoding(op.getOperand()) &&
|
||||
!isSharedEncoding(op.getResult())) {
|
||||
return mlir::failure();
|
||||
}
|
||||
if (isSharedEncoding(op.getOperand()) && isSharedEncoding(op.getResult())) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto srcType = op.getOperand().getType().cast<RankedTensorType>();
|
||||
auto srcShared =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
if (srcShared && srcShared.getVec() > 1)
|
||||
return mlir::failure();
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
||||
op, op->getResultTypes().front(), arg->getOperand(0));
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(type1, splat(type2, x)) -> splat(type1, x)
|
||||
if (auto splat = llvm::dyn_cast<triton::SplatOp>(arg)) {
|
||||
rewriter.replaceOpWithNewOp<triton::SplatOp>(op, op->getResultTypes(),
|
||||
splat.getSrc());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(type1, make_range(type2, x)) -> make_range(type1, x)
|
||||
if (auto range = llvm::dyn_cast<triton::MakeRangeOp>(arg)) {
|
||||
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
|
||||
op, op->getResultTypes(), range.getStart(), range.getEnd());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(type, constant) -> constant
|
||||
if (auto cst = llvm::dyn_cast<arith::ConstantOp>(arg))
|
||||
if (auto ret = cst.getValue().dyn_cast<SplatElementsAttr>()) {
|
||||
auto newRet = SplatElementsAttr::get(op->getResultTypes().front(),
|
||||
ret.getSplatValue<Attribute>());
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newRet);
|
||||
return mlir::success();
|
||||
}
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void TritonGPUDialect::initialize() {
|
||||
addAttributes<
|
||||
#define GET_ATTRDEF_LIST
|
||||
|
||||
212
lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Normal file
212
lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Normal file
@@ -0,0 +1,212 @@
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include <memory>
|
||||
|
||||
using namespace mlir;
|
||||
namespace {
|
||||
using triton::DotOp;
|
||||
using triton::gpu::ConvertLayoutOp;
|
||||
using triton::gpu::DotOperandEncodingAttr;
|
||||
using triton::gpu::MmaEncodingAttr;
|
||||
using triton::gpu::SliceEncodingAttr;
|
||||
|
||||
int computeCapabilityToMMAVersion(int computeCapability) {
|
||||
if (computeCapability < 70) {
|
||||
return 0;
|
||||
} else if (computeCapability < 80) {
|
||||
return 1;
|
||||
} else if (computeCapability < 90) {
|
||||
return 2;
|
||||
} else if (computeCapability < 100) {
|
||||
// FIXME: temporarily add this to pass unis tests
|
||||
return 2;
|
||||
} else {
|
||||
assert(false && "computeCapability > 100 not supported");
|
||||
return 3;
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 2> mmaVersionToShapePerWarp(int version) {
|
||||
if (version == 1)
|
||||
return {16, 16};
|
||||
else if (version == 2)
|
||||
return {16, 8};
|
||||
else {
|
||||
assert(false && "version not supported");
|
||||
return {0, 0};
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2> warpsPerTileV1(const ArrayRef<int64_t> shape,
|
||||
int numWarps) {
|
||||
// Set a default value that ensures product of wpt equals numWarps
|
||||
return {static_cast<unsigned>(numWarps), 1};
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
|
||||
const ArrayRef<int64_t> shape,
|
||||
int numWarps) {
|
||||
SetVector<Operation *> slices;
|
||||
mlir::getForwardSlice(dotOp.getResult(), &slices);
|
||||
if (llvm::find_if(slices, [](Operation *op) {
|
||||
return isa<triton::DotOp>(op);
|
||||
}) != slices.end())
|
||||
return {(unsigned)numWarps, 1};
|
||||
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
|
||||
bool changed = false;
|
||||
// TODO (@daadaada): double-check.
|
||||
// original logic in
|
||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
|
||||
// seems buggy for shape = [32, 16] ?
|
||||
do {
|
||||
changed = false;
|
||||
if (ret[0] * ret[1] >= numWarps)
|
||||
break;
|
||||
if (shape[0] / shapePerWarp[0] / ret[0] >=
|
||||
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
|
||||
if (ret[0] < shape[0] / shapePerWarp[0]) {
|
||||
ret[0] *= 2;
|
||||
} else
|
||||
ret[1] *= 2;
|
||||
} else {
|
||||
ret[1] *= 2;
|
||||
}
|
||||
} while (true);
|
||||
return ret;
|
||||
}
|
||||
|
||||
class BlockedToMMA : public mlir::RewritePattern {
|
||||
int computeCapability;
|
||||
mutable int mmaV1Counter{}; // used to generate ID for MMAv1 encoding
|
||||
|
||||
public:
|
||||
BlockedToMMA(mlir::MLIRContext *context, int computeCapability)
|
||||
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
|
||||
computeCapability(computeCapability) {}
|
||||
|
||||
static SmallVector<unsigned, 2> getWarpsPerTile(triton::DotOp dotOp,
|
||||
const ArrayRef<int64_t> shape,
|
||||
int version, int numWarps) {
|
||||
switch (version) {
|
||||
case 1:
|
||||
return warpsPerTileV1(shape, numWarps);
|
||||
case 2:
|
||||
return warpsPerTileV2(dotOp, shape, numWarps);
|
||||
default:
|
||||
llvm_unreachable("unsupported MMA version");
|
||||
}
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto dotOp = cast<triton::DotOp>(op);
|
||||
// TODO: Check data-types and SM compatibility
|
||||
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
|
||||
if (!oldRetType.getEncoding() ||
|
||||
oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
||||
return failure();
|
||||
|
||||
// for FMA, should retain the blocked layout.
|
||||
int versionMajor = computeCapabilityToMMAVersion(computeCapability);
|
||||
if (!supportMMA(dotOp, versionMajor))
|
||||
return failure();
|
||||
|
||||
// get MMA encoding for the given number of warps
|
||||
auto retShape = oldRetType.getShape();
|
||||
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
|
||||
auto warpsPerTile =
|
||||
getWarpsPerTile(dotOp, retShape, versionMajor, numWarps);
|
||||
triton::gpu::MmaEncodingAttr mmaEnc;
|
||||
if (versionMajor == 1) {
|
||||
mmaEnc = triton::gpu::MmaEncodingAttr::get(
|
||||
oldRetType.getContext(), versionMajor, numWarps, mmaV1Counter++);
|
||||
} else if (versionMajor == 2) {
|
||||
mmaEnc = triton::gpu::MmaEncodingAttr::get(
|
||||
oldRetType.getContext(), versionMajor, 0 /*versionMinor*/,
|
||||
warpsPerTile);
|
||||
} else {
|
||||
llvm_unreachable("Mma layout only supports versionMajor in {1, 2}");
|
||||
}
|
||||
auto newRetType =
|
||||
RankedTensorType::get(retShape, oldRetType.getElementType(), mmaEnc);
|
||||
|
||||
// convert accumulator
|
||||
auto oldAcc = dotOp.getOperand(2);
|
||||
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
oldAcc.getLoc(), newRetType, oldAcc);
|
||||
Value a = dotOp.getA();
|
||||
Value b = dotOp.getB();
|
||||
auto oldAType = a.getType().cast<RankedTensorType>();
|
||||
auto oldBType = b.getType().cast<RankedTensorType>();
|
||||
auto oldAOrder = oldAType.getEncoding()
|
||||
.cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent()
|
||||
.cast<triton::gpu::BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
auto oldBOrder = oldBType.getEncoding()
|
||||
.cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent()
|
||||
.cast<triton::gpu::BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
Attribute isMMAv1RowA;
|
||||
Attribute isMMAv1RowB;
|
||||
if (versionMajor == 1) {
|
||||
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
|
||||
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
|
||||
}
|
||||
|
||||
auto newAType = RankedTensorType::get(
|
||||
oldAType.getShape(), oldAType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(
|
||||
oldAType.getContext(), 0, newRetType.getEncoding(), isMMAv1RowA));
|
||||
auto newBType = RankedTensorType::get(
|
||||
oldBType.getShape(), oldBType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(
|
||||
oldBType.getContext(), 1, newRetType.getEncoding(), isMMAv1RowB));
|
||||
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
||||
auto newDot = rewriter.create<triton::DotOp>(
|
||||
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.getAllowTF32());
|
||||
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
||||
op, oldRetType, newDot.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
class TritonGPUAccelerateMatmulPass
|
||||
: public TritonGPUAccelerateMatmulBase<TritonGPUAccelerateMatmulPass> {
|
||||
public:
|
||||
TritonGPUAccelerateMatmulPass() = default;
|
||||
TritonGPUAccelerateMatmulPass(int computeCapability) {
|
||||
this->computeCapability = computeCapability;
|
||||
}
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp m = getOperation();
|
||||
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
patterns.add<::BlockedToMMA>(context, computeCapability);
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
mlir::createTritonGPUAccelerateMatmulPass(int computeCapability) {
|
||||
return std::make_unique<TritonGPUAccelerateMatmulPass>(computeCapability);
|
||||
}
|
||||
@@ -1,22 +1,18 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Combine.td)
|
||||
mlir_tablegen(TritonGPUCombine.inc -gen-rewriters)
|
||||
add_public_tablegen_target(TritonGPUCombineIncGen)
|
||||
|
||||
add_mlir_dialect_library(TritonGPUTransforms
|
||||
AccelerateMatmul.cpp
|
||||
Coalesce.cpp
|
||||
CanonicalizeLoops.cpp
|
||||
Combine.cpp
|
||||
DecomposeConversions.cpp
|
||||
FuseTranspositions.cpp
|
||||
Pipeline.cpp
|
||||
Prefetch.cpp
|
||||
RemoveLayoutConversions.cpp
|
||||
ReorderInstructions.cpp
|
||||
DecomposeConversions.cpp
|
||||
TritonGPUConversion.cpp
|
||||
UpdateMmaForVolta.cpp
|
||||
Utility.cpp
|
||||
|
||||
DEPENDS
|
||||
TritonGPUTransformsIncGen
|
||||
TritonGPUCombineIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
TritonIR
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
namespace {
|
||||
|
||||
struct CanonicalizePass
|
||||
: public TritonGPUCanonicalizeLoopsBase<CanonicalizePass> {
|
||||
CanonicalizePass() = default;
|
||||
|
||||
void runOnOperation() override {
|
||||
|
||||
// Canonicalize pass may have created dead code that
|
||||
// standard scf.for canonicalization cannot handle
|
||||
// as of LLVM 14. For example, the iteration arguments
|
||||
// for the pointer of the synchronous loads that are
|
||||
// discarded.
|
||||
// The following piece of code is a workaround to
|
||||
// very crudely remove dead code, by making an iteration
|
||||
// argument yield itself if it is not used to create
|
||||
// side effects anywhere.
|
||||
getOperation()->walk([&](scf::ForOp forOp) -> void {
|
||||
for (size_t i = 0; i < forOp.getNumResults(); ++i) {
|
||||
// condition 1: no other iter arguments depend on it
|
||||
SetVector<Operation *> fwdSlice;
|
||||
mlir::getForwardSlice(forOp.getRegionIterArgs()[i], &fwdSlice);
|
||||
Operation *yieldOp = forOp.getBody()->getTerminator();
|
||||
bool noOtherDependency = std::all_of(
|
||||
yieldOp->operand_begin(), yieldOp->operand_end(), [&](Value arg) {
|
||||
return arg == yieldOp->getOperand(i) ||
|
||||
!fwdSlice.contains(arg.getDefiningOp());
|
||||
});
|
||||
// condition 2: final value is not used after the loop
|
||||
auto retVal = forOp.getResult(i);
|
||||
bool noUserAfterLoop = retVal.getUsers().empty();
|
||||
// yielding the region iter arg will cause loop canonicalization
|
||||
// to clean up the dead code
|
||||
if (noOtherDependency && noUserAfterLoop) {
|
||||
yieldOp->setOperand(i, forOp.getRegionIterArgs()[i]);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUCanonicalizeLoopsPass() {
|
||||
return std::make_unique<CanonicalizePass>();
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
#ifndef TRITONGPU_PATTERNS
|
||||
#define TRITONGPU_PATTERNS
|
||||
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUOps.td"
|
||||
include "triton/Dialect/Triton/IR/TritonOps.td"
|
||||
include "mlir/IR/PatternBase.td"
|
||||
|
||||
#endif
|
||||
153
lib/Dialect/TritonGPU/Transforms/FuseTranspositions.cpp
Normal file
153
lib/Dialect/TritonGPU/Transforms/FuseTranspositions.cpp
Normal file
@@ -0,0 +1,153 @@
|
||||
#include "Utility.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include <memory>
|
||||
|
||||
using namespace mlir;
|
||||
namespace {
|
||||
using triton::DotOp;
|
||||
using triton::gpu::ConvertLayoutOp;
|
||||
using triton::gpu::DotOperandEncodingAttr;
|
||||
using triton::gpu::MmaEncodingAttr;
|
||||
using triton::gpu::SliceEncodingAttr;
|
||||
|
||||
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
|
||||
public:
|
||||
explicit OptimizeConvertToDotOperand(mlir::MLIRContext *context)
|
||||
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
||||
// order
|
||||
ArrayRef<unsigned> order;
|
||||
if (auto srcBlockedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
||||
order = srcBlockedLayout.getOrder();
|
||||
else if (auto srcSharedLayout =
|
||||
srcType.getEncoding()
|
||||
.dyn_cast<triton::gpu::SharedEncodingAttr>())
|
||||
order = srcSharedLayout.getOrder();
|
||||
else
|
||||
return failure();
|
||||
// dot operand output
|
||||
auto dstDotOperandLayout =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (!dstDotOperandLayout)
|
||||
return failure();
|
||||
if (!dstDotOperandLayout.getIsMMAv1Row())
|
||||
return failure();
|
||||
bool isMMAv1Row =
|
||||
dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row))
|
||||
return failure();
|
||||
|
||||
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
|
||||
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
op->getContext(), dstDotOperandLayout.getOpIdx(),
|
||||
dstDotOperandLayout.getParent(), newIsRow);
|
||||
auto newDstType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(), newDstEncoding);
|
||||
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newDstType, cvt.getOperand());
|
||||
rewriter.replaceOp(op, newCvt.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// convert(trans(convert(arg)))
|
||||
// x = convert_layout arg: #distributed -> #shared_x
|
||||
// y = trans x: #shared_x -> #shared_y
|
||||
// z = convert_layout y: #shared_y -> #dot_operand
|
||||
class ConvertTransConvert : public mlir::RewritePattern {
|
||||
|
||||
public:
|
||||
ConvertTransConvert(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
1, context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto dstOp = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto tmpOp =
|
||||
dyn_cast_or_null<triton::TransOp>(dstOp.getSrc().getDefiningOp());
|
||||
if (!tmpOp)
|
||||
return mlir::failure();
|
||||
auto srcOp = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(
|
||||
tmpOp.getSrc().getDefiningOp());
|
||||
if (!srcOp)
|
||||
return mlir::failure();
|
||||
auto arg = srcOp.getSrc();
|
||||
auto X = tmpOp.getSrc();
|
||||
// types
|
||||
auto argType = arg.getType().cast<RankedTensorType>();
|
||||
auto XType = X.getType().cast<RankedTensorType>();
|
||||
auto ZType = dstOp.getResult().getType().cast<RankedTensorType>();
|
||||
// encodings
|
||||
auto argEncoding = argType.getEncoding();
|
||||
auto XEncoding =
|
||||
XType.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
|
||||
auto ZEncoding =
|
||||
ZType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (!ZEncoding)
|
||||
return mlir::failure();
|
||||
// new X encoding
|
||||
auto newXOrder = triton::gpu::getOrder(argEncoding);
|
||||
auto newXEncoding = triton::gpu::SharedEncodingAttr::get(
|
||||
getContext(), ZEncoding, XType.getShape(), newXOrder,
|
||||
XType.getElementType());
|
||||
auto newXType = RankedTensorType::get(XType.getShape(),
|
||||
XType.getElementType(), newXEncoding);
|
||||
if (XEncoding == newXEncoding)
|
||||
return mlir::failure();
|
||||
|
||||
auto newX = rewriter.create<triton::gpu::ConvertLayoutOp>(srcOp.getLoc(),
|
||||
newXType, arg);
|
||||
auto newY = rewriter.create<triton::TransOp>(tmpOp.getLoc(), newX);
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(dstOp, ZType,
|
||||
newY);
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
class TritonGPUFuseTranspositionsPass
|
||||
: public TritonGPUFuseTranspositionsBase<TritonGPUFuseTranspositionsPass> {
|
||||
public:
|
||||
TritonGPUFuseTranspositionsPass() = default;
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp m = getOperation();
|
||||
|
||||
mlir::PassManager pm(m.getContext());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
auto ret = pm.run(m);
|
||||
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
patterns.add<OptimizeConvertToDotOperand>(context);
|
||||
patterns.add<ConvertTransConvert>(context);
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
|
||||
signalPassFailure();
|
||||
if (fixupLoops(m).failed())
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUFuseTranspositionsPass() {
|
||||
return std::make_unique<TritonGPUFuseTranspositionsPass>();
|
||||
}
|
||||
@@ -22,7 +22,6 @@
|
||||
|
||||
using namespace mlir;
|
||||
namespace {
|
||||
#include "TritonGPUCombine.inc"
|
||||
using triton::DotOp;
|
||||
using triton::gpu::ConvertLayoutOp;
|
||||
using triton::gpu::DotOperandEncodingAttr;
|
||||
@@ -139,132 +138,7 @@ public:
|
||||
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op))
|
||||
return mlir::failure();
|
||||
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
// we don't handle conversions to DotOperandEncodingAttr
|
||||
// this is a heuristics to accommodate fused attention
|
||||
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = convert.getType().cast<RankedTensorType>();
|
||||
if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>() &&
|
||||
srcType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
||||
return mlir::failure();
|
||||
// convert to the same layout -- we can delete
|
||||
if (op->getResultTypes() == op->getOperandTypes()) {
|
||||
rewriter.replaceOp(op, op->getOperands());
|
||||
return mlir::success();
|
||||
}
|
||||
Operation *arg = op->getOperand(0).getDefiningOp();
|
||||
// block argument
|
||||
if (!arg)
|
||||
return mlir::failure();
|
||||
// cvt(view) -> view
|
||||
if (auto view = dyn_cast<triton::ViewOp>(arg)) {
|
||||
rewriter.replaceOpWithNewOp<triton::ViewOp>(
|
||||
op, op->getResult(0).getType(), view.getResult());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2)
|
||||
auto alloc_tensor = dyn_cast<triton::gpu::AllocTensorOp>(arg);
|
||||
if (alloc_tensor) {
|
||||
if (!isSharedEncoding(op->getResult(0))) {
|
||||
return mlir::failure();
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::AllocTensorOp>(
|
||||
op, op->getResult(0).getType());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2))
|
||||
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
|
||||
if (insert_slice) {
|
||||
if (!isSharedEncoding(op->getResult(0))) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto newType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
// Ensure that the new insert_slice op is placed in the same place as the
|
||||
// old insert_slice op. Otherwise, the new insert_slice op may be placed
|
||||
// after the async_wait op, which is not allowed.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(insert_slice);
|
||||
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newType, insert_slice.getDst());
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>(
|
||||
op, newType, insert_slice.getSrc(), newArg.getResult(),
|
||||
insert_slice.getIndex(), insert_slice.getMask(),
|
||||
insert_slice.getOther(), insert_slice.getCache(),
|
||||
insert_slice.getEvict(), insert_slice.getIsVolatile(),
|
||||
insert_slice.getAxis());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2))
|
||||
auto extract_slice = dyn_cast<tensor::ExtractSliceOp>(arg);
|
||||
if (extract_slice) {
|
||||
if (!isSharedEncoding(op->getResult(0))) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto origType =
|
||||
extract_slice.getSource().getType().cast<RankedTensorType>();
|
||||
auto newType = RankedTensorType::get(
|
||||
origType.getShape(), origType.getElementType(),
|
||||
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
|
||||
auto origResType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
auto resType = RankedTensorType::get(
|
||||
origResType.getShape(), origResType.getElementType(),
|
||||
extract_slice.getType().cast<RankedTensorType>().getEncoding());
|
||||
// Ensure that the new extract_slice op is placed in the same place as the
|
||||
// old extract_slice op. Otherwise, the new extract_slice op may be placed
|
||||
// after the async_wait op, which is not allowed.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(extract_slice);
|
||||
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newType, extract_slice.getSource());
|
||||
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
|
||||
op, resType, newArg.getResult(), extract_slice.offsets(),
|
||||
extract_slice.sizes(), extract_slice.strides(),
|
||||
extract_slice.static_offsets(), extract_slice.static_sizes(),
|
||||
extract_slice.static_strides());
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// cvt(cvt(x, type1), type2) -> cvt(x, type2)
|
||||
if (llvm::isa<triton::gpu::ConvertLayoutOp>(arg)) {
|
||||
if (arg->getOperand(0).getDefiningOp() &&
|
||||
!isSharedEncoding(arg->getOperand(0)) &&
|
||||
isSharedEncoding(convert.getOperand()) &&
|
||||
!isSharedEncoding(convert.getResult())) {
|
||||
return mlir::failure();
|
||||
}
|
||||
if (isSharedEncoding(convert.getOperand()) &&
|
||||
isSharedEncoding(convert.getResult())) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
|
||||
auto srcShared =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
if (srcShared && srcShared.getVec() > 1)
|
||||
return mlir::failure();
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
||||
op, op->getResultTypes().front(), arg->getOperand(0));
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(type1, splat(type2, x)) -> splat(type1, x)
|
||||
if (auto splat = llvm::dyn_cast<triton::SplatOp>(arg)) {
|
||||
rewriter.replaceOpWithNewOp<triton::SplatOp>(op, op->getResultTypes(),
|
||||
splat.getSrc());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(type1, make_range(type2, x)) -> make_range(type1, x)
|
||||
if (auto range = llvm::dyn_cast<triton::MakeRangeOp>(arg)) {
|
||||
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
|
||||
op, op->getResultTypes(), range.getStart(), range.getEnd());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(type, constant) -> constant
|
||||
if (auto cst = llvm::dyn_cast<arith::ConstantOp>(arg))
|
||||
if (auto ret = cst.getValue().dyn_cast<SplatElementsAttr>()) {
|
||||
auto newRet = SplatElementsAttr::get(op->getResultTypes().front(),
|
||||
ret.getSplatValue<Attribute>());
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newRet);
|
||||
return mlir::success();
|
||||
}
|
||||
return mlir::failure();
|
||||
return ConvertLayoutOp::canonicalize(convert, rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -568,9 +442,9 @@ public:
|
||||
};
|
||||
|
||||
//
|
||||
class FoldConvertAndReduce : public mlir::RewritePattern {
|
||||
class RematerializeForward : public mlir::RewritePattern {
|
||||
public:
|
||||
explicit FoldConvertAndReduce(mlir::MLIRContext *context)
|
||||
explicit RematerializeForward(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
1, context) {}
|
||||
|
||||
@@ -837,390 +711,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
class RematerializeForward : public mlir::RewritePattern {
|
||||
public:
|
||||
explicit RematerializeForward(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
2, context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *_cvtOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cvt = cast<triton::gpu::ConvertLayoutOp>(_cvtOp);
|
||||
auto forOp = dyn_cast<scf::ForOp>(cvt->getParentOp());
|
||||
if (!forOp)
|
||||
return mlir::failure();
|
||||
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
|
||||
|
||||
SetVector<Operation *> cvtSlices;
|
||||
auto filter = [&](Operation *op) {
|
||||
return isInLoop(op) &&
|
||||
!isa<triton::LoadOp, triton::StoreOp, triton::AtomicRMWOp,
|
||||
triton::AtomicCASOp>(op) &&
|
||||
!isa<triton::DotOp>(op) && !isa<scf::YieldOp>(op) &&
|
||||
!isa<triton::gpu::ConvertLayoutOp>(op);
|
||||
};
|
||||
mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter);
|
||||
if (cvtSlices.empty())
|
||||
return failure();
|
||||
|
||||
for (Operation *op : cvtSlices) {
|
||||
if (!isa<triton::ViewOp, triton::CatOp>(op) &&
|
||||
!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
|
||||
!op->hasTrait<mlir::OpTrait::Elementwise>() &&
|
||||
!isa<triton::StoreOp>(op))
|
||||
return failure();
|
||||
for (Value arg : op->getOperands()) {
|
||||
Operation *argOp = arg.getDefiningOp();
|
||||
if (argOp && (argOp != cvt) &&
|
||||
!isa<arith::ConstantOp, triton::SplatOp, triton::MakeRangeOp>(
|
||||
argOp)) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise, we push the conversion forward
|
||||
// since we'll be able to move it out of
|
||||
// the loop once it reaches the yield op
|
||||
pushConversionForward(cvt, cvtSlices, rewriter);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
namespace {
|
||||
int computeCapabilityToMMAVersion(int computeCapability) {
|
||||
if (computeCapability < 70) {
|
||||
return 0;
|
||||
} else if (computeCapability < 80) {
|
||||
return 1;
|
||||
} else if (computeCapability < 90) {
|
||||
return 2;
|
||||
} else if (computeCapability < 100) {
|
||||
// FIXME: temporarily add this to pass unis tests
|
||||
return 2;
|
||||
} else {
|
||||
assert(false && "computeCapability > 100 not supported");
|
||||
return 3;
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 2> mmaVersionToShapePerWarp(int version) {
|
||||
if (version == 1)
|
||||
return {16, 16};
|
||||
else if (version == 2)
|
||||
return {16, 8};
|
||||
else {
|
||||
assert(false && "version not supported");
|
||||
return {0, 0};
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2> warpsPerTileV1(const ArrayRef<int64_t> shape,
|
||||
int numWarps) {
|
||||
// Set a default value and ensure product of wpt equals numWarps
|
||||
return {static_cast<unsigned>(numWarps), 1};
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
|
||||
const ArrayRef<int64_t> shape,
|
||||
int numWarps) {
|
||||
SetVector<Operation *> slices;
|
||||
mlir::getForwardSlice(dotOp.getResult(), &slices);
|
||||
if (llvm::find_if(slices, [](Operation *op) {
|
||||
return isa<triton::DotOp>(op);
|
||||
}) != slices.end())
|
||||
return {(unsigned)numWarps, 1};
|
||||
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
|
||||
bool changed = false;
|
||||
// TODO (@daadaada): double-check.
|
||||
// original logic in
|
||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
|
||||
// seems buggy for shape = [32, 16] ?
|
||||
do {
|
||||
changed = false;
|
||||
if (ret[0] * ret[1] >= numWarps)
|
||||
break;
|
||||
if (shape[0] / shapePerWarp[0] / ret[0] >=
|
||||
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
|
||||
if (ret[0] < shape[0] / shapePerWarp[0]) {
|
||||
ret[0] *= 2;
|
||||
} else
|
||||
ret[1] *= 2;
|
||||
} else {
|
||||
ret[1] *= 2;
|
||||
}
|
||||
} while (true);
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class OptimizeBlockedToShared : public mlir::RewritePattern {
|
||||
public:
|
||||
explicit OptimizeBlockedToShared(mlir::MLIRContext *context)
|
||||
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
||||
auto srcBlockedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||
auto dstSharedLayout =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
if (!srcBlockedLayout || !dstSharedLayout)
|
||||
return failure();
|
||||
if (srcBlockedLayout.getOrder() == dstSharedLayout.getOrder())
|
||||
return failure();
|
||||
// For now only works if single use is transpose
|
||||
// TODO: rematerialize #shared uses
|
||||
auto users = op->getUsers();
|
||||
if (std::distance(users.begin(), users.end()) != 1 ||
|
||||
!isa<triton::TransOp>(*users.begin()))
|
||||
return failure();
|
||||
|
||||
auto tmpShared = triton::gpu::SharedEncodingAttr::get(
|
||||
op->getContext(), dstSharedLayout.getVec(),
|
||||
dstSharedLayout.getPerPhase(), dstSharedLayout.getMaxPhase(),
|
||||
srcBlockedLayout.getOrder());
|
||||
auto tmpType = RankedTensorType::get(srcType.getShape(),
|
||||
srcType.getElementType(), tmpShared);
|
||||
auto tmpCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), tmpType, cvt.getOperand());
|
||||
|
||||
auto newDstType = RankedTensorType::get(
|
||||
users.begin()->getResultTypes()[0].cast<RankedTensorType>().getShape(),
|
||||
srcType.getElementType(), dstSharedLayout);
|
||||
|
||||
auto newTrans = rewriter.create<triton::TransOp>(op->getLoc(), newDstType,
|
||||
tmpCvt.getResult());
|
||||
|
||||
rewriter.replaceOp(*users.begin(), newTrans.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
|
||||
public:
|
||||
explicit OptimizeConvertToDotOperand(mlir::MLIRContext *context)
|
||||
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
||||
// order
|
||||
ArrayRef<unsigned> order;
|
||||
if (auto srcBlockedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
||||
order = srcBlockedLayout.getOrder();
|
||||
else if (auto srcSharedLayout =
|
||||
srcType.getEncoding()
|
||||
.dyn_cast<triton::gpu::SharedEncodingAttr>())
|
||||
order = srcSharedLayout.getOrder();
|
||||
else
|
||||
return failure();
|
||||
// dot operand output
|
||||
auto dstDotOperandLayout =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (!dstDotOperandLayout)
|
||||
return failure();
|
||||
if (!dstDotOperandLayout.getIsMMAv1Row())
|
||||
return failure();
|
||||
bool isMMAv1Row =
|
||||
dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row))
|
||||
return failure();
|
||||
|
||||
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
|
||||
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
op->getContext(), dstDotOperandLayout.getOpIdx(),
|
||||
dstDotOperandLayout.getParent(), newIsRow);
|
||||
auto newDstType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(), newDstEncoding);
|
||||
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newDstType, cvt.getOperand());
|
||||
rewriter.replaceOp(op, newCvt.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class BlockedToMMA : public mlir::RewritePattern {
|
||||
int computeCapability;
|
||||
mutable int mmaV1Counter{}; // used to generate ID for MMAv1 encoding
|
||||
|
||||
public:
|
||||
BlockedToMMA(mlir::MLIRContext *context, int computeCapability)
|
||||
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
|
||||
computeCapability(computeCapability) {}
|
||||
|
||||
static SmallVector<unsigned, 2> getWarpsPerTile(triton::DotOp dotOp,
|
||||
const ArrayRef<int64_t> shape,
|
||||
int version, int numWarps) {
|
||||
switch (version) {
|
||||
case 1:
|
||||
return warpsPerTileV1(shape, numWarps);
|
||||
case 2:
|
||||
return warpsPerTileV2(dotOp, shape, numWarps);
|
||||
default:
|
||||
assert(false && "not supported version");
|
||||
return {0, 0};
|
||||
}
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto dotOp = cast<triton::DotOp>(op);
|
||||
// TODO: Check data-types and SM compatibility
|
||||
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
|
||||
if (!oldRetType.getEncoding() ||
|
||||
oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
||||
return failure();
|
||||
|
||||
// for FMA, should retain the blocked layout.
|
||||
int versionMajor = computeCapabilityToMMAVersion(computeCapability);
|
||||
if (!supportMMA(dotOp, versionMajor))
|
||||
return failure();
|
||||
|
||||
// get MMA encoding for the given number of warps
|
||||
auto retShape = oldRetType.getShape();
|
||||
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
|
||||
auto warpsPerTile =
|
||||
getWarpsPerTile(dotOp, retShape, versionMajor, numWarps);
|
||||
triton::gpu::MmaEncodingAttr mmaEnc;
|
||||
if (versionMajor == 1) {
|
||||
mmaEnc = triton::gpu::MmaEncodingAttr::get(
|
||||
oldRetType.getContext(), versionMajor, numWarps, mmaV1Counter++);
|
||||
} else if (versionMajor == 2) {
|
||||
mmaEnc = triton::gpu::MmaEncodingAttr::get(
|
||||
oldRetType.getContext(), versionMajor, 0 /*versionMinor*/,
|
||||
warpsPerTile);
|
||||
} else {
|
||||
assert(false && "Mma layout only support versionMajor of 1 or 2");
|
||||
}
|
||||
auto newRetType =
|
||||
RankedTensorType::get(retShape, oldRetType.getElementType(), mmaEnc);
|
||||
|
||||
// convert accumulator
|
||||
auto oldAcc = dotOp.getOperand(2);
|
||||
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
oldAcc.getLoc(), newRetType, oldAcc);
|
||||
Value a = dotOp.getA();
|
||||
Value b = dotOp.getB();
|
||||
auto oldAType = a.getType().cast<RankedTensorType>();
|
||||
auto oldBType = b.getType().cast<RankedTensorType>();
|
||||
auto oldAOrder = oldAType.getEncoding()
|
||||
.cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent()
|
||||
.cast<triton::gpu::BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
auto oldBOrder = oldBType.getEncoding()
|
||||
.cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent()
|
||||
.cast<triton::gpu::BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
Attribute isMMAv1RowA;
|
||||
Attribute isMMAv1RowB;
|
||||
if (versionMajor == 1) {
|
||||
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
|
||||
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
|
||||
}
|
||||
|
||||
auto newAType = RankedTensorType::get(
|
||||
oldAType.getShape(), oldAType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(
|
||||
oldAType.getContext(), 0, newRetType.getEncoding(), isMMAv1RowA));
|
||||
auto newBType = RankedTensorType::get(
|
||||
oldBType.getShape(), oldBType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(
|
||||
oldBType.getContext(), 1, newRetType.getEncoding(), isMMAv1RowB));
|
||||
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
||||
auto newDot = rewriter.create<triton::DotOp>(
|
||||
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.getAllowTF32());
|
||||
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
||||
op, oldRetType, newDot.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Convert + trans + convert
|
||||
// x = convert_layout distributed -> #shared_x
|
||||
// y = trans x -> #shared_y
|
||||
// z = convert_layout y -> #dot_operand
|
||||
class ConvertTransConvert : public mlir::RewritePattern {
|
||||
|
||||
public:
|
||||
ConvertTransConvert(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
1, context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto dstOp = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto tmpOp =
|
||||
dyn_cast_or_null<triton::TransOp>(dstOp.getSrc().getDefiningOp());
|
||||
if (!tmpOp)
|
||||
return mlir::failure();
|
||||
auto srcOp = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(
|
||||
tmpOp.getSrc().getDefiningOp());
|
||||
if (!srcOp)
|
||||
return mlir::failure();
|
||||
auto arg = srcOp.getSrc();
|
||||
auto X = tmpOp.getSrc();
|
||||
// types
|
||||
auto argType = arg.getType().cast<RankedTensorType>();
|
||||
auto XType = X.getType().cast<RankedTensorType>();
|
||||
auto ZType = dstOp.getResult().getType().cast<RankedTensorType>();
|
||||
// encodings
|
||||
auto argEncoding = argType.getEncoding();
|
||||
auto XEncoding =
|
||||
XType.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
|
||||
auto ZEncoding =
|
||||
ZType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (!ZEncoding)
|
||||
return mlir::failure();
|
||||
// new X encoding
|
||||
auto newXOrder = triton::gpu::getOrder(argEncoding);
|
||||
auto newXEncoding = triton::gpu::SharedEncodingAttr::get(
|
||||
getContext(), ZEncoding, XType.getShape(), newXOrder,
|
||||
XType.getElementType());
|
||||
auto newXType = RankedTensorType::get(XType.getShape(),
|
||||
XType.getElementType(), newXEncoding);
|
||||
if (XEncoding == newXEncoding)
|
||||
return mlir::failure();
|
||||
|
||||
auto newX = rewriter.create<triton::gpu::ConvertLayoutOp>(srcOp.getLoc(),
|
||||
newXType, arg);
|
||||
auto newY = rewriter.create<triton::TransOp>(tmpOp.getLoc(), newX);
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(dstOp, ZType,
|
||||
newY);
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
class ConvertDotConvert : public mlir::RewritePattern {
|
||||
public:
|
||||
@@ -1272,31 +762,25 @@ public:
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
class TritonGPUCombineOpsPass
|
||||
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
|
||||
class TritonGPURemoveLayoutConversionsPass
|
||||
: public TritonGPURemoveLayoutConversionsBase<
|
||||
TritonGPURemoveLayoutConversionsPass> {
|
||||
public:
|
||||
TritonGPUCombineOpsPass() = default;
|
||||
TritonGPUCombineOpsPass(int computeCapability) {
|
||||
this->computeCapability = computeCapability;
|
||||
}
|
||||
TritonGPURemoveLayoutConversionsPass() = default;
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp m = getOperation();
|
||||
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
|
||||
patterns.add<OptimizeBlockedToShared>(context);
|
||||
patterns.add<OptimizeConvertToDotOperand>(context);
|
||||
patterns.add<SimplifyConversion>(context);
|
||||
patterns.add<SimplifyReduceCvt>(context);
|
||||
patterns.add<FoldConvertAndReduce>(context);
|
||||
patterns.add<DecomposeDotOperand>(context);
|
||||
patterns.add<RematerializeBackward>(context);
|
||||
patterns.add<RematerializeForward>(context);
|
||||
patterns.add<MoveConvertOutOfLoop>(context);
|
||||
patterns.add<MoveConvertOutOfIf>(context);
|
||||
patterns.add<BlockedToMMA>(context, computeCapability);
|
||||
patterns.add<ConvertTransConvert>(context);
|
||||
patterns.add<DecomposeDotOperand>(context);
|
||||
patterns.add<ConvertDotConvert>(context);
|
||||
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
|
||||
@@ -1309,7 +793,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
mlir::createTritonGPUCombineOpsPass(int computeCapability) {
|
||||
return std::make_unique<TritonGPUCombineOpsPass>(computeCapability);
|
||||
std::unique_ptr<Pass> mlir::createTritonGPURemoveLayoutConversionsPass() {
|
||||
return std::make_unique<TritonGPURemoveLayoutConversionsPass>();
|
||||
}
|
||||
@@ -1375,7 +1375,7 @@ void init_triton_ir(py::module &&m) {
|
||||
.def(
|
||||
"add_sccp_pass",
|
||||
[](mlir::PassManager &self) { self.addPass(mlir::createSCCPPass()); })
|
||||
.def("add_coalesce_pass",
|
||||
.def("add_tritongpu_coalesce_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUCoalescePass());
|
||||
})
|
||||
@@ -1414,10 +1414,18 @@ void init_triton_ir(py::module &&m) {
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUPrefetchPass());
|
||||
})
|
||||
.def("add_tritongpu_combine_pass",
|
||||
.def("add_tritongpu_accelerate_matmul_pass",
|
||||
[](mlir::PassManager &self, int computeCapability) {
|
||||
self.addPass(
|
||||
mlir::createTritonGPUCombineOpsPass(computeCapability));
|
||||
mlir::createTritonGPUAccelerateMatmulPass(computeCapability));
|
||||
})
|
||||
.def("add_tritongpu_fuse_transpositions_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUFuseTranspositionsPass());
|
||||
})
|
||||
.def("add_tritongpu_remove_layout_conversions_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPURemoveLayoutConversionsPass());
|
||||
})
|
||||
.def("add_tritongpu_update_mma_for_volta_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
|
||||
@@ -975,32 +975,32 @@ def ast_to_ttir(fn, signature, specialization, constants):
|
||||
return optimize_triton_ir(mod)
|
||||
|
||||
|
||||
def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability):
|
||||
def ttir_to_ttgir(mod, num_warps):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.add_convert_triton_to_tritongpu_pass(num_warps)
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def optimize_ttgir(mod, num_stages, compute_capability):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_coalesce_pass()
|
||||
# The combine pass converts blocked layout to mma layout
|
||||
# for dot ops so that pipeline can get shared memory swizzled correctly.
|
||||
pm.add_tritongpu_combine_pass(compute_capability)
|
||||
pm.add_tritongpu_coalesce_pass()
|
||||
pm.add_tritongpu_accelerate_matmul_pass(compute_capability)
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
pm.add_tritongpu_fuse_transpositions_pass()
|
||||
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||
# Prefetch must be done after pipeline pass because pipeline pass
|
||||
# extracts slices from the original tensor.
|
||||
pm.add_tritongpu_prefetch_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_tritongpu_combine_pass(compute_capability)
|
||||
pm.add_licm_pass()
|
||||
pm.add_tritongpu_combine_pass(compute_capability)
|
||||
pm.add_cse_pass()
|
||||
pm.add_tritongpu_fuse_transpositions_pass()
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
pm.add_tritongpu_decompose_conversions_pass()
|
||||
if compute_capability // 10 == 7:
|
||||
# The update_mma_for_volta pass helps to compute some information for MMA encoding specifically for MMAv1
|
||||
# NOTE this pass should be placed after all the passes those modifies mma layout
|
||||
pm.add_tritongpu_update_mma_for_volta_pass()
|
||||
pm.add_tritongpu_reorder_instructions_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_symbol_dce_pass()
|
||||
pm.add_tritongpu_reorder_instructions_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
@@ -1565,7 +1565,7 @@ def compile(fn, **kwargs):
|
||||
"ttir": (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
||||
"ttgir": (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, capability)),
|
||||
"llir": (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, capability)),
|
||||
"ptx": (lambda path: Path(path).read_text(),
|
||||
|
||||
@@ -42,7 +42,8 @@ if __name__ == '__main__':
|
||||
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
|
||||
|
||||
# triton-ir -> triton-gpu-ir
|
||||
module = triton.compiler.ttir_to_ttgir(module, num_warps=4, num_stages=3, compute_capability=args.sm)
|
||||
module = triton.compiler.ttir_to_ttgir(module, num_warps=4)
|
||||
module = triton.compiler.optimize_ttgir(module, num_stages=3, compute_capability=args.sm)
|
||||
if args.target == 'triton-gpu-ir':
|
||||
print(module.str())
|
||||
sys.exit(0)
|
||||
|
||||
@@ -223,6 +223,7 @@ class _attention(torch.autograd.Function):
|
||||
BLOCK_DMODEL=Lk, num_warps=num_warps,
|
||||
num_stages=2,
|
||||
)
|
||||
# print(h.asm["ttgir"])
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
ctx.grid = grid
|
||||
@@ -260,6 +261,7 @@ class _attention(torch.autograd.Function):
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
||||
num_stages=1,
|
||||
)
|
||||
# print(h.asm["ttgir"])
|
||||
return dq, dk, dv, None
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-coalesce -canonicalize | FileCheck %s
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-coalesce | FileCheck %s
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-combine 2>&1 | FileCheck %s
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-remove-layout-conversions 2>&1 | FileCheck %s
|
||||
|
||||
#layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
#layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu -tritongpu-combine -tritongpu-pipeline=num-stages=3 -tritongpu-combine -test-print-allocation 2>&1 | FileCheck %s
|
||||
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu -tritongpu-remove-layout-conversions -tritongpu-pipeline=num-stages=3 -test-print-allocation 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK: offset = 0, size = 49152
|
||||
// CHECK: offset = 49152, size = 49152
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-combine -tritongpu-update-mma-for-volta 2>&1 | FileCheck %s
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-fuse-transposition -tritongpu-update-mma-for-volta 2>&1 | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
Reference in New Issue
Block a user