Files
ROCm/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
Philippe Tillet 52c146f66b [OPTIMIZER][BACKEND] significantly cleaner handling of mixed-precision kernels (#1949)
we currently have a very janky approach to optimizing mixed-precision
matmul workloads, where some layout combinations (e.g., NT matmul) were
explicitly pattern-matched to take a more optimized codepath. Attempt at
unifying all the codepaths to codegen cp.async failed, due to bugs in
SharedToDotOperandMMAv2.cpp.

This PR fixes said bugs, add some assertions for SharedToDotOperandMMAv2
modes that aren't well supported, and greatly simplify our handling of
element-wise operations between load and conversions to DotOperand.
2023-07-28 10:29:42 -07:00

649 lines
27 KiB
C++

#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Pass/Pass.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Membar.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Tools/Sys/GetPlatform.hpp"
#include "ConvertLayoutOpToLLVM.h"
#include "DotOpToLLVM.h"
#include "ElementwiseOpToLLVM.h"
#include "LoadStoreOpToLLVM.h"
#include "ReduceOpToLLVM.h"
#include "ScanOpToLLVM.h"
#include "TritonGPUToLLVM.h"
#include "TypeConverter.h"
#include "ViewOpToLLVM.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
using namespace mlir;
using namespace mlir::triton;
#define GEN_PASS_CLASSES
#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc"
namespace {
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx, bool isROCM)
: ConversionTarget(ctx) {
addLegalDialect<index::IndexDialect>();
addLegalDialect<LLVM::LLVMDialect>();
if (isROCM) {
addLegalDialect<ROCDL::ROCDLDialect>();
} else {
addLegalDialect<NVVM::NVVMDialect>();
}
addLegalOp<mlir::UnrealizedConversionCastOp>();
}
};
struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
using ConvertOpToLLVMPattern<triton::ReturnOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
if (funcOp->hasAttr("nvvm.kernel")) {
// A GPU kernel
if (op.getNumOperands() > 0) {
return rewriter.notifyMatchFailure(
op, "Kernel functions do not support return with operands");
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
op->getAttrs());
} else {
// A device function
LLVM::ReturnOp newOp;
if (adaptor.getOperands().size() < 2) {
// Single or no return value.
newOp =
rewriter.create<LLVM::ReturnOp>(op.getLoc(), adaptor.getOperands());
} else {
// Pack the results into a struct.
auto packedResultsTy = this->getTypeConverter()->packFunctionResults(
funcOp.getResultTypes());
Value packedResults =
rewriter.create<LLVM::UndefOp>(op.getLoc(), packedResultsTy);
auto loc = op.getLoc();
for (auto it : llvm::enumerate(adaptor.getOperands())) {
packedResults = insert_val(packedResultsTy, packedResults, it.value(),
it.index());
}
newOp = rewriter.create<LLVM::ReturnOp>(op.getLoc(), packedResults);
}
newOp->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newOp->getResults());
}
return success();
}
};
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
/// information.
struct FuncOpConversion : public FuncOpConversionBase {
FuncOpConversion(LLVMTypeConverter &converter, int numWarps,
ModuleAllocation &allocation, PatternBenefit benefit)
: FuncOpConversionBase(converter, benefit), numWarps(numWarps),
allocation(allocation) {}
triton::FuncOp amendFuncOp(triton::FuncOp funcOp,
ConversionPatternRewriter &rewriter) const {
// Push back a variable that indicates the current stack pointer of shared
// memory to the function arguments.
auto loc = funcOp.getLoc();
auto ctx = funcOp->getContext();
auto ptrTy = LLVM::LLVMPointerType::get(
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
// 1. Modify the function type to add the new argument.
auto funcTy = funcOp.getFunctionType();
auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs());
amendedInputTy.push_back(ptrTy);
auto amendedFuncTy = FunctionType::get(funcTy.getContext(), amendedInputTy,
funcTy.getResults());
// 2. Modify the argument attributes to add the new argument.
SmallVector<NamedAttribute> amendedAttrs;
filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs);
auto amendedArgAttrs = llvm::to_vector<4>(funcOp.getAllArgAttrs());
amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx));
amendedAttrs.push_back(rewriter.getNamedAttr(
funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(amendedArgAttrs)));
// 3. Add a new argument to the region
auto amendedFuncOp = rewriter.create<triton::FuncOp>(
funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs);
auto &region = funcOp.getBody();
region.addArgument(ptrTy, loc);
rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(),
amendedFuncOp.end());
return amendedFuncOp;
}
LogicalResult
matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Prevent LLVM's inliner to inline this function
auto amendedFuncOp = funcOp;
if (!allocation.isRoot(funcOp))
amendedFuncOp = amendFuncOp(funcOp, rewriter);
auto newFuncOp = convertFuncOpToLLVMFuncOp(amendedFuncOp, rewriter);
if (!newFuncOp) {
return failure();
}
auto ctx = funcOp->getContext();
if (allocation.isRoot(funcOp)) {
// Set an attribute to indicate this function is a kernel entry.
newFuncOp->setAttr("nvvm.kernel",
rewriter.getIntegerAttr(type::u1Ty(ctx), 1));
} else {
// The noinline attribute will be used by the LLVM codegen to prevent
// inlining.
// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp#L267
newFuncOp.setPassthroughAttr(
ArrayAttr::get(ctx, rewriter.getStringAttr("noinline")));
rewriter.eraseOp(amendedFuncOp);
}
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
// for `nvvm.annotation` metadata.
newFuncOp->setAttr("nvvm.maxntid", rewriter.getI32ArrayAttr(32 * numWarps));
// The call graph is updated by mapping the old function to the new one.
allocation.mapFuncOp(funcOp, newFuncOp);
rewriter.eraseOp(funcOp);
return success();
}
private:
int numWarps{0};
ModuleAllocation &allocation;
};
// CallOpInterfaceLowering is adapted from
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L485
struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
CallOpConversion(LLVMTypeConverter &converter, int numWarps,
ModuleAllocation &allocation, PatternBenefit benefit)
: ConvertOpToLLVMPattern<triton::CallOp>(converter, benefit),
numWarps(numWarps), allocation(allocation) {}
LogicalResult
matchAndRewrite(triton::CallOp callOp,
typename triton::CallOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto promotedOperands = promoteOperands(callOp, adaptor, rewriter);
auto newCallOp =
convertCallOpToLLVMCallOp(callOp, promotedOperands, rewriter);
if (!newCallOp)
return failure();
allocation.mapCallOp(callOp, newCallOp);
auto results = getCallOpResults(callOp, newCallOp, rewriter);
rewriter.replaceOp(callOp, results);
return success();
}
private:
SmallVector<Value, 4>
promoteOperands(triton::CallOp callOp,
typename triton::CallOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Get the last argument of the caller, which is the current stack pointer
// of shared memory and append it to the operands of the callOp.
auto loc = callOp.getLoc();
auto caller = callOp->getParentOfType<FunctionOpInterface>();
auto ptrTy = LLVM::LLVMPointerType::get(
this->getTypeConverter()->convertType(rewriter.getI8Type()),
NVVM::kSharedMemorySpace);
auto promotedOperands = this->getTypeConverter()->promoteOperands(
callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
adaptor.getOperands(), rewriter);
auto base = allocation.getFunctionSharedMemoryBase(caller);
auto *funcAllocation = allocation.getFuncData(caller);
auto bufferId = funcAllocation->getBufferId(callOp);
// function doesn't have a shared mem buffer
if (bufferId == (size_t)-1) {
promotedOperands.push_back(base);
return promotedOperands;
}
// function has a shared mem buffer
auto offset = funcAllocation->getOffset(bufferId);
auto offsetValue = gep(ptrTy, base, i32_val(offset));
promotedOperands.push_back(offsetValue);
return promotedOperands;
}
LLVM::CallOp
convertCallOpToLLVMCallOp(triton::CallOp callOp,
ArrayRef<Value> promotedOperands,
ConversionPatternRewriter &rewriter) const {
// Pack the result types into a struct.
Type packedResult = nullptr;
unsigned numResults = callOp.getNumResults();
auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
if (numResults != 0) {
if (!(packedResult =
this->getTypeConverter()->packFunctionResults(resultTypes)))
return nullptr;
}
auto newCallOp = rewriter.create<LLVM::CallOp>(
callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
promotedOperands, callOp->getAttrs());
return newCallOp;
}
SmallVector<Value>
getCallOpResults(triton::CallOp callOp, LLVM::CallOp newCallOp,
ConversionPatternRewriter &rewriter) const {
auto numResults = callOp.getNumResults();
SmallVector<Value> results;
if (numResults < 2) {
// If < 2 results, packing did not do anything and we can just return.
results.append(newCallOp.result_begin(), newCallOp.result_end());
} else {
// Otherwise, it had been converted to an operation producing a structure.
// Extract individual results from the structure and return them as list.
results.reserve(numResults);
for (unsigned i = 0; i < numResults; ++i) {
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
callOp.getLoc(), newCallOp->getResult(0), i));
}
}
return results;
}
int numWarps{0};
ModuleAllocation &allocation;
};
class TritonLLVMConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMConversionTarget(MLIRContext &ctx, bool isROCM)
: ConversionTarget(ctx) {
addLegalDialect<LLVM::LLVMDialect>();
if (isROCM) {
addLegalDialect<ROCDL::ROCDLDialect>();
} else {
addLegalDialect<NVVM::NVVMDialect>();
}
addIllegalDialect<triton::TritonDialect>();
addIllegalDialect<triton::gpu::TritonGPUDialect>();
addIllegalDialect<mlir::gpu::GPUDialect>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
}
};
class ConvertTritonGPUToLLVM
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
public:
explicit ConvertTritonGPUToLLVM(int computeCapability, bool isROCM)
: computeCapability(computeCapability), isROCM(isROCM) {}
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();
mlir::LowerToLLVMOptions option(context);
option.overrideIndexBitwidth(32);
TritonGPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMConversionTarget target(*context, isROCM);
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
// Preprocess
decomposeFp8e4b15Convert(mod);
decomposeMmaToDotOperand(mod, numWarps, threadsPerWarp);
decomposeBlockedToDotOperand(mod);
decomposeInsertSliceAsyncOp(mod);
// Allocate shared memory and set barrier
ModuleAllocation allocation(mod);
ModuleMembarAnalysis membarPass(&allocation);
membarPass.run();
// Lower functions
{
mlir::LowerToLLVMOptions option(context);
TritonGPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMFunctionConversionTarget funcTarget(*context, isROCM);
RewritePatternSet funcPatterns(context);
funcPatterns.add<FuncOpConversion>(typeConverter, numWarps, allocation,
/*benefit=*/1);
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
funcPatterns);
if (failed(
applyPartialConversion(mod, funcTarget, std::move(funcPatterns))))
return signalPassFailure();
}
// initSharedMemory is run before the conversion of call and ret ops,
// because the call op has to know the shared memory base address of each
// function
initSharedMemory(allocation, typeConverter);
// Convert call and ret ops
{
mlir::LowerToLLVMOptions option(context);
TritonGPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMFunctionConversionTarget funcTarget(*context, isROCM);
RewritePatternSet funcPatterns(context);
funcPatterns.add<CallOpConversion>(typeConverter, numWarps, allocation,
/*benefit=*/1);
funcPatterns.add<ReturnOpConversion>(typeConverter, /*benefit=*/1);
if (failed(
applyPartialConversion(mod, funcTarget, std::move(funcPatterns))))
return signalPassFailure();
}
ModuleAxisInfoAnalysis axisInfoAnalysis(mod);
// Rewrite ops
RewritePatternSet patterns(context);
// TritonGPU lowering patterns
OpBuilder::InsertPoint indexInsertPoint;
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo indexCacheInfo{
&baseIndexCache, &indexCache, &indexInsertPoint};
// TODO: enable index cache if there are multiple functions
if (axisInfoAnalysis.getNumFunctions() > 1) {
indexCacheInfo = {nullptr, nullptr, nullptr};
}
populateTritonGPUToLLVMPatterns(typeConverter, patterns, allocation,
indexCacheInfo, /*benefit=*/1);
populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, allocation,
indexCacheInfo, /*benefit=*/1);
populateDotOpToLLVMPatterns(typeConverter, patterns, allocation,
/*benefit=*/1);
populateElementwiseOpToLLVMPatterns(typeConverter, patterns, /*benefit=*/1);
populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, axisInfoAnalysis,
allocation, indexCacheInfo,
/*benefit=*/1);
populateReduceOpToLLVMPatterns(typeConverter, patterns, allocation,
indexCacheInfo, /*benefit=*/1);
populateScanOpToLLVMPatterns(typeConverter, patterns, allocation,
indexCacheInfo, /*benefit=*/1);
populateViewOpToLLVMPatterns(typeConverter, patterns, /*benefit=*/1);
// Native lowering patterns
if (isROCM) {
mlir::populateGpuToROCDLConversionPatterns(typeConverter, patterns,
mlir::gpu::amd::HIP);
} else {
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
}
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
patterns);
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();
}
private:
using IndexCacheKeyT = std::pair<Attribute, RankedTensorType>;
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
baseIndexCache;
DenseMap<IndexCacheKeyT, SmallVector<SmallVector<Value>>,
CacheKeyDenseMapInfo>
indexCache;
int computeCapability{};
bool isROCM{};
void initSharedMemory(ModuleAllocation &allocation,
TritonGPUToLLVMTypeConverter &typeConverter) {
ModuleOp mod = getOperation();
OpBuilder b(mod.getBodyRegion());
auto ctx = mod.getContext();
auto loc = mod.getLoc();
auto elemTy = typeConverter.convertType(b.getIntegerType(8));
// Set array size 0 and external linkage indicates that we use dynamic
// shared allocation to allow a larger shared memory size for each kernel.
auto arrayTy = LLVM::LLVMArrayType::get(elemTy, 0);
auto global = b.create<LLVM::GlobalOp>(
loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::External,
"global_smem", /*value=*/Attribute(), /*alignment=*/0,
// Add ROCm support.
static_cast<unsigned>(NVVM::NVVMMemorySpace::kSharedMemorySpace));
mod.walk([&](FunctionOpInterface funcOp) {
Value funcSmem;
b.setInsertionPointToStart(&funcOp.getFunctionBody().front());
if (allocation.isRoot(funcOp)) {
funcSmem = b.create<LLVM::AddressOfOp>(loc, global);
} else {
funcSmem = funcOp.getArgument(funcOp.getNumArguments() - 1);
}
auto ptrTy =
LLVM::LLVMPointerType::get(typeConverter.convertType(b.getI8Type()),
NVVM::NVVMMemorySpace::kSharedMemorySpace);
funcSmem = b.create<LLVM::BitcastOp>(loc, ptrTy, funcSmem);
allocation.setFunctionSharedMemoryValue(funcOp, funcSmem);
});
mod->setAttr("triton_gpu.shared",
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32),
allocation.getSharedMemorySize()));
}
void decomposeFp8e4b15Convert(ModuleOp mod) const {
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
OpBuilder builder(cvtOp);
if (!getElementTypeOrSelf(cvtOp).isa<mlir::Float8E4M3B11FNUZType>())
return;
auto shape = cvtOp.getType().cast<RankedTensorType>().getShape();
auto argEncoding =
cvtOp.getOperand().getType().cast<RankedTensorType>().getEncoding();
auto cvtEncoding = cvtOp.getType().cast<RankedTensorType>().getEncoding();
if (argEncoding.isa<triton::gpu::DotOperandEncodingAttr>() ||
cvtEncoding.isa<triton::gpu::DotOperandEncodingAttr>())
return;
auto F16Ty = builder.getF16Type();
auto newArgType = RankedTensorType::get(shape, F16Ty, argEncoding);
auto newCvtType = RankedTensorType::get(shape, F16Ty, cvtEncoding);
auto newArg = builder.create<mlir::triton::FpToFpOp>(
cvtOp.getLoc(), newArgType, cvtOp.getOperand());
auto newCvt = builder.create<mlir::triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), newCvtType, newArg);
auto newRet = builder.create<mlir::triton::FpToFpOp>(
cvtOp.getLoc(), cvtOp.getType(), newCvt.getResult());
cvtOp.replaceAllUsesWith(newRet.getResult());
cvtOp.erase();
});
}
void decomposeMmaToDotOperand(ModuleOp mod, int numWarps,
int threadsPerWarp) const {
// Replace `mma -> dot_op` with `mma -> blocked -> dot_op`
// unless certain conditions are met
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
OpBuilder builder(cvtOp);
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvtOp.getType().cast<RankedTensorType>();
auto srcMma =
srcType.getEncoding().dyn_cast<triton::gpu::MmaEncodingAttr>();
auto dstDotOp =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (srcMma && dstDotOp && !isMmaToDotShortcut(srcType, dstType)) {
auto tmpType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(),
triton::gpu::BlockedEncodingAttr::get(
mod.getContext(), srcType.getShape(), getSizePerThread(srcMma),
getOrder(srcMma), numWarps, threadsPerWarp));
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), dstType, tmp);
cvtOp.replaceAllUsesWith(newConvert.getResult());
cvtOp.erase();
}
});
}
void decomposeBlockedToDotOperand(ModuleOp mod) const {
// Replace `blocked -> dot_op` with `blocked -> shared -> dot_op`
// because the codegen doesn't handle `blocked -> dot_op` directly
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
OpBuilder builder(cvtOp);
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvtOp.getType().cast<RankedTensorType>();
auto srcBlocked =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
auto dstDotOp =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (srcBlocked && dstDotOp) {
auto tmpType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(),
triton::gpu::SharedEncodingAttr::get(
mod.getContext(), dstDotOp, srcType.getShape(),
getOrder(srcBlocked), srcType.getElementType()));
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), dstType, tmp);
cvtOp.replaceAllUsesWith(newConvert.getResult());
cvtOp.erase();
}
});
}
void decomposeInsertSliceAsyncOp(ModuleOp mod) const {
ModuleAxisInfoAnalysis axisInfoAnalysis(mod);
// TODO(Keren): This is a hacky knob that may cause performance regression
// when decomposition has been performed. We should remove this knob once we
// have thorough analysis on async wait. Currently, we decompose
// `insert_slice_async` into `load` and `insert_slice` without knowing which
// `async_wait` is responsible for the `insert_slice_async`. To guarantee
// correctness, we blindly set the `async_wait` to wait for all async ops.
//
// There are two options to improve this:
// 1. We can perform a dataflow analysis to find the `async_wait` that is
// responsible for the `insert_slice_async` in the backend.
// 2. We can modify the pipeline to perform the decomposition before the
// `async_wait` is inserted. However, it is also risky because we don't know
// the correct vectorized shape yet in the pipeline pass. Making the
// pipeline pass aware of the vectorization could introduce additional
// dependencies on the AxisInfoAnalysis and the Coalesce analysis.
bool decomposed = false;
// insert_slice_async %src, %dst, %idx, %mask, %other
// =>
// %tmp = load %src, %mask, %other
// %res = insert_slice %tmp into %dst[%idx]
mod.walk([&](triton::gpu::InsertSliceAsyncOp insertSliceAsyncOp) -> void {
OpBuilder builder(insertSliceAsyncOp);
// Get the vectorized load size
auto src = insertSliceAsyncOp.getSrc();
auto dst = insertSliceAsyncOp.getDst();
auto mask = insertSliceAsyncOp.getMask();
auto srcTy = src.getType().cast<RankedTensorType>();
auto dstTy = dst.getType().cast<RankedTensorType>();
auto srcBlocked =
srcTy.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
auto resSharedLayout =
dstTy.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
auto resElemTy = dstTy.getElementType();
unsigned inVec = axisInfoAnalysis.getPtrContiguity(src);
if (mask)
inVec =
std::min<unsigned>(axisInfoAnalysis.getMaskAlignment(mask), inVec);
unsigned outVec = resSharedLayout.getVec();
unsigned minVec = inVec;
if (outVec > 1)
minVec = std::min(outVec, inVec);
auto maxBitWidth =
std::max<unsigned>(128, resElemTy.getIntOrFloatBitWidth());
auto vecBitWidth = resElemTy.getIntOrFloatBitWidth() * minVec;
auto bitWidth = std::min<unsigned>(maxBitWidth, vecBitWidth);
auto byteWidth = bitWidth / 8;
// If the load byte width is not eligible or the current compute
// capability does not support async copy, then we do decompose
if (triton::gpu::InsertSliceAsyncOp::getEligibleLoadByteWidth(
computeCapability)
.contains(byteWidth)) {
return;
}
// load
auto tmpTy =
RankedTensorType::get(srcTy.getShape(), resElemTy, srcBlocked);
auto loadOp = builder.create<triton::LoadOp>(
insertSliceAsyncOp.getLoc(), tmpTy, insertSliceAsyncOp.getSrc(),
insertSliceAsyncOp.getMask(), insertSliceAsyncOp.getOther(),
// TODO(Chenggang): confirm `boundaryCheck` and `padding`
/*boundaryCheck=*/nullptr, /*padding=*/nullptr,
insertSliceAsyncOp.getCache(), insertSliceAsyncOp.getEvict(),
insertSliceAsyncOp.getIsVolatile());
// insert_slice
auto axis = insertSliceAsyncOp.getAxis();
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
auto offsets = SmallVector<OpFoldResult>(dstTy.getRank(), intAttr(0));
auto sizes = SmallVector<OpFoldResult>(dstTy.getRank(), intAttr(1));
auto strides = SmallVector<OpFoldResult>(dstTy.getRank(), intAttr(1));
offsets[axis] = insertSliceAsyncOp.getIndex();
for (size_t i = 0; i < dstTy.getRank(); i++) {
if (i != axis)
sizes[i] = intAttr(dstTy.getShape()[i]);
}
auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
insertSliceAsyncOp.getLoc(), loadOp, insertSliceAsyncOp.getDst(),
offsets, sizes, strides);
// Replace
insertSliceAsyncOp.replaceAllUsesWith(insertSliceOp.getResult());
insertSliceAsyncOp.erase();
decomposed = true;
});
mod.walk([&](triton::gpu::AsyncCommitGroupOp asyncCommitGroupOp) -> void {
if (!triton::gpu::AsyncCommitGroupOp::isSupported(computeCapability))
asyncCommitGroupOp.erase();
});
mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void {
if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability)) {
// async wait is supported in Ampere and later
asyncWaitOp.erase();
} else if (decomposed) {
// Wait for all previous async ops
OpBuilder builder(asyncWaitOp);
builder.create<triton::gpu::AsyncWaitOp>(asyncWaitOp.getLoc(), 0);
asyncWaitOp.erase();
}
});
}
};
} // anonymous namespace
namespace mlir {
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonGPUToLLVMPass(int computeCapability, bool isROCM) {
return std::make_unique<::ConvertTritonGPUToLLVM>(computeCapability, isROCM);
}
} // namespace triton
} // namespace mlir