mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor: ins forward decl w generic ty @pass-init
Insert forward declarations with generic types at pass initialization. More docs for all the pass for lowering LUT
This commit is contained in:
committed by
Quentin Bourgerie
parent
d97512f507
commit
746d991af6
@@ -169,6 +169,47 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter rewriter,
|
||||
return op.getODSResults(0).front();
|
||||
}
|
||||
|
||||
// This is the rewritting of the HLFHE::ApplyLookupTable operation, it will be
|
||||
// rewritten as 3 new operations:
|
||||
// - Create the required GLWE ciphertext out of the plain lookup table
|
||||
// - Keyswitch the input ciphertext to match the input key of the bootstrapping
|
||||
// - Bootstrap the keyswitched ciphertext with the constructed GLWE ciphertext
|
||||
// Example:
|
||||
// from:
|
||||
// ```
|
||||
// "%result = MidLFHE.apply_lookup_table"(% arg0, % tlu){
|
||||
// k = 1 : i32,
|
||||
// polynomialSize = 2048 : i32,
|
||||
// levelKS = 3 : i32,
|
||||
// baseLogKS = 2 : i32,
|
||||
// levelBS = 5 : i32,
|
||||
// baseLogBS = 4 : i32,
|
||||
// outputSizeKS = 600 : i32
|
||||
// } : (!MidLFHE.glwe<{2048, 1, 64} {4}>, tensor<16xi4>)
|
||||
// ->(!MidLFHE.glwe<{2048, 1, 64} {4}>)
|
||||
// ```
|
||||
// to:
|
||||
// ```
|
||||
// % accumulator =
|
||||
// "LowLFHE.glwe_from_table"(
|
||||
// % [[TABLE]]){k = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32}
|
||||
// : (tensor<16xi4>)
|
||||
// ->!LowLFHE.glwe_ciphertext
|
||||
// % keyswitched = "LowLFHE.keyswitch_lwe"(% arg0){
|
||||
// baseLog = 2 : i32,
|
||||
// inputLweSize = 1 : i32,
|
||||
// level = 3 : i32,
|
||||
// outputLweSize = 600 : i32
|
||||
// } : (!LowLFHE.lwe_ciphertext<2048, 4>)
|
||||
// ->!LowLFHE.lwe_ciphertext<600, 4>
|
||||
// % result = "LowLFHE.bootstrap_lwe"(% keyswitched, % accumulator){
|
||||
// baseLog = 4 : i32,
|
||||
// k = 1 : i32,
|
||||
// level = 5 : i32,
|
||||
// polynomialSize = 2048 : i32
|
||||
// } : (!LowLFHE.lwe_ciphertext<600, 4>, !LowLFHE.glwe_ciphertext)
|
||||
// ->!LowLFHE.lwe_ciphertext<2048, 4>
|
||||
// ```
|
||||
mlir::Value createPBS(mlir::PatternRewriter rewriter, mlir::Location loc,
|
||||
mlir::Value ct, mlir::Value table, mlir::IntegerAttr k,
|
||||
mlir::IntegerAttr polynomialSize,
|
||||
@@ -178,10 +219,9 @@ mlir::Value createPBS(mlir::PatternRewriter rewriter, mlir::Location loc,
|
||||
// convert result type
|
||||
GLWECipherTextType glwe_type = result.getType().cast<GLWECipherTextType>();
|
||||
LweCiphertextType lwe_type =
|
||||
convertTypeGLWEToLWE(rewriter.getContext(), glwe_type);
|
||||
convertTypeToLWE(rewriter.getContext(), glwe_type);
|
||||
// fill the the table in the GLWE accumulator
|
||||
mlir::IntegerAttr precision = mlir::IntegerAttr::get(
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32), glwe_type.getP());
|
||||
mlir::IntegerAttr precision = rewriter.getI32IntegerAttr(glwe_type.getP());
|
||||
mlir::Value accumulator =
|
||||
rewriter
|
||||
.create<mlir::zamalang::LowLFHE::GlweFromTable>(
|
||||
@@ -191,8 +231,8 @@ mlir::Value createPBS(mlir::PatternRewriter rewriter, mlir::Location loc,
|
||||
|
||||
// keyswitch
|
||||
auto ct_type = ct.getType().cast<GLWECipherTextType>();
|
||||
mlir::SmallVector<mlir::Value, 1> ksArgs{ct};
|
||||
mlir::SmallVector<mlir::NamedAttribute, 6> ksAttrs{
|
||||
mlir::SmallVector<mlir::Value> ksArgs{ct};
|
||||
mlir::SmallVector<mlir::NamedAttribute> ksAttrs{
|
||||
mlir::NamedAttribute(
|
||||
mlir::Identifier::get("inputLweSize", rewriter.getContext()), k),
|
||||
mlir::NamedAttribute(
|
||||
@@ -203,16 +243,17 @@ mlir::Value createPBS(mlir::PatternRewriter rewriter, mlir::Location loc,
|
||||
mlir::NamedAttribute(
|
||||
mlir::Identifier::get("baseLog", rewriter.getContext()), baseLogKS),
|
||||
};
|
||||
auto ksOutType = LweCiphertextType::get(
|
||||
rewriter.getContext(), outputSizeKS.getInt(), ct_type.getP());
|
||||
mlir::Value keyswitched =
|
||||
rewriter
|
||||
.create<mlir::zamalang::LowLFHE::KeySwitchLweOp>(
|
||||
loc, convertTypeGLWEToLWE(rewriter.getContext(), ct_type), ksArgs,
|
||||
ksAttrs)
|
||||
.create<mlir::zamalang::LowLFHE::KeySwitchLweOp>(loc, ksOutType,
|
||||
ksArgs, ksAttrs)
|
||||
.result();
|
||||
|
||||
// bootstrap operation
|
||||
mlir::SmallVector<mlir::Value, 2> bsArgs{keyswitched, accumulator};
|
||||
mlir::SmallVector<mlir::NamedAttribute, 6> bsAttrs{
|
||||
mlir::SmallVector<mlir::Value> bsArgs{keyswitched, accumulator};
|
||||
mlir::SmallVector<mlir::NamedAttribute> bsAttrs{
|
||||
mlir::NamedAttribute(mlir::Identifier::get("k", rewriter.getContext()),
|
||||
k),
|
||||
mlir::NamedAttribute(
|
||||
|
||||
@@ -41,7 +41,7 @@ public:
|
||||
CircuitGate inputGate(size_t pos) { return std::get<0>(inputs[pos]); }
|
||||
CircuitGate outputGate(size_t pos) { return std::get<0>(outputs[pos]); }
|
||||
|
||||
void generateRuntimeContext() {
|
||||
void initGlobalRuntimeContext() {
|
||||
auto ksk = std::get<1>(this->keyswitchKeys["ksk_v0"]);
|
||||
auto bsk = std::get<1>(this->bootstrapKeys["bsk_v0"]);
|
||||
setGlobalRuntimeContext(createRuntimeContext(ksk, bsk));
|
||||
|
||||
@@ -26,7 +26,7 @@ public:
|
||||
};
|
||||
|
||||
mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter,
|
||||
mlir::RewriterBase &rewriter,
|
||||
llvm::StringRef funcName,
|
||||
mlir::FunctionType funcType) {
|
||||
// Looking for the `funcName` Operation
|
||||
@@ -54,6 +54,271 @@ mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op,
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Set of functions to generate generic types.
|
||||
// Generic types are used to add forward declarations without a specific type.
|
||||
// For example, we may need to add LWE ciphertext of different dimensions, or
|
||||
// allocate them. All the calls to the C API should be done using this generic
|
||||
// types, and casting should then be performed back to the appropriate type.
|
||||
|
||||
inline mlir::zamalang::LowLFHE::LweCiphertextType
|
||||
getGenericLweCiphertextType(mlir::MLIRContext *context) {
|
||||
return mlir::zamalang::LowLFHE::LweCiphertextType::get(context, -1, -1);
|
||||
}
|
||||
|
||||
inline mlir::zamalang::LowLFHE::GlweCiphertextType
|
||||
getGenericGlweCiphertextType(mlir::MLIRContext *context) {
|
||||
return mlir::zamalang::LowLFHE::GlweCiphertextType::get(context);
|
||||
}
|
||||
|
||||
inline mlir::zamalang::LowLFHE::PlaintextType
|
||||
getGenericPlaintextType(mlir::MLIRContext *context) {
|
||||
return mlir::zamalang::LowLFHE::PlaintextType::get(context, -1);
|
||||
}
|
||||
|
||||
inline mlir::zamalang::LowLFHE::PlaintextListType
|
||||
getGenericPlaintextListType(mlir::MLIRContext *context) {
|
||||
return mlir::zamalang::LowLFHE::PlaintextListType::get(context);
|
||||
}
|
||||
|
||||
inline mlir::zamalang::LowLFHE::ForeignPlaintextListType
|
||||
getGenericForeignPlaintextListType(mlir::MLIRContext *context) {
|
||||
return mlir::zamalang::LowLFHE::ForeignPlaintextListType::get(context);
|
||||
}
|
||||
|
||||
inline mlir::zamalang::LowLFHE::CleartextType
|
||||
getGenericCleartextType(mlir::MLIRContext *context) {
|
||||
return mlir::zamalang::LowLFHE::CleartextType::get(context, -1);
|
||||
}
|
||||
|
||||
inline mlir::zamalang::LowLFHE::LweBootstrapKeyType
|
||||
getGenericLweBootstrapKeyType(mlir::MLIRContext *context) {
|
||||
return mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(context);
|
||||
}
|
||||
|
||||
inline mlir::zamalang::LowLFHE::LweKeySwitchKeyType
|
||||
getGenericLweKeySwitchKeyType(mlir::MLIRContext *context) {
|
||||
return mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(context);
|
||||
}
|
||||
|
||||
// Get the generic version of the type.
|
||||
// Useful when iterating over a set of types.
|
||||
mlir::Type getGenericType(mlir::Type baseType) {
|
||||
if (baseType.isa<mlir::zamalang::LowLFHE::LweCiphertextType>()) {
|
||||
return getGenericLweCiphertextType(baseType.getContext());
|
||||
}
|
||||
if (baseType.isa<mlir::zamalang::LowLFHE::PlaintextType>()) {
|
||||
return getGenericPlaintextType(baseType.getContext());
|
||||
}
|
||||
if (baseType.isa<mlir::zamalang::LowLFHE::CleartextType>()) {
|
||||
return getGenericCleartextType(baseType.getContext());
|
||||
}
|
||||
return baseType;
|
||||
}
|
||||
|
||||
// Insert all forward declarations needed for the pass.
|
||||
// Should generalize input and output types for all decalarations, and the
|
||||
// pattern using them would be resposible for casting them to the appropriate
|
||||
// type.
|
||||
mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
mlir::IRRewriter &rewriter) {
|
||||
auto genericLweCiphertextType =
|
||||
getGenericLweCiphertextType(rewriter.getContext());
|
||||
auto genericGlweCiphertextType =
|
||||
getGenericGlweCiphertextType(rewriter.getContext());
|
||||
auto genericPlaintextType = getGenericPlaintextType(rewriter.getContext());
|
||||
auto genericPlaintextListType =
|
||||
getGenericPlaintextListType(rewriter.getContext());
|
||||
auto genericForeignPlaintextList =
|
||||
getGenericForeignPlaintextListType(rewriter.getContext());
|
||||
auto genericCleartextType = getGenericCleartextType(rewriter.getContext());
|
||||
auto genericBSKType = getGenericLweBootstrapKeyType(rewriter.getContext());
|
||||
auto genericKSKType = getGenericLweKeySwitchKeyType(rewriter.getContext());
|
||||
auto errType = mlir::IndexType::get(rewriter.getContext());
|
||||
|
||||
// Insert forward declaration of allocate lwe ciphertext
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{
|
||||
errType,
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
},
|
||||
|
||||
{genericLweCiphertextType});
|
||||
if (insertForwardDeclaration(op, rewriter, "allocate_lwe_ciphertext_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the add_lwe_ciphertexts function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
errType,
|
||||
genericLweCiphertextType,
|
||||
genericLweCiphertextType,
|
||||
genericLweCiphertextType,
|
||||
},
|
||||
{});
|
||||
if (insertForwardDeclaration(op, rewriter, "add_lwe_ciphertexts_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the add_plaintext_lwe_ciphertext_u64 function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
errType,
|
||||
genericLweCiphertextType,
|
||||
genericLweCiphertextType,
|
||||
genericPlaintextType,
|
||||
},
|
||||
{});
|
||||
if (insertForwardDeclaration(op, rewriter,
|
||||
"add_plaintext_lwe_ciphertext_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the mul_cleartext_lwe_ciphertext_u64 function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
errType,
|
||||
genericLweCiphertextType,
|
||||
genericLweCiphertextType,
|
||||
genericCleartextType,
|
||||
},
|
||||
{});
|
||||
if (insertForwardDeclaration(op, rewriter,
|
||||
"mul_cleartext_lwe_ciphertext_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the negate_lwe_ciphertext_u64 function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{errType, genericLweCiphertextType, genericLweCiphertextType}, {});
|
||||
if (insertForwardDeclaration(op, rewriter, "negate_lwe_ciphertext_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the getBsk function
|
||||
{
|
||||
auto funcType =
|
||||
mlir::FunctionType::get(rewriter.getContext(), {}, {genericBSKType});
|
||||
if (insertForwardDeclaration(op, rewriter, "getGlobalBootstrapKey",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the bootstrap function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
errType,
|
||||
genericBSKType,
|
||||
genericLweCiphertextType,
|
||||
genericLweCiphertextType,
|
||||
genericGlweCiphertextType,
|
||||
},
|
||||
{});
|
||||
if (insertForwardDeclaration(op, rewriter, "bootstrap_lwe_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the getKsk function
|
||||
{
|
||||
auto funcType =
|
||||
mlir::FunctionType::get(rewriter.getContext(), {}, {genericKSKType});
|
||||
if (insertForwardDeclaration(op, rewriter, "getGlobalKeyswitchKey",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the keyswitch function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
errType,
|
||||
// ksk
|
||||
genericKSKType,
|
||||
// output ct
|
||||
genericLweCiphertextType,
|
||||
// input ct
|
||||
genericLweCiphertextType,
|
||||
},
|
||||
{});
|
||||
if (insertForwardDeclaration(op, rewriter, "keyswitch_lwe_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the alloc_glwe function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{
|
||||
errType,
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
},
|
||||
{genericGlweCiphertextType});
|
||||
if (insertForwardDeclaration(op, rewriter, "allocate_glwe_ciphertext_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the alloc_plaintext_list function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{errType, mlir::IntegerType::get(rewriter.getContext(), 32)},
|
||||
{genericPlaintextListType});
|
||||
if (insertForwardDeclaration(op, rewriter, "allocate_plaintext_list_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the fill_plaintext_list function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{errType, genericPlaintextListType, genericForeignPlaintextList}, {});
|
||||
if (insertForwardDeclaration(
|
||||
op, rewriter, "fill_plaintext_list_with_expansion_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the add_plaintext_list_glwe function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{errType, genericGlweCiphertextType,
|
||||
genericGlweCiphertextType,
|
||||
genericPlaintextListType},
|
||||
{});
|
||||
if (insertForwardDeclaration(
|
||||
op, rewriter, "add_plaintext_list_glwe_ciphertext_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
/// LowLFHEOpToConcreteCAPICallPattern<Op> match the `Op` Operation and
|
||||
/// replace with a call to `funcName`, the funcName should be an external
|
||||
/// function that was linked later. It insert the forward declaration of the
|
||||
@@ -81,29 +346,7 @@ struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern<Op> {
|
||||
matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
|
||||
LowLFHEToConcreteCAPITypeConverter typeConverter;
|
||||
auto errType = mlir::IndexType::get(rewriter.getContext());
|
||||
// Insert forward declaration of the operator function
|
||||
{
|
||||
mlir::SmallVector<mlir::Type, 4> operands{errType,
|
||||
op->getResultTypes().front()};
|
||||
for (auto ty : op->getOperandTypes()) {
|
||||
operands.push_back(typeConverter.convertType(ty));
|
||||
}
|
||||
auto funcType =
|
||||
mlir::FunctionType::get(rewriter.getContext(), operands, {});
|
||||
if (insertForwardDeclaration(op, rewriter, funcName, funcType).failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the alloc function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {errType, rewriter.getIndexType()},
|
||||
{op->getResultTypes().front()});
|
||||
if (insertForwardDeclaration(op, rewriter, allocName, funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
mlir::Type resultType = op->getResultTypes().front();
|
||||
auto lweResultType =
|
||||
resultType.cast<mlir::zamalang::LowLFHE::LweCiphertextType>();
|
||||
@@ -114,18 +357,39 @@ struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern<Op> {
|
||||
rewriter.getIndexAttr(0));
|
||||
// Add the call to the allocation
|
||||
auto lweSizeOp = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(lweResultType.getSize()));
|
||||
op.getLoc(), rewriter.getI32IntegerAttr(lweResultType.getSize()));
|
||||
mlir::SmallVector<mlir::Value> allocOperands{errOp, lweSizeOp};
|
||||
auto alloc = rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
||||
op, allocName, op.getType(), allocOperands);
|
||||
|
||||
// Add err and allocated value to operands
|
||||
mlir::SmallVector<mlir::Value, 4> newOperands{errOp, alloc.getResult(0)};
|
||||
for (auto operand : op->getOperands()) {
|
||||
newOperands.push_back(operand);
|
||||
auto allocGeneric = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), allocName,
|
||||
getGenericLweCiphertextType(rewriter.getContext()), allocOperands);
|
||||
// Construct operands for the operation.
|
||||
// errOp doesn't need to be casted to something generic, allocGeneric
|
||||
// already is. All the rest will be converted if needed
|
||||
mlir::SmallVector<mlir::Value, 4> newOperands{errOp,
|
||||
allocGeneric.getResult(0)};
|
||||
for (mlir::Value operand : op->getOperands()) {
|
||||
mlir::Type operandType = operand.getType();
|
||||
mlir::Type castedType = getGenericType(operandType);
|
||||
if (castedType == operandType) {
|
||||
// Type didn't change, no need for cast
|
||||
newOperands.push_back(operand);
|
||||
} else {
|
||||
// Type changed, need to cast to the generic one
|
||||
auto castedOperand = rewriter
|
||||
.create<mlir::UnrealizedConversionCastOp>(
|
||||
op.getLoc(), castedType, operand)
|
||||
.getResult(0);
|
||||
newOperands.push_back(castedOperand);
|
||||
}
|
||||
}
|
||||
// The operations called here are known to be inplace, and no need for a
|
||||
// return type.
|
||||
rewriter.create<mlir::CallOp>(op.getLoc(), funcName, mlir::TypeRange{},
|
||||
newOperands);
|
||||
// cast result value to the appropriate type
|
||||
auto alloc =
|
||||
rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
|
||||
op, op.getType(), allocGeneric.getResult(0));
|
||||
}
|
||||
return mlir::success();
|
||||
};
|
||||
@@ -145,32 +409,24 @@ struct LowLFHEZeroOpPattern
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::zamalang::LowLFHE::ZeroLWEOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto allocName = "allocate_lwe_ciphertext_u64";
|
||||
auto errType = mlir::IndexType::get(rewriter.getContext());
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {errType, rewriter.getIndexType()},
|
||||
{op->getResultTypes().front()});
|
||||
if (insertForwardDeclaration(op, rewriter, allocName, funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Replace the operation with a call to the `funcName`
|
||||
{
|
||||
mlir::Type resultType = op->getResultTypes().front();
|
||||
auto lweResultType =
|
||||
resultType.cast<mlir::zamalang::LowLFHE::LweCiphertextType>();
|
||||
// Create the err value
|
||||
auto errOp = rewriter.create<mlir::ConstantOp>(op.getLoc(),
|
||||
rewriter.getIndexAttr(0));
|
||||
// Add the call to the allocation
|
||||
auto lweSizeOp = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(lweResultType.getSize()));
|
||||
mlir::SmallVector<mlir::Value> allocOperands{errOp, lweSizeOp};
|
||||
auto alloc = rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
||||
op, allocName, op.getType(), allocOperands);
|
||||
}
|
||||
|
||||
mlir::Type resultType = op->getResultTypes().front();
|
||||
auto lweResultType =
|
||||
resultType.cast<mlir::zamalang::LowLFHE::LweCiphertextType>();
|
||||
// Create the err value
|
||||
auto errOp = rewriter.create<mlir::ConstantOp>(op.getLoc(),
|
||||
rewriter.getIndexAttr(0));
|
||||
// Allocate a fresh new ciphertext
|
||||
auto lweSizeOp = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI32IntegerAttr(lweResultType.getSize()));
|
||||
mlir::SmallVector<mlir::Value> allocOperands{errOp, lweSizeOp};
|
||||
auto allocGeneric = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "allocate_lwe_ciphertext_u64",
|
||||
getGenericLweCiphertextType(rewriter.getContext()), allocOperands);
|
||||
// Cast the result to the appropriate type
|
||||
rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
|
||||
op, op.getType(), allocGeneric.getResult(0));
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
@@ -215,6 +471,14 @@ struct LowLFHEIntToCleartextOpPattern
|
||||
};
|
||||
};
|
||||
|
||||
// Rewrite the GlweFromTable operation to a series of ops:
|
||||
// - allocation of two GLWE, one for the addition, and one for storing the
|
||||
// result
|
||||
// - allocation of plaintext_list to build the GLWE accumulator
|
||||
// - build the foreign_plaintext_list using the input table
|
||||
// - fill the plaintext_list with the foreign_plaintext_list
|
||||
// - construct the GLWE accumulator by adding the plaintext_list to a freshly
|
||||
// allocated GLWE
|
||||
struct GlweFromTableOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::GlweFromTable> {
|
||||
GlweFromTableOpPattern(mlir::MLIRContext *context,
|
||||
@@ -227,49 +491,19 @@ struct GlweFromTableOpPattern
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
LowLFHEToConcreteCAPITypeConverter typeConverter;
|
||||
auto errType = mlir::IndexType::get(rewriter.getContext());
|
||||
// Insert forward declaration of the alloc_glwe function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{
|
||||
errType,
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
},
|
||||
{mlir::zamalang::LowLFHE::GlweCiphertextType::get(
|
||||
rewriter.getContext())});
|
||||
if (insertForwardDeclaration(op, rewriter, "allocate_glwe_ciphertext_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the alloc_plaintext_list function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{errType, mlir::IntegerType::get(rewriter.getContext(), 32)},
|
||||
{mlir::zamalang::LowLFHE::PlaintextListType::get(
|
||||
rewriter.getContext())});
|
||||
if (insertForwardDeclaration(op, rewriter, "allocate_plaintext_list_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: move this to insertForwardDeclarations
|
||||
// issue: can't define function with tensor<*xtype> that accept ranked
|
||||
// tensors
|
||||
|
||||
// Insert forward declaration of the foregin_pt_list function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{errType,
|
||||
// mlir::UnrankedTensorType::get(
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 64)),
|
||||
op->getOperandTypes().front(),
|
||||
{errType, op->getOperandTypes().front(),
|
||||
mlir::IntegerType::get(rewriter.getContext(), 64),
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32)},
|
||||
{mlir::zamalang::LowLFHE::ForeignPlaintextListType::get(
|
||||
rewriter.getContext())});
|
||||
{getGenericForeignPlaintextListType(rewriter.getContext())});
|
||||
if (insertForwardDeclaration(
|
||||
op, rewriter, "runtime_foreign_plaintext_list_u64", funcType)
|
||||
.failed()) {
|
||||
@@ -277,41 +511,6 @@ struct GlweFromTableOpPattern
|
||||
}
|
||||
}
|
||||
|
||||
// Insert forward declaration of the fill_plaintext_list function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{errType,
|
||||
mlir::zamalang::LowLFHE::PlaintextListType::get(
|
||||
rewriter.getContext()),
|
||||
mlir::zamalang::LowLFHE::ForeignPlaintextListType::get(
|
||||
rewriter.getContext())},
|
||||
{});
|
||||
if (insertForwardDeclaration(
|
||||
op, rewriter, "fill_plaintext_list_with_expansion_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
// Insert forward declaration of the add_plaintext_list_glwe function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{errType,
|
||||
mlir::zamalang::LowLFHE::GlweCiphertextType::get(
|
||||
rewriter.getContext()),
|
||||
mlir::zamalang::LowLFHE::GlweCiphertextType::get(
|
||||
rewriter.getContext()),
|
||||
mlir::zamalang::LowLFHE::PlaintextListType::get(
|
||||
rewriter.getContext())},
|
||||
{});
|
||||
if (insertForwardDeclaration(
|
||||
op, rewriter, "add_plaintext_list_glwe_ciphertext_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
auto errOp = rewriter.create<mlir::ConstantOp>(op.getLoc(),
|
||||
rewriter.getIndexAttr(0));
|
||||
// allocate two glwe to build accumulator
|
||||
@@ -324,39 +523,33 @@ struct GlweFromTableOpPattern
|
||||
// first accumulator would replace the op since it's the returned value
|
||||
auto accumulatorOp = rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
||||
op, "allocate_glwe_ciphertext_u64",
|
||||
mlir::zamalang::LowLFHE::GlweCiphertextType::get(rewriter.getContext()),
|
||||
allocGlweOperands);
|
||||
getGenericGlweCiphertextType(rewriter.getContext()), allocGlweOperands);
|
||||
// second accumulator is just needed to build the actual accumulator
|
||||
auto _accumulatorOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "allocate_glwe_ciphertext_u64",
|
||||
mlir::zamalang::LowLFHE::GlweCiphertextType::get(rewriter.getContext()),
|
||||
allocGlweOperands);
|
||||
getGenericGlweCiphertextType(rewriter.getContext()), allocGlweOperands);
|
||||
// allocate plaintext list
|
||||
mlir::SmallVector<mlir::Value> allocPlaintextListOperands{errOp,
|
||||
polySizeOp};
|
||||
auto plaintextListOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "allocate_plaintext_list_u64",
|
||||
mlir::zamalang::LowLFHE::PlaintextListType::get(rewriter.getContext()),
|
||||
getGenericPlaintextListType(rewriter.getContext()),
|
||||
allocPlaintextListOperands);
|
||||
// create foreign plaintext
|
||||
auto rankedTensorType =
|
||||
op->getOperandTypes().front().cast<mlir::RankedTensorType>();
|
||||
if (rankedTensorType.getRank() != 1) {
|
||||
llvm::errs() << "table lookup must be of a single dimension";
|
||||
return mlir::failure();
|
||||
}
|
||||
assert(rankedTensorType.getRank() == 1 &&
|
||||
"table lookup must be of a single dimension");
|
||||
auto sizeOp = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIntegerAttr(
|
||||
mlir::IntegerType::get(rewriter.getContext(), 64),
|
||||
rankedTensorType.getDimSize(0)));
|
||||
op.getLoc(),
|
||||
rewriter.getI64IntegerAttr(rankedTensorType.getDimSize(0)));
|
||||
auto precisionOp =
|
||||
rewriter.create<mlir::ConstantOp>(op.getLoc(), op->getAttr("p"));
|
||||
mlir::SmallVector<mlir::Value> ForeignPlaintextListOperands{
|
||||
errOp, op->getOperand(0), sizeOp, precisionOp};
|
||||
auto foreignPlaintextListOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "runtime_foreign_plaintext_list_u64",
|
||||
mlir::zamalang::LowLFHE::ForeignPlaintextListType::get(
|
||||
rewriter.getContext()),
|
||||
getGenericForeignPlaintextListType(rewriter.getContext()),
|
||||
ForeignPlaintextListOperands);
|
||||
// fill plaintext list
|
||||
mlir::SmallVector<mlir::Value> FillPlaintextListOperands{
|
||||
@@ -376,6 +569,11 @@ struct GlweFromTableOpPattern
|
||||
};
|
||||
};
|
||||
|
||||
// Rewrite a BootstrapLweOp with a series of ops:
|
||||
// - allocate the result LWE ciphertext
|
||||
// - get the global bootstrapping key
|
||||
// - use the key and the input accumulator (GLWE) to bootstrap the input
|
||||
// ciphertext
|
||||
struct LowLFHEBootstrapLweOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::BootstrapLweOp> {
|
||||
LowLFHEBootstrapLweOpPattern(mlir::MLIRContext *context,
|
||||
@@ -386,141 +584,58 @@ struct LowLFHEBootstrapLweOpPattern
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::zamalang::LowLFHE::BootstrapLweOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto errType = mlir::IndexType::get(rewriter.getContext());
|
||||
auto lweOperandType = op->getOperandTypes().front();
|
||||
// Insert forward declaration of the allocate_bsk_key function
|
||||
// {
|
||||
// auto funcType = mlir::FunctionType::get(
|
||||
// rewriter.getContext(),
|
||||
// {
|
||||
// errType,
|
||||
// // level
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// // baselog
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// // glwe size
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// // lwe size
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// // polynomial size
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// },
|
||||
// {mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(
|
||||
// rewriter.getContext())});
|
||||
// if (insertForwardDeclaration(op, rewriter,
|
||||
// "allocate_lwe_bootstrap_key_u64",
|
||||
// funcType)
|
||||
// .failed()) {
|
||||
// return mlir::failure();
|
||||
// }
|
||||
// }
|
||||
|
||||
// Insert forward declaration of the getBsk function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {},
|
||||
{mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(
|
||||
rewriter.getContext())});
|
||||
if (insertForwardDeclaration(op, rewriter, "getGlobalBootstrapKey",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the allocate_lwe_ct function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{
|
||||
errType,
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
},
|
||||
{lweOperandType});
|
||||
if (insertForwardDeclaration(op, rewriter, "allocate_lwe_ciphertext_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the bootstrap function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{
|
||||
errType,
|
||||
mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(
|
||||
rewriter.getContext()),
|
||||
lweOperandType,
|
||||
lweOperandType,
|
||||
mlir::zamalang::LowLFHE::GlweCiphertextType::get(
|
||||
rewriter.getContext()),
|
||||
},
|
||||
{});
|
||||
if (insertForwardDeclaration(op, rewriter, "bootstrap_lwe_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
auto resultType = op->getResultTypes().front();
|
||||
auto bstOutputSize =
|
||||
resultType.cast<mlir::zamalang::LowLFHE::LweCiphertextType>().getSize();
|
||||
auto errOp = rewriter.create<mlir::ConstantOp>(op.getLoc(),
|
||||
rewriter.getIndexAttr(0));
|
||||
// allocate the result lwe ciphertext
|
||||
// TODO: use right value for output lwe size
|
||||
// LweSize output_lwe_size = { (glwe_size._0 -1) * poly_size._0 + 1}
|
||||
// allocate the result lwe ciphertext, should be of a generic type, to cast
|
||||
// before return
|
||||
auto lweSizeOp = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(), mlir::IntegerAttr::get(
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
op->getAttr("k").cast<mlir::IntegerAttr>().getInt()));
|
||||
op.getLoc(),
|
||||
mlir::IntegerAttr::get(
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32), bstOutputSize));
|
||||
mlir::SmallVector<mlir::Value> allocLweCtOperands{errOp, lweSizeOp};
|
||||
auto allocateLweCtOp = rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
||||
op, "allocate_lwe_ciphertext_u64", lweOperandType, allocLweCtOperands);
|
||||
// allocate bsk
|
||||
// auto decompLevelCountOp = rewriter.create<mlir::ConstantOp>(
|
||||
// op.getLoc(),
|
||||
// mlir::IntegerAttr::get(
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// op->getAttr("level").cast<mlir::IntegerAttr>().getInt()));
|
||||
// auto decompBaseLogOp = rewriter.create<mlir::ConstantOp>(
|
||||
// op.getLoc(),
|
||||
// mlir::IntegerAttr::get(
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// op->getAttr("baseLog").cast<mlir::IntegerAttr>().getInt()));
|
||||
// auto glweSizeOp = rewriter.create<mlir::ConstantOp>(
|
||||
// op.getLoc(),
|
||||
// mlir::IntegerAttr::get(
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32), -1));
|
||||
// auto polySizeOp = rewriter.create<mlir::ConstantOp>(
|
||||
// op.getLoc(),
|
||||
// mlir::IntegerAttr::get(
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// op->getAttr("polynomialSize").cast<mlir::IntegerAttr>().getInt()));
|
||||
// mlir::SmallVector<mlir::Value> allocBskOperands{
|
||||
// errOp, decompLevelCountOp, decompBaseLogOp,
|
||||
// glweSizeOp, lweSizeOp, polySizeOp};
|
||||
// auto allocateBskOp = rewriter.create<mlir::CallOp>(
|
||||
// op.getLoc(), "allocate_lwe_bootstrap_key_u64",
|
||||
// mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(
|
||||
// rewriter.getContext()),
|
||||
// allocBskOperands);
|
||||
|
||||
auto allocateGenericLweCtOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "allocate_lwe_ciphertext_u64",
|
||||
getGenericLweCiphertextType(rewriter.getContext()), allocLweCtOperands);
|
||||
// get bsk
|
||||
mlir::SmallVector<mlir::Value> getBskOperands{};
|
||||
auto getBskOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "getGlobalBootstrapKey",
|
||||
mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(
|
||||
rewriter.getContext()),
|
||||
getBskOperands);
|
||||
getGenericLweBootstrapKeyType(rewriter.getContext()), getBskOperands);
|
||||
// bootstrap
|
||||
// cast input ciphertext to a generic type
|
||||
mlir::Value lweToBootstrap =
|
||||
rewriter
|
||||
.create<mlir::UnrealizedConversionCastOp>(
|
||||
op.getLoc(), getGenericType(op.getOperand(0).getType()),
|
||||
op.getOperand(0))
|
||||
.getResult(0);
|
||||
// cast input accumulator to a generic type
|
||||
mlir::Value accumulator =
|
||||
rewriter
|
||||
.create<mlir::UnrealizedConversionCastOp>(
|
||||
op.getLoc(), getGenericType(op.getOperand(1).getType()),
|
||||
op.getOperand(1))
|
||||
.getResult(0);
|
||||
mlir::SmallVector<mlir::Value> bootstrapOperands{
|
||||
errOp, getBskOp.getResult(0), allocateLweCtOp.getResult(0),
|
||||
op->getOperand(0), op->getOperand(1)};
|
||||
errOp, getBskOp.getResult(0), allocateGenericLweCtOp.getResult(0),
|
||||
lweToBootstrap, accumulator};
|
||||
rewriter.create<mlir::CallOp>(op.getLoc(), "bootstrap_lwe_u64",
|
||||
mlir::TypeRange({}), bootstrapOperands);
|
||||
// Cast result to the appropriate type
|
||||
rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
|
||||
op, resultType, allocateGenericLweCtOp.getResult(0));
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
// Rewrite a KeySwitchLweOp with a series of ops:
|
||||
// - allocate the result LWE ciphertext
|
||||
// - get the global keyswitch key
|
||||
// - use the key to keyswitch the input ciphertext
|
||||
struct LowLFHEKeySwitchLweOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::KeySwitchLweOp> {
|
||||
LowLFHEKeySwitchLweOpPattern(mlir::MLIRContext *context,
|
||||
@@ -531,139 +646,41 @@ struct LowLFHEKeySwitchLweOpPattern
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::zamalang::LowLFHE::KeySwitchLweOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto errType = mlir::IndexType::get(rewriter.getContext());
|
||||
auto lweOperandType = op->getOperandTypes().front();
|
||||
// Insert forward declaration of the allocate_ksk_key function
|
||||
// {
|
||||
// auto funcType = mlir::FunctionType::get(
|
||||
// rewriter.getContext(),
|
||||
// {
|
||||
// errType,
|
||||
// // level
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// // baselog
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// // input lwe size
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// // output lwe size
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// },
|
||||
// {mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(
|
||||
// rewriter.getContext())});
|
||||
// if (insertForwardDeclaration(op, rewriter,
|
||||
// "allocate_lwe_keyswitch_key_u64",
|
||||
// funcType)
|
||||
// .failed()) {
|
||||
// return mlir::failure();
|
||||
// }
|
||||
// }
|
||||
|
||||
// Insert forward declaration of the getKsk function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {},
|
||||
{mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(
|
||||
rewriter.getContext())});
|
||||
if (insertForwardDeclaration(op, rewriter, "getGlobalKeyswitchKey",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Insert forward declaration of the allocate_lwe_ct function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{
|
||||
errType,
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
},
|
||||
{lweOperandType});
|
||||
if (insertForwardDeclaration(op, rewriter, "allocate_lwe_ciphertext_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// TODO: build the right type here
|
||||
auto lweOutputType = lweOperandType;
|
||||
// Insert forward declaration of the keyswitch function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{
|
||||
errType,
|
||||
// ksk
|
||||
mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(
|
||||
rewriter.getContext()),
|
||||
// output ct
|
||||
lweOutputType,
|
||||
// input ct
|
||||
lweOperandType,
|
||||
},
|
||||
{});
|
||||
if (insertForwardDeclaration(op, rewriter, "keyswitch_lwe_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
auto errOp = rewriter.create<mlir::ConstantOp>(op.getLoc(),
|
||||
rewriter.getIndexAttr(0));
|
||||
// allocate the result lwe ciphertext
|
||||
// allocate the result lwe ciphertext, should be of a generic type, to cast
|
||||
// before return
|
||||
auto lweSizeOp = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(),
|
||||
mlir::IntegerAttr::get(
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
op->getAttr("outputLweSize").cast<mlir::IntegerAttr>().getInt()));
|
||||
mlir::SmallVector<mlir::Value> allocLweCtOperands{errOp, lweSizeOp};
|
||||
auto allocateLweCtOp = rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
||||
op, "allocate_lwe_ciphertext_u64", lweOutputType, allocLweCtOperands);
|
||||
// allocate ksk
|
||||
// auto decompLevelCountOp = rewriter.create<mlir::ConstantOp>(
|
||||
// op.getLoc(),
|
||||
// mlir::IntegerAttr::get(
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// op->getAttr("level").cast<mlir::IntegerAttr>().getInt()));
|
||||
// auto decompBaseLogOp = rewriter.create<mlir::ConstantOp>(
|
||||
// op.getLoc(),
|
||||
// mlir::IntegerAttr::get(
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// op->getAttr("baseLog").cast<mlir::IntegerAttr>().getInt()));
|
||||
// auto inputLweSizeOp = rewriter.create<mlir::ConstantOp>(
|
||||
// op.getLoc(),
|
||||
// mlir::IntegerAttr::get(
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// op->getAttr("inputLweSize").cast<mlir::IntegerAttr>().getInt()));
|
||||
// auto outputLweSizeOp = rewriter.create<mlir::ConstantOp>(
|
||||
// op.getLoc(),
|
||||
// mlir::IntegerAttr::get(
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// op->getAttr("outputLweSize").cast<mlir::IntegerAttr>().getInt()));
|
||||
// mlir::SmallVector<mlir::Value> allockskOperands{
|
||||
// errOp, decompLevelCountOp, decompBaseLogOp, inputLweSizeOp,
|
||||
// outputLweSizeOp};
|
||||
// auto allocateKskOp = rewriter.create<mlir::CallOp>(
|
||||
// op.getLoc(), "allocate_lwe_keyswitch_key_u64",
|
||||
// mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(
|
||||
// rewriter.getContext()),
|
||||
// allockskOperands);
|
||||
|
||||
auto allocateGenericLweCtOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "allocate_lwe_ciphertext_u64",
|
||||
getGenericLweCiphertextType(rewriter.getContext()), allocLweCtOperands);
|
||||
// get ksk
|
||||
mlir::SmallVector<mlir::Value> getkskOperands{};
|
||||
auto getKskOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "getGlobalKeyswitchKey",
|
||||
mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(
|
||||
rewriter.getContext()),
|
||||
getkskOperands);
|
||||
|
||||
getGenericLweKeySwitchKeyType(rewriter.getContext()), getkskOperands);
|
||||
// keyswitch
|
||||
// cast input ciphertext to a generic type
|
||||
mlir::Value lweToKeyswitch =
|
||||
rewriter
|
||||
.create<mlir::UnrealizedConversionCastOp>(
|
||||
op.getLoc(), getGenericType(op.getOperand().getType()),
|
||||
op.getOperand())
|
||||
.getResult(0);
|
||||
mlir::SmallVector<mlir::Value> keyswitchOperands{
|
||||
errOp, getKskOp.getResult(0), allocateLweCtOp.getResult(0),
|
||||
op->getOperand(0)};
|
||||
errOp, getKskOp.getResult(0), allocateGenericLweCtOp.getResult(0),
|
||||
lweToKeyswitch};
|
||||
rewriter.create<mlir::CallOp>(op.getLoc(), "keyswitch_lwe_u64",
|
||||
mlir::TypeRange({}), keyswitchOperands);
|
||||
|
||||
// Cast result to the appropriate type
|
||||
auto lweOutputType = op->getResultTypes().front();
|
||||
rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
|
||||
op, lweOutputType, allocateGenericLweCtOp.getResult(0));
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
@@ -713,8 +730,14 @@ void LowLFHEToConcreteCAPIPass::runOnOperation() {
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
populateLowLFHEToConcreteCAPICall(patterns);
|
||||
|
||||
// Apply the conversion
|
||||
// Insert forward declarations
|
||||
mlir::ModuleOp op = getOperation();
|
||||
mlir::IRRewriter rewriter(&getContext());
|
||||
if (insertForwardDeclarations(op, rewriter).failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
|
||||
// Apply the conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
|
||||
@@ -35,6 +35,7 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
|
||||
// legalize LLVM dialect.
|
||||
mlir::LLVMConversionTarget target(getContext());
|
||||
target.addLegalOp<mlir::ModuleOp>();
|
||||
target.addIllegalOp<mlir::UnrealizedConversionCastOp>();
|
||||
|
||||
// Setup the LLVMTypeConverter (that converts `std` types to `llvm` types) and
|
||||
// add our types conversion to `llvm` compatible type.
|
||||
|
||||
@@ -107,7 +107,7 @@ bool verifyEncryptedIntegerInputsConsistency(::mlir::OpState &op,
|
||||
"`ct` argument.";
|
||||
return mlir::failure();
|
||||
}
|
||||
// Disable this check for the moment
|
||||
// Disable this check for the moment: issue/111
|
||||
// Check the witdh of the encrypted integer and the integer of the tabulated
|
||||
// lambda are equals
|
||||
// if (ct.getWidth() != l_cst.getElementType().cast<IntegerType>().getWidth())
|
||||
|
||||
@@ -123,7 +123,7 @@ mlir::LogicalResult verifyApplyLookupTable(ApplyLookupTable &op) {
|
||||
"`ct` argument.";
|
||||
return mlir::failure();
|
||||
}
|
||||
// Disable this check for the moment
|
||||
// Disable this check for the moment: issue/111
|
||||
// Check the witdh of the encrypted integer and the integer of the tabulated
|
||||
// lambda are equals
|
||||
// if (result.getP() < l_cst.getElementType().cast<IntegerType>().getWidth())
|
||||
|
||||
@@ -236,7 +236,7 @@ JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
|
||||
}
|
||||
|
||||
// Setup runtime context with appropriate keys
|
||||
keySet.generateRuntimeContext();
|
||||
keySet.initGlobalRuntimeContext();
|
||||
}
|
||||
|
||||
JITLambda::Argument::~Argument() {
|
||||
|
||||
@@ -1,17 +1,30 @@
|
||||
// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: module
|
||||
// CHECK-NEXT: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext)
|
||||
// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, i32) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)
|
||||
// CHECK-NEXT: func private @fill_plaintext_list_with_expansion_u64(index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list)
|
||||
// CHECK-NEXT: func private @allocate_plaintext_list_u64(index, i32) -> !LowLFHE.plaintext_list
|
||||
// CHECK-NEXT: func private @allocate_glwe_ciphertext_u64(index, i32, i32) -> !LowLFHE.glwe_ciphertext
|
||||
// CHECK-NEXT: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
|
||||
// CHECK-NEXT: func private @getGlobalKeyswitchKey() -> !LowLFHE.lwe_key_switch_key
|
||||
// CHECK-NEXT: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext)
|
||||
// CHECK-NEXT: func private @getGlobalBootstrapKey() -> !LowLFHE.lwe_bootstrap_key
|
||||
// CHECK-NEXT: func private @negate_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
|
||||
// CHECK-NEXT: func private @mul_cleartext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.cleartext<_>)
|
||||
// CHECK-NEXT: func private @add_plaintext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.plaintext<_>)
|
||||
// CHECK-NEXT: func private @add_lwe_ciphertexts_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
|
||||
// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, i32) -> !LowLFHE.lwe_ciphertext<_,_>
|
||||
// CHECK-LABEL: func @bootstrap_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
func @bootstrap_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4> {
|
||||
// CHECK-NEXT: %[[ERR:.*]] = constant 0 : index
|
||||
// CHECK-NEXT: %[[C0:.*]] = constant 1 : i32
|
||||
// CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[ERR]], %[[C0]]) : (index, i32) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
// CHECK-NEXT: %[[C0:.*]] = constant 1024 : i32
|
||||
// CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[ERR]], %[[C0]]) : (index, i32) -> !LowLFHE.lwe_ciphertext<_,_>
|
||||
// CHECK-NEXT: %[[V2:.*]] = call @getGlobalBootstrapKey() : () -> !LowLFHE.lwe_bootstrap_key
|
||||
// CHECK-NEXT: call @bootstrap_lwe_u64(%[[ERR]], %[[V2]], %[[V1]], %arg0, %arg1) : (index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> ()
|
||||
// CHECK-NEXT: return %[[V1]] : !LowLFHE.lwe_ciphertext<1024,4>
|
||||
// CHECK-NEXT: %[[V3:.*]] = unrealized_conversion_cast %arg0 : !LowLFHE.lwe_ciphertext<1024,4> to !LowLFHE.lwe_ciphertext<_,_>
|
||||
// CHECK-NEXT: %[[V4:.*]] = unrealized_conversion_cast %arg1 : !LowLFHE.glwe_ciphertext to !LowLFHE.glwe_ciphertext
|
||||
// CHECK-NEXT: call @bootstrap_lwe_u64(%[[ERR]], %[[V2]], %[[V1]], %[[V3]], %[[V4]]) : (index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext) -> ()
|
||||
// CHECK-NEXT: %[[RES:.*]] = unrealized_conversion_cast %[[V1]] : !LowLFHE.lwe_ciphertext<_,_> to !LowLFHE.lwe_ciphertext<1024,4>
|
||||
// CHECK-NEXT: return %[[RES]] : !LowLFHE.lwe_ciphertext<1024,4>
|
||||
%1 = "LowLFHE.bootstrap_lwe"(%arg0, %arg1) {baseLog = 2 : i32, k = 1 : i32, level = 3 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
return %1: !LowLFHE.lwe_ciphertext<1024,4>
|
||||
}
|
||||
@@ -1,13 +1,22 @@
|
||||
// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: module
|
||||
// CHECK-NEXT: func private @runtime_foreign_plaintext_list_u64(index, tensor<16xi64>, i64, i32) -> !LowLFHE.foreign_plaintext_list
|
||||
// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)
|
||||
// CHECK-NEXT: func private @fill_plaintext_list_with_expansion_u64(index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list)
|
||||
// CHECK-NEXT: func private @runtime_foreign_plaintext_list_u64(index, tensor<16xi4>, i64, i32) -> !LowLFHE.foreign_plaintext_list
|
||||
// CHECK-NEXT: func private @allocate_plaintext_list_u64(index, i32) -> !LowLFHE.plaintext_list
|
||||
// CHECK-NEXT: func private @allocate_glwe_ciphertext_u64(index, i32, i32) -> !LowLFHE.glwe_ciphertext
|
||||
// CHECK-LABEL: func @glwe_from_table(%arg0: tensor<16xi4>) -> !LowLFHE.glwe_ciphertext
|
||||
func @glwe_from_table(%arg0: tensor<16xi4>) -> !LowLFHE.glwe_ciphertext {
|
||||
// CHECK-NEXT: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
|
||||
// CHECK-NEXT: func private @getGlobalKeyswitchKey() -> !LowLFHE.lwe_key_switch_key
|
||||
// CHECK-NEXT: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext)
|
||||
// CHECK-NEXT: func private @getGlobalBootstrapKey() -> !LowLFHE.lwe_bootstrap_key
|
||||
// CHECK-NEXT: func private @negate_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
|
||||
// CHECK-NEXT: func private @mul_cleartext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.cleartext<_>)
|
||||
// CHECK-NEXT: func private @add_plaintext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.plaintext<_>)
|
||||
// CHECK-NEXT: func private @add_lwe_ciphertexts_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
|
||||
// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, i32) -> !LowLFHE.lwe_ciphertext<_,_>
|
||||
// CHECK-LABEL: func @glwe_from_table(%arg0: tensor<16xi64>) -> !LowLFHE.glwe_ciphertext
|
||||
func @glwe_from_table(%arg0: tensor<16xi64>) -> !LowLFHE.glwe_ciphertext {
|
||||
// CHECK-NEXT: %[[V0:.*]] = constant 0 : index
|
||||
// CHECK-NEXT: %[[C0:.*]] = constant 1 : i32
|
||||
// CHECK-NEXT: %[[C1:.*]] = constant 1024 : i32
|
||||
@@ -16,10 +25,10 @@ func @glwe_from_table(%arg0: tensor<16xi4>) -> !LowLFHE.glwe_ciphertext {
|
||||
// CHECK-NEXT: %[[V3:.*]] = call @allocate_plaintext_list_u64(%[[V0]], %[[C1]]) : (index, i32) -> !LowLFHE.plaintext_list
|
||||
// CHECK-NEXT: %[[C2:.*]] = constant 16 : i64
|
||||
// CHECK-NEXT: %[[C3:.*]] = constant 4 : i32
|
||||
// CHECK-NEXT: %[[V4:.*]] = call @runtime_foreign_plaintext_list_u64(%[[V0]], %arg0, %[[C2]], %[[C3]]) : (index, tensor<16xi4>, i64, i32) -> !LowLFHE.foreign_plaintext_list
|
||||
// CHECK-NEXT: %[[V4:.*]] = call @runtime_foreign_plaintext_list_u64(%[[V0]], %arg0, %[[C2]], %[[C3]]) : (index, tensor<16xi64>, i64, i32) -> !LowLFHE.foreign_plaintext_list
|
||||
// CHECK-NEXT: call @fill_plaintext_list_with_expansion_u64(%[[V0]], %[[V3]], %[[V4]]) : (index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list) -> ()
|
||||
// CHECK-NEXT: call @add_plaintext_list_glwe_ciphertext_u64(%[[V0]], %[[V1]], %[[V2]], %[[V3]]) : (index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list) -> ()
|
||||
// CHECK-NEXT: return %[[V1]] : !LowLFHE.glwe_ciphertext
|
||||
%1 = "LowLFHE.glwe_from_table"(%arg0) {k = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi4>) -> !LowLFHE.glwe_ciphertext
|
||||
%1 = "LowLFHE.glwe_from_table"(%arg0) {k = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !LowLFHE.glwe_ciphertext
|
||||
return %1: !LowLFHE.glwe_ciphertext
|
||||
}
|
||||
@@ -1,16 +1,29 @@
|
||||
// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: module
|
||||
// CHECK-NEXT: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>)
|
||||
// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, i32) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)
|
||||
// CHECK-NEXT: func private @fill_plaintext_list_with_expansion_u64(index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list)
|
||||
// CHECK-NEXT: func private @allocate_plaintext_list_u64(index, i32) -> !LowLFHE.plaintext_list
|
||||
// CHECK-NEXT: func private @allocate_glwe_ciphertext_u64(index, i32, i32) -> !LowLFHE.glwe_ciphertext
|
||||
// CHECK-NEXT: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
|
||||
// CHECK-NEXT: func private @getGlobalKeyswitchKey() -> !LowLFHE.lwe_key_switch_key
|
||||
// CHECK-NEXT: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext)
|
||||
// CHECK-NEXT: func private @getGlobalBootstrapKey() -> !LowLFHE.lwe_bootstrap_key
|
||||
// CHECK-NEXT: func private @negate_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
|
||||
// CHECK-NEXT: func private @mul_cleartext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.cleartext<_>)
|
||||
// CHECK-NEXT: func private @add_plaintext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.plaintext<_>)
|
||||
// CHECK-NEXT: func private @add_lwe_ciphertexts_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>)
|
||||
// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, i32) -> !LowLFHE.lwe_ciphertext<_,_>
|
||||
// CHECK-LABEL: func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> {
|
||||
// CHECK-NEXT: %[[ERR:.*]] = constant 0 : index
|
||||
// CHECK-NEXT: %[[C0:.*]] = constant 1 : i32
|
||||
// CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[ERR]], %[[C0]]) : (index, i32) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
// CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[ERR]], %[[C0]]) : (index, i32) -> !LowLFHE.lwe_ciphertext<_,_>
|
||||
// CHECK-NEXT: %[[V2:.*]] = call @getGlobalKeyswitchKey() : () -> !LowLFHE.lwe_key_switch_key
|
||||
// CHECK-NEXT: call @keyswitch_lwe_u64(%[[ERR]], %[[V2]], %[[V1]], %arg0) : (index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>) -> ()
|
||||
// CHECK-NEXT: return %[[V1]] : !LowLFHE.lwe_ciphertext<1024,4>
|
||||
// CHECK-NEXT: %[[V3:.*]] = unrealized_conversion_cast %arg0 : !LowLFHE.lwe_ciphertext<1024,4> to !LowLFHE.lwe_ciphertext<_,_>
|
||||
// CHECK-NEXT: call @keyswitch_lwe_u64(%[[ERR]], %[[V2]], %[[V1]], %[[V3]]) : (index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) -> ()
|
||||
// CHECK-NEXT: %[[RES:.*]] = unrealized_conversion_cast %[[V1]] : !LowLFHE.lwe_ciphertext<_,_> to !LowLFHE.lwe_ciphertext<1024,4>
|
||||
// CHECK-NEXT: return %[[RES]] : !LowLFHE.lwe_ciphertext<1024,4>
|
||||
%1 = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 1 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
return %1: !LowLFHE.lwe_ciphertext<1024,4>
|
||||
}
|
||||
@@ -3,8 +3,8 @@
|
||||
// CHECK-LABEL: func @apply_lookup_table(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: tensor<16xi4>) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi4>) -> !MidLFHE.glwe<{1024,1,64}{4}> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "LowLFHE.glwe_from_table"(%arg1) {k = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi4>) -> !LowLFHE.glwe_ciphertext
|
||||
// CHECK-NEXT: %[[V2:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 600 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "LowLFHE.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, k = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 600 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<600,4>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "LowLFHE.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, k = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<600,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
// CHECK-NEXT: return %[[V3]] : !LowLFHE.lwe_ciphertext<1024,4>
|
||||
%1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1){k=1:i32, polynomialSize=1024:i32, levelKS=3:i32, baseLogKS=2:i32, levelBS=5:i32, baseLogBS=4:i32, outputSizeKS=600:i32}: (!MidLFHE.glwe<{1024,1,64}{4}>, tensor<16xi4>) -> (!MidLFHE.glwe<{1024,1,64}{4}>)
|
||||
return %1: !MidLFHE.glwe<{1024,1,64}{4}>
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{2048,1,64}{4}>) -> !MidLFHE.glwe<{2048,1,64}{4}> {
|
||||
// CHECK-NEXT: %[[TABLE:.*]] = constant dense<[0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1]> : tensor<16xi4>
|
||||
// CHECK-NEXT: %[[V1:.*]] = "LowLFHE.glwe_from_table"(%[[TABLE]]) {k = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32} : (tensor<16xi4>) -> !LowLFHE.glwe_ciphertext
|
||||
// CHECK-NEXT: %[[V2:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 600 : i32} : (!LowLFHE.lwe_ciphertext<2048,4>) -> !LowLFHE.lwe_ciphertext<2048,4>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "LowLFHE.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, k = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (!LowLFHE.lwe_ciphertext<2048,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<2048,4>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 600 : i32} : (!LowLFHE.lwe_ciphertext<2048,4>) -> !LowLFHE.lwe_ciphertext<600,4>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "LowLFHE.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, k = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (!LowLFHE.lwe_ciphertext<600,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<2048,4>
|
||||
// CHECK-NEXT: return %[[V3]] : !LowLFHE.lwe_ciphertext<2048,4>
|
||||
%tlu = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi4>
|
||||
%1 = "MidLFHE.apply_lookup_table"(%arg0, %tlu){k=1:i32, polynomialSize=2048:i32, levelKS=3:i32, baseLogKS=2:i32, levelBS=5:i32, baseLogBS=4:i32, outputSizeKS=600:i32}: (!MidLFHE.glwe<{2048,1,64}{4}>, tensor<16xi4>) -> (!MidLFHE.glwe<{2048,1,64}{4}>)
|
||||
|
||||
@@ -29,7 +29,6 @@ def test_compile_and_run(mlir_input, args, expected_result):
|
||||
(
|
||||
"""
|
||||
func @main(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
// 0..128 shifted << 55
|
||||
%tlu = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64>
|
||||
%1 = "HLFHE.apply_lookup_table"(%arg0, %tlu): (!HLFHE.eint<7>, tensor<128xi64>) -> (!HLFHE.eint<7>)
|
||||
return %1: !HLFHE.eint<7>
|
||||
|
||||
Reference in New Issue
Block a user