From 247d60503d3afc16c6d3383d92debe36df7764d2 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Thu, 7 Apr 2022 23:49:30 +0200 Subject: [PATCH] fix(compiler): Custom copy op from 1D tensor to avoir stack allocation from mlir memref to llvm lowering --- .../include/concretelang/Conversion/Tools.h | 11 ++++ .../include/concretelang/Runtime/wrappers.h | 6 ++ .../BConcreteToBConcreteCAPI.cpp | 30 +--------- .../BConcreteToBConcreteCAPI/CMakeLists.txt | 1 + compiler/lib/Conversion/CMakeLists.txt | 7 +++ .../CMakeLists.txt | 1 + .../MLIRLowerableDialectsToLLVM.cpp | 57 +++++++++++++++++++ compiler/lib/Conversion/Tools.cpp | 35 ++++++++++++ compiler/lib/Runtime/wrappers.cpp | 18 ++++++ 9 files changed, 137 insertions(+), 29 deletions(-) create mode 100644 compiler/include/concretelang/Conversion/Tools.h create mode 100644 compiler/lib/Conversion/Tools.cpp diff --git a/compiler/include/concretelang/Conversion/Tools.h b/compiler/include/concretelang/Conversion/Tools.h new file mode 100644 index 000000000..8dc30adc0 --- /dev/null +++ b/compiler/include/concretelang/Conversion/Tools.h @@ -0,0 +1,11 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt +// for license information. + +#include "mlir/IR/PatternMatch.h" + +mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op, + mlir::RewriterBase &rewriter, + llvm::StringRef funcName, + mlir::FunctionType funcType); \ No newline at end of file diff --git a/compiler/include/concretelang/Runtime/wrappers.h b/compiler/include/concretelang/Runtime/wrappers.h index f4f85a0ec..9629f2d21 100644 --- a/compiler/include/concretelang/Runtime/wrappers.h +++ b/compiler/include/concretelang/Runtime/wrappers.h @@ -57,6 +57,12 @@ void memref_bootstrap_lwe_u64( uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned, uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride, mlir::concretelang::RuntimeContext *context); + +void memref_copy_one_rank(uint64_t *src_allocated, uint64_t *src_aligned, + uint64_t src_offset, uint64_t src_size, + uint64_t src_stride, uint64_t *dst_allocated, + uint64_t *dst_aligned, uint64_t dst_offset, + uint64_t dst_size, uint64_t dst_stride); } #endif diff --git a/compiler/lib/Conversion/BConcreteToBConcreteCAPI/BConcreteToBConcreteCAPI.cpp b/compiler/lib/Conversion/BConcreteToBConcreteCAPI/BConcreteToBConcreteCAPI.cpp index 1e98fda62..d8598f8a0 100644 --- a/compiler/lib/Conversion/BConcreteToBConcreteCAPI/BConcreteToBConcreteCAPI.cpp +++ b/compiler/lib/Conversion/BConcreteToBConcreteCAPI/BConcreteToBConcreteCAPI.cpp @@ -12,6 +12,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "concretelang/Conversion/Passes.h" +#include "concretelang/Conversion/Tools.h" #include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" #include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" #include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" @@ -34,35 +35,6 @@ public: } }; -mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op, - mlir::RewriterBase &rewriter, - llvm::StringRef funcName, - mlir::FunctionType funcType) { - // Looking for the `funcName` Operation - auto module = mlir::SymbolTable::getNearestSymbolTable(op); - auto opFunc = mlir::dyn_cast_or_null( - mlir::SymbolTable::lookupSymbolIn(module, funcName)); - if (!opFunc) { - // Insert the forward declaration of the funcName - mlir::OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&module->getRegion(0).front()); - - opFunc = rewriter.create(rewriter.getUnknownLoc(), funcName, - funcType); - opFunc.setPrivate(); - } else { - // Check if the `funcName` is well a private function - if (!opFunc.isPrivate()) { - op->emitError() << "the function \"" << funcName - << "\" conflicts with the concrete C API, please rename"; - return mlir::failure(); - } - } - assert(mlir::SymbolTable::lookupSymbolIn(module, funcName) - ->template hasTrait()); - return mlir::success(); -} - // Set of functions to generate generic types. // Generic types are used to add forward declarations without a specific type. // For example, we may need to add LWE ciphertext of different dimensions, or diff --git a/compiler/lib/Conversion/BConcreteToBConcreteCAPI/CMakeLists.txt b/compiler/lib/Conversion/BConcreteToBConcreteCAPI/CMakeLists.txt index 22876ab2b..ab2b16111 100644 --- a/compiler/lib/Conversion/BConcreteToBConcreteCAPI/CMakeLists.txt +++ b/compiler/lib/Conversion/BConcreteToBConcreteCAPI/CMakeLists.txt @@ -9,6 +9,7 @@ add_mlir_dialect_library(BConcreteToBConcreteCAPI mlir-headers LINK_LIBS PUBLIC + ConcretelangConversion MLIRIR MLIRTransforms ) diff --git a/compiler/lib/Conversion/CMakeLists.txt b/compiler/lib/Conversion/CMakeLists.txt index 7933368b4..2e540d09e 100644 --- a/compiler/lib/Conversion/CMakeLists.txt +++ b/compiler/lib/Conversion/CMakeLists.txt @@ -5,3 +5,10 @@ add_subdirectory(FHETensorOpsToLinalg) add_subdirectory(ConcreteToBConcrete) add_subdirectory(BConcreteToBConcreteCAPI) add_subdirectory(MLIRLowerableDialectsToLLVM) + +add_mlir_library(ConcretelangConversion + Tools.cpp + + LINK_LIBS PUBLIC + MLIRIR +) \ No newline at end of file diff --git a/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/CMakeLists.txt b/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/CMakeLists.txt index 5d5922f7e..b0d26b4e7 100644 --- a/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/CMakeLists.txt +++ b/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/CMakeLists.txt @@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRLowerableDialectsToLLVM LINK_LIBS PUBLIC ${dialect_libs} ${conversion_libs} + ConcretelangConversion MLIRIR MLIRTransforms MLIRLLVMIR diff --git a/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp b/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp index 887870d76..9db8a57cd 100644 --- a/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp +++ b/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp @@ -25,6 +25,7 @@ #include "llvm/ADT/Sequence.h" #include "concretelang/Conversion/Passes.h" +#include "concretelang/Conversion/Tools.h" #include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h" #include "concretelang/Dialect/RT/Analysis/Autopar.h" #include "concretelang/Dialect/RT/IR/RTTypes.h" @@ -39,6 +40,61 @@ struct MLIRLowerableDialectsToLLVMPass }; } // namespace +// This rewrite pattern transforms any instance of `memref.copy` +// operators on 1D memref. +// This is introduced to avoid the MLIR lowering of `memref.copy` of ranked +// memref that basically allocate unranked memref structure on the stack before +// calling @memrefCopy. +// +// Example: +// +// ```mlir +// memref.copy %src, %dst : memref to memref +// ``` +// +// becomes: +// +// ```mlir +// %_src = memref.cast %src = memref to memref +// %_dst = memref.cast %dst = memref to memref +// call @memref_copy_one_rank(%_src, %_dst) : (tensor, tensor) -> +// () +// ``` +struct Memref1DCopyOpPattern + : public mlir::OpRewritePattern { + Memref1DCopyOpPattern(mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : mlir::OpRewritePattern(context, benefit) {} + + mlir::LogicalResult + matchAndRewrite(mlir::memref::CopyOp copyOp, + mlir::PatternRewriter &rewriter) const override { + if (copyOp.source().getType().cast().getRank() != 1 || + copyOp.source().getType().cast().getRank() != 1) { + return mlir::failure(); + } + auto opType = mlir::MemRefType::get({-1}, rewriter.getI64Type()); + // Insert forward declaration of the add_lwe_ciphertexts function + { + if (insertForwardDeclaration( + copyOp, rewriter, "memref_copy_one_rank", + mlir::FunctionType::get(rewriter.getContext(), {opType, opType}, + {})) + .failed()) { + return mlir::failure(); + } + } + auto sourceOp = rewriter.create( + copyOp.getLoc(), opType, copyOp.source()); + auto targetOp = rewriter.create( + copyOp.getLoc(), opType, copyOp.target()); + rewriter.replaceOpWithNewOp( + copyOp, "memref_copy_one_rank", mlir::TypeRange{}, + mlir::ValueRange{sourceOp, targetOp}); + return mlir::success(); + }; +}; + void MLIRLowerableDialectsToLLVMPass::runOnOperation() { // Setup the conversion target. We reuse the LLVMConversionTarget that // legalize LLVM dialect. @@ -63,6 +119,7 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() { // Setup the set of the patterns rewriter. At this point we want to // convert the `scf` operations to `std` and `std` operations to `llvm`. mlir::RewritePatternSet patterns(&getContext()); + patterns.add(&getContext(), 100); mlir::concretelang::populateRTToLLVMConversionPatterns(typeConverter, patterns); mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns); diff --git a/compiler/lib/Conversion/Tools.cpp b/compiler/lib/Conversion/Tools.cpp new file mode 100644 index 000000000..e722eefc0 --- /dev/null +++ b/compiler/lib/Conversion/Tools.cpp @@ -0,0 +1,35 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt +// for license information. + +#include "concretelang/Conversion/Tools.h" + +mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op, + mlir::RewriterBase &rewriter, + llvm::StringRef funcName, + mlir::FunctionType funcType) { + // Looking for the `funcName` Operation + auto module = mlir::SymbolTable::getNearestSymbolTable(op); + auto opFunc = mlir::dyn_cast_or_null( + mlir::SymbolTable::lookupSymbolIn(module, funcName)); + if (!opFunc) { + // Insert the forward declaration of the funcName + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&module->getRegion(0).front()); + + opFunc = rewriter.create(rewriter.getUnknownLoc(), funcName, + funcType); + opFunc.setPrivate(); + } else { + // Check if the `funcName` is well a private function + if (!opFunc.isPrivate()) { + op->emitError() << "the function \"" << funcName + << "\" conflicts with the concrete C API, please rename"; + return mlir::failure(); + } + } + assert(mlir::SymbolTable::lookupSymbolIn(module, funcName) + ->template hasTrait()); + return mlir::success(); +} \ No newline at end of file diff --git a/compiler/lib/Runtime/wrappers.cpp b/compiler/lib/Runtime/wrappers.cpp index 8944f59ef..89577bc68 100644 --- a/compiler/lib/Runtime/wrappers.cpp +++ b/compiler/lib/Runtime/wrappers.cpp @@ -6,6 +6,7 @@ #include "concretelang/Runtime/wrappers.h" #include #include +#include void memref_expand_lut_in_trivial_glwe_ct_u64( uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned, @@ -96,3 +97,20 @@ void memref_bootstrap_lwe_u64( out_aligned + out_offset, ct0_aligned + ct0_offset, glwe_ct_aligned + glwe_ct_offset); } + +void memref_copy_one_rank(uint64_t *src_allocated, uint64_t *src_aligned, + uint64_t src_offset, uint64_t src_size, + uint64_t src_stride, uint64_t *dst_allocated, + uint64_t *dst_aligned, uint64_t dst_offset, + uint64_t dst_size, uint64_t dst_stride) { + assert(src_size == dst_size && "memref_copy_one_rank size differs"); + if (src_stride == dst_stride) { + memcpy(dst_aligned + dst_offset, src_aligned + src_offset, + src_size * sizeof(uint64_t)); + return; + } + for (size_t i = 0; i < src_size; i++) { + dst_aligned[dst_offset + i * dst_stride] = + src_aligned[src_offset + i * src_stride]; + } +} \ No newline at end of file