diff --git a/include/triton/Conversion/TritonGPUToLLVM/Passes.td b/include/triton/Conversion/TritonGPUToLLVM/Passes.td index 3380360d0..6d3dbaf54 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Passes.td +++ b/include/triton/Conversion/TritonGPUToLLVM/Passes.td @@ -27,7 +27,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp" Option<"computeCapability", "compute-capability", "int32_t", /*default*/"80", "device compute capability">, - Option<"TmaMetadata", "tma-metadata", + Option<"tmaMetadata", "tma-metadata", "mlir::triton::gpu::TMAMetadataTy*", /*default*/"nullptr", "tma metadata to the runtime">, Option<"isROCM", "is-rocm", diff --git a/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h b/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h index e1e1bf297..df581d911 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h @@ -14,10 +14,12 @@ template class OperationPass; namespace triton { -std::unique_ptr> createConvertTritonGPUToLLVMPass( - int computeCapability = 80, - mlir::triton::gpu::TMAMetadataTy *tmaMetadata = nullptr, - bool isROCM = false); +#define GEN_PASS_DECL +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" + +std::unique_ptr> createConvertTritonGPUToLLVMPass(); +std::unique_ptr> +createConvertTritonGPUToLLVMPass(const ConvertTritonGPUToLLVMOptions &options); } // namespace triton diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp index 112e2e7b4..30ee8f18b 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp @@ -40,13 +40,17 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTTRITONGPUTOLLVM +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + using namespace mlir; using namespace mlir::triton; namespace ttng = mlir::triton::nvidia_gpu; -#define GEN_PASS_CLASSES -#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" - namespace { // pass ws related named attrs. @@ -372,15 +376,10 @@ public: } }; -class ConvertTritonGPUToLLVM - : public ConvertTritonGPUToLLVMBase { - -public: - explicit ConvertTritonGPUToLLVM(int computeCapability, - mlir::triton::gpu::TMAMetadataTy *tmaMetadata, - bool isROCM) - : computeCapability(computeCapability), tmaMetadata(tmaMetadata), - isROCM(isROCM) {} +struct ConvertTritonGPUToLLVM + : public triton::impl::ConvertTritonGPUToLLVMBase { + using ConvertTritonGPUToLLVMBase< + ConvertTritonGPUToLLVM>::ConvertTritonGPUToLLVMBase; void runOnOperation() override { MLIRContext *context = &getContext(); @@ -569,10 +568,6 @@ private: CacheKeyDenseMapInfo> indexCache; - int computeCapability{}; - bool isROCM{}; - mlir::triton::gpu::TMAMetadataTy *tmaMetadata; - void initSharedMemory(ModuleAllocation &allocation, TritonGPUToLLVMTypeConverter &typeConverter) { ModuleOp mod = getOperation(); @@ -862,12 +857,12 @@ private: namespace mlir { namespace triton { +std::unique_ptr> createConvertTritonGPUToLLVMPass() { + return std::make_unique(); +} std::unique_ptr> -createConvertTritonGPUToLLVMPass(int computeCapability, - mlir::triton::gpu::TMAMetadataTy *tmaMetadata, - bool isROCM) { - return std::make_unique<::ConvertTritonGPUToLLVM>(computeCapability, - tmaMetadata, isROCM); +createConvertTritonGPUToLLVMPass(const ConvertTritonGPUToLLVMOptions &options) { + return std::make_unique(options); } } // namespace triton diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index 6e6319be2..3c6c165c1 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -351,7 +351,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, pm.addPass(mlir::createConvertSCFToCFPass()); pm.addPass(mlir::createConvertIndexToLLVMPass()); pm.addPass( - createConvertTritonGPUToLLVMPass(computeCapability, &tmaInfos, isROCM)); + createConvertTritonGPUToLLVMPass({computeCapability, &tmaInfos, isROCM})); pm.addPass(createConvertNVGPUToLLVMPass()); pm.addPass(mlir::createArithToLLVMConversionPass()); pm.addPass(mlir::createCanonicalizerPass());