mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge commit '36fc54b6f28168d3644808bfe299f1ba06a36272' into ifu230908-2
Conflicts: .gitignore bin/triton-translate.cpp include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td lib/Analysis/Utility.cpp lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp lib/Conversion/TritonGPUToLLVM/Utility.h lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp lib/Dialect/TritonGPU/IR/Dialect.cpp lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp lib/Target/LLVMIR/LLVMIRTranslation.cpp python/src/triton.cc python/test/unit/runtime/test_subproc.py python/triton/compiler/compiler.py python/triton/compiler/make_launcher.py python/triton/language/semantic.py python/triton/runtime/jit.py python/tutorials/06-fused-attention.py test/Conversion/triton_to_tritongpu.mlir test/Conversion/tritongpu_to_llvm.mlir test/TritonGPU/coalesce.mlir unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt
This commit is contained in:
@@ -1,10 +1,14 @@
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
|
||||
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
|
||||
#include "triton/Tools/Sys/GetEnv.hpp"
|
||||
#include <deque>
|
||||
|
||||
@@ -37,6 +41,51 @@ bool ReduceOpHelper::isFastReduction() {
|
||||
getParentOrder(getSrcLayout())[0];
|
||||
}
|
||||
|
||||
// Cases where distributed shared memory is not required in ConvertLayout:
|
||||
// (1) numCTAs == 1
|
||||
// (2) numCTAs > 1 but srcCTALayout == dstCTALayout
|
||||
// TODO: Case with SliceLayout as srcLayout and numCTAs > 1 is to be implemented
|
||||
// in the future
|
||||
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout) {
|
||||
unsigned numCTAs = triton::gpu::getNumCTAs(srcLayout);
|
||||
assert(numCTAs == triton::gpu::getNumCTAs(dstLayout) &&
|
||||
"Invalid layout conversion: the numbers of CTAs of src and dst "
|
||||
"layouts are different");
|
||||
|
||||
// Case (1): Never use dsmem when numCTAs == 1
|
||||
if (numCTAs == 1)
|
||||
return false;
|
||||
|
||||
// Case where CTAsPerCGA of srcLayout in the sliced dim is not 1 is not
|
||||
// implemented yet
|
||||
if (auto sliceLayout = srcLayout.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
|
||||
auto dim = sliceLayout.getDim();
|
||||
auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(sliceLayout.getParent());
|
||||
if (CTAsPerCGA[dim] != 1)
|
||||
assert(0 && "Layout conversion to be implemented");
|
||||
}
|
||||
|
||||
// Case where CTAsPerCGA of dstLayout in the sliced dim is not 1 is supported
|
||||
if (auto sliceLayout = dstLayout.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
|
||||
auto dim = sliceLayout.getDim();
|
||||
auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(sliceLayout.getParent());
|
||||
if (CTAsPerCGA[dim] != 1)
|
||||
return true;
|
||||
}
|
||||
|
||||
// The above two branches make sure that it is legal to call getCTALayout of
|
||||
// srcLayout and dstLayout
|
||||
|
||||
// Case (2): Do not use dsmem when srcCTALayout == dstCTALayout
|
||||
auto srcCTALayout = triton::gpu::getCTALayout(srcLayout);
|
||||
auto dstCTALayout = triton::gpu::getCTALayout(dstLayout);
|
||||
if (srcCTALayout == dstCTALayout)
|
||||
return false;
|
||||
|
||||
// Dsmem access is required when srcCTALayout != dstCTALayout
|
||||
return true;
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getInterWarpSize() {
|
||||
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
|
||||
unsigned sizeIntraWarps = getIntraWarpSize();
|
||||
@@ -125,7 +174,7 @@ unsigned ReduceOpHelper::getScratchSizeInBytes() {
|
||||
|
||||
unsigned bytesPerElem = 0;
|
||||
for (const auto &ty : srcElementTypes) {
|
||||
bytesPerElem += ty.getIntOrFloatBitWidth() / 8;
|
||||
bytesPerElem += ceil<unsigned>(ty.getIntOrFloatBitWidth(), 8);
|
||||
}
|
||||
return bytesPerElem * elems;
|
||||
}
|
||||
@@ -136,7 +185,7 @@ bool ReduceOpHelper::isSupportedLayout() {
|
||||
return true;
|
||||
}
|
||||
if (auto mmaLayout = srcLayout.dyn_cast<triton::gpu::MmaEncodingAttr>()) {
|
||||
if (mmaLayout.isAmpere()) {
|
||||
if (mmaLayout.isAmpere() || mmaLayout.isHopper()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@@ -286,6 +335,8 @@ bool maybeSharedAllocationOp(Operation *op) {
|
||||
return dialect &&
|
||||
(dialect->getTypeID() ==
|
||||
mlir::TypeID::get<triton::gpu::TritonGPUDialect>() ||
|
||||
dialect->getTypeID() ==
|
||||
mlir::TypeID::get<triton::nvidia_gpu::TritonNvidiaGPUDialect>() ||
|
||||
dialect->getTypeID() == mlir::TypeID::get<triton::TritonDialect>() ||
|
||||
dialect->getTypeID() == mlir::TypeID::get<arith::ArithDialect>() ||
|
||||
dialect->getTypeID() == mlir::TypeID::get<tensor::TensorDialect>());
|
||||
@@ -294,6 +345,8 @@ bool maybeSharedAllocationOp(Operation *op) {
|
||||
bool maybeAliasOp(Operation *op) {
|
||||
return isa<triton::gpu::ExtractSliceOp>(op) || isa<triton::TransOp>(op) ||
|
||||
isa<triton::gpu::InsertSliceAsyncOp>(op) ||
|
||||
isa<triton::nvidia_gpu::InsertSliceAsyncV2Op>(op) ||
|
||||
isa<triton::nvidia_gpu::StoreAsyncOp>(op) ||
|
||||
isa<tensor::InsertSliceOp>(op);
|
||||
}
|
||||
|
||||
@@ -303,7 +356,25 @@ bool supportMMA(triton::DotOp op, int version) {
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
|
||||
auto aElemTy = op.getA().getType().cast<RankedTensorType>().getElementType();
|
||||
auto bElemTy = op.getB().getType().cast<RankedTensorType>().getElementType();
|
||||
<<<<<<< HEAD
|
||||
|
||||
=======
|
||||
if (version == 3) {
|
||||
if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3"))
|
||||
return false;
|
||||
auto retType = op.getResult().getType().cast<RankedTensorType>();
|
||||
auto retShapePerCTA = triton::gpu::getShapePerCTA(retType);
|
||||
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
if (!(numWarps % 4 == 0 && retShapePerCTA[0] % 64 == 0 &&
|
||||
retShapePerCTA[1] % 8 == 0 &&
|
||||
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ() ||
|
||||
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
|
||||
aElemTy.isF32()))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
if (aElemTy.isF32() && bElemTy.isF32()) {
|
||||
return (op.getAllowTF32() && version == 2) || version == 3;
|
||||
}
|
||||
@@ -345,25 +416,22 @@ bool supportMFMA(triton::DotOp op) {
|
||||
#endif
|
||||
|
||||
bool supportMMA(Value value, int version) {
|
||||
// Tell whether a DotOp support HMMA by the operand type(either $a or $b).
|
||||
// Tell whether a DotOp support MMA by the operand type(either $a or $b).
|
||||
// We cannot get both the operand types(in TypeConverter), here we assume the
|
||||
// types of both the operands are identical here.
|
||||
assert((version == 1 || version == 2) &&
|
||||
assert((version == 1 || version == 2 || version == 3) &&
|
||||
"Unexpected MMA layout version found");
|
||||
|
||||
auto elemTy = value.getType().cast<RankedTensorType>().getElementType();
|
||||
return elemTy.isF16() || elemTy.isBF16() ||
|
||||
// FP8 is not natively supported on all mma versions but it can always be
|
||||
// promoted to fp16 therefore we can always support it.
|
||||
bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() ||
|
||||
elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ();
|
||||
return isFP8 || elemTy.isF16() || elemTy.isBF16() ||
|
||||
(elemTy.isF32() && version >= 2) ||
|
||||
(elemTy.isInteger(8) && version >= 2);
|
||||
}
|
||||
|
||||
Type getElementType(Value value) {
|
||||
auto type = value.getType();
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return tensorType.getElementType();
|
||||
return type;
|
||||
}
|
||||
|
||||
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
// dot_op<opIdx=0, parent=#mma> = #mma
|
||||
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
|
||||
@@ -378,6 +446,7 @@ bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
!srcTy.getElementType().isF32();
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
#ifdef USE_ROCM
|
||||
bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
@@ -395,6 +464,18 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
|
||||
}
|
||||
#endif
|
||||
=======
|
||||
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
auto src = srcTy.getEncoding().cast<triton::gpu::MmaEncodingAttr>();
|
||||
auto dst = dstTy.getEncoding().cast<triton::gpu::MmaEncodingAttr>();
|
||||
auto srcElemsPerThread = triton::gpu::getTotalElemsPerThread(srcTy);
|
||||
auto dstElemsPerThread = triton::gpu::getTotalElemsPerThread(dstTy);
|
||||
// when #mma = MmaEncoding<version=3, warpsPerCTA=[..., 1]>
|
||||
return src.getVersionMajor() == 3 && src.getWarpsPerCTA()[1] == 1 &&
|
||||
dst.getVersionMajor() == 3 && dst.getWarpsPerCTA()[1] == 1 &&
|
||||
srcElemsPerThread == dstElemsPerThread;
|
||||
}
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
|
||||
bool isSingleValue(Value value) {
|
||||
// Don't consider load as expensive if it is loading a scalar.
|
||||
@@ -455,9 +536,11 @@ struct DFSState {
|
||||
SmallVector<Operation *, 16> topologicalCounts;
|
||||
DenseSet<Operation *> seen;
|
||||
|
||||
/// We mark each op as ready if all its operands are seen. If an op is ready,
|
||||
/// we add it to the queue. Otherwise, we keep adding its operands to the
|
||||
/// ancestors set.
|
||||
/// We mark each op as ready if all its operands and parents ops are seen. If
|
||||
/// an op is ready, we add it to the queue. Otherwise, we keep adding its
|
||||
/// operands to the ancestors set.
|
||||
/// We always want an op to be scheduled after all its parents to handle
|
||||
/// correctly cases with scf operations.
|
||||
void addToReadyQueue(Operation *op, DFSSubgraphState &subGraph,
|
||||
SmallVector<Operation *, 4> &readyQueue) {
|
||||
bool ready = true;
|
||||
@@ -468,6 +551,14 @@ struct DFSState {
|
||||
ready = false;
|
||||
}
|
||||
}
|
||||
Operation *parent = op->getParentOp();
|
||||
while (parent) {
|
||||
if (!seen.count(parent)) {
|
||||
subGraph.push_back(parent);
|
||||
ready = false;
|
||||
}
|
||||
parent = parent->getParentOp();
|
||||
}
|
||||
if (ready)
|
||||
readyQueue.push_back(op);
|
||||
}
|
||||
@@ -615,4 +706,81 @@ std::unique_ptr<DataFlowSolver> createDataFlowSolver() {
|
||||
return solver;
|
||||
}
|
||||
|
||||
static triton::MakeTensorPtrOp getMakeTensorPtrOpImpl(Operation *op, Value v) {
|
||||
|
||||
if (auto makeTensorPtrOp = dyn_cast<triton::MakeTensorPtrOp>(op)) {
|
||||
return makeTensorPtrOp;
|
||||
}
|
||||
|
||||
if (auto advanceOp = dyn_cast<triton::AdvanceOp>(op)) {
|
||||
return getMakeTensorPtrOp(advanceOp.getPtr());
|
||||
}
|
||||
|
||||
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
|
||||
auto idx = v.cast<OpResult>().getResultNumber();
|
||||
llvm::SmallVector<scf::YieldOp> yieldOps;
|
||||
op->walk([&](Operation *op) {
|
||||
if (auto yieldOp = dyn_cast<scf::YieldOp>(op))
|
||||
yieldOps.push_back(yieldOp);
|
||||
});
|
||||
|
||||
// benzh@ if multi yields, all yields operand should come from same arg.
|
||||
Value newValue = yieldOps[0].getOperands()[idx];
|
||||
return getMakeTensorPtrOp(newValue);
|
||||
}
|
||||
|
||||
llvm_unreachable("Unable to getMakeTensorPtr()");
|
||||
}
|
||||
|
||||
triton::MakeTensorPtrOp getMakeTensorPtrOp(Value v) {
|
||||
using BranchOps = llvm::SetVector<std::pair<Operation *, int>>;
|
||||
llvm::DenseMap<Block *, BranchOps> blockToCFOps;
|
||||
auto moduleOp =
|
||||
v.getParentBlock()->getParentOp()->getParentOfType<ModuleOp>();
|
||||
|
||||
moduleOp.walk([&](Operation *op) {
|
||||
if (auto br = dyn_cast<cf::BranchOp>(op)) {
|
||||
Block *block = br.getDest();
|
||||
blockToCFOps[block].insert({op, -1});
|
||||
}
|
||||
if (auto condBr = dyn_cast<cf::CondBranchOp>(op)) {
|
||||
Block *blockT = condBr.getTrueDest();
|
||||
Block *blockF = condBr.getFalseDest();
|
||||
blockToCFOps[blockT].insert({condBr, 1});
|
||||
blockToCFOps[blockF].insert({condBr, 0});
|
||||
}
|
||||
});
|
||||
|
||||
if (Operation *definingOp = v.getDefiningOp()) {
|
||||
return getMakeTensorPtrOpImpl(definingOp, v);
|
||||
} else if (BlockArgument arg = v.cast<BlockArgument>()) {
|
||||
unsigned argNum = arg.getArgNumber();
|
||||
Operation *argOwner = arg.getOwner()->getParentOp();
|
||||
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(argOwner)) {
|
||||
return getMakeTensorPtrOp(
|
||||
forOp.getOperand(argNum + forOp.getNumControlOperands() - 1));
|
||||
} else if (auto funcOp = dyn_cast<mlir::triton::FuncOp>(argOwner)) {
|
||||
Block *block = arg.getOwner();
|
||||
Operation *op;
|
||||
int tOrF;
|
||||
std::tie(op, tOrF) = blockToCFOps[block][0];
|
||||
if (auto br = dyn_cast<cf::BranchOp>(op)) {
|
||||
return getMakeTensorPtrOp(br.getDestOperands()[argNum]);
|
||||
}
|
||||
if (auto condBr = dyn_cast<cf::CondBranchOp>(op)) {
|
||||
if (tOrF) {
|
||||
return getMakeTensorPtrOp(condBr.getTrueDestOperands()[argNum]);
|
||||
} else {
|
||||
return getMakeTensorPtrOp(condBr.getFalseDestOperands()[argNum]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return getMakeTensorPtrOp(argOwner->getOperand(argNum));
|
||||
}
|
||||
}
|
||||
|
||||
llvm_unreachable("Unable to getMakeTensorPtr()");
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
Reference in New Issue
Block a user