[TritonGPUToLLVM] Correct the usage of option passing (#2104)

For example, when given `--convert-triton-gpu-to-llvm="is-rocm=true"`,
`ConvertTritonGPUToLLVMPass` should generate ROCM-compatible LLVM.
Before this PR, transformation options passed in command line are not
respected.
This commit is contained in:
Whitney Tsang
2023-08-15 20:56:01 -04:00
committed by GitHub
parent 780266c3a2
commit 129e7dfc6f
4 changed files with 24 additions and 27 deletions

View File

@@ -27,7 +27,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">,
Option<"TmaMetadata", "tma-metadata",
Option<"tmaMetadata", "tma-metadata",
"mlir::triton::gpu::TMAMetadataTy*", /*default*/"nullptr",
"tma metadata to the runtime">,
Option<"isROCM", "is-rocm",

View File

@@ -14,10 +14,12 @@ template <typename T> class OperationPass;
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass(
int computeCapability = 80,
mlir::triton::gpu::TMAMetadataTy *tmaMetadata = nullptr,
bool isROCM = false);
#define GEN_PASS_DECL
#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc"
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass();
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonGPUToLLVMPass(const ConvertTritonGPUToLLVMOptions &options);
} // namespace triton

View File

@@ -40,13 +40,17 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
namespace mlir {
namespace triton {
#define GEN_PASS_DEF_CONVERTTRITONGPUTOLLVM
#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc"
} // namespace triton
} // namespace mlir
using namespace mlir;
using namespace mlir::triton;
namespace ttng = mlir::triton::nvidia_gpu;
#define GEN_PASS_CLASSES
#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc"
namespace {
// pass ws related named attrs.
@@ -372,15 +376,10 @@ public:
}
};
class ConvertTritonGPUToLLVM
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
public:
explicit ConvertTritonGPUToLLVM(int computeCapability,
mlir::triton::gpu::TMAMetadataTy *tmaMetadata,
bool isROCM)
: computeCapability(computeCapability), tmaMetadata(tmaMetadata),
isROCM(isROCM) {}
struct ConvertTritonGPUToLLVM
: public triton::impl::ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
using ConvertTritonGPUToLLVMBase<
ConvertTritonGPUToLLVM>::ConvertTritonGPUToLLVMBase;
void runOnOperation() override {
MLIRContext *context = &getContext();
@@ -569,10 +568,6 @@ private:
CacheKeyDenseMapInfo>
indexCache;
int computeCapability{};
bool isROCM{};
mlir::triton::gpu::TMAMetadataTy *tmaMetadata;
void initSharedMemory(ModuleAllocation &allocation,
TritonGPUToLLVMTypeConverter &typeConverter) {
ModuleOp mod = getOperation();
@@ -862,12 +857,12 @@ private:
namespace mlir {
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass() {
return std::make_unique<ConvertTritonGPUToLLVM>();
}
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonGPUToLLVMPass(int computeCapability,
mlir::triton::gpu::TMAMetadataTy *tmaMetadata,
bool isROCM) {
return std::make_unique<::ConvertTritonGPUToLLVM>(computeCapability,
tmaMetadata, isROCM);
createConvertTritonGPUToLLVMPass(const ConvertTritonGPUToLLVMOptions &options) {
return std::make_unique<ConvertTritonGPUToLLVM>(options);
}
} // namespace triton

View File

@@ -351,7 +351,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(mlir::createConvertIndexToLLVMPass());
pm.addPass(
createConvertTritonGPUToLLVMPass(computeCapability, &tmaInfos, isROCM));
createConvertTritonGPUToLLVMPass({computeCapability, &tmaInfos, isROCM}));
pm.addPass(createConvertNVGPUToLLVMPass());
pm.addPass(mlir::createArithToLLVMConversionPass());
pm.addPass(mlir::createCanonicalizerPass());