mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
enhance(compiler): Lower from Concrete to BConcrete and BConcrete to C API call
This commit is contained in:
committed by
Quentin Bourgerie
parent
b3368027d0
commit
626493dda7
@@ -0,0 +1,22 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_BCONCRETETOBCONCRETECAPI_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_BCONCRETETOBCONCRETECAPI_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Create a pass to convert `Concrete` operators to function call to the
|
||||
/// `ConcreteCAPI`
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertBConcreteToBConcreteCAPIPass();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,18 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef ZAMALANG_CONVERSION_CONCRETETOBCONCRETE_PASS_H_
|
||||
#define ZAMALANG_CONVERSION_CONCRETETOBCONCRETE_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Create a pass to convert `Concrete` dialect to `BConcrete` dialect.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertConcreteToBConcretePass();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -11,6 +11,8 @@
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
|
||||
#include "concretelang/Conversion/BConcreteToBConcreteCAPI/Pass.h"
|
||||
#include "concretelang/Conversion/ConcreteToBConcrete/Pass.h"
|
||||
#include "concretelang/Conversion/ConcreteToConcreteCAPI/Pass.h"
|
||||
#include "concretelang/Conversion/ConcreteUnparametrize/Pass.h"
|
||||
#include "concretelang/Conversion/FHETensorOpsToLinalg/Pass.h"
|
||||
@@ -18,6 +20,7 @@
|
||||
#include "concretelang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h"
|
||||
#include "concretelang/Conversion/TFHEGlobalParametrization/Pass.h"
|
||||
#include "concretelang/Conversion/TFHEToConcrete/Pass.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
|
||||
|
||||
@@ -29,7 +29,15 @@ def TFHEToConcrete : Pass<"tfhe-to-concrete", "mlir::ModuleOp"> {
|
||||
let description = [{ Lowers operations from the TFHE dialect to Concrete }];
|
||||
let constructor = "mlir::concretelang::createConvertTFHEToConcretePass()";
|
||||
let options = [];
|
||||
let dependentDialects = ["mlir::linalg::LinalgDialect"];
|
||||
let dependentDialects = ["mlir::linalg::LinalgDialect", "mlir::concretelang::TFHE::TFHEDialect"];
|
||||
}
|
||||
|
||||
def ConcreteToBConcrete : Pass<"concrete-to-bconcrete", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from the Concrete dialect to Bufferized Concrete";
|
||||
let description = [{ Lowers operations from the Concrete dialect to Bufferized Concrete }];
|
||||
let constructor = "mlir::concretelang::createConvertConcreteToBConcretePass()";
|
||||
let options = [];
|
||||
let dependentDialects = ["mlir::linalg::LinalgDialect", "mlir::concretelang::Concrete::ConcreteDialect", "mlir::concretelang::BConcrete::BConcreteDialect"];
|
||||
}
|
||||
|
||||
def ConcreteToConcreteCAPI : Pass<"concrete-to-concrete-c-api", "mlir::ModuleOp"> {
|
||||
@@ -38,6 +46,12 @@ def ConcreteToConcreteCAPI : Pass<"concrete-to-concrete-c-api", "mlir::ModuleOp"
|
||||
let dependentDialects = ["mlir::concretelang::Concrete::ConcreteDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def BConcreteToBConcreteCAPI : Pass<"bconcrete-to-bconcrete-c-api", "mlir::ModuleOp"> {
|
||||
let summary = "Lower operations from the Bufferized Concrete dialect to std with function call to the Bufferized Concrete C API";
|
||||
let constructor = "mlir::concretelang::createConvertBConcreteToBConcreteCAPIPass()";
|
||||
let dependentDialects = ["mlir::concretelang::BConcrete::BConcreteDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def ConcreteUnparametrize : Pass<"concrete-unparametrize", "mlir::ModuleOp"> {
|
||||
let summary = "Unparametrize Concrete types and remove unrealized_conversion_cast";
|
||||
let constructor = "mlir::concretelang::createConvertConcreteToConcreteCAPIPass()";
|
||||
|
||||
@@ -119,6 +119,10 @@ public:
|
||||
// operations
|
||||
CONCRETE,
|
||||
|
||||
// Read sources and lower all FHE, TFHE and Concrete operations to BConcrete
|
||||
// operations
|
||||
BCONCRETE,
|
||||
|
||||
// Read sources and lower all FHE, TFHE and Concrete
|
||||
// operations to canonical MLIR dialects. Cryptographic operations
|
||||
// are lowered to invocations of the concrete library.
|
||||
|
||||
@@ -43,8 +43,12 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
|
||||
@@ -0,0 +1,593 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "mlir//IR/BuiltinTypes.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
|
||||
|
||||
namespace {
|
||||
class BConcreteToBConcreteCAPITypeConverter : public mlir::TypeConverter {
|
||||
|
||||
public:
|
||||
BConcreteToBConcreteCAPITypeConverter() {
|
||||
addConversion([](mlir::Type type) { return type; });
|
||||
addConversion([&](mlir::concretelang::Concrete::PlaintextType type) {
|
||||
return mlir::IntegerType::get(type.getContext(), 64);
|
||||
});
|
||||
addConversion([&](mlir::concretelang::Concrete::CleartextType type) {
|
||||
return mlir::IntegerType::get(type.getContext(), 64);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op,
|
||||
mlir::RewriterBase &rewriter,
|
||||
llvm::StringRef funcName,
|
||||
mlir::FunctionType funcType) {
|
||||
// Looking for the `funcName` Operation
|
||||
auto module = mlir::SymbolTable::getNearestSymbolTable(op);
|
||||
auto opFunc = mlir::dyn_cast_or_null<mlir::SymbolOpInterface>(
|
||||
mlir::SymbolTable::lookupSymbolIn(module, funcName));
|
||||
if (!opFunc) {
|
||||
// Insert the forward declaration of the funcName
|
||||
mlir::OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&module->getRegion(0).front());
|
||||
|
||||
opFunc = rewriter.create<mlir::FuncOp>(rewriter.getUnknownLoc(), funcName,
|
||||
funcType);
|
||||
opFunc.setPrivate();
|
||||
} else {
|
||||
// Check if the `funcName` is well a private function
|
||||
if (!opFunc.isPrivate()) {
|
||||
op->emitError() << "the function \"" << funcName
|
||||
<< "\" conflicts with the concrete C API, please rename";
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
assert(mlir::SymbolTable::lookupSymbolIn(module, funcName)
|
||||
->template hasTrait<mlir::OpTrait::FunctionLike>());
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Set of functions to generate generic types.
|
||||
// Generic types are used to add forward declarations without a specific type.
|
||||
// For example, we may need to add LWE ciphertext of different dimensions, or
|
||||
// allocate them. All the calls to the C API should be done using this generic
|
||||
// types, and casting should then be performed back to the appropriate type.
|
||||
|
||||
inline mlir::Type getGenericLweBufferType(mlir::MLIRContext *context) {
|
||||
return mlir::RankedTensorType::get({-1}, mlir::IntegerType::get(context, 64));
|
||||
}
|
||||
|
||||
inline mlir::concretelang::Concrete::GlweCiphertextType
|
||||
getGenericGlweCiphertextType(mlir::MLIRContext *context) {
|
||||
return mlir::concretelang::Concrete::GlweCiphertextType::get(context);
|
||||
}
|
||||
|
||||
inline mlir::Type getGenericPlaintextType(mlir::MLIRContext *context) {
|
||||
return mlir::IntegerType::get(context, 64);
|
||||
}
|
||||
|
||||
inline mlir::Type getGenericCleartextType(mlir::MLIRContext *context) {
|
||||
return mlir::IntegerType::get(context, 64);
|
||||
}
|
||||
|
||||
inline mlir::concretelang::Concrete::PlaintextListType
|
||||
getGenericPlaintextListType(mlir::MLIRContext *context) {
|
||||
return mlir::concretelang::Concrete::PlaintextListType::get(context);
|
||||
}
|
||||
|
||||
inline mlir::concretelang::Concrete::ForeignPlaintextListType
|
||||
getGenericForeignPlaintextListType(mlir::MLIRContext *context) {
|
||||
return mlir::concretelang::Concrete::ForeignPlaintextListType::get(context);
|
||||
}
|
||||
|
||||
inline mlir::concretelang::Concrete::LweKeySwitchKeyType
|
||||
getGenericLweKeySwitchKeyType(mlir::MLIRContext *context) {
|
||||
return mlir::concretelang::Concrete::LweKeySwitchKeyType::get(context);
|
||||
}
|
||||
|
||||
inline mlir::concretelang::Concrete::LweBootstrapKeyType
|
||||
getGenericLweBootstrapKeyType(mlir::MLIRContext *context) {
|
||||
return mlir::concretelang::Concrete::LweBootstrapKeyType::get(context);
|
||||
}
|
||||
|
||||
// Insert all forward declarations needed for the pass.
|
||||
// Should generalize input and output types for all decalarations, and the
|
||||
// pattern using them would be resposible for casting them to the appropriate
|
||||
// type.
|
||||
mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
mlir::IRRewriter &rewriter) {
|
||||
auto lweBufferType = getGenericLweBufferType(rewriter.getContext());
|
||||
auto plaintextType = getGenericPlaintextType(rewriter.getContext());
|
||||
auto cleartextType = getGenericCleartextType(rewriter.getContext());
|
||||
auto glweCiphertextType = getGenericGlweCiphertextType(rewriter.getContext());
|
||||
auto plaintextListType = getGenericPlaintextListType(rewriter.getContext());
|
||||
auto foreignPlaintextList =
|
||||
getGenericForeignPlaintextListType(rewriter.getContext());
|
||||
auto keySwitchKeyType = getGenericLweKeySwitchKeyType(rewriter.getContext());
|
||||
auto bootstrapKeyType = getGenericLweBootstrapKeyType(rewriter.getContext());
|
||||
auto contextType =
|
||||
mlir::concretelang::Concrete::ContextType::get(rewriter.getContext());
|
||||
|
||||
// Insert forward declaration of the add_lwe_ciphertexts function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {lweBufferType, lweBufferType, lweBufferType},
|
||||
{});
|
||||
if (insertForwardDeclaration(op, rewriter, "memref_add_lwe_ciphertexts_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the add_plaintext_lwe_ciphertext_u64 function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {lweBufferType, lweBufferType, plaintextType},
|
||||
{});
|
||||
if (insertForwardDeclaration(
|
||||
op, rewriter, "memref_add_plaintext_lwe_ciphertext_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the mul_cleartext_lwe_ciphertext_u64 function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {lweBufferType, lweBufferType, cleartextType},
|
||||
{});
|
||||
if (insertForwardDeclaration(
|
||||
op, rewriter, "memref_mul_cleartext_lwe_ciphertext_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the negate_lwe_ciphertext_u64 function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{lweBufferType, lweBufferType}, {});
|
||||
if (insertForwardDeclaration(op, rewriter,
|
||||
"memref_negate_lwe_ciphertext_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the memref_keyswitch_lwe_u64 function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {keySwitchKeyType, lweBufferType, lweBufferType},
|
||||
{});
|
||||
if (insertForwardDeclaration(op, rewriter, "memref_keyswitch_lwe_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the memref_bootstrap_lwe_u64 function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{bootstrapKeyType, lweBufferType, lweBufferType, glweCiphertextType},
|
||||
{});
|
||||
if (insertForwardDeclaration(op, rewriter, "memref_bootstrap_lwe_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the fill_plaintext_list function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {plaintextListType, foreignPlaintextList}, {});
|
||||
if (insertForwardDeclaration(
|
||||
op, rewriter, "fill_plaintext_list_with_expansion_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the add_plaintext_list_glwe function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{glweCiphertextType, glweCiphertextType, plaintextListType}, {});
|
||||
if (insertForwardDeclaration(
|
||||
op, rewriter, "add_plaintext_list_glwe_ciphertext_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the getGlobalKeyswitchKey function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{contextType}, {keySwitchKeyType});
|
||||
if (insertForwardDeclaration(op, rewriter, "get_keyswitch_key", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the getGlobalBootstrapKey function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{contextType}, {bootstrapKeyType});
|
||||
if (insertForwardDeclaration(op, rewriter, "get_bootstrap_key", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// For all operands `tensor<Axi64>` replace with
|
||||
// `%casted = tensor.cast %op : tensor<Axi64> to tensor<?xui64>`
|
||||
template <typename Op>
|
||||
mlir::SmallVector<mlir::Value>
|
||||
getCastedTensorOperands(Op op, mlir::PatternRewriter &rewriter) {
|
||||
mlir::SmallVector<mlir::Value, 4> newOperands{};
|
||||
for (mlir::Value operand : op->getOperands()) {
|
||||
mlir::Type operandType = operand.getType();
|
||||
if (operandType.isa<mlir::RankedTensorType>()) {
|
||||
mlir::Value castedOp = rewriter.create<mlir::tensor::CastOp>(
|
||||
op.getLoc(), getGenericLweBufferType(rewriter.getContext()), operand);
|
||||
newOperands.push_back(castedOp);
|
||||
} else {
|
||||
newOperands.push_back(operand);
|
||||
}
|
||||
}
|
||||
return std::move(newOperands);
|
||||
}
|
||||
|
||||
/// BConcreteOpToConcreteCAPICallPattern<Op> match the `BConcreteOp`
|
||||
/// Operation and replace with a call to `funcName`, the funcName should be an
|
||||
/// external function that was linked later. It insert the forward declaration
|
||||
/// of the private `funcName` if it not already in the symbol table. The C
|
||||
/// signature of the function should be `void (out, args..., lweDimension)`, the
|
||||
/// pattern rewrite:
|
||||
/// ```
|
||||
/// "BConcreteOp"(%out, args ...) :
|
||||
/// (tensor<sizexi64>, tensor<sizexi64>...) -> ()
|
||||
/// ```
|
||||
/// to
|
||||
/// ```
|
||||
/// %out0 = tensor.cast %out : tensor<sizexi64> to tensor<?xui64>
|
||||
/// %args = tensor.cast ...
|
||||
/// call @funcName(%out, args...) : (tensor<?xi64>, tensor<?xi64>...) -> ()
|
||||
/// ```
|
||||
template <typename BConcreteOp>
|
||||
struct ConcreteOpToConcreteCAPICallPattern
|
||||
: public mlir::OpRewritePattern<BConcreteOp> {
|
||||
ConcreteOpToConcreteCAPICallPattern(mlir::MLIRContext *context,
|
||||
mlir::StringRef funcName,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<BConcreteOp>(context, benefit),
|
||||
funcName(funcName) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(BConcreteOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
BConcreteToBConcreteCAPITypeConverter typeConverter;
|
||||
rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
||||
op, funcName, mlir::TypeRange{},
|
||||
getCastedTensorOperands<BConcreteOp>(op, rewriter));
|
||||
return mlir::success();
|
||||
};
|
||||
|
||||
private:
|
||||
std::string funcName;
|
||||
};
|
||||
|
||||
struct ConcreteEncodeIntOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::Concrete::EncodeIntOp> {
|
||||
ConcreteEncodeIntOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<mlir::concretelang::Concrete::EncodeIntOp>(
|
||||
context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::Concrete::EncodeIntOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
{
|
||||
mlir::Value castedInt = rewriter.create<mlir::arith::ExtUIOp>(
|
||||
op.getLoc(), rewriter.getIntegerType(64), op->getOperands().front());
|
||||
mlir::Value constantShiftOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI64IntegerAttr(64 - op.getType().getP()));
|
||||
|
||||
mlir::Type resultType = rewriter.getIntegerType(64);
|
||||
rewriter.replaceOpWithNewOp<mlir::arith::ShLIOp>(
|
||||
op, resultType, castedInt, constantShiftOp);
|
||||
}
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
struct ConcreteIntToCleartextOpPattern
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::concretelang::Concrete::IntToCleartextOp> {
|
||||
ConcreteIntToCleartextOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<mlir::concretelang::Concrete::IntToCleartextOp>(
|
||||
context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::Concrete::IntToCleartextOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<mlir::arith::ExtUIOp>(
|
||||
op, rewriter.getIntegerType(64), op->getOperands().front());
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
mlir::Value getContextArgument(mlir::Operation *op) {
|
||||
mlir::Block *block = op->getBlock();
|
||||
while (block != nullptr) {
|
||||
if (llvm::isa<mlir::FuncOp>(block->getParentOp())) {
|
||||
|
||||
mlir::Value context = block->getArguments().back();
|
||||
|
||||
assert(
|
||||
context.getType().isa<mlir::concretelang::Concrete::ContextType>() &&
|
||||
"the Concrete.context should be the last argument of the enclosing "
|
||||
"function of the op");
|
||||
|
||||
return context;
|
||||
}
|
||||
block = block->getParentOp()->getBlock();
|
||||
}
|
||||
assert("can't find a function that enclose the op");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Rewrite pattern that rewrite every
|
||||
// ```
|
||||
// "BConcrete.keyswitch_lwe_buffer"(%out, %in) {...}:
|
||||
// (tensor<2049xi64>, tensor<2049xi64>) -> ()
|
||||
// ```
|
||||
//
|
||||
// to
|
||||
//
|
||||
// ```
|
||||
// %ksk = call @get_keywswitch_key(%ctx) :
|
||||
// (!Concrete.context) -> !Concrete.lwe_key_switch_key
|
||||
// %out_ = tensor.cast %out : tensor<sizexi64> to tensor<?xi64>
|
||||
// %in_ = tensor.cast %in : tensor<size'xi64> to tensor<?xi64>
|
||||
// call @memref_keyswitch_lwe_u64(%ksk, %out_, %in_) :
|
||||
// (!Concrete.lwe_key_switch_key, tensor<?xui64>, tensor<?xui64>) -> ()
|
||||
// ```
|
||||
struct BConcreteKeySwitchLweOpPattern
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::concretelang::BConcrete::KeySwitchLweBufferOp> {
|
||||
BConcreteKeySwitchLweOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<
|
||||
mlir::concretelang::BConcrete::KeySwitchLweBufferOp>(context,
|
||||
benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::BConcrete::KeySwitchLweBufferOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::CallOp kskOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "get_keyswitch_key",
|
||||
getGenericLweKeySwitchKeyType(rewriter.getContext()),
|
||||
mlir::SmallVector<mlir::Value>{getContextArgument(op)});
|
||||
mlir::SmallVector<mlir::Value, 3> operands{kskOp.getResult(0)};
|
||||
|
||||
operands.append(
|
||||
getCastedTensorOperands<
|
||||
mlir::concretelang::BConcrete::KeySwitchLweBufferOp>(op, rewriter));
|
||||
rewriter.replaceOpWithNewOp<mlir::CallOp>(op, "memref_keyswitch_lwe_u64",
|
||||
mlir::TypeRange({}), operands);
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
// Rewrite pattern that rewrite every
|
||||
// ```
|
||||
// "BConcrete.bootstrap_lwe_buffer"(%out, %in, %acc) {...} :
|
||||
// (tensor<2049xui64>, tensor<2049xui64>, !Concrete.glwe_ciphertext) -> ()
|
||||
// ```
|
||||
//
|
||||
// to
|
||||
//
|
||||
// ```
|
||||
// %bsk = call @getGlobalBootstrapKey() : () -> !Concrete.lwe_bootstrap_key
|
||||
// %out_ = tensor.cast %out : tensor<sizexi64> to tensor<?xi64>
|
||||
// %in_ = tensor.cast %in : tensor<size'xi64> to tensor<?xi64>
|
||||
// call @memref_bootstrap_lwe_u64(%bsk, %out_, %in_, %acc_) :
|
||||
// (!Concrete.lwe_bootstrap_key, tensor<?xi64>, tensor<?xi64>,
|
||||
// !Concrete.glwe_ciphertext) -> ()
|
||||
// ```
|
||||
struct BConcreteBootstrapLweOpPattern
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::concretelang::BConcrete::BootstrapLweBufferOp> {
|
||||
BConcreteBootstrapLweOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<
|
||||
mlir::concretelang::BConcrete::BootstrapLweBufferOp>(context,
|
||||
benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::BConcrete::BootstrapLweBufferOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::SmallVector<mlir::Value> getkskOperands{};
|
||||
mlir::CallOp bskOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "get_bootstrap_key",
|
||||
getGenericLweBootstrapKeyType(rewriter.getContext()),
|
||||
mlir::SmallVector<mlir::Value>{getContextArgument(op)});
|
||||
mlir::SmallVector<mlir::Value, 4> operands{bskOp.getResult(0)};
|
||||
operands.append(
|
||||
getCastedTensorOperands<
|
||||
mlir::concretelang::BConcrete::BootstrapLweBufferOp>(op, rewriter));
|
||||
rewriter.replaceOpWithNewOp<mlir::CallOp>(op, "memref_bootstrap_lwe_u64",
|
||||
mlir::TypeRange({}), operands);
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
/// Populate the RewritePatternSet with all patterns that rewrite Concrete
|
||||
/// operators to the corresponding function call to the `Concrete C API`.
|
||||
void populateBConcreteToBConcreteCAPICall(mlir::RewritePatternSet &patterns) {
|
||||
patterns.add<ConcreteOpToConcreteCAPICallPattern<
|
||||
mlir::concretelang::BConcrete::AddLweBuffersOp>>(
|
||||
patterns.getContext(), "memref_add_lwe_ciphertexts_u64");
|
||||
patterns.add<ConcreteOpToConcreteCAPICallPattern<
|
||||
mlir::concretelang::BConcrete::AddPlaintextLweBufferOp>>(
|
||||
patterns.getContext(), "memref_add_plaintext_lwe_ciphertext_u64");
|
||||
patterns.add<ConcreteOpToConcreteCAPICallPattern<
|
||||
mlir::concretelang::BConcrete::MulCleartextLweBufferOp>>(
|
||||
patterns.getContext(), "memref_mul_cleartext_lwe_ciphertext_u64");
|
||||
patterns.add<ConcreteOpToConcreteCAPICallPattern<
|
||||
mlir::concretelang::BConcrete::NegateLweBufferOp>>(
|
||||
patterns.getContext(), "memref_negate_lwe_ciphertext_u64");
|
||||
patterns.add<ConcreteEncodeIntOpPattern>(patterns.getContext());
|
||||
patterns.add<ConcreteIntToCleartextOpPattern>(patterns.getContext());
|
||||
// patterns.add<ConcreteZeroOpPattern>(patterns.getContext());
|
||||
patterns.add<BConcreteKeySwitchLweOpPattern>(patterns.getContext());
|
||||
patterns.add<BConcreteBootstrapLweOpPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
struct AddRuntimeContextToFuncOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::FuncOp> {
|
||||
AddRuntimeContextToFuncOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<mlir::FuncOp>(context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::FuncOp oldFuncOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::OpBuilder::InsertionGuard guard(rewriter);
|
||||
mlir::FunctionType oldFuncType = oldFuncOp.getType();
|
||||
|
||||
// Add a Concrete.context to the function signature
|
||||
mlir::SmallVector<mlir::Type> newInputs(oldFuncType.getInputs().begin(),
|
||||
oldFuncType.getInputs().end());
|
||||
newInputs.push_back(
|
||||
rewriter.getType<mlir::concretelang::Concrete::ContextType>());
|
||||
mlir::FunctionType newFuncTy = rewriter.getType<mlir::FunctionType>(
|
||||
newInputs, oldFuncType.getResults());
|
||||
// Create the new func
|
||||
mlir::FuncOp newFuncOp = rewriter.create<mlir::FuncOp>(
|
||||
oldFuncOp.getLoc(), oldFuncOp.getName(), newFuncTy);
|
||||
|
||||
// Create the arguments of the new func
|
||||
mlir::Region &newFuncBody = newFuncOp.body();
|
||||
mlir::Block *newFuncEntryBlock = new mlir::Block();
|
||||
newFuncEntryBlock->addArguments(newFuncTy.getInputs());
|
||||
newFuncBody.push_back(newFuncEntryBlock);
|
||||
|
||||
// Clone the old body to the new one
|
||||
mlir::BlockAndValueMapping map;
|
||||
for (auto arg : llvm::enumerate(oldFuncOp.getArguments())) {
|
||||
map.map(arg.value(), newFuncEntryBlock->getArgument(arg.index()));
|
||||
}
|
||||
for (auto &op : oldFuncOp.body().front()) {
|
||||
newFuncEntryBlock->push_back(op.clone(map));
|
||||
}
|
||||
rewriter.eraseOp(oldFuncOp);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Legal function are one that are private or has a Concrete.context as last
|
||||
// arguments.
|
||||
static bool isLegal(mlir::FuncOp funcOp) {
|
||||
if (!funcOp.isPublic()) {
|
||||
return true;
|
||||
}
|
||||
// TODO : Don't need to add a runtime context for function that doesn't
|
||||
// manipulates Concrete types.
|
||||
//
|
||||
// if (!llvm::any_of(funcOp.getType().getInputs(), [](mlir::Type t) {
|
||||
// if (auto tensorTy = t.dyn_cast_or_null<mlir::TensorType>()) {
|
||||
// t = tensorTy.getElementType();
|
||||
// }
|
||||
// return llvm::isa<mlir::concretelang::Concrete::ConcreteDialect>(
|
||||
// t.getDialect());
|
||||
// })) {
|
||||
// return true;
|
||||
// }
|
||||
return funcOp.getType().getNumInputs() >= 1 &&
|
||||
funcOp.getType()
|
||||
.getInputs()
|
||||
.back()
|
||||
.isa<mlir::concretelang::Concrete::ContextType>();
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
struct BConcreteToBConcreteCAPIPass
|
||||
: public BConcreteToBConcreteCAPIBase<BConcreteToBConcreteCAPIPass> {
|
||||
void runOnOperation() final;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void BConcreteToBConcreteCAPIPass::runOnOperation() {
|
||||
mlir::ModuleOp op = getOperation();
|
||||
|
||||
// First of all add the Concrete.context to the block arguments of function
|
||||
// that manipulates ciphertexts.
|
||||
{
|
||||
mlir::ConversionTarget target(getContext());
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
|
||||
return AddRuntimeContextToFuncOpPattern::isLegal(funcOp);
|
||||
});
|
||||
|
||||
patterns.add<AddRuntimeContextToFuncOpPattern>(patterns.getContext());
|
||||
|
||||
// Apply the conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
.failed()) {
|
||||
this->signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Insert forward declaration
|
||||
mlir::IRRewriter rewriter(&getContext());
|
||||
if (insertForwardDeclarations(op, rewriter).failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
// Rewrite Concrete ops to CallOp to the Concrete C API
|
||||
{
|
||||
mlir::ConversionTarget target(getContext());
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
target.addIllegalDialect<mlir::concretelang::BConcrete::BConcreteDialect>();
|
||||
|
||||
target.addLegalDialect<mlir::BuiltinDialect, mlir::StandardOpsDialect,
|
||||
mlir::tensor::TensorDialect,
|
||||
mlir::arith::ArithmeticDialect>();
|
||||
|
||||
populateBConcreteToBConcreteCAPICall(patterns);
|
||||
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
.failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertBConcreteToBConcreteCAPIPass() {
|
||||
return std::make_unique<BConcreteToBConcreteCAPIPass>();
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -0,0 +1,16 @@
|
||||
add_mlir_dialect_library(BConcreteToBConcreteCAPI
|
||||
BConcreteToBConcreteCAPI.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
|
||||
|
||||
DEPENDS
|
||||
BConcreteDialect
|
||||
MLIRConversionPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRTransforms
|
||||
)
|
||||
|
||||
target_link_libraries(BConcreteToBConcreteCAPI PUBLIC MLIRIR)
|
||||
@@ -2,6 +2,8 @@ add_subdirectory(FHEToTFHE)
|
||||
add_subdirectory(TFHEGlobalParametrization)
|
||||
add_subdirectory(TFHEToConcrete)
|
||||
add_subdirectory(FHETensorOpsToLinalg)
|
||||
add_subdirectory(ConcreteToBConcrete)
|
||||
add_subdirectory(ConcreteToConcreteCAPI)
|
||||
add_subdirectory(BConcreteToBConcreteCAPI)
|
||||
add_subdirectory(MLIRLowerableDialectsToLLVM)
|
||||
add_subdirectory(ConcreteUnparametrize)
|
||||
|
||||
16
compiler/lib/Conversion/ConcreteToBConcrete/CMakeLists.txt
Normal file
16
compiler/lib/Conversion/ConcreteToBConcrete/CMakeLists.txt
Normal file
@@ -0,0 +1,16 @@
|
||||
add_mlir_dialect_library(ConcreteToBConcrete
|
||||
ConcreteToBConcrete.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/Concrete
|
||||
|
||||
DEPENDS
|
||||
ConcreteDialect
|
||||
BConcreteDialect
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRTransforms
|
||||
MLIRMath)
|
||||
|
||||
target_link_libraries(ConcreteToBConcrete PUBLIC BConcreteDialect MLIRIR)
|
||||
@@ -0,0 +1,919 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
|
||||
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
|
||||
#include "concretelang/Dialect/RT/IR/RTOps.h"
|
||||
|
||||
namespace {
|
||||
struct ConcreteToBConcretePass
|
||||
: public ConcreteToBConcreteBase<ConcreteToBConcretePass> {
|
||||
void runOnOperation() final;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
/// ConcreteToBConcreteTypeConverter is a TypeConverter that transform
|
||||
/// `Concrete.lwe_ciphertext<dimension,p>` to `tensor<dimension+1, i64>>`
|
||||
/// `tensor<...xConcrete.lwe_ciphertext<dimension,p>>` to
|
||||
/// `tensor<...xdimension+1, i64>>`
|
||||
class ConcreteToBConcreteTypeConverter : public mlir::TypeConverter {
|
||||
|
||||
public:
|
||||
ConcreteToBConcreteTypeConverter() {
|
||||
addConversion([](mlir::Type type) { return type; });
|
||||
addConversion([&](mlir::concretelang::Concrete::LweCiphertextType type) {
|
||||
return mlir::RankedTensorType::get(
|
||||
{type.getDimension() + 1},
|
||||
mlir::IntegerType::get(type.getContext(), 64));
|
||||
});
|
||||
addConversion([&](mlir::RankedTensorType type) {
|
||||
auto lwe = type.getElementType()
|
||||
.dyn_cast_or_null<
|
||||
mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
if (lwe == nullptr) {
|
||||
return (mlir::Type)(type);
|
||||
}
|
||||
mlir::SmallVector<int64_t> newShape;
|
||||
newShape.reserve(type.getShape().size() + 1);
|
||||
newShape.append(type.getShape().begin(), type.getShape().end());
|
||||
newShape.push_back(lwe.getDimension() + 1);
|
||||
mlir::Type r = mlir::RankedTensorType::get(
|
||||
newShape, mlir::IntegerType::get(type.getContext(), 64));
|
||||
return r;
|
||||
});
|
||||
addConversion([&](mlir::MemRefType type) {
|
||||
auto lwe = type.getElementType()
|
||||
.dyn_cast_or_null<
|
||||
mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
if (lwe == nullptr) {
|
||||
return (mlir::Type)(type);
|
||||
}
|
||||
mlir::SmallVector<int64_t> newShape;
|
||||
newShape.reserve(type.getShape().size() + 1);
|
||||
newShape.append(type.getShape().begin(), type.getShape().end());
|
||||
newShape.push_back(lwe.getDimension() + 1);
|
||||
mlir::Type r = mlir::MemRefType::get(
|
||||
newShape, mlir::IntegerType::get(type.getContext(), 64));
|
||||
return r;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of `Concrete.zero_tensor`
|
||||
// operators.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = "Concrete.zero_tensor" () :
|
||||
// tensor<...x!Concrete.lwe_ciphertext<lweDim,p>>
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = tensor.generate {
|
||||
// ^bb0(... : index):
|
||||
// %c0 = arith.constant 0 : i64
|
||||
// tensor.yield %z
|
||||
// }: tensor<...xlweDim+1xi64>
|
||||
// i64>
|
||||
// ```
|
||||
template <typename ZeroOp>
|
||||
struct ZeroOpPattern : public mlir::OpRewritePattern<ZeroOp> {
|
||||
ZeroOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<ZeroOp>(context, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(ZeroOp zeroOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
auto resultTy = zeroOp.getType();
|
||||
auto newResultTy = converter.convertType(resultTy);
|
||||
|
||||
auto generateBody = [&](mlir::OpBuilder &nestedBuilder,
|
||||
mlir::Location nestedLoc,
|
||||
mlir::ValueRange blockArgs) {
|
||||
// %c0 = 0 : i64
|
||||
auto cstOp = nestedBuilder.create<mlir::arith::ConstantOp>(
|
||||
nestedLoc, nestedBuilder.getI64IntegerAttr(1));
|
||||
// tensor.yield %z : !FHE.eint<p>
|
||||
nestedBuilder.create<mlir::tensor::YieldOp>(nestedLoc, cstOp.getResult());
|
||||
};
|
||||
// tensor.generate
|
||||
rewriter.replaceOpWithNewOp<mlir::tensor::GenerateOp>(
|
||||
zeroOp, newResultTy, mlir::ValueRange{}, generateBody);
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
// This template rewrite pattern transforms any instance of
|
||||
// `ConcreteOp` to an instance of `BConcreteOp`.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// %0 = "ConcreteOp"(%arg0, ...) :
|
||||
// (!Concrete.lwe_ciphertext<lwe_dimension, p>, ...) ->
|
||||
// (!Concrete.lwe_ciphertext<lwe_dimension, p>)
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// %0 = linalg.init_tensor [dimension+1] : tensor<dimension+1, i64>
|
||||
// "BConcreteOp"(%0, %arg0, ...) : (tensor<dimension+1, i64>>,
|
||||
// tensor<dimension+1, i64>>, ..., ) -> ()
|
||||
//
|
||||
// A reference to the preallocated output is always passed as the first
|
||||
// argument.
|
||||
template <typename ConcreteOp, typename BConcreteOp>
|
||||
struct LowToBConcrete : public mlir::OpRewritePattern<ConcreteOp> {
|
||||
LowToBConcrete(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<ConcreteOp>(context, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(ConcreteOp concreteOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
mlir::concretelang::Concrete::LweCiphertextType resultTy =
|
||||
((mlir::Type)concreteOp->getResult(0).getType())
|
||||
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
auto newResultTy =
|
||||
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
|
||||
|
||||
// %0 = linalg.init_tensor [dimension+1] : tensor<dimension+1, i64>
|
||||
mlir::Value init = rewriter.replaceOpWithNewOp<mlir::linalg::InitTensorOp>(
|
||||
concreteOp, newResultTy.getShape(), newResultTy.getElementType());
|
||||
|
||||
// "BConcreteOp"(%0, %arg0, ...) : (tensor<dimension+1, i64>>,
|
||||
// tensor<dimension+1, i64>>, ..., ) -> ()
|
||||
mlir::SmallVector<mlir::Value, 3> newOperands{init};
|
||||
|
||||
newOperands.append(concreteOp.getOperation()->getOperands().begin(),
|
||||
concreteOp.getOperation()->getOperands().end());
|
||||
|
||||
llvm::ArrayRef<::mlir::NamedAttribute> attributes =
|
||||
concreteOp.getOperation()->getAttrs();
|
||||
|
||||
rewriter.create<BConcreteOp>(concreteOp.getLoc(),
|
||||
mlir::SmallVector<mlir::Type>{}, newOperands,
|
||||
attributes);
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of
|
||||
// `tensor.extract_slice` operators that operates on tensor of lwe ciphertext.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = tensor.extract_slice %arg0
|
||||
// [offsets...] [sizes...] [strides...]
|
||||
// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>> to
|
||||
// tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = tensor.extract_slice %arg0
|
||||
// [offsets..., 0] [sizes..., lweDimension+1] [strides..., 1]
|
||||
// : tensor<...xlweDimension+1,i64> to
|
||||
// tensor<...xlweDimension+1,i64>
|
||||
// ```
|
||||
struct ExtractSliceOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::tensor::ExtractSliceOp> {
|
||||
ExtractSliceOpPattern(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<mlir::tensor::ExtractSliceOp>(context,
|
||||
benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
auto resultTy = extractSliceOp.result().getType();
|
||||
auto resultEltTy =
|
||||
resultTy.cast<mlir::RankedTensorType>()
|
||||
.getElementType()
|
||||
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
auto newResultTy = converter.convertType(resultTy);
|
||||
|
||||
// add 0 to the static_offsets
|
||||
mlir::SmallVector<mlir::Attribute> staticOffsets;
|
||||
staticOffsets.append(extractSliceOp.static_offsets().begin(),
|
||||
extractSliceOp.static_offsets().end());
|
||||
staticOffsets.push_back(rewriter.getI64IntegerAttr(0));
|
||||
|
||||
// add the lweSize to the sizes
|
||||
mlir::SmallVector<mlir::Attribute> staticSizes;
|
||||
staticSizes.append(extractSliceOp.static_sizes().begin(),
|
||||
extractSliceOp.static_sizes().end());
|
||||
staticSizes.push_back(
|
||||
rewriter.getI64IntegerAttr(resultEltTy.getDimension() + 1));
|
||||
|
||||
// add 1 to the strides
|
||||
mlir::SmallVector<mlir::Attribute> staticStrides;
|
||||
staticStrides.append(extractSliceOp.static_strides().begin(),
|
||||
extractSliceOp.static_strides().end());
|
||||
staticStrides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
|
||||
// replace tensor.extract_slice to the new one
|
||||
rewriter.replaceOpWithNewOp<mlir::tensor::ExtractSliceOp>(
|
||||
extractSliceOp, newResultTy, extractSliceOp.source(),
|
||||
extractSliceOp.offsets(), extractSliceOp.sizes(),
|
||||
extractSliceOp.strides(), rewriter.getArrayAttr(staticOffsets),
|
||||
rewriter.getArrayAttr(staticSizes),
|
||||
rewriter.getArrayAttr(staticStrides));
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of
|
||||
// `tensor.extract` operators that operates on tensor of lwe ciphertext.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = tensor.extract %t[offsets...]
|
||||
// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %1 = tensor.extract_slice %arg0
|
||||
// [offsets...] [1..., lweDimension+1] [1...]
|
||||
// : tensor<...xlweDimension+1,i64> to
|
||||
// tensor<1...xlweDimension+1,i64>
|
||||
// %0 = linalg.tensor_collapse_shape %0 [[...]] :
|
||||
// tensor<1x1xlweDimension+1xi64> into tensor<lweDimension+1xi64>
|
||||
// ```
|
||||
//
|
||||
// TODO: since they are a bug on lowering extract_slice with rank reduction we
|
||||
// add a linalg.tensor_collapse_shape after the extract_slice without rank
|
||||
// reduction. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/issues/396.
|
||||
struct ExtractOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::tensor::ExtractOp> {
|
||||
ExtractOpPattern(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<mlir::tensor::ExtractOp>(context, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::tensor::ExtractOp extractOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
auto lweResultTy =
|
||||
extractOp.result()
|
||||
.getType()
|
||||
.dyn_cast_or_null<
|
||||
mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
if (lweResultTy == nullptr) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
auto newResultTy =
|
||||
converter.convertType(lweResultTy).cast<mlir::RankedTensorType>();
|
||||
auto rankOfResult = extractOp.indices().size() + 1;
|
||||
|
||||
// [min..., 0] for static_offsets ()
|
||||
mlir::SmallVector<mlir::Attribute> staticOffsets(
|
||||
rankOfResult,
|
||||
rewriter.getI64IntegerAttr(std::numeric_limits<int64_t>::min()));
|
||||
staticOffsets[staticOffsets.size() - 1] = rewriter.getI64IntegerAttr(0);
|
||||
|
||||
// [1..., lweDimension+1] for static_sizes
|
||||
mlir::SmallVector<mlir::Attribute> staticSizes(
|
||||
rankOfResult, rewriter.getI64IntegerAttr(1));
|
||||
staticSizes[staticSizes.size() - 1] = rewriter.getI64IntegerAttr(
|
||||
newResultTy.getDimSize(newResultTy.getRank() - 1));
|
||||
|
||||
// [1...] for static_strides
|
||||
mlir::SmallVector<mlir::Attribute> staticStrides(
|
||||
rankOfResult, rewriter.getI64IntegerAttr(1));
|
||||
|
||||
// replace tensor.extract_slice to the new one
|
||||
mlir::SmallVector<int64_t> extractedSliceShape(
|
||||
extractOp.indices().size() + 1, 0);
|
||||
extractedSliceShape.reserve(extractOp.indices().size() + 1);
|
||||
for (size_t i = 0; i < extractedSliceShape.size() - 1; i++) {
|
||||
extractedSliceShape[i] = 1;
|
||||
}
|
||||
extractedSliceShape[extractedSliceShape.size() - 1] =
|
||||
newResultTy.getDimSize(0);
|
||||
|
||||
auto extractedSliceType =
|
||||
mlir::RankedTensorType::get(extractedSliceShape, rewriter.getI64Type());
|
||||
auto extractedSlice = rewriter.create<mlir::tensor::ExtractSliceOp>(
|
||||
extractOp.getLoc(), extractedSliceType, extractOp.tensor(),
|
||||
extractOp.indices(), mlir::SmallVector<mlir::Value>{},
|
||||
mlir::SmallVector<mlir::Value>{}, rewriter.getArrayAttr(staticOffsets),
|
||||
rewriter.getArrayAttr(staticSizes),
|
||||
rewriter.getArrayAttr(staticStrides));
|
||||
mlir::ReassociationIndices reassociation;
|
||||
for (int64_t i = 0; i < extractedSliceType.getRank(); i++) {
|
||||
reassociation.push_back(i);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<mlir::linalg::TensorCollapseShapeOp>(
|
||||
extractOp, newResultTy, extractedSlice,
|
||||
mlir::SmallVector<mlir::ReassociationIndices>{reassociation});
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of
|
||||
// `tensor.insert_slice` operators that operates on tensor of lwe ciphertext.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = tensor.insert_slice %arg1
|
||||
// into %arg0[offsets...] [sizes...] [strides...]
|
||||
// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>> into
|
||||
// tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = tensor.insert_slice %arg1
|
||||
// into %arg0[offsets..., 0] [sizes..., lweDimension+1] [strides..., 1]
|
||||
// : tensor<...xlweDimension+1xi64> into
|
||||
// tensor<...xlweDimension+1xi64>
|
||||
// ```
|
||||
struct InsertSliceOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::tensor::InsertSliceOp> {
|
||||
InsertSliceOpPattern(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<mlir::tensor::InsertSliceOp>(context,
|
||||
benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::tensor::InsertSliceOp insertSliceOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
auto resultTy = insertSliceOp.result().getType();
|
||||
|
||||
auto newResultTy =
|
||||
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
|
||||
|
||||
// add 0 to static_offsets
|
||||
mlir::SmallVector<mlir::Attribute> staticOffsets;
|
||||
staticOffsets.append(insertSliceOp.static_offsets().begin(),
|
||||
insertSliceOp.static_offsets().end());
|
||||
staticOffsets.push_back(rewriter.getI64IntegerAttr(0));
|
||||
|
||||
// add lweDimension+1 to static_sizes
|
||||
mlir::SmallVector<mlir::Attribute> staticSizes;
|
||||
staticSizes.append(insertSliceOp.static_sizes().begin(),
|
||||
insertSliceOp.static_sizes().end());
|
||||
staticSizes.push_back(rewriter.getI64IntegerAttr(
|
||||
newResultTy.getDimSize(newResultTy.getRank() - 1)));
|
||||
|
||||
// add 1 to the strides
|
||||
mlir::SmallVector<mlir::Attribute> staticStrides;
|
||||
staticStrides.append(insertSliceOp.static_strides().begin(),
|
||||
insertSliceOp.static_strides().end());
|
||||
staticStrides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
|
||||
// replace tensor.insert_slice with the new one
|
||||
rewriter.replaceOpWithNewOp<mlir::tensor::InsertSliceOp>(
|
||||
insertSliceOp, newResultTy, insertSliceOp.source(),
|
||||
insertSliceOp.dest(), insertSliceOp.offsets(), insertSliceOp.sizes(),
|
||||
insertSliceOp.strides(), rewriter.getArrayAttr(staticOffsets),
|
||||
rewriter.getArrayAttr(staticSizes),
|
||||
rewriter.getArrayAttr(staticStrides));
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of
|
||||
// `tensor.from_elements` operators that operates on tensor of lwe ciphertext.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = tensor.from_elements %e0, ..., %e(n-1)
|
||||
// : tensor<Nx!Concrete.lwe_ciphertext<lweDim,p>>
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %m = memref.alloc() : memref<NxlweDim+1xi64>
|
||||
// %s0 = memref.subview %m[0, 0][1, lweDim+1][1, 1] : memref<lweDim+1xi64>
|
||||
// %m0 = memref.buffer_cast %e0 : memref<lweDim+1xi64>
|
||||
// memref.copy %m0, s0 : memref<lweDim+1xi64> to memref<lweDim+1xi64>
|
||||
// ...
|
||||
// %s(n-1) = memref.subview %m[(n-1), 0][1, lweDim+1][1, 1]
|
||||
// : memref<lweDim+1xi64>
|
||||
// %m(n-1) = memref.buffer_cast %e(n-1) : memref<lweDim+1xi64>
|
||||
// memref.copy %e(n-1), s(n-1)
|
||||
// : memref<lweDim+1xi64> to memref<lweDim+1xi64>
|
||||
// %0 = memref.tensor_load %m : memref<NxlweDim+1xi64>
|
||||
// ```
|
||||
struct FromElementsOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::tensor::FromElementsOp> {
|
||||
FromElementsOpPattern(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<mlir::tensor::FromElementsOp>(context,
|
||||
benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::tensor::FromElementsOp fromElementsOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
|
||||
auto resultTy = fromElementsOp.result().getType();
|
||||
if (converter.isLegal(resultTy)) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto eltResultTy =
|
||||
resultTy.cast<mlir::RankedTensorType>()
|
||||
.getElementType()
|
||||
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
auto newTensorResultTy =
|
||||
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
|
||||
auto newMemrefResultTy = mlir::MemRefType::get(
|
||||
newTensorResultTy.getShape(), newTensorResultTy.getElementType());
|
||||
|
||||
// %m = memref.alloc() : memref<NxlweDim+1xi64>
|
||||
auto mOp = rewriter.create<mlir::memref::AllocOp>(fromElementsOp.getLoc(),
|
||||
newMemrefResultTy);
|
||||
|
||||
// for i = 0 to n-1
|
||||
// %si = memref.subview %m[i, 0][1, lweDim+1][1, 1] : memref<lweDim+1xi64>
|
||||
// %mi = memref.buffer_cast %ei : memref<lweDim+1xi64>
|
||||
// memref.copy %mi, si : memref<lweDim+1xi64> to memref<lweDim+1xi64>
|
||||
auto subviewResultTy = mlir::MemRefType::get(
|
||||
{eltResultTy.getDimension() + 1}, newMemrefResultTy.getElementType());
|
||||
auto offset = 0;
|
||||
for (auto eiOp : fromElementsOp.elements()) {
|
||||
mlir::SmallVector<mlir::Attribute, 2> staticOffsets{
|
||||
rewriter.getI64IntegerAttr(offset), rewriter.getI64IntegerAttr(0)};
|
||||
mlir::SmallVector<mlir::Attribute, 2> staticSizes{
|
||||
rewriter.getI64IntegerAttr(1),
|
||||
rewriter.getI64IntegerAttr(eltResultTy.getDimension() + 1)};
|
||||
mlir::SmallVector<mlir::Attribute, 2> staticStrides{
|
||||
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)};
|
||||
auto siOp = rewriter.create<mlir::memref::SubViewOp>(
|
||||
fromElementsOp.getLoc(), subviewResultTy, mOp, mlir::ValueRange{},
|
||||
mlir::ValueRange{}, mlir::ValueRange{},
|
||||
rewriter.getArrayAttr(staticOffsets),
|
||||
rewriter.getArrayAttr(staticSizes),
|
||||
rewriter.getArrayAttr(staticStrides));
|
||||
auto miOp = rewriter.create<mlir::memref::BufferCastOp>(
|
||||
fromElementsOp.getLoc(), subviewResultTy, eiOp);
|
||||
rewriter.create<mlir::memref::CopyOp>(fromElementsOp.getLoc(), miOp,
|
||||
siOp);
|
||||
offset++;
|
||||
}
|
||||
|
||||
// Go back to tensor world
|
||||
// %0 = memref.tensor_load %m : memref<NxlweDim+1xi64>
|
||||
rewriter.replaceOpWithNewOp<mlir::memref::TensorLoadOp>(fromElementsOp,
|
||||
mOp);
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
// This template rewrite pattern transforms any instance of
|
||||
// `ShapeOp` operators that operates on tensor of lwe ciphertext by adding the
|
||||
// lwe size as a size of the tensor result and by adding a trivial reassociation
|
||||
// at the end of the reassociations map.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = "ShapeOp" %arg0 [reassocations...]
|
||||
// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>> into
|
||||
// tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = "ShapeOp" %arg0 [reassociations..., [inRank or outRank]]
|
||||
// : tensor<...xlweDimesion+1xi64> into
|
||||
// tensor<...xlweDimesion+1xi64>
|
||||
// ```
|
||||
template <typename ShapeOp, bool inRank>
|
||||
struct TensorShapeOpPattern : public mlir::OpRewritePattern<ShapeOp> {
|
||||
TensorShapeOpPattern(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<ShapeOp>(context, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(ShapeOp shapeOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
auto resultTy = shapeOp.result().getType();
|
||||
|
||||
auto newResultTy =
|
||||
((mlir::Type)converter.convertType(resultTy)).cast<mlir::MemRefType>();
|
||||
|
||||
// add [rank] to reassociations
|
||||
auto oldReassocs = shapeOp.getReassociationIndices();
|
||||
mlir::SmallVector<mlir::ReassociationIndices> newReassocs;
|
||||
newReassocs.append(oldReassocs.begin(), oldReassocs.end());
|
||||
mlir::ReassociationIndices lweAssoc;
|
||||
auto reassocTy =
|
||||
((mlir::Type)converter.convertType(
|
||||
(inRank ? shapeOp.src() : shapeOp.result()).getType()))
|
||||
.cast<mlir::MemRefType>();
|
||||
lweAssoc.push_back(reassocTy.getRank());
|
||||
newReassocs.push_back(lweAssoc);
|
||||
|
||||
rewriter.replaceOpWithNewOp<ShapeOp>(shapeOp, newResultTy, shapeOp.src(),
|
||||
newReassocs);
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
// Add the instantiated TensorShapeOpPattern rewrite pattern with the `ShapeOp`
|
||||
// to the patterns set and populate the conversion target.
|
||||
template <typename ShapeOp, bool inRank>
|
||||
void insertTensorShapeOpPattern(mlir::MLIRContext &context,
|
||||
mlir::RewritePatternSet &patterns,
|
||||
mlir::ConversionTarget &target) {
|
||||
patterns.insert<TensorShapeOpPattern<ShapeOp, inRank>>(&context);
|
||||
target.addDynamicallyLegalOp<ShapeOp>([&](ShapeOp op) {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
return converter.isLegal(op.result().getType());
|
||||
});
|
||||
}
|
||||
|
||||
// This template rewrite pattern transforms any instance of
|
||||
// `MemrefOp` operators that returns a memref of lwe ciphertext to the same
|
||||
// operator but which returns the bufferized lwe ciphertext.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = "MemrefOp"(...) : ... -> memref<...x!Concrete.lwe_ciphertext<lweDim,p>>
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = "MemrefOp"(...) : ... -> memref<...xlweDim+1xi64>
|
||||
// ```
|
||||
template <typename MemrefOp>
|
||||
struct MemrefOpPattern : public mlir::OpRewritePattern<MemrefOp> {
|
||||
MemrefOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<MemrefOp>(context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(MemrefOp memrefOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
|
||||
mlir::SmallVector<mlir::Type, 1> convertedTypes;
|
||||
if (converter.convertTypes(memrefOp->getResultTypes(), convertedTypes)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<MemrefOp>(memrefOp, convertedTypes,
|
||||
memrefOp->getOperands(),
|
||||
memrefOp->getAttrs());
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
// Add the instantiated MemrefOpPattern rewrite pattern with the `MemrefOp`
|
||||
// to the patterns set and populate the conversion target.
|
||||
template <typename... MemrefOp>
|
||||
void insertMemrefOpPattern(mlir::MLIRContext &context,
|
||||
mlir::RewritePatternSet &patterns,
|
||||
mlir::ConversionTarget &target) {
|
||||
(void)std::initializer_list<int>{
|
||||
0, (patterns.insert<MemrefOpPattern<MemrefOp>>(&context),
|
||||
target.addDynamicallyLegalOp<MemrefOp>([&](MemrefOp op) {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
return converter.isLegal(op->getResultTypes());
|
||||
}),
|
||||
0)...};
|
||||
}
|
||||
|
||||
// cc from Loops.cpp
|
||||
static mlir::SmallVector<mlir::Value>
|
||||
makeCanonicalAffineApplies(mlir::OpBuilder &b, mlir::Location loc,
|
||||
mlir::AffineMap map,
|
||||
mlir::ArrayRef<mlir::Value> vals) {
|
||||
if (map.isEmpty())
|
||||
return {};
|
||||
|
||||
assert(map.getNumInputs() == vals.size());
|
||||
mlir::SmallVector<mlir::Value> res;
|
||||
res.reserve(map.getNumResults());
|
||||
auto dims = map.getNumDims();
|
||||
for (auto e : map.getResults()) {
|
||||
auto exprMap = mlir::AffineMap::get(dims, map.getNumSymbols(), e);
|
||||
mlir::SmallVector<mlir::Value> operands(vals.begin(), vals.end());
|
||||
canonicalizeMapAndOperands(&exprMap, &operands);
|
||||
res.push_back(b.create<mlir::AffineApplyOp>(loc, exprMap, operands));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
static std::pair<mlir::Value, mlir::Value>
|
||||
makeOperandLoadOrSubview(mlir::OpBuilder &builder, mlir::Location loc,
|
||||
mlir::ArrayRef<mlir::Value> allIvs,
|
||||
mlir::linalg::LinalgOp linalgOp,
|
||||
mlir::OpOperand *operand) {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
|
||||
mlir::Value opVal = operand->get();
|
||||
mlir::MemRefType opTy = opVal.getType().cast<mlir::MemRefType>();
|
||||
|
||||
if (auto lweType =
|
||||
opTy.getElementType()
|
||||
.dyn_cast_or_null<
|
||||
mlir::concretelang::Concrete::LweCiphertextType>()) {
|
||||
// For memref of ciphertexts operands create the inner memref
|
||||
// subview to the ciphertext, and go back to the tensor type as BConcrete
|
||||
// operators works with tensor.
|
||||
// %op : memref<dim...xConcrete.lwe_ciphertext<lweDim,p>>
|
||||
// %opInner = memref.subview %opInner[offsets...][1...][1,...]
|
||||
// : memref<...xConcrete.lwe_ciphertext<lweDim,p>> to
|
||||
// memref<Concrete.lwe_ciphertext<lweDim,p>>
|
||||
|
||||
auto tensorizedLweTy =
|
||||
converter.convertType(lweType).cast<mlir::RankedTensorType>();
|
||||
auto subviewResultTy = mlir::MemRefType::get(
|
||||
tensorizedLweTy.getShape(), tensorizedLweTy.getElementType());
|
||||
auto offsets = makeCanonicalAffineApplies(
|
||||
builder, loc, linalgOp.getTiedIndexingMap(operand), allIvs);
|
||||
mlir::SmallVector<mlir::Attribute> staticOffsets(
|
||||
opTy.getRank(),
|
||||
builder.getI64IntegerAttr(std::numeric_limits<int64_t>::min()));
|
||||
mlir::SmallVector<mlir::Attribute> staticSizes(
|
||||
opTy.getRank(), builder.getI64IntegerAttr(1));
|
||||
mlir::SmallVector<mlir::Attribute> staticStrides(
|
||||
opTy.getRank(), builder.getI64IntegerAttr(1));
|
||||
|
||||
auto subViewOp = builder.create<mlir::memref::SubViewOp>(
|
||||
loc, subviewResultTy, opVal, offsets, mlir::ValueRange{},
|
||||
mlir::ValueRange{}, builder.getArrayAttr(staticOffsets),
|
||||
builder.getArrayAttr(staticSizes), builder.getArrayAttr(staticStrides));
|
||||
return std::pair<mlir::Value, mlir::Value>(
|
||||
subViewOp, builder.create<mlir::memref::TensorLoadOp>(loc, subViewOp));
|
||||
} else {
|
||||
// For memref of non ciphertexts load the value from the memref.
|
||||
// with %op : memref<dim...xip>
|
||||
// %opInner = memref.load %op[offsets...] : memref<dim...xip>
|
||||
auto offsets = makeCanonicalAffineApplies(
|
||||
builder, loc, linalgOp.getTiedIndexingMap(operand), allIvs);
|
||||
return std::pair<mlir::Value, mlir::Value>(
|
||||
nullptr,
|
||||
builder.create<mlir::memref::LoadOp>(loc, operand->get(), offsets));
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
inlineRegionAndEmitTensorStore(mlir::OpBuilder &builder, mlir::Location loc,
|
||||
mlir::linalg::LinalgOp linalgOp,
|
||||
llvm::ArrayRef<mlir::Value> indexedValues,
|
||||
mlir::ValueRange outputBuffers) {
|
||||
// Clone the block with the new operands
|
||||
auto &block = linalgOp->getRegion(0).front();
|
||||
mlir::BlockAndValueMapping map;
|
||||
map.map(block.getArguments(), indexedValues);
|
||||
for (auto &op : block.without_terminator()) {
|
||||
auto *newOp = builder.clone(op, map);
|
||||
map.map(op.getResults(), newOp->getResults());
|
||||
}
|
||||
// Create memref.tensor_store operation for each terminator operands
|
||||
auto *terminator = block.getTerminator();
|
||||
for (mlir::OpOperand &operand : terminator->getOpOperands()) {
|
||||
mlir::Value toStore = map.lookupOrDefault(operand.get());
|
||||
builder.create<mlir::memref::TensorStoreOp>(
|
||||
loc, toStore, outputBuffers[operand.getOperandNumber()]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LoopType>
|
||||
class LinalgRewritePattern
|
||||
: public mlir::OpInterfaceConversionPattern<mlir::linalg::LinalgOp> {
|
||||
public:
|
||||
using mlir::OpInterfaceConversionPattern<
|
||||
mlir::linalg::LinalgOp>::OpInterfaceConversionPattern;
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::linalg::LinalgOp linalgOp,
|
||||
mlir::ArrayRef<mlir::Value> operands,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
assert(linalgOp.hasBufferSemantics() &&
|
||||
"expected linalg op with buffer semantics");
|
||||
|
||||
auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
|
||||
auto iteratorTypes =
|
||||
llvm::to_vector<4>(linalgOp.iterator_types().getValue());
|
||||
|
||||
mlir::SmallVector<mlir::Value> allIvs;
|
||||
mlir::linalg::GenerateLoopNest<LoopType>::doit(
|
||||
rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes,
|
||||
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange ivs,
|
||||
mlir::ValueRange operandValuesToUse) -> mlir::scf::ValueVector {
|
||||
// Keep indexed values to replace the linalg.generic block arguments
|
||||
// by them
|
||||
mlir::SmallVector<mlir::Value> indexedValues;
|
||||
indexedValues.reserve(linalgOp.getNumInputsAndOutputs());
|
||||
assert(
|
||||
operandValuesToUse == linalgOp->getOperands() &&
|
||||
"expect operands are captured and not passed by loop argument");
|
||||
allIvs.append(ivs.begin(), ivs.end());
|
||||
|
||||
// For all input operands create the inner operand
|
||||
for (mlir::OpOperand *inputOperand : linalgOp.getInputOperands()) {
|
||||
auto innerOperand = makeOperandLoadOrSubview(
|
||||
builder, loc, allIvs, linalgOp, inputOperand);
|
||||
indexedValues.push_back(innerOperand.second);
|
||||
}
|
||||
|
||||
// For all output operands create the inner operand
|
||||
assert(linalgOp.getOutputOperands() ==
|
||||
linalgOp.getOutputBufferOperands() &&
|
||||
"expect only memref as output operands");
|
||||
mlir::SmallVector<mlir::Value> outputBuffers;
|
||||
for (mlir::OpOperand *outputOperand : linalgOp.getOutputOperands()) {
|
||||
auto innerOperand = makeOperandLoadOrSubview(
|
||||
builder, loc, allIvs, linalgOp, outputOperand);
|
||||
indexedValues.push_back(innerOperand.second);
|
||||
assert(innerOperand.first != nullptr &&
|
||||
"Expected a memref subview as output buffer");
|
||||
outputBuffers.push_back(innerOperand.first);
|
||||
}
|
||||
// Finally inline the linalgOp region
|
||||
inlineRegionAndEmitTensorStore(builder, loc, linalgOp, indexedValues,
|
||||
outputBuffers);
|
||||
|
||||
return mlir::scf::ValueVector{};
|
||||
});
|
||||
rewriter.eraseOp(linalgOp);
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
void ConcreteToBConcretePass::runOnOperation() {
|
||||
auto op = this->getOperation();
|
||||
|
||||
// First of all we transform LinalgOp that work on tensor of ciphertext to
|
||||
// work on memref.
|
||||
{
|
||||
mlir::ConversionTarget target(getContext());
|
||||
mlir::BufferizeTypeConverter converter;
|
||||
|
||||
// Mark all Standard operations legal.
|
||||
target
|
||||
.addLegalDialect<mlir::arith::ArithmeticDialect, mlir::AffineDialect,
|
||||
mlir::memref::MemRefDialect, mlir::StandardOpsDialect,
|
||||
mlir::tensor::TensorDialect>();
|
||||
|
||||
// Mark all Linalg operations illegal as long as they work on encrypted
|
||||
// tensors.
|
||||
target.addDynamicallyLegalOp<mlir::linalg::GenericOp, mlir::linalg::YieldOp,
|
||||
mlir::linalg::CopyOp>(
|
||||
[&](mlir::Operation *op) { return converter.isLegal(op); });
|
||||
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
mlir::linalg::populateLinalgBufferizePatterns(converter, patterns);
|
||||
if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Then convert ciphertext to tensor or add a dimension to tensor of
|
||||
// ciphertext and memref of ciphertext
|
||||
{
|
||||
mlir::ConversionTarget target(getContext());
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
mlir::OwningRewritePatternList patterns(&getContext());
|
||||
|
||||
// All BConcrete ops are legal after the conversion
|
||||
target.addLegalDialect<mlir::concretelang::BConcrete::BConcreteDialect>();
|
||||
|
||||
// Add Concrete ops are illegal after the conversion unless those which are
|
||||
// explicitly marked as legal (more or less operators that didn't work on
|
||||
// ciphertexts)
|
||||
target.addIllegalDialect<mlir::concretelang::Concrete::ConcreteDialect>();
|
||||
target.addLegalOp<mlir::concretelang::Concrete::EncodeIntOp>();
|
||||
target.addLegalOp<mlir::concretelang::Concrete::GlweFromTable>();
|
||||
target.addLegalOp<mlir::concretelang::Concrete::IntToCleartextOp>();
|
||||
|
||||
// Add patterns to convert the zero ops to tensor.generate
|
||||
patterns
|
||||
.insert<ZeroOpPattern<mlir::concretelang::Concrete::ZeroTensorLWEOp>,
|
||||
ZeroOpPattern<mlir::concretelang::Concrete::ZeroLWEOp>>(
|
||||
&getContext());
|
||||
target.addLegalOp<mlir::tensor::GenerateOp, mlir::tensor::YieldOp>();
|
||||
|
||||
// Add patterns to trivialy convert Concrete op to the equivalent
|
||||
// BConcrete op
|
||||
target.addLegalOp<mlir::linalg::InitTensorOp>();
|
||||
patterns.insert<
|
||||
LowToBConcrete<mlir::concretelang::Concrete::AddLweCiphertextsOp,
|
||||
mlir::concretelang::BConcrete::AddLweBuffersOp>,
|
||||
LowToBConcrete<
|
||||
mlir::concretelang::Concrete::AddPlaintextLweCiphertextOp,
|
||||
mlir::concretelang::BConcrete::AddPlaintextLweBufferOp>,
|
||||
LowToBConcrete<
|
||||
mlir::concretelang::Concrete::MulCleartextLweCiphertextOp,
|
||||
mlir::concretelang::BConcrete::MulCleartextLweBufferOp>,
|
||||
LowToBConcrete<
|
||||
mlir::concretelang::Concrete::MulCleartextLweCiphertextOp,
|
||||
mlir::concretelang::BConcrete::MulCleartextLweBufferOp>,
|
||||
LowToBConcrete<mlir::concretelang::Concrete::NegateLweCiphertextOp,
|
||||
mlir::concretelang::BConcrete::NegateLweBufferOp>,
|
||||
LowToBConcrete<mlir::concretelang::Concrete::KeySwitchLweOp,
|
||||
mlir::concretelang::BConcrete::KeySwitchLweBufferOp>,
|
||||
LowToBConcrete<mlir::concretelang::Concrete::BootstrapLweOp,
|
||||
mlir::concretelang::BConcrete::BootstrapLweBufferOp>>(
|
||||
&getContext());
|
||||
|
||||
// Add patterns to rewrite tensor operators that works on encrypted tensors
|
||||
patterns.insert<ExtractSliceOpPattern, ExtractOpPattern,
|
||||
InsertSliceOpPattern, FromElementsOpPattern>(&getContext());
|
||||
target.addDynamicallyLegalOp<
|
||||
mlir::tensor::ExtractSliceOp, mlir::tensor::ExtractOp,
|
||||
mlir::tensor::InsertSliceOp, mlir::tensor::FromElementsOp>(
|
||||
[&](mlir::Operation *op) {
|
||||
return converter.isLegal(op->getResult(0).getType());
|
||||
});
|
||||
target.addLegalOp<mlir::memref::CopyOp,
|
||||
mlir::linalg::TensorCollapseShapeOp>();
|
||||
|
||||
// Add patterns to rewrite some of memref ops that was introduced by the
|
||||
// linalg bufferization of encrypted tensor (first conversion of this pass)
|
||||
insertTensorShapeOpPattern<mlir::memref::ExpandShapeOp, true>(
|
||||
getContext(), patterns, target);
|
||||
insertTensorShapeOpPattern<mlir::memref::CollapseShapeOp, false>(
|
||||
getContext(), patterns, target);
|
||||
|
||||
// Add patterns to rewrite linalg op to nested loops with views on
|
||||
// ciphertexts
|
||||
patterns.insert<LinalgRewritePattern<mlir::scf::ForOp>>(converter,
|
||||
&getContext());
|
||||
target.addLegalOp<mlir::arith::ConstantOp, mlir::scf::ForOp,
|
||||
mlir::scf::YieldOp, mlir::AffineApplyOp,
|
||||
mlir::memref::SubViewOp, mlir::memref::LoadOp,
|
||||
mlir::memref::TensorStoreOp>();
|
||||
|
||||
// Add patterns to do the conversion of func
|
||||
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
|
||||
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
|
||||
return converter.isSignatureLegal(funcOp.getType()) &&
|
||||
converter.isLegal(&funcOp.getBody());
|
||||
});
|
||||
|
||||
// Add patterns to convert some memref operators that is generated by
|
||||
// previous step
|
||||
insertMemrefOpPattern<mlir::memref::AllocOp, mlir::memref::BufferCastOp,
|
||||
mlir::memref::TensorLoadOp>(getContext(), patterns,
|
||||
target);
|
||||
|
||||
// Conversion of RT Dialect Ops
|
||||
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DataflowTaskOp>>(patterns.getContext(),
|
||||
converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DataflowTaskOp>(target, converter);
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
.failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertConcreteToBConcretePass() {
|
||||
return std::make_unique<ConcreteToBConcretePass>();
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -264,16 +264,6 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
if (target == Target::CONCRETE)
|
||||
return std::move(res);
|
||||
|
||||
// Concrete -> Canonical dialects
|
||||
if (mlir::concretelang::pipeline::lowerConcreteToStd(mlirContext, module,
|
||||
enablePass)
|
||||
.failed()) {
|
||||
return errorDiag(
|
||||
"Lowering from Concrete to canonical MLIR dialects failed");
|
||||
}
|
||||
if (target == Target::STD)
|
||||
return std::move(res);
|
||||
|
||||
// Generate client parameters if requested
|
||||
if (this->generateClientParameters) {
|
||||
if (!this->clientParametersFuncName.hasValue()) {
|
||||
@@ -304,6 +294,28 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
}
|
||||
}
|
||||
|
||||
// Concrete -> BConcrete
|
||||
if (mlir::concretelang::pipeline::lowerConcreteToBConcrete(
|
||||
mlirContext, module, this->enablePass)
|
||||
.failed()) {
|
||||
return StreamStringError(
|
||||
"Lowering from Concrete to Bufferized Concrete failed");
|
||||
}
|
||||
|
||||
if (target == Target::BCONCRETE) {
|
||||
return std::move(res);
|
||||
}
|
||||
|
||||
// BConcrete -> Canonical dialects
|
||||
if (mlir::concretelang::pipeline::lowerBConcreteToStd(mlirContext, module,
|
||||
enablePass)
|
||||
.failed()) {
|
||||
return errorDiag(
|
||||
"Lowering from Bufferized Concrete to canonical MLIR dialects failed");
|
||||
}
|
||||
if (target == Target::STD)
|
||||
return std::move(res);
|
||||
|
||||
// MLIR canonical dialects -> LLVM Dialect
|
||||
if (mlir::concretelang::pipeline::lowerStdToLLVMDialect(
|
||||
mlirContext, module, enablePass,
|
||||
|
||||
@@ -181,10 +181,25 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("ConcreteToStd", pm, context);
|
||||
pipelinePrinting("ConcreteToBConcrete", pm, context);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createConvertConcreteToBConcretePass(),
|
||||
enablePass);
|
||||
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("BConcreteToStd", pm, context);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createConvertBConcreteToBConcreteCAPIPass(),
|
||||
enablePass);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createConvertConcreteToConcreteCAPIPass(),
|
||||
enablePass);
|
||||
|
||||
@@ -41,6 +41,7 @@ enum Action {
|
||||
DUMP_FHE,
|
||||
DUMP_TFHE,
|
||||
DUMP_CONCRETE,
|
||||
DUMP_BCONCRETE,
|
||||
DUMP_STD,
|
||||
DUMP_LLVM_DIALECT,
|
||||
DUMP_LLVM_IR,
|
||||
@@ -101,6 +102,9 @@ static llvm::cl::opt<enum Action> action(
|
||||
"Lower to TFHE and dump result")),
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_CONCRETE, "dump-concrete",
|
||||
"Lower to Concrete and dump result")),
|
||||
llvm::cl::values(
|
||||
clEnumValN(Action::DUMP_BCONCRETE, "dump-bconcrete",
|
||||
"Lower to Bufferized Concrete and dump result")),
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_STD, "dump-std",
|
||||
"Lower to std and dump result")),
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_LLVM_DIALECT, "dump-llvm-dialect",
|
||||
@@ -324,6 +328,9 @@ mlir::LogicalResult processInputBuffer(
|
||||
case Action::DUMP_CONCRETE:
|
||||
target = mlir::concretelang::CompilerEngine::Target::CONCRETE;
|
||||
break;
|
||||
case Action::DUMP_BCONCRETE:
|
||||
target = mlir::concretelang::CompilerEngine::Target::BCONCRETE;
|
||||
break;
|
||||
case Action::DUMP_STD:
|
||||
target = mlir::concretelang::CompilerEngine::Target::STD;
|
||||
break;
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
// RUN: concretecompiler --passes bconcrete-to-bconcrete-c-api --action=dump-std %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func @add_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<2049xi64>, %arg2: !Concrete.context) -> tensor<2049xi64>
|
||||
func @add_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<2049xi64>) -> tensor<2049xi64> {
|
||||
// CHECK-NEXT: %0 = linalg.init_tensor [2049] : tensor<2049xi64>
|
||||
// CHECK-NEXT: %1 = tensor.cast %0 : tensor<2049xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: %2 = tensor.cast %arg0 : tensor<2049xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: %3 = tensor.cast %arg1 : tensor<2049xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: call @memref_add_lwe_ciphertexts_u64(%1, %2, %3) : (tensor<?xi64>, tensor<?xi64>, tensor<?xi64>) -> ()
|
||||
// CHECK-NEXT: return %0 : tensor<2049xi64>
|
||||
%0 = linalg.init_tensor [2049] : tensor<2049xi64>
|
||||
"BConcrete.add_lwe_buffer"(%0, %arg0, %arg1) : (tensor<2049xi64>, tensor<2049xi64>, tensor<2049xi64>) -> ()
|
||||
return %0 : tensor<2049xi64>
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
// RUN: concretecompiler --passes bconcrete-to-bconcrete-c-api --action=dump-std %s 2>&1| FileCheck %s
|
||||
|
||||
|
||||
// CHECK-LABEL: func @add_glwe_const_int(%arg0: tensor<1025xi64>, %arg1: !Concrete.context) -> tensor<1025xi64>
|
||||
func @add_glwe_const_int(%arg0: tensor<1025xi64>) -> tensor<1025xi64> {
|
||||
// CHECK-NEXT: %c1_i8 = arith.constant 1 : i8
|
||||
// CHECK-NEXT: %0 = arith.extui %c1_i8 : i8 to i64
|
||||
// CHECK-NEXT: %c56_i64 = arith.constant 56 : i64
|
||||
// CHECK-NEXT: %1 = arith.shli %0, %c56_i64 : i64
|
||||
// CHECK-NEXT: %2 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
// CHECK-NEXT: %3 = tensor.cast %2 : tensor<1025xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: %4 = tensor.cast %arg0 : tensor<1025xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: call @memref_add_plaintext_lwe_ciphertext_u64(%3, %4, %1) : (tensor<?xi64>, tensor<?xi64>, i64) -> ()
|
||||
// CHECK-NEXT: return %2 : tensor<1025xi64>
|
||||
%0 = arith.constant 1 : i8
|
||||
%1 = "Concrete.encode_int"(%0) : (i8) -> !Concrete.plaintext<8>
|
||||
%2 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
"BConcrete.add_plaintext_lwe_buffer"(%2, %arg0, %1) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.plaintext<8>) -> ()
|
||||
return %2 : tensor<1025xi64>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func @add_glwe_int(%arg0: tensor<1025xi64>, %arg1: i5, %arg2: !Concrete.context) -> tensor<1025xi64>
|
||||
func @add_glwe_int(%arg0: tensor<1025xi64>, %arg1: i5) -> tensor<1025xi64> {
|
||||
// CHECK-NEXT: %0 = arith.extui %arg1 : i5 to i64
|
||||
// CHECK-NEXT: %c59_i64 = arith.constant 59 : i64
|
||||
// CHECK-NEXT: %1 = arith.shli %0, %c59_i64 : i64
|
||||
// CHECK-NEXT: %2 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
// CHECK-NEXT: %3 = tensor.cast %2 : tensor<1025xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: %4 = tensor.cast %arg0 : tensor<1025xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: call @memref_add_plaintext_lwe_ciphertext_u64(%3, %4, %1) : (tensor<?xi64>, tensor<?xi64>, i64) -> ()
|
||||
// CHECK-NEXT: return %2 : tensor<1025xi64>
|
||||
%0 = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5>
|
||||
%1 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
"BConcrete.add_plaintext_lwe_buffer"(%1, %arg0, %0) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.plaintext<5>) -> ()
|
||||
return %1 : tensor<1025xi64>
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
// RUN: concretecompiler --passes bconcrete-to-bconcrete-c-api --action=dump-std %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func @bootstrap_lwe(%arg0: tensor<1025xi64>, %arg1: !Concrete.glwe_ciphertext, %arg2: !Concrete.context) -> tensor<1025xi64> {
|
||||
// CHECK-NEXT: %0 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
// CHECK-NEXT: %1 = call @get_bootstrap_key(%arg2) : (!Concrete.context) -> !Concrete.lwe_bootstrap_key
|
||||
// CHECK-NEXT: %2 = tensor.cast %0 : tensor<1025xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: %3 = tensor.cast %arg0 : tensor<1025xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: call @memref_bootstrap_lwe_u64(%1, %2, %3, %arg1) : (!Concrete.lwe_bootstrap_key, tensor<?xi64>, tensor<?xi64>, !Concrete.glwe_ciphertext) -> ()
|
||||
// CHECK-NEXT: return %0 : tensor<1025xi64>
|
||||
// CHECK-NEXT: }
|
||||
func @bootstrap_lwe(%arg0: tensor<1025xi64>, %arg1: !Concrete.glwe_ciphertext) -> tensor<1025xi64> {
|
||||
%0 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
"BConcrete.bootstrap_lwe_buffer"(%0, %arg0, %arg1) {baseLog = 2 : i32, glweDimension = 1 : i32, level = 3 : i32, polynomialSize = 1024 : i32} : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.glwe_ciphertext) -> ()
|
||||
return %0 : tensor<1025xi64>
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
// RUN: concretecompiler --passes bconcrete-to-bconcrete-c-api --action=dump-std %s 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func @keyswitch_lwe(%arg0: tensor<1025xi64>, %arg1: !Concrete.context) -> tensor<1025xi64> {
|
||||
//CHECK-NEXT: %0 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
//CHECK-NEXT: %1 = call @get_keyswitch_key(%arg1) : (!Concrete.context) -> !Concrete.lwe_key_switch_key
|
||||
//CHECK-NEXT: %2 = tensor.cast %0 : tensor<1025xi64> to tensor<?xi64>
|
||||
//CHECK-NEXT: %3 = tensor.cast %arg0 : tensor<1025xi64> to tensor<?xi64>
|
||||
//CHECK-NEXT: call @memref_keyswitch_lwe_u64(%1, %2, %3) : (!Concrete.lwe_key_switch_key, tensor<?xi64>, tensor<?xi64>) -> ()
|
||||
//CHECK-NEXT: return %0 : tensor<1025xi64>
|
||||
//CHECK-NEXT: }
|
||||
func @keyswitch_lwe(%arg0: tensor<1025xi64>) -> tensor<1025xi64> {
|
||||
%0 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
"BConcrete.keyswitch_lwe_buffer"(%0, %arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 1 : i32} : (tensor<1025xi64>, tensor<1025xi64>) -> ()
|
||||
return %0 : tensor<1025xi64>
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
// RUN: concretecompiler --passes bconcrete-to-bconcrete-c-api --action=dump-std %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @mul_lwe_const_int(%arg0: tensor<1025xi64>, %arg1: !Concrete.context) -> tensor<1025xi64>
|
||||
func @mul_lwe_const_int(%arg0: tensor<1025xi64>) -> tensor<1025xi64> {
|
||||
// CHECK-NEXT: %c1_i8 = arith.constant 1 : i8
|
||||
// CHECK-NEXT: %0 = arith.extui %c1_i8 : i8 to i64
|
||||
// CHECK-NEXT: %1 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
// CHECK-NEXT: %2 = tensor.cast %1 : tensor<1025xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: %3 = tensor.cast %arg0 : tensor<1025xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: call @memref_mul_cleartext_lwe_ciphertext_u64(%2, %3, %0) : (tensor<?xi64>, tensor<?xi64>, i64) -> ()
|
||||
// CHECK-NEXT: return %1 : tensor<1025xi64>
|
||||
%c1_i8 = arith.constant 1 : i8
|
||||
%1 = "Concrete.int_to_cleartext"(%c1_i8) : (i8) -> !Concrete.cleartext<8>
|
||||
%2 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
"BConcrete.mul_cleartext_lwe_buffer"(%2, %arg0, %1) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.cleartext<8>) -> ()
|
||||
return %2 : tensor<1025xi64>
|
||||
}
|
||||
|
||||
|
||||
|
||||
// CHECK-LABEL: func @mul_lwe_int(%arg0: tensor<1025xi64>, %arg1: i5, %arg2: !Concrete.context) -> tensor<1025xi64>
|
||||
func @mul_lwe_int(%arg0: tensor<1025xi64>, %arg1: i5) -> tensor<1025xi64> {
|
||||
// CHECK-NEXT: %0 = arith.extui %arg1 : i5 to i64
|
||||
// CHECK-NEXT: %1 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
// CHECK-NEXT: %2 = tensor.cast %1 : tensor<1025xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: %3 = tensor.cast %arg0 : tensor<1025xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: call @memref_mul_cleartext_lwe_ciphertext_u64(%2, %3, %0) : (tensor<?xi64>, tensor<?xi64>, i64) -> ()
|
||||
// CHECK-NEXT: return %1 : tensor<1025xi64>
|
||||
%0 = "Concrete.int_to_cleartext"(%arg1) : (i5) -> !Concrete.cleartext<5>
|
||||
%1 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
"BConcrete.mul_cleartext_lwe_buffer"(%1, %arg0, %0) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.cleartext<5>) -> ()
|
||||
return %1 : tensor<1025xi64>
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
// RUN: concretecompiler --passes bconcrete-to-bconcrete-c-api --action=dump-std %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @neg_lwe(%arg0: tensor<1025xi64>, %arg1: !Concrete.context) -> tensor<1025xi64> {
|
||||
func @neg_lwe(%arg0: tensor<1025xi64>) -> tensor<1025xi64> {
|
||||
// CHECK-NEXT: %0 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
// CHECK-NEXT: %1 = tensor.cast %0 : tensor<1025xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: %2 = tensor.cast %arg0 : tensor<1025xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: call @memref_negate_lwe_ciphertext_u64(%1, %2) : (tensor<?xi64>, tensor<?xi64>) -> ()
|
||||
// CHECK-NEXT: return %0 : tensor<1025xi64>
|
||||
%0 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
"BConcrete.negate_lwe_buffer"(%0, %arg0) : (tensor<1025xi64>, tensor<1025xi64>) -> ()
|
||||
return %0 : tensor<1025xi64>
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
// RUN: concretecompiler --passes bconcrete-to-bconcrete-c-api --action=dump-std %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @sub_const_int_lwe(%arg0: tensor<1025xi64>, %arg1: !Concrete.context) -> tensor<1025xi64> {
|
||||
func @sub_const_int_lwe(%arg0: tensor<1025xi64>) -> tensor<1025xi64> {
|
||||
// CHECK-NEXT: %c1_i8 = arith.constant 1 : i8
|
||||
// CHECK-NEXT: %0 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
// CHECK-NEXT: %1 = tensor.cast %0 : tensor<1025xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: %2 = tensor.cast %arg0 : tensor<1025xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: call @memref_negate_lwe_ciphertext_u64(%1, %2) : (tensor<?xi64>, tensor<?xi64>) -> ()
|
||||
// CHECK-NEXT: %3 = arith.extui %c1_i8 : i8 to i64
|
||||
// CHECK-NEXT: %c56_i64 = arith.constant 56 : i64
|
||||
// CHECK-NEXT: %4 = arith.shli %3, %c56_i64 : i64
|
||||
// CHECK-NEXT: %5 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
// CHECK-NEXT: %6 = tensor.cast %5 : tensor<1025xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: %7 = tensor.cast %0 : tensor<1025xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: call @memref_add_plaintext_lwe_ciphertext_u64(%6, %7, %4) : (tensor<?xi64>, tensor<?xi64>, i64) -> ()
|
||||
// CHECK-NEXT: return %5 : tensor<1025xi64>
|
||||
%0 = arith.constant 1 : i8
|
||||
%1 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
"BConcrete.negate_lwe_buffer"(%1, %arg0) : (tensor<1025xi64>, tensor<1025xi64>) -> ()
|
||||
%2 = "Concrete.encode_int"(%0) : (i8) -> !Concrete.plaintext<8>
|
||||
%3 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
"BConcrete.add_plaintext_lwe_buffer"(%3, %1, %2) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.plaintext<8>) -> ()
|
||||
return %3 : tensor<1025xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @sub_int_lwe(%arg0: tensor<1025xi64>, %arg1: i5, %arg2: !Concrete.context) -> tensor<1025xi64> {
|
||||
func @sub_int_lwe(%arg0: tensor<1025xi64>, %arg1: i5) -> tensor<1025xi64> {
|
||||
// CHECK-NEXT: %0 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
// CHECK-NEXT: %1 = tensor.cast %0 : tensor<1025xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: %2 = tensor.cast %arg0 : tensor<1025xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: call @memref_negate_lwe_ciphertext_u64(%1, %2) : (tensor<?xi64>, tensor<?xi64>) -> ()
|
||||
// CHECK-NEXT: %3 = arith.extui %arg1 : i5 to i64
|
||||
// CHECK-NEXT: %c59_i64 = arith.constant 59 : i64
|
||||
// CHECK-NEXT: %4 = arith.shli %3, %c59_i64 : i64
|
||||
// CHECK-NEXT: %5 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
// CHECK-NEXT: %6 = tensor.cast %5 : tensor<1025xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: %7 = tensor.cast %0 : tensor<1025xi64> to tensor<?xi64>
|
||||
// CHECK-NEXT: call @memref_add_plaintext_lwe_ciphertext_u64(%6, %7, %4) : (tensor<?xi64>, tensor<?xi64>, i64) -> ()
|
||||
// CHECK-NEXT: return %5 : tensor<1025xi64>
|
||||
%0 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
"BConcrete.negate_lwe_buffer"(%0, %arg0) : (tensor<1025xi64>, tensor<1025xi64>) -> ()
|
||||
%1 = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5>
|
||||
%2 = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
"BConcrete.add_plaintext_lwe_buffer"(%2, %0, %1) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.plaintext<5>) -> ()
|
||||
return %2 : tensor<1025xi64>
|
||||
}
|
||||
10
compiler/tests/Conversion/ConcreteToBConcrete/add_lwe.mlir
Normal file
10
compiler/tests/Conversion/ConcreteToBConcrete/add_lwe.mlir
Normal file
@@ -0,0 +1,10 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @add_glwe(%arg0: tensor<2049xi64>, %arg1: tensor<2049xi64>) -> tensor<2049xi64>
|
||||
func @add_glwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = linalg.init_tensor [2049] : tensor<2049xi64>
|
||||
// CHECK-NEXT: "BConcrete.add_lwe_buffer"(%[[V1]], %arg0, %arg1) : (tensor<2049xi64>, tensor<2049xi64>, tensor<2049xi64>) -> ()
|
||||
// CHECK-NEXT: return %[[V1]] : tensor<2049xi64>
|
||||
%0 = "Concrete.add_lwe_ciphertexts"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
return %0 : !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @add_glwe_const_int(%arg0: tensor<1025xi64>) -> tensor<1025xi64>
|
||||
func @add_glwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8
|
||||
// CHECK-NEXT: %[[V2:.*]] = "Concrete.encode_int"(%[[V1]]) : (i8) -> !Concrete.plaintext<8>
|
||||
// CHECK-NEXT: %[[V3:.*]] = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
// CHECK-NEXT: "BConcrete.add_plaintext_lwe_buffer"(%1, %arg0, %0) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.plaintext<8>) -> ()
|
||||
// CHECK-NEXT: return %[[V3]] : tensor<1025xi64>
|
||||
%0 = arith.constant 1 : i8
|
||||
%1 = "Concrete.encode_int"(%0) : (i8) -> !Concrete.plaintext<8>
|
||||
%2 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %1) : (!Concrete.lwe_ciphertext<1024,7>, !Concrete.plaintext<8>) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
return %2 : !Concrete.lwe_ciphertext<1024,7>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @add_glwe_int(%arg0: tensor<1025xi64>, %arg1: i5) -> tensor<1025xi64>
|
||||
func @add_glwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5>
|
||||
// CHECK-NEXT: %[[V2:.*]] = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
// CHECK-NEXT: "BConcrete.add_plaintext_lwe_buffer"(%[[V2:.*]], %arg0, %[[V1:.*]]) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.plaintext<5>) -> ()
|
||||
// CHECK-NEXT: return %[[V2]] : tensor<1025xi64>
|
||||
%0 = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5>
|
||||
%1 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,4>, !Concrete.plaintext<5>) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
return %1 : !Concrete.lwe_ciphertext<1024,4>
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @apply_lookup_table(%arg0: tensor<1025xi64>, %arg1: tensor<16xi64>) -> tensor<1025xi64>
|
||||
func @apply_lookup_table(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: tensor<16xi64>) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%arg1) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext
|
||||
// CHECK-NEXT: %[[V2:.*]] = linalg.init_tensor [601] : tensor<601xi64>
|
||||
// CHECK-NEXT: "BConcrete.keyswitch_lwe_buffer"(%[[V2]], %arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (tensor<601xi64>, tensor<1025xi64>) -> ()
|
||||
// CHECK-NEXT: %[[V3:.*]] = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
// CHECK-NEXT: "BConcrete.bootstrap_lwe_buffer"(%[[V3]], %[[V2]], %[[V1]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (tensor<1025xi64>, tensor<601xi64>, !Concrete.glwe_ciphertext) -> ()
|
||||
// CHECK-NEXT: return %[[V3]] : tensor<1025xi64>
|
||||
%0 = "Concrete.glwe_from_table"(%arg1) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext
|
||||
%1 = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<600,4>
|
||||
%2 = "Concrete.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
return %2 : !Concrete.lwe_ciphertext<1024,4>
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @apply_lookup_table_cst(%arg0: tensor<2049xi64>) -> tensor<2049xi64>
|
||||
func @apply_lookup_table_cst(%arg0: !Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<2048,4> {
|
||||
// CHECK-NEXT: %[[TABLE:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%[[TABLE:.*]]) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext
|
||||
// CHECK-NEXT: %[[V2:.*]] = linalg.init_tensor [601] : tensor<601xi64>
|
||||
// CHECK-NEXT: "BConcrete.keyswitch_lwe_buffer"([[V2:.*]], %arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (tensor<601xi64>, tensor<2049xi64>) -> ()
|
||||
// CHECK-NEXT: %[[V3:.*]] = linalg.init_tensor [2049] : tensor<2049xi64>
|
||||
// CHECK-NEXT: "BConcrete.bootstrap_lwe_buffer"(%[[V3:.*]], %[[V2:.*]], %[[V1:.*]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (tensor<2049xi64>, tensor<601xi64>, !Concrete.glwe_ciphertext) -> ()
|
||||
// CHECK-NEXT: return %[[V3]] : tensor<2049xi64>
|
||||
%tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
|
||||
%0 = "Concrete.glwe_from_table"(%tlu) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext
|
||||
%1 = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (!Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<600,4>
|
||||
%2 = "Concrete.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext) -> !Concrete.lwe_ciphertext<2048,4>
|
||||
return %2 : !Concrete.lwe_ciphertext<2048,4>
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func @identity(%arg0: tensor<1025xi64>) -> tensor<1025xi64> {
|
||||
// CHECK-NEXT: return %arg0 : tensor<1025xi64>
|
||||
// CHECK-NEXT: }
|
||||
func @identity(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
|
||||
return %arg0 : !Concrete.lwe_ciphertext<1024,7>
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @mul_lwe_const_int(%arg0: tensor<1025xi64>) -> tensor<1025xi64>
|
||||
func @mul_lwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8
|
||||
// CHECK-NEXT: %[[V2:.*]] = "Concrete.int_to_cleartext"(%c1_i8) : (i8) -> !Concrete.cleartext<8>
|
||||
// CHECK-NEXT: %[[V3:.*]] = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
// CHECK-NEXT: "BConcrete.mul_cleartext_lwe_buffer"(%[[V3]], %arg0, %[[V2]]) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.cleartext<8>) -> ()
|
||||
// CHECK-NEXT: return %[[V3]] : tensor<1025xi64>
|
||||
%0 = arith.constant 1 : i8
|
||||
%1 = "Concrete.int_to_cleartext"(%0) : (i8) -> !Concrete.cleartext<8>
|
||||
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %1) : (!Concrete.lwe_ciphertext<1024,7>, !Concrete.cleartext<8>) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
return %2 : !Concrete.lwe_ciphertext<1024,7>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @mul_lwe_int(%arg0: tensor<1025xi64>, %arg1: i5) -> tensor<1025xi64>
|
||||
func @mul_lwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.int_to_cleartext"(%arg1) : (i5) -> !Concrete.cleartext<5>
|
||||
// CHECK-NEXT: %[[V2:.*]] = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
// CHECK-NEXT: "BConcrete.mul_cleartext_lwe_buffer"(%[[V2]], %arg0, %[[V1]]) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.cleartext<5>) -> ()
|
||||
// CHECK-NEXT: return %[[V2]] : tensor<1025xi64>
|
||||
%0 = "Concrete.int_to_cleartext"(%arg1) : (i5) -> !Concrete.cleartext<5>
|
||||
%1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,4>, !Concrete.cleartext<5>) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
return %1 : !Concrete.lwe_ciphertext<1024,4>
|
||||
}
|
||||
10
compiler/tests/Conversion/ConcreteToBConcrete/neg_lwe.mlir
Normal file
10
compiler/tests/Conversion/ConcreteToBConcrete/neg_lwe.mlir
Normal file
@@ -0,0 +1,10 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @neg_lwe(%arg0: tensor<1025xi64>) -> tensor<1025xi64>
|
||||
func @neg_lwe(%arg0: !Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
// CHECK-NEXT: "BConcrete.negate_lwe_buffer"(%[[V1]], %arg0) : (tensor<1025xi64>, tensor<1025xi64>) -> ()
|
||||
// CHECK-NEXT: return %[[V1]] : tensor<1025xi64>
|
||||
%0 = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
return %0 : !Concrete.lwe_ciphertext<1024,4>
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @sub_const_int_lwe(%arg0: tensor<1025xi64>) -> tensor<1025xi64>
|
||||
func @sub_const_int_lwe(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8
|
||||
// CHECK-NEXT: %[[V2:.*]] = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
// CHECK-NEXT: "BConcrete.negate_lwe_buffer"(%[[V2]], %arg0) : (tensor<1025xi64>, tensor<1025xi64>) -> ()
|
||||
// CHECK-NEXT: %[[V3:.*]] = "Concrete.encode_int"(%[[V1]]) : (i8) -> !Concrete.plaintext<8>
|
||||
// CHECK-NEXT: %[[V4:.*]] = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
// CHECK-NEXT: "BConcrete.add_plaintext_lwe_buffer"(%[[V4]], %[[V2]], %[[V3]]) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.plaintext<8>) -> ()
|
||||
// CHECK-NEXT: return %[[V4]] : tensor<1025xi64>
|
||||
%0 = arith.constant 1 : i8
|
||||
%1 = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
%2 = "Concrete.encode_int"(%0) : (i8) -> !Concrete.plaintext<8>
|
||||
%3 = "Concrete.add_plaintext_lwe_ciphertext"(%1, %2) : (!Concrete.lwe_ciphertext<1024,7>, !Concrete.plaintext<8>) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
return %3 : !Concrete.lwe_ciphertext<1024,7>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func @sub_int_lwe(%arg0: tensor<1025xi64>, %arg1: i5) -> tensor<1025xi64>
|
||||
func @sub_int_lwe(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
// CHECK-NEXT: "BConcrete.negate_lwe_buffer"(%[[V1]], %arg0) : (tensor<1025xi64>, tensor<1025xi64>) -> ()
|
||||
// CHECK-NEXT: %[[V2:.*]] = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5>
|
||||
// CHECK-NEXT: %[[V3:.*]] = linalg.init_tensor [1025] : tensor<1025xi64>
|
||||
// CHECK-NEXT: "BConcrete.add_plaintext_lwe_buffer"(%[[V3]], %[[V1]], %[[V2]]) : (tensor<1025xi64>, tensor<1025xi64>, !Concrete.plaintext<5>) -> ()
|
||||
// CHECK-NEXT: return %[[V3]] : tensor<1025xi64>
|
||||
%0 = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
%1 = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5>
|
||||
%2 = "Concrete.add_plaintext_lwe_ciphertext"(%0, %1) : (!Concrete.lwe_ciphertext<1024,4>, !Concrete.plaintext<5>) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
return %2 : !Concrete.lwe_ciphertext<1024,4>
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
// CHECK: func @tensor_identity(%arg0: tensor<2x3x4x1025xi64>) -> tensor<2x3x4x1025xi64> {
|
||||
// CHECK-NEXT: return %arg0 : tensor<2x3x4x1025xi64>
|
||||
// CHECK-NEXT: }
|
||||
func @tensor_identity(%arg0: tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>>) -> tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>> {
|
||||
return %arg0 : tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>>
|
||||
}
|
||||
Reference in New Issue
Block a user