refactor: ins forward decl w generic ty @pass-init

Insert forward declarations with generic types at pass initialization.
More docs for all the pass for lowering LUT
This commit is contained in:
youben11
2021-09-08 15:04:10 +01:00
committed by Quentin Bourgerie
parent d97512f507
commit 746d991af6
13 changed files with 513 additions and 414 deletions

View File

@@ -169,6 +169,47 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter rewriter,
return op.getODSResults(0).front();
}
// This is the rewritting of the HLFHE::ApplyLookupTable operation, it will be
// rewritten as 3 new operations:
// - Create the required GLWE ciphertext out of the plain lookup table
// - Keyswitch the input ciphertext to match the input key of the bootstrapping
// - Bootstrap the keyswitched ciphertext with the constructed GLWE ciphertext
// Example:
// from:
// ```
// "%result = MidLFHE.apply_lookup_table"(% arg0, % tlu){
// k = 1 : i32,
// polynomialSize = 2048 : i32,
// levelKS = 3 : i32,
// baseLogKS = 2 : i32,
// levelBS = 5 : i32,
// baseLogBS = 4 : i32,
// outputSizeKS = 600 : i32
// } : (!MidLFHE.glwe<{2048, 1, 64} {4}>, tensor<16xi4>)
// ->(!MidLFHE.glwe<{2048, 1, 64} {4}>)
// ```
// to:
// ```
// % accumulator =
// "LowLFHE.glwe_from_table"(
// % [[TABLE]]){k = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32}
// : (tensor<16xi4>)
// ->!LowLFHE.glwe_ciphertext
// % keyswitched = "LowLFHE.keyswitch_lwe"(% arg0){
// baseLog = 2 : i32,
// inputLweSize = 1 : i32,
// level = 3 : i32,
// outputLweSize = 600 : i32
// } : (!LowLFHE.lwe_ciphertext<2048, 4>)
// ->!LowLFHE.lwe_ciphertext<600, 4>
// % result = "LowLFHE.bootstrap_lwe"(% keyswitched, % accumulator){
// baseLog = 4 : i32,
// k = 1 : i32,
// level = 5 : i32,
// polynomialSize = 2048 : i32
// } : (!LowLFHE.lwe_ciphertext<600, 4>, !LowLFHE.glwe_ciphertext)
// ->!LowLFHE.lwe_ciphertext<2048, 4>
// ```
mlir::Value createPBS(mlir::PatternRewriter rewriter, mlir::Location loc,
mlir::Value ct, mlir::Value table, mlir::IntegerAttr k,
mlir::IntegerAttr polynomialSize,
@@ -178,10 +219,9 @@ mlir::Value createPBS(mlir::PatternRewriter rewriter, mlir::Location loc,
// convert result type
GLWECipherTextType glwe_type = result.getType().cast<GLWECipherTextType>();
LweCiphertextType lwe_type =
convertTypeGLWEToLWE(rewriter.getContext(), glwe_type);
convertTypeToLWE(rewriter.getContext(), glwe_type);
// fill the the table in the GLWE accumulator
mlir::IntegerAttr precision = mlir::IntegerAttr::get(
mlir::IntegerType::get(rewriter.getContext(), 32), glwe_type.getP());
mlir::IntegerAttr precision = rewriter.getI32IntegerAttr(glwe_type.getP());
mlir::Value accumulator =
rewriter
.create<mlir::zamalang::LowLFHE::GlweFromTable>(
@@ -191,8 +231,8 @@ mlir::Value createPBS(mlir::PatternRewriter rewriter, mlir::Location loc,
// keyswitch
auto ct_type = ct.getType().cast<GLWECipherTextType>();
mlir::SmallVector<mlir::Value, 1> ksArgs{ct};
mlir::SmallVector<mlir::NamedAttribute, 6> ksAttrs{
mlir::SmallVector<mlir::Value> ksArgs{ct};
mlir::SmallVector<mlir::NamedAttribute> ksAttrs{
mlir::NamedAttribute(
mlir::Identifier::get("inputLweSize", rewriter.getContext()), k),
mlir::NamedAttribute(
@@ -203,16 +243,17 @@ mlir::Value createPBS(mlir::PatternRewriter rewriter, mlir::Location loc,
mlir::NamedAttribute(
mlir::Identifier::get("baseLog", rewriter.getContext()), baseLogKS),
};
auto ksOutType = LweCiphertextType::get(
rewriter.getContext(), outputSizeKS.getInt(), ct_type.getP());
mlir::Value keyswitched =
rewriter
.create<mlir::zamalang::LowLFHE::KeySwitchLweOp>(
loc, convertTypeGLWEToLWE(rewriter.getContext(), ct_type), ksArgs,
ksAttrs)
.create<mlir::zamalang::LowLFHE::KeySwitchLweOp>(loc, ksOutType,
ksArgs, ksAttrs)
.result();
// bootstrap operation
mlir::SmallVector<mlir::Value, 2> bsArgs{keyswitched, accumulator};
mlir::SmallVector<mlir::NamedAttribute, 6> bsAttrs{
mlir::SmallVector<mlir::Value> bsArgs{keyswitched, accumulator};
mlir::SmallVector<mlir::NamedAttribute> bsAttrs{
mlir::NamedAttribute(mlir::Identifier::get("k", rewriter.getContext()),
k),
mlir::NamedAttribute(

View File

@@ -41,7 +41,7 @@ public:
CircuitGate inputGate(size_t pos) { return std::get<0>(inputs[pos]); }
CircuitGate outputGate(size_t pos) { return std::get<0>(outputs[pos]); }
void generateRuntimeContext() {
void initGlobalRuntimeContext() {
auto ksk = std::get<1>(this->keyswitchKeys["ksk_v0"]);
auto bsk = std::get<1>(this->bootstrapKeys["bsk_v0"]);
setGlobalRuntimeContext(createRuntimeContext(ksk, bsk));

View File

@@ -26,7 +26,7 @@ public:
};
mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op,
mlir::PatternRewriter &rewriter,
mlir::RewriterBase &rewriter,
llvm::StringRef funcName,
mlir::FunctionType funcType) {
// Looking for the `funcName` Operation
@@ -54,6 +54,271 @@ mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op,
return mlir::success();
}
// Set of functions to generate generic types.
// Generic types are used to add forward declarations without a specific type.
// For example, we may need to add LWE ciphertext of different dimensions, or
// allocate them. All the calls to the C API should be done using this generic
// types, and casting should then be performed back to the appropriate type.
inline mlir::zamalang::LowLFHE::LweCiphertextType
getGenericLweCiphertextType(mlir::MLIRContext *context) {
return mlir::zamalang::LowLFHE::LweCiphertextType::get(context, -1, -1);
}
inline mlir::zamalang::LowLFHE::GlweCiphertextType
getGenericGlweCiphertextType(mlir::MLIRContext *context) {
return mlir::zamalang::LowLFHE::GlweCiphertextType::get(context);
}
inline mlir::zamalang::LowLFHE::PlaintextType
getGenericPlaintextType(mlir::MLIRContext *context) {
return mlir::zamalang::LowLFHE::PlaintextType::get(context, -1);
}
inline mlir::zamalang::LowLFHE::PlaintextListType
getGenericPlaintextListType(mlir::MLIRContext *context) {
return mlir::zamalang::LowLFHE::PlaintextListType::get(context);
}
inline mlir::zamalang::LowLFHE::ForeignPlaintextListType
getGenericForeignPlaintextListType(mlir::MLIRContext *context) {
return mlir::zamalang::LowLFHE::ForeignPlaintextListType::get(context);
}
inline mlir::zamalang::LowLFHE::CleartextType
getGenericCleartextType(mlir::MLIRContext *context) {
return mlir::zamalang::LowLFHE::CleartextType::get(context, -1);
}
inline mlir::zamalang::LowLFHE::LweBootstrapKeyType
getGenericLweBootstrapKeyType(mlir::MLIRContext *context) {
return mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(context);
}
inline mlir::zamalang::LowLFHE::LweKeySwitchKeyType
getGenericLweKeySwitchKeyType(mlir::MLIRContext *context) {
return mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(context);
}
// Get the generic version of the type.
// Useful when iterating over a set of types.
mlir::Type getGenericType(mlir::Type baseType) {
if (baseType.isa<mlir::zamalang::LowLFHE::LweCiphertextType>()) {
return getGenericLweCiphertextType(baseType.getContext());
}
if (baseType.isa<mlir::zamalang::LowLFHE::PlaintextType>()) {
return getGenericPlaintextType(baseType.getContext());
}
if (baseType.isa<mlir::zamalang::LowLFHE::CleartextType>()) {
return getGenericCleartextType(baseType.getContext());
}
return baseType;
}
// Insert all forward declarations needed for the pass.
// Should generalize input and output types for all decalarations, and the
// pattern using them would be resposible for casting them to the appropriate
// type.
mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
mlir::IRRewriter &rewriter) {
auto genericLweCiphertextType =
getGenericLweCiphertextType(rewriter.getContext());
auto genericGlweCiphertextType =
getGenericGlweCiphertextType(rewriter.getContext());
auto genericPlaintextType = getGenericPlaintextType(rewriter.getContext());
auto genericPlaintextListType =
getGenericPlaintextListType(rewriter.getContext());
auto genericForeignPlaintextList =
getGenericForeignPlaintextListType(rewriter.getContext());
auto genericCleartextType = getGenericCleartextType(rewriter.getContext());
auto genericBSKType = getGenericLweBootstrapKeyType(rewriter.getContext());
auto genericKSKType = getGenericLweKeySwitchKeyType(rewriter.getContext());
auto errType = mlir::IndexType::get(rewriter.getContext());
// Insert forward declaration of allocate lwe ciphertext
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{
errType,
mlir::IntegerType::get(rewriter.getContext(), 32),
},
{genericLweCiphertextType});
if (insertForwardDeclaration(op, rewriter, "allocate_lwe_ciphertext_u64",
funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the add_lwe_ciphertexts function
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{
errType,
genericLweCiphertextType,
genericLweCiphertextType,
genericLweCiphertextType,
},
{});
if (insertForwardDeclaration(op, rewriter, "add_lwe_ciphertexts_u64",
funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the add_plaintext_lwe_ciphertext_u64 function
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{
errType,
genericLweCiphertextType,
genericLweCiphertextType,
genericPlaintextType,
},
{});
if (insertForwardDeclaration(op, rewriter,
"add_plaintext_lwe_ciphertext_u64", funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the mul_cleartext_lwe_ciphertext_u64 function
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{
errType,
genericLweCiphertextType,
genericLweCiphertextType,
genericCleartextType,
},
{});
if (insertForwardDeclaration(op, rewriter,
"mul_cleartext_lwe_ciphertext_u64", funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the negate_lwe_ciphertext_u64 function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{errType, genericLweCiphertextType, genericLweCiphertextType}, {});
if (insertForwardDeclaration(op, rewriter, "negate_lwe_ciphertext_u64",
funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the getBsk function
{
auto funcType =
mlir::FunctionType::get(rewriter.getContext(), {}, {genericBSKType});
if (insertForwardDeclaration(op, rewriter, "getGlobalBootstrapKey",
funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the bootstrap function
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{
errType,
genericBSKType,
genericLweCiphertextType,
genericLweCiphertextType,
genericGlweCiphertextType,
},
{});
if (insertForwardDeclaration(op, rewriter, "bootstrap_lwe_u64", funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the getKsk function
{
auto funcType =
mlir::FunctionType::get(rewriter.getContext(), {}, {genericKSKType});
if (insertForwardDeclaration(op, rewriter, "getGlobalKeyswitchKey",
funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the keyswitch function
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{
errType,
// ksk
genericKSKType,
// output ct
genericLweCiphertextType,
// input ct
genericLweCiphertextType,
},
{});
if (insertForwardDeclaration(op, rewriter, "keyswitch_lwe_u64", funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the alloc_glwe function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{
errType,
mlir::IntegerType::get(rewriter.getContext(), 32),
mlir::IntegerType::get(rewriter.getContext(), 32),
},
{genericGlweCiphertextType});
if (insertForwardDeclaration(op, rewriter, "allocate_glwe_ciphertext_u64",
funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the alloc_plaintext_list function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{errType, mlir::IntegerType::get(rewriter.getContext(), 32)},
{genericPlaintextListType});
if (insertForwardDeclaration(op, rewriter, "allocate_plaintext_list_u64",
funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the fill_plaintext_list function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{errType, genericPlaintextListType, genericForeignPlaintextList}, {});
if (insertForwardDeclaration(
op, rewriter, "fill_plaintext_list_with_expansion_u64", funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the add_plaintext_list_glwe function
{
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
{errType, genericGlweCiphertextType,
genericGlweCiphertextType,
genericPlaintextListType},
{});
if (insertForwardDeclaration(
op, rewriter, "add_plaintext_list_glwe_ciphertext_u64", funcType)
.failed()) {
return mlir::failure();
}
}
return mlir::success();
}
/// LowLFHEOpToConcreteCAPICallPattern<Op> match the `Op` Operation and
/// replace with a call to `funcName`, the funcName should be an external
/// function that was linked later. It insert the forward declaration of the
@@ -81,29 +346,7 @@ struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern<Op> {
matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
LowLFHEToConcreteCAPITypeConverter typeConverter;
auto errType = mlir::IndexType::get(rewriter.getContext());
// Insert forward declaration of the operator function
{
mlir::SmallVector<mlir::Type, 4> operands{errType,
op->getResultTypes().front()};
for (auto ty : op->getOperandTypes()) {
operands.push_back(typeConverter.convertType(ty));
}
auto funcType =
mlir::FunctionType::get(rewriter.getContext(), operands, {});
if (insertForwardDeclaration(op, rewriter, funcName, funcType).failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the alloc function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(), {errType, rewriter.getIndexType()},
{op->getResultTypes().front()});
if (insertForwardDeclaration(op, rewriter, allocName, funcType)
.failed()) {
return mlir::failure();
}
}
mlir::Type resultType = op->getResultTypes().front();
auto lweResultType =
resultType.cast<mlir::zamalang::LowLFHE::LweCiphertextType>();
@@ -114,18 +357,39 @@ struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern<Op> {
rewriter.getIndexAttr(0));
// Add the call to the allocation
auto lweSizeOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(lweResultType.getSize()));
op.getLoc(), rewriter.getI32IntegerAttr(lweResultType.getSize()));
mlir::SmallVector<mlir::Value> allocOperands{errOp, lweSizeOp};
auto alloc = rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, allocName, op.getType(), allocOperands);
// Add err and allocated value to operands
mlir::SmallVector<mlir::Value, 4> newOperands{errOp, alloc.getResult(0)};
for (auto operand : op->getOperands()) {
newOperands.push_back(operand);
auto allocGeneric = rewriter.create<mlir::CallOp>(
op.getLoc(), allocName,
getGenericLweCiphertextType(rewriter.getContext()), allocOperands);
// Construct operands for the operation.
// errOp doesn't need to be casted to something generic, allocGeneric
// already is. All the rest will be converted if needed
mlir::SmallVector<mlir::Value, 4> newOperands{errOp,
allocGeneric.getResult(0)};
for (mlir::Value operand : op->getOperands()) {
mlir::Type operandType = operand.getType();
mlir::Type castedType = getGenericType(operandType);
if (castedType == operandType) {
// Type didn't change, no need for cast
newOperands.push_back(operand);
} else {
// Type changed, need to cast to the generic one
auto castedOperand = rewriter
.create<mlir::UnrealizedConversionCastOp>(
op.getLoc(), castedType, operand)
.getResult(0);
newOperands.push_back(castedOperand);
}
}
// The operations called here are known to be inplace, and no need for a
// return type.
rewriter.create<mlir::CallOp>(op.getLoc(), funcName, mlir::TypeRange{},
newOperands);
// cast result value to the appropriate type
auto alloc =
rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
op, op.getType(), allocGeneric.getResult(0));
}
return mlir::success();
};
@@ -145,32 +409,24 @@ struct LowLFHEZeroOpPattern
mlir::LogicalResult
matchAndRewrite(mlir::zamalang::LowLFHE::ZeroLWEOp op,
mlir::PatternRewriter &rewriter) const override {
auto allocName = "allocate_lwe_ciphertext_u64";
auto errType = mlir::IndexType::get(rewriter.getContext());
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(), {errType, rewriter.getIndexType()},
{op->getResultTypes().front()});
if (insertForwardDeclaration(op, rewriter, allocName, funcType)
.failed()) {
return mlir::failure();
}
}
// Replace the operation with a call to the `funcName`
{
mlir::Type resultType = op->getResultTypes().front();
auto lweResultType =
resultType.cast<mlir::zamalang::LowLFHE::LweCiphertextType>();
// Create the err value
auto errOp = rewriter.create<mlir::ConstantOp>(op.getLoc(),
rewriter.getIndexAttr(0));
// Add the call to the allocation
auto lweSizeOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(lweResultType.getSize()));
mlir::SmallVector<mlir::Value> allocOperands{errOp, lweSizeOp};
auto alloc = rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, allocName, op.getType(), allocOperands);
}
mlir::Type resultType = op->getResultTypes().front();
auto lweResultType =
resultType.cast<mlir::zamalang::LowLFHE::LweCiphertextType>();
// Create the err value
auto errOp = rewriter.create<mlir::ConstantOp>(op.getLoc(),
rewriter.getIndexAttr(0));
// Allocate a fresh new ciphertext
auto lweSizeOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(), rewriter.getI32IntegerAttr(lweResultType.getSize()));
mlir::SmallVector<mlir::Value> allocOperands{errOp, lweSizeOp};
auto allocGeneric = rewriter.create<mlir::CallOp>(
op.getLoc(), "allocate_lwe_ciphertext_u64",
getGenericLweCiphertextType(rewriter.getContext()), allocOperands);
// Cast the result to the appropriate type
rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
op, op.getType(), allocGeneric.getResult(0));
return mlir::success();
};
};
@@ -215,6 +471,14 @@ struct LowLFHEIntToCleartextOpPattern
};
};
// Rewrite the GlweFromTable operation to a series of ops:
// - allocation of two GLWE, one for the addition, and one for storing the
// result
// - allocation of plaintext_list to build the GLWE accumulator
// - build the foreign_plaintext_list using the input table
// - fill the plaintext_list with the foreign_plaintext_list
// - construct the GLWE accumulator by adding the plaintext_list to a freshly
// allocated GLWE
struct GlweFromTableOpPattern
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::GlweFromTable> {
GlweFromTableOpPattern(mlir::MLIRContext *context,
@@ -227,49 +491,19 @@ struct GlweFromTableOpPattern
mlir::PatternRewriter &rewriter) const override {
LowLFHEToConcreteCAPITypeConverter typeConverter;
auto errType = mlir::IndexType::get(rewriter.getContext());
// Insert forward declaration of the alloc_glwe function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{
errType,
mlir::IntegerType::get(rewriter.getContext(), 32),
mlir::IntegerType::get(rewriter.getContext(), 32),
},
{mlir::zamalang::LowLFHE::GlweCiphertextType::get(
rewriter.getContext())});
if (insertForwardDeclaration(op, rewriter, "allocate_glwe_ciphertext_u64",
funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the alloc_plaintext_list function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{errType, mlir::IntegerType::get(rewriter.getContext(), 32)},
{mlir::zamalang::LowLFHE::PlaintextListType::get(
rewriter.getContext())});
if (insertForwardDeclaration(op, rewriter, "allocate_plaintext_list_u64",
funcType)
.failed()) {
return mlir::failure();
}
}
// TODO: move this to insertForwardDeclarations
// issue: can't define function with tensor<*xtype> that accept ranked
// tensors
// Insert forward declaration of the foregin_pt_list function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{errType,
// mlir::UnrankedTensorType::get(
// mlir::IntegerType::get(rewriter.getContext(), 64)),
op->getOperandTypes().front(),
{errType, op->getOperandTypes().front(),
mlir::IntegerType::get(rewriter.getContext(), 64),
mlir::IntegerType::get(rewriter.getContext(), 32)},
{mlir::zamalang::LowLFHE::ForeignPlaintextListType::get(
rewriter.getContext())});
{getGenericForeignPlaintextListType(rewriter.getContext())});
if (insertForwardDeclaration(
op, rewriter, "runtime_foreign_plaintext_list_u64", funcType)
.failed()) {
@@ -277,41 +511,6 @@ struct GlweFromTableOpPattern
}
}
// Insert forward declaration of the fill_plaintext_list function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{errType,
mlir::zamalang::LowLFHE::PlaintextListType::get(
rewriter.getContext()),
mlir::zamalang::LowLFHE::ForeignPlaintextListType::get(
rewriter.getContext())},
{});
if (insertForwardDeclaration(
op, rewriter, "fill_plaintext_list_with_expansion_u64", funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the add_plaintext_list_glwe function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{errType,
mlir::zamalang::LowLFHE::GlweCiphertextType::get(
rewriter.getContext()),
mlir::zamalang::LowLFHE::GlweCiphertextType::get(
rewriter.getContext()),
mlir::zamalang::LowLFHE::PlaintextListType::get(
rewriter.getContext())},
{});
if (insertForwardDeclaration(
op, rewriter, "add_plaintext_list_glwe_ciphertext_u64", funcType)
.failed()) {
return mlir::failure();
}
}
auto errOp = rewriter.create<mlir::ConstantOp>(op.getLoc(),
rewriter.getIndexAttr(0));
// allocate two glwe to build accumulator
@@ -324,39 +523,33 @@ struct GlweFromTableOpPattern
// first accumulator would replace the op since it's the returned value
auto accumulatorOp = rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, "allocate_glwe_ciphertext_u64",
mlir::zamalang::LowLFHE::GlweCiphertextType::get(rewriter.getContext()),
allocGlweOperands);
getGenericGlweCiphertextType(rewriter.getContext()), allocGlweOperands);
// second accumulator is just needed to build the actual accumulator
auto _accumulatorOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "allocate_glwe_ciphertext_u64",
mlir::zamalang::LowLFHE::GlweCiphertextType::get(rewriter.getContext()),
allocGlweOperands);
getGenericGlweCiphertextType(rewriter.getContext()), allocGlweOperands);
// allocate plaintext list
mlir::SmallVector<mlir::Value> allocPlaintextListOperands{errOp,
polySizeOp};
auto plaintextListOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "allocate_plaintext_list_u64",
mlir::zamalang::LowLFHE::PlaintextListType::get(rewriter.getContext()),
getGenericPlaintextListType(rewriter.getContext()),
allocPlaintextListOperands);
// create foreign plaintext
auto rankedTensorType =
op->getOperandTypes().front().cast<mlir::RankedTensorType>();
if (rankedTensorType.getRank() != 1) {
llvm::errs() << "table lookup must be of a single dimension";
return mlir::failure();
}
assert(rankedTensorType.getRank() == 1 &&
"table lookup must be of a single dimension");
auto sizeOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(), rewriter.getIntegerAttr(
mlir::IntegerType::get(rewriter.getContext(), 64),
rankedTensorType.getDimSize(0)));
op.getLoc(),
rewriter.getI64IntegerAttr(rankedTensorType.getDimSize(0)));
auto precisionOp =
rewriter.create<mlir::ConstantOp>(op.getLoc(), op->getAttr("p"));
mlir::SmallVector<mlir::Value> ForeignPlaintextListOperands{
errOp, op->getOperand(0), sizeOp, precisionOp};
auto foreignPlaintextListOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "runtime_foreign_plaintext_list_u64",
mlir::zamalang::LowLFHE::ForeignPlaintextListType::get(
rewriter.getContext()),
getGenericForeignPlaintextListType(rewriter.getContext()),
ForeignPlaintextListOperands);
// fill plaintext list
mlir::SmallVector<mlir::Value> FillPlaintextListOperands{
@@ -376,6 +569,11 @@ struct GlweFromTableOpPattern
};
};
// Rewrite a BootstrapLweOp with a series of ops:
// - allocate the result LWE ciphertext
// - get the global bootstrapping key
// - use the key and the input accumulator (GLWE) to bootstrap the input
// ciphertext
struct LowLFHEBootstrapLweOpPattern
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::BootstrapLweOp> {
LowLFHEBootstrapLweOpPattern(mlir::MLIRContext *context,
@@ -386,141 +584,58 @@ struct LowLFHEBootstrapLweOpPattern
mlir::LogicalResult
matchAndRewrite(mlir::zamalang::LowLFHE::BootstrapLweOp op,
mlir::PatternRewriter &rewriter) const override {
auto errType = mlir::IndexType::get(rewriter.getContext());
auto lweOperandType = op->getOperandTypes().front();
// Insert forward declaration of the allocate_bsk_key function
// {
// auto funcType = mlir::FunctionType::get(
// rewriter.getContext(),
// {
// errType,
// // level
// mlir::IntegerType::get(rewriter.getContext(), 32),
// // baselog
// mlir::IntegerType::get(rewriter.getContext(), 32),
// // glwe size
// mlir::IntegerType::get(rewriter.getContext(), 32),
// // lwe size
// mlir::IntegerType::get(rewriter.getContext(), 32),
// // polynomial size
// mlir::IntegerType::get(rewriter.getContext(), 32),
// },
// {mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(
// rewriter.getContext())});
// if (insertForwardDeclaration(op, rewriter,
// "allocate_lwe_bootstrap_key_u64",
// funcType)
// .failed()) {
// return mlir::failure();
// }
// }
// Insert forward declaration of the getBsk function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(), {},
{mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(
rewriter.getContext())});
if (insertForwardDeclaration(op, rewriter, "getGlobalBootstrapKey",
funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the allocate_lwe_ct function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{
errType,
mlir::IntegerType::get(rewriter.getContext(), 32),
},
{lweOperandType});
if (insertForwardDeclaration(op, rewriter, "allocate_lwe_ciphertext_u64",
funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the bootstrap function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{
errType,
mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(
rewriter.getContext()),
lweOperandType,
lweOperandType,
mlir::zamalang::LowLFHE::GlweCiphertextType::get(
rewriter.getContext()),
},
{});
if (insertForwardDeclaration(op, rewriter, "bootstrap_lwe_u64", funcType)
.failed()) {
return mlir::failure();
}
}
auto resultType = op->getResultTypes().front();
auto bstOutputSize =
resultType.cast<mlir::zamalang::LowLFHE::LweCiphertextType>().getSize();
auto errOp = rewriter.create<mlir::ConstantOp>(op.getLoc(),
rewriter.getIndexAttr(0));
// allocate the result lwe ciphertext
// TODO: use right value for output lwe size
// LweSize output_lwe_size = { (glwe_size._0 -1) * poly_size._0 + 1}
// allocate the result lwe ciphertext, should be of a generic type, to cast
// before return
auto lweSizeOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(), mlir::IntegerAttr::get(
mlir::IntegerType::get(rewriter.getContext(), 32),
op->getAttr("k").cast<mlir::IntegerAttr>().getInt()));
op.getLoc(),
mlir::IntegerAttr::get(
mlir::IntegerType::get(rewriter.getContext(), 32), bstOutputSize));
mlir::SmallVector<mlir::Value> allocLweCtOperands{errOp, lweSizeOp};
auto allocateLweCtOp = rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, "allocate_lwe_ciphertext_u64", lweOperandType, allocLweCtOperands);
// allocate bsk
// auto decompLevelCountOp = rewriter.create<mlir::ConstantOp>(
// op.getLoc(),
// mlir::IntegerAttr::get(
// mlir::IntegerType::get(rewriter.getContext(), 32),
// op->getAttr("level").cast<mlir::IntegerAttr>().getInt()));
// auto decompBaseLogOp = rewriter.create<mlir::ConstantOp>(
// op.getLoc(),
// mlir::IntegerAttr::get(
// mlir::IntegerType::get(rewriter.getContext(), 32),
// op->getAttr("baseLog").cast<mlir::IntegerAttr>().getInt()));
// auto glweSizeOp = rewriter.create<mlir::ConstantOp>(
// op.getLoc(),
// mlir::IntegerAttr::get(
// mlir::IntegerType::get(rewriter.getContext(), 32), -1));
// auto polySizeOp = rewriter.create<mlir::ConstantOp>(
// op.getLoc(),
// mlir::IntegerAttr::get(
// mlir::IntegerType::get(rewriter.getContext(), 32),
// op->getAttr("polynomialSize").cast<mlir::IntegerAttr>().getInt()));
// mlir::SmallVector<mlir::Value> allocBskOperands{
// errOp, decompLevelCountOp, decompBaseLogOp,
// glweSizeOp, lweSizeOp, polySizeOp};
// auto allocateBskOp = rewriter.create<mlir::CallOp>(
// op.getLoc(), "allocate_lwe_bootstrap_key_u64",
// mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(
// rewriter.getContext()),
// allocBskOperands);
auto allocateGenericLweCtOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "allocate_lwe_ciphertext_u64",
getGenericLweCiphertextType(rewriter.getContext()), allocLweCtOperands);
// get bsk
mlir::SmallVector<mlir::Value> getBskOperands{};
auto getBskOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "getGlobalBootstrapKey",
mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(
rewriter.getContext()),
getBskOperands);
getGenericLweBootstrapKeyType(rewriter.getContext()), getBskOperands);
// bootstrap
// cast input ciphertext to a generic type
mlir::Value lweToBootstrap =
rewriter
.create<mlir::UnrealizedConversionCastOp>(
op.getLoc(), getGenericType(op.getOperand(0).getType()),
op.getOperand(0))
.getResult(0);
// cast input accumulator to a generic type
mlir::Value accumulator =
rewriter
.create<mlir::UnrealizedConversionCastOp>(
op.getLoc(), getGenericType(op.getOperand(1).getType()),
op.getOperand(1))
.getResult(0);
mlir::SmallVector<mlir::Value> bootstrapOperands{
errOp, getBskOp.getResult(0), allocateLweCtOp.getResult(0),
op->getOperand(0), op->getOperand(1)};
errOp, getBskOp.getResult(0), allocateGenericLweCtOp.getResult(0),
lweToBootstrap, accumulator};
rewriter.create<mlir::CallOp>(op.getLoc(), "bootstrap_lwe_u64",
mlir::TypeRange({}), bootstrapOperands);
// Cast result to the appropriate type
rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
op, resultType, allocateGenericLweCtOp.getResult(0));
return mlir::success();
};
};
// Rewrite a KeySwitchLweOp with a series of ops:
// - allocate the result LWE ciphertext
// - get the global keyswitch key
// - use the key to keyswitch the input ciphertext
struct LowLFHEKeySwitchLweOpPattern
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::KeySwitchLweOp> {
LowLFHEKeySwitchLweOpPattern(mlir::MLIRContext *context,
@@ -531,139 +646,41 @@ struct LowLFHEKeySwitchLweOpPattern
mlir::LogicalResult
matchAndRewrite(mlir::zamalang::LowLFHE::KeySwitchLweOp op,
mlir::PatternRewriter &rewriter) const override {
auto errType = mlir::IndexType::get(rewriter.getContext());
auto lweOperandType = op->getOperandTypes().front();
// Insert forward declaration of the allocate_ksk_key function
// {
// auto funcType = mlir::FunctionType::get(
// rewriter.getContext(),
// {
// errType,
// // level
// mlir::IntegerType::get(rewriter.getContext(), 32),
// // baselog
// mlir::IntegerType::get(rewriter.getContext(), 32),
// // input lwe size
// mlir::IntegerType::get(rewriter.getContext(), 32),
// // output lwe size
// mlir::IntegerType::get(rewriter.getContext(), 32),
// },
// {mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(
// rewriter.getContext())});
// if (insertForwardDeclaration(op, rewriter,
// "allocate_lwe_keyswitch_key_u64",
// funcType)
// .failed()) {
// return mlir::failure();
// }
// }
// Insert forward declaration of the getKsk function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(), {},
{mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(
rewriter.getContext())});
if (insertForwardDeclaration(op, rewriter, "getGlobalKeyswitchKey",
funcType)
.failed()) {
return mlir::failure();
}
}
// Insert forward declaration of the allocate_lwe_ct function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{
errType,
mlir::IntegerType::get(rewriter.getContext(), 32),
},
{lweOperandType});
if (insertForwardDeclaration(op, rewriter, "allocate_lwe_ciphertext_u64",
funcType)
.failed()) {
return mlir::failure();
}
}
// TODO: build the right type here
auto lweOutputType = lweOperandType;
// Insert forward declaration of the keyswitch function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{
errType,
// ksk
mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(
rewriter.getContext()),
// output ct
lweOutputType,
// input ct
lweOperandType,
},
{});
if (insertForwardDeclaration(op, rewriter, "keyswitch_lwe_u64", funcType)
.failed()) {
return mlir::failure();
}
}
auto errOp = rewriter.create<mlir::ConstantOp>(op.getLoc(),
rewriter.getIndexAttr(0));
// allocate the result lwe ciphertext
// allocate the result lwe ciphertext, should be of a generic type, to cast
// before return
auto lweSizeOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(),
mlir::IntegerAttr::get(
mlir::IntegerType::get(rewriter.getContext(), 32),
op->getAttr("outputLweSize").cast<mlir::IntegerAttr>().getInt()));
mlir::SmallVector<mlir::Value> allocLweCtOperands{errOp, lweSizeOp};
auto allocateLweCtOp = rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, "allocate_lwe_ciphertext_u64", lweOutputType, allocLweCtOperands);
// allocate ksk
// auto decompLevelCountOp = rewriter.create<mlir::ConstantOp>(
// op.getLoc(),
// mlir::IntegerAttr::get(
// mlir::IntegerType::get(rewriter.getContext(), 32),
// op->getAttr("level").cast<mlir::IntegerAttr>().getInt()));
// auto decompBaseLogOp = rewriter.create<mlir::ConstantOp>(
// op.getLoc(),
// mlir::IntegerAttr::get(
// mlir::IntegerType::get(rewriter.getContext(), 32),
// op->getAttr("baseLog").cast<mlir::IntegerAttr>().getInt()));
// auto inputLweSizeOp = rewriter.create<mlir::ConstantOp>(
// op.getLoc(),
// mlir::IntegerAttr::get(
// mlir::IntegerType::get(rewriter.getContext(), 32),
// op->getAttr("inputLweSize").cast<mlir::IntegerAttr>().getInt()));
// auto outputLweSizeOp = rewriter.create<mlir::ConstantOp>(
// op.getLoc(),
// mlir::IntegerAttr::get(
// mlir::IntegerType::get(rewriter.getContext(), 32),
// op->getAttr("outputLweSize").cast<mlir::IntegerAttr>().getInt()));
// mlir::SmallVector<mlir::Value> allockskOperands{
// errOp, decompLevelCountOp, decompBaseLogOp, inputLweSizeOp,
// outputLweSizeOp};
// auto allocateKskOp = rewriter.create<mlir::CallOp>(
// op.getLoc(), "allocate_lwe_keyswitch_key_u64",
// mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(
// rewriter.getContext()),
// allockskOperands);
auto allocateGenericLweCtOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "allocate_lwe_ciphertext_u64",
getGenericLweCiphertextType(rewriter.getContext()), allocLweCtOperands);
// get ksk
mlir::SmallVector<mlir::Value> getkskOperands{};
auto getKskOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "getGlobalKeyswitchKey",
mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(
rewriter.getContext()),
getkskOperands);
getGenericLweKeySwitchKeyType(rewriter.getContext()), getkskOperands);
// keyswitch
// cast input ciphertext to a generic type
mlir::Value lweToKeyswitch =
rewriter
.create<mlir::UnrealizedConversionCastOp>(
op.getLoc(), getGenericType(op.getOperand().getType()),
op.getOperand())
.getResult(0);
mlir::SmallVector<mlir::Value> keyswitchOperands{
errOp, getKskOp.getResult(0), allocateLweCtOp.getResult(0),
op->getOperand(0)};
errOp, getKskOp.getResult(0), allocateGenericLweCtOp.getResult(0),
lweToKeyswitch};
rewriter.create<mlir::CallOp>(op.getLoc(), "keyswitch_lwe_u64",
mlir::TypeRange({}), keyswitchOperands);
// Cast result to the appropriate type
auto lweOutputType = op->getResultTypes().front();
rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
op, lweOutputType, allocateGenericLweCtOp.getResult(0));
return mlir::success();
};
};
@@ -713,8 +730,14 @@ void LowLFHEToConcreteCAPIPass::runOnOperation() {
mlir::RewritePatternSet patterns(&getContext());
populateLowLFHEToConcreteCAPICall(patterns);
// Apply the conversion
// Insert forward declarations
mlir::ModuleOp op = getOperation();
mlir::IRRewriter rewriter(&getContext());
if (insertForwardDeclarations(op, rewriter).failed()) {
this->signalPassFailure();
}
// Apply the conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
this->signalPassFailure();
}

View File

@@ -35,6 +35,7 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
// legalize LLVM dialect.
mlir::LLVMConversionTarget target(getContext());
target.addLegalOp<mlir::ModuleOp>();
target.addIllegalOp<mlir::UnrealizedConversionCastOp>();
// Setup the LLVMTypeConverter (that converts `std` types to `llvm` types) and
// add our types conversion to `llvm` compatible type.

View File

@@ -107,7 +107,7 @@ bool verifyEncryptedIntegerInputsConsistency(::mlir::OpState &op,
"`ct` argument.";
return mlir::failure();
}
// Disable this check for the moment
// Disable this check for the moment: issue/111
// Check the witdh of the encrypted integer and the integer of the tabulated
// lambda are equals
// if (ct.getWidth() != l_cst.getElementType().cast<IntegerType>().getWidth())

View File

@@ -123,7 +123,7 @@ mlir::LogicalResult verifyApplyLookupTable(ApplyLookupTable &op) {
"`ct` argument.";
return mlir::failure();
}
// Disable this check for the moment
// Disable this check for the moment: issue/111
// Check the witdh of the encrypted integer and the integer of the tabulated
// lambda are equals
// if (result.getP() < l_cst.getElementType().cast<IntegerType>().getWidth())

View File

@@ -236,7 +236,7 @@ JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
}
// Setup runtime context with appropriate keys
keySet.generateRuntimeContext();
keySet.initGlobalRuntimeContext();
}
JITLambda::Argument::~Argument() {

View File

@@ -1,17 +1,30 @@
// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api %s 2>&1| FileCheck %s
// CHECK-LABEL: module
// CHECK-NEXT: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext)
// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, i32) -> !LowLFHE.lwe_ciphertext<1024,4>
// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)
// CHECK-NEXT: func private @fill_plaintext_list_with_expansion_u64(index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list)
// CHECK-NEXT: func private @allocate_plaintext_list_u64(index, i32) -> !LowLFHE.plaintext_list
// CHECK-NEXT: func private @allocate_glwe_ciphertext_u64(index, i32, i32) -> !LowLFHE.glwe_ciphertext
// CHECK-NEXT: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
// CHECK-NEXT: func private @getGlobalKeyswitchKey() -> !LowLFHE.lwe_key_switch_key
// CHECK-NEXT: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext)
// CHECK-NEXT: func private @getGlobalBootstrapKey() -> !LowLFHE.lwe_bootstrap_key
// CHECK-NEXT: func private @negate_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
// CHECK-NEXT: func private @mul_cleartext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.cleartext<_>)
// CHECK-NEXT: func private @add_plaintext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.plaintext<_>)
// CHECK-NEXT: func private @add_lwe_ciphertexts_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, i32) -> !LowLFHE.lwe_ciphertext<_,_>
// CHECK-LABEL: func @bootstrap_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4>
func @bootstrap_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4> {
// CHECK-NEXT: %[[ERR:.*]] = constant 0 : index
// CHECK-NEXT: %[[C0:.*]] = constant 1 : i32
// CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[ERR]], %[[C0]]) : (index, i32) -> !LowLFHE.lwe_ciphertext<1024,4>
// CHECK-NEXT: %[[C0:.*]] = constant 1024 : i32
// CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[ERR]], %[[C0]]) : (index, i32) -> !LowLFHE.lwe_ciphertext<_,_>
// CHECK-NEXT: %[[V2:.*]] = call @getGlobalBootstrapKey() : () -> !LowLFHE.lwe_bootstrap_key
// CHECK-NEXT: call @bootstrap_lwe_u64(%[[ERR]], %[[V2]], %[[V1]], %arg0, %arg1) : (index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> ()
// CHECK-NEXT: return %[[V1]] : !LowLFHE.lwe_ciphertext<1024,4>
// CHECK-NEXT: %[[V3:.*]] = unrealized_conversion_cast %arg0 : !LowLFHE.lwe_ciphertext<1024,4> to !LowLFHE.lwe_ciphertext<_,_>
// CHECK-NEXT: %[[V4:.*]] = unrealized_conversion_cast %arg1 : !LowLFHE.glwe_ciphertext to !LowLFHE.glwe_ciphertext
// CHECK-NEXT: call @bootstrap_lwe_u64(%[[ERR]], %[[V2]], %[[V1]], %[[V3]], %[[V4]]) : (index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext) -> ()
// CHECK-NEXT: %[[RES:.*]] = unrealized_conversion_cast %[[V1]] : !LowLFHE.lwe_ciphertext<_,_> to !LowLFHE.lwe_ciphertext<1024,4>
// CHECK-NEXT: return %[[RES]] : !LowLFHE.lwe_ciphertext<1024,4>
%1 = "LowLFHE.bootstrap_lwe"(%arg0, %arg1) {baseLog = 2 : i32, k = 1 : i32, level = 3 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4>
return %1: !LowLFHE.lwe_ciphertext<1024,4>
}

View File

@@ -1,13 +1,22 @@
// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api %s 2>&1| FileCheck %s
// CHECK-LABEL: module
// CHECK-NEXT: func private @runtime_foreign_plaintext_list_u64(index, tensor<16xi64>, i64, i32) -> !LowLFHE.foreign_plaintext_list
// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)
// CHECK-NEXT: func private @fill_plaintext_list_with_expansion_u64(index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list)
// CHECK-NEXT: func private @runtime_foreign_plaintext_list_u64(index, tensor<16xi4>, i64, i32) -> !LowLFHE.foreign_plaintext_list
// CHECK-NEXT: func private @allocate_plaintext_list_u64(index, i32) -> !LowLFHE.plaintext_list
// CHECK-NEXT: func private @allocate_glwe_ciphertext_u64(index, i32, i32) -> !LowLFHE.glwe_ciphertext
// CHECK-LABEL: func @glwe_from_table(%arg0: tensor<16xi4>) -> !LowLFHE.glwe_ciphertext
func @glwe_from_table(%arg0: tensor<16xi4>) -> !LowLFHE.glwe_ciphertext {
// CHECK-NEXT: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
// CHECK-NEXT: func private @getGlobalKeyswitchKey() -> !LowLFHE.lwe_key_switch_key
// CHECK-NEXT: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext)
// CHECK-NEXT: func private @getGlobalBootstrapKey() -> !LowLFHE.lwe_bootstrap_key
// CHECK-NEXT: func private @negate_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
// CHECK-NEXT: func private @mul_cleartext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.cleartext<_>)
// CHECK-NEXT: func private @add_plaintext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.plaintext<_>)
// CHECK-NEXT: func private @add_lwe_ciphertexts_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, i32) -> !LowLFHE.lwe_ciphertext<_,_>
// CHECK-LABEL: func @glwe_from_table(%arg0: tensor<16xi64>) -> !LowLFHE.glwe_ciphertext
func @glwe_from_table(%arg0: tensor<16xi64>) -> !LowLFHE.glwe_ciphertext {
// CHECK-NEXT: %[[V0:.*]] = constant 0 : index
// CHECK-NEXT: %[[C0:.*]] = constant 1 : i32
// CHECK-NEXT: %[[C1:.*]] = constant 1024 : i32
@@ -16,10 +25,10 @@ func @glwe_from_table(%arg0: tensor<16xi4>) -> !LowLFHE.glwe_ciphertext {
// CHECK-NEXT: %[[V3:.*]] = call @allocate_plaintext_list_u64(%[[V0]], %[[C1]]) : (index, i32) -> !LowLFHE.plaintext_list
// CHECK-NEXT: %[[C2:.*]] = constant 16 : i64
// CHECK-NEXT: %[[C3:.*]] = constant 4 : i32
// CHECK-NEXT: %[[V4:.*]] = call @runtime_foreign_plaintext_list_u64(%[[V0]], %arg0, %[[C2]], %[[C3]]) : (index, tensor<16xi4>, i64, i32) -> !LowLFHE.foreign_plaintext_list
// CHECK-NEXT: %[[V4:.*]] = call @runtime_foreign_plaintext_list_u64(%[[V0]], %arg0, %[[C2]], %[[C3]]) : (index, tensor<16xi64>, i64, i32) -> !LowLFHE.foreign_plaintext_list
// CHECK-NEXT: call @fill_plaintext_list_with_expansion_u64(%[[V0]], %[[V3]], %[[V4]]) : (index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list) -> ()
// CHECK-NEXT: call @add_plaintext_list_glwe_ciphertext_u64(%[[V0]], %[[V1]], %[[V2]], %[[V3]]) : (index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list) -> ()
// CHECK-NEXT: return %[[V1]] : !LowLFHE.glwe_ciphertext
%1 = "LowLFHE.glwe_from_table"(%arg0) {k = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi4>) -> !LowLFHE.glwe_ciphertext
%1 = "LowLFHE.glwe_from_table"(%arg0) {k = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !LowLFHE.glwe_ciphertext
return %1: !LowLFHE.glwe_ciphertext
}

View File

@@ -1,16 +1,29 @@
// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api %s 2>&1| FileCheck %s
// CHECK-LABEL: module
// CHECK-NEXT: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>)
// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, i32) -> !LowLFHE.lwe_ciphertext<1024,4>
// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)
// CHECK-NEXT: func private @fill_plaintext_list_with_expansion_u64(index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list)
// CHECK-NEXT: func private @allocate_plaintext_list_u64(index, i32) -> !LowLFHE.plaintext_list
// CHECK-NEXT: func private @allocate_glwe_ciphertext_u64(index, i32, i32) -> !LowLFHE.glwe_ciphertext
// CHECK-NEXT: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
// CHECK-NEXT: func private @getGlobalKeyswitchKey() -> !LowLFHE.lwe_key_switch_key
// CHECK-NEXT: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext)
// CHECK-NEXT: func private @getGlobalBootstrapKey() -> !LowLFHE.lwe_bootstrap_key
// CHECK-NEXT: func private @negate_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
// CHECK-NEXT: func private @mul_cleartext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.cleartext<_>)
// CHECK-NEXT: func private @add_plaintext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.plaintext<_>)
// CHECK-NEXT: func private @add_lwe_ciphertexts_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, i32) -> !LowLFHE.lwe_ciphertext<_,_>
// CHECK-LABEL: func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4>
func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> {
// CHECK-NEXT: %[[ERR:.*]] = constant 0 : index
// CHECK-NEXT: %[[C0:.*]] = constant 1 : i32
// CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[ERR]], %[[C0]]) : (index, i32) -> !LowLFHE.lwe_ciphertext<1024,4>
// CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[ERR]], %[[C0]]) : (index, i32) -> !LowLFHE.lwe_ciphertext<_,_>
// CHECK-NEXT: %[[V2:.*]] = call @getGlobalKeyswitchKey() : () -> !LowLFHE.lwe_key_switch_key
// CHECK-NEXT: call @keyswitch_lwe_u64(%[[ERR]], %[[V2]], %[[V1]], %arg0) : (index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>) -> ()
// CHECK-NEXT: return %[[V1]] : !LowLFHE.lwe_ciphertext<1024,4>
// CHECK-NEXT: %[[V3:.*]] = unrealized_conversion_cast %arg0 : !LowLFHE.lwe_ciphertext<1024,4> to !LowLFHE.lwe_ciphertext<_,_>
// CHECK-NEXT: call @keyswitch_lwe_u64(%[[ERR]], %[[V2]], %[[V1]], %[[V3]]) : (index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) -> ()
// CHECK-NEXT: %[[RES:.*]] = unrealized_conversion_cast %[[V1]] : !LowLFHE.lwe_ciphertext<_,_> to !LowLFHE.lwe_ciphertext<1024,4>
// CHECK-NEXT: return %[[RES]] : !LowLFHE.lwe_ciphertext<1024,4>
%1 = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 1 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4>
return %1: !LowLFHE.lwe_ciphertext<1024,4>
}

View File

@@ -3,8 +3,8 @@
// CHECK-LABEL: func @apply_lookup_table(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: tensor<16xi4>) -> !LowLFHE.lwe_ciphertext<1024,4>
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi4>) -> !MidLFHE.glwe<{1024,1,64}{4}> {
// CHECK-NEXT: %[[V1:.*]] = "LowLFHE.glwe_from_table"(%arg1) {k = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi4>) -> !LowLFHE.glwe_ciphertext
// CHECK-NEXT: %[[V2:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 600 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4>
// CHECK-NEXT: %[[V3:.*]] = "LowLFHE.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, k = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4>
// CHECK-NEXT: %[[V2:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 600 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<600,4>
// CHECK-NEXT: %[[V3:.*]] = "LowLFHE.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, k = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<600,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4>
// CHECK-NEXT: return %[[V3]] : !LowLFHE.lwe_ciphertext<1024,4>
%1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1){k=1:i32, polynomialSize=1024:i32, levelKS=3:i32, baseLogKS=2:i32, levelBS=5:i32, baseLogBS=4:i32, outputSizeKS=600:i32}: (!MidLFHE.glwe<{1024,1,64}{4}>, tensor<16xi4>) -> (!MidLFHE.glwe<{1024,1,64}{4}>)
return %1: !MidLFHE.glwe<{1024,1,64}{4}>

View File

@@ -4,8 +4,8 @@
func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{2048,1,64}{4}>) -> !MidLFHE.glwe<{2048,1,64}{4}> {
// CHECK-NEXT: %[[TABLE:.*]] = constant dense<[0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1]> : tensor<16xi4>
// CHECK-NEXT: %[[V1:.*]] = "LowLFHE.glwe_from_table"(%[[TABLE]]) {k = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32} : (tensor<16xi4>) -> !LowLFHE.glwe_ciphertext
// CHECK-NEXT: %[[V2:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 600 : i32} : (!LowLFHE.lwe_ciphertext<2048,4>) -> !LowLFHE.lwe_ciphertext<2048,4>
// CHECK-NEXT: %[[V3:.*]] = "LowLFHE.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, k = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (!LowLFHE.lwe_ciphertext<2048,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<2048,4>
// CHECK-NEXT: %[[V2:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 600 : i32} : (!LowLFHE.lwe_ciphertext<2048,4>) -> !LowLFHE.lwe_ciphertext<600,4>
// CHECK-NEXT: %[[V3:.*]] = "LowLFHE.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, k = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (!LowLFHE.lwe_ciphertext<600,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<2048,4>
// CHECK-NEXT: return %[[V3]] : !LowLFHE.lwe_ciphertext<2048,4>
%tlu = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi4>
%1 = "MidLFHE.apply_lookup_table"(%arg0, %tlu){k=1:i32, polynomialSize=2048:i32, levelKS=3:i32, baseLogKS=2:i32, levelBS=5:i32, baseLogBS=4:i32, outputSizeKS=600:i32}: (!MidLFHE.glwe<{2048,1,64}{4}>, tensor<16xi4>) -> (!MidLFHE.glwe<{2048,1,64}{4}>)

View File

@@ -29,7 +29,6 @@ def test_compile_and_run(mlir_input, args, expected_result):
(
"""
func @main(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
// 0..128 shifted << 55
%tlu = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64>
%1 = "HLFHE.apply_lookup_table"(%arg0, %tlu): (!HLFHE.eint<7>, tensor<128xi64>) -> (!HLFHE.eint<7>)
return %1: !HLFHE.eint<7>