From 8a9cce64e31b2ba1282e1d29a64c9faaa22a757a Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Fri, 11 Feb 2022 14:14:38 +0100 Subject: [PATCH] enhance(compiler): Add custom finalize bufferize pass to handle memref.tensor_load op --- compiler/include/concretelang/CMakeLists.txt | 3 +- .../concretelang/Transforms/Bufferize.h | 21 ++++++ .../concretelang/Transforms/Bufferize.td | 14 ++++ .../concretelang/Transforms/CMakeLists.txt | 3 + compiler/lib/CMakeLists.txt | 1 + compiler/lib/Support/CMakeLists.txt | 1 + compiler/lib/Support/Pipeline.cpp | 6 +- compiler/lib/Transforms/Bufferize.cpp | 73 +++++++++++++++++++ compiler/lib/Transforms/CMakeLists.txt | 16 ++++ 9 files changed, 135 insertions(+), 3 deletions(-) create mode 100644 compiler/include/concretelang/Transforms/Bufferize.h create mode 100644 compiler/include/concretelang/Transforms/Bufferize.td create mode 100644 compiler/include/concretelang/Transforms/CMakeLists.txt create mode 100644 compiler/lib/Transforms/Bufferize.cpp create mode 100644 compiler/lib/Transforms/CMakeLists.txt diff --git a/compiler/include/concretelang/CMakeLists.txt b/compiler/include/concretelang/CMakeLists.txt index cf45425ad..75ae0af0c 100644 --- a/compiler/include/concretelang/CMakeLists.txt +++ b/compiler/include/concretelang/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(Dialect) -add_subdirectory(Conversion) \ No newline at end of file +add_subdirectory(Conversion) +add_subdirectory(Transforms) \ No newline at end of file diff --git a/compiler/include/concretelang/Transforms/Bufferize.h b/compiler/include/concretelang/Transforms/Bufferize.h new file mode 100644 index 000000000..2bceae6aa --- /dev/null +++ b/compiler/include/concretelang/Transforms/Bufferize.h @@ -0,0 +1,21 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_BUFFERIZE_PASS_H +#define CONCRETELANG_BUFFERIZE_PASS_H + +#include +#include + +#define GEN_PASS_CLASSES +#include + +namespace mlir { +namespace concretelang { +std::unique_ptr createFinalizingBufferizePass(); +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compiler/include/concretelang/Transforms/Bufferize.td b/compiler/include/concretelang/Transforms/Bufferize.td new file mode 100644 index 000000000..d368b79d1 --- /dev/null +++ b/compiler/include/concretelang/Transforms/Bufferize.td @@ -0,0 +1,14 @@ +#ifndef CONCRETELANG_FHELINALG_TILING_PASS +#define CONCRETELANG_FHELINALG_TILING_PASS + +include "mlir/Pass/PassBase.td" + +def FinalizingBufferize : FunctionPass<"concretelang-bufferize"> { + let summary = + "Marks FHELinalg operations for tiling using a vector of tile sizes"; + let constructor = "mlir::concretelang::createBufferizePass()"; + let options = []; + let dependentDialects = [ "mlir::memref::MemRefDialect" ]; +} + +#endif diff --git a/compiler/include/concretelang/Transforms/CMakeLists.txt b/compiler/include/concretelang/Transforms/CMakeLists.txt new file mode 100644 index 000000000..d93f30a0d --- /dev/null +++ b/compiler/include/concretelang/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Bufferize.td) +mlir_tablegen(Bufferize.h.inc -gen-pass-decls -name Transforms) +add_public_tablegen_target(ConcretelangTransformsPassIncGen) diff --git a/compiler/lib/CMakeLists.txt b/compiler/lib/CMakeLists.txt index 96829a24b..6c407d083 100644 --- a/compiler/lib/CMakeLists.txt +++ b/compiler/lib/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(Dialect) add_subdirectory(Conversion) +add_subdirectory(Transforms) add_subdirectory(Support) add_subdirectory(Runtime) add_subdirectory(ClientLib) diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index 52a7bbeb5..9ff8f2982 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -26,6 +26,7 @@ add_mlir_library(ConcretelangSupport MLIRLowerableDialectsToLLVM FHEDialectAnalysis RTDialectAnalysis + ConcretelangTransforms MLIRExecutionEngine ${LLVM_PTHREAD_LIB} diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index d2badcd00..36fe58e35 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -26,6 +26,7 @@ #include #include #include +#include namespace mlir { namespace concretelang { @@ -234,8 +235,9 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module, addPotentiallyNestedPass(pm, mlir::createFuncBufferizePass(), enablePass); addPotentiallyNestedPass( pm, mlir::concretelang::createBufferizeDataflowTaskOpsPass(), enablePass); - addPotentiallyNestedPass(pm, mlir::createFinalizingBufferizePass(), - enablePass); + addPotentiallyNestedPass( + pm, mlir::concretelang::createFinalizingBufferizePass(), enablePass); + if (parallelizeLoops) addPotentiallyNestedPass(pm, mlir::createConvertSCFToOpenMPPass(), enablePass); diff --git a/compiler/lib/Transforms/Bufferize.cpp b/compiler/lib/Transforms/Bufferize.cpp new file mode 100644 index 000000000..62b285ac6 --- /dev/null +++ b/compiler/lib/Transforms/Bufferize.cpp @@ -0,0 +1,73 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt +// for license information. + +#include "mlir/Transforms/Bufferize.h" +#include "concretelang/Transforms/Bufferize.h" + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Operation.h" +#include "mlir/Transforms/Passes.h" +using namespace mlir; + +namespace { +// In a finalizing bufferize conversion, we know that all tensors have been +// converted to memrefs, thus, this op becomes an identity. +class BufferizeTensorStoreOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(memref::TensorStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.tensor(), op.memref()); + return success(); + } +}; +} // namespace + +void populatePatterns(BufferizeTypeConverter &typeConverter, + RewritePatternSet &patterns) { + mlir::populateEliminateBufferizeMaterializationsPatterns(typeConverter, + patterns); + patterns.add(typeConverter, patterns.getContext()); +} + +namespace { +struct FinalizingBufferizePass + : public FinalizingBufferizeBase { + using FinalizingBufferizeBase< + FinalizingBufferizePass>::FinalizingBufferizeBase; + + void runOnFunction() override { + auto func = getFunction(); + auto *context = &getContext(); + + BufferizeTypeConverter typeConverter; + RewritePatternSet patterns(context); + ConversionTarget target(*context); + populatePatterns(typeConverter, patterns); + + // If all result types are legal, and all block arguments are legal (ensured + // by func conversion above), then all types in the program are legal. + // + // We also check that the operand types are legal to avoid creating invalid + // IR. For example, this prevents + // populateEliminateBufferizeMaterializationsPatterns from updating the + // types of the operands to a return op without updating the enclosing + // function. + target.markUnknownOpDynamicallyLegal( + [&](Operation *op) { return typeConverter.isLegal(op); }); + target.addLegalOp(); + + if (failed(applyFullConversion(func, target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr +mlir::concretelang::createFinalizingBufferizePass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/compiler/lib/Transforms/CMakeLists.txt b/compiler/lib/Transforms/CMakeLists.txt new file mode 100644 index 000000000..5606bdd55 --- /dev/null +++ b/compiler/lib/Transforms/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_library(ConcretelangTransforms + Bufferize.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/concretelang/Transforms + + DEPENDS + MLIRTransforms + ConcretelangTransformsPassIncGen + + LINK_LIBS PUBLIC + MLIRMemRef + MLIRTransforms +) + +target_link_libraries(FHELinalgDialectTransforms PUBLIC MLIRIR)