mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
cleanup(compiler): Remove useless concrete types, simplify print and parse, and remove BConcreteToBConcreteCAPI pass
This commit is contained in:
@@ -1,22 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_BCONCRETETOBCONCRETECAPI_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_BCONCRETETOBCONCRETECAPI_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Create a pass to convert `Concrete` operators to function call to the
|
||||
/// `ConcreteCAPI`
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertBConcreteToBConcreteCAPIPass();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -11,7 +11,6 @@
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
|
||||
#include "concretelang/Conversion/BConcreteToBConcreteCAPI/Pass.h"
|
||||
#include "concretelang/Conversion/ConcreteToBConcrete/Pass.h"
|
||||
#include "concretelang/Conversion/FHETensorOpsToLinalg/Pass.h"
|
||||
#include "concretelang/Conversion/FHEToTFHE/Pass.h"
|
||||
|
||||
@@ -10,11 +10,6 @@ def Concrete_Dialect : Dialect {
|
||||
A dialect for representation of low level operation on fully homomorphic ciphertext.
|
||||
}];
|
||||
let cppNamespace = "::mlir::concretelang::Concrete";
|
||||
let useDefaultTypePrinterParser = 0;
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;
|
||||
void printType(::mlir::Type type, ::mlir::DialectAsmPrinter &printer) const override;
|
||||
}];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -5,141 +5,85 @@ include "mlir/IR/BuiltinTypes.td"
|
||||
|
||||
include "concretelang/Dialect/Concrete/IR/ConcreteDialect.td"
|
||||
|
||||
class Concrete_Type<string name, list<Trait> traits = []> : TypeDef<Concrete_Dialect, name, traits> { }
|
||||
class Concrete_Type<string name, list<Trait> traits = []>
|
||||
: TypeDef<Concrete_Dialect, name, traits> {}
|
||||
|
||||
def GlweCiphertextType : Concrete_Type<"GlweCiphertext"> {
|
||||
let mnemonic = "glwe_ciphertext";
|
||||
let mnemonic = "glwe_ciphertext";
|
||||
|
||||
let summary = "A GLWE ciphertext (encryption of a polynomial of fixed-precision integers)";
|
||||
let summary = "A GLWE ciphertext (encryption of a polynomial of "
|
||||
"fixed-precision integers)";
|
||||
|
||||
let description = [{
|
||||
GLWE ciphertext.
|
||||
}];
|
||||
let description = [{GLWE ciphertext.}];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
let parameters = (ins
|
||||
"signed":$polynomialSize,
|
||||
"signed":$glweDimension,
|
||||
// Precision of the lwe ciphertext
|
||||
"signed":$p
|
||||
);
|
||||
let parameters = (ins "signed"
|
||||
: $polynomialSize, "signed"
|
||||
: $glweDimension,
|
||||
// Precision of the lwe ciphertext
|
||||
"signed"
|
||||
: $p);
|
||||
}
|
||||
|
||||
def LweCiphertextType : Concrete_Type<"LweCiphertext", [MemRefElementTypeInterface]> {
|
||||
let mnemonic = "lwe_ciphertext";
|
||||
def LweCiphertextType
|
||||
: Concrete_Type<"LweCiphertext", [MemRefElementTypeInterface]> {
|
||||
let mnemonic = "lwe_ciphertext";
|
||||
|
||||
let summary = "A LWE ciphertext (encryption of a fixed-precision integer)";
|
||||
let summary = "A LWE ciphertext (encryption of a fixed-precision integer)";
|
||||
|
||||
let description = [{
|
||||
Learning With Error ciphertext.
|
||||
}];
|
||||
let description = [{Learning With Error ciphertext.}];
|
||||
|
||||
let parameters = (ins
|
||||
// The dimension of the lwe ciphertext
|
||||
"signed"
|
||||
: $dimension,
|
||||
// Precision of the lwe ciphertext
|
||||
"signed"
|
||||
: $p);
|
||||
|
||||
let parameters = (ins
|
||||
// The dimension of the lwe ciphertext
|
||||
"signed":$dimension,
|
||||
// Precision of the lwe ciphertext
|
||||
"signed":$p
|
||||
);
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def CleartextType : Concrete_Type<"Cleartext"> {
|
||||
let mnemonic = "cleartext";
|
||||
let mnemonic = "cleartext";
|
||||
|
||||
let summary = "A cleartext (a fixed-precision integer) ready to be multiplied to a LWE ciphertext";
|
||||
let summary = "A cleartext (a fixed-precision integer) ready to be "
|
||||
"multiplied to a LWE ciphertext";
|
||||
|
||||
let description = [{
|
||||
Cleartext.
|
||||
}];
|
||||
let description = [{Cleartext.}];
|
||||
|
||||
let parameters = (ins
|
||||
// Number of bits of the cleartext representation
|
||||
"signed":$p
|
||||
);
|
||||
let parameters = (ins
|
||||
// Number of bits of the cleartext representation
|
||||
"signed"
|
||||
: $p);
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PlaintextType : Concrete_Type<"Plaintext"> {
|
||||
let mnemonic = "plaintext";
|
||||
let mnemonic = "plaintext";
|
||||
|
||||
let summary = "A Plaintext (a fixed-precision integer) ready to be added to a LWE ciphertext";
|
||||
let summary = "A Plaintext (a fixed-precision integer) ready to be added to "
|
||||
"a LWE ciphertext";
|
||||
|
||||
let description = [{
|
||||
Plaintext.
|
||||
}];
|
||||
let description = [{Plaintext.}];
|
||||
|
||||
let parameters = (ins
|
||||
// Number of bits of the cleartext representation
|
||||
"signed":$p
|
||||
);
|
||||
let parameters = (ins
|
||||
// Number of bits of the cleartext representation
|
||||
"signed"
|
||||
: $p);
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PlaintextListType : Concrete_Type<"PlaintextList"> {
|
||||
let mnemonic = "plaintext_list";
|
||||
|
||||
let summary = "List of plaintexts";
|
||||
|
||||
let description = [{
|
||||
Plaintext list.
|
||||
}];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def ForeignPlaintextListType : Concrete_Type<"ForeignPlaintextList"> {
|
||||
let mnemonic = "foreign_plaintext_list";
|
||||
|
||||
let summary = "A foreign (reference to a independently allocated memory space) plaintext list";
|
||||
|
||||
let description = [{
|
||||
Foreign plaintext list.
|
||||
}];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def LweKeySwitchKeyType : Concrete_Type<"LweKeySwitchKey"> {
|
||||
let mnemonic = "lwe_key_switch_key";
|
||||
|
||||
let summary = "A LWE keyswitching key";
|
||||
|
||||
let description = [{
|
||||
Learning With Error keyswitching key.
|
||||
}];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def LweBootstrapKeyType : Concrete_Type<"LweBootstrapKey"> {
|
||||
let mnemonic = "lwe_bootstrap_key";
|
||||
|
||||
let summary = "A LWE bootstrapping key";
|
||||
|
||||
let description = [{
|
||||
Learning With Error bootstrapping key.
|
||||
}];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def Context : Concrete_Type<"Context"> {
|
||||
let mnemonic = "context";
|
||||
let mnemonic = "context";
|
||||
|
||||
let summary = "A runtime context";
|
||||
let summary = "A runtime context";
|
||||
|
||||
let description = [{
|
||||
An abstract runtime context to pass contextual value, like public keys, ...
|
||||
}];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let description = [{An abstract runtime context to pass contextual value,
|
||||
like public keys, ...}];
|
||||
}
|
||||
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,718 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR//BuiltinTypes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Tools.h"
|
||||
#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include <mlir/IR/OperationSupport.h>
|
||||
#include <mlir/IR/Value.h>
|
||||
|
||||
static mlir::Type convertTypeIfConcreteType(mlir::MLIRContext *context,
|
||||
mlir::Type t) {
|
||||
if (t.isa<mlir::concretelang::Concrete::PlaintextType>() ||
|
||||
t.isa<mlir::concretelang::Concrete::CleartextType>()) {
|
||||
return mlir::IntegerType::get(context, 64);
|
||||
} else {
|
||||
return t;
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
class BConcreteToBConcreteCAPITypeConverter : public mlir::TypeConverter {
|
||||
|
||||
public:
|
||||
BConcreteToBConcreteCAPITypeConverter() {
|
||||
addConversion([](mlir::Type type) { return type; });
|
||||
addConversion([&](mlir::concretelang::Concrete::PlaintextType type) {
|
||||
return convertTypeIfConcreteType(type.getContext(), type);
|
||||
});
|
||||
addConversion([&](mlir::concretelang::Concrete::CleartextType type) {
|
||||
return convertTypeIfConcreteType(type.getContext(), type);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Set of functions to generate generic types.
|
||||
// Generic types are used to add forward declarations without a specific type.
|
||||
// For example, we may need to add LWE ciphertext of different dimensions, or
|
||||
// allocate them. All the calls to the C API should be done using this generic
|
||||
// types, and casting should then be performed back to the appropriate type.
|
||||
|
||||
inline mlir::Type getGenericLweBufferType(mlir::MLIRContext *context) {
|
||||
return mlir::RankedTensorType::get({-1}, mlir::IntegerType::get(context, 64));
|
||||
}
|
||||
|
||||
inline mlir::Type getGenericLweMemrefType(mlir::MLIRContext *context) {
|
||||
return mlir::MemRefType::get({-1}, mlir::IntegerType::get(context, 64));
|
||||
}
|
||||
|
||||
inline mlir::Type getGenericGlweBufferType(mlir::MLIRContext *context) {
|
||||
return mlir::RankedTensorType::get({-1}, mlir::IntegerType::get(context, 64));
|
||||
}
|
||||
|
||||
inline mlir::Type getGenericPlaintextType(mlir::MLIRContext *context) {
|
||||
return mlir::IntegerType::get(context, 64);
|
||||
}
|
||||
|
||||
inline mlir::Type getGenericCleartextType(mlir::MLIRContext *context) {
|
||||
return mlir::IntegerType::get(context, 64);
|
||||
}
|
||||
|
||||
inline mlir::concretelang::Concrete::LweKeySwitchKeyType
|
||||
getGenericLweKeySwitchKeyType(mlir::MLIRContext *context) {
|
||||
return mlir::concretelang::Concrete::LweKeySwitchKeyType::get(context);
|
||||
}
|
||||
|
||||
inline mlir::concretelang::Concrete::LweBootstrapKeyType
|
||||
getGenericLweBootstrapKeyType(mlir::MLIRContext *context) {
|
||||
return mlir::concretelang::Concrete::LweBootstrapKeyType::get(context);
|
||||
}
|
||||
|
||||
// Insert all forward declarations needed for the pass.
|
||||
// Should generalize input and output types for all decalarations, and the
|
||||
// pattern using them would be resposible for casting them to the appropriate
|
||||
// type.
|
||||
mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
mlir::IRRewriter &rewriter) {
|
||||
auto lweBufferType = getGenericLweMemrefType(rewriter.getContext());
|
||||
auto plaintextType = getGenericPlaintextType(rewriter.getContext());
|
||||
auto cleartextType = getGenericCleartextType(rewriter.getContext());
|
||||
auto keySwitchKeyType = getGenericLweKeySwitchKeyType(rewriter.getContext());
|
||||
auto bootstrapKeyType = getGenericLweBootstrapKeyType(rewriter.getContext());
|
||||
auto contextType =
|
||||
mlir::concretelang::Concrete::ContextType::get(rewriter.getContext());
|
||||
|
||||
// Insert forward declaration of the add_lwe_ciphertexts function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {lweBufferType, lweBufferType, lweBufferType},
|
||||
{});
|
||||
if (insertForwardDeclaration(op, rewriter, "memref_add_lwe_ciphertexts_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the add_plaintext_lwe_ciphertext function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {lweBufferType, lweBufferType, plaintextType},
|
||||
{});
|
||||
if (insertForwardDeclaration(
|
||||
op, rewriter, "memref_add_plaintext_lwe_ciphertext_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the mul_cleartext_lwe_ciphertext function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {lweBufferType, lweBufferType, cleartextType},
|
||||
{});
|
||||
if (insertForwardDeclaration(
|
||||
op, rewriter, "memref_mul_cleartext_lwe_ciphertext_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the negate_lwe_ciphertext function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{lweBufferType, lweBufferType}, {});
|
||||
if (insertForwardDeclaration(op, rewriter,
|
||||
"memref_negate_lwe_ciphertext_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the memref_keyswitch_lwe_u64 function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {lweBufferType, lweBufferType, contextType}, {});
|
||||
if (insertForwardDeclaration(op, rewriter, "memref_keyswitch_lwe_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the memref_bootstrap_lwe_u64 function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{lweBufferType, lweBufferType, lweBufferType, contextType}, {});
|
||||
if (insertForwardDeclaration(op, rewriter, "memref_bootstrap_lwe_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
// Insert forward declaration of the expand_lut_in_trivial_glwe_ct function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{
|
||||
getGenericGlweBufferType(rewriter.getContext()),
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
mlir::RankedTensorType::get(
|
||||
{-1}, mlir::IntegerType::get(rewriter.getContext(), 64)),
|
||||
},
|
||||
{});
|
||||
if (insertForwardDeclaration(
|
||||
op, rewriter, "memref_expand_lut_in_trivial_glwe_ct_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
// Insert forward declaration of the getGlobalKeyswitchKey function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{contextType}, {keySwitchKeyType});
|
||||
if (insertForwardDeclaration(op, rewriter, "get_keyswitch_key_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the getGlobalBootstrapKey function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{contextType}, {bootstrapKeyType});
|
||||
if (insertForwardDeclaration(op, rewriter, "get_bootstrap_key_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Replaces an operand `tensor<Axi64>` with
|
||||
// ```
|
||||
// %casted_tensor = tensor.cast %op : tensor<Axi64> to tensor<?xui64>
|
||||
// %casted_memref = bufferization.to_memref %casted_tensor : memref<?xui64>
|
||||
// ```
|
||||
mlir::Value getCastedTensorOperand(mlir::PatternRewriter &rewriter,
|
||||
mlir::Location loc, mlir::Value operand) {
|
||||
mlir::Type operandType = operand.getType();
|
||||
if (operandType.isa<mlir::RankedTensorType>()) {
|
||||
mlir::Value castedTensor = rewriter.create<mlir::tensor::CastOp>(
|
||||
loc, getGenericLweBufferType(rewriter.getContext()), operand);
|
||||
|
||||
mlir::Value castedMemRef = rewriter.create<mlir::bufferization::ToMemrefOp>(
|
||||
loc, getGenericLweMemrefType(rewriter.getContext()), castedTensor);
|
||||
return castedMemRef;
|
||||
} else {
|
||||
return operand;
|
||||
}
|
||||
}
|
||||
|
||||
mlir::SmallVector<mlir::Value>
|
||||
getCastedTensorOperands(mlir::PatternRewriter &rewriter, mlir::Operation *op) {
|
||||
return llvm::to_vector<3>(
|
||||
llvm::map_range(op->getOperands(), [&](mlir::Value operand) {
|
||||
return getCastedTensorOperand(rewriter, op->getLoc(), operand);
|
||||
}));
|
||||
}
|
||||
|
||||
// template <typename Op>
|
||||
// mlir::SmallVector<mlir::Value>
|
||||
// getCastedTensorOperands(Op op, mlir::PatternRewriter &rewriter) {
|
||||
// mlir::SmallVector<mlir::Value, 4> newOperands{};
|
||||
// for (mlir::Value operand : op->getOperands()) {
|
||||
// mlir::Type operandType = operand.getType();
|
||||
// if (operandType.isa<mlir::RankedTensorType>()) {
|
||||
// mlir::Value castedTensor = rewriter.create<mlir::tensor::CastOp>(
|
||||
// op.getLoc(), getGenericLweBufferType(rewriter.getContext()),
|
||||
// operand);
|
||||
|
||||
// mlir::Value castedMemRef =
|
||||
// rewriter.create<mlir::bufferization::ToMemrefOp>(
|
||||
// op.getLoc(), getGenericLweBufferType(rewriter.getContext()),
|
||||
// operand);
|
||||
// newOperands.push_back(castedMemRef);
|
||||
// } else {
|
||||
// newOperands.push_back(operand);
|
||||
// }
|
||||
// }
|
||||
// return std::move(newOperands);
|
||||
// }
|
||||
|
||||
/// BConcreteOpToConcreteCAPICallPattern<Op> matches the `BConcreteOp`
|
||||
/// Operation and replaces it with a call to `funcName`, the funcName should be
|
||||
/// an external function that is linked later. It inserts the forward
|
||||
/// declaration of the private `funcName` if it not already in the symbol table.
|
||||
/// The C signature of the function should be `void (out, args...,
|
||||
/// lweDimension)`, the pattern rewrites:
|
||||
/// ```
|
||||
/// "%out = BConcreteOp"(args ...) :
|
||||
/// (tensor<sizexi64>...) -> tensor<sizexi64>
|
||||
/// ```
|
||||
/// to
|
||||
/// ```
|
||||
/// %args_tensor = tensor.cast ...
|
||||
/// %args_memref = bufferize.to_memref ...
|
||||
/// %out_tensor_ranked = linalg.tensor_init ...
|
||||
// %out_tensor = tensor.cast ...
|
||||
/// %out_memref = bufferize.to_memref ...
|
||||
/// call @funcName(%out_memref, %args_memref...) :
|
||||
/// (memref<?xi64>, memref<?xi64>...) -> ()
|
||||
// %out = bufferize.to_tensor ...
|
||||
/// ```
|
||||
template <typename BConcreteOp>
|
||||
struct ConcreteOpToConcreteCAPICallPattern
|
||||
: public mlir::OpRewritePattern<BConcreteOp> {
|
||||
ConcreteOpToConcreteCAPICallPattern(mlir::MLIRContext *context,
|
||||
mlir::StringRef funcName,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<BConcreteOp>(context, benefit),
|
||||
funcName(funcName) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(BConcreteOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
BConcreteToBConcreteCAPITypeConverter typeConverter;
|
||||
|
||||
mlir::RankedTensorType tensorResultTy =
|
||||
op.getResult().getType().template cast<mlir::RankedTensorType>();
|
||||
|
||||
mlir::Value outTensor = rewriter.create<mlir::linalg::InitTensorOp>(
|
||||
op.getLoc(), tensorResultTy.getShape(),
|
||||
tensorResultTy.getElementType());
|
||||
|
||||
mlir::Value outMemref =
|
||||
getCastedTensorOperand(rewriter, op.getLoc(), outTensor);
|
||||
|
||||
mlir::SmallVector<mlir::Value> castedOperands{outMemref};
|
||||
castedOperands.append(getCastedTensorOperands(rewriter, op));
|
||||
|
||||
mlir::func::CallOp callOp = rewriter.create<mlir::func::CallOp>(
|
||||
op.getLoc(), funcName, mlir::TypeRange{}, castedOperands);
|
||||
|
||||
// Convert remaining, non-tensor types (e.g., plaintext values)
|
||||
mlir::concretelang::convertOperandAndResultTypes(
|
||||
rewriter, callOp, [&](mlir::MLIRContext *context, mlir::Type t) {
|
||||
return typeConverter.convertType(t);
|
||||
});
|
||||
|
||||
mlir::Value updatedOutTensor =
|
||||
rewriter.create<mlir::bufferization::ToTensorOp>(op.getLoc(),
|
||||
outMemref);
|
||||
|
||||
rewriter.replaceOpWithNewOp<mlir::tensor::CastOp>(op, tensorResultTy,
|
||||
updatedOutTensor);
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
|
||||
private:
|
||||
std::string funcName;
|
||||
};
|
||||
|
||||
struct ConcreteEncodeIntOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::Concrete::EncodeIntOp> {
|
||||
ConcreteEncodeIntOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<mlir::concretelang::Concrete::EncodeIntOp>(
|
||||
context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::Concrete::EncodeIntOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
{
|
||||
mlir::Value castedInt = rewriter.create<mlir::arith::ExtUIOp>(
|
||||
op.getLoc(), rewriter.getIntegerType(64), op->getOperands().front());
|
||||
mlir::Value constantShiftOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI64IntegerAttr(64 - op.getType().getP()));
|
||||
|
||||
mlir::Type resultType = rewriter.getIntegerType(64);
|
||||
rewriter.replaceOpWithNewOp<mlir::arith::ShLIOp>(
|
||||
op, resultType, castedInt, constantShiftOp);
|
||||
}
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
struct ConcreteIntToCleartextOpPattern
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::concretelang::Concrete::IntToCleartextOp> {
|
||||
ConcreteIntToCleartextOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<mlir::concretelang::Concrete::IntToCleartextOp>(
|
||||
context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::Concrete::IntToCleartextOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<mlir::arith::ExtUIOp>(
|
||||
op, rewriter.getIntegerType(64), op->getOperands().front());
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
mlir::Value getContextArgument(mlir::Operation *op) {
|
||||
mlir::Block *block = op->getBlock();
|
||||
while (block != nullptr) {
|
||||
if (llvm::isa<mlir::func::FuncOp>(block->getParentOp())) {
|
||||
|
||||
mlir::Value context = block->getArguments().back();
|
||||
|
||||
assert(
|
||||
context.getType().isa<mlir::concretelang::Concrete::ContextType>() &&
|
||||
"the Concrete.context should be the last argument of the enclosing "
|
||||
"function of the op");
|
||||
|
||||
return context;
|
||||
}
|
||||
block = block->getParentOp()->getBlock();
|
||||
}
|
||||
assert("can't find a function that enclose the op");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Rewrite pattern that rewrite every
|
||||
// ```
|
||||
// %out = "BConcrete.keyswitch_lwe_buffer"(%out, %in) {...}:
|
||||
// (tensor<2049xi64>) -> (tensor<2049xi64>)
|
||||
// ```
|
||||
//
|
||||
// to
|
||||
//
|
||||
// ```
|
||||
// %out = linalg.tensor_init [B] : tensor<Bxui64>
|
||||
// %out_casted = tensor.cast %out : tensor<Axi64> to tensor<?xi64>
|
||||
// %out_memref = bufferize.to_memref %out_casted ...
|
||||
// %in_casted = tensor.cast %in : tensor<Axi64> to tensor<?xi64>
|
||||
// %in_memref = bufferize.to_memref ...
|
||||
// call @memref_keyswitch_lwe_u64(%out_memref, %in_memref) :
|
||||
// (tensor<?xui64>, !Concrete.context) -> (tensor<?xui64>)
|
||||
// ```
|
||||
struct BConcreteKeySwitchLweOpPattern
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::concretelang::BConcrete::KeySwitchLweBufferOp> {
|
||||
BConcreteKeySwitchLweOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<
|
||||
mlir::concretelang::BConcrete::KeySwitchLweBufferOp>(context,
|
||||
benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::BConcrete::KeySwitchLweBufferOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
// Create the output operand
|
||||
mlir::RankedTensorType tensorResultTy =
|
||||
op.getResult().getType().template cast<mlir::RankedTensorType>();
|
||||
mlir::Value outTensor =
|
||||
rewriter.replaceOpWithNewOp<mlir::linalg::InitTensorOp>(
|
||||
op, tensorResultTy.getShape(), tensorResultTy.getElementType());
|
||||
mlir::Value outMemref =
|
||||
getCastedTensorOperand(rewriter, op.getLoc(), outTensor);
|
||||
|
||||
mlir::SmallVector<mlir::Value> operands{outMemref};
|
||||
operands.append(getCastedTensorOperands(rewriter, op));
|
||||
operands.push_back(getContextArgument(op));
|
||||
|
||||
rewriter.create<mlir::func::CallOp>(op.getLoc(), "memref_keyswitch_lwe_u64",
|
||||
mlir::TypeRange({}), operands);
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
// Rewrite pattern that rewrite every
|
||||
// ```
|
||||
// %out = "BConcrete.bootstrap_lwe_buffer"(%in, %acc) {...} :
|
||||
// (tensor<Axui64>, !Concrete.glwe_ciphertext) -> (tensor<Bxui64>)
|
||||
// ```
|
||||
//
|
||||
// to
|
||||
//
|
||||
// ```
|
||||
// %out = linalg.tensor_init [B] : tensor<Bxui64>
|
||||
// %out_casted = tensor.cast %out : tensor<Axi64> to tensor<?xi64>
|
||||
// %out_memref = bufferize.to_memref %out_casted ...
|
||||
// %in_casted = tensor.cast %in : tensor<Axi64> to tensor<?xi64>
|
||||
// %in_memref = bufferize.to_memref ...
|
||||
// call @memref_bootstrap_lwe_u64(%out_memref, %in_memref, %acc_, %ctx) :
|
||||
// (memref<?xi64>, memref<?xi64>,
|
||||
// !Concrete.glwe_ciphertext, !Concrete.context) -> ()
|
||||
// ```
|
||||
struct BConcreteBootstrapLweOpPattern
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::concretelang::BConcrete::BootstrapLweBufferOp> {
|
||||
BConcreteBootstrapLweOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<
|
||||
mlir::concretelang::BConcrete::BootstrapLweBufferOp>(context,
|
||||
benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::BConcrete::BootstrapLweBufferOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
// Create the output operand
|
||||
mlir::RankedTensorType tensorResultTy =
|
||||
op.getResult().getType().template cast<mlir::RankedTensorType>();
|
||||
mlir::Value outTensor =
|
||||
rewriter.replaceOpWithNewOp<mlir::linalg::InitTensorOp>(
|
||||
op, tensorResultTy.getShape(), tensorResultTy.getElementType());
|
||||
mlir::Value outMemref =
|
||||
getCastedTensorOperand(rewriter, op.getLoc(), outTensor);
|
||||
|
||||
mlir::SmallVector<mlir::Value> operands{outMemref};
|
||||
operands.append(getCastedTensorOperands(rewriter, op));
|
||||
operands.push_back(getContextArgument(op));
|
||||
|
||||
rewriter.create<mlir::func::CallOp>(op.getLoc(), "memref_bootstrap_lwe_u64",
|
||||
mlir::TypeRange({}), operands);
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
// Rewrite pattern that rewrite every
|
||||
// ```
|
||||
// "BConcrete.fill_glwe_table"(%glwe, %lut) {glweDimension=1,
|
||||
// polynomialSize=2048, outPrecision=3} :
|
||||
// (tensor<4096xi64>, tensor<32xi64>) -> ()
|
||||
// ```
|
||||
//
|
||||
// to
|
||||
//
|
||||
// ```
|
||||
// %glweDim = arith.constant 1 : i32
|
||||
// %polySize = arith.constant 2048 : i32
|
||||
// %outPrecision = arith.constant 3 : i32
|
||||
// %glwe_ = tensor.cast %glwe : tensor<4096xi64> to tensor<?xi64>
|
||||
// %lut_ = tensor.cast %lut : tensor<32xi64> to tensor<?xi64>
|
||||
// call @expand_lut_in_trivial_glwe_ct(%glwe, %polySize, %glweDim,
|
||||
// %outPrecision, %lut_) :
|
||||
// (tensor<?xi64>, i32, i32, tensor<?xi64>) -> ()
|
||||
// ```
|
||||
struct BConcreteGlweFromTableOpPattern
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::concretelang::BConcrete::FillGlweFromTable> {
|
||||
BConcreteGlweFromTableOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<
|
||||
mlir::concretelang::BConcrete::FillGlweFromTable>(context,
|
||||
benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::BConcrete::FillGlweFromTable op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
BConcreteToBConcreteCAPITypeConverter typeConverter;
|
||||
// %glweDim = arith.constant 1 : i32
|
||||
// %polySize = arith.constant 2048 : i32
|
||||
// %outPrecision = arith.constant 3 : i32
|
||||
|
||||
auto castedOp = getCastedTensorOperands(rewriter, op);
|
||||
|
||||
auto polySizeOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI32IntegerAttr(op.polynomialSize()));
|
||||
auto glweDimensionOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI32IntegerAttr(op.glweDimension()));
|
||||
auto outPrecisionOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI32IntegerAttr(op.outPrecision()));
|
||||
|
||||
mlir::SmallVector<mlir::Value> newOperands{
|
||||
castedOp[0], polySizeOp, glweDimensionOp, outPrecisionOp, castedOp[1]};
|
||||
|
||||
// getCastedTensor(op.getLoc(), newOperands, rewriter);
|
||||
// perform operands conversion
|
||||
// %glwe_ = tensor.cast %glwe : tensor<4096xi64> to tensor<?xi64>
|
||||
// %lut_ = tensor.cast %lut : tensor<32xi64> to tensor<?xi64>
|
||||
|
||||
// call @expand_lut_in_trivial_glwe_ct(%glwe, %polySize, %glweDim,
|
||||
// %lut_) :
|
||||
// (tensor<?xi64>, i32, i32, tensor<?xi64>) -> ()
|
||||
|
||||
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
|
||||
op, "memref_expand_lut_in_trivial_glwe_ct_u64",
|
||||
mlir::SmallVector<mlir::Type>{}, newOperands);
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
/// Populate the RewritePatternSet with all patterns that rewrite Concrete
|
||||
/// operators to the corresponding function call to the `Concrete C API`.
|
||||
void populateBConcreteToBConcreteCAPICall(mlir::RewritePatternSet &patterns) {
|
||||
patterns.add<ConcreteOpToConcreteCAPICallPattern<
|
||||
mlir::concretelang::BConcrete::AddLweBuffersOp>>(
|
||||
patterns.getContext(), "memref_add_lwe_ciphertexts_u64");
|
||||
patterns.add<ConcreteOpToConcreteCAPICallPattern<
|
||||
mlir::concretelang::BConcrete::AddPlaintextLweBufferOp>>(
|
||||
patterns.getContext(), "memref_add_plaintext_lwe_ciphertext_u64");
|
||||
patterns.add<ConcreteOpToConcreteCAPICallPattern<
|
||||
mlir::concretelang::BConcrete::MulCleartextLweBufferOp>>(
|
||||
patterns.getContext(), "memref_mul_cleartext_lwe_ciphertext_u64");
|
||||
patterns.add<ConcreteOpToConcreteCAPICallPattern<
|
||||
mlir::concretelang::BConcrete::NegateLweBufferOp>>(
|
||||
patterns.getContext(), "memref_negate_lwe_ciphertext_u64");
|
||||
patterns.add<ConcreteEncodeIntOpPattern>(patterns.getContext());
|
||||
patterns.add<ConcreteIntToCleartextOpPattern>(patterns.getContext());
|
||||
patterns.add<BConcreteKeySwitchLweOpPattern>(patterns.getContext());
|
||||
patterns.add<BConcreteBootstrapLweOpPattern>(patterns.getContext());
|
||||
patterns.add<BConcreteGlweFromTableOpPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
struct AddRuntimeContextToFuncOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::func::FuncOp> {
|
||||
AddRuntimeContextToFuncOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<mlir::func::FuncOp>(context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::func::FuncOp oldFuncOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::OpBuilder::InsertionGuard guard(rewriter);
|
||||
mlir::FunctionType oldFuncType = oldFuncOp.getFunctionType();
|
||||
|
||||
// Add a Concrete.context to the function signature
|
||||
mlir::SmallVector<mlir::Type> newInputs(oldFuncType.getInputs().begin(),
|
||||
oldFuncType.getInputs().end());
|
||||
newInputs.push_back(
|
||||
rewriter.getType<mlir::concretelang::Concrete::ContextType>());
|
||||
mlir::FunctionType newFuncTy = rewriter.getType<mlir::FunctionType>(
|
||||
newInputs, oldFuncType.getResults());
|
||||
// Create the new func
|
||||
mlir::func::FuncOp newFuncOp = rewriter.create<mlir::func::FuncOp>(
|
||||
oldFuncOp.getLoc(), oldFuncOp.getName(), newFuncTy);
|
||||
|
||||
// Create the arguments of the new func
|
||||
mlir::Region &newFuncBody = newFuncOp.getBody();
|
||||
mlir::Block *newFuncEntryBlock = new mlir::Block();
|
||||
llvm::SmallVector<mlir::Location> locations(newFuncTy.getInputs().size(),
|
||||
oldFuncOp.getLoc());
|
||||
|
||||
newFuncEntryBlock->addArguments(newFuncTy.getInputs(), locations);
|
||||
newFuncBody.push_back(newFuncEntryBlock);
|
||||
|
||||
// Clone the old body to the new one
|
||||
mlir::BlockAndValueMapping map;
|
||||
for (auto arg : llvm::enumerate(oldFuncOp.getArguments())) {
|
||||
map.map(arg.value(), newFuncEntryBlock->getArgument(arg.index()));
|
||||
}
|
||||
for (auto &op : oldFuncOp.getBody().front()) {
|
||||
newFuncEntryBlock->push_back(op.clone(map));
|
||||
}
|
||||
rewriter.eraseOp(oldFuncOp);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Legal function are one that are private or has a Concrete.context as last
|
||||
// arguments.
|
||||
static bool isLegal(mlir::func::FuncOp funcOp) {
|
||||
if (!funcOp.isPublic()) {
|
||||
return true;
|
||||
}
|
||||
// TODO : Don't need to add a runtime context for function that doesn't
|
||||
// manipulates Concrete types.
|
||||
//
|
||||
// if (!llvm::any_of(funcOp.getType().getInputs(), [](mlir::Type t) {
|
||||
// if (auto tensorTy = t.dyn_cast_or_null<mlir::TensorType>()) {
|
||||
// t = tensorTy.getElementType();
|
||||
// }
|
||||
// return llvm::isa<mlir::concretelang::Concrete::ConcreteDialect>(
|
||||
// t.getDialect());
|
||||
// })) {
|
||||
// return true;
|
||||
// }
|
||||
return funcOp.getFunctionType().getNumInputs() >= 1 &&
|
||||
funcOp.getFunctionType()
|
||||
.getInputs()
|
||||
.back()
|
||||
.isa<mlir::concretelang::Concrete::ContextType>();
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
struct BConcreteToBConcreteCAPIPass
|
||||
: public BConcreteToBConcreteCAPIBase<BConcreteToBConcreteCAPIPass> {
|
||||
void runOnOperation() final;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void BConcreteToBConcreteCAPIPass::runOnOperation() {
|
||||
mlir::ModuleOp op = getOperation();
|
||||
|
||||
// First of all add the Concrete.context to the block arguments of function
|
||||
// that manipulates ciphertexts.
|
||||
{
|
||||
mlir::ConversionTarget target(getContext());
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
target.addDynamicallyLegalOp<mlir::func::FuncOp>(
|
||||
[&](mlir::func::FuncOp funcOp) {
|
||||
return AddRuntimeContextToFuncOpPattern::isLegal(funcOp);
|
||||
});
|
||||
|
||||
patterns.add<AddRuntimeContextToFuncOpPattern>(patterns.getContext());
|
||||
|
||||
// Apply the conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
.failed()) {
|
||||
this->signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Insert forward declaration
|
||||
mlir::IRRewriter rewriter(&getContext());
|
||||
if (insertForwardDeclarations(op, rewriter).failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
// Rewrite Concrete ops to CallOp to the Concrete C API
|
||||
{
|
||||
mlir::ConversionTarget target(getContext());
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
target.addIllegalDialect<mlir::concretelang::BConcrete::BConcreteDialect>();
|
||||
|
||||
target.addLegalDialect<mlir::BuiltinDialect, mlir::func::FuncDialect,
|
||||
mlir::tensor::TensorDialect,
|
||||
mlir::arith::ArithmeticDialect>();
|
||||
|
||||
target.addLegalOp<mlir::linalg::InitTensorOp>();
|
||||
target.addLegalOp<mlir::bufferization::ToMemrefOp>();
|
||||
target.addLegalOp<mlir::bufferization::ToTensorOp>();
|
||||
|
||||
populateBConcreteToBConcreteCAPICall(patterns);
|
||||
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
.failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertBConcreteToBConcreteCAPIPass() {
|
||||
return std::make_unique<BConcreteToBConcreteCAPIPass>();
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -1,17 +0,0 @@
|
||||
add_mlir_dialect_library(BConcreteToBConcreteCAPI
|
||||
BConcreteToBConcreteCAPI.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
|
||||
|
||||
DEPENDS
|
||||
BConcreteDialect
|
||||
mlir-headers
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
ConcretelangConversion
|
||||
MLIRIR
|
||||
MLIRTransforms
|
||||
)
|
||||
|
||||
target_link_libraries(BConcreteToBConcreteCAPI PUBLIC MLIRIR)
|
||||
@@ -3,7 +3,6 @@ add_subdirectory(TFHEGlobalParametrization)
|
||||
add_subdirectory(TFHEToConcrete)
|
||||
add_subdirectory(FHETensorOpsToLinalg)
|
||||
add_subdirectory(ConcreteToBConcrete)
|
||||
add_subdirectory(BConcreteToBConcreteCAPI)
|
||||
add_subdirectory(MLIRLowerableDialectsToLLVM)
|
||||
add_subdirectory(LinalgExtras)
|
||||
|
||||
|
||||
@@ -154,11 +154,7 @@ llvm::Optional<mlir::Type>
|
||||
MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) {
|
||||
if (type.isa<mlir::concretelang::Concrete::LweCiphertextType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::GlweCiphertextType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::LweKeySwitchKeyType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::LweBootstrapKeyType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::ContextType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::ForeignPlaintextListType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::PlaintextListType>() ||
|
||||
type.isa<mlir::concretelang::RT::FutureType>()) {
|
||||
return mlir::LLVM::LLVMPointerType::get(
|
||||
mlir::IntegerType::get(type.getContext(), 64));
|
||||
|
||||
@@ -152,62 +152,3 @@ mlir::Type PlaintextType::parse(mlir::AsmParser &parser) {
|
||||
|
||||
return getChecked(loc, loc.getContext(), p);
|
||||
}
|
||||
|
||||
mlir::Type PlaintextListType::parse(mlir::AsmParser &parser) {
|
||||
return get(parser.getContext());
|
||||
}
|
||||
|
||||
void PlaintextListType::print(mlir::AsmPrinter &p) const {}
|
||||
|
||||
mlir::Type ForeignPlaintextListType::parse(mlir::AsmParser &parser) {
|
||||
return get(parser.getContext());
|
||||
}
|
||||
|
||||
void ForeignPlaintextListType::print(mlir::AsmPrinter &p) const {}
|
||||
|
||||
mlir::Type LweKeySwitchKeyType::parse(mlir::AsmParser &parser) {
|
||||
return get(parser.getContext());
|
||||
}
|
||||
|
||||
void LweKeySwitchKeyType::print(mlir::AsmPrinter &p) const {}
|
||||
|
||||
mlir::Type LweBootstrapKeyType::parse(mlir::AsmParser &parser) {
|
||||
return get(parser.getContext());
|
||||
}
|
||||
|
||||
void LweBootstrapKeyType::print(mlir::AsmPrinter &p) const {}
|
||||
|
||||
void ContextType::print(mlir::AsmPrinter &p) const {}
|
||||
|
||||
mlir::Type ContextType::parse(mlir::AsmParser &parser) {
|
||||
return get(parser.getContext());
|
||||
}
|
||||
|
||||
::mlir::Type
|
||||
ConcreteDialect::parseType(::mlir::DialectAsmParser &parser) const {
|
||||
mlir::Type type;
|
||||
|
||||
std::string types_str[] = {
|
||||
"plaintext", "plaintext_list", "foreign_plaintext_list",
|
||||
"lwe_ciphertext", "lwe_key_switch_key", "lwe_bootstrap_key",
|
||||
"glwe_ciphertext", "cleartext", "context",
|
||||
};
|
||||
|
||||
for (const std::string &type_str : types_str) {
|
||||
if (parser.parseOptionalKeyword(type_str).succeeded()) {
|
||||
generatedTypeParser(parser, type_str, type);
|
||||
return type;
|
||||
}
|
||||
}
|
||||
|
||||
parser.emitError(parser.getCurrentLocation(), "Unknown Concrete type");
|
||||
|
||||
return type;
|
||||
}
|
||||
|
||||
void ConcreteDialect::printType(::mlir::Type type,
|
||||
::mlir::DialectAsmPrinter &printer) const {
|
||||
if (generatedTypePrinter(type, printer).failed())
|
||||
// Calling default printer if failed to print Concrete type
|
||||
printer.printType(type);
|
||||
}
|
||||
|
||||
@@ -166,11 +166,7 @@ static mlir::Value getSizeInBytes(Value val, Location loc, OpBuilder builder) {
|
||||
// bytes until we can get the actual size of the actual types.
|
||||
if (type.isa<mlir::concretelang::Concrete::ContextType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::LweCiphertextType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::GlweCiphertextType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::LweKeySwitchKeyType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::LweBootstrapKeyType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::ForeignPlaintextListType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::PlaintextListType>())
|
||||
type.isa<mlir::concretelang::Concrete::GlweCiphertextType>())
|
||||
return builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(8));
|
||||
|
||||
// For all other types, get type size.
|
||||
|
||||
@@ -7,36 +7,12 @@ func @type_plaintext(%arg0: !Concrete.plaintext<7>) -> !Concrete.plaintext<7> {
|
||||
return %arg0: !Concrete.plaintext<7>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @type_plaintext_list(%arg0: !Concrete.plaintext_list) -> !Concrete.plaintext_list
|
||||
func @type_plaintext_list(%arg0: !Concrete.plaintext_list) -> !Concrete.plaintext_list {
|
||||
// CHECK-NEXT: return %arg0 : !Concrete.plaintext_list
|
||||
return %arg0: !Concrete.plaintext_list
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @type_foreign_plaintext_list(%arg0: !Concrete.foreign_plaintext_list) -> !Concrete.foreign_plaintext_list
|
||||
func @type_foreign_plaintext_list(%arg0: !Concrete.foreign_plaintext_list) -> !Concrete.foreign_plaintext_list {
|
||||
// CHECK-NEXT: return %arg0 : !Concrete.foreign_plaintext_list
|
||||
return %arg0: !Concrete.foreign_plaintext_list
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @type_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
func @type_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
// CHECK-NEXT: return %arg0 : !Concrete.lwe_ciphertext<2048,7>
|
||||
return %arg0: !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @type_lwe_key_switch_key(%arg0: !Concrete.lwe_key_switch_key) -> !Concrete.lwe_key_switch_key
|
||||
func @type_lwe_key_switch_key(%arg0: !Concrete.lwe_key_switch_key) -> !Concrete.lwe_key_switch_key {
|
||||
// CHECK-NEXT: return %arg0 : !Concrete.lwe_key_switch_key
|
||||
return %arg0: !Concrete.lwe_key_switch_key
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @type_lwe_bootstrap_key(%arg0: !Concrete.lwe_bootstrap_key) -> !Concrete.lwe_bootstrap_key
|
||||
func @type_lwe_bootstrap_key(%arg0: !Concrete.lwe_bootstrap_key) -> !Concrete.lwe_bootstrap_key {
|
||||
// CHECK-NEXT: return %arg0 : !Concrete.lwe_bootstrap_key
|
||||
return %arg0: !Concrete.lwe_bootstrap_key
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @type_cleartext(%arg0: !Concrete.cleartext<5>) -> !Concrete.cleartext<5>
|
||||
func @type_cleartext(%arg0: !Concrete.cleartext<5>) -> !Concrete.cleartext<5> {
|
||||
// CHECK-NEXT: return %arg0 : !Concrete.cleartext<5>
|
||||
|
||||
Reference in New Issue
Block a user