enhance(compiler): Lower from Concrete to BConcrete and BConcrete to C API call

This commit is contained in:
Quentin Bourgerie
2022-02-11 13:53:11 +01:00
committed by Quentin Bourgerie
parent b3368027d0
commit 626493dda7
30 changed files with 1984 additions and 16 deletions

View File

@@ -0,0 +1,22 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CONVERSION_BCONCRETETOBCONCRETECAPI_PASS_H_
#define CONCRETELANG_CONVERSION_BCONCRETETOBCONCRETECAPI_PASS_H_
#include "mlir/Pass/Pass.h"
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
namespace mlir {
namespace concretelang {
/// Create a pass to convert `Concrete` operators to function call to the
/// `ConcreteCAPI`
std::unique_ptr<OperationPass<ModuleOp>>
createConvertBConcreteToBConcreteCAPIPass();
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -0,0 +1,18 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef ZAMALANG_CONVERSION_CONCRETETOBCONCRETE_PASS_H_
#define ZAMALANG_CONVERSION_CONCRETETOBCONCRETE_PASS_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace concretelang {
/// Create a pass to convert `Concrete` dialect to `BConcrete` dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertConcreteToBConcretePass();
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -11,6 +11,8 @@
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "concretelang/Conversion/BConcreteToBConcreteCAPI/Pass.h"
#include "concretelang/Conversion/ConcreteToBConcrete/Pass.h"
#include "concretelang/Conversion/ConcreteToConcreteCAPI/Pass.h"
#include "concretelang/Conversion/ConcreteUnparametrize/Pass.h"
#include "concretelang/Conversion/FHETensorOpsToLinalg/Pass.h"
@@ -18,6 +20,7 @@
#include "concretelang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h"
#include "concretelang/Conversion/TFHEGlobalParametrization/Pass.h"
#include "concretelang/Conversion/TFHEToConcrete/Pass.h"
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"

View File

@@ -29,7 +29,15 @@ def TFHEToConcrete : Pass<"tfhe-to-concrete", "mlir::ModuleOp"> {
let description = [{ Lowers operations from the TFHE dialect to Concrete }];
let constructor = "mlir::concretelang::createConvertTFHEToConcretePass()";
let options = [];
let dependentDialects = ["mlir::linalg::LinalgDialect"];
let dependentDialects = ["mlir::linalg::LinalgDialect", "mlir::concretelang::TFHE::TFHEDialect"];
}
def ConcreteToBConcrete : Pass<"concrete-to-bconcrete", "mlir::ModuleOp"> {
let summary = "Lowers operations from the Concrete dialect to Bufferized Concrete";
let description = [{ Lowers operations from the Concrete dialect to Bufferized Concrete }];
let constructor = "mlir::concretelang::createConvertConcreteToBConcretePass()";
let options = [];
let dependentDialects = ["mlir::linalg::LinalgDialect", "mlir::concretelang::Concrete::ConcreteDialect", "mlir::concretelang::BConcrete::BConcreteDialect"];
}
def ConcreteToConcreteCAPI : Pass<"concrete-to-concrete-c-api", "mlir::ModuleOp"> {
@@ -38,6 +46,12 @@ def ConcreteToConcreteCAPI : Pass<"concrete-to-concrete-c-api", "mlir::ModuleOp"
let dependentDialects = ["mlir::concretelang::Concrete::ConcreteDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"];
}
def BConcreteToBConcreteCAPI : Pass<"bconcrete-to-bconcrete-c-api", "mlir::ModuleOp"> {
let summary = "Lower operations from the Bufferized Concrete dialect to std with function call to the Bufferized Concrete C API";
let constructor = "mlir::concretelang::createConvertBConcreteToBConcreteCAPIPass()";
let dependentDialects = ["mlir::concretelang::BConcrete::BConcreteDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"];
}
def ConcreteUnparametrize : Pass<"concrete-unparametrize", "mlir::ModuleOp"> {
let summary = "Unparametrize Concrete types and remove unrealized_conversion_cast";
let constructor = "mlir::concretelang::createConvertConcreteToConcreteCAPIPass()";

View File

@@ -119,6 +119,10 @@ public:
// operations
CONCRETE,
// Read sources and lower all FHE, TFHE and Concrete operations to BConcrete
// operations
BCONCRETE,
// Read sources and lower all FHE, TFHE and Concrete
// operations to canonical MLIR dialects. Cryptographic operations
// are lowered to invocations of the concrete library.

View File

@@ -43,8 +43,12 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
lowerConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,

View File

@@ -0,0 +1,593 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#include "mlir//IR/BuiltinTypes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/DialectConversion.h"
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
namespace {
class BConcreteToBConcreteCAPITypeConverter : public mlir::TypeConverter {
public:
BConcreteToBConcreteCAPITypeConverter() {
addConversion([](mlir::Type type) { return type; });
addConversion([&](mlir::concretelang::Concrete::PlaintextType type) {
return mlir::IntegerType::get(type.getContext(), 64);
});
addConversion([&](mlir::concretelang::Concrete::CleartextType type) {
return mlir::IntegerType::get(type.getContext(), 64);
});
}
};
mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op,
mlir::RewriterBase &rewriter,
llvm::StringRef funcName,
mlir::FunctionType funcType) {
// Looking for the `funcName` Operation
auto module = mlir::SymbolTable::getNearestSymbolTable(op);
auto opFunc = mlir::dyn_cast_or_null<mlir::SymbolOpInterface>(
mlir::SymbolTable::lookupSymbolIn(module, funcName));
if (!opFunc) {
// Insert the forward declaration of the funcName
mlir::OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&module->getRegion(0).front());
opFunc = rewriter.create<mlir::FuncOp>(rewriter.getUnknownLoc(), funcName,
funcType);
opFunc.setPrivate();
} else {
// Check if the `funcName` is well a private function
if (!opFunc.isPrivate()) {
op->emitError() << "the function \"" << funcName
<< "\" conflicts with the concrete C API, please rename";
return mlir::failure();
}
}
assert(mlir::SymbolTable::lookupSymbolIn(module, funcName)
->template hasTrait<mlir::OpTrait::FunctionLike>());
return mlir::success();
}
// Set of functions to generate generic types.
// Generic types are used to add forward declarations without a specific type.
// For example, we may need to add LWE ciphertext of different dimensions, or
// allocate them. All the calls to the C API should be done using this generic
// types, and casting should then be performed back to the appropriate type.
inline mlir::Type getGenericLweBufferType(mlir::MLIRContext *context) {
return mlir::RankedTensorType::get({-1}, mlir::IntegerType::get(context, 64));
}
inline mlir::concretelang::Concrete::GlweCiphertextType
getGenericGlweCiphertextType(mlir::MLIRContext *context) {
return mlir::concretelang::Concrete::GlweCiphertextType::get(context);
}
inline mlir::Type getGenericPlaintextType(mlir::MLIRContext *context) {
return mlir::IntegerType::get(context, 64);
}
inline mlir::Type getGenericCleartextType(mlir::MLIRContext *context) {
return mlir::IntegerType::get(context, 64);
}
inline mlir::concretelang::Concrete::PlaintextListType
getGenericPlaintextListType(mlir::MLIRContext *context) {
return mlir::concretelang::Concrete::PlaintextListType::get(context);
}
inline mlir::concretelang::Concrete::ForeignPlaintextListType
getGenericForeignPlaintextListType(mlir::MLIRContext *context) {
return mlir::concretelang::Concrete::ForeignPlaintextListType::get(context);
}
inline mlir::concretelang::Concrete::LweKeySwitchKeyType
getGenericLweKeySwitchKeyType(mlir::MLIRContext *context) {
return mlir::concretelang::Concrete::LweKeySwitchKeyType::get(context);
}
inline mlir::concretelang::Concrete::LweBootstrapKeyType
getGenericLweBootstrapKeyType(mlir::MLIRContext *context) {
return mlir::concretelang::Concrete::LweBootstrapKeyType::get(context);
}
// Insert all forward declarations needed for the pass.
// Should generalize input and output types for all decalarations, and the
// pattern using them would be resposible for casting them to the appropriate
// type.
mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
mlir::IRRewriter &rewriter) {
auto lweBufferType = getGenericLweBufferType(rewriter.getContext());
auto plaintextType = getGenericPlaintextType(rewriter.getContext());
auto cleartextType = getGenericCleartextType(rewriter.getContext());
auto glweCiphertextType = getGenericGlweCiphertextType(rewriter.getContext());
auto plaintextListType = getGenericPlaintextListType(rewriter.getContext());
auto foreignPlaintextList =
getGenericForeignPlaintextListType(rewriter.getContext());
auto keySwitchKeyType = getGenericLweKeySwitchKeyType(rewriter.getContext());
auto bootstrapKeyType = getGenericLweBootstrapKeyType(rewriter.getContext());
auto contextType =
mlir::concretelang::Concrete::ContextType::get(rewriter.getContext());
// Insert forward declaration of the add_lwe_ciphertexts function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(), {lweBufferType, lweBufferType, lweBufferType},
{});
if (insertForwardDeclaration(op, rewriter, "memref_add_lwe_ciphertexts_u64",
funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the add_plaintext_lwe_ciphertext_u64 function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(), {lweBufferType, lweBufferType, plaintextType},
{});
if (insertForwardDeclaration(
op, rewriter, "memref_add_plaintext_lwe_ciphertext_u64", funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the mul_cleartext_lwe_ciphertext_u64 function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(), {lweBufferType, lweBufferType, cleartextType},
{});
if (insertForwardDeclaration(
op, rewriter, "memref_mul_cleartext_lwe_ciphertext_u64", funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the negate_lwe_ciphertext_u64 function
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{lweBufferType, lweBufferType}, {});
if (insertForwardDeclaration(op, rewriter,
"memref_negate_lwe_ciphertext_u64", funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the memref_keyswitch_lwe_u64 function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(), {keySwitchKeyType, lweBufferType, lweBufferType},
{});
if (insertForwardDeclaration(op, rewriter, "memref_keyswitch_lwe_u64",
funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the memref_bootstrap_lwe_u64 function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{bootstrapKeyType, lweBufferType, lweBufferType, glweCiphertextType},
{});
if (insertForwardDeclaration(op, rewriter, "memref_bootstrap_lwe_u64",
funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the fill_plaintext_list function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(), {plaintextListType, foreignPlaintextList}, {});
if (insertForwardDeclaration(
op, rewriter, "fill_plaintext_list_with_expansion_u64", funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the add_plaintext_list_glwe function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{glweCiphertextType, glweCiphertextType, plaintextListType}, {});
if (insertForwardDeclaration(
op, rewriter, "add_plaintext_list_glwe_ciphertext_u64", funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the getGlobalKeyswitchKey function
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{contextType}, {keySwitchKeyType});
if (insertForwardDeclaration(op, rewriter, "get_keyswitch_key", funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the getGlobalBootstrapKey function
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{contextType}, {bootstrapKeyType});
if (insertForwardDeclaration(op, rewriter, "get_bootstrap_key", funcType)
.failed()) {
return mlir::failure();
}
}
return mlir::success();
}
// For all operands `tensor<Axi64>` replace with
// `%casted = tensor.cast %op : tensor<Axi64> to tensor<?xui64>`
template <typename Op>
mlir::SmallVector<mlir::Value>
getCastedTensorOperands(Op op, mlir::PatternRewriter &rewriter) {
mlir::SmallVector<mlir::Value, 4> newOperands{};
for (mlir::Value operand : op->getOperands()) {
mlir::Type operandType = operand.getType();
if (operandType.isa<mlir::RankedTensorType>()) {
mlir::Value castedOp = rewriter.create<mlir::tensor::CastOp>(
op.getLoc(), getGenericLweBufferType(rewriter.getContext()), operand);
newOperands.push_back(castedOp);
} else {
newOperands.push_back(operand);
}
}
return std::move(newOperands);
}
/// BConcreteOpToConcreteCAPICallPattern<Op> match the `BConcreteOp`
/// Operation and replace with a call to `funcName`, the funcName should be an
/// external function that was linked later. It insert the forward declaration
/// of the private `funcName` if it not already in the symbol table. The C
/// signature of the function should be `void (out, args..., lweDimension)`, the
/// pattern rewrite:
/// ```
/// "BConcreteOp"(%out, args ...) :
/// (tensor<sizexi64>, tensor<sizexi64>...) -> ()
/// ```
/// to
/// ```
/// %out0 = tensor.cast %out : tensor<sizexi64> to tensor<?xui64>
/// %args = tensor.cast ...
/// call @funcName(%out, args...) : (tensor<?xi64>, tensor<?xi64>...) -> ()
/// ```
template <typename BConcreteOp>
struct ConcreteOpToConcreteCAPICallPattern
: public mlir::OpRewritePattern<BConcreteOp> {
ConcreteOpToConcreteCAPICallPattern(mlir::MLIRContext *context,
mlir::StringRef funcName,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<BConcreteOp>(context, benefit),
funcName(funcName) {}
mlir::LogicalResult
matchAndRewrite(BConcreteOp op,
mlir::PatternRewriter &rewriter) const override {
BConcreteToBConcreteCAPITypeConverter typeConverter;
rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, funcName, mlir::TypeRange{},
getCastedTensorOperands<BConcreteOp>(op, rewriter));
return mlir::success();
};
private:
std::string funcName;
};
struct ConcreteEncodeIntOpPattern
: public mlir::OpRewritePattern<mlir::concretelang::Concrete::EncodeIntOp> {
ConcreteEncodeIntOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<mlir::concretelang::Concrete::EncodeIntOp>(
context, benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::concretelang::Concrete::EncodeIntOp op,
mlir::PatternRewriter &rewriter) const override {
{
mlir::Value castedInt = rewriter.create<mlir::arith::ExtUIOp>(
op.getLoc(), rewriter.getIntegerType(64), op->getOperands().front());
mlir::Value constantShiftOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getI64IntegerAttr(64 - op.getType().getP()));
mlir::Type resultType = rewriter.getIntegerType(64);
rewriter.replaceOpWithNewOp<mlir::arith::ShLIOp>(
op, resultType, castedInt, constantShiftOp);
}
return mlir::success();
};
};
struct ConcreteIntToCleartextOpPattern
: public mlir::OpRewritePattern<
mlir::concretelang::Concrete::IntToCleartextOp> {
ConcreteIntToCleartextOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<mlir::concretelang::Concrete::IntToCleartextOp>(
context, benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::concretelang::Concrete::IntToCleartextOp op,
mlir::PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::arith::ExtUIOp>(
op, rewriter.getIntegerType(64), op->getOperands().front());
return mlir::success();
};
};
mlir::Value getContextArgument(mlir::Operation *op) {
mlir::Block *block = op->getBlock();
while (block != nullptr) {
if (llvm::isa<mlir::FuncOp>(block->getParentOp())) {
mlir::Value context = block->getArguments().back();
assert(
context.getType().isa<mlir::concretelang::Concrete::ContextType>() &&
"the Concrete.context should be the last argument of the enclosing "
"function of the op");
return context;
}
block = block->getParentOp()->getBlock();
}
assert("can't find a function that enclose the op");
return nullptr;
}
// Rewrite pattern that rewrite every
// ```
// "BConcrete.keyswitch_lwe_buffer"(%out, %in) {...}:
// (tensor<2049xi64>, tensor<2049xi64>) -> ()
// ```
//
// to
//
// ```
// %ksk = call @get_keywswitch_key(%ctx) :
// (!Concrete.context) -> !Concrete.lwe_key_switch_key
// %out_ = tensor.cast %out : tensor<sizexi64> to tensor<?xi64>
// %in_ = tensor.cast %in : tensor<size'xi64> to tensor<?xi64>
// call @memref_keyswitch_lwe_u64(%ksk, %out_, %in_) :
// (!Concrete.lwe_key_switch_key, tensor<?xui64>, tensor<?xui64>) -> ()
// ```
struct BConcreteKeySwitchLweOpPattern
: public mlir::OpRewritePattern<
mlir::concretelang::BConcrete::KeySwitchLweBufferOp> {
BConcreteKeySwitchLweOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<
mlir::concretelang::BConcrete::KeySwitchLweBufferOp>(context,
benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::concretelang::BConcrete::KeySwitchLweBufferOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::CallOp kskOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "get_keyswitch_key",
getGenericLweKeySwitchKeyType(rewriter.getContext()),
mlir::SmallVector<mlir::Value>{getContextArgument(op)});
mlir::SmallVector<mlir::Value, 3> operands{kskOp.getResult(0)};
operands.append(
getCastedTensorOperands<
mlir::concretelang::BConcrete::KeySwitchLweBufferOp>(op, rewriter));
rewriter.replaceOpWithNewOp<mlir::CallOp>(op, "memref_keyswitch_lwe_u64",
mlir::TypeRange({}), operands);
return mlir::success();
};
};
// Rewrite pattern that rewrite every
// ```
// "BConcrete.bootstrap_lwe_buffer"(%out, %in, %acc) {...} :
// (tensor<2049xui64>, tensor<2049xui64>, !Concrete.glwe_ciphertext) -> ()
// ```
//
// to
//
// ```
// %bsk = call @getGlobalBootstrapKey() : () -> !Concrete.lwe_bootstrap_key
// %out_ = tensor.cast %out : tensor<sizexi64> to tensor<?xi64>
// %in_ = tensor.cast %in : tensor<size'xi64> to tensor<?xi64>
// call @memref_bootstrap_lwe_u64(%bsk, %out_, %in_, %acc_) :
// (!Concrete.lwe_bootstrap_key, tensor<?xi64>, tensor<?xi64>,
// !Concrete.glwe_ciphertext) -> ()
// ```
struct BConcreteBootstrapLweOpPattern
: public mlir::OpRewritePattern<
mlir::concretelang::BConcrete::BootstrapLweBufferOp> {
BConcreteBootstrapLweOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<
mlir::concretelang::BConcrete::BootstrapLweBufferOp>(context,
benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::concretelang::BConcrete::BootstrapLweBufferOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::SmallVector<mlir::Value> getkskOperands{};
mlir::CallOp bskOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "get_bootstrap_key",
getGenericLweBootstrapKeyType(rewriter.getContext()),
mlir::SmallVector<mlir::Value>{getContextArgument(op)});
mlir::SmallVector<mlir::Value, 4> operands{bskOp.getResult(0)};
operands.append(
getCastedTensorOperands<
mlir::concretelang::BConcrete::BootstrapLweBufferOp>(op, rewriter));
rewriter.replaceOpWithNewOp<mlir::CallOp>(op, "memref_bootstrap_lwe_u64",
mlir::TypeRange({}), operands);
return mlir::success();
};
};
/// Populate the RewritePatternSet with all patterns that rewrite Concrete
/// operators to the corresponding function call to the `Concrete C API`.
void populateBConcreteToBConcreteCAPICall(mlir::RewritePatternSet &patterns) {
patterns.add<ConcreteOpToConcreteCAPICallPattern<
mlir::concretelang::BConcrete::AddLweBuffersOp>>(
patterns.getContext(), "memref_add_lwe_ciphertexts_u64");
patterns.add<ConcreteOpToConcreteCAPICallPattern<
mlir::concretelang::BConcrete::AddPlaintextLweBufferOp>>(
patterns.getContext(), "memref_add_plaintext_lwe_ciphertext_u64");
patterns.add<ConcreteOpToConcreteCAPICallPattern<
mlir::concretelang::BConcrete::MulCleartextLweBufferOp>>(
patterns.getContext(), "memref_mul_cleartext_lwe_ciphertext_u64");
patterns.add<ConcreteOpToConcreteCAPICallPattern<
mlir::concretelang::BConcrete::NegateLweBufferOp>>(
patterns.getContext(), "memref_negate_lwe_ciphertext_u64");
patterns.add<ConcreteEncodeIntOpPattern>(patterns.getContext());
patterns.add<ConcreteIntToCleartextOpPattern>(patterns.getContext());
// patterns.add<ConcreteZeroOpPattern>(patterns.getContext());
patterns.add<BConcreteKeySwitchLweOpPattern>(patterns.getContext());
patterns.add<BConcreteBootstrapLweOpPattern>(patterns.getContext());
}
struct AddRuntimeContextToFuncOpPattern
: public mlir::OpRewritePattern<mlir::FuncOp> {
AddRuntimeContextToFuncOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<mlir::FuncOp>(context, benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::FuncOp oldFuncOp,
mlir::PatternRewriter &rewriter) const override {
mlir::OpBuilder::InsertionGuard guard(rewriter);
mlir::FunctionType oldFuncType = oldFuncOp.getType();
// Add a Concrete.context to the function signature
mlir::SmallVector<mlir::Type> newInputs(oldFuncType.getInputs().begin(),
oldFuncType.getInputs().end());
newInputs.push_back(
rewriter.getType<mlir::concretelang::Concrete::ContextType>());
mlir::FunctionType newFuncTy = rewriter.getType<mlir::FunctionType>(
newInputs, oldFuncType.getResults());
// Create the new func
mlir::FuncOp newFuncOp = rewriter.create<mlir::FuncOp>(
oldFuncOp.getLoc(), oldFuncOp.getName(), newFuncTy);
// Create the arguments of the new func
mlir::Region &newFuncBody = newFuncOp.body();
mlir::Block *newFuncEntryBlock = new mlir::Block();
newFuncEntryBlock->addArguments(newFuncTy.getInputs());
newFuncBody.push_back(newFuncEntryBlock);
// Clone the old body to the new one
mlir::BlockAndValueMapping map;
for (auto arg : llvm::enumerate(oldFuncOp.getArguments())) {
map.map(arg.value(), newFuncEntryBlock->getArgument(arg.index()));
}
for (auto &op : oldFuncOp.body().front()) {
newFuncEntryBlock->push_back(op.clone(map));
}
rewriter.eraseOp(oldFuncOp);
return mlir::success();
}
// Legal function are one that are private or has a Concrete.context as last
// arguments.
static bool isLegal(mlir::FuncOp funcOp) {
if (!funcOp.isPublic()) {
return true;
}
// TODO : Don't need to add a runtime context for function that doesn't
// manipulates Concrete types.
//
// if (!llvm::any_of(funcOp.getType().getInputs(), [](mlir::Type t) {
// if (auto tensorTy = t.dyn_cast_or_null<mlir::TensorType>()) {
// t = tensorTy.getElementType();
// }
// return llvm::isa<mlir::concretelang::Concrete::ConcreteDialect>(
// t.getDialect());
// })) {
// return true;
// }
return funcOp.getType().getNumInputs() >= 1 &&
funcOp.getType()
.getInputs()
.back()
.isa<mlir::concretelang::Concrete::ContextType>();
}
};
namespace {
struct BConcreteToBConcreteCAPIPass
: public BConcreteToBConcreteCAPIBase<BConcreteToBConcreteCAPIPass> {
void runOnOperation() final;
};
} // namespace
void BConcreteToBConcreteCAPIPass::runOnOperation() {
mlir::ModuleOp op = getOperation();
// First of all add the Concrete.context to the block arguments of function
// that manipulates ciphertexts.
{
mlir::ConversionTarget target(getContext());
mlir::RewritePatternSet patterns(&getContext());
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
return AddRuntimeContextToFuncOpPattern::isLegal(funcOp);
});
patterns.add<AddRuntimeContextToFuncOpPattern>(patterns.getContext());
// Apply the conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns))
.failed()) {
this->signalPassFailure();
return;
}
}
// Insert forward declaration
mlir::IRRewriter rewriter(&getContext());
if (insertForwardDeclarations(op, rewriter).failed()) {
this->signalPassFailure();
}
// Rewrite Concrete ops to CallOp to the Concrete C API
{
mlir::ConversionTarget target(getContext());
mlir::RewritePatternSet patterns(&getContext());
target.addIllegalDialect<mlir::concretelang::BConcrete::BConcreteDialect>();
target.addLegalDialect<mlir::BuiltinDialect, mlir::StandardOpsDialect,
mlir::tensor::TensorDialect,
mlir::arith::ArithmeticDialect>();
populateBConcreteToBConcreteCAPICall(patterns);
if (mlir::applyPartialConversion(op, target, std::move(patterns))
.failed()) {
this->signalPassFailure();
}
}
}
} // namespace
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertBConcreteToBConcreteCAPIPass() {
return std::make_unique<BConcreteToBConcreteCAPIPass>();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -0,0 +1,16 @@
add_mlir_dialect_library(BConcreteToBConcreteCAPI
BConcreteToBConcreteCAPI.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
DEPENDS
BConcreteDialect
MLIRConversionPassIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRTransforms
)
target_link_libraries(BConcreteToBConcreteCAPI PUBLIC MLIRIR)

View File

@@ -2,6 +2,8 @@ add_subdirectory(FHEToTFHE)
add_subdirectory(TFHEGlobalParametrization)
add_subdirectory(TFHEToConcrete)
add_subdirectory(FHETensorOpsToLinalg)
add_subdirectory(ConcreteToBConcrete)
add_subdirectory(ConcreteToConcreteCAPI)
add_subdirectory(BConcreteToBConcreteCAPI)
add_subdirectory(MLIRLowerableDialectsToLLVM)
add_subdirectory(ConcreteUnparametrize)

View File

@@ -0,0 +1,16 @@
add_mlir_dialect_library(ConcreteToBConcrete
ConcreteToBConcrete.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/Concrete
DEPENDS
ConcreteDialect
BConcreteDialect
LINK_LIBS PUBLIC
MLIRIR
MLIRTransforms
MLIRMath)
target_link_libraries(ConcreteToBConcrete PUBLIC BConcreteDialect MLIRIR)

View File

@@ -0,0 +1,919 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#include <iostream>
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
#include "concretelang/Dialect/RT/IR/RTOps.h"
namespace {
struct ConcreteToBConcretePass
: public ConcreteToBConcreteBase<ConcreteToBConcretePass> {
void runOnOperation() final;
};
} // namespace
/// ConcreteToBConcreteTypeConverter is a TypeConverter that transform
/// `Concrete.lwe_ciphertext<dimension,p>` to `tensor<dimension+1, i64>>`
/// `tensor<...xConcrete.lwe_ciphertext<dimension,p>>` to
/// `tensor<...xdimension+1, i64>>`
class ConcreteToBConcreteTypeConverter : public mlir::TypeConverter {
public:
ConcreteToBConcreteTypeConverter() {
addConversion([](mlir::Type type) { return type; });
addConversion([&](mlir::concretelang::Concrete::LweCiphertextType type) {
return mlir::RankedTensorType::get(
{type.getDimension() + 1},
mlir::IntegerType::get(type.getContext(), 64));
});
addConversion([&](mlir::RankedTensorType type) {
auto lwe = type.getElementType()
.dyn_cast_or_null<
mlir::concretelang::Concrete::LweCiphertextType>();
if (lwe == nullptr) {
return (mlir::Type)(type);
}
mlir::SmallVector<int64_t> newShape;
newShape.reserve(type.getShape().size() + 1);
newShape.append(type.getShape().begin(), type.getShape().end());
newShape.push_back(lwe.getDimension() + 1);
mlir::Type r = mlir::RankedTensorType::get(
newShape, mlir::IntegerType::get(type.getContext(), 64));
return r;
});
addConversion([&](mlir::MemRefType type) {
auto lwe = type.getElementType()
.dyn_cast_or_null<
mlir::concretelang::Concrete::LweCiphertextType>();
if (lwe == nullptr) {
return (mlir::Type)(type);
}
mlir::SmallVector<int64_t> newShape;
newShape.reserve(type.getShape().size() + 1);
newShape.append(type.getShape().begin(), type.getShape().end());
newShape.push_back(lwe.getDimension() + 1);
mlir::Type r = mlir::MemRefType::get(
newShape, mlir::IntegerType::get(type.getContext(), 64));
return r;
});
}
};
// This rewrite pattern transforms any instance of `Concrete.zero_tensor`
// operators.
//
// Example:
//
// ```mlir
// %0 = "Concrete.zero_tensor" () :
// tensor<...x!Concrete.lwe_ciphertext<lweDim,p>>
// ```
//
// becomes:
//
// ```mlir
// %0 = tensor.generate {
// ^bb0(... : index):
// %c0 = arith.constant 0 : i64
// tensor.yield %z
// }: tensor<...xlweDim+1xi64>
// i64>
// ```
template <typename ZeroOp>
struct ZeroOpPattern : public mlir::OpRewritePattern<ZeroOp> {
ZeroOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<ZeroOp>(context, benefit) {}
::mlir::LogicalResult
matchAndRewrite(ZeroOp zeroOp,
::mlir::PatternRewriter &rewriter) const override {
ConcreteToBConcreteTypeConverter converter;
auto resultTy = zeroOp.getType();
auto newResultTy = converter.convertType(resultTy);
auto generateBody = [&](mlir::OpBuilder &nestedBuilder,
mlir::Location nestedLoc,
mlir::ValueRange blockArgs) {
// %c0 = 0 : i64
auto cstOp = nestedBuilder.create<mlir::arith::ConstantOp>(
nestedLoc, nestedBuilder.getI64IntegerAttr(1));
// tensor.yield %z : !FHE.eint<p>
nestedBuilder.create<mlir::tensor::YieldOp>(nestedLoc, cstOp.getResult());
};
// tensor.generate
rewriter.replaceOpWithNewOp<mlir::tensor::GenerateOp>(
zeroOp, newResultTy, mlir::ValueRange{}, generateBody);
return ::mlir::success();
};
};
// This template rewrite pattern transforms any instance of
// `ConcreteOp` to an instance of `BConcreteOp`.
//
// Example:
//
// %0 = "ConcreteOp"(%arg0, ...) :
// (!Concrete.lwe_ciphertext<lwe_dimension, p>, ...) ->
// (!Concrete.lwe_ciphertext<lwe_dimension, p>)
//
// becomes:
//
// %0 = linalg.init_tensor [dimension+1] : tensor<dimension+1, i64>
// "BConcreteOp"(%0, %arg0, ...) : (tensor<dimension+1, i64>>,
// tensor<dimension+1, i64>>, ..., ) -> ()
//
// A reference to the preallocated output is always passed as the first
// argument.
template <typename ConcreteOp, typename BConcreteOp>
struct LowToBConcrete : public mlir::OpRewritePattern<ConcreteOp> {
LowToBConcrete(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<ConcreteOp>(context, benefit) {}
::mlir::LogicalResult
matchAndRewrite(ConcreteOp concreteOp,
::mlir::PatternRewriter &rewriter) const override {
ConcreteToBConcreteTypeConverter converter;
mlir::concretelang::Concrete::LweCiphertextType resultTy =
((mlir::Type)concreteOp->getResult(0).getType())
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
auto newResultTy =
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
// %0 = linalg.init_tensor [dimension+1] : tensor<dimension+1, i64>
mlir::Value init = rewriter.replaceOpWithNewOp<mlir::linalg::InitTensorOp>(
concreteOp, newResultTy.getShape(), newResultTy.getElementType());
// "BConcreteOp"(%0, %arg0, ...) : (tensor<dimension+1, i64>>,
// tensor<dimension+1, i64>>, ..., ) -> ()
mlir::SmallVector<mlir::Value, 3> newOperands{init};
newOperands.append(concreteOp.getOperation()->getOperands().begin(),
concreteOp.getOperation()->getOperands().end());
llvm::ArrayRef<::mlir::NamedAttribute> attributes =
concreteOp.getOperation()->getAttrs();
rewriter.create<BConcreteOp>(concreteOp.getLoc(),
mlir::SmallVector<mlir::Type>{}, newOperands,
attributes);
return ::mlir::success();
};
};
// This rewrite pattern transforms any instance of
// `tensor.extract_slice` operators that operates on tensor of lwe ciphertext.
//
// Example:
//
// ```mlir
// %0 = tensor.extract_slice %arg0
// [offsets...] [sizes...] [strides...]
// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>> to
// tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
// ```
//
// becomes:
//
// ```mlir
// %0 = tensor.extract_slice %arg0
// [offsets..., 0] [sizes..., lweDimension+1] [strides..., 1]
// : tensor<...xlweDimension+1,i64> to
// tensor<...xlweDimension+1,i64>
// ```
struct ExtractSliceOpPattern
: public mlir::OpRewritePattern<mlir::tensor::ExtractSliceOp> {
ExtractSliceOpPattern(::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<mlir::tensor::ExtractSliceOp>(context,
benefit) {}
::mlir::LogicalResult
matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp,
::mlir::PatternRewriter &rewriter) const override {
ConcreteToBConcreteTypeConverter converter;
auto resultTy = extractSliceOp.result().getType();
auto resultEltTy =
resultTy.cast<mlir::RankedTensorType>()
.getElementType()
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
auto newResultTy = converter.convertType(resultTy);
// add 0 to the static_offsets
mlir::SmallVector<mlir::Attribute> staticOffsets;
staticOffsets.append(extractSliceOp.static_offsets().begin(),
extractSliceOp.static_offsets().end());
staticOffsets.push_back(rewriter.getI64IntegerAttr(0));
// add the lweSize to the sizes
mlir::SmallVector<mlir::Attribute> staticSizes;
staticSizes.append(extractSliceOp.static_sizes().begin(),
extractSliceOp.static_sizes().end());
staticSizes.push_back(
rewriter.getI64IntegerAttr(resultEltTy.getDimension() + 1));
// add 1 to the strides
mlir::SmallVector<mlir::Attribute> staticStrides;
staticStrides.append(extractSliceOp.static_strides().begin(),
extractSliceOp.static_strides().end());
staticStrides.push_back(rewriter.getI64IntegerAttr(1));
// replace tensor.extract_slice to the new one
rewriter.replaceOpWithNewOp<mlir::tensor::ExtractSliceOp>(
extractSliceOp, newResultTy, extractSliceOp.source(),
extractSliceOp.offsets(), extractSliceOp.sizes(),
extractSliceOp.strides(), rewriter.getArrayAttr(staticOffsets),
rewriter.getArrayAttr(staticSizes),
rewriter.getArrayAttr(staticStrides));
return ::mlir::success();
};
};
// This rewrite pattern transforms any instance of
// `tensor.extract` operators that operates on tensor of lwe ciphertext.
//
// Example:
//
// ```mlir
// %0 = tensor.extract %t[offsets...]
// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
// ```
//
// becomes:
//
// ```mlir
// %1 = tensor.extract_slice %arg0
// [offsets...] [1..., lweDimension+1] [1...]
// : tensor<...xlweDimension+1,i64> to
// tensor<1...xlweDimension+1,i64>
// %0 = linalg.tensor_collapse_shape %0 [[...]] :
// tensor<1x1xlweDimension+1xi64> into tensor<lweDimension+1xi64>
// ```
//
// TODO: since they are a bug on lowering extract_slice with rank reduction we
// add a linalg.tensor_collapse_shape after the extract_slice without rank
// reduction. See
// https://github.com/zama-ai/concrete-compiler-internal/issues/396.
struct ExtractOpPattern
: public mlir::OpRewritePattern<mlir::tensor::ExtractOp> {
ExtractOpPattern(::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<mlir::tensor::ExtractOp>(context, benefit) {}
::mlir::LogicalResult
matchAndRewrite(mlir::tensor::ExtractOp extractOp,
::mlir::PatternRewriter &rewriter) const override {
ConcreteToBConcreteTypeConverter converter;
auto lweResultTy =
extractOp.result()
.getType()
.dyn_cast_or_null<
mlir::concretelang::Concrete::LweCiphertextType>();
if (lweResultTy == nullptr) {
return mlir::failure();
}
auto newResultTy =
converter.convertType(lweResultTy).cast<mlir::RankedTensorType>();
auto rankOfResult = extractOp.indices().size() + 1;
// [min..., 0] for static_offsets ()
mlir::SmallVector<mlir::Attribute> staticOffsets(
rankOfResult,
rewriter.getI64IntegerAttr(std::numeric_limits<int64_t>::min()));
staticOffsets[staticOffsets.size() - 1] = rewriter.getI64IntegerAttr(0);
// [1..., lweDimension+1] for static_sizes
mlir::SmallVector<mlir::Attribute> staticSizes(
rankOfResult, rewriter.getI64IntegerAttr(1));
staticSizes[staticSizes.size() - 1] = rewriter.getI64IntegerAttr(
newResultTy.getDimSize(newResultTy.getRank() - 1));
// [1...] for static_strides
mlir::SmallVector<mlir::Attribute> staticStrides(
rankOfResult, rewriter.getI64IntegerAttr(1));
// replace tensor.extract_slice to the new one
mlir::SmallVector<int64_t> extractedSliceShape(
extractOp.indices().size() + 1, 0);
extractedSliceShape.reserve(extractOp.indices().size() + 1);
for (size_t i = 0; i < extractedSliceShape.size() - 1; i++) {
extractedSliceShape[i] = 1;
}
extractedSliceShape[extractedSliceShape.size() - 1] =
newResultTy.getDimSize(0);
auto extractedSliceType =
mlir::RankedTensorType::get(extractedSliceShape, rewriter.getI64Type());
auto extractedSlice = rewriter.create<mlir::tensor::ExtractSliceOp>(
extractOp.getLoc(), extractedSliceType, extractOp.tensor(),
extractOp.indices(), mlir::SmallVector<mlir::Value>{},
mlir::SmallVector<mlir::Value>{}, rewriter.getArrayAttr(staticOffsets),
rewriter.getArrayAttr(staticSizes),
rewriter.getArrayAttr(staticStrides));
mlir::ReassociationIndices reassociation;
for (int64_t i = 0; i < extractedSliceType.getRank(); i++) {
reassociation.push_back(i);
}
rewriter.replaceOpWithNewOp<mlir::linalg::TensorCollapseShapeOp>(
extractOp, newResultTy, extractedSlice,
mlir::SmallVector<mlir::ReassociationIndices>{reassociation});
return ::mlir::success();
};
};
// This rewrite pattern transforms any instance of
// `tensor.insert_slice` operators that operates on tensor of lwe ciphertext.
//
// Example:
//
// ```mlir
// %0 = tensor.insert_slice %arg1
// into %arg0[offsets...] [sizes...] [strides...]
// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>> into
// tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
// ```
//
// becomes:
//
// ```mlir
// %0 = tensor.insert_slice %arg1
// into %arg0[offsets..., 0] [sizes..., lweDimension+1] [strides..., 1]
// : tensor<...xlweDimension+1xi64> into
// tensor<...xlweDimension+1xi64>
// ```
struct InsertSliceOpPattern
: public mlir::OpRewritePattern<mlir::tensor::InsertSliceOp> {
InsertSliceOpPattern(::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<mlir::tensor::InsertSliceOp>(context,
benefit) {}
::mlir::LogicalResult
matchAndRewrite(mlir::tensor::InsertSliceOp insertSliceOp,
::mlir::PatternRewriter &rewriter) const override {
ConcreteToBConcreteTypeConverter converter;
auto resultTy = insertSliceOp.result().getType();
auto newResultTy =
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
// add 0 to static_offsets
mlir::SmallVector<mlir::Attribute> staticOffsets;
staticOffsets.append(insertSliceOp.static_offsets().begin(),
insertSliceOp.static_offsets().end());
staticOffsets.push_back(rewriter.getI64IntegerAttr(0));
// add lweDimension+1 to static_sizes
mlir::SmallVector<mlir::Attribute> staticSizes;
staticSizes.append(insertSliceOp.static_sizes().begin(),
insertSliceOp.static_sizes().end());
staticSizes.push_back(rewriter.getI64IntegerAttr(
newResultTy.getDimSize(newResultTy.getRank() - 1)));
// add 1 to the strides
mlir::SmallVector<mlir::Attribute> staticStrides;
staticStrides.append(insertSliceOp.static_strides().begin(),
insertSliceOp.static_strides().end());
staticStrides.push_back(rewriter.getI64IntegerAttr(1));
// replace tensor.insert_slice with the new one
rewriter.replaceOpWithNewOp<mlir::tensor::InsertSliceOp>(
insertSliceOp, newResultTy, insertSliceOp.source(),
insertSliceOp.dest(), insertSliceOp.offsets(), insertSliceOp.sizes(),
insertSliceOp.strides(), rewriter.getArrayAttr(staticOffsets),
rewriter.getArrayAttr(staticSizes),
rewriter.getArrayAttr(staticStrides));
return ::mlir::success();
};
};
// This rewrite pattern transforms any instance of
// `tensor.from_elements` operators that operates on tensor of lwe ciphertext.
//
// Example:
//
// ```mlir
// %0 = tensor.from_elements %e0, ..., %e(n-1)
// : tensor<Nx!Concrete.lwe_ciphertext<lweDim,p>>
// ```
//
// becomes:
//
// ```mlir
// %m = memref.alloc() : memref<NxlweDim+1xi64>
// %s0 = memref.subview %m[0, 0][1, lweDim+1][1, 1] : memref<lweDim+1xi64>
// %m0 = memref.buffer_cast %e0 : memref<lweDim+1xi64>
// memref.copy %m0, s0 : memref<lweDim+1xi64> to memref<lweDim+1xi64>
// ...
// %s(n-1) = memref.subview %m[(n-1), 0][1, lweDim+1][1, 1]
// : memref<lweDim+1xi64>
// %m(n-1) = memref.buffer_cast %e(n-1) : memref<lweDim+1xi64>
// memref.copy %e(n-1), s(n-1)
// : memref<lweDim+1xi64> to memref<lweDim+1xi64>
// %0 = memref.tensor_load %m : memref<NxlweDim+1xi64>
// ```
struct FromElementsOpPattern
: public mlir::OpRewritePattern<mlir::tensor::FromElementsOp> {
FromElementsOpPattern(::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<mlir::tensor::FromElementsOp>(context,
benefit) {}
::mlir::LogicalResult
matchAndRewrite(mlir::tensor::FromElementsOp fromElementsOp,
::mlir::PatternRewriter &rewriter) const override {
ConcreteToBConcreteTypeConverter converter;
auto resultTy = fromElementsOp.result().getType();
if (converter.isLegal(resultTy)) {
return mlir::failure();
}
auto eltResultTy =
resultTy.cast<mlir::RankedTensorType>()
.getElementType()
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
auto newTensorResultTy =
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
auto newMemrefResultTy = mlir::MemRefType::get(
newTensorResultTy.getShape(), newTensorResultTy.getElementType());
// %m = memref.alloc() : memref<NxlweDim+1xi64>
auto mOp = rewriter.create<mlir::memref::AllocOp>(fromElementsOp.getLoc(),
newMemrefResultTy);
// for i = 0 to n-1
// %si = memref.subview %m[i, 0][1, lweDim+1][1, 1] : memref<lweDim+1xi64>
// %mi = memref.buffer_cast %ei : memref<lweDim+1xi64>
// memref.copy %mi, si : memref<lweDim+1xi64> to memref<lweDim+1xi64>
auto subviewResultTy = mlir::MemRefType::get(
{eltResultTy.getDimension() + 1}, newMemrefResultTy.getElementType());
auto offset = 0;
for (auto eiOp : fromElementsOp.elements()) {
mlir::SmallVector<mlir::Attribute, 2> staticOffsets{
rewriter.getI64IntegerAttr(offset), rewriter.getI64IntegerAttr(0)};
mlir::SmallVector<mlir::Attribute, 2> staticSizes{
rewriter.getI64IntegerAttr(1),
rewriter.getI64IntegerAttr(eltResultTy.getDimension() + 1)};
mlir::SmallVector<mlir::Attribute, 2> staticStrides{
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)};
auto siOp = rewriter.create<mlir::memref::SubViewOp>(
fromElementsOp.getLoc(), subviewResultTy, mOp, mlir::ValueRange{},
mlir::ValueRange{}, mlir::ValueRange{},
rewriter.getArrayAttr(staticOffsets),
rewriter.getArrayAttr(staticSizes),
rewriter.getArrayAttr(staticStrides));
auto miOp = rewriter.create<mlir::memref::BufferCastOp>(
fromElementsOp.getLoc(), subviewResultTy, eiOp);
rewriter.create<mlir::memref::CopyOp>(fromElementsOp.getLoc(), miOp,
siOp);
offset++;
}
// Go back to tensor world
// %0 = memref.tensor_load %m : memref<NxlweDim+1xi64>
rewriter.replaceOpWithNewOp<mlir::memref::TensorLoadOp>(fromElementsOp,
mOp);
return ::mlir::success();
};
};
// This template rewrite pattern transforms any instance of
// `ShapeOp` operators that operates on tensor of lwe ciphertext by adding the
// lwe size as a size of the tensor result and by adding a trivial reassociation
// at the end of the reassociations map.
//
// Example:
//
// ```mlir
// %0 = "ShapeOp" %arg0 [reassocations...]
// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>> into
// tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
// ```
//
// becomes:
//
// ```mlir
// %0 = "ShapeOp" %arg0 [reassociations..., [inRank or outRank]]
// : tensor<...xlweDimesion+1xi64> into
// tensor<...xlweDimesion+1xi64>
// ```
template <typename ShapeOp, bool inRank>
struct TensorShapeOpPattern : public mlir::OpRewritePattern<ShapeOp> {
TensorShapeOpPattern(::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<ShapeOp>(context, benefit) {}
::mlir::LogicalResult
matchAndRewrite(ShapeOp shapeOp,
::mlir::PatternRewriter &rewriter) const override {
ConcreteToBConcreteTypeConverter converter;
auto resultTy = shapeOp.result().getType();
auto newResultTy =
((mlir::Type)converter.convertType(resultTy)).cast<mlir::MemRefType>();
// add [rank] to reassociations
auto oldReassocs = shapeOp.getReassociationIndices();
mlir::SmallVector<mlir::ReassociationIndices> newReassocs;
newReassocs.append(oldReassocs.begin(), oldReassocs.end());
mlir::ReassociationIndices lweAssoc;
auto reassocTy =
((mlir::Type)converter.convertType(
(inRank ? shapeOp.src() : shapeOp.result()).getType()))
.cast<mlir::MemRefType>();
lweAssoc.push_back(reassocTy.getRank());
newReassocs.push_back(lweAssoc);
rewriter.replaceOpWithNewOp<ShapeOp>(shapeOp, newResultTy, shapeOp.src(),
newReassocs);
return ::mlir::success();
};
};
// Add the instantiated TensorShapeOpPattern rewrite pattern with the `ShapeOp`
// to the patterns set and populate the conversion target.
template <typename ShapeOp, bool inRank>
void insertTensorShapeOpPattern(mlir::MLIRContext &context,
mlir::RewritePatternSet &patterns,
mlir::ConversionTarget &target) {
patterns.insert<TensorShapeOpPattern<ShapeOp, inRank>>(&context);
target.addDynamicallyLegalOp<ShapeOp>([&](ShapeOp op) {
ConcreteToBConcreteTypeConverter converter;
return converter.isLegal(op.result().getType());
});
}
// This template rewrite pattern transforms any instance of
// `MemrefOp` operators that returns a memref of lwe ciphertext to the same
// operator but which returns the bufferized lwe ciphertext.
//
// Example:
//
// ```mlir
// %0 = "MemrefOp"(...) : ... -> memref<...x!Concrete.lwe_ciphertext<lweDim,p>>
// ```
//
// becomes:
//
// ```mlir
// %0 = "MemrefOp"(...) : ... -> memref<...xlweDim+1xi64>
// ```
template <typename MemrefOp>
struct MemrefOpPattern : public mlir::OpRewritePattern<MemrefOp> {
MemrefOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<MemrefOp>(context, benefit) {}
mlir::LogicalResult
matchAndRewrite(MemrefOp memrefOp,
mlir::PatternRewriter &rewriter) const override {
ConcreteToBConcreteTypeConverter converter;
mlir::SmallVector<mlir::Type, 1> convertedTypes;
if (converter.convertTypes(memrefOp->getResultTypes(), convertedTypes)
.failed()) {
return mlir::failure();
}
rewriter.replaceOpWithNewOp<MemrefOp>(memrefOp, convertedTypes,
memrefOp->getOperands(),
memrefOp->getAttrs());
return ::mlir::success();
};
};
// Add the instantiated MemrefOpPattern rewrite pattern with the `MemrefOp`
// to the patterns set and populate the conversion target.
template <typename... MemrefOp>
void insertMemrefOpPattern(mlir::MLIRContext &context,
mlir::RewritePatternSet &patterns,
mlir::ConversionTarget &target) {
(void)std::initializer_list<int>{
0, (patterns.insert<MemrefOpPattern<MemrefOp>>(&context),
target.addDynamicallyLegalOp<MemrefOp>([&](MemrefOp op) {
ConcreteToBConcreteTypeConverter converter;
return converter.isLegal(op->getResultTypes());
}),
0)...};
}
// cc from Loops.cpp
static mlir::SmallVector<mlir::Value>
makeCanonicalAffineApplies(mlir::OpBuilder &b, mlir::Location loc,
mlir::AffineMap map,
mlir::ArrayRef<mlir::Value> vals) {
if (map.isEmpty())
return {};
assert(map.getNumInputs() == vals.size());
mlir::SmallVector<mlir::Value> res;
res.reserve(map.getNumResults());
auto dims = map.getNumDims();
for (auto e : map.getResults()) {
auto exprMap = mlir::AffineMap::get(dims, map.getNumSymbols(), e);
mlir::SmallVector<mlir::Value> operands(vals.begin(), vals.end());
canonicalizeMapAndOperands(&exprMap, &operands);
res.push_back(b.create<mlir::AffineApplyOp>(loc, exprMap, operands));
}
return res;
}
static std::pair<mlir::Value, mlir::Value>
makeOperandLoadOrSubview(mlir::OpBuilder &builder, mlir::Location loc,
mlir::ArrayRef<mlir::Value> allIvs,
mlir::linalg::LinalgOp linalgOp,
mlir::OpOperand *operand) {
ConcreteToBConcreteTypeConverter converter;
mlir::Value opVal = operand->get();
mlir::MemRefType opTy = opVal.getType().cast<mlir::MemRefType>();
if (auto lweType =
opTy.getElementType()
.dyn_cast_or_null<
mlir::concretelang::Concrete::LweCiphertextType>()) {
// For memref of ciphertexts operands create the inner memref
// subview to the ciphertext, and go back to the tensor type as BConcrete
// operators works with tensor.
// %op : memref<dim...xConcrete.lwe_ciphertext<lweDim,p>>
// %opInner = memref.subview %opInner[offsets...][1...][1,...]
// : memref<...xConcrete.lwe_ciphertext<lweDim,p>> to
// memref<Concrete.lwe_ciphertext<lweDim,p>>
auto tensorizedLweTy =
converter.convertType(lweType).cast<mlir::RankedTensorType>();
auto subviewResultTy = mlir::MemRefType::get(
tensorizedLweTy.getShape(), tensorizedLweTy.getElementType());
auto offsets = makeCanonicalAffineApplies(
builder, loc, linalgOp.getTiedIndexingMap(operand), allIvs);
mlir::SmallVector<mlir::Attribute> staticOffsets(
opTy.getRank(),
builder.getI64IntegerAttr(std::numeric_limits<int64_t>::min()));
mlir::SmallVector<mlir::Attribute> staticSizes(
opTy.getRank(), builder.getI64IntegerAttr(1));
mlir::SmallVector<mlir::Attribute> staticStrides(
opTy.getRank(), builder.getI64IntegerAttr(1));
auto subViewOp = builder.create<mlir::memref::SubViewOp>(
loc, subviewResultTy, opVal, offsets, mlir::ValueRange{},
mlir::ValueRange{}, builder.getArrayAttr(staticOffsets),
builder.getArrayAttr(staticSizes), builder.getArrayAttr(staticStrides));
return std::pair<mlir::Value, mlir::Value>(
subViewOp, builder.create<mlir::memref::TensorLoadOp>(loc, subViewOp));
} else {
// For memref of non ciphertexts load the value from the memref.
// with %op : memref<dim...xip>
// %opInner = memref.load %op[offsets...] : memref<dim...xip>
auto offsets = makeCanonicalAffineApplies(
builder, loc, linalgOp.getTiedIndexingMap(operand), allIvs);
return std::pair<mlir::Value, mlir::Value>(
nullptr,
builder.create<mlir::memref::LoadOp>(loc, operand->get(), offsets));
}
}
static void
inlineRegionAndEmitTensorStore(mlir::OpBuilder &builder, mlir::Location loc,
mlir::linalg::LinalgOp linalgOp,
llvm::ArrayRef<mlir::Value> indexedValues,
mlir::ValueRange outputBuffers) {
// Clone the block with the new operands
auto &block = linalgOp->getRegion(0).front();
mlir::BlockAndValueMapping map;
map.map(block.getArguments(), indexedValues);
for (auto &op : block.without_terminator()) {
auto *newOp = builder.clone(op, map);
map.map(op.getResults(), newOp->getResults());
}
// Create memref.tensor_store operation for each terminator operands
auto *terminator = block.getTerminator();
for (mlir::OpOperand &operand : terminator->getOpOperands()) {
mlir::Value toStore = map.lookupOrDefault(operand.get());
builder.create<mlir::memref::TensorStoreOp>(
loc, toStore, outputBuffers[operand.getOperandNumber()]);
}
}
template <typename LoopType>
class LinalgRewritePattern
: public mlir::OpInterfaceConversionPattern<mlir::linalg::LinalgOp> {
public:
using mlir::OpInterfaceConversionPattern<
mlir::linalg::LinalgOp>::OpInterfaceConversionPattern;
mlir::LogicalResult
matchAndRewrite(mlir::linalg::LinalgOp linalgOp,
mlir::ArrayRef<mlir::Value> operands,
mlir::ConversionPatternRewriter &rewriter) const override {
assert(linalgOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
auto iteratorTypes =
llvm::to_vector<4>(linalgOp.iterator_types().getValue());
mlir::SmallVector<mlir::Value> allIvs;
mlir::linalg::GenerateLoopNest<LoopType>::doit(
rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes,
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange ivs,
mlir::ValueRange operandValuesToUse) -> mlir::scf::ValueVector {
// Keep indexed values to replace the linalg.generic block arguments
// by them
mlir::SmallVector<mlir::Value> indexedValues;
indexedValues.reserve(linalgOp.getNumInputsAndOutputs());
assert(
operandValuesToUse == linalgOp->getOperands() &&
"expect operands are captured and not passed by loop argument");
allIvs.append(ivs.begin(), ivs.end());
// For all input operands create the inner operand
for (mlir::OpOperand *inputOperand : linalgOp.getInputOperands()) {
auto innerOperand = makeOperandLoadOrSubview(
builder, loc, allIvs, linalgOp, inputOperand);
indexedValues.push_back(innerOperand.second);
}
// For all output operands create the inner operand
assert(linalgOp.getOutputOperands() ==
linalgOp.getOutputBufferOperands() &&
"expect only memref as output operands");
mlir::SmallVector<mlir::Value> outputBuffers;
for (mlir::OpOperand *outputOperand : linalgOp.getOutputOperands()) {
auto innerOperand = makeOperandLoadOrSubview(
builder, loc, allIvs, linalgOp, outputOperand);
indexedValues.push_back(innerOperand.second);
assert(innerOperand.first != nullptr &&
"Expected a memref subview as output buffer");
outputBuffers.push_back(innerOperand.first);
}
// Finally inline the linalgOp region
inlineRegionAndEmitTensorStore(builder, loc, linalgOp, indexedValues,
outputBuffers);
return mlir::scf::ValueVector{};
});
rewriter.eraseOp(linalgOp);
return mlir::success();
};
};
void ConcreteToBConcretePass::runOnOperation() {
auto op = this->getOperation();
// First of all we transform LinalgOp that work on tensor of ciphertext to
// work on memref.
{
mlir::ConversionTarget target(getContext());
mlir::BufferizeTypeConverter converter;
// Mark all Standard operations legal.
target
.addLegalDialect<mlir::arith::ArithmeticDialect, mlir::AffineDialect,
mlir::memref::MemRefDialect, mlir::StandardOpsDialect,
mlir::tensor::TensorDialect>();
// Mark all Linalg operations illegal as long as they work on encrypted
// tensors.
target.addDynamicallyLegalOp<mlir::linalg::GenericOp, mlir::linalg::YieldOp,
mlir::linalg::CopyOp>(
[&](mlir::Operation *op) { return converter.isLegal(op); });
mlir::RewritePatternSet patterns(&getContext());
mlir::linalg::populateLinalgBufferizePatterns(converter, patterns);
if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
signalPassFailure();
return;
}
}
// Then convert ciphertext to tensor or add a dimension to tensor of
// ciphertext and memref of ciphertext
{
mlir::ConversionTarget target(getContext());
ConcreteToBConcreteTypeConverter converter;
mlir::OwningRewritePatternList patterns(&getContext());
// All BConcrete ops are legal after the conversion
target.addLegalDialect<mlir::concretelang::BConcrete::BConcreteDialect>();
// Add Concrete ops are illegal after the conversion unless those which are
// explicitly marked as legal (more or less operators that didn't work on
// ciphertexts)
target.addIllegalDialect<mlir::concretelang::Concrete::ConcreteDialect>();
target.addLegalOp<mlir::concretelang::Concrete::EncodeIntOp>();
target.addLegalOp<mlir::concretelang::Concrete::GlweFromTable>();
target.addLegalOp<mlir::concretelang::Concrete::IntToCleartextOp>();
// Add patterns to convert the zero ops to tensor.generate
patterns
.insert<ZeroOpPattern<mlir::concretelang::Concrete::ZeroTensorLWEOp>,
ZeroOpPattern<mlir::concretelang::Concrete::ZeroLWEOp>>(
&getContext());
target.addLegalOp<mlir::tensor::GenerateOp, mlir::tensor::YieldOp>();
// Add patterns to trivialy convert Concrete op to the equivalent
// BConcrete op
target.addLegalOp<mlir::linalg::InitTensorOp>();
patterns.insert<
LowToBConcrete<mlir::concretelang::Concrete::AddLweCiphertextsOp,
mlir::concretelang::BConcrete::AddLweBuffersOp>,
LowToBConcrete<
mlir::concretelang::Concrete::AddPlaintextLweCiphertextOp,
mlir::concretelang::BConcrete::AddPlaintextLweBufferOp>,
LowToBConcrete<
mlir::concretelang::Concrete::MulCleartextLweCiphertextOp,
mlir::concretelang::BConcrete::MulCleartextLweBufferOp>,
LowToBConcrete<
mlir::concretelang::Concrete::MulCleartextLweCiphertextOp,
mlir::concretelang::BConcrete::MulCleartextLweBufferOp>,
LowToBConcrete<mlir::concretelang::Concrete::NegateLweCiphertextOp,
mlir::concretelang::BConcrete::NegateLweBufferOp>,
LowToBConcrete<mlir::concretelang::Concrete::KeySwitchLweOp,
mlir::concretelang::BConcrete::KeySwitchLweBufferOp>,
LowToBConcrete<mlir::concretelang::Concrete::BootstrapLweOp,
mlir::concretelang::BConcrete::BootstrapLweBufferOp>>(
&getContext());
// Add patterns to rewrite tensor operators that works on encrypted tensors
patterns.insert<ExtractSliceOpPattern, ExtractOpPattern,
InsertSliceOpPattern, FromElementsOpPattern>(&getContext());
target.addDynamicallyLegalOp<
mlir::tensor::ExtractSliceOp, mlir::tensor::ExtractOp,
mlir::tensor::InsertSliceOp, mlir::tensor::FromElementsOp>(
[&](mlir::Operation *op) {
return converter.isLegal(op->getResult(0).getType());
});
target.addLegalOp<mlir::memref::CopyOp,
mlir::linalg::TensorCollapseShapeOp>();
// Add patterns to rewrite some of memref ops that was introduced by the
// linalg bufferization of encrypted tensor (first conversion of this pass)
insertTensorShapeOpPattern<mlir::memref::ExpandShapeOp, true>(
getContext(), patterns, target);
insertTensorShapeOpPattern<mlir::memref::CollapseShapeOp, false>(
getContext(), patterns, target);
// Add patterns to rewrite linalg op to nested loops with views on
// ciphertexts
patterns.insert<LinalgRewritePattern<mlir::scf::ForOp>>(converter,
&getContext());
target.addLegalOp<mlir::arith::ConstantOp, mlir::scf::ForOp,
mlir::scf::YieldOp, mlir::AffineApplyOp,
mlir::memref::SubViewOp, mlir::memref::LoadOp,
mlir::memref::TensorStoreOp>();
// Add patterns to do the conversion of func
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
return converter.isSignatureLegal(funcOp.getType()) &&
converter.isLegal(&funcOp.getBody());
});
// Add patterns to convert some memref operators that is generated by
// previous step
insertMemrefOpPattern<mlir::memref::AllocOp, mlir::memref::BufferCastOp,
mlir::memref::TensorLoadOp>(getContext(), patterns,
target);
// Conversion of RT Dialect Ops
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DataflowTaskOp>>(patterns.getContext(),
converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DataflowTaskOp>(target, converter);
// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns))
.failed()) {
this->signalPassFailure();
}
}
}
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertConcreteToBConcretePass() {
return std::make_unique<ConcreteToBConcretePass>();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -264,16 +264,6 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
if (target == Target::CONCRETE)
return std::move(res);
// Concrete -> Canonical dialects
if (mlir::concretelang::pipeline::lowerConcreteToStd(mlirContext, module,
enablePass)
.failed()) {
return errorDiag(
"Lowering from Concrete to canonical MLIR dialects failed");
}
if (target == Target::STD)
return std::move(res);
// Generate client parameters if requested
if (this->generateClientParameters) {
if (!this->clientParametersFuncName.hasValue()) {
@@ -304,6 +294,28 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
}
}
// Concrete -> BConcrete
if (mlir::concretelang::pipeline::lowerConcreteToBConcrete(
mlirContext, module, this->enablePass)
.failed()) {
return StreamStringError(
"Lowering from Concrete to Bufferized Concrete failed");
}
if (target == Target::BCONCRETE) {
return std::move(res);
}
// BConcrete -> Canonical dialects
if (mlir::concretelang::pipeline::lowerBConcreteToStd(mlirContext, module,
enablePass)
.failed()) {
return errorDiag(
"Lowering from Bufferized Concrete to canonical MLIR dialects failed");
}
if (target == Target::STD)
return std::move(res);
// MLIR canonical dialects -> LLVM Dialect
if (mlir::concretelang::pipeline::lowerStdToLLVMDialect(
mlirContext, module, enablePass,

View File

@@ -181,10 +181,25 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
}
mlir::LogicalResult
lowerConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
pipelinePrinting("ConcreteToStd", pm, context);
pipelinePrinting("ConcreteToBConcrete", pm, context);
addPotentiallyNestedPass(
pm, mlir::concretelang::createConvertConcreteToBConcretePass(),
enablePass);
return pm.run(module.getOperation());
}
mlir::LogicalResult
lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
pipelinePrinting("BConcreteToStd", pm, context);
addPotentiallyNestedPass(
pm, mlir::concretelang::createConvertBConcreteToBConcreteCAPIPass(),
enablePass);
addPotentiallyNestedPass(
pm, mlir::concretelang::createConvertConcreteToConcreteCAPIPass(),
enablePass);

View File

@@ -41,6 +41,7 @@ enum Action {
DUMP_FHE,
DUMP_TFHE,
DUMP_CONCRETE,
DUMP_BCONCRETE,
DUMP_STD,
DUMP_LLVM_DIALECT,
DUMP_LLVM_IR,
@@ -101,6 +102,9 @@ static llvm::cl::opt<enum Action> action(
"Lower to TFHE and dump result")),
llvm::cl::values(clEnumValN(Action::DUMP_CONCRETE, "dump-concrete",
"Lower to Concrete and dump result")),
llvm::cl::values(
clEnumValN(Action::DUMP_BCONCRETE, "dump-bconcrete",
"Lower to Bufferized Concrete and dump result")),
llvm::cl::values(clEnumValN(Action::DUMP_STD, "dump-std",
"Lower to std and dump result")),
llvm::cl::values(clEnumValN(Action::DUMP_LLVM_DIALECT, "dump-llvm-dialect",
@@ -324,6 +328,9 @@ mlir::LogicalResult processInputBuffer(
case Action::DUMP_CONCRETE:
target = mlir::concretelang::CompilerEngine::Target::CONCRETE;
break;
case Action::DUMP_BCONCRETE:
target = mlir::concretelang::CompilerEngine::Target::BCONCRETE;
break;
case Action::DUMP_STD:
target = mlir::concretelang::CompilerEngine::Target::STD;
break;

View File

@@ -0,0 +1,14 @@
// RUN: concretecompiler --passes bconcrete-to-bconcrete-c-api --action=dump-std %s 2>&1| FileCheck %s
// CHECK: func @add_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<2049xi64>, %arg2: !Concrete.context) -> tensor<2049xi64>
func @add_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<2049xi64>) -> tensor<2049xi64> {
// CHECK-NEXT: %0 = linalg.init_tensor [2049] : tensor<2049xi64>
// CHECK-NEXT: %1 = tensor.cast %0 : tensor<2049xi64> to tensor<?xi64>
// CHECK-NEXT: %2 = tensor.cast %arg0 : tensor<2049xi64> to tensor<?xi64>
// CHECK-NEXT: %3 = tensor.cast %arg1 : tensor<2049xi64> to tensor<?xi64>
// CHECK-NEXT: call @memref_add_lwe_ciphertexts_u64(%1, %2, %3) : (tensor<?xi64>, tensor<?xi64>, tensor<?xi64>) -> ()
// CHECK-NEXT: return %0 : tensor<2049xi64>
%0 = linalg.init_tensor [2049] : tensor<2049xi64>
"BConcrete.add_lwe_buffer"(%0, %arg0, %arg1) : (tensor<2049xi64>, tensor<2049xi64>, tensor<2049xi64>) -> ()
return %0 : tensor<2049xi64>
}

View File

@@ -0,0 +1,37 @@
// RUN: concretecompiler --passes bconcrete-to-bconcrete-c-api --action=dump-std %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_glwe_const_int(%arg0: tensor<1025xi64>, %arg1: !Concrete.context) -> tensor<1025xi64>
func @add_glwe_const_int(%arg0: tensor<1025xi64>) -> tensor<1025xi64> {
// CHECK-NEXT: %c1_i8 = arith.constant 1 : i8
// CHECK-NEXT: %0 = arith.extui %c1_i8 : i8 to i64
// CHECK-NEXT: %c56_i64 = arith.constant 56 : i64
// CHECK-NEXT: %1 = arith.shli %0, %c56_i64 : i64
// CHECK-NEXT: %2 = linalg.init_tensor [1025] : tensor<1025xi64>
// CHECK-NEXT: %3 = tensor.cast %2 : tensor<1025xi64> to tensor<?xi64>
// CHECK-NEXT: %4 = tensor.cast %arg0 : tensor<1025xi64> to tensor<?xi64>
// CHECK-NEXT: call @memref_add_plaintext_lwe_ciphertext_u64(%3, %4, %1) : (tensor<?xi64>, tensor<?xi64>, i64) -> ()
// CHECK-NEXT: return %2 : tensor<1025xi64>
%0 = arith.constant 1 : i8
%1 = "Concrete.encode_int"(%0) : (i8) -> !Concrete.plaintext<8>
%2 = linalg.init_tensor [1025] : tensor<1025xi64>
"BConcrete.add_plaintext_lwe_buffer"(%2, %arg0, %1) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.plaintext<8>) -> ()
return %2 : tensor<1025xi64>
}
// CHECK-LABEL: func @add_glwe_int(%arg0: tensor<1025xi64>, %arg1: i5, %arg2: !Concrete.context) -> tensor<1025xi64>
func @add_glwe_int(%arg0: tensor<1025xi64>, %arg1: i5) -> tensor<1025xi64> {
// CHECK-NEXT: %0 = arith.extui %arg1 : i5 to i64
// CHECK-NEXT: %c59_i64 = arith.constant 59 : i64
// CHECK-NEXT: %1 = arith.shli %0, %c59_i64 : i64
// CHECK-NEXT: %2 = linalg.init_tensor [1025] : tensor<1025xi64>
// CHECK-NEXT: %3 = tensor.cast %2 : tensor<1025xi64> to tensor<?xi64>
// CHECK-NEXT: %4 = tensor.cast %arg0 : tensor<1025xi64> to tensor<?xi64>
// CHECK-NEXT: call @memref_add_plaintext_lwe_ciphertext_u64(%3, %4, %1) : (tensor<?xi64>, tensor<?xi64>, i64) -> ()
// CHECK-NEXT: return %2 : tensor<1025xi64>
%0 = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5>
%1 = linalg.init_tensor [1025] : tensor<1025xi64>
"BConcrete.add_plaintext_lwe_buffer"(%1, %arg0, %0) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.plaintext<5>) -> ()
return %1 : tensor<1025xi64>
}

View File

@@ -0,0 +1,15 @@
// RUN: concretecompiler --passes bconcrete-to-bconcrete-c-api --action=dump-std %s 2>&1| FileCheck %s
// CHECK: func @bootstrap_lwe(%arg0: tensor<1025xi64>, %arg1: !Concrete.glwe_ciphertext, %arg2: !Concrete.context) -> tensor<1025xi64> {
// CHECK-NEXT: %0 = linalg.init_tensor [1025] : tensor<1025xi64>
// CHECK-NEXT: %1 = call @get_bootstrap_key(%arg2) : (!Concrete.context) -> !Concrete.lwe_bootstrap_key
// CHECK-NEXT: %2 = tensor.cast %0 : tensor<1025xi64> to tensor<?xi64>
// CHECK-NEXT: %3 = tensor.cast %arg0 : tensor<1025xi64> to tensor<?xi64>
// CHECK-NEXT: call @memref_bootstrap_lwe_u64(%1, %2, %3, %arg1) : (!Concrete.lwe_bootstrap_key, tensor<?xi64>, tensor<?xi64>, !Concrete.glwe_ciphertext) -> ()
// CHECK-NEXT: return %0 : tensor<1025xi64>
// CHECK-NEXT: }
func @bootstrap_lwe(%arg0: tensor<1025xi64>, %arg1: !Concrete.glwe_ciphertext) -> tensor<1025xi64> {
%0 = linalg.init_tensor [1025] : tensor<1025xi64>
"BConcrete.bootstrap_lwe_buffer"(%0, %arg0, %arg1) {baseLog = 2 : i32, glweDimension = 1 : i32, level = 3 : i32, polynomialSize = 1024 : i32} : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.glwe_ciphertext) -> ()
return %0 : tensor<1025xi64>
}

View File

@@ -0,0 +1,15 @@
// RUN: concretecompiler --passes bconcrete-to-bconcrete-c-api --action=dump-std %s 2>&1| FileCheck %s
//CHECK: func @keyswitch_lwe(%arg0: tensor<1025xi64>, %arg1: !Concrete.context) -> tensor<1025xi64> {
//CHECK-NEXT: %0 = linalg.init_tensor [1025] : tensor<1025xi64>
//CHECK-NEXT: %1 = call @get_keyswitch_key(%arg1) : (!Concrete.context) -> !Concrete.lwe_key_switch_key
//CHECK-NEXT: %2 = tensor.cast %0 : tensor<1025xi64> to tensor<?xi64>
//CHECK-NEXT: %3 = tensor.cast %arg0 : tensor<1025xi64> to tensor<?xi64>
//CHECK-NEXT: call @memref_keyswitch_lwe_u64(%1, %2, %3) : (!Concrete.lwe_key_switch_key, tensor<?xi64>, tensor<?xi64>) -> ()
//CHECK-NEXT: return %0 : tensor<1025xi64>
//CHECK-NEXT: }
func @keyswitch_lwe(%arg0: tensor<1025xi64>) -> tensor<1025xi64> {
%0 = linalg.init_tensor [1025] : tensor<1025xi64>
"BConcrete.keyswitch_lwe_buffer"(%0, %arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 1 : i32} : (tensor<1025xi64>, tensor<1025xi64>) -> ()
return %0 : tensor<1025xi64>
}

View File

@@ -0,0 +1,33 @@
// RUN: concretecompiler --passes bconcrete-to-bconcrete-c-api --action=dump-std %s 2>&1| FileCheck %s
// CHECK-LABEL: func @mul_lwe_const_int(%arg0: tensor<1025xi64>, %arg1: !Concrete.context) -> tensor<1025xi64>
func @mul_lwe_const_int(%arg0: tensor<1025xi64>) -> tensor<1025xi64> {
// CHECK-NEXT: %c1_i8 = arith.constant 1 : i8
// CHECK-NEXT: %0 = arith.extui %c1_i8 : i8 to i64
// CHECK-NEXT: %1 = linalg.init_tensor [1025] : tensor<1025xi64>
// CHECK-NEXT: %2 = tensor.cast %1 : tensor<1025xi64> to tensor<?xi64>
// CHECK-NEXT: %3 = tensor.cast %arg0 : tensor<1025xi64> to tensor<?xi64>
// CHECK-NEXT: call @memref_mul_cleartext_lwe_ciphertext_u64(%2, %3, %0) : (tensor<?xi64>, tensor<?xi64>, i64) -> ()
// CHECK-NEXT: return %1 : tensor<1025xi64>
%c1_i8 = arith.constant 1 : i8
%1 = "Concrete.int_to_cleartext"(%c1_i8) : (i8) -> !Concrete.cleartext<8>
%2 = linalg.init_tensor [1025] : tensor<1025xi64>
"BConcrete.mul_cleartext_lwe_buffer"(%2, %arg0, %1) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.cleartext<8>) -> ()
return %2 : tensor<1025xi64>
}
// CHECK-LABEL: func @mul_lwe_int(%arg0: tensor<1025xi64>, %arg1: i5, %arg2: !Concrete.context) -> tensor<1025xi64>
func @mul_lwe_int(%arg0: tensor<1025xi64>, %arg1: i5) -> tensor<1025xi64> {
// CHECK-NEXT: %0 = arith.extui %arg1 : i5 to i64
// CHECK-NEXT: %1 = linalg.init_tensor [1025] : tensor<1025xi64>
// CHECK-NEXT: %2 = tensor.cast %1 : tensor<1025xi64> to tensor<?xi64>
// CHECK-NEXT: %3 = tensor.cast %arg0 : tensor<1025xi64> to tensor<?xi64>
// CHECK-NEXT: call @memref_mul_cleartext_lwe_ciphertext_u64(%2, %3, %0) : (tensor<?xi64>, tensor<?xi64>, i64) -> ()
// CHECK-NEXT: return %1 : tensor<1025xi64>
%0 = "Concrete.int_to_cleartext"(%arg1) : (i5) -> !Concrete.cleartext<5>
%1 = linalg.init_tensor [1025] : tensor<1025xi64>
"BConcrete.mul_cleartext_lwe_buffer"(%1, %arg0, %0) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.cleartext<5>) -> ()
return %1 : tensor<1025xi64>
}

View File

@@ -0,0 +1,13 @@
// RUN: concretecompiler --passes bconcrete-to-bconcrete-c-api --action=dump-std %s 2>&1| FileCheck %s
// CHECK-LABEL: func @neg_lwe(%arg0: tensor<1025xi64>, %arg1: !Concrete.context) -> tensor<1025xi64> {
func @neg_lwe(%arg0: tensor<1025xi64>) -> tensor<1025xi64> {
// CHECK-NEXT: %0 = linalg.init_tensor [1025] : tensor<1025xi64>
// CHECK-NEXT: %1 = tensor.cast %0 : tensor<1025xi64> to tensor<?xi64>
// CHECK-NEXT: %2 = tensor.cast %arg0 : tensor<1025xi64> to tensor<?xi64>
// CHECK-NEXT: call @memref_negate_lwe_ciphertext_u64(%1, %2) : (tensor<?xi64>, tensor<?xi64>) -> ()
// CHECK-NEXT: return %0 : tensor<1025xi64>
%0 = linalg.init_tensor [1025] : tensor<1025xi64>
"BConcrete.negate_lwe_buffer"(%0, %arg0) : (tensor<1025xi64>, tensor<1025xi64>) -> ()
return %0 : tensor<1025xi64>
}

View File

@@ -0,0 +1,47 @@
// RUN: concretecompiler --passes bconcrete-to-bconcrete-c-api --action=dump-std %s 2>&1| FileCheck %s
// CHECK-LABEL: func @sub_const_int_lwe(%arg0: tensor<1025xi64>, %arg1: !Concrete.context) -> tensor<1025xi64> {
func @sub_const_int_lwe(%arg0: tensor<1025xi64>) -> tensor<1025xi64> {
// CHECK-NEXT: %c1_i8 = arith.constant 1 : i8
// CHECK-NEXT: %0 = linalg.init_tensor [1025] : tensor<1025xi64>
// CHECK-NEXT: %1 = tensor.cast %0 : tensor<1025xi64> to tensor<?xi64>
// CHECK-NEXT: %2 = tensor.cast %arg0 : tensor<1025xi64> to tensor<?xi64>
// CHECK-NEXT: call @memref_negate_lwe_ciphertext_u64(%1, %2) : (tensor<?xi64>, tensor<?xi64>) -> ()
// CHECK-NEXT: %3 = arith.extui %c1_i8 : i8 to i64
// CHECK-NEXT: %c56_i64 = arith.constant 56 : i64
// CHECK-NEXT: %4 = arith.shli %3, %c56_i64 : i64
// CHECK-NEXT: %5 = linalg.init_tensor [1025] : tensor<1025xi64>
// CHECK-NEXT: %6 = tensor.cast %5 : tensor<1025xi64> to tensor<?xi64>
// CHECK-NEXT: %7 = tensor.cast %0 : tensor<1025xi64> to tensor<?xi64>
// CHECK-NEXT: call @memref_add_plaintext_lwe_ciphertext_u64(%6, %7, %4) : (tensor<?xi64>, tensor<?xi64>, i64) -> ()
// CHECK-NEXT: return %5 : tensor<1025xi64>
%0 = arith.constant 1 : i8
%1 = linalg.init_tensor [1025] : tensor<1025xi64>
"BConcrete.negate_lwe_buffer"(%1, %arg0) : (tensor<1025xi64>, tensor<1025xi64>) -> ()
%2 = "Concrete.encode_int"(%0) : (i8) -> !Concrete.plaintext<8>
%3 = linalg.init_tensor [1025] : tensor<1025xi64>
"BConcrete.add_plaintext_lwe_buffer"(%3, %1, %2) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.plaintext<8>) -> ()
return %3 : tensor<1025xi64>
}
// CHECK-LABEL: func @sub_int_lwe(%arg0: tensor<1025xi64>, %arg1: i5, %arg2: !Concrete.context) -> tensor<1025xi64> {
func @sub_int_lwe(%arg0: tensor<1025xi64>, %arg1: i5) -> tensor<1025xi64> {
// CHECK-NEXT: %0 = linalg.init_tensor [1025] : tensor<1025xi64>
// CHECK-NEXT: %1 = tensor.cast %0 : tensor<1025xi64> to tensor<?xi64>
// CHECK-NEXT: %2 = tensor.cast %arg0 : tensor<1025xi64> to tensor<?xi64>
// CHECK-NEXT: call @memref_negate_lwe_ciphertext_u64(%1, %2) : (tensor<?xi64>, tensor<?xi64>) -> ()
// CHECK-NEXT: %3 = arith.extui %arg1 : i5 to i64
// CHECK-NEXT: %c59_i64 = arith.constant 59 : i64
// CHECK-NEXT: %4 = arith.shli %3, %c59_i64 : i64
// CHECK-NEXT: %5 = linalg.init_tensor [1025] : tensor<1025xi64>
// CHECK-NEXT: %6 = tensor.cast %5 : tensor<1025xi64> to tensor<?xi64>
// CHECK-NEXT: %7 = tensor.cast %0 : tensor<1025xi64> to tensor<?xi64>
// CHECK-NEXT: call @memref_add_plaintext_lwe_ciphertext_u64(%6, %7, %4) : (tensor<?xi64>, tensor<?xi64>, i64) -> ()
// CHECK-NEXT: return %5 : tensor<1025xi64>
%0 = linalg.init_tensor [1025] : tensor<1025xi64>
"BConcrete.negate_lwe_buffer"(%0, %arg0) : (tensor<1025xi64>, tensor<1025xi64>) -> ()
%1 = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5>
%2 = linalg.init_tensor [1025] : tensor<1025xi64>
"BConcrete.add_plaintext_lwe_buffer"(%2, %0, %1) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.plaintext<5>) -> ()
return %2 : tensor<1025xi64>
}

