mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
[TritonGPUToLLVM] Correct the usage of option passing (#2104)
For example, when given `--convert-triton-gpu-to-llvm="is-rocm=true"`, `ConvertTritonGPUToLLVMPass` should generate ROCM-compatible LLVM. Before this PR, transformation options passed in command line are not respected.
This commit is contained in:
@@ -27,7 +27,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">,
|
||||
Option<"TmaMetadata", "tma-metadata",
|
||||
Option<"tmaMetadata", "tma-metadata",
|
||||
"mlir::triton::gpu::TMAMetadataTy*", /*default*/"nullptr",
|
||||
"tma metadata to the runtime">,
|
||||
Option<"isROCM", "is-rocm",
|
||||
|
||||
@@ -14,10 +14,12 @@ template <typename T> class OperationPass;
|
||||
|
||||
namespace triton {
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass(
|
||||
int computeCapability = 80,
|
||||
mlir::triton::gpu::TMAMetadataTy *tmaMetadata = nullptr,
|
||||
bool isROCM = false);
|
||||
#define GEN_PASS_DECL
|
||||
#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc"
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass();
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertTritonGPUToLLVMPass(const ConvertTritonGPUToLLVMOptions &options);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
|
||||
@@ -40,13 +40,17 @@
|
||||
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
#define GEN_PASS_DEF_CONVERTTRITONGPUTOLLVM
|
||||
#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc"
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
namespace ttng = mlir::triton::nvidia_gpu;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc"
|
||||
|
||||
namespace {
|
||||
|
||||
// pass ws related named attrs.
|
||||
@@ -372,15 +376,10 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertTritonGPUToLLVM
|
||||
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
|
||||
|
||||
public:
|
||||
explicit ConvertTritonGPUToLLVM(int computeCapability,
|
||||
mlir::triton::gpu::TMAMetadataTy *tmaMetadata,
|
||||
bool isROCM)
|
||||
: computeCapability(computeCapability), tmaMetadata(tmaMetadata),
|
||||
isROCM(isROCM) {}
|
||||
struct ConvertTritonGPUToLLVM
|
||||
: public triton::impl::ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
|
||||
using ConvertTritonGPUToLLVMBase<
|
||||
ConvertTritonGPUToLLVM>::ConvertTritonGPUToLLVMBase;
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
@@ -569,10 +568,6 @@ private:
|
||||
CacheKeyDenseMapInfo>
|
||||
indexCache;
|
||||
|
||||
int computeCapability{};
|
||||
bool isROCM{};
|
||||
mlir::triton::gpu::TMAMetadataTy *tmaMetadata;
|
||||
|
||||
void initSharedMemory(ModuleAllocation &allocation,
|
||||
TritonGPUToLLVMTypeConverter &typeConverter) {
|
||||
ModuleOp mod = getOperation();
|
||||
@@ -862,12 +857,12 @@ private:
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass() {
|
||||
return std::make_unique<ConvertTritonGPUToLLVM>();
|
||||
}
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertTritonGPUToLLVMPass(int computeCapability,
|
||||
mlir::triton::gpu::TMAMetadataTy *tmaMetadata,
|
||||
bool isROCM) {
|
||||
return std::make_unique<::ConvertTritonGPUToLLVM>(computeCapability,
|
||||
tmaMetadata, isROCM);
|
||||
createConvertTritonGPUToLLVMPass(const ConvertTritonGPUToLLVMOptions &options) {
|
||||
return std::make_unique<ConvertTritonGPUToLLVM>(options);
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
|
||||
@@ -351,7 +351,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
pm.addPass(mlir::createConvertSCFToCFPass());
|
||||
pm.addPass(mlir::createConvertIndexToLLVMPass());
|
||||
pm.addPass(
|
||||
createConvertTritonGPUToLLVMPass(computeCapability, &tmaInfos, isROCM));
|
||||
createConvertTritonGPUToLLVMPass({computeCapability, &tmaInfos, isROCM}));
|
||||
pm.addPass(createConvertNVGPUToLLVMPass());
|
||||
pm.addPass(mlir::createArithToLLVMConversionPass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
|
||||
Reference in New Issue
Block a user