From 626493dda76963d2117da760244632678ac67b60 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Fri, 11 Feb 2022 13:53:11 +0100 Subject: [PATCH] enhance(compiler): Lower from Concrete to BConcrete and BConcrete to C API call --- .../BConcreteToBConcreteCAPI/Pass.h | 22 + .../Conversion/ConcreteToBConcrete/Pass.h | 18 + .../include/concretelang/Conversion/Passes.h | 3 + .../include/concretelang/Conversion/Passes.td | 16 +- .../concretelang/Support/CompilerEngine.h | 4 + .../include/concretelang/Support/Pipeline.h | 8 +- .../BConcreteToBConcreteCAPI.cpp | 593 +++++++++++ .../BConcreteToBConcreteCAPI/CMakeLists.txt | 16 + compiler/lib/Conversion/CMakeLists.txt | 2 + .../ConcreteToBConcrete/CMakeLists.txt | 16 + .../ConcreteToBConcrete.cpp | 919 ++++++++++++++++++ compiler/lib/Support/CompilerEngine.cpp | 32 +- compiler/lib/Support/Pipeline.cpp | 21 +- compiler/src/main.cpp | 7 + .../BConcreteToBConcreteCAPI/add_lwe.mlir | 14 + .../BConcreteToBConcreteCAPI/add_lwe_int.mlir | 37 + .../bootstrap_lwe.mlir | 15 + .../keyswitch_lwe.mlir | 15 + .../BConcreteToBConcreteCAPI/mul_lwe_int.mlir | 33 + .../BConcreteToBConcreteCAPI/neg_lwe.mlir | 13 + .../BConcreteToBConcreteCAPI/sub_int_lwe.mlir | 47 + .../ConcreteToBConcrete/add_lwe.mlir | 10 + .../ConcreteToBConcrete/add_lwe_int.mlir | 25 + .../apply_lookup_table.mlir | 15 + .../apply_lookup_table_cst.mlir | 17 + .../ConcreteToBConcrete/identity.mlir | 8 + .../ConcreteToBConcrete/mul_lwe_int.mlir | 25 + .../ConcreteToBConcrete/neg_lwe.mlir | 10 + .../ConcreteToBConcrete/sub_int_lwe.mlir | 32 + .../ConcreteToBConcrete/tensor_identity.mlir | 7 + 30 files changed, 1984 insertions(+), 16 deletions(-) create mode 100644 compiler/include/concretelang/Conversion/BConcreteToBConcreteCAPI/Pass.h create mode 100644 compiler/include/concretelang/Conversion/ConcreteToBConcrete/Pass.h create mode 100644 compiler/lib/Conversion/BConcreteToBConcreteCAPI/BConcreteToBConcreteCAPI.cpp create mode 100644 compiler/lib/Conversion/BConcreteToBConcreteCAPI/CMakeLists.txt create mode 100644 compiler/lib/Conversion/ConcreteToBConcrete/CMakeLists.txt create mode 100644 compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp create mode 100644 compiler/tests/Conversion/BConcreteToBConcreteCAPI/add_lwe.mlir create mode 100644 compiler/tests/Conversion/BConcreteToBConcreteCAPI/add_lwe_int.mlir create mode 100644 compiler/tests/Conversion/BConcreteToBConcreteCAPI/bootstrap_lwe.mlir create mode 100644 compiler/tests/Conversion/BConcreteToBConcreteCAPI/keyswitch_lwe.mlir create mode 100644 compiler/tests/Conversion/BConcreteToBConcreteCAPI/mul_lwe_int.mlir create mode 100644 compiler/tests/Conversion/BConcreteToBConcreteCAPI/neg_lwe.mlir create mode 100644 compiler/tests/Conversion/BConcreteToBConcreteCAPI/sub_int_lwe.mlir create mode 100644 compiler/tests/Conversion/ConcreteToBConcrete/add_lwe.mlir create mode 100644 compiler/tests/Conversion/ConcreteToBConcrete/add_lwe_int.mlir create mode 100644 compiler/tests/Conversion/ConcreteToBConcrete/apply_lookup_table.mlir create mode 100644 compiler/tests/Conversion/ConcreteToBConcrete/apply_lookup_table_cst.mlir create mode 100644 compiler/tests/Conversion/ConcreteToBConcrete/identity.mlir create mode 100644 compiler/tests/Conversion/ConcreteToBConcrete/mul_lwe_int.mlir create mode 100644 compiler/tests/Conversion/ConcreteToBConcrete/neg_lwe.mlir create mode 100644 compiler/tests/Conversion/ConcreteToBConcrete/sub_int_lwe.mlir create mode 100644 compiler/tests/Conversion/ConcreteToBConcrete/tensor_identity.mlir diff --git a/compiler/include/concretelang/Conversion/BConcreteToBConcreteCAPI/Pass.h b/compiler/include/concretelang/Conversion/BConcreteToBConcreteCAPI/Pass.h new file mode 100644 index 000000000..c84d547c0 --- /dev/null +++ b/compiler/include/concretelang/Conversion/BConcreteToBConcreteCAPI/Pass.h @@ -0,0 +1,22 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_CONVERSION_BCONCRETETOBCONCRETECAPI_PASS_H_ +#define CONCRETELANG_CONVERSION_BCONCRETETOBCONCRETECAPI_PASS_H_ + +#include "mlir/Pass/Pass.h" + +#include "concretelang/Conversion/Utils/GlobalFHEContext.h" + +namespace mlir { +namespace concretelang { +/// Create a pass to convert `Concrete` operators to function call to the +/// `ConcreteCAPI` +std::unique_ptr> +createConvertBConcreteToBConcreteCAPIPass(); +} // namespace concretelang +} // namespace mlir + +#endif \ No newline at end of file diff --git a/compiler/include/concretelang/Conversion/ConcreteToBConcrete/Pass.h b/compiler/include/concretelang/Conversion/ConcreteToBConcrete/Pass.h new file mode 100644 index 000000000..cd0353690 --- /dev/null +++ b/compiler/include/concretelang/Conversion/ConcreteToBConcrete/Pass.h @@ -0,0 +1,18 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt +// for license information. + +#ifndef ZAMALANG_CONVERSION_CONCRETETOBCONCRETE_PASS_H_ +#define ZAMALANG_CONVERSION_CONCRETETOBCONCRETE_PASS_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace concretelang { +/// Create a pass to convert `Concrete` dialect to `BConcrete` dialect. +std::unique_ptr> createConvertConcreteToBConcretePass(); +} // namespace concretelang +} // namespace mlir + +#endif \ No newline at end of file diff --git a/compiler/include/concretelang/Conversion/Passes.h b/compiler/include/concretelang/Conversion/Passes.h index fcfc9b4ef..82ec18c3f 100644 --- a/compiler/include/concretelang/Conversion/Passes.h +++ b/compiler/include/concretelang/Conversion/Passes.h @@ -11,6 +11,8 @@ #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "concretelang/Conversion/BConcreteToBConcreteCAPI/Pass.h" +#include "concretelang/Conversion/ConcreteToBConcrete/Pass.h" #include "concretelang/Conversion/ConcreteToConcreteCAPI/Pass.h" #include "concretelang/Conversion/ConcreteUnparametrize/Pass.h" #include "concretelang/Conversion/FHETensorOpsToLinalg/Pass.h" @@ -18,6 +20,7 @@ #include "concretelang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h" #include "concretelang/Conversion/TFHEGlobalParametrization/Pass.h" #include "concretelang/Conversion/TFHEToConcrete/Pass.h" +#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" #include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" #include "concretelang/Dialect/FHE/IR/FHEDialect.h" #include "concretelang/Dialect/TFHE/IR/TFHEDialect.h" diff --git a/compiler/include/concretelang/Conversion/Passes.td b/compiler/include/concretelang/Conversion/Passes.td index 260b6689c..d8e5b0587 100644 --- a/compiler/include/concretelang/Conversion/Passes.td +++ b/compiler/include/concretelang/Conversion/Passes.td @@ -29,7 +29,15 @@ def TFHEToConcrete : Pass<"tfhe-to-concrete", "mlir::ModuleOp"> { let description = [{ Lowers operations from the TFHE dialect to Concrete }]; let constructor = "mlir::concretelang::createConvertTFHEToConcretePass()"; let options = []; - let dependentDialects = ["mlir::linalg::LinalgDialect"]; + let dependentDialects = ["mlir::linalg::LinalgDialect", "mlir::concretelang::TFHE::TFHEDialect"]; +} + +def ConcreteToBConcrete : Pass<"concrete-to-bconcrete", "mlir::ModuleOp"> { + let summary = "Lowers operations from the Concrete dialect to Bufferized Concrete"; + let description = [{ Lowers operations from the Concrete dialect to Bufferized Concrete }]; + let constructor = "mlir::concretelang::createConvertConcreteToBConcretePass()"; + let options = []; + let dependentDialects = ["mlir::linalg::LinalgDialect", "mlir::concretelang::Concrete::ConcreteDialect", "mlir::concretelang::BConcrete::BConcreteDialect"]; } def ConcreteToConcreteCAPI : Pass<"concrete-to-concrete-c-api", "mlir::ModuleOp"> { @@ -38,6 +46,12 @@ def ConcreteToConcreteCAPI : Pass<"concrete-to-concrete-c-api", "mlir::ModuleOp" let dependentDialects = ["mlir::concretelang::Concrete::ConcreteDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"]; } +def BConcreteToBConcreteCAPI : Pass<"bconcrete-to-bconcrete-c-api", "mlir::ModuleOp"> { + let summary = "Lower operations from the Bufferized Concrete dialect to std with function call to the Bufferized Concrete C API"; + let constructor = "mlir::concretelang::createConvertBConcreteToBConcreteCAPIPass()"; + let dependentDialects = ["mlir::concretelang::BConcrete::BConcreteDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"]; +} + def ConcreteUnparametrize : Pass<"concrete-unparametrize", "mlir::ModuleOp"> { let summary = "Unparametrize Concrete types and remove unrealized_conversion_cast"; let constructor = "mlir::concretelang::createConvertConcreteToConcreteCAPIPass()"; diff --git a/compiler/include/concretelang/Support/CompilerEngine.h b/compiler/include/concretelang/Support/CompilerEngine.h index da7f6f9bb..69a3669d2 100644 --- a/compiler/include/concretelang/Support/CompilerEngine.h +++ b/compiler/include/concretelang/Support/CompilerEngine.h @@ -119,6 +119,10 @@ public: // operations CONCRETE, + // Read sources and lower all FHE, TFHE and Concrete operations to BConcrete + // operations + BCONCRETE, + // Read sources and lower all FHE, TFHE and Concrete // operations to canonical MLIR dialects. Cryptographic operations // are lowered to invocations of the concrete library. diff --git a/compiler/include/concretelang/Support/Pipeline.h b/compiler/include/concretelang/Support/Pipeline.h index 3e5797916..9b590b30f 100644 --- a/compiler/include/concretelang/Support/Pipeline.h +++ b/compiler/include/concretelang/Support/Pipeline.h @@ -43,8 +43,12 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass); mlir::LogicalResult -lowerConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, - std::function enablePass); +lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass); + +mlir::LogicalResult +lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass); mlir::LogicalResult lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module, diff --git a/compiler/lib/Conversion/BConcreteToBConcreteCAPI/BConcreteToBConcreteCAPI.cpp b/compiler/lib/Conversion/BConcreteToBConcreteCAPI/BConcreteToBConcreteCAPI.cpp new file mode 100644 index 000000000..cf8a812a4 --- /dev/null +++ b/compiler/lib/Conversion/BConcreteToBConcreteCAPI/BConcreteToBConcreteCAPI.cpp @@ -0,0 +1,593 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt +// for license information. + +#include "mlir//IR/BuiltinTypes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "concretelang/Conversion/Passes.h" +#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" +#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" +#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" +#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h" +#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h" + +namespace { +class BConcreteToBConcreteCAPITypeConverter : public mlir::TypeConverter { + +public: + BConcreteToBConcreteCAPITypeConverter() { + addConversion([](mlir::Type type) { return type; }); + addConversion([&](mlir::concretelang::Concrete::PlaintextType type) { + return mlir::IntegerType::get(type.getContext(), 64); + }); + addConversion([&](mlir::concretelang::Concrete::CleartextType type) { + return mlir::IntegerType::get(type.getContext(), 64); + }); + } +}; + +mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op, + mlir::RewriterBase &rewriter, + llvm::StringRef funcName, + mlir::FunctionType funcType) { + // Looking for the `funcName` Operation + auto module = mlir::SymbolTable::getNearestSymbolTable(op); + auto opFunc = mlir::dyn_cast_or_null( + mlir::SymbolTable::lookupSymbolIn(module, funcName)); + if (!opFunc) { + // Insert the forward declaration of the funcName + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&module->getRegion(0).front()); + + opFunc = rewriter.create(rewriter.getUnknownLoc(), funcName, + funcType); + opFunc.setPrivate(); + } else { + // Check if the `funcName` is well a private function + if (!opFunc.isPrivate()) { + op->emitError() << "the function \"" << funcName + << "\" conflicts with the concrete C API, please rename"; + return mlir::failure(); + } + } + assert(mlir::SymbolTable::lookupSymbolIn(module, funcName) + ->template hasTrait()); + return mlir::success(); +} + +// Set of functions to generate generic types. +// Generic types are used to add forward declarations without a specific type. +// For example, we may need to add LWE ciphertext of different dimensions, or +// allocate them. All the calls to the C API should be done using this generic +// types, and casting should then be performed back to the appropriate type. + +inline mlir::Type getGenericLweBufferType(mlir::MLIRContext *context) { + return mlir::RankedTensorType::get({-1}, mlir::IntegerType::get(context, 64)); +} + +inline mlir::concretelang::Concrete::GlweCiphertextType +getGenericGlweCiphertextType(mlir::MLIRContext *context) { + return mlir::concretelang::Concrete::GlweCiphertextType::get(context); +} + +inline mlir::Type getGenericPlaintextType(mlir::MLIRContext *context) { + return mlir::IntegerType::get(context, 64); +} + +inline mlir::Type getGenericCleartextType(mlir::MLIRContext *context) { + return mlir::IntegerType::get(context, 64); +} + +inline mlir::concretelang::Concrete::PlaintextListType +getGenericPlaintextListType(mlir::MLIRContext *context) { + return mlir::concretelang::Concrete::PlaintextListType::get(context); +} + +inline mlir::concretelang::Concrete::ForeignPlaintextListType +getGenericForeignPlaintextListType(mlir::MLIRContext *context) { + return mlir::concretelang::Concrete::ForeignPlaintextListType::get(context); +} + +inline mlir::concretelang::Concrete::LweKeySwitchKeyType +getGenericLweKeySwitchKeyType(mlir::MLIRContext *context) { + return mlir::concretelang::Concrete::LweKeySwitchKeyType::get(context); +} + +inline mlir::concretelang::Concrete::LweBootstrapKeyType +getGenericLweBootstrapKeyType(mlir::MLIRContext *context) { + return mlir::concretelang::Concrete::LweBootstrapKeyType::get(context); +} + +// Insert all forward declarations needed for the pass. +// Should generalize input and output types for all decalarations, and the +// pattern using them would be resposible for casting them to the appropriate +// type. +mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, + mlir::IRRewriter &rewriter) { + auto lweBufferType = getGenericLweBufferType(rewriter.getContext()); + auto plaintextType = getGenericPlaintextType(rewriter.getContext()); + auto cleartextType = getGenericCleartextType(rewriter.getContext()); + auto glweCiphertextType = getGenericGlweCiphertextType(rewriter.getContext()); + auto plaintextListType = getGenericPlaintextListType(rewriter.getContext()); + auto foreignPlaintextList = + getGenericForeignPlaintextListType(rewriter.getContext()); + auto keySwitchKeyType = getGenericLweKeySwitchKeyType(rewriter.getContext()); + auto bootstrapKeyType = getGenericLweBootstrapKeyType(rewriter.getContext()); + auto contextType = + mlir::concretelang::Concrete::ContextType::get(rewriter.getContext()); + + // Insert forward declaration of the add_lwe_ciphertexts function + { + auto funcType = mlir::FunctionType::get( + rewriter.getContext(), {lweBufferType, lweBufferType, lweBufferType}, + {}); + if (insertForwardDeclaration(op, rewriter, "memref_add_lwe_ciphertexts_u64", + funcType) + .failed()) { + return mlir::failure(); + } + } + // Insert forward declaration of the add_plaintext_lwe_ciphertext_u64 function + { + auto funcType = mlir::FunctionType::get( + rewriter.getContext(), {lweBufferType, lweBufferType, plaintextType}, + {}); + if (insertForwardDeclaration( + op, rewriter, "memref_add_plaintext_lwe_ciphertext_u64", funcType) + .failed()) { + return mlir::failure(); + } + } + // Insert forward declaration of the mul_cleartext_lwe_ciphertext_u64 function + { + auto funcType = mlir::FunctionType::get( + rewriter.getContext(), {lweBufferType, lweBufferType, cleartextType}, + {}); + if (insertForwardDeclaration( + op, rewriter, "memref_mul_cleartext_lwe_ciphertext_u64", funcType) + .failed()) { + return mlir::failure(); + } + } + // Insert forward declaration of the negate_lwe_ciphertext_u64 function + { + auto funcType = mlir::FunctionType::get(rewriter.getContext(), + {lweBufferType, lweBufferType}, {}); + if (insertForwardDeclaration(op, rewriter, + "memref_negate_lwe_ciphertext_u64", funcType) + .failed()) { + return mlir::failure(); + } + } + // Insert forward declaration of the memref_keyswitch_lwe_u64 function + { + auto funcType = mlir::FunctionType::get( + rewriter.getContext(), {keySwitchKeyType, lweBufferType, lweBufferType}, + {}); + if (insertForwardDeclaration(op, rewriter, "memref_keyswitch_lwe_u64", + funcType) + .failed()) { + return mlir::failure(); + } + } + // Insert forward declaration of the memref_bootstrap_lwe_u64 function + { + auto funcType = mlir::FunctionType::get( + rewriter.getContext(), + {bootstrapKeyType, lweBufferType, lweBufferType, glweCiphertextType}, + {}); + if (insertForwardDeclaration(op, rewriter, "memref_bootstrap_lwe_u64", + funcType) + .failed()) { + return mlir::failure(); + } + } + // Insert forward declaration of the fill_plaintext_list function + { + auto funcType = mlir::FunctionType::get( + rewriter.getContext(), {plaintextListType, foreignPlaintextList}, {}); + if (insertForwardDeclaration( + op, rewriter, "fill_plaintext_list_with_expansion_u64", funcType) + .failed()) { + return mlir::failure(); + } + } + // Insert forward declaration of the add_plaintext_list_glwe function + { + auto funcType = mlir::FunctionType::get( + rewriter.getContext(), + {glweCiphertextType, glweCiphertextType, plaintextListType}, {}); + if (insertForwardDeclaration( + op, rewriter, "add_plaintext_list_glwe_ciphertext_u64", funcType) + .failed()) { + return mlir::failure(); + } + } + // Insert forward declaration of the getGlobalKeyswitchKey function + { + auto funcType = mlir::FunctionType::get(rewriter.getContext(), + {contextType}, {keySwitchKeyType}); + if (insertForwardDeclaration(op, rewriter, "get_keyswitch_key", funcType) + .failed()) { + return mlir::failure(); + } + } + // Insert forward declaration of the getGlobalBootstrapKey function + { + auto funcType = mlir::FunctionType::get(rewriter.getContext(), + {contextType}, {bootstrapKeyType}); + if (insertForwardDeclaration(op, rewriter, "get_bootstrap_key", funcType) + .failed()) { + return mlir::failure(); + } + } + return mlir::success(); +} + +// For all operands `tensor` replace with +// `%casted = tensor.cast %op : tensor to tensor` +template +mlir::SmallVector +getCastedTensorOperands(Op op, mlir::PatternRewriter &rewriter) { + mlir::SmallVector newOperands{}; + for (mlir::Value operand : op->getOperands()) { + mlir::Type operandType = operand.getType(); + if (operandType.isa()) { + mlir::Value castedOp = rewriter.create( + op.getLoc(), getGenericLweBufferType(rewriter.getContext()), operand); + newOperands.push_back(castedOp); + } else { + newOperands.push_back(operand); + } + } + return std::move(newOperands); +} + +/// BConcreteOpToConcreteCAPICallPattern match the `BConcreteOp` +/// Operation and replace with a call to `funcName`, the funcName should be an +/// external function that was linked later. It insert the forward declaration +/// of the private `funcName` if it not already in the symbol table. The C +/// signature of the function should be `void (out, args..., lweDimension)`, the +/// pattern rewrite: +/// ``` +/// "BConcreteOp"(%out, args ...) : +/// (tensor, tensor...) -> () +/// ``` +/// to +/// ``` +/// %out0 = tensor.cast %out : tensor to tensor +/// %args = tensor.cast ... +/// call @funcName(%out, args...) : (tensor, tensor...) -> () +/// ``` +template +struct ConcreteOpToConcreteCAPICallPattern + : public mlir::OpRewritePattern { + ConcreteOpToConcreteCAPICallPattern(mlir::MLIRContext *context, + mlir::StringRef funcName, + mlir::PatternBenefit benefit = 1) + : mlir::OpRewritePattern(context, benefit), + funcName(funcName) {} + + mlir::LogicalResult + matchAndRewrite(BConcreteOp op, + mlir::PatternRewriter &rewriter) const override { + BConcreteToBConcreteCAPITypeConverter typeConverter; + rewriter.replaceOpWithNewOp( + op, funcName, mlir::TypeRange{}, + getCastedTensorOperands(op, rewriter)); + return mlir::success(); + }; + +private: + std::string funcName; +}; + +struct ConcreteEncodeIntOpPattern + : public mlir::OpRewritePattern { + ConcreteEncodeIntOpPattern(mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : mlir::OpRewritePattern( + context, benefit) {} + + mlir::LogicalResult + matchAndRewrite(mlir::concretelang::Concrete::EncodeIntOp op, + mlir::PatternRewriter &rewriter) const override { + { + mlir::Value castedInt = rewriter.create( + op.getLoc(), rewriter.getIntegerType(64), op->getOperands().front()); + mlir::Value constantShiftOp = rewriter.create( + op.getLoc(), rewriter.getI64IntegerAttr(64 - op.getType().getP())); + + mlir::Type resultType = rewriter.getIntegerType(64); + rewriter.replaceOpWithNewOp( + op, resultType, castedInt, constantShiftOp); + } + return mlir::success(); + }; +}; + +struct ConcreteIntToCleartextOpPattern + : public mlir::OpRewritePattern< + mlir::concretelang::Concrete::IntToCleartextOp> { + ConcreteIntToCleartextOpPattern(mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : mlir::OpRewritePattern( + context, benefit) {} + + mlir::LogicalResult + matchAndRewrite(mlir::concretelang::Concrete::IntToCleartextOp op, + mlir::PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, rewriter.getIntegerType(64), op->getOperands().front()); + return mlir::success(); + }; +}; + +mlir::Value getContextArgument(mlir::Operation *op) { + mlir::Block *block = op->getBlock(); + while (block != nullptr) { + if (llvm::isa(block->getParentOp())) { + + mlir::Value context = block->getArguments().back(); + + assert( + context.getType().isa() && + "the Concrete.context should be the last argument of the enclosing " + "function of the op"); + + return context; + } + block = block->getParentOp()->getBlock(); + } + assert("can't find a function that enclose the op"); + return nullptr; +} + +// Rewrite pattern that rewrite every +// ``` +// "BConcrete.keyswitch_lwe_buffer"(%out, %in) {...}: +// (tensor<2049xi64>, tensor<2049xi64>) -> () +// ``` +// +// to +// +// ``` +// %ksk = call @get_keywswitch_key(%ctx) : +// (!Concrete.context) -> !Concrete.lwe_key_switch_key +// %out_ = tensor.cast %out : tensor to tensor +// %in_ = tensor.cast %in : tensor to tensor +// call @memref_keyswitch_lwe_u64(%ksk, %out_, %in_) : +// (!Concrete.lwe_key_switch_key, tensor, tensor) -> () +// ``` +struct BConcreteKeySwitchLweOpPattern + : public mlir::OpRewritePattern< + mlir::concretelang::BConcrete::KeySwitchLweBufferOp> { + BConcreteKeySwitchLweOpPattern(mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : mlir::OpRewritePattern< + mlir::concretelang::BConcrete::KeySwitchLweBufferOp>(context, + benefit) {} + + mlir::LogicalResult + matchAndRewrite(mlir::concretelang::BConcrete::KeySwitchLweBufferOp op, + mlir::PatternRewriter &rewriter) const override { + + mlir::CallOp kskOp = rewriter.create( + op.getLoc(), "get_keyswitch_key", + getGenericLweKeySwitchKeyType(rewriter.getContext()), + mlir::SmallVector{getContextArgument(op)}); + mlir::SmallVector operands{kskOp.getResult(0)}; + + operands.append( + getCastedTensorOperands< + mlir::concretelang::BConcrete::KeySwitchLweBufferOp>(op, rewriter)); + rewriter.replaceOpWithNewOp(op, "memref_keyswitch_lwe_u64", + mlir::TypeRange({}), operands); + return mlir::success(); + }; +}; + +// Rewrite pattern that rewrite every +// ``` +// "BConcrete.bootstrap_lwe_buffer"(%out, %in, %acc) {...} : +// (tensor<2049xui64>, tensor<2049xui64>, !Concrete.glwe_ciphertext) -> () +// ``` +// +// to +// +// ``` +// %bsk = call @getGlobalBootstrapKey() : () -> !Concrete.lwe_bootstrap_key +// %out_ = tensor.cast %out : tensor to tensor +// %in_ = tensor.cast %in : tensor to tensor +// call @memref_bootstrap_lwe_u64(%bsk, %out_, %in_, %acc_) : +// (!Concrete.lwe_bootstrap_key, tensor, tensor, +// !Concrete.glwe_ciphertext) -> () +// ``` +struct BConcreteBootstrapLweOpPattern + : public mlir::OpRewritePattern< + mlir::concretelang::BConcrete::BootstrapLweBufferOp> { + BConcreteBootstrapLweOpPattern(mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : mlir::OpRewritePattern< + mlir::concretelang::BConcrete::BootstrapLweBufferOp>(context, + benefit) {} + + mlir::LogicalResult + matchAndRewrite(mlir::concretelang::BConcrete::BootstrapLweBufferOp op, + mlir::PatternRewriter &rewriter) const override { + + mlir::SmallVector getkskOperands{}; + mlir::CallOp bskOp = rewriter.create( + op.getLoc(), "get_bootstrap_key", + getGenericLweBootstrapKeyType(rewriter.getContext()), + mlir::SmallVector{getContextArgument(op)}); + mlir::SmallVector operands{bskOp.getResult(0)}; + operands.append( + getCastedTensorOperands< + mlir::concretelang::BConcrete::BootstrapLweBufferOp>(op, rewriter)); + rewriter.replaceOpWithNewOp(op, "memref_bootstrap_lwe_u64", + mlir::TypeRange({}), operands); + return mlir::success(); + }; +}; + +/// Populate the RewritePatternSet with all patterns that rewrite Concrete +/// operators to the corresponding function call to the `Concrete C API`. +void populateBConcreteToBConcreteCAPICall(mlir::RewritePatternSet &patterns) { + patterns.add>( + patterns.getContext(), "memref_add_lwe_ciphertexts_u64"); + patterns.add>( + patterns.getContext(), "memref_add_plaintext_lwe_ciphertext_u64"); + patterns.add>( + patterns.getContext(), "memref_mul_cleartext_lwe_ciphertext_u64"); + patterns.add>( + patterns.getContext(), "memref_negate_lwe_ciphertext_u64"); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + // patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); +} + +struct AddRuntimeContextToFuncOpPattern + : public mlir::OpRewritePattern { + AddRuntimeContextToFuncOpPattern(mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : mlir::OpRewritePattern(context, benefit) {} + + mlir::LogicalResult + matchAndRewrite(mlir::FuncOp oldFuncOp, + mlir::PatternRewriter &rewriter) const override { + mlir::OpBuilder::InsertionGuard guard(rewriter); + mlir::FunctionType oldFuncType = oldFuncOp.getType(); + + // Add a Concrete.context to the function signature + mlir::SmallVector newInputs(oldFuncType.getInputs().begin(), + oldFuncType.getInputs().end()); + newInputs.push_back( + rewriter.getType()); + mlir::FunctionType newFuncTy = rewriter.getType( + newInputs, oldFuncType.getResults()); + // Create the new func + mlir::FuncOp newFuncOp = rewriter.create( + oldFuncOp.getLoc(), oldFuncOp.getName(), newFuncTy); + + // Create the arguments of the new func + mlir::Region &newFuncBody = newFuncOp.body(); + mlir::Block *newFuncEntryBlock = new mlir::Block(); + newFuncEntryBlock->addArguments(newFuncTy.getInputs()); + newFuncBody.push_back(newFuncEntryBlock); + + // Clone the old body to the new one + mlir::BlockAndValueMapping map; + for (auto arg : llvm::enumerate(oldFuncOp.getArguments())) { + map.map(arg.value(), newFuncEntryBlock->getArgument(arg.index())); + } + for (auto &op : oldFuncOp.body().front()) { + newFuncEntryBlock->push_back(op.clone(map)); + } + rewriter.eraseOp(oldFuncOp); + return mlir::success(); + } + + // Legal function are one that are private or has a Concrete.context as last + // arguments. + static bool isLegal(mlir::FuncOp funcOp) { + if (!funcOp.isPublic()) { + return true; + } + // TODO : Don't need to add a runtime context for function that doesn't + // manipulates Concrete types. + // + // if (!llvm::any_of(funcOp.getType().getInputs(), [](mlir::Type t) { + // if (auto tensorTy = t.dyn_cast_or_null()) { + // t = tensorTy.getElementType(); + // } + // return llvm::isa( + // t.getDialect()); + // })) { + // return true; + // } + return funcOp.getType().getNumInputs() >= 1 && + funcOp.getType() + .getInputs() + .back() + .isa(); + } +}; + +namespace { +struct BConcreteToBConcreteCAPIPass + : public BConcreteToBConcreteCAPIBase { + void runOnOperation() final; +}; +} // namespace + +void BConcreteToBConcreteCAPIPass::runOnOperation() { + mlir::ModuleOp op = getOperation(); + + // First of all add the Concrete.context to the block arguments of function + // that manipulates ciphertexts. + { + mlir::ConversionTarget target(getContext()); + mlir::RewritePatternSet patterns(&getContext()); + + target.addDynamicallyLegalOp([&](mlir::FuncOp funcOp) { + return AddRuntimeContextToFuncOpPattern::isLegal(funcOp); + }); + + patterns.add(patterns.getContext()); + + // Apply the conversion + if (mlir::applyPartialConversion(op, target, std::move(patterns)) + .failed()) { + this->signalPassFailure(); + return; + } + } + + // Insert forward declaration + mlir::IRRewriter rewriter(&getContext()); + if (insertForwardDeclarations(op, rewriter).failed()) { + this->signalPassFailure(); + } + // Rewrite Concrete ops to CallOp to the Concrete C API + { + mlir::ConversionTarget target(getContext()); + mlir::RewritePatternSet patterns(&getContext()); + + target.addIllegalDialect(); + + target.addLegalDialect(); + + populateBConcreteToBConcreteCAPICall(patterns); + + if (mlir::applyPartialConversion(op, target, std::move(patterns)) + .failed()) { + this->signalPassFailure(); + } + } +} +} // namespace + +namespace mlir { +namespace concretelang { +std::unique_ptr> +createConvertBConcreteToBConcreteCAPIPass() { + return std::make_unique(); +} +} // namespace concretelang +} // namespace mlir \ No newline at end of file diff --git a/compiler/lib/Conversion/BConcreteToBConcreteCAPI/CMakeLists.txt b/compiler/lib/Conversion/BConcreteToBConcreteCAPI/CMakeLists.txt new file mode 100644 index 000000000..90bf82695 --- /dev/null +++ b/compiler/lib/Conversion/BConcreteToBConcreteCAPI/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_dialect_library(BConcreteToBConcreteCAPI + BConcreteToBConcreteCAPI.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE + + DEPENDS + BConcreteDialect + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRTransforms +) + +target_link_libraries(BConcreteToBConcreteCAPI PUBLIC MLIRIR) diff --git a/compiler/lib/Conversion/CMakeLists.txt b/compiler/lib/Conversion/CMakeLists.txt index d30b88a76..4ad1f761d 100644 --- a/compiler/lib/Conversion/CMakeLists.txt +++ b/compiler/lib/Conversion/CMakeLists.txt @@ -2,6 +2,8 @@ add_subdirectory(FHEToTFHE) add_subdirectory(TFHEGlobalParametrization) add_subdirectory(TFHEToConcrete) add_subdirectory(FHETensorOpsToLinalg) +add_subdirectory(ConcreteToBConcrete) add_subdirectory(ConcreteToConcreteCAPI) +add_subdirectory(BConcreteToBConcreteCAPI) add_subdirectory(MLIRLowerableDialectsToLLVM) add_subdirectory(ConcreteUnparametrize) diff --git a/compiler/lib/Conversion/ConcreteToBConcrete/CMakeLists.txt b/compiler/lib/Conversion/ConcreteToBConcrete/CMakeLists.txt new file mode 100644 index 000000000..f06006f14 --- /dev/null +++ b/compiler/lib/Conversion/ConcreteToBConcrete/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_dialect_library(ConcreteToBConcrete + ConcreteToBConcrete.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/Concrete + + DEPENDS + ConcreteDialect + BConcreteDialect + + LINK_LIBS PUBLIC + MLIRIR + MLIRTransforms + MLIRMath) + +target_link_libraries(ConcreteToBConcrete PUBLIC BConcreteDialect MLIRIR) diff --git a/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp b/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp new file mode 100644 index 000000000..8ca525063 --- /dev/null +++ b/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp @@ -0,0 +1,919 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt +// for license information. + +#include + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "concretelang/Conversion/Passes.h" +#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h" +#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h" +#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" +#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" +#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" +#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h" +#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h" +#include "concretelang/Dialect/RT/IR/RTOps.h" + +namespace { +struct ConcreteToBConcretePass + : public ConcreteToBConcreteBase { + void runOnOperation() final; +}; +} // namespace + +/// ConcreteToBConcreteTypeConverter is a TypeConverter that transform +/// `Concrete.lwe_ciphertext` to `tensor>` +/// `tensor<...xConcrete.lwe_ciphertext>` to +/// `tensor<...xdimension+1, i64>>` +class ConcreteToBConcreteTypeConverter : public mlir::TypeConverter { + +public: + ConcreteToBConcreteTypeConverter() { + addConversion([](mlir::Type type) { return type; }); + addConversion([&](mlir::concretelang::Concrete::LweCiphertextType type) { + return mlir::RankedTensorType::get( + {type.getDimension() + 1}, + mlir::IntegerType::get(type.getContext(), 64)); + }); + addConversion([&](mlir::RankedTensorType type) { + auto lwe = type.getElementType() + .dyn_cast_or_null< + mlir::concretelang::Concrete::LweCiphertextType>(); + if (lwe == nullptr) { + return (mlir::Type)(type); + } + mlir::SmallVector newShape; + newShape.reserve(type.getShape().size() + 1); + newShape.append(type.getShape().begin(), type.getShape().end()); + newShape.push_back(lwe.getDimension() + 1); + mlir::Type r = mlir::RankedTensorType::get( + newShape, mlir::IntegerType::get(type.getContext(), 64)); + return r; + }); + addConversion([&](mlir::MemRefType type) { + auto lwe = type.getElementType() + .dyn_cast_or_null< + mlir::concretelang::Concrete::LweCiphertextType>(); + if (lwe == nullptr) { + return (mlir::Type)(type); + } + mlir::SmallVector newShape; + newShape.reserve(type.getShape().size() + 1); + newShape.append(type.getShape().begin(), type.getShape().end()); + newShape.push_back(lwe.getDimension() + 1); + mlir::Type r = mlir::MemRefType::get( + newShape, mlir::IntegerType::get(type.getContext(), 64)); + return r; + }); + } +}; + +// This rewrite pattern transforms any instance of `Concrete.zero_tensor` +// operators. +// +// Example: +// +// ```mlir +// %0 = "Concrete.zero_tensor" () : +// tensor<...x!Concrete.lwe_ciphertext> +// ``` +// +// becomes: +// +// ```mlir +// %0 = tensor.generate { +// ^bb0(... : index): +// %c0 = arith.constant 0 : i64 +// tensor.yield %z +// }: tensor<...xlweDim+1xi64> +// i64> +// ``` +template +struct ZeroOpPattern : public mlir::OpRewritePattern { + ZeroOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern(context, benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(ZeroOp zeroOp, + ::mlir::PatternRewriter &rewriter) const override { + ConcreteToBConcreteTypeConverter converter; + auto resultTy = zeroOp.getType(); + auto newResultTy = converter.convertType(resultTy); + + auto generateBody = [&](mlir::OpBuilder &nestedBuilder, + mlir::Location nestedLoc, + mlir::ValueRange blockArgs) { + // %c0 = 0 : i64 + auto cstOp = nestedBuilder.create( + nestedLoc, nestedBuilder.getI64IntegerAttr(1)); + // tensor.yield %z : !FHE.eint

+ nestedBuilder.create(nestedLoc, cstOp.getResult()); + }; + // tensor.generate + rewriter.replaceOpWithNewOp( + zeroOp, newResultTy, mlir::ValueRange{}, generateBody); + + return ::mlir::success(); + }; +}; + +// This template rewrite pattern transforms any instance of +// `ConcreteOp` to an instance of `BConcreteOp`. +// +// Example: +// +// %0 = "ConcreteOp"(%arg0, ...) : +// (!Concrete.lwe_ciphertext, ...) -> +// (!Concrete.lwe_ciphertext) +// +// becomes: +// +// %0 = linalg.init_tensor [dimension+1] : tensor +// "BConcreteOp"(%0, %arg0, ...) : (tensor>, +// tensor>, ..., ) -> () +// +// A reference to the preallocated output is always passed as the first +// argument. +template +struct LowToBConcrete : public mlir::OpRewritePattern { + LowToBConcrete(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern(context, benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(ConcreteOp concreteOp, + ::mlir::PatternRewriter &rewriter) const override { + ConcreteToBConcreteTypeConverter converter; + mlir::concretelang::Concrete::LweCiphertextType resultTy = + ((mlir::Type)concreteOp->getResult(0).getType()) + .cast(); + auto newResultTy = + converter.convertType(resultTy).cast(); + + // %0 = linalg.init_tensor [dimension+1] : tensor + mlir::Value init = rewriter.replaceOpWithNewOp( + concreteOp, newResultTy.getShape(), newResultTy.getElementType()); + + // "BConcreteOp"(%0, %arg0, ...) : (tensor>, + // tensor>, ..., ) -> () + mlir::SmallVector newOperands{init}; + + newOperands.append(concreteOp.getOperation()->getOperands().begin(), + concreteOp.getOperation()->getOperands().end()); + + llvm::ArrayRef<::mlir::NamedAttribute> attributes = + concreteOp.getOperation()->getAttrs(); + + rewriter.create(concreteOp.getLoc(), + mlir::SmallVector{}, newOperands, + attributes); + + return ::mlir::success(); + }; +}; + +// This rewrite pattern transforms any instance of +// `tensor.extract_slice` operators that operates on tensor of lwe ciphertext. +// +// Example: +// +// ```mlir +// %0 = tensor.extract_slice %arg0 +// [offsets...] [sizes...] [strides...] +// : tensor<...x!Concrete.lwe_ciphertext> to +// tensor<...x!Concrete.lwe_ciphertext> +// ``` +// +// becomes: +// +// ```mlir +// %0 = tensor.extract_slice %arg0 +// [offsets..., 0] [sizes..., lweDimension+1] [strides..., 1] +// : tensor<...xlweDimension+1,i64> to +// tensor<...xlweDimension+1,i64> +// ``` +struct ExtractSliceOpPattern + : public mlir::OpRewritePattern { + ExtractSliceOpPattern(::mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern(context, + benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp, + ::mlir::PatternRewriter &rewriter) const override { + ConcreteToBConcreteTypeConverter converter; + auto resultTy = extractSliceOp.result().getType(); + auto resultEltTy = + resultTy.cast() + .getElementType() + .cast(); + auto newResultTy = converter.convertType(resultTy); + + // add 0 to the static_offsets + mlir::SmallVector staticOffsets; + staticOffsets.append(extractSliceOp.static_offsets().begin(), + extractSliceOp.static_offsets().end()); + staticOffsets.push_back(rewriter.getI64IntegerAttr(0)); + + // add the lweSize to the sizes + mlir::SmallVector staticSizes; + staticSizes.append(extractSliceOp.static_sizes().begin(), + extractSliceOp.static_sizes().end()); + staticSizes.push_back( + rewriter.getI64IntegerAttr(resultEltTy.getDimension() + 1)); + + // add 1 to the strides + mlir::SmallVector staticStrides; + staticStrides.append(extractSliceOp.static_strides().begin(), + extractSliceOp.static_strides().end()); + staticStrides.push_back(rewriter.getI64IntegerAttr(1)); + + // replace tensor.extract_slice to the new one + rewriter.replaceOpWithNewOp( + extractSliceOp, newResultTy, extractSliceOp.source(), + extractSliceOp.offsets(), extractSliceOp.sizes(), + extractSliceOp.strides(), rewriter.getArrayAttr(staticOffsets), + rewriter.getArrayAttr(staticSizes), + rewriter.getArrayAttr(staticStrides)); + + return ::mlir::success(); + }; +}; + +// This rewrite pattern transforms any instance of +// `tensor.extract` operators that operates on tensor of lwe ciphertext. +// +// Example: +// +// ```mlir +// %0 = tensor.extract %t[offsets...] +// : tensor<...x!Concrete.lwe_ciphertext> +// ``` +// +// becomes: +// +// ```mlir +// %1 = tensor.extract_slice %arg0 +// [offsets...] [1..., lweDimension+1] [1...] +// : tensor<...xlweDimension+1,i64> to +// tensor<1...xlweDimension+1,i64> +// %0 = linalg.tensor_collapse_shape %0 [[...]] : +// tensor<1x1xlweDimension+1xi64> into tensor +// ``` +// +// TODO: since they are a bug on lowering extract_slice with rank reduction we +// add a linalg.tensor_collapse_shape after the extract_slice without rank +// reduction. See +// https://github.com/zama-ai/concrete-compiler-internal/issues/396. +struct ExtractOpPattern + : public mlir::OpRewritePattern { + ExtractOpPattern(::mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern(context, benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(mlir::tensor::ExtractOp extractOp, + ::mlir::PatternRewriter &rewriter) const override { + ConcreteToBConcreteTypeConverter converter; + auto lweResultTy = + extractOp.result() + .getType() + .dyn_cast_or_null< + mlir::concretelang::Concrete::LweCiphertextType>(); + if (lweResultTy == nullptr) { + return mlir::failure(); + } + + auto newResultTy = + converter.convertType(lweResultTy).cast(); + auto rankOfResult = extractOp.indices().size() + 1; + + // [min..., 0] for static_offsets () + mlir::SmallVector staticOffsets( + rankOfResult, + rewriter.getI64IntegerAttr(std::numeric_limits::min())); + staticOffsets[staticOffsets.size() - 1] = rewriter.getI64IntegerAttr(0); + + // [1..., lweDimension+1] for static_sizes + mlir::SmallVector staticSizes( + rankOfResult, rewriter.getI64IntegerAttr(1)); + staticSizes[staticSizes.size() - 1] = rewriter.getI64IntegerAttr( + newResultTy.getDimSize(newResultTy.getRank() - 1)); + + // [1...] for static_strides + mlir::SmallVector staticStrides( + rankOfResult, rewriter.getI64IntegerAttr(1)); + + // replace tensor.extract_slice to the new one + mlir::SmallVector extractedSliceShape( + extractOp.indices().size() + 1, 0); + extractedSliceShape.reserve(extractOp.indices().size() + 1); + for (size_t i = 0; i < extractedSliceShape.size() - 1; i++) { + extractedSliceShape[i] = 1; + } + extractedSliceShape[extractedSliceShape.size() - 1] = + newResultTy.getDimSize(0); + + auto extractedSliceType = + mlir::RankedTensorType::get(extractedSliceShape, rewriter.getI64Type()); + auto extractedSlice = rewriter.create( + extractOp.getLoc(), extractedSliceType, extractOp.tensor(), + extractOp.indices(), mlir::SmallVector{}, + mlir::SmallVector{}, rewriter.getArrayAttr(staticOffsets), + rewriter.getArrayAttr(staticSizes), + rewriter.getArrayAttr(staticStrides)); + mlir::ReassociationIndices reassociation; + for (int64_t i = 0; i < extractedSliceType.getRank(); i++) { + reassociation.push_back(i); + } + rewriter.replaceOpWithNewOp( + extractOp, newResultTy, extractedSlice, + mlir::SmallVector{reassociation}); + return ::mlir::success(); + }; +}; + +// This rewrite pattern transforms any instance of +// `tensor.insert_slice` operators that operates on tensor of lwe ciphertext. +// +// Example: +// +// ```mlir +// %0 = tensor.insert_slice %arg1 +// into %arg0[offsets...] [sizes...] [strides...] +// : tensor<...x!Concrete.lwe_ciphertext> into +// tensor<...x!Concrete.lwe_ciphertext> +// ``` +// +// becomes: +// +// ```mlir +// %0 = tensor.insert_slice %arg1 +// into %arg0[offsets..., 0] [sizes..., lweDimension+1] [strides..., 1] +// : tensor<...xlweDimension+1xi64> into +// tensor<...xlweDimension+1xi64> +// ``` +struct InsertSliceOpPattern + : public mlir::OpRewritePattern { + InsertSliceOpPattern(::mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern(context, + benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(mlir::tensor::InsertSliceOp insertSliceOp, + ::mlir::PatternRewriter &rewriter) const override { + ConcreteToBConcreteTypeConverter converter; + auto resultTy = insertSliceOp.result().getType(); + + auto newResultTy = + converter.convertType(resultTy).cast(); + + // add 0 to static_offsets + mlir::SmallVector staticOffsets; + staticOffsets.append(insertSliceOp.static_offsets().begin(), + insertSliceOp.static_offsets().end()); + staticOffsets.push_back(rewriter.getI64IntegerAttr(0)); + + // add lweDimension+1 to static_sizes + mlir::SmallVector staticSizes; + staticSizes.append(insertSliceOp.static_sizes().begin(), + insertSliceOp.static_sizes().end()); + staticSizes.push_back(rewriter.getI64IntegerAttr( + newResultTy.getDimSize(newResultTy.getRank() - 1))); + + // add 1 to the strides + mlir::SmallVector staticStrides; + staticStrides.append(insertSliceOp.static_strides().begin(), + insertSliceOp.static_strides().end()); + staticStrides.push_back(rewriter.getI64IntegerAttr(1)); + + // replace tensor.insert_slice with the new one + rewriter.replaceOpWithNewOp( + insertSliceOp, newResultTy, insertSliceOp.source(), + insertSliceOp.dest(), insertSliceOp.offsets(), insertSliceOp.sizes(), + insertSliceOp.strides(), rewriter.getArrayAttr(staticOffsets), + rewriter.getArrayAttr(staticSizes), + rewriter.getArrayAttr(staticStrides)); + + return ::mlir::success(); + }; +}; + +// This rewrite pattern transforms any instance of +// `tensor.from_elements` operators that operates on tensor of lwe ciphertext. +// +// Example: +// +// ```mlir +// %0 = tensor.from_elements %e0, ..., %e(n-1) +// : tensor> +// ``` +// +// becomes: +// +// ```mlir +// %m = memref.alloc() : memref +// %s0 = memref.subview %m[0, 0][1, lweDim+1][1, 1] : memref +// %m0 = memref.buffer_cast %e0 : memref +// memref.copy %m0, s0 : memref to memref +// ... +// %s(n-1) = memref.subview %m[(n-1), 0][1, lweDim+1][1, 1] +// : memref +// %m(n-1) = memref.buffer_cast %e(n-1) : memref +// memref.copy %e(n-1), s(n-1) +// : memref to memref +// %0 = memref.tensor_load %m : memref +// ``` +struct FromElementsOpPattern + : public mlir::OpRewritePattern { + FromElementsOpPattern(::mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern(context, + benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(mlir::tensor::FromElementsOp fromElementsOp, + ::mlir::PatternRewriter &rewriter) const override { + ConcreteToBConcreteTypeConverter converter; + + auto resultTy = fromElementsOp.result().getType(); + if (converter.isLegal(resultTy)) { + return mlir::failure(); + } + auto eltResultTy = + resultTy.cast() + .getElementType() + .cast(); + auto newTensorResultTy = + converter.convertType(resultTy).cast(); + auto newMemrefResultTy = mlir::MemRefType::get( + newTensorResultTy.getShape(), newTensorResultTy.getElementType()); + + // %m = memref.alloc() : memref + auto mOp = rewriter.create(fromElementsOp.getLoc(), + newMemrefResultTy); + + // for i = 0 to n-1 + // %si = memref.subview %m[i, 0][1, lweDim+1][1, 1] : memref + // %mi = memref.buffer_cast %ei : memref + // memref.copy %mi, si : memref to memref + auto subviewResultTy = mlir::MemRefType::get( + {eltResultTy.getDimension() + 1}, newMemrefResultTy.getElementType()); + auto offset = 0; + for (auto eiOp : fromElementsOp.elements()) { + mlir::SmallVector staticOffsets{ + rewriter.getI64IntegerAttr(offset), rewriter.getI64IntegerAttr(0)}; + mlir::SmallVector staticSizes{ + rewriter.getI64IntegerAttr(1), + rewriter.getI64IntegerAttr(eltResultTy.getDimension() + 1)}; + mlir::SmallVector staticStrides{ + rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)}; + auto siOp = rewriter.create( + fromElementsOp.getLoc(), subviewResultTy, mOp, mlir::ValueRange{}, + mlir::ValueRange{}, mlir::ValueRange{}, + rewriter.getArrayAttr(staticOffsets), + rewriter.getArrayAttr(staticSizes), + rewriter.getArrayAttr(staticStrides)); + auto miOp = rewriter.create( + fromElementsOp.getLoc(), subviewResultTy, eiOp); + rewriter.create(fromElementsOp.getLoc(), miOp, + siOp); + offset++; + } + + // Go back to tensor world + // %0 = memref.tensor_load %m : memref + rewriter.replaceOpWithNewOp(fromElementsOp, + mOp); + + return ::mlir::success(); + }; +}; + +// This template rewrite pattern transforms any instance of +// `ShapeOp` operators that operates on tensor of lwe ciphertext by adding the +// lwe size as a size of the tensor result and by adding a trivial reassociation +// at the end of the reassociations map. +// +// Example: +// +// ```mlir +// %0 = "ShapeOp" %arg0 [reassocations...] +// : tensor<...x!Concrete.lwe_ciphertext> into +// tensor<...x!Concrete.lwe_ciphertext> +// ``` +// +// becomes: +// +// ```mlir +// %0 = "ShapeOp" %arg0 [reassociations..., [inRank or outRank]] +// : tensor<...xlweDimesion+1xi64> into +// tensor<...xlweDimesion+1xi64> +// ``` +template +struct TensorShapeOpPattern : public mlir::OpRewritePattern { + TensorShapeOpPattern(::mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern(context, benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(ShapeOp shapeOp, + ::mlir::PatternRewriter &rewriter) const override { + ConcreteToBConcreteTypeConverter converter; + auto resultTy = shapeOp.result().getType(); + + auto newResultTy = + ((mlir::Type)converter.convertType(resultTy)).cast(); + + // add [rank] to reassociations + auto oldReassocs = shapeOp.getReassociationIndices(); + mlir::SmallVector newReassocs; + newReassocs.append(oldReassocs.begin(), oldReassocs.end()); + mlir::ReassociationIndices lweAssoc; + auto reassocTy = + ((mlir::Type)converter.convertType( + (inRank ? shapeOp.src() : shapeOp.result()).getType())) + .cast(); + lweAssoc.push_back(reassocTy.getRank()); + newReassocs.push_back(lweAssoc); + + rewriter.replaceOpWithNewOp(shapeOp, newResultTy, shapeOp.src(), + newReassocs); + return ::mlir::success(); + }; +}; + +// Add the instantiated TensorShapeOpPattern rewrite pattern with the `ShapeOp` +// to the patterns set and populate the conversion target. +template +void insertTensorShapeOpPattern(mlir::MLIRContext &context, + mlir::RewritePatternSet &patterns, + mlir::ConversionTarget &target) { + patterns.insert>(&context); + target.addDynamicallyLegalOp([&](ShapeOp op) { + ConcreteToBConcreteTypeConverter converter; + return converter.isLegal(op.result().getType()); + }); +} + +// This template rewrite pattern transforms any instance of +// `MemrefOp` operators that returns a memref of lwe ciphertext to the same +// operator but which returns the bufferized lwe ciphertext. +// +// Example: +// +// ```mlir +// %0 = "MemrefOp"(...) : ... -> memref<...x!Concrete.lwe_ciphertext> +// ``` +// +// becomes: +// +// ```mlir +// %0 = "MemrefOp"(...) : ... -> memref<...xlweDim+1xi64> +// ``` +template +struct MemrefOpPattern : public mlir::OpRewritePattern { + MemrefOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) + : mlir::OpRewritePattern(context, benefit) {} + + mlir::LogicalResult + matchAndRewrite(MemrefOp memrefOp, + mlir::PatternRewriter &rewriter) const override { + ConcreteToBConcreteTypeConverter converter; + + mlir::SmallVector convertedTypes; + if (converter.convertTypes(memrefOp->getResultTypes(), convertedTypes) + .failed()) { + return mlir::failure(); + } + + rewriter.replaceOpWithNewOp(memrefOp, convertedTypes, + memrefOp->getOperands(), + memrefOp->getAttrs()); + return ::mlir::success(); + }; +}; + +// Add the instantiated MemrefOpPattern rewrite pattern with the `MemrefOp` +// to the patterns set and populate the conversion target. +template +void insertMemrefOpPattern(mlir::MLIRContext &context, + mlir::RewritePatternSet &patterns, + mlir::ConversionTarget &target) { + (void)std::initializer_list{ + 0, (patterns.insert>(&context), + target.addDynamicallyLegalOp([&](MemrefOp op) { + ConcreteToBConcreteTypeConverter converter; + return converter.isLegal(op->getResultTypes()); + }), + 0)...}; +} + +// cc from Loops.cpp +static mlir::SmallVector +makeCanonicalAffineApplies(mlir::OpBuilder &b, mlir::Location loc, + mlir::AffineMap map, + mlir::ArrayRef vals) { + if (map.isEmpty()) + return {}; + + assert(map.getNumInputs() == vals.size()); + mlir::SmallVector res; + res.reserve(map.getNumResults()); + auto dims = map.getNumDims(); + for (auto e : map.getResults()) { + auto exprMap = mlir::AffineMap::get(dims, map.getNumSymbols(), e); + mlir::SmallVector operands(vals.begin(), vals.end()); + canonicalizeMapAndOperands(&exprMap, &operands); + res.push_back(b.create(loc, exprMap, operands)); + } + return res; +} + +static std::pair +makeOperandLoadOrSubview(mlir::OpBuilder &builder, mlir::Location loc, + mlir::ArrayRef allIvs, + mlir::linalg::LinalgOp linalgOp, + mlir::OpOperand *operand) { + ConcreteToBConcreteTypeConverter converter; + + mlir::Value opVal = operand->get(); + mlir::MemRefType opTy = opVal.getType().cast(); + + if (auto lweType = + opTy.getElementType() + .dyn_cast_or_null< + mlir::concretelang::Concrete::LweCiphertextType>()) { + // For memref of ciphertexts operands create the inner memref + // subview to the ciphertext, and go back to the tensor type as BConcrete + // operators works with tensor. + // %op : memref> + // %opInner = memref.subview %opInner[offsets...][1...][1,...] + // : memref<...xConcrete.lwe_ciphertext> to + // memref> + + auto tensorizedLweTy = + converter.convertType(lweType).cast(); + auto subviewResultTy = mlir::MemRefType::get( + tensorizedLweTy.getShape(), tensorizedLweTy.getElementType()); + auto offsets = makeCanonicalAffineApplies( + builder, loc, linalgOp.getTiedIndexingMap(operand), allIvs); + mlir::SmallVector staticOffsets( + opTy.getRank(), + builder.getI64IntegerAttr(std::numeric_limits::min())); + mlir::SmallVector staticSizes( + opTy.getRank(), builder.getI64IntegerAttr(1)); + mlir::SmallVector staticStrides( + opTy.getRank(), builder.getI64IntegerAttr(1)); + + auto subViewOp = builder.create( + loc, subviewResultTy, opVal, offsets, mlir::ValueRange{}, + mlir::ValueRange{}, builder.getArrayAttr(staticOffsets), + builder.getArrayAttr(staticSizes), builder.getArrayAttr(staticStrides)); + return std::pair( + subViewOp, builder.create(loc, subViewOp)); + } else { + // For memref of non ciphertexts load the value from the memref. + // with %op : memref + // %opInner = memref.load %op[offsets...] : memref + auto offsets = makeCanonicalAffineApplies( + builder, loc, linalgOp.getTiedIndexingMap(operand), allIvs); + return std::pair( + nullptr, + builder.create(loc, operand->get(), offsets)); + } +} + +static void +inlineRegionAndEmitTensorStore(mlir::OpBuilder &builder, mlir::Location loc, + mlir::linalg::LinalgOp linalgOp, + llvm::ArrayRef indexedValues, + mlir::ValueRange outputBuffers) { + // Clone the block with the new operands + auto &block = linalgOp->getRegion(0).front(); + mlir::BlockAndValueMapping map; + map.map(block.getArguments(), indexedValues); + for (auto &op : block.without_terminator()) { + auto *newOp = builder.clone(op, map); + map.map(op.getResults(), newOp->getResults()); + } + // Create memref.tensor_store operation for each terminator operands + auto *terminator = block.getTerminator(); + for (mlir::OpOperand &operand : terminator->getOpOperands()) { + mlir::Value toStore = map.lookupOrDefault(operand.get()); + builder.create( + loc, toStore, outputBuffers[operand.getOperandNumber()]); + } +} + +template +class LinalgRewritePattern + : public mlir::OpInterfaceConversionPattern { +public: + using mlir::OpInterfaceConversionPattern< + mlir::linalg::LinalgOp>::OpInterfaceConversionPattern; + + mlir::LogicalResult + matchAndRewrite(mlir::linalg::LinalgOp linalgOp, + mlir::ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const override { + assert(linalgOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + + auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc()); + auto iteratorTypes = + llvm::to_vector<4>(linalgOp.iterator_types().getValue()); + + mlir::SmallVector allIvs; + mlir::linalg::GenerateLoopNest::doit( + rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes, + [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange ivs, + mlir::ValueRange operandValuesToUse) -> mlir::scf::ValueVector { + // Keep indexed values to replace the linalg.generic block arguments + // by them + mlir::SmallVector indexedValues; + indexedValues.reserve(linalgOp.getNumInputsAndOutputs()); + assert( + operandValuesToUse == linalgOp->getOperands() && + "expect operands are captured and not passed by loop argument"); + allIvs.append(ivs.begin(), ivs.end()); + + // For all input operands create the inner operand + for (mlir::OpOperand *inputOperand : linalgOp.getInputOperands()) { + auto innerOperand = makeOperandLoadOrSubview( + builder, loc, allIvs, linalgOp, inputOperand); + indexedValues.push_back(innerOperand.second); + } + + // For all output operands create the inner operand + assert(linalgOp.getOutputOperands() == + linalgOp.getOutputBufferOperands() && + "expect only memref as output operands"); + mlir::SmallVector outputBuffers; + for (mlir::OpOperand *outputOperand : linalgOp.getOutputOperands()) { + auto innerOperand = makeOperandLoadOrSubview( + builder, loc, allIvs, linalgOp, outputOperand); + indexedValues.push_back(innerOperand.second); + assert(innerOperand.first != nullptr && + "Expected a memref subview as output buffer"); + outputBuffers.push_back(innerOperand.first); + } + // Finally inline the linalgOp region + inlineRegionAndEmitTensorStore(builder, loc, linalgOp, indexedValues, + outputBuffers); + + return mlir::scf::ValueVector{}; + }); + rewriter.eraseOp(linalgOp); + return mlir::success(); + }; +}; + +void ConcreteToBConcretePass::runOnOperation() { + auto op = this->getOperation(); + + // First of all we transform LinalgOp that work on tensor of ciphertext to + // work on memref. + { + mlir::ConversionTarget target(getContext()); + mlir::BufferizeTypeConverter converter; + + // Mark all Standard operations legal. + target + .addLegalDialect(); + + // Mark all Linalg operations illegal as long as they work on encrypted + // tensors. + target.addDynamicallyLegalOp( + [&](mlir::Operation *op) { return converter.isLegal(op); }); + + mlir::RewritePatternSet patterns(&getContext()); + mlir::linalg::populateLinalgBufferizePatterns(converter, patterns); + if (failed(applyPartialConversion(op, target, std::move(patterns)))) { + signalPassFailure(); + return; + } + } + + // Then convert ciphertext to tensor or add a dimension to tensor of + // ciphertext and memref of ciphertext + { + mlir::ConversionTarget target(getContext()); + ConcreteToBConcreteTypeConverter converter; + mlir::OwningRewritePatternList patterns(&getContext()); + + // All BConcrete ops are legal after the conversion + target.addLegalDialect(); + + // Add Concrete ops are illegal after the conversion unless those which are + // explicitly marked as legal (more or less operators that didn't work on + // ciphertexts) + target.addIllegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + + // Add patterns to convert the zero ops to tensor.generate + patterns + .insert, + ZeroOpPattern>( + &getContext()); + target.addLegalOp(); + + // Add patterns to trivialy convert Concrete op to the equivalent + // BConcrete op + target.addLegalOp(); + patterns.insert< + LowToBConcrete, + LowToBConcrete< + mlir::concretelang::Concrete::AddPlaintextLweCiphertextOp, + mlir::concretelang::BConcrete::AddPlaintextLweBufferOp>, + LowToBConcrete< + mlir::concretelang::Concrete::MulCleartextLweCiphertextOp, + mlir::concretelang::BConcrete::MulCleartextLweBufferOp>, + LowToBConcrete< + mlir::concretelang::Concrete::MulCleartextLweCiphertextOp, + mlir::concretelang::BConcrete::MulCleartextLweBufferOp>, + LowToBConcrete, + LowToBConcrete, + LowToBConcrete>( + &getContext()); + + // Add patterns to rewrite tensor operators that works on encrypted tensors + patterns.insert(&getContext()); + target.addDynamicallyLegalOp< + mlir::tensor::ExtractSliceOp, mlir::tensor::ExtractOp, + mlir::tensor::InsertSliceOp, mlir::tensor::FromElementsOp>( + [&](mlir::Operation *op) { + return converter.isLegal(op->getResult(0).getType()); + }); + target.addLegalOp(); + + // Add patterns to rewrite some of memref ops that was introduced by the + // linalg bufferization of encrypted tensor (first conversion of this pass) + insertTensorShapeOpPattern( + getContext(), patterns, target); + insertTensorShapeOpPattern( + getContext(), patterns, target); + + // Add patterns to rewrite linalg op to nested loops with views on + // ciphertexts + patterns.insert>(converter, + &getContext()); + target.addLegalOp(); + + // Add patterns to do the conversion of func + mlir::populateFuncOpTypeConversionPattern(patterns, converter); + target.addDynamicallyLegalOp([&](mlir::FuncOp funcOp) { + return converter.isSignatureLegal(funcOp.getType()) && + converter.isLegal(&funcOp.getBody()); + }); + + // Add patterns to convert some memref operators that is generated by + // previous step + insertMemrefOpPattern(getContext(), patterns, + target); + + // Conversion of RT Dialect Ops + patterns.add>(patterns.getContext(), + converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::DataflowTaskOp>(target, converter); + + // Apply conversion + if (mlir::applyPartialConversion(op, target, std::move(patterns)) + .failed()) { + this->signalPassFailure(); + } + } +} + +namespace mlir { +namespace concretelang { +std::unique_ptr> +createConvertConcreteToBConcretePass() { + return std::make_unique(); +} +} // namespace concretelang +} // namespace mlir diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 274ef0cd3..b084e924e 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -264,16 +264,6 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { if (target == Target::CONCRETE) return std::move(res); - // Concrete -> Canonical dialects - if (mlir::concretelang::pipeline::lowerConcreteToStd(mlirContext, module, - enablePass) - .failed()) { - return errorDiag( - "Lowering from Concrete to canonical MLIR dialects failed"); - } - if (target == Target::STD) - return std::move(res); - // Generate client parameters if requested if (this->generateClientParameters) { if (!this->clientParametersFuncName.hasValue()) { @@ -304,6 +294,28 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { } } + // Concrete -> BConcrete + if (mlir::concretelang::pipeline::lowerConcreteToBConcrete( + mlirContext, module, this->enablePass) + .failed()) { + return StreamStringError( + "Lowering from Concrete to Bufferized Concrete failed"); + } + + if (target == Target::BCONCRETE) { + return std::move(res); + } + + // BConcrete -> Canonical dialects + if (mlir::concretelang::pipeline::lowerBConcreteToStd(mlirContext, module, + enablePass) + .failed()) { + return errorDiag( + "Lowering from Bufferized Concrete to canonical MLIR dialects failed"); + } + if (target == Target::STD) + return std::move(res); + // MLIR canonical dialects -> LLVM Dialect if (mlir::concretelang::pipeline::lowerStdToLLVMDialect( mlirContext, module, enablePass, diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 5d0f1e37e..d2badcd00 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -181,10 +181,25 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, } mlir::LogicalResult -lowerConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, - std::function enablePass) { +lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass) { mlir::PassManager pm(&context); - pipelinePrinting("ConcreteToStd", pm, context); + pipelinePrinting("ConcreteToBConcrete", pm, context); + addPotentiallyNestedPass( + pm, mlir::concretelang::createConvertConcreteToBConcretePass(), + enablePass); + + return pm.run(module.getOperation()); +} + +mlir::LogicalResult +lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass) { + mlir::PassManager pm(&context); + pipelinePrinting("BConcreteToStd", pm, context); + addPotentiallyNestedPass( + pm, mlir::concretelang::createConvertBConcreteToBConcreteCAPIPass(), + enablePass); addPotentiallyNestedPass( pm, mlir::concretelang::createConvertConcreteToConcreteCAPIPass(), enablePass); diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 74b1f83e0..7d835f118 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -41,6 +41,7 @@ enum Action { DUMP_FHE, DUMP_TFHE, DUMP_CONCRETE, + DUMP_BCONCRETE, DUMP_STD, DUMP_LLVM_DIALECT, DUMP_LLVM_IR, @@ -101,6 +102,9 @@ static llvm::cl::opt action( "Lower to TFHE and dump result")), llvm::cl::values(clEnumValN(Action::DUMP_CONCRETE, "dump-concrete", "Lower to Concrete and dump result")), + llvm::cl::values( + clEnumValN(Action::DUMP_BCONCRETE, "dump-bconcrete", + "Lower to Bufferized Concrete and dump result")), llvm::cl::values(clEnumValN(Action::DUMP_STD, "dump-std", "Lower to std and dump result")), llvm::cl::values(clEnumValN(Action::DUMP_LLVM_DIALECT, "dump-llvm-dialect", @@ -324,6 +328,9 @@ mlir::LogicalResult processInputBuffer( case Action::DUMP_CONCRETE: target = mlir::concretelang::CompilerEngine::Target::CONCRETE; break; + case Action::DUMP_BCONCRETE: + target = mlir::concretelang::CompilerEngine::Target::BCONCRETE; + break; case Action::DUMP_STD: target = mlir::concretelang::CompilerEngine::Target::STD; break; diff --git a/compiler/tests/Conversion/BConcreteToBConcreteCAPI/add_lwe.mlir b/compiler/tests/Conversion/BConcreteToBConcreteCAPI/add_lwe.mlir new file mode 100644 index 000000000..bd952dad5 --- /dev/null +++ b/compiler/tests/Conversion/BConcreteToBConcreteCAPI/add_lwe.mlir @@ -0,0 +1,14 @@ +// RUN: concretecompiler --passes bconcrete-to-bconcrete-c-api --action=dump-std %s 2>&1| FileCheck %s + +// CHECK: func @add_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<2049xi64>, %arg2: !Concrete.context) -> tensor<2049xi64> +func @add_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<2049xi64>) -> tensor<2049xi64> { + // CHECK-NEXT: %0 = linalg.init_tensor [2049] : tensor<2049xi64> + // CHECK-NEXT: %1 = tensor.cast %0 : tensor<2049xi64> to tensor + // CHECK-NEXT: %2 = tensor.cast %arg0 : tensor<2049xi64> to tensor + // CHECK-NEXT: %3 = tensor.cast %arg1 : tensor<2049xi64> to tensor + // CHECK-NEXT: call @memref_add_lwe_ciphertexts_u64(%1, %2, %3) : (tensor, tensor, tensor) -> () + // CHECK-NEXT: return %0 : tensor<2049xi64> + %0 = linalg.init_tensor [2049] : tensor<2049xi64> + "BConcrete.add_lwe_buffer"(%0, %arg0, %arg1) : (tensor<2049xi64>, tensor<2049xi64>, tensor<2049xi64>) -> () + return %0 : tensor<2049xi64> +} diff --git a/compiler/tests/Conversion/BConcreteToBConcreteCAPI/add_lwe_int.mlir b/compiler/tests/Conversion/BConcreteToBConcreteCAPI/add_lwe_int.mlir new file mode 100644 index 000000000..0ea0246ba --- /dev/null +++ b/compiler/tests/Conversion/BConcreteToBConcreteCAPI/add_lwe_int.mlir @@ -0,0 +1,37 @@ +// RUN: concretecompiler --passes bconcrete-to-bconcrete-c-api --action=dump-std %s 2>&1| FileCheck %s + + +// CHECK-LABEL: func @add_glwe_const_int(%arg0: tensor<1025xi64>, %arg1: !Concrete.context) -> tensor<1025xi64> +func @add_glwe_const_int(%arg0: tensor<1025xi64>) -> tensor<1025xi64> { + // CHECK-NEXT: %c1_i8 = arith.constant 1 : i8 + // CHECK-NEXT: %0 = arith.extui %c1_i8 : i8 to i64 + // CHECK-NEXT: %c56_i64 = arith.constant 56 : i64 + // CHECK-NEXT: %1 = arith.shli %0, %c56_i64 : i64 + // CHECK-NEXT: %2 = linalg.init_tensor [1025] : tensor<1025xi64> + // CHECK-NEXT: %3 = tensor.cast %2 : tensor<1025xi64> to tensor + // CHECK-NEXT: %4 = tensor.cast %arg0 : tensor<1025xi64> to tensor + // CHECK-NEXT: call @memref_add_plaintext_lwe_ciphertext_u64(%3, %4, %1) : (tensor, tensor, i64) -> () + // CHECK-NEXT: return %2 : tensor<1025xi64> + %0 = arith.constant 1 : i8 + %1 = "Concrete.encode_int"(%0) : (i8) -> !Concrete.plaintext<8> + %2 = linalg.init_tensor [1025] : tensor<1025xi64> + "BConcrete.add_plaintext_lwe_buffer"(%2, %arg0, %1) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.plaintext<8>) -> () + return %2 : tensor<1025xi64> +} + + +// CHECK-LABEL: func @add_glwe_int(%arg0: tensor<1025xi64>, %arg1: i5, %arg2: !Concrete.context) -> tensor<1025xi64> +func @add_glwe_int(%arg0: tensor<1025xi64>, %arg1: i5) -> tensor<1025xi64> { + // CHECK-NEXT: %0 = arith.extui %arg1 : i5 to i64 + // CHECK-NEXT: %c59_i64 = arith.constant 59 : i64 + // CHECK-NEXT: %1 = arith.shli %0, %c59_i64 : i64 + // CHECK-NEXT: %2 = linalg.init_tensor [1025] : tensor<1025xi64> + // CHECK-NEXT: %3 = tensor.cast %2 : tensor<1025xi64> to tensor + // CHECK-NEXT: %4 = tensor.cast %arg0 : tensor<1025xi64> to tensor + // CHECK-NEXT: call @memref_add_plaintext_lwe_ciphertext_u64(%3, %4, %1) : (tensor, tensor, i64) -> () + // CHECK-NEXT: return %2 : tensor<1025xi64> + %0 = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5> + %1 = linalg.init_tensor [1025] : tensor<1025xi64> + "BConcrete.add_plaintext_lwe_buffer"(%1, %arg0, %0) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.plaintext<5>) -> () + return %1 : tensor<1025xi64> +} diff --git a/compiler/tests/Conversion/BConcreteToBConcreteCAPI/bootstrap_lwe.mlir b/compiler/tests/Conversion/BConcreteToBConcreteCAPI/bootstrap_lwe.mlir new file mode 100644 index 000000000..a86287f15 --- /dev/null +++ b/compiler/tests/Conversion/BConcreteToBConcreteCAPI/bootstrap_lwe.mlir @@ -0,0 +1,15 @@ +// RUN: concretecompiler --passes bconcrete-to-bconcrete-c-api --action=dump-std %s 2>&1| FileCheck %s + +// CHECK: func @bootstrap_lwe(%arg0: tensor<1025xi64>, %arg1: !Concrete.glwe_ciphertext, %arg2: !Concrete.context) -> tensor<1025xi64> { +// CHECK-NEXT: %0 = linalg.init_tensor [1025] : tensor<1025xi64> +// CHECK-NEXT: %1 = call @get_bootstrap_key(%arg2) : (!Concrete.context) -> !Concrete.lwe_bootstrap_key +// CHECK-NEXT: %2 = tensor.cast %0 : tensor<1025xi64> to tensor +// CHECK-NEXT: %3 = tensor.cast %arg0 : tensor<1025xi64> to tensor +// CHECK-NEXT: call @memref_bootstrap_lwe_u64(%1, %2, %3, %arg1) : (!Concrete.lwe_bootstrap_key, tensor, tensor, !Concrete.glwe_ciphertext) -> () +// CHECK-NEXT: return %0 : tensor<1025xi64> +// CHECK-NEXT: } +func @bootstrap_lwe(%arg0: tensor<1025xi64>, %arg1: !Concrete.glwe_ciphertext) -> tensor<1025xi64> { + %0 = linalg.init_tensor [1025] : tensor<1025xi64> + "BConcrete.bootstrap_lwe_buffer"(%0, %arg0, %arg1) {baseLog = 2 : i32, glweDimension = 1 : i32, level = 3 : i32, polynomialSize = 1024 : i32} : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.glwe_ciphertext) -> () + return %0 : tensor<1025xi64> +} \ No newline at end of file diff --git a/compiler/tests/Conversion/BConcreteToBConcreteCAPI/keyswitch_lwe.mlir b/compiler/tests/Conversion/BConcreteToBConcreteCAPI/keyswitch_lwe.mlir new file mode 100644 index 000000000..054ee8f03 --- /dev/null +++ b/compiler/tests/Conversion/BConcreteToBConcreteCAPI/keyswitch_lwe.mlir @@ -0,0 +1,15 @@ +// RUN: concretecompiler --passes bconcrete-to-bconcrete-c-api --action=dump-std %s 2>&1| FileCheck %s + +//CHECK: func @keyswitch_lwe(%arg0: tensor<1025xi64>, %arg1: !Concrete.context) -> tensor<1025xi64> { +//CHECK-NEXT: %0 = linalg.init_tensor [1025] : tensor<1025xi64> +//CHECK-NEXT: %1 = call @get_keyswitch_key(%arg1) : (!Concrete.context) -> !Concrete.lwe_key_switch_key +//CHECK-NEXT: %2 = tensor.cast %0 : tensor<1025xi64> to tensor +//CHECK-NEXT: %3 = tensor.cast %arg0 : tensor<1025xi64> to tensor +//CHECK-NEXT: call @memref_keyswitch_lwe_u64(%1, %2, %3) : (!Concrete.lwe_key_switch_key, tensor, tensor) -> () +//CHECK-NEXT: return %0 : tensor<1025xi64> +//CHECK-NEXT: } +func @keyswitch_lwe(%arg0: tensor<1025xi64>) -> tensor<1025xi64> { + %0 = linalg.init_tensor [1025] : tensor<1025xi64> + "BConcrete.keyswitch_lwe_buffer"(%0, %arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 1 : i32} : (tensor<1025xi64>, tensor<1025xi64>) -> () + return %0 : tensor<1025xi64> +} \ No newline at end of file diff --git a/compiler/tests/Conversion/BConcreteToBConcreteCAPI/mul_lwe_int.mlir b/compiler/tests/Conversion/BConcreteToBConcreteCAPI/mul_lwe_int.mlir new file mode 100644 index 000000000..ee8ec21a5 --- /dev/null +++ b/compiler/tests/Conversion/BConcreteToBConcreteCAPI/mul_lwe_int.mlir @@ -0,0 +1,33 @@ +// RUN: concretecompiler --passes bconcrete-to-bconcrete-c-api --action=dump-std %s 2>&1| FileCheck %s + +// CHECK-LABEL: func @mul_lwe_const_int(%arg0: tensor<1025xi64>, %arg1: !Concrete.context) -> tensor<1025xi64> +func @mul_lwe_const_int(%arg0: tensor<1025xi64>) -> tensor<1025xi64> { + // CHECK-NEXT: %c1_i8 = arith.constant 1 : i8 + // CHECK-NEXT: %0 = arith.extui %c1_i8 : i8 to i64 + // CHECK-NEXT: %1 = linalg.init_tensor [1025] : tensor<1025xi64> + // CHECK-NEXT: %2 = tensor.cast %1 : tensor<1025xi64> to tensor + // CHECK-NEXT: %3 = tensor.cast %arg0 : tensor<1025xi64> to tensor + // CHECK-NEXT: call @memref_mul_cleartext_lwe_ciphertext_u64(%2, %3, %0) : (tensor, tensor, i64) -> () + // CHECK-NEXT: return %1 : tensor<1025xi64> + %c1_i8 = arith.constant 1 : i8 + %1 = "Concrete.int_to_cleartext"(%c1_i8) : (i8) -> !Concrete.cleartext<8> + %2 = linalg.init_tensor [1025] : tensor<1025xi64> + "BConcrete.mul_cleartext_lwe_buffer"(%2, %arg0, %1) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.cleartext<8>) -> () + return %2 : tensor<1025xi64> +} + + + +// CHECK-LABEL: func @mul_lwe_int(%arg0: tensor<1025xi64>, %arg1: i5, %arg2: !Concrete.context) -> tensor<1025xi64> +func @mul_lwe_int(%arg0: tensor<1025xi64>, %arg1: i5) -> tensor<1025xi64> { + // CHECK-NEXT: %0 = arith.extui %arg1 : i5 to i64 + // CHECK-NEXT: %1 = linalg.init_tensor [1025] : tensor<1025xi64> + // CHECK-NEXT: %2 = tensor.cast %1 : tensor<1025xi64> to tensor + // CHECK-NEXT: %3 = tensor.cast %arg0 : tensor<1025xi64> to tensor + // CHECK-NEXT: call @memref_mul_cleartext_lwe_ciphertext_u64(%2, %3, %0) : (tensor, tensor, i64) -> () + // CHECK-NEXT: return %1 : tensor<1025xi64> + %0 = "Concrete.int_to_cleartext"(%arg1) : (i5) -> !Concrete.cleartext<5> + %1 = linalg.init_tensor [1025] : tensor<1025xi64> + "BConcrete.mul_cleartext_lwe_buffer"(%1, %arg0, %0) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.cleartext<5>) -> () + return %1 : tensor<1025xi64> +} diff --git a/compiler/tests/Conversion/BConcreteToBConcreteCAPI/neg_lwe.mlir b/compiler/tests/Conversion/BConcreteToBConcreteCAPI/neg_lwe.mlir new file mode 100644 index 000000000..2ec41fec9 --- /dev/null +++ b/compiler/tests/Conversion/BConcreteToBConcreteCAPI/neg_lwe.mlir @@ -0,0 +1,13 @@ +// RUN: concretecompiler --passes bconcrete-to-bconcrete-c-api --action=dump-std %s 2>&1| FileCheck %s + +// CHECK-LABEL: func @neg_lwe(%arg0: tensor<1025xi64>, %arg1: !Concrete.context) -> tensor<1025xi64> { +func @neg_lwe(%arg0: tensor<1025xi64>) -> tensor<1025xi64> { + // CHECK-NEXT: %0 = linalg.init_tensor [1025] : tensor<1025xi64> + // CHECK-NEXT: %1 = tensor.cast %0 : tensor<1025xi64> to tensor + // CHECK-NEXT: %2 = tensor.cast %arg0 : tensor<1025xi64> to tensor + // CHECK-NEXT: call @memref_negate_lwe_ciphertext_u64(%1, %2) : (tensor, tensor) -> () + // CHECK-NEXT: return %0 : tensor<1025xi64> + %0 = linalg.init_tensor [1025] : tensor<1025xi64> + "BConcrete.negate_lwe_buffer"(%0, %arg0) : (tensor<1025xi64>, tensor<1025xi64>) -> () + return %0 : tensor<1025xi64> +} diff --git a/compiler/tests/Conversion/BConcreteToBConcreteCAPI/sub_int_lwe.mlir b/compiler/tests/Conversion/BConcreteToBConcreteCAPI/sub_int_lwe.mlir new file mode 100644 index 000000000..4792386a9 --- /dev/null +++ b/compiler/tests/Conversion/BConcreteToBConcreteCAPI/sub_int_lwe.mlir @@ -0,0 +1,47 @@ +// RUN: concretecompiler --passes bconcrete-to-bconcrete-c-api --action=dump-std %s 2>&1| FileCheck %s + +// CHECK-LABEL: func @sub_const_int_lwe(%arg0: tensor<1025xi64>, %arg1: !Concrete.context) -> tensor<1025xi64> { +func @sub_const_int_lwe(%arg0: tensor<1025xi64>) -> tensor<1025xi64> { + // CHECK-NEXT: %c1_i8 = arith.constant 1 : i8 + // CHECK-NEXT: %0 = linalg.init_tensor [1025] : tensor<1025xi64> + // CHECK-NEXT: %1 = tensor.cast %0 : tensor<1025xi64> to tensor + // CHECK-NEXT: %2 = tensor.cast %arg0 : tensor<1025xi64> to tensor + // CHECK-NEXT: call @memref_negate_lwe_ciphertext_u64(%1, %2) : (tensor, tensor) -> () + // CHECK-NEXT: %3 = arith.extui %c1_i8 : i8 to i64 + // CHECK-NEXT: %c56_i64 = arith.constant 56 : i64 + // CHECK-NEXT: %4 = arith.shli %3, %c56_i64 : i64 + // CHECK-NEXT: %5 = linalg.init_tensor [1025] : tensor<1025xi64> + // CHECK-NEXT: %6 = tensor.cast %5 : tensor<1025xi64> to tensor + // CHECK-NEXT: %7 = tensor.cast %0 : tensor<1025xi64> to tensor + // CHECK-NEXT: call @memref_add_plaintext_lwe_ciphertext_u64(%6, %7, %4) : (tensor, tensor, i64) -> () + // CHECK-NEXT: return %5 : tensor<1025xi64> + %0 = arith.constant 1 : i8 + %1 = linalg.init_tensor [1025] : tensor<1025xi64> + "BConcrete.negate_lwe_buffer"(%1, %arg0) : (tensor<1025xi64>, tensor<1025xi64>) -> () + %2 = "Concrete.encode_int"(%0) : (i8) -> !Concrete.plaintext<8> + %3 = linalg.init_tensor [1025] : tensor<1025xi64> + "BConcrete.add_plaintext_lwe_buffer"(%3, %1, %2) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.plaintext<8>) -> () + return %3 : tensor<1025xi64> +} + +// CHECK-LABEL: func @sub_int_lwe(%arg0: tensor<1025xi64>, %arg1: i5, %arg2: !Concrete.context) -> tensor<1025xi64> { +func @sub_int_lwe(%arg0: tensor<1025xi64>, %arg1: i5) -> tensor<1025xi64> { + // CHECK-NEXT: %0 = linalg.init_tensor [1025] : tensor<1025xi64> + // CHECK-NEXT: %1 = tensor.cast %0 : tensor<1025xi64> to tensor + // CHECK-NEXT: %2 = tensor.cast %arg0 : tensor<1025xi64> to tensor + // CHECK-NEXT: call @memref_negate_lwe_ciphertext_u64(%1, %2) : (tensor, tensor) -> () + // CHECK-NEXT: %3 = arith.extui %arg1 : i5 to i64 + // CHECK-NEXT: %c59_i64 = arith.constant 59 : i64 + // CHECK-NEXT: %4 = arith.shli %3, %c59_i64 : i64 + // CHECK-NEXT: %5 = linalg.init_tensor [1025] : tensor<1025xi64> + // CHECK-NEXT: %6 = tensor.cast %5 : tensor<1025xi64> to tensor + // CHECK-NEXT: %7 = tensor.cast %0 : tensor<1025xi64> to tensor + // CHECK-NEXT: call @memref_add_plaintext_lwe_ciphertext_u64(%6, %7, %4) : (tensor, tensor, i64) -> () + // CHECK-NEXT: return %5 : tensor<1025xi64> + %0 = linalg.init_tensor [1025] : tensor<1025xi64> + "BConcrete.negate_lwe_buffer"(%0, %arg0) : (tensor<1025xi64>, tensor<1025xi64>) -> () + %1 = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5> + %2 = linalg.init_tensor [1025] : tensor<1025xi64> + "BConcrete.add_plaintext_lwe_buffer"(%2, %0, %1) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.plaintext<5>) -> () + return %2 : tensor<1025xi64> +} diff --git a/compiler/tests/Conversion/ConcreteToBConcrete/add_lwe.mlir b/compiler/tests/Conversion/ConcreteToBConcrete/add_lwe.mlir new file mode 100644 index 000000000..ab9adda07 --- /dev/null +++ b/compiler/tests/Conversion/ConcreteToBConcrete/add_lwe.mlir @@ -0,0 +1,10 @@ +// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s + +// CHECK-LABEL: func @add_glwe(%arg0: tensor<2049xi64>, %arg1: tensor<2049xi64>) -> tensor<2049xi64> +func @add_glwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> { + // CHECK-NEXT: %[[V1:.*]] = linalg.init_tensor [2049] : tensor<2049xi64> + // CHECK-NEXT: "BConcrete.add_lwe_buffer"(%[[V1]], %arg0, %arg1) : (tensor<2049xi64>, tensor<2049xi64>, tensor<2049xi64>) -> () + // CHECK-NEXT: return %[[V1]] : tensor<2049xi64> + %0 = "Concrete.add_lwe_ciphertexts"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> + return %0 : !Concrete.lwe_ciphertext<2048,7> +} diff --git a/compiler/tests/Conversion/ConcreteToBConcrete/add_lwe_int.mlir b/compiler/tests/Conversion/ConcreteToBConcrete/add_lwe_int.mlir new file mode 100644 index 000000000..18442f452 --- /dev/null +++ b/compiler/tests/Conversion/ConcreteToBConcrete/add_lwe_int.mlir @@ -0,0 +1,25 @@ +// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s + +// CHECK-LABEL: func @add_glwe_const_int(%arg0: tensor<1025xi64>) -> tensor<1025xi64> +func @add_glwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> { + // CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8 + // CHECK-NEXT: %[[V2:.*]] = "Concrete.encode_int"(%[[V1]]) : (i8) -> !Concrete.plaintext<8> + // CHECK-NEXT: %[[V3:.*]] = linalg.init_tensor [1025] : tensor<1025xi64> + // CHECK-NEXT: "BConcrete.add_plaintext_lwe_buffer"(%1, %arg0, %0) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.plaintext<8>) -> () + // CHECK-NEXT: return %[[V3]] : tensor<1025xi64> + %0 = arith.constant 1 : i8 + %1 = "Concrete.encode_int"(%0) : (i8) -> !Concrete.plaintext<8> + %2 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %1) : (!Concrete.lwe_ciphertext<1024,7>, !Concrete.plaintext<8>) -> !Concrete.lwe_ciphertext<1024,7> + return %2 : !Concrete.lwe_ciphertext<1024,7> +} + +// CHECK-LABEL: func @add_glwe_int(%arg0: tensor<1025xi64>, %arg1: i5) -> tensor<1025xi64> +func @add_glwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> { + // CHECK-NEXT: %[[V1:.*]] = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5> + // CHECK-NEXT: %[[V2:.*]] = linalg.init_tensor [1025] : tensor<1025xi64> + // CHECK-NEXT: "BConcrete.add_plaintext_lwe_buffer"(%[[V2:.*]], %arg0, %[[V1:.*]]) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.plaintext<5>) -> () + // CHECK-NEXT: return %[[V2]] : tensor<1025xi64> + %0 = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5> + %1 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,4>, !Concrete.plaintext<5>) -> !Concrete.lwe_ciphertext<1024,4> + return %1 : !Concrete.lwe_ciphertext<1024,4> +} diff --git a/compiler/tests/Conversion/ConcreteToBConcrete/apply_lookup_table.mlir b/compiler/tests/Conversion/ConcreteToBConcrete/apply_lookup_table.mlir new file mode 100644 index 000000000..457dbf8cc --- /dev/null +++ b/compiler/tests/Conversion/ConcreteToBConcrete/apply_lookup_table.mlir @@ -0,0 +1,15 @@ +// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s + +// CHECK-LABEL: func @apply_lookup_table(%arg0: tensor<1025xi64>, %arg1: tensor<16xi64>) -> tensor<1025xi64> +func @apply_lookup_table(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: tensor<16xi64>) -> !Concrete.lwe_ciphertext<1024,4> { + // CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%arg1) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext + // CHECK-NEXT: %[[V2:.*]] = linalg.init_tensor [601] : tensor<601xi64> + // CHECK-NEXT: "BConcrete.keyswitch_lwe_buffer"(%[[V2]], %arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (tensor<601xi64>, tensor<1025xi64>) -> () + // CHECK-NEXT: %[[V3:.*]] = linalg.init_tensor [1025] : tensor<1025xi64> + // CHECK-NEXT: "BConcrete.bootstrap_lwe_buffer"(%[[V3]], %[[V2]], %[[V1]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (tensor<1025xi64>, tensor<601xi64>, !Concrete.glwe_ciphertext) -> () + // CHECK-NEXT: return %[[V3]] : tensor<1025xi64> + %0 = "Concrete.glwe_from_table"(%arg1) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext + %1 = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<600,4> + %2 = "Concrete.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<1024,4> + return %2 : !Concrete.lwe_ciphertext<1024,4> +} diff --git a/compiler/tests/Conversion/ConcreteToBConcrete/apply_lookup_table_cst.mlir b/compiler/tests/Conversion/ConcreteToBConcrete/apply_lookup_table_cst.mlir new file mode 100644 index 000000000..92f2a49d6 --- /dev/null +++ b/compiler/tests/Conversion/ConcreteToBConcrete/apply_lookup_table_cst.mlir @@ -0,0 +1,17 @@ +// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s + +// CHECK-LABEL: func @apply_lookup_table_cst(%arg0: tensor<2049xi64>) -> tensor<2049xi64> +func @apply_lookup_table_cst(%arg0: !Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<2048,4> { + // CHECK-NEXT: %[[TABLE:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64> + // CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%[[TABLE:.*]]) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext + // CHECK-NEXT: %[[V2:.*]] = linalg.init_tensor [601] : tensor<601xi64> + // CHECK-NEXT: "BConcrete.keyswitch_lwe_buffer"([[V2:.*]], %arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (tensor<601xi64>, tensor<2049xi64>) -> () + // CHECK-NEXT: %[[V3:.*]] = linalg.init_tensor [2049] : tensor<2049xi64> + // CHECK-NEXT: "BConcrete.bootstrap_lwe_buffer"(%[[V3:.*]], %[[V2:.*]], %[[V1:.*]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (tensor<2049xi64>, tensor<601xi64>, !Concrete.glwe_ciphertext) -> () + // CHECK-NEXT: return %[[V3]] : tensor<2049xi64> + %tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64> + %0 = "Concrete.glwe_from_table"(%tlu) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext + %1 = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (!Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<600,4> + %2 = "Concrete.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<2048,4> + return %2 : !Concrete.lwe_ciphertext<2048,4> +} diff --git a/compiler/tests/Conversion/ConcreteToBConcrete/identity.mlir b/compiler/tests/Conversion/ConcreteToBConcrete/identity.mlir new file mode 100644 index 000000000..7a69e1620 --- /dev/null +++ b/compiler/tests/Conversion/ConcreteToBConcrete/identity.mlir @@ -0,0 +1,8 @@ +// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s + +// CHECK: func @identity(%arg0: tensor<1025xi64>) -> tensor<1025xi64> { +// CHECK-NEXT: return %arg0 : tensor<1025xi64> +// CHECK-NEXT: } +func @identity(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> { + return %arg0 : !Concrete.lwe_ciphertext<1024,7> +} \ No newline at end of file diff --git a/compiler/tests/Conversion/ConcreteToBConcrete/mul_lwe_int.mlir b/compiler/tests/Conversion/ConcreteToBConcrete/mul_lwe_int.mlir new file mode 100644 index 000000000..c13616579 --- /dev/null +++ b/compiler/tests/Conversion/ConcreteToBConcrete/mul_lwe_int.mlir @@ -0,0 +1,25 @@ +// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s + +// CHECK-LABEL: func @mul_lwe_const_int(%arg0: tensor<1025xi64>) -> tensor<1025xi64> +func @mul_lwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> { + // CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8 + // CHECK-NEXT: %[[V2:.*]] = "Concrete.int_to_cleartext"(%c1_i8) : (i8) -> !Concrete.cleartext<8> + // CHECK-NEXT: %[[V3:.*]] = linalg.init_tensor [1025] : tensor<1025xi64> + // CHECK-NEXT: "BConcrete.mul_cleartext_lwe_buffer"(%[[V3]], %arg0, %[[V2]]) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.cleartext<8>) -> () + // CHECK-NEXT: return %[[V3]] : tensor<1025xi64> + %0 = arith.constant 1 : i8 + %1 = "Concrete.int_to_cleartext"(%0) : (i8) -> !Concrete.cleartext<8> + %2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %1) : (!Concrete.lwe_ciphertext<1024,7>, !Concrete.cleartext<8>) -> !Concrete.lwe_ciphertext<1024,7> + return %2 : !Concrete.lwe_ciphertext<1024,7> +} + +// CHECK-LABEL: func @mul_lwe_int(%arg0: tensor<1025xi64>, %arg1: i5) -> tensor<1025xi64> +func @mul_lwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> { + // CHECK-NEXT: %[[V1:.*]] = "Concrete.int_to_cleartext"(%arg1) : (i5) -> !Concrete.cleartext<5> + // CHECK-NEXT: %[[V2:.*]] = linalg.init_tensor [1025] : tensor<1025xi64> + // CHECK-NEXT: "BConcrete.mul_cleartext_lwe_buffer"(%[[V2]], %arg0, %[[V1]]) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.cleartext<5>) -> () + // CHECK-NEXT: return %[[V2]] : tensor<1025xi64> + %0 = "Concrete.int_to_cleartext"(%arg1) : (i5) -> !Concrete.cleartext<5> + %1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,4>, !Concrete.cleartext<5>) -> !Concrete.lwe_ciphertext<1024,4> + return %1 : !Concrete.lwe_ciphertext<1024,4> +} diff --git a/compiler/tests/Conversion/ConcreteToBConcrete/neg_lwe.mlir b/compiler/tests/Conversion/ConcreteToBConcrete/neg_lwe.mlir new file mode 100644 index 000000000..6e64b89d2 --- /dev/null +++ b/compiler/tests/Conversion/ConcreteToBConcrete/neg_lwe.mlir @@ -0,0 +1,10 @@ +// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s + +// CHECK-LABEL: func @neg_lwe(%arg0: tensor<1025xi64>) -> tensor<1025xi64> +func @neg_lwe(%arg0: !Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4> { + // CHECK-NEXT: %[[V1:.*]] = linalg.init_tensor [1025] : tensor<1025xi64> + // CHECK-NEXT: "BConcrete.negate_lwe_buffer"(%[[V1]], %arg0) : (tensor<1025xi64>, tensor<1025xi64>) -> () + // CHECK-NEXT: return %[[V1]] : tensor<1025xi64> + %0 = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4> + return %0 : !Concrete.lwe_ciphertext<1024,4> +} diff --git a/compiler/tests/Conversion/ConcreteToBConcrete/sub_int_lwe.mlir b/compiler/tests/Conversion/ConcreteToBConcrete/sub_int_lwe.mlir new file mode 100644 index 000000000..906f81191 --- /dev/null +++ b/compiler/tests/Conversion/ConcreteToBConcrete/sub_int_lwe.mlir @@ -0,0 +1,32 @@ +// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s + +// CHECK-LABEL: func @sub_const_int_lwe(%arg0: tensor<1025xi64>) -> tensor<1025xi64> +func @sub_const_int_lwe(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> { + // CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8 + // CHECK-NEXT: %[[V2:.*]] = linalg.init_tensor [1025] : tensor<1025xi64> + // CHECK-NEXT: "BConcrete.negate_lwe_buffer"(%[[V2]], %arg0) : (tensor<1025xi64>, tensor<1025xi64>) -> () + // CHECK-NEXT: %[[V3:.*]] = "Concrete.encode_int"(%[[V1]]) : (i8) -> !Concrete.plaintext<8> + // CHECK-NEXT: %[[V4:.*]] = linalg.init_tensor [1025] : tensor<1025xi64> + // CHECK-NEXT: "BConcrete.add_plaintext_lwe_buffer"(%[[V4]], %[[V2]], %[[V3]]) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.plaintext<8>) -> () + // CHECK-NEXT: return %[[V4]] : tensor<1025xi64> + %0 = arith.constant 1 : i8 + %1 = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> + %2 = "Concrete.encode_int"(%0) : (i8) -> !Concrete.plaintext<8> + %3 = "Concrete.add_plaintext_lwe_ciphertext"(%1, %2) : (!Concrete.lwe_ciphertext<1024,7>, !Concrete.plaintext<8>) -> !Concrete.lwe_ciphertext<1024,7> + return %3 : !Concrete.lwe_ciphertext<1024,7> +} + + +// CHECK-LABEL: func @sub_int_lwe(%arg0: tensor<1025xi64>, %arg1: i5) -> tensor<1025xi64> +func @sub_int_lwe(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> { + // CHECK-NEXT: %[[V1:.*]] = linalg.init_tensor [1025] : tensor<1025xi64> + // CHECK-NEXT: "BConcrete.negate_lwe_buffer"(%[[V1]], %arg0) : (tensor<1025xi64>, tensor<1025xi64>) -> () + // CHECK-NEXT: %[[V2:.*]] = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5> + // CHECK-NEXT: %[[V3:.*]] = linalg.init_tensor [1025] : tensor<1025xi64> + // CHECK-NEXT: "BConcrete.add_plaintext_lwe_buffer"(%[[V3]], %[[V1]], %[[V2]]) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.plaintext<5>) -> () + // CHECK-NEXT: return %[[V3]] : tensor<1025xi64> + %0 = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4> + %1 = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5> + %2 = "Concrete.add_plaintext_lwe_ciphertext"(%0, %1) : (!Concrete.lwe_ciphertext<1024,4>, !Concrete.plaintext<5>) -> !Concrete.lwe_ciphertext<1024,4> + return %2 : !Concrete.lwe_ciphertext<1024,4> +} diff --git a/compiler/tests/Conversion/ConcreteToBConcrete/tensor_identity.mlir b/compiler/tests/Conversion/ConcreteToBConcrete/tensor_identity.mlir new file mode 100644 index 000000000..d8b1f6fd4 --- /dev/null +++ b/compiler/tests/Conversion/ConcreteToBConcrete/tensor_identity.mlir @@ -0,0 +1,7 @@ +// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s +// CHECK: func @tensor_identity(%arg0: tensor<2x3x4x1025xi64>) -> tensor<2x3x4x1025xi64> { +// CHECK-NEXT: return %arg0 : tensor<2x3x4x1025xi64> +// CHECK-NEXT: } +func @tensor_identity(%arg0: tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>>) -> tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>> { + return %arg0 : tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>> +} \ No newline at end of file