mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
* Select mfma dimensions and instruction from static table * Extend mfmaLayout to include version and instrShape * Simplify generateMFMAOp by searching the mfma instruction in the table * Fix getNonKDim() and non_k_dim * Break instrShape into MDim and NDim
1121 lines
45 KiB
C++
1121 lines
45 KiB
C++
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
|
|
|
|
#include "mlir/Analysis/DataFlowFramework.h"
|
|
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
|
|
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
|
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
|
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
|
|
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
|
|
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
|
|
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
|
#include "mlir/Dialect/Index/IR/IndexDialect.h"
|
|
#include "mlir/Dialect/Index/IR/IndexOps.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "triton/Analysis/Allocation.h"
|
|
#include "triton/Analysis/AxisInfo.h"
|
|
#include "triton/Analysis/Membar.h"
|
|
#include "triton/Dialect/NVGPU/IR/Dialect.h"
|
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
#ifndef USE_ROCM
|
|
#else
|
|
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
|
|
#endif
|
|
#include "triton/Tools/Sys/GetPlatform.hpp"
|
|
|
|
#include "BarrierOpToLLVM.h"
|
|
#include "ClusterOpsToLLVM.h"
|
|
#include "ConvertLayoutOpToLLVM.h"
|
|
#include "DotOpToLLVM.h"
|
|
#include "ElementwiseOpToLLVM.h"
|
|
#include "LoadStoreOpToLLVM.h"
|
|
#include "ReduceOpToLLVM.h"
|
|
#include "RegReallocOpToLLVM.h"
|
|
#include "ScanOpToLLVM.h"
|
|
#include "TensorPtrOpsToLLVM.h"
|
|
#include "TritonGPUToLLVM.h"
|
|
#include "TritonGPUToLLVMBase.h"
|
|
#include "TypeConverter.h"
|
|
#include "ViewOpToLLVM.h"
|
|
|
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
|
|
|
namespace mlir {
|
|
namespace triton {
|
|
#define GEN_PASS_DEF_CONVERTTRITONGPUTOLLVM
|
|
#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc"
|
|
} // namespace triton
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::triton;
|
|
namespace ttng = mlir::triton::nvidia_gpu;
|
|
|
|
namespace {
|
|
|
|
// pass ws related named attrs.
|
|
static void addWSNamedAttrs(Operation *op,
|
|
ArrayRef<mlir::NamedAttribute> attrs) {
|
|
for (const NamedAttribute attr : attrs)
|
|
if (attr.getName() == "async_agent" || attr.getName() == "agent.mutex_role")
|
|
op->setAttr(attr.getName(), attr.getValue());
|
|
}
|
|
|
|
#ifdef USE_ROCM
|
|
constexpr int LDSSize = 65536;
|
|
constexpr int kPtrBitWidth = 64;
|
|
#endif
|
|
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
|
|
public:
|
|
explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx, Target target)
|
|
: ConversionTarget(ctx) {
|
|
addLegalDialect<index::IndexDialect>();
|
|
addLegalDialect<LLVM::LLVMDialect>();
|
|
switch (target) {
|
|
case Target::NVVM:
|
|
addLegalDialect<NVVM::NVVMDialect>();
|
|
break;
|
|
case Target::ROCDL:
|
|
addLegalDialect<ROCDL::ROCDLDialect>();
|
|
addLegalDialect<mlir::scf::SCFDialect>();
|
|
break;
|
|
}
|
|
addLegalOp<mlir::UnrealizedConversionCastOp>();
|
|
}
|
|
};
|
|
|
|
class FoldSplatMaskInInsertAsync : public mlir::RewritePattern {
|
|
|
|
public:
|
|
FoldSplatMaskInInsertAsync(mlir::MLIRContext *context)
|
|
: mlir::RewritePattern(
|
|
triton::nvidia_gpu::InsertSliceAsyncV2Op::getOperationName(), 1,
|
|
context) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(mlir::Operation *op,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
auto insertOp = cast<triton::nvidia_gpu::InsertSliceAsyncV2Op>(op);
|
|
if (!insertOp.getMask())
|
|
return failure();
|
|
auto splatOp = insertOp.getMask().getDefiningOp<triton::SplatOp>();
|
|
if (!splatOp)
|
|
return failure();
|
|
rewriter.updateRootInPlace(insertOp, [&]() {
|
|
insertOp.getMaskMutable().assign(splatOp->getOperand(0));
|
|
});
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
|
|
using ConvertOpToLLVMPattern<triton::ReturnOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
|
|
if (funcOp->hasAttr("nvvm.kernel")) {
|
|
// A GPU kernel
|
|
if (op.getNumOperands() > 0) {
|
|
return rewriter.notifyMatchFailure(
|
|
op, "Kernel functions do not support return with operands");
|
|
}
|
|
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
|
|
op->getAttrs());
|
|
} else {
|
|
// A device function
|
|
LLVM::ReturnOp newOp;
|
|
if (adaptor.getOperands().size() < 2) {
|
|
// Single or no return value.
|
|
newOp =
|
|
rewriter.create<LLVM::ReturnOp>(op.getLoc(), adaptor.getOperands());
|
|
} else {
|
|
// Pack the results into a struct.
|
|
auto packedResultsTy = this->getTypeConverter()->packFunctionResults(
|
|
funcOp.getResultTypes());
|
|
Value packedResults =
|
|
rewriter.create<LLVM::UndefOp>(op.getLoc(), packedResultsTy);
|
|
auto loc = op.getLoc();
|
|
for (auto it : llvm::enumerate(adaptor.getOperands())) {
|
|
packedResults = insert_val(packedResultsTy, packedResults, it.value(),
|
|
it.index());
|
|
}
|
|
newOp = rewriter.create<LLVM::ReturnOp>(op.getLoc(), packedResults);
|
|
}
|
|
newOp->setAttrs(op->getAttrs());
|
|
rewriter.replaceOp(op, newOp->getResults());
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
|
|
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
|
|
/// information.
|
|
struct FuncOpConversion : public FuncOpConversionBase {
|
|
FuncOpConversion(LLVMTypeConverter &converter, int numWarps,
|
|
ModuleAllocation &allocation, PatternBenefit benefit)
|
|
: FuncOpConversionBase(converter, benefit), numWarps(numWarps),
|
|
allocation(allocation) {}
|
|
|
|
triton::FuncOp amendFuncOp(triton::FuncOp funcOp,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
// Push back a variable that indicates the current stack pointer of shared
|
|
// memory to the function arguments.
|
|
auto loc = funcOp.getLoc();
|
|
auto ctx = funcOp->getContext();
|
|
auto ptrTy = LLVM::LLVMPointerType::get(
|
|
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
|
|
// 1. Modify the function type to add the new argument.
|
|
auto funcTy = funcOp.getFunctionType();
|
|
auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs());
|
|
amendedInputTy.push_back(ptrTy);
|
|
auto amendedFuncTy = FunctionType::get(funcTy.getContext(), amendedInputTy,
|
|
funcTy.getResults());
|
|
// 2. Modify the argument attributes to add the new argument.
|
|
SmallVector<NamedAttribute> amendedAttrs;
|
|
filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs);
|
|
auto amendedArgAttrs = llvm::to_vector<4>(funcOp.getAllArgAttrs());
|
|
amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx));
|
|
amendedAttrs.push_back(rewriter.getNamedAttr(
|
|
funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(amendedArgAttrs)));
|
|
// 3. Add a new argument to the region
|
|
auto amendedFuncOp = rewriter.create<triton::FuncOp>(
|
|
funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs);
|
|
auto ®ion = funcOp.getBody();
|
|
region.addArgument(ptrTy, loc);
|
|
rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(),
|
|
amendedFuncOp.end());
|
|
return amendedFuncOp;
|
|
}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// Prevent LLVM's inliner to inline this function
|
|
auto amendedFuncOp = funcOp;
|
|
if (!allocation.isRoot(funcOp))
|
|
amendedFuncOp = amendFuncOp(funcOp, rewriter);
|
|
|
|
// Collect TMA informations.
|
|
unsigned numTMALoad = 0;
|
|
funcOp.walk(
|
|
[&numTMALoad](triton::nvidia_gpu::InsertSliceAsyncV2Op insertSliceOp) {
|
|
numTMALoad++;
|
|
});
|
|
unsigned numTMAStore = 0;
|
|
funcOp.walk([&numTMAStore](triton::nvidia_gpu::StoreAsyncOp storeAsyncOp) {
|
|
numTMAStore++;
|
|
});
|
|
unsigned numTMA = numTMALoad + numTMAStore;
|
|
|
|
auto newFuncOp = convertFuncOpToLLVMFuncOp(amendedFuncOp, rewriter);
|
|
if (!newFuncOp) {
|
|
return failure();
|
|
}
|
|
|
|
auto ctx = funcOp->getContext();
|
|
|
|
if (allocation.isRoot(funcOp)) {
|
|
// Set an attribute to indicate this function is a kernel entry.
|
|
newFuncOp->setAttr("nvvm.kernel",
|
|
rewriter.getIntegerAttr(type::u1Ty(ctx), 1));
|
|
} else {
|
|
// The noinline attribute will be used by the LLVM codegen to prevent
|
|
// inlining.
|
|
// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp#L267
|
|
newFuncOp.setPassthroughAttr(
|
|
ArrayAttr::get(ctx, rewriter.getStringAttr("noinline")));
|
|
rewriter.eraseOp(amendedFuncOp);
|
|
}
|
|
#ifndef USE_ROCM
|
|
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
|
|
// for `nvvm.annotation` metadata.
|
|
newFuncOp->setAttr("nvvm.maxntid", rewriter.getI32ArrayAttr(32 * numWarps));
|
|
#endif
|
|
// The call graph is updated by mapping the old function to the new one.
|
|
allocation.mapFuncOp(funcOp, newFuncOp);
|
|
|
|
// Append arguments to receive TMADesc in global memory in the runtime
|
|
auto i8PtrTy = LLVM::LLVMPointerType::get(
|
|
this->getTypeConverter()->convertType(rewriter.getI8Type()), 1);
|
|
auto numArgs = newFuncOp.getBody().front().getNumArguments();
|
|
auto funcTy = newFuncOp.getFunctionType().cast<LLVM::LLVMFunctionType>();
|
|
SmallVector<Type> newInputsTy(funcTy.getParams().begin(),
|
|
funcTy.getParams().end());
|
|
for (unsigned i = 0; i < numTMA; ++i) {
|
|
newFuncOp.getBody().front().addArgument(i8PtrTy, funcOp.getLoc());
|
|
newInputsTy.push_back(i8PtrTy);
|
|
}
|
|
newFuncOp.setType(
|
|
LLVM::LLVMFunctionType::get(funcTy.getReturnType(), newInputsTy));
|
|
// required by AxisInfoAnalysis
|
|
for (unsigned i = 0; i < numTMA; ++i) {
|
|
newFuncOp.setArgAttr(numArgs + i, "tt.divisibility",
|
|
rewriter.getIntegerAttr(i32_ty, 1));
|
|
}
|
|
|
|
newFuncOp->setAttr(kAttrNumTMALoadDescsName,
|
|
rewriter.getIntegerAttr(i32_ty, numTMALoad));
|
|
newFuncOp->setAttr(kAttrNumTMAStoreDescsName,
|
|
rewriter.getIntegerAttr(i32_ty, numTMAStore));
|
|
|
|
rewriter.eraseOp(funcOp);
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
int numWarps{0};
|
|
ModuleAllocation &allocation;
|
|
};
|
|
|
|
// CallOpInterfaceLowering is adapted from
|
|
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L485
|
|
struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
|
|
CallOpConversion(LLVMTypeConverter &converter, int numWarps,
|
|
ModuleAllocation &allocation, PatternBenefit benefit)
|
|
: ConvertOpToLLVMPattern<triton::CallOp>(converter, benefit),
|
|
numWarps(numWarps), allocation(allocation) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(triton::CallOp callOp,
|
|
typename triton::CallOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto promotedOperands = promoteOperands(callOp, adaptor, rewriter);
|
|
auto newCallOp =
|
|
convertCallOpToLLVMCallOp(callOp, promotedOperands, rewriter);
|
|
if (!newCallOp)
|
|
return failure();
|
|
allocation.mapCallOp(callOp, newCallOp);
|
|
auto results = getCallOpResults(callOp, newCallOp, rewriter);
|
|
rewriter.replaceOp(callOp, results);
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
SmallVector<Value, 4>
|
|
promoteOperands(triton::CallOp callOp,
|
|
typename triton::CallOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
// Get the last argument of the caller, which is the current stack pointer
|
|
// of shared memory and append it to the operands of the callOp.
|
|
auto loc = callOp.getLoc();
|
|
auto caller = callOp->getParentOfType<FunctionOpInterface>();
|
|
auto ptrTy = LLVM::LLVMPointerType::get(
|
|
this->getTypeConverter()->convertType(rewriter.getI8Type()),
|
|
NVVM::kSharedMemorySpace);
|
|
auto promotedOperands = this->getTypeConverter()->promoteOperands(
|
|
callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
|
|
adaptor.getOperands(), rewriter);
|
|
auto base = allocation.getFunctionSharedMemoryBase(caller);
|
|
auto *funcAllocation = allocation.getFuncData(caller);
|
|
auto bufferId = funcAllocation->getBufferId(callOp);
|
|
// function doesn't have a shared mem buffer
|
|
if (bufferId == (size_t)-1) {
|
|
promotedOperands.push_back(base);
|
|
return promotedOperands;
|
|
}
|
|
// function has a shared mem buffer
|
|
auto offset = funcAllocation->getOffset(bufferId);
|
|
auto offsetValue = gep(ptrTy, base, i32_val(offset));
|
|
promotedOperands.push_back(offsetValue);
|
|
return promotedOperands;
|
|
}
|
|
|
|
LLVM::CallOp
|
|
convertCallOpToLLVMCallOp(triton::CallOp callOp,
|
|
ArrayRef<Value> promotedOperands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
// Pack the result types into a struct.
|
|
Type packedResult = nullptr;
|
|
unsigned numResults = callOp.getNumResults();
|
|
auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
|
|
|
|
if (numResults != 0) {
|
|
if (!(packedResult =
|
|
this->getTypeConverter()->packFunctionResults(resultTypes)))
|
|
return nullptr;
|
|
}
|
|
auto newCallOp = rewriter.create<LLVM::CallOp>(
|
|
callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
|
|
promotedOperands, callOp->getAttrs());
|
|
return newCallOp;
|
|
}
|
|
|
|
SmallVector<Value>
|
|
getCallOpResults(triton::CallOp callOp, LLVM::CallOp newCallOp,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
auto numResults = callOp.getNumResults();
|
|
SmallVector<Value> results;
|
|
if (numResults < 2) {
|
|
// If < 2 results, packing did not do anything and we can just return.
|
|
results.append(newCallOp.result_begin(), newCallOp.result_end());
|
|
} else {
|
|
// Otherwise, it had been converted to an operation producing a structure.
|
|
// Extract individual results from the structure and return them as list.
|
|
results.reserve(numResults);
|
|
for (unsigned i = 0; i < numResults; ++i) {
|
|
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
|
|
callOp.getLoc(), newCallOp->getResult(0), i));
|
|
}
|
|
}
|
|
return results;
|
|
}
|
|
|
|
int numWarps{0};
|
|
ModuleAllocation &allocation;
|
|
};
|
|
|
|
class TritonLLVMConversionTarget : public ConversionTarget {
|
|
public:
|
|
explicit TritonLLVMConversionTarget(MLIRContext &ctx, Target target)
|
|
: ConversionTarget(ctx) {
|
|
addLegalDialect<LLVM::LLVMDialect>();
|
|
switch (target) {
|
|
case Target::NVVM:
|
|
addLegalDialect<NVVM::NVVMDialect>();
|
|
break;
|
|
case Target::ROCDL:
|
|
addLegalDialect<ROCDL::ROCDLDialect>();
|
|
addLegalDialect<mlir::scf::SCFDialect>();
|
|
break;
|
|
}
|
|
addLegalDialect<mlir::triton::nvgpu::NVGPUDialect>();
|
|
addIllegalDialect<triton::TritonDialect>();
|
|
addIllegalDialect<triton::gpu::TritonGPUDialect>();
|
|
addIllegalDialect<triton::nvidia_gpu::TritonNvidiaGPUDialect>();
|
|
addIllegalDialect<mlir::gpu::GPUDialect>();
|
|
addLegalOp<mlir::UnrealizedConversionCastOp>();
|
|
}
|
|
};
|
|
|
|
struct ConvertTritonGPUToLLVM
|
|
: public triton::impl::ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
|
|
using ConvertTritonGPUToLLVMBase<
|
|
ConvertTritonGPUToLLVM>::ConvertTritonGPUToLLVMBase;
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<triton::nvgpu::NVGPUDialect, LLVM::LLVMDialect,
|
|
NVVM::NVVMDialect, ROCDL::ROCDLDialect>();
|
|
}
|
|
|
|
ConvertTritonGPUToLLVM(int32_t computeCapability, Target target,
|
|
mlir::triton::gpu::TMAMetadataTy *tmaMetadata)
|
|
: ConvertTritonGPUToLLVMBase({computeCapability, target}),
|
|
tmaMetadata(tmaMetadata) {}
|
|
|
|
void runOnOperation() override {
|
|
MLIRContext *context = &getContext();
|
|
ModuleOp mod = getOperation();
|
|
mlir::LowerToLLVMOptions option(context);
|
|
option.overrideIndexBitwidth(32);
|
|
TritonGPUToLLVMTypeConverter typeConverter(context, option);
|
|
TritonLLVMConversionTarget convTarget(*context, target);
|
|
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
|
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
|
|
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
|
|
|
|
// Preprocess
|
|
decomposeFp8e4b15Convert(mod);
|
|
decomposeMmaToDotOperand(mod, numWarps, threadsPerWarp, numCTAs);
|
|
#ifdef USE_ROCM
|
|
decomposeMfmaToDotOperand(mod, numWarps, threadsPerWarp, numCTAs);
|
|
reduceCvtOpLDSUsage(mod);
|
|
#endif
|
|
decomposeBlockedToDotOperand(mod);
|
|
decomposeInsertSliceAsyncOp(mod);
|
|
decomposeMixedModeDotOp(mod);
|
|
|
|
// Allocate shared memory and set barrier
|
|
ModuleAllocation allocation(mod);
|
|
ModuleMembarAnalysis membarPass(&allocation);
|
|
membarPass.run();
|
|
|
|
/* Get tensorPtrMap before conversion */
|
|
TensorPtrMapT tensorPtrMap;
|
|
mod.walk([&tensorPtrMap](
|
|
mlir::triton::nvidia_gpu::InsertSliceAsyncV2Op insertOp) {
|
|
auto src = insertOp.getSrc();
|
|
auto ptrTy = src.getType().dyn_cast<triton::PointerType>();
|
|
if (ptrTy && ptrTy.getPointeeType().isa<RankedTensorType>()) {
|
|
auto makeTensorPtrOp = getMakeTensorPtrOp(insertOp.getSrc());
|
|
tensorPtrMap[insertOp.getOperation()] = makeTensorPtrOp;
|
|
}
|
|
});
|
|
|
|
mod.walk([&tensorPtrMap](mlir::triton::nvidia_gpu::StoreAsyncOp storeOp) {
|
|
auto dst = storeOp.getDst();
|
|
auto ptrTy = dst.getType().dyn_cast<triton::PointerType>();
|
|
if (ptrTy && ptrTy.getPointeeType().isa<RankedTensorType>()) {
|
|
auto makeTensorPtrOp = getMakeTensorPtrOp(storeOp.getDst());
|
|
tensorPtrMap[storeOp.getOperation()] = makeTensorPtrOp;
|
|
}
|
|
});
|
|
|
|
// Hack: cleanup
|
|
{
|
|
RewritePatternSet patterns(context);
|
|
patterns.add<FoldSplatMaskInInsertAsync>(context);
|
|
SmallVector<Operation *> insertSlices;
|
|
mod.walk([&insertSlices](triton::nvidia_gpu::InsertSliceAsyncV2Op op) {
|
|
insertSlices.push_back(op);
|
|
});
|
|
if (applyOpPatternsAndFold(insertSlices, std::move(patterns)).failed())
|
|
signalPassFailure();
|
|
}
|
|
|
|
// Lower functions
|
|
{
|
|
mlir::LowerToLLVMOptions option(context);
|
|
TritonGPUToLLVMTypeConverter typeConverter(context, option);
|
|
TritonLLVMFunctionConversionTarget funcTarget(*context, target);
|
|
RewritePatternSet funcPatterns(context);
|
|
funcPatterns.add<FuncOpConversion>(typeConverter, numWarps, allocation,
|
|
/*benefit=*/1);
|
|
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
|
|
funcPatterns);
|
|
if (failed(
|
|
applyPartialConversion(mod, funcTarget, std::move(funcPatterns))))
|
|
return signalPassFailure();
|
|
}
|
|
|
|
// initSharedMemory is run before the conversion of call and ret ops,
|
|
// because the call op has to know the shared memory base address of each
|
|
// function
|
|
initSharedMemory(allocation, typeConverter);
|
|
|
|
// Convert call and ret ops
|
|
{
|
|
mlir::LowerToLLVMOptions option(context);
|
|
TritonGPUToLLVMTypeConverter typeConverter(context, option);
|
|
TritonLLVMFunctionConversionTarget funcTarget(*context, target);
|
|
RewritePatternSet funcPatterns(context);
|
|
funcPatterns.add<CallOpConversion>(typeConverter, numWarps, allocation,
|
|
/*benefit=*/1);
|
|
funcPatterns.add<ReturnOpConversion>(typeConverter, /*benefit=*/1);
|
|
if (failed(
|
|
applyPartialConversion(mod, funcTarget, std::move(funcPatterns))))
|
|
return signalPassFailure();
|
|
}
|
|
|
|
ModuleAxisInfoAnalysis axisInfoAnalysis(mod);
|
|
|
|
// Emit logics to get threadId/blockIds/linearized clusterCTAId etc. and
|
|
// cache the values. The reason to do it here is that cluster_ctaid is
|
|
// currently implemented via inline asm, and thus cannot be CSEed.
|
|
// clusterCTAId will be emitted only when numCTAs is larger than 1, and
|
|
// other values will be DCEed if not used hereafter.
|
|
bool isWarpSpecialization =
|
|
ttng::TritonNvidiaGPUDialect::getWSSupportedAttr(mod);
|
|
OpBuilder::InsertPoint indexInsertPoint;
|
|
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo indexCacheInfo{
|
|
&baseIndexCache, &indexCache, &indexInsertPoint};
|
|
// TODO: enable index cache if there are multiple functions
|
|
if (axisInfoAnalysis.getNumFunctions() > 1) {
|
|
indexCacheInfo = {nullptr, nullptr, nullptr};
|
|
}
|
|
|
|
// tmaMetadata is absent in a triton-opt unit test, in this case, create a
|
|
// local one and dump it after this pass is done.
|
|
mlir::triton::gpu::TMAMetadataTy tmaMetaDataDebug;
|
|
if (tmaMetadata == nullptr)
|
|
tmaMetadata = &tmaMetaDataDebug;
|
|
|
|
RewritePatternSet patterns(context);
|
|
|
|
auto populatePatterns1 = [&](auto populateFunc) {
|
|
populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis,
|
|
allocation, indexCacheInfo,
|
|
/*benefit*/ 10);
|
|
};
|
|
|
|
auto populatePatterns2 = [&](auto populateFunc) {
|
|
populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis,
|
|
allocation, /*benefit*/ 10);
|
|
};
|
|
|
|
auto populatePatterns3 = [&](auto populateFunc) {
|
|
populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis,
|
|
allocation, indexCacheInfo, tmaMetadata, &tensorPtrMap,
|
|
/*benefit*/ 10);
|
|
};
|
|
|
|
auto populatePatterns4 = [&](auto populateFunc) {
|
|
populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis,
|
|
allocation, indexCacheInfo, computeCapability,
|
|
/*benefit*/ 10);
|
|
};
|
|
|
|
populatePatterns1(populateTritonGPUToLLVMPatterns);
|
|
populatePatterns1(populateConvertLayoutOpToLLVMPatterns);
|
|
populatePatterns2(populateDotOpToLLVMPatterns);
|
|
populatePatterns4(populateElementwiseOpToLLVMPatterns);
|
|
populatePatterns3(populateLoadStoreOpToLLVMPatterns);
|
|
populatePatterns4(populateReduceOpToLLVMPatterns);
|
|
populatePatterns1(populateScanOpToLLVMPatterns);
|
|
populatePatterns2(populateViewOpToLLVMPatterns);
|
|
populatePatterns2(populateBarrierOpToLLVMPatterns);
|
|
populatePatterns2(populateTensorPtrOpsToLLVMPatterns);
|
|
populatePatterns2(populateClusterOpsToLLVMPatterns);
|
|
populatePatterns2(populateRegReallocOpToLLVMPatterns);
|
|
|
|
// TODO(thomas): this should probably be done in a separate step to not
|
|
// interfere with our own lowering of arith ops. Add arith/math's patterns
|
|
// to help convert scalar expression to LLVM.
|
|
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
|
|
mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns);
|
|
|
|
// Native lowering patterns
|
|
switch (target) {
|
|
case Target::NVVM:
|
|
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
|
|
break;
|
|
case Target::ROCDL:
|
|
mlir::populateGpuToROCDLConversionPatterns(typeConverter, patterns,
|
|
mlir::gpu::amd::HIP);
|
|
break;
|
|
}
|
|
|
|
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
|
|
patterns);
|
|
if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
|
|
return signalPassFailure();
|
|
|
|
// Fold CTAId when there is only 1 CTA.
|
|
if (numCTAs == 1) {
|
|
mod.walk([](triton::nvgpu::ClusterCTAIdOp id) {
|
|
OpBuilder b(id);
|
|
Value zero = LLVM::createConstantI32(id->getLoc(), b, 0);
|
|
id.replaceAllUsesWith(zero);
|
|
});
|
|
}
|
|
}
|
|
|
|
private:
|
|
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
|
|
baseIndexCache;
|
|
DenseMap<IndexCacheKeyT, SmallVector<SmallVector<Value>>,
|
|
CacheKeyDenseMapInfo>
|
|
indexCache;
|
|
mlir::triton::gpu::TMAMetadataTy *tmaMetadata = nullptr;
|
|
|
|
void initSharedMemory(ModuleAllocation &allocation,
|
|
TritonGPUToLLVMTypeConverter &typeConverter) {
|
|
ModuleOp mod = getOperation();
|
|
OpBuilder b(mod.getBodyRegion());
|
|
auto ctx = mod.getContext();
|
|
auto loc = mod.getLoc();
|
|
auto elemTy = typeConverter.convertType(b.getIntegerType(8));
|
|
// Set array size 0 and external linkage indicates that we use dynamic
|
|
// shared allocation to allow a larger shared memory size for each kernel.
|
|
auto arrayTy = LLVM::LLVMArrayType::get(elemTy, 0);
|
|
auto global = b.create<LLVM::GlobalOp>(
|
|
loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::External,
|
|
"global_smem", /*value=*/Attribute(), /*alignment=*/0,
|
|
// Add ROCm support.
|
|
static_cast<unsigned>(NVVM::NVVMMemorySpace::kSharedMemorySpace));
|
|
mod.walk([&](FunctionOpInterface funcOp) {
|
|
Value funcSmem;
|
|
b.setInsertionPointToStart(&funcOp.getFunctionBody().front());
|
|
if (allocation.isRoot(funcOp)) {
|
|
funcSmem = b.create<LLVM::AddressOfOp>(loc, global);
|
|
} else {
|
|
funcSmem = funcOp.getArgument(funcOp.getNumArguments() - 1);
|
|
}
|
|
auto ptrTy =
|
|
LLVM::LLVMPointerType::get(typeConverter.convertType(b.getI8Type()),
|
|
NVVM::NVVMMemorySpace::kSharedMemorySpace);
|
|
funcSmem = b.create<LLVM::BitcastOp>(loc, ptrTy, funcSmem);
|
|
allocation.setFunctionSharedMemoryValue(funcOp, funcSmem);
|
|
});
|
|
mod->setAttr("triton_gpu.shared",
|
|
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32),
|
|
allocation.getSharedMemorySize()));
|
|
}
|
|
|
|
void decomposeFp8e4b15Convert(ModuleOp mod) const {
|
|
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
|
OpBuilder builder(cvtOp);
|
|
if (!getElementTypeOrSelf(cvtOp)
|
|
.isa<mlir::Float8E4M3B11FNUZType, mlir::Float8E4M3FNType>())
|
|
return;
|
|
auto shape = cvtOp.getType().cast<RankedTensorType>().getShape();
|
|
auto argEncoding =
|
|
cvtOp.getOperand().getType().cast<RankedTensorType>().getEncoding();
|
|
auto cvtEncoding = cvtOp.getType().cast<RankedTensorType>().getEncoding();
|
|
if (argEncoding.isa<triton::gpu::DotOperandEncodingAttr>() ||
|
|
cvtEncoding.isa<triton::gpu::DotOperandEncodingAttr>())
|
|
return;
|
|
auto F16Ty = builder.getF16Type();
|
|
|
|
auto newArgType = RankedTensorType::get(shape, F16Ty, argEncoding);
|
|
auto newCvtType = RankedTensorType::get(shape, F16Ty, cvtEncoding);
|
|
auto newArg = builder.create<mlir::triton::FpToFpOp>(
|
|
cvtOp.getLoc(), newArgType, cvtOp.getOperand());
|
|
addWSNamedAttrs(newArg, cvtOp->getAttrs());
|
|
auto newCvt = builder.create<mlir::triton::gpu::ConvertLayoutOp>(
|
|
cvtOp.getLoc(), newCvtType, newArg);
|
|
addWSNamedAttrs(newCvt, cvtOp->getAttrs());
|
|
auto newRet = builder.create<mlir::triton::FpToFpOp>(
|
|
cvtOp.getLoc(), cvtOp.getType(), newCvt.getResult());
|
|
addWSNamedAttrs(newRet, cvtOp->getAttrs());
|
|
cvtOp.replaceAllUsesWith(newRet.getResult());
|
|
cvtOp.erase();
|
|
});
|
|
}
|
|
|
|
void decomposeMmaToDotOperand(ModuleOp mod, int numWarps, int threadsPerWarp,
|
|
int numCTAs) const {
|
|
// Replace `mma -> dot_op` with `mma -> blocked -> dot_op`
|
|
// unless certain conditions are met
|
|
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
|
OpBuilder builder(cvtOp);
|
|
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
|
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
|
auto srcMma =
|
|
srcType.getEncoding().dyn_cast<triton::gpu::MmaEncodingAttr>();
|
|
auto dstDotOp =
|
|
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
|
if (srcMma && dstDotOp && !isMmaToDotShortcut(srcType, dstType)) {
|
|
auto tmpType = RankedTensorType::get(
|
|
dstType.getShape(), dstType.getElementType(),
|
|
triton::gpu::BlockedEncodingAttr::get(
|
|
mod.getContext(), srcType.getShape(), getSizePerThread(srcMma),
|
|
getOrder(srcMma), numWarps, threadsPerWarp, numCTAs));
|
|
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
|
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
|
addWSNamedAttrs(tmp, cvtOp->getAttrs());
|
|
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
|
cvtOp.getLoc(), dstType, tmp);
|
|
addWSNamedAttrs(newConvert, cvtOp->getAttrs());
|
|
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
|
cvtOp.erase();
|
|
}
|
|
});
|
|
}
|
|
|
|
#ifdef USE_ROCM
|
|
void decomposeMfmaToDotOperand(ModuleOp mod, int numWarps, int threadsPerWarp,
|
|
int numCTAs) const {
|
|
// Replace `mfma -> dot_op` with `mfma -> blocked -> dot_op`
|
|
// unless certain conditions are met
|
|
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
|
OpBuilder builder(cvtOp);
|
|
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
|
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
|
auto srcMfma =
|
|
srcType.getEncoding().dyn_cast<triton::gpu::MfmaEncodingAttr>();
|
|
auto dstDotOp =
|
|
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
|
if (srcMfma && dstDotOp && !isMfmaToDotShortcut(srcType, dstType)) {
|
|
auto tmpType = RankedTensorType::get(
|
|
dstType.getShape(), dstType.getElementType(),
|
|
triton::gpu::BlockedEncodingAttr::get(
|
|
mod.getContext(), srcType.getShape(), getSizePerThread(srcMfma),
|
|
getOrder(srcMfma), numWarps, threadsPerWarp, numCTAs));
|
|
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
|
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
|
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
|
cvtOp.getLoc(), dstType, tmp);
|
|
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
|
cvtOp.erase();
|
|
}
|
|
});
|
|
}
|
|
|
|
int getCvtOpLDSUsage(triton::gpu::ConvertLayoutOp &cvtOp) const {
|
|
unsigned inVec = 0;
|
|
unsigned outVec = 0;
|
|
auto smemShape = getScratchConfigForCvtLayout(cvtOp, inVec, outVec);
|
|
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
|
std::multiplies{});
|
|
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
|
auto bytes =
|
|
srcType.getElementType().isa<triton::PointerType>()
|
|
? elems * kPtrBitWidth / 8
|
|
: elems * std::max<int>(8, srcType.getElementTypeBitWidth()) / 8;
|
|
|
|
return bytes;
|
|
}
|
|
|
|
bool isPowerOfTwo(unsigned x) const { return x && (x & (x - 1)) == 0; }
|
|
|
|
std::vector<std::pair<int, int>> factorizePowerOf2(int n) const {
|
|
assert(isPowerOfTwo(n));
|
|
int x = log2(n);
|
|
std::vector<std::pair<int, int>> pairs;
|
|
|
|
for (int i = 0; i <= x / 2; ++i) {
|
|
int j = x - i;
|
|
pairs.push_back({pow(2, i), pow(2, j)});
|
|
pairs.push_back({pow(2, j), pow(2, i)});
|
|
}
|
|
|
|
return pairs;
|
|
}
|
|
|
|
std::pair<triton::gpu::ConvertLayoutOp, triton::gpu::ConvertLayoutOp>
|
|
createNewConvertOps(ModuleOp &mod, OpBuilder &builder,
|
|
triton::gpu::ConvertLayoutOp &cvtOp,
|
|
std::pair<unsigned, unsigned> warpsPerCta) const {
|
|
unsigned warpsPerCtaX = warpsPerCta.first;
|
|
unsigned warpsPerCtaY = warpsPerCta.second;
|
|
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
|
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
|
|
|
auto srcMfma =
|
|
srcType.getEncoding().dyn_cast<triton::gpu::MfmaEncodingAttr>();
|
|
auto newMfmaEnc = triton::gpu::MfmaEncodingAttr::get(
|
|
mod.getContext(), srcMfma.getVersionMajor(), srcMfma.getVersionMinor(),
|
|
{warpsPerCtaX, warpsPerCtaY}, srcMfma.getMDim(), srcMfma.getNDim(),
|
|
srcMfma.getIsTransposed());
|
|
|
|
auto newDstType = RankedTensorType::get(
|
|
dstType.getShape(), dstType.getElementType(), dstType.getEncoding());
|
|
auto newSrcType = RankedTensorType::get(
|
|
srcType.getShape(), srcType.getElementType(), newMfmaEnc);
|
|
|
|
auto tmpCvt = builder.create<triton::gpu::ConvertLayoutOp>(
|
|
cvtOp.getLoc(), newSrcType, cvtOp.getOperand());
|
|
auto newEpilogueCvt = builder.create<triton::gpu::ConvertLayoutOp>(
|
|
cvtOp.getLoc(), newDstType, tmpCvt);
|
|
|
|
return std::make_pair(tmpCvt, newEpilogueCvt);
|
|
}
|
|
|
|
// Try to reduce LDS usage of cvt(mfma->blocked) op by changing the shape of
|
|
// WarpsPerCta attribute in mfma layout. The implicit LDS usage of
|
|
// cvt(mfma->blocked) op depends on the number of warps per CTA that mfma
|
|
// layout uses along x dimension and block layout uses across y dimension.
|
|
//
|
|
// clang-format off
|
|
//
|
|
// LDS usage of this op is roughly calculated as:
|
|
// LDS_USAGE = getShapePerCTA(mfma_layout)[0] * getShapePerCTA(blocked_layout)[1] * sizeof(data_type)
|
|
// LDS_USAGE = warpsPerCTA(mfma_layout)[0] * warpsPerCta(blocked_layout)[1] * C,
|
|
// where C = 32 * sizePerWarp(blocked_layout)[1] * threadsPerWarp(blocked_layout)[1] * sizeof(data_type)
|
|
//
|
|
// clang-format on
|
|
//
|
|
// When LDS_USAGE exceeds the size of LDS, try to lower LDS usage by
|
|
// decomposing cvt(mfma->blocked) op into 2 conversions: cvt(mfma->mfma_tmp)
|
|
// and cvt(mfma_tmp->blocked), where mfma_tmp has WarpsPerCta attribute that
|
|
// minimizes uses of LDS for these conversions.
|
|
void reduceCvtOpLDSUsage(ModuleOp mod) const {
|
|
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
|
OpBuilder builder(cvtOp);
|
|
|
|
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
|
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
|
|
|
auto srcMfma =
|
|
srcType.getEncoding().dyn_cast<triton::gpu::MfmaEncodingAttr>();
|
|
auto dstBlocked =
|
|
dstType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
|
|
|
if (!srcMfma || !dstBlocked) {
|
|
return;
|
|
}
|
|
|
|
auto currLDSUsage = getCvtOpLDSUsage(cvtOp);
|
|
if (currLDSUsage <= LDSSize) {
|
|
return;
|
|
}
|
|
|
|
unsigned numWarps =
|
|
srcMfma.getWarpsPerCTA()[0] * srcMfma.getWarpsPerCTA()[1];
|
|
|
|
triton::gpu::ConvertLayoutOp tmpCvt;
|
|
triton::gpu::ConvertLayoutOp newEpilogueCvt;
|
|
|
|
// Find all possible shapes of WarpsPerCTA by finding all possible
|
|
// factorizations of numWarps. Pick shape for which both conversions in
|
|
// decomposition use LDS less than LDSSize and for which sum of LDS usage
|
|
// is minimal. If no such shape exists, do not decompose.
|
|
unsigned minLDSUsage = 2 * LDSSize;
|
|
int minIdx = -1;
|
|
auto factorizedNumWarps = factorizePowerOf2(numWarps);
|
|
|
|
for (int i = 0; i < factorizedNumWarps.size(); i++) {
|
|
auto warpsPerCTAPair = factorizedNumWarps[i];
|
|
std::tie(tmpCvt, newEpilogueCvt) =
|
|
createNewConvertOps(mod, builder, cvtOp, warpsPerCTAPair);
|
|
|
|
int tmpCvtLDS = getCvtOpLDSUsage(tmpCvt);
|
|
int newCvtLDS = getCvtOpLDSUsage(newEpilogueCvt);
|
|
if (tmpCvtLDS <= LDSSize && newCvtLDS <= LDSSize) {
|
|
int LDSUsage = tmpCvtLDS + newCvtLDS;
|
|
if (LDSUsage < minLDSUsage) {
|
|
minLDSUsage = LDSUsage;
|
|
minIdx = i;
|
|
}
|
|
}
|
|
newEpilogueCvt.erase();
|
|
tmpCvt.erase();
|
|
}
|
|
|
|
if (minIdx == -1) {
|
|
return;
|
|
}
|
|
|
|
assert(minIdx >= 0 && minIdx < factorizedNumWarps.size());
|
|
auto warpsPerCTAPair = factorizedNumWarps[minIdx];
|
|
std::tie(tmpCvt, newEpilogueCvt) =
|
|
createNewConvertOps(mod, builder, cvtOp, warpsPerCTAPair);
|
|
|
|
cvtOp.replaceAllUsesWith(newEpilogueCvt.getResult());
|
|
cvtOp.erase();
|
|
});
|
|
}
|
|
|
|
#endif
|
|
|
|
void decomposeBlockedToDotOperand(ModuleOp mod) const {
|
|
// Replace `blocked -> dot_op` with `blocked -> shared -> dot_op`
|
|
// because the codegen doesn't handle `blocked -> dot_op` directly
|
|
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
|
OpBuilder builder(cvtOp);
|
|
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
|
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
|
auto srcBlocked =
|
|
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
|
auto dstDotOp =
|
|
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
|
if (srcBlocked && dstDotOp) {
|
|
auto tmpType = RankedTensorType::get(
|
|
dstType.getShape(), dstType.getElementType(),
|
|
triton::gpu::SharedEncodingAttr::get(
|
|
mod.getContext(), dstDotOp, srcType.getShape(),
|
|
srcBlocked.getOrder(), srcBlocked.getCTALayout(),
|
|
srcType.getElementType()));
|
|
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
|
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
|
addWSNamedAttrs(tmp, cvtOp->getAttrs());
|
|
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
|
cvtOp.getLoc(), dstType, tmp);
|
|
addWSNamedAttrs(newConvert, cvtOp->getAttrs());
|
|
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
|
cvtOp.erase();
|
|
}
|
|
});
|
|
}
|
|
|
|
void decomposeInsertSliceAsyncOp(ModuleOp mod) const {
|
|
ModuleAxisInfoAnalysis axisInfoAnalysis(mod);
|
|
// TODO(Keren): This is a hacky knob that may cause performance regression
|
|
// when decomposition has been performed. We should remove this knob once we
|
|
// have thorough analysis on async wait. Currently, we decompose
|
|
// `insert_slice_async` into `load` and `insert_slice` without knowing which
|
|
// `async_wait` is responsible for the `insert_slice_async`. To guarantee
|
|
// correctness, we blindly set the `async_wait` to wait for all async ops.
|
|
//
|
|
// There are two options to improve this:
|
|
// 1. We can perform a dataflow analysis to find the `async_wait` that is
|
|
// responsible for the `insert_slice_async` in the backend.
|
|
// 2. We can modify the pipeline to perform the decomposition before the
|
|
// `async_wait` is inserted. However, it is also risky because we don't know
|
|
// the correct vectorized shape yet in the pipeline pass. Making the
|
|
// pipeline pass aware of the vectorization could introduce additional
|
|
// dependencies on the AxisInfoAnalysis and the Coalesce analysis.
|
|
bool decomposed = false;
|
|
// insert_slice_async %src, %dst, %idx, %mask, %other
|
|
// =>
|
|
// %tmp = load %src, %mask, %other
|
|
// %res = insert_slice %tmp into %dst[%idx]
|
|
mod.walk([&](triton::gpu::InsertSliceAsyncOp insertSliceAsyncOp) -> void {
|
|
OpBuilder builder(insertSliceAsyncOp);
|
|
|
|
// Get the vectorized load size
|
|
auto src = insertSliceAsyncOp.getSrc();
|
|
auto dst = insertSliceAsyncOp.getDst();
|
|
auto mask = insertSliceAsyncOp.getMask();
|
|
auto srcTy = src.getType().cast<RankedTensorType>();
|
|
auto dstTy = dst.getType().cast<RankedTensorType>();
|
|
auto srcLayout = srcTy.getEncoding();
|
|
assert((srcLayout.isa<BlockedEncodingAttr, SliceEncodingAttr>() &&
|
|
"Unexpected srcLayout"));
|
|
|
|
auto resSharedLayout =
|
|
dstTy.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
|
auto resElemTy = dstTy.getElementType();
|
|
unsigned inVec = axisInfoAnalysis.getPtrContiguity(src);
|
|
if (mask)
|
|
inVec =
|
|
std::min<unsigned>(axisInfoAnalysis.getMaskAlignment(mask), inVec);
|
|
unsigned outVec = resSharedLayout.getVec();
|
|
unsigned minVec = inVec;
|
|
if (outVec > 1)
|
|
minVec = std::min(outVec, inVec);
|
|
auto maxBitWidth =
|
|
std::max<unsigned>(128, resElemTy.getIntOrFloatBitWidth());
|
|
auto vecBitWidth = resElemTy.getIntOrFloatBitWidth() * minVec;
|
|
auto bitWidth = std::min<unsigned>(maxBitWidth, vecBitWidth);
|
|
auto byteWidth = bitWidth / 8;
|
|
|
|
// If the load byte width is not eligible or the current compute
|
|
// capability does not support async copy, then we do decompose
|
|
#ifndef USE_ROCM
|
|
if (triton::gpu::InsertSliceAsyncOp::getEligibleLoadByteWidth(
|
|
computeCapability)
|
|
.contains(byteWidth)) {
|
|
return;
|
|
}
|
|
#endif
|
|
|
|
// load
|
|
auto tmpTy =
|
|
RankedTensorType::get(srcTy.getShape(), resElemTy, srcLayout);
|
|
auto loadOp = builder.create<triton::LoadOp>(
|
|
insertSliceAsyncOp.getLoc(), tmpTy, insertSliceAsyncOp.getSrc(),
|
|
insertSliceAsyncOp.getMask(), insertSliceAsyncOp.getOther(),
|
|
// TODO(Chenggang): confirm `boundaryCheck` and `padding`
|
|
/*boundaryCheck=*/nullptr, /*padding=*/nullptr,
|
|
insertSliceAsyncOp.getCache(), insertSliceAsyncOp.getEvict(),
|
|
insertSliceAsyncOp.getIsVolatile());
|
|
addWSNamedAttrs(loadOp, insertSliceAsyncOp->getAttrs());
|
|
|
|
// insert_slice
|
|
auto axis = insertSliceAsyncOp.getAxis();
|
|
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
|
|
auto offsets = SmallVector<OpFoldResult>(dstTy.getRank(), intAttr(0));
|
|
auto sizes = SmallVector<OpFoldResult>(dstTy.getRank(), intAttr(1));
|
|
auto strides = SmallVector<OpFoldResult>(dstTy.getRank(), intAttr(1));
|
|
offsets[axis] = insertSliceAsyncOp.getIndex();
|
|
for (size_t i = 0; i < dstTy.getRank(); i++) {
|
|
if (i != axis)
|
|
sizes[i] = intAttr(dstTy.getShape()[i]);
|
|
}
|
|
auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
|
|
insertSliceAsyncOp.getLoc(), loadOp, insertSliceAsyncOp.getDst(),
|
|
offsets, sizes, strides);
|
|
addWSNamedAttrs(insertSliceOp, insertSliceAsyncOp->getAttrs());
|
|
|
|
// Replace
|
|
insertSliceAsyncOp.replaceAllUsesWith(insertSliceOp.getResult());
|
|
insertSliceAsyncOp.erase();
|
|
decomposed = true;
|
|
});
|
|
|
|
mod.walk([&](triton::gpu::AsyncCommitGroupOp asyncCommitGroupOp) -> void {
|
|
#ifdef USE_ROCM
|
|
asyncCommitGroupOp.erase();
|
|
#else
|
|
if (!triton::gpu::AsyncCommitGroupOp::isSupported(computeCapability))
|
|
asyncCommitGroupOp.erase();
|
|
#endif
|
|
});
|
|
|
|
mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void {
|
|
#ifdef USE_ROCM
|
|
// AsyncWait is not supported for ROCM and should be removed
|
|
asyncWaitOp.erase();
|
|
#else
|
|
if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability)) {
|
|
// async wait is supported in Ampere and later
|
|
asyncWaitOp.erase();
|
|
} else if (decomposed) {
|
|
// Wait for all previous async ops
|
|
OpBuilder builder(asyncWaitOp);
|
|
auto newWaitOp =
|
|
builder.create<triton::gpu::AsyncWaitOp>(asyncWaitOp.getLoc(), 0);
|
|
addWSNamedAttrs(newWaitOp, asyncWaitOp->getAttrs());
|
|
asyncWaitOp.erase();
|
|
}
|
|
#endif
|
|
});
|
|
}
|
|
|
|
static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
|
|
Type promotedType) {
|
|
Type tensorPromotedType =
|
|
operand.getType().cast<RankedTensorType>().cloneWith(std::nullopt,
|
|
promotedType);
|
|
return builder.create<triton::FpToFpOp>(loc, tensorPromotedType, operand);
|
|
}
|
|
|
|
// promote operands of dot op if the existing combination is not natively
|
|
// supported.
|
|
void decomposeMixedModeDotOp(ModuleOp mod) const {
|
|
mod.walk([](triton::DotOp dotOp) -> void {
|
|
Value D = dotOp.getResult();
|
|
OpBuilder builder(dotOp);
|
|
Type AElType =
|
|
dotOp.getA().getType().cast<RankedTensorType>().getElementType();
|
|
Type promoteType;
|
|
MmaEncodingAttr mmaLayout = D.getType()
|
|
.cast<RankedTensorType>()
|
|
.getEncoding()
|
|
.dyn_cast<MmaEncodingAttr>();
|
|
if (mmaLayout) {
|
|
bool isNativeHopperFP8 =
|
|
AElType.isFloat8E5M2() || AElType.isFloat8E4M3FNUZ();
|
|
bool isFP8 = isNativeHopperFP8 || AElType.isFloat8E5M2FNUZ() ||
|
|
AElType.isFloat8E4M3FN() || AElType.isFloat8E4M3B11FNUZ();
|
|
if (!isFP8 || (isNativeHopperFP8 && mmaLayout.isHopper()))
|
|
return;
|
|
promoteType = builder.getF16Type();
|
|
#ifdef USE_ROCM
|
|
} else if (MfmaEncodingAttr mfmaLayout =
|
|
D.getType()
|
|
.cast<RankedTensorType>()
|
|
.getEncoding()
|
|
.dyn_cast<MfmaEncodingAttr>()) {
|
|
Type BElType =
|
|
dotOp.getB().getType().cast<RankedTensorType>().getElementType();
|
|
|
|
auto maxBitWidth = std::max(AElType.getIntOrFloatBitWidth(),
|
|
BElType.getIntOrFloatBitWidth());
|
|
|
|
// TODO check mfma tensor core version compatibility
|
|
if (maxBitWidth == 8)
|
|
return;
|
|
|
|
if (AElType == BElType)
|
|
return;
|
|
|
|
if (maxBitWidth < 16)
|
|
promoteType = builder.getF16Type();
|
|
else if (maxBitWidth <= 32)
|
|
promoteType = builder.getF32Type();
|
|
#endif
|
|
} else {
|
|
// FMA case.
|
|
Type AElType =
|
|
dotOp.getA().getType().cast<RankedTensorType>().getElementType();
|
|
Type DElType = D.getType().cast<RankedTensorType>().getElementType();
|
|
if (AElType == DElType)
|
|
return;
|
|
promoteType = DElType;
|
|
}
|
|
Location loc = dotOp.getLoc();
|
|
Value promotedA = promoteOperand(builder, loc, dotOp.getA(), promoteType);
|
|
Value promotedB = promoteOperand(builder, loc, dotOp.getB(), promoteType);
|
|
dotOp.setOperand(0, promotedA);
|
|
dotOp.setOperand(1, promotedB);
|
|
});
|
|
}
|
|
};
|
|
|
|
} // anonymous namespace
|
|
|
|
namespace mlir {
|
|
namespace triton {
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass() {
|
|
return std::make_unique<ConvertTritonGPUToLLVM>();
|
|
}
|
|
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass(
|
|
int32_t computeCapability, Target target,
|
|
mlir::triton::gpu::TMAMetadataTy *tmaMetadata) {
|
|
return std::make_unique<ConvertTritonGPUToLLVM>(computeCapability, target,
|
|
tmaMetadata);
|
|
}
|
|
|
|
} // namespace triton
|
|
} // namespace mlir
|