View File

@@ -0,0 +1,10 @@
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_glwe(%arg0: tensor<2049xi64>, %arg1: tensor<2049xi64>) -> tensor<2049xi64>
func @add_glwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> {
// CHECK-NEXT: %[[V1:.*]] = linalg.init_tensor [2049] : tensor<2049xi64>
// CHECK-NEXT: "BConcrete.add_lwe_buffer"(%[[V1]], %arg0, %arg1) : (tensor<2049xi64>, tensor<2049xi64>, tensor<2049xi64>) -> ()
// CHECK-NEXT: return %[[V1]] : tensor<2049xi64>
%0 = "Concrete.add_lwe_ciphertexts"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7>
return %0 : !Concrete.lwe_ciphertext<2048,7>
}

View File

@@ -0,0 +1,25 @@
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_glwe_const_int(%arg0: tensor<1025xi64>) -> tensor<1025xi64>
func @add_glwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8
// CHECK-NEXT: %[[V2:.*]] = "Concrete.encode_int"(%[[V1]]) : (i8) -> !Concrete.plaintext<8>
// CHECK-NEXT: %[[V3:.*]] = linalg.init_tensor [1025] : tensor<1025xi64>
// CHECK-NEXT: "BConcrete.add_plaintext_lwe_buffer"(%1, %arg0, %0) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.plaintext<8>) -> ()
// CHECK-NEXT: return %[[V3]] : tensor<1025xi64>
%0 = arith.constant 1 : i8
%1 = "Concrete.encode_int"(%0) : (i8) -> !Concrete.plaintext<8>
%2 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %1) : (!Concrete.lwe_ciphertext<1024,7>, !Concrete.plaintext<8>) -> !Concrete.lwe_ciphertext<1024,7>
return %2 : !Concrete.lwe_ciphertext<1024,7>
}
// CHECK-LABEL: func @add_glwe_int(%arg0: tensor<1025xi64>, %arg1: i5) -> tensor<1025xi64>
func @add_glwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> {
// CHECK-NEXT: %[[V1:.*]] = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5>
// CHECK-NEXT: %[[V2:.*]] = linalg.init_tensor [1025] : tensor<1025xi64>
// CHECK-NEXT: "BConcrete.add_plaintext_lwe_buffer"(%[[V2:.*]], %arg0, %[[V1:.*]]) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.plaintext<5>) -> ()
// CHECK-NEXT: return %[[V2]] : tensor<1025xi64>
%0 = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5>
%1 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,4>, !Concrete.plaintext<5>) -> !Concrete.lwe_ciphertext<1024,4>
return %1 : !Concrete.lwe_ciphertext<1024,4>
}

View File

@@ -0,0 +1,15 @@
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table(%arg0: tensor<1025xi64>, %arg1: tensor<16xi64>) -> tensor<1025xi64>
func @apply_lookup_table(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: tensor<16xi64>) -> !Concrete.lwe_ciphertext<1024,4> {
// CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%arg1) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext
// CHECK-NEXT: %[[V2:.*]] = linalg.init_tensor [601] : tensor<601xi64>
// CHECK-NEXT: "BConcrete.keyswitch_lwe_buffer"(%[[V2]], %arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (tensor<601xi64>, tensor<1025xi64>) -> ()
// CHECK-NEXT: %[[V3:.*]] = linalg.init_tensor [1025] : tensor<1025xi64>
// CHECK-NEXT: "BConcrete.bootstrap_lwe_buffer"(%[[V3]], %[[V2]], %[[V1]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (tensor<1025xi64>, tensor<601xi64>, !Concrete.glwe_ciphertext) -> ()
// CHECK-NEXT: return %[[V3]] : tensor<1025xi64>
%0 = "Concrete.glwe_from_table"(%arg1) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext
%1 = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<600,4>
%2 = "Concrete.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<1024,4>
return %2 : !Concrete.lwe_ciphertext<1024,4>
}

View File

@@ -0,0 +1,17 @@
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table_cst(%arg0: tensor<2049xi64>) -> tensor<2049xi64>
func @apply_lookup_table_cst(%arg0: !Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<2048,4> {
// CHECK-NEXT: %[[TABLE:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
// CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%[[TABLE:.*]]) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext
// CHECK-NEXT: %[[V2:.*]] = linalg.init_tensor [601] : tensor<601xi64>
// CHECK-NEXT: "BConcrete.keyswitch_lwe_buffer"([[V2:.*]], %arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (tensor<601xi64>, tensor<2049xi64>) -> ()
// CHECK-NEXT: %[[V3:.*]] = linalg.init_tensor [2049] : tensor<2049xi64>
// CHECK-NEXT: "BConcrete.bootstrap_lwe_buffer"(%[[V3:.*]], %[[V2:.*]], %[[V1:.*]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (tensor<2049xi64>, tensor<601xi64>, !Concrete.glwe_ciphertext) -> ()
// CHECK-NEXT: return %[[V3]] : tensor<2049xi64>
%tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
%0 = "Concrete.glwe_from_table"(%tlu) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext
%1 = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (!Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<600,4>
%2 = "Concrete.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<2048,4>
return %2 : !Concrete.lwe_ciphertext<2048,4>
}

View File

@@ -0,0 +1,8 @@
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
// CHECK: func @identity(%arg0: tensor<1025xi64>) -> tensor<1025xi64> {
// CHECK-NEXT: return %arg0 : tensor<1025xi64>
// CHECK-NEXT: }
func @identity(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
return %arg0 : !Concrete.lwe_ciphertext<1024,7>
}

View File

@@ -0,0 +1,25 @@
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
// CHECK-LABEL: func @mul_lwe_const_int(%arg0: tensor<1025xi64>) -> tensor<1025xi64>
func @mul_lwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8
// CHECK-NEXT: %[[V2:.*]] = "Concrete.int_to_cleartext"(%c1_i8) : (i8) -> !Concrete.cleartext<8>
// CHECK-NEXT: %[[V3:.*]] = linalg.init_tensor [1025] : tensor<1025xi64>
// CHECK-NEXT: "BConcrete.mul_cleartext_lwe_buffer"(%[[V3]], %arg0, %[[V2]]) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.cleartext<8>) -> ()
// CHECK-NEXT: return %[[V3]] : tensor<1025xi64>
%0 = arith.constant 1 : i8
%1 = "Concrete.int_to_cleartext"(%0) : (i8) -> !Concrete.cleartext<8>
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %1) : (!Concrete.lwe_ciphertext<1024,7>, !Concrete.cleartext<8>) -> !Concrete.lwe_ciphertext<1024,7>
return %2 : !Concrete.lwe_ciphertext<1024,7>
}
// CHECK-LABEL: func @mul_lwe_int(%arg0: tensor<1025xi64>, %arg1: i5) -> tensor<1025xi64>
func @mul_lwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> {
// CHECK-NEXT: %[[V1:.*]] = "Concrete.int_to_cleartext"(%arg1) : (i5) -> !Concrete.cleartext<5>
// CHECK-NEXT: %[[V2:.*]] = linalg.init_tensor [1025] : tensor<1025xi64>
// CHECK-NEXT: "BConcrete.mul_cleartext_lwe_buffer"(%[[V2]], %arg0, %[[V1]]) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.cleartext<5>) -> ()
// CHECK-NEXT: return %[[V2]] : tensor<1025xi64>
%0 = "Concrete.int_to_cleartext"(%arg1) : (i5) -> !Concrete.cleartext<5>
%1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,4>, !Concrete.cleartext<5>) -> !Concrete.lwe_ciphertext<1024,4>
return %1 : !Concrete.lwe_ciphertext<1024,4>
}

View File

@@ -0,0 +1,10 @@
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
// CHECK-LABEL: func @neg_lwe(%arg0: tensor<1025xi64>) -> tensor<1025xi64>
func @neg_lwe(%arg0: !Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4> {
// CHECK-NEXT: %[[V1:.*]] = linalg.init_tensor [1025] : tensor<1025xi64>
// CHECK-NEXT: "BConcrete.negate_lwe_buffer"(%[[V1]], %arg0) : (tensor<1025xi64>, tensor<1025xi64>) -> ()
// CHECK-NEXT: return %[[V1]] : tensor<1025xi64>
%0 = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4>
return %0 : !Concrete.lwe_ciphertext<1024,4>
}

View File

@@ -0,0 +1,32 @@
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
// CHECK-LABEL: func @sub_const_int_lwe(%arg0: tensor<1025xi64>) -> tensor<1025xi64>
func @sub_const_int_lwe(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8
// CHECK-NEXT: %[[V2:.*]] = linalg.init_tensor [1025] : tensor<1025xi64>
// CHECK-NEXT: "BConcrete.negate_lwe_buffer"(%[[V2]], %arg0) : (tensor<1025xi64>, tensor<1025xi64>) -> ()
// CHECK-NEXT: %[[V3:.*]] = "Concrete.encode_int"(%[[V1]]) : (i8) -> !Concrete.plaintext<8>
// CHECK-NEXT: %[[V4:.*]] = linalg.init_tensor [1025] : tensor<1025xi64>
// CHECK-NEXT: "BConcrete.add_plaintext_lwe_buffer"(%[[V4]], %[[V2]], %[[V3]]) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.plaintext<8>) -> ()
// CHECK-NEXT: return %[[V4]] : tensor<1025xi64>
%0 = arith.constant 1 : i8
%1 = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7>
%2 = "Concrete.encode_int"(%0) : (i8) -> !Concrete.plaintext<8>
%3 = "Concrete.add_plaintext_lwe_ciphertext"(%1, %2) : (!Concrete.lwe_ciphertext<1024,7>, !Concrete.plaintext<8>) -> !Concrete.lwe_ciphertext<1024,7>
return %3 : !Concrete.lwe_ciphertext<1024,7>
}
// CHECK-LABEL: func @sub_int_lwe(%arg0: tensor<1025xi64>, %arg1: i5) -> tensor<1025xi64>
func @sub_int_lwe(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> {
// CHECK-NEXT: %[[V1:.*]] = linalg.init_tensor [1025] : tensor<1025xi64>
// CHECK-NEXT: "BConcrete.negate_lwe_buffer"(%[[V1]], %arg0) : (tensor<1025xi64>, tensor<1025xi64>) -> ()
// CHECK-NEXT: %[[V2:.*]] = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5>
// CHECK-NEXT: %[[V3:.*]] = linalg.init_tensor [1025] : tensor<1025xi64>
// CHECK-NEXT: "BConcrete.add_plaintext_lwe_buffer"(%[[V3]], %[[V1]], %[[V2]]) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.plaintext<5>) -> ()
// CHECK-NEXT: return %[[V3]] : tensor<1025xi64>
%0 = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4>
%1 = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5>
%2 = "Concrete.add_plaintext_lwe_ciphertext"(%0, %1) : (!Concrete.lwe_ciphertext<1024,4>, !Concrete.plaintext<5>) -> !Concrete.lwe_ciphertext<1024,4>
return %2 : !Concrete.lwe_ciphertext<1024,4>
}

View File

@@ -0,0 +1,7 @@
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
// CHECK: func @tensor_identity(%arg0: tensor<2x3x4x1025xi64>) -> tensor<2x3x4x1025xi64> {
// CHECK-NEXT: return %arg0 : tensor<2x3x4x1025xi64>
// CHECK-NEXT: }
func @tensor_identity(%arg0: tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>>) -> tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>> {
return %arg0 : tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>>
}