[TESTS] Add triton version of mlir-reduce (#1498)

[`mlir-reduce`](https://mlir.llvm.org/docs/Tools/mlir-reduce/) is a tool
to reduce the complexity of bug reproducers written in mlir. Similar to
`triton-opt`, `triton` needs to have its own version with the dialects
registered properly for it to work.
This commit is contained in:
peterbell10
2023-04-10 20:31:11 +00:00
committed by GitHub
parent 8c55276c90
commit 2c06f875e4
4 changed files with 71 additions and 34 deletions

View File

@@ -21,6 +21,26 @@ target_link_libraries(triton-opt PRIVATE
mlir_check_all_link_libraries(triton-opt)
add_llvm_executable(triton-reduce triton-reduce.cpp PARTIAL_SOURCES_INTENDED)
mlir_check_all_link_libraries(triton-reduce)
llvm_update_compile_flags(triton-reduce)
target_link_libraries(triton-reduce PRIVATE
TritonAnalysis
TritonTransforms
TritonGPUTransforms
${dialect_libs}
${conversion_libs}
# tests
TritonTestAnalysis
# MLIR core
MLIRReduceLib
MLIRPass
MLIRTransforms
)
mlir_check_all_link_libraries(triton-reduce)
add_llvm_executable(triton-translate triton-translate.cpp PARTIAL_SOURCES_INTENDED)
llvm_update_compile_flags(triton-translate)

View File

@@ -0,0 +1,38 @@
#pragma once
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
#include "triton/Conversion/TritonToTritonGPU/Passes.h"
#include "mlir/InitAllPasses.h"
namespace mlir {
namespace test {
void registerTestAliasPass();
void registerTestAlignmentPass();
void registerTestAllocationPass();
void registerTestMembarPass();
} // namespace test
} // namespace mlir
inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::registerAllPasses();
mlir::registerTritonPasses();
mlir::registerTritonGPUPasses();
mlir::test::registerTestAliasPass();
mlir::test::registerTestAlignmentPass();
mlir::test::registerTestAllocationPass();
mlir::test::registerTestMembarPass();
mlir::triton::registerConvertTritonToTritonGPUPass();
mlir::triton::registerConvertTritonGPUToLLVMPass();
// TODO: register Triton & TritonGPU passes
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
mlir::triton::gpu::TritonGPUDialect, mlir::func::FuncDialect,
mlir::math::MathDialect, mlir::arith::ArithDialect,
mlir::scf::SCFDialect, mlir::gpu::GPUDialect>();
}

View File

@@ -1,42 +1,10 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "./RegisterTritonDialects.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
#include "triton/Conversion/TritonToTritonGPU/Passes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
namespace mlir {
namespace test {
void registerTestAliasPass();
void registerTestAlignmentPass();
void registerTestAllocationPass();
void registerTestMembarPass();
} // namespace test
} // namespace mlir
int main(int argc, char **argv) {
mlir::registerAllPasses();
mlir::registerTritonPasses();
mlir::registerTritonGPUPasses();
mlir::test::registerTestAliasPass();
mlir::test::registerTestAlignmentPass();
mlir::test::registerTestAllocationPass();
mlir::test::registerTestMembarPass();
mlir::triton::registerConvertTritonToTritonGPUPass();
mlir::triton::registerConvertTritonGPUToLLVMPass();
// TODO: register Triton & TritonGPU passes
mlir::DialectRegistry registry;
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
mlir::triton::gpu::TritonGPUDialect, mlir::func::FuncDialect,
mlir::math::MathDialect, mlir::arith::ArithDialect,
mlir::scf::SCFDialect, mlir::gpu::GPUDialect>();
registerTritonDialects(registry);
return mlir::asMainReturnCode(mlir::MlirOptMain(
argc, argv, "Triton (GPU) optimizer driver\n", registry));

11
bin/triton-reduce.cpp Normal file
View File

@@ -0,0 +1,11 @@
#include "./RegisterTritonDialects.h"
#include "mlir/Tools/mlir-reduce/MlirReduceMain.h"
int main(int argc, char **argv) {
mlir::DialectRegistry registry;
registerTritonDialects(registry);
mlir::MLIRContext context(registry);
return mlir::failed(mlir::mlirReduceMain(argc, argv, context));
}