diff --git a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h index 9466a7097..3a65456c7 100644 --- a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h +++ b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h @@ -169,6 +169,47 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter rewriter, return op.getODSResults(0).front(); } +// This is the rewritting of the HLFHE::ApplyLookupTable operation, it will be +// rewritten as 3 new operations: +// - Create the required GLWE ciphertext out of the plain lookup table +// - Keyswitch the input ciphertext to match the input key of the bootstrapping +// - Bootstrap the keyswitched ciphertext with the constructed GLWE ciphertext +// Example: +// from: +// ``` +// "%result = MidLFHE.apply_lookup_table"(% arg0, % tlu){ +// k = 1 : i32, +// polynomialSize = 2048 : i32, +// levelKS = 3 : i32, +// baseLogKS = 2 : i32, +// levelBS = 5 : i32, +// baseLogBS = 4 : i32, +// outputSizeKS = 600 : i32 +// } : (!MidLFHE.glwe<{2048, 1, 64} {4}>, tensor<16xi4>) +// ->(!MidLFHE.glwe<{2048, 1, 64} {4}>) +// ``` +// to: +// ``` +// % accumulator = +// "LowLFHE.glwe_from_table"( +// % [[TABLE]]){k = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32} +// : (tensor<16xi4>) +// ->!LowLFHE.glwe_ciphertext +// % keyswitched = "LowLFHE.keyswitch_lwe"(% arg0){ +// baseLog = 2 : i32, +// inputLweSize = 1 : i32, +// level = 3 : i32, +// outputLweSize = 600 : i32 +// } : (!LowLFHE.lwe_ciphertext<2048, 4>) +// ->!LowLFHE.lwe_ciphertext<600, 4> +// % result = "LowLFHE.bootstrap_lwe"(% keyswitched, % accumulator){ +// baseLog = 4 : i32, +// k = 1 : i32, +// level = 5 : i32, +// polynomialSize = 2048 : i32 +// } : (!LowLFHE.lwe_ciphertext<600, 4>, !LowLFHE.glwe_ciphertext) +// ->!LowLFHE.lwe_ciphertext<2048, 4> +// ``` mlir::Value createPBS(mlir::PatternRewriter rewriter, mlir::Location loc, mlir::Value ct, mlir::Value table, mlir::IntegerAttr k, mlir::IntegerAttr polynomialSize, @@ -178,10 +219,9 @@ mlir::Value createPBS(mlir::PatternRewriter rewriter, mlir::Location loc, // convert result type GLWECipherTextType glwe_type = result.getType().cast(); LweCiphertextType lwe_type = - convertTypeGLWEToLWE(rewriter.getContext(), glwe_type); + convertTypeToLWE(rewriter.getContext(), glwe_type); // fill the the table in the GLWE accumulator - mlir::IntegerAttr precision = mlir::IntegerAttr::get( - mlir::IntegerType::get(rewriter.getContext(), 32), glwe_type.getP()); + mlir::IntegerAttr precision = rewriter.getI32IntegerAttr(glwe_type.getP()); mlir::Value accumulator = rewriter .create( @@ -191,8 +231,8 @@ mlir::Value createPBS(mlir::PatternRewriter rewriter, mlir::Location loc, // keyswitch auto ct_type = ct.getType().cast(); - mlir::SmallVector ksArgs{ct}; - mlir::SmallVector ksAttrs{ + mlir::SmallVector ksArgs{ct}; + mlir::SmallVector ksAttrs{ mlir::NamedAttribute( mlir::Identifier::get("inputLweSize", rewriter.getContext()), k), mlir::NamedAttribute( @@ -203,16 +243,17 @@ mlir::Value createPBS(mlir::PatternRewriter rewriter, mlir::Location loc, mlir::NamedAttribute( mlir::Identifier::get("baseLog", rewriter.getContext()), baseLogKS), }; + auto ksOutType = LweCiphertextType::get( + rewriter.getContext(), outputSizeKS.getInt(), ct_type.getP()); mlir::Value keyswitched = rewriter - .create( - loc, convertTypeGLWEToLWE(rewriter.getContext(), ct_type), ksArgs, - ksAttrs) + .create(loc, ksOutType, + ksArgs, ksAttrs) .result(); // bootstrap operation - mlir::SmallVector bsArgs{keyswitched, accumulator}; - mlir::SmallVector bsAttrs{ + mlir::SmallVector bsArgs{keyswitched, accumulator}; + mlir::SmallVector bsAttrs{ mlir::NamedAttribute(mlir::Identifier::get("k", rewriter.getContext()), k), mlir::NamedAttribute( diff --git a/compiler/include/zamalang/Support/KeySet.h b/compiler/include/zamalang/Support/KeySet.h index a31e4f017..0b4f50db3 100644 --- a/compiler/include/zamalang/Support/KeySet.h +++ b/compiler/include/zamalang/Support/KeySet.h @@ -41,7 +41,7 @@ public: CircuitGate inputGate(size_t pos) { return std::get<0>(inputs[pos]); } CircuitGate outputGate(size_t pos) { return std::get<0>(outputs[pos]); } - void generateRuntimeContext() { + void initGlobalRuntimeContext() { auto ksk = std::get<1>(this->keyswitchKeys["ksk_v0"]); auto bsk = std::get<1>(this->bootstrapKeys["bsk_v0"]); setGlobalRuntimeContext(createRuntimeContext(ksk, bsk)); diff --git a/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp b/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp index 0aa06c307..eb844537c 100644 --- a/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp +++ b/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp @@ -26,7 +26,7 @@ public: }; mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op, - mlir::PatternRewriter &rewriter, + mlir::RewriterBase &rewriter, llvm::StringRef funcName, mlir::FunctionType funcType) { // Looking for the `funcName` Operation @@ -54,6 +54,271 @@ mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op, return mlir::success(); } +// Set of functions to generate generic types. +// Generic types are used to add forward declarations without a specific type. +// For example, we may need to add LWE ciphertext of different dimensions, or +// allocate them. All the calls to the C API should be done using this generic +// types, and casting should then be performed back to the appropriate type. + +inline mlir::zamalang::LowLFHE::LweCiphertextType +getGenericLweCiphertextType(mlir::MLIRContext *context) { + return mlir::zamalang::LowLFHE::LweCiphertextType::get(context, -1, -1); +} + +inline mlir::zamalang::LowLFHE::GlweCiphertextType +getGenericGlweCiphertextType(mlir::MLIRContext *context) { + return mlir::zamalang::LowLFHE::GlweCiphertextType::get(context); +} + +inline mlir::zamalang::LowLFHE::PlaintextType +getGenericPlaintextType(mlir::MLIRContext *context) { + return mlir::zamalang::LowLFHE::PlaintextType::get(context, -1); +} + +inline mlir::zamalang::LowLFHE::PlaintextListType +getGenericPlaintextListType(mlir::MLIRContext *context) { + return mlir::zamalang::LowLFHE::PlaintextListType::get(context); +} + +inline mlir::zamalang::LowLFHE::ForeignPlaintextListType +getGenericForeignPlaintextListType(mlir::MLIRContext *context) { + return mlir::zamalang::LowLFHE::ForeignPlaintextListType::get(context); +} + +inline mlir::zamalang::LowLFHE::CleartextType +getGenericCleartextType(mlir::MLIRContext *context) { + return mlir::zamalang::LowLFHE::CleartextType::get(context, -1); +} + +inline mlir::zamalang::LowLFHE::LweBootstrapKeyType +getGenericLweBootstrapKeyType(mlir::MLIRContext *context) { + return mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(context); +} + +inline mlir::zamalang::LowLFHE::LweKeySwitchKeyType +getGenericLweKeySwitchKeyType(mlir::MLIRContext *context) { + return mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(context); +} + +// Get the generic version of the type. +// Useful when iterating over a set of types. +mlir::Type getGenericType(mlir::Type baseType) { + if (baseType.isa()) { + return getGenericLweCiphertextType(baseType.getContext()); + } + if (baseType.isa()) { + return getGenericPlaintextType(baseType.getContext()); + } + if (baseType.isa()) { + return getGenericCleartextType(baseType.getContext()); + } + return baseType; +} + +// Insert all forward declarations needed for the pass. +// Should generalize input and output types for all decalarations, and the +// pattern using them would be resposible for casting them to the appropriate +// type. +mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, + mlir::IRRewriter &rewriter) { + auto genericLweCiphertextType = + getGenericLweCiphertextType(rewriter.getContext()); + auto genericGlweCiphertextType = + getGenericGlweCiphertextType(rewriter.getContext()); + auto genericPlaintextType = getGenericPlaintextType(rewriter.getContext()); + auto genericPlaintextListType = + getGenericPlaintextListType(rewriter.getContext()); + auto genericForeignPlaintextList = + getGenericForeignPlaintextListType(rewriter.getContext()); + auto genericCleartextType = getGenericCleartextType(rewriter.getContext()); + auto genericBSKType = getGenericLweBootstrapKeyType(rewriter.getContext()); + auto genericKSKType = getGenericLweKeySwitchKeyType(rewriter.getContext()); + auto errType = mlir::IndexType::get(rewriter.getContext()); + + // Insert forward declaration of allocate lwe ciphertext + { + auto funcType = mlir::FunctionType::get( + rewriter.getContext(), + { + errType, + mlir::IntegerType::get(rewriter.getContext(), 32), + }, + + {genericLweCiphertextType}); + if (insertForwardDeclaration(op, rewriter, "allocate_lwe_ciphertext_u64", + funcType) + .failed()) { + return mlir::failure(); + } + } + // Insert forward declaration of the add_lwe_ciphertexts function + { + auto funcType = mlir::FunctionType::get(rewriter.getContext(), + { + errType, + genericLweCiphertextType, + genericLweCiphertextType, + genericLweCiphertextType, + }, + {}); + if (insertForwardDeclaration(op, rewriter, "add_lwe_ciphertexts_u64", + funcType) + .failed()) { + return mlir::failure(); + } + } + // Insert forward declaration of the add_plaintext_lwe_ciphertext_u64 function + { + auto funcType = mlir::FunctionType::get(rewriter.getContext(), + { + errType, + genericLweCiphertextType, + genericLweCiphertextType, + genericPlaintextType, + }, + {}); + if (insertForwardDeclaration(op, rewriter, + "add_plaintext_lwe_ciphertext_u64", funcType) + .failed()) { + return mlir::failure(); + } + } + // Insert forward declaration of the mul_cleartext_lwe_ciphertext_u64 function + { + auto funcType = mlir::FunctionType::get(rewriter.getContext(), + { + errType, + genericLweCiphertextType, + genericLweCiphertextType, + genericCleartextType, + }, + {}); + if (insertForwardDeclaration(op, rewriter, + "mul_cleartext_lwe_ciphertext_u64", funcType) + .failed()) { + return mlir::failure(); + } + } + // Insert forward declaration of the negate_lwe_ciphertext_u64 function + { + auto funcType = mlir::FunctionType::get( + rewriter.getContext(), + {errType, genericLweCiphertextType, genericLweCiphertextType}, {}); + if (insertForwardDeclaration(op, rewriter, "negate_lwe_ciphertext_u64", + funcType) + .failed()) { + return mlir::failure(); + } + } + // Insert forward declaration of the getBsk function + { + auto funcType = + mlir::FunctionType::get(rewriter.getContext(), {}, {genericBSKType}); + if (insertForwardDeclaration(op, rewriter, "getGlobalBootstrapKey", + funcType) + .failed()) { + return mlir::failure(); + } + } + // Insert forward declaration of the bootstrap function + { + auto funcType = mlir::FunctionType::get(rewriter.getContext(), + { + errType, + genericBSKType, + genericLweCiphertextType, + genericLweCiphertextType, + genericGlweCiphertextType, + }, + {}); + if (insertForwardDeclaration(op, rewriter, "bootstrap_lwe_u64", funcType) + .failed()) { + return mlir::failure(); + } + } + // Insert forward declaration of the getKsk function + { + auto funcType = + mlir::FunctionType::get(rewriter.getContext(), {}, {genericKSKType}); + if (insertForwardDeclaration(op, rewriter, "getGlobalKeyswitchKey", + funcType) + .failed()) { + return mlir::failure(); + } + } + // Insert forward declaration of the keyswitch function + { + auto funcType = mlir::FunctionType::get(rewriter.getContext(), + { + errType, + // ksk + genericKSKType, + // output ct + genericLweCiphertextType, + // input ct + genericLweCiphertextType, + }, + {}); + if (insertForwardDeclaration(op, rewriter, "keyswitch_lwe_u64", funcType) + .failed()) { + return mlir::failure(); + } + } + // Insert forward declaration of the alloc_glwe function + { + auto funcType = mlir::FunctionType::get( + rewriter.getContext(), + { + errType, + mlir::IntegerType::get(rewriter.getContext(), 32), + mlir::IntegerType::get(rewriter.getContext(), 32), + }, + {genericGlweCiphertextType}); + if (insertForwardDeclaration(op, rewriter, "allocate_glwe_ciphertext_u64", + funcType) + .failed()) { + return mlir::failure(); + } + } + // Insert forward declaration of the alloc_plaintext_list function + { + auto funcType = mlir::FunctionType::get( + rewriter.getContext(), + {errType, mlir::IntegerType::get(rewriter.getContext(), 32)}, + {genericPlaintextListType}); + if (insertForwardDeclaration(op, rewriter, "allocate_plaintext_list_u64", + funcType) + .failed()) { + return mlir::failure(); + } + } + // Insert forward declaration of the fill_plaintext_list function + { + auto funcType = mlir::FunctionType::get( + rewriter.getContext(), + {errType, genericPlaintextListType, genericForeignPlaintextList}, {}); + if (insertForwardDeclaration( + op, rewriter, "fill_plaintext_list_with_expansion_u64", funcType) + .failed()) { + return mlir::failure(); + } + } + // Insert forward declaration of the add_plaintext_list_glwe function + { + auto funcType = mlir::FunctionType::get(rewriter.getContext(), + {errType, genericGlweCiphertextType, + genericGlweCiphertextType, + genericPlaintextListType}, + {}); + if (insertForwardDeclaration( + op, rewriter, "add_plaintext_list_glwe_ciphertext_u64", funcType) + .failed()) { + return mlir::failure(); + } + } + return mlir::success(); +} + /// LowLFHEOpToConcreteCAPICallPattern match the `Op` Operation and /// replace with a call to `funcName`, the funcName should be an external /// function that was linked later. It insert the forward declaration of the @@ -81,29 +346,7 @@ struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern { matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { LowLFHEToConcreteCAPITypeConverter typeConverter; auto errType = mlir::IndexType::get(rewriter.getContext()); - // Insert forward declaration of the operator function - { - mlir::SmallVector operands{errType, - op->getResultTypes().front()}; - for (auto ty : op->getOperandTypes()) { - operands.push_back(typeConverter.convertType(ty)); - } - auto funcType = - mlir::FunctionType::get(rewriter.getContext(), operands, {}); - if (insertForwardDeclaration(op, rewriter, funcName, funcType).failed()) { - return mlir::failure(); - } - } - // Insert forward declaration of the alloc function - { - auto funcType = mlir::FunctionType::get( - rewriter.getContext(), {errType, rewriter.getIndexType()}, - {op->getResultTypes().front()}); - if (insertForwardDeclaration(op, rewriter, allocName, funcType) - .failed()) { - return mlir::failure(); - } - } + mlir::Type resultType = op->getResultTypes().front(); auto lweResultType = resultType.cast(); @@ -114,18 +357,39 @@ struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern { rewriter.getIndexAttr(0)); // Add the call to the allocation auto lweSizeOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(lweResultType.getSize())); + op.getLoc(), rewriter.getI32IntegerAttr(lweResultType.getSize())); mlir::SmallVector allocOperands{errOp, lweSizeOp}; - auto alloc = rewriter.replaceOpWithNewOp( - op, allocName, op.getType(), allocOperands); - - // Add err and allocated value to operands - mlir::SmallVector newOperands{errOp, alloc.getResult(0)}; - for (auto operand : op->getOperands()) { - newOperands.push_back(operand); + auto allocGeneric = rewriter.create( + op.getLoc(), allocName, + getGenericLweCiphertextType(rewriter.getContext()), allocOperands); + // Construct operands for the operation. + // errOp doesn't need to be casted to something generic, allocGeneric + // already is. All the rest will be converted if needed + mlir::SmallVector newOperands{errOp, + allocGeneric.getResult(0)}; + for (mlir::Value operand : op->getOperands()) { + mlir::Type operandType = operand.getType(); + mlir::Type castedType = getGenericType(operandType); + if (castedType == operandType) { + // Type didn't change, no need for cast + newOperands.push_back(operand); + } else { + // Type changed, need to cast to the generic one + auto castedOperand = rewriter + .create( + op.getLoc(), castedType, operand) + .getResult(0); + newOperands.push_back(castedOperand); + } } + // The operations called here are known to be inplace, and no need for a + // return type. rewriter.create(op.getLoc(), funcName, mlir::TypeRange{}, newOperands); + // cast result value to the appropriate type + auto alloc = + rewriter.replaceOpWithNewOp( + op, op.getType(), allocGeneric.getResult(0)); } return mlir::success(); }; @@ -145,32 +409,24 @@ struct LowLFHEZeroOpPattern mlir::LogicalResult matchAndRewrite(mlir::zamalang::LowLFHE::ZeroLWEOp op, mlir::PatternRewriter &rewriter) const override { - auto allocName = "allocate_lwe_ciphertext_u64"; - auto errType = mlir::IndexType::get(rewriter.getContext()); - { - auto funcType = mlir::FunctionType::get( - rewriter.getContext(), {errType, rewriter.getIndexType()}, - {op->getResultTypes().front()}); - if (insertForwardDeclaration(op, rewriter, allocName, funcType) - .failed()) { - return mlir::failure(); - } - } - // Replace the operation with a call to the `funcName` - { - mlir::Type resultType = op->getResultTypes().front(); - auto lweResultType = - resultType.cast(); - // Create the err value - auto errOp = rewriter.create(op.getLoc(), - rewriter.getIndexAttr(0)); - // Add the call to the allocation - auto lweSizeOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(lweResultType.getSize())); - mlir::SmallVector allocOperands{errOp, lweSizeOp}; - auto alloc = rewriter.replaceOpWithNewOp( - op, allocName, op.getType(), allocOperands); - } + + mlir::Type resultType = op->getResultTypes().front(); + auto lweResultType = + resultType.cast(); + // Create the err value + auto errOp = rewriter.create(op.getLoc(), + rewriter.getIndexAttr(0)); + // Allocate a fresh new ciphertext + auto lweSizeOp = rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(lweResultType.getSize())); + mlir::SmallVector allocOperands{errOp, lweSizeOp}; + auto allocGeneric = rewriter.create( + op.getLoc(), "allocate_lwe_ciphertext_u64", + getGenericLweCiphertextType(rewriter.getContext()), allocOperands); + // Cast the result to the appropriate type + rewriter.replaceOpWithNewOp( + op, op.getType(), allocGeneric.getResult(0)); + return mlir::success(); }; }; @@ -215,6 +471,14 @@ struct LowLFHEIntToCleartextOpPattern }; }; +// Rewrite the GlweFromTable operation to a series of ops: +// - allocation of two GLWE, one for the addition, and one for storing the +// result +// - allocation of plaintext_list to build the GLWE accumulator +// - build the foreign_plaintext_list using the input table +// - fill the plaintext_list with the foreign_plaintext_list +// - construct the GLWE accumulator by adding the plaintext_list to a freshly +// allocated GLWE struct GlweFromTableOpPattern : public mlir::OpRewritePattern { GlweFromTableOpPattern(mlir::MLIRContext *context, @@ -227,49 +491,19 @@ struct GlweFromTableOpPattern mlir::PatternRewriter &rewriter) const override { LowLFHEToConcreteCAPITypeConverter typeConverter; auto errType = mlir::IndexType::get(rewriter.getContext()); - // Insert forward declaration of the alloc_glwe function - { - auto funcType = mlir::FunctionType::get( - rewriter.getContext(), - { - errType, - mlir::IntegerType::get(rewriter.getContext(), 32), - mlir::IntegerType::get(rewriter.getContext(), 32), - }, - {mlir::zamalang::LowLFHE::GlweCiphertextType::get( - rewriter.getContext())}); - if (insertForwardDeclaration(op, rewriter, "allocate_glwe_ciphertext_u64", - funcType) - .failed()) { - return mlir::failure(); - } - } - // Insert forward declaration of the alloc_plaintext_list function - { - auto funcType = mlir::FunctionType::get( - rewriter.getContext(), - {errType, mlir::IntegerType::get(rewriter.getContext(), 32)}, - {mlir::zamalang::LowLFHE::PlaintextListType::get( - rewriter.getContext())}); - if (insertForwardDeclaration(op, rewriter, "allocate_plaintext_list_u64", - funcType) - .failed()) { - return mlir::failure(); - } - } + + // TODO: move this to insertForwardDeclarations + // issue: can't define function with tensor<*xtype> that accept ranked + // tensors // Insert forward declaration of the foregin_pt_list function { auto funcType = mlir::FunctionType::get( rewriter.getContext(), - {errType, - // mlir::UnrankedTensorType::get( - // mlir::IntegerType::get(rewriter.getContext(), 64)), - op->getOperandTypes().front(), + {errType, op->getOperandTypes().front(), mlir::IntegerType::get(rewriter.getContext(), 64), mlir::IntegerType::get(rewriter.getContext(), 32)}, - {mlir::zamalang::LowLFHE::ForeignPlaintextListType::get( - rewriter.getContext())}); + {getGenericForeignPlaintextListType(rewriter.getContext())}); if (insertForwardDeclaration( op, rewriter, "runtime_foreign_plaintext_list_u64", funcType) .failed()) { @@ -277,41 +511,6 @@ struct GlweFromTableOpPattern } } - // Insert forward declaration of the fill_plaintext_list function - { - auto funcType = mlir::FunctionType::get( - rewriter.getContext(), - {errType, - mlir::zamalang::LowLFHE::PlaintextListType::get( - rewriter.getContext()), - mlir::zamalang::LowLFHE::ForeignPlaintextListType::get( - rewriter.getContext())}, - {}); - if (insertForwardDeclaration( - op, rewriter, "fill_plaintext_list_with_expansion_u64", funcType) - .failed()) { - return mlir::failure(); - } - } - - // Insert forward declaration of the add_plaintext_list_glwe function - { - auto funcType = mlir::FunctionType::get( - rewriter.getContext(), - {errType, - mlir::zamalang::LowLFHE::GlweCiphertextType::get( - rewriter.getContext()), - mlir::zamalang::LowLFHE::GlweCiphertextType::get( - rewriter.getContext()), - mlir::zamalang::LowLFHE::PlaintextListType::get( - rewriter.getContext())}, - {}); - if (insertForwardDeclaration( - op, rewriter, "add_plaintext_list_glwe_ciphertext_u64", funcType) - .failed()) { - return mlir::failure(); - } - } auto errOp = rewriter.create(op.getLoc(), rewriter.getIndexAttr(0)); // allocate two glwe to build accumulator @@ -324,39 +523,33 @@ struct GlweFromTableOpPattern // first accumulator would replace the op since it's the returned value auto accumulatorOp = rewriter.replaceOpWithNewOp( op, "allocate_glwe_ciphertext_u64", - mlir::zamalang::LowLFHE::GlweCiphertextType::get(rewriter.getContext()), - allocGlweOperands); + getGenericGlweCiphertextType(rewriter.getContext()), allocGlweOperands); // second accumulator is just needed to build the actual accumulator auto _accumulatorOp = rewriter.create( op.getLoc(), "allocate_glwe_ciphertext_u64", - mlir::zamalang::LowLFHE::GlweCiphertextType::get(rewriter.getContext()), - allocGlweOperands); + getGenericGlweCiphertextType(rewriter.getContext()), allocGlweOperands); // allocate plaintext list mlir::SmallVector allocPlaintextListOperands{errOp, polySizeOp}; auto plaintextListOp = rewriter.create( op.getLoc(), "allocate_plaintext_list_u64", - mlir::zamalang::LowLFHE::PlaintextListType::get(rewriter.getContext()), + getGenericPlaintextListType(rewriter.getContext()), allocPlaintextListOperands); // create foreign plaintext auto rankedTensorType = op->getOperandTypes().front().cast(); - if (rankedTensorType.getRank() != 1) { - llvm::errs() << "table lookup must be of a single dimension"; - return mlir::failure(); - } + assert(rankedTensorType.getRank() == 1 && + "table lookup must be of a single dimension"); auto sizeOp = rewriter.create( - op.getLoc(), rewriter.getIntegerAttr( - mlir::IntegerType::get(rewriter.getContext(), 64), - rankedTensorType.getDimSize(0))); + op.getLoc(), + rewriter.getI64IntegerAttr(rankedTensorType.getDimSize(0))); auto precisionOp = rewriter.create(op.getLoc(), op->getAttr("p")); mlir::SmallVector ForeignPlaintextListOperands{ errOp, op->getOperand(0), sizeOp, precisionOp}; auto foreignPlaintextListOp = rewriter.create( op.getLoc(), "runtime_foreign_plaintext_list_u64", - mlir::zamalang::LowLFHE::ForeignPlaintextListType::get( - rewriter.getContext()), + getGenericForeignPlaintextListType(rewriter.getContext()), ForeignPlaintextListOperands); // fill plaintext list mlir::SmallVector FillPlaintextListOperands{ @@ -376,6 +569,11 @@ struct GlweFromTableOpPattern }; }; +// Rewrite a BootstrapLweOp with a series of ops: +// - allocate the result LWE ciphertext +// - get the global bootstrapping key +// - use the key and the input accumulator (GLWE) to bootstrap the input +// ciphertext struct LowLFHEBootstrapLweOpPattern : public mlir::OpRewritePattern { LowLFHEBootstrapLweOpPattern(mlir::MLIRContext *context, @@ -386,141 +584,58 @@ struct LowLFHEBootstrapLweOpPattern mlir::LogicalResult matchAndRewrite(mlir::zamalang::LowLFHE::BootstrapLweOp op, mlir::PatternRewriter &rewriter) const override { - auto errType = mlir::IndexType::get(rewriter.getContext()); - auto lweOperandType = op->getOperandTypes().front(); - // Insert forward declaration of the allocate_bsk_key function - // { - // auto funcType = mlir::FunctionType::get( - // rewriter.getContext(), - // { - // errType, - // // level - // mlir::IntegerType::get(rewriter.getContext(), 32), - // // baselog - // mlir::IntegerType::get(rewriter.getContext(), 32), - // // glwe size - // mlir::IntegerType::get(rewriter.getContext(), 32), - // // lwe size - // mlir::IntegerType::get(rewriter.getContext(), 32), - // // polynomial size - // mlir::IntegerType::get(rewriter.getContext(), 32), - // }, - // {mlir::zamalang::LowLFHE::LweBootstrapKeyType::get( - // rewriter.getContext())}); - // if (insertForwardDeclaration(op, rewriter, - // "allocate_lwe_bootstrap_key_u64", - // funcType) - // .failed()) { - // return mlir::failure(); - // } - // } - - // Insert forward declaration of the getBsk function - { - auto funcType = mlir::FunctionType::get( - rewriter.getContext(), {}, - {mlir::zamalang::LowLFHE::LweBootstrapKeyType::get( - rewriter.getContext())}); - if (insertForwardDeclaration(op, rewriter, "getGlobalBootstrapKey", - funcType) - .failed()) { - return mlir::failure(); - } - } - // Insert forward declaration of the allocate_lwe_ct function - { - auto funcType = mlir::FunctionType::get( - rewriter.getContext(), - { - errType, - mlir::IntegerType::get(rewriter.getContext(), 32), - }, - {lweOperandType}); - if (insertForwardDeclaration(op, rewriter, "allocate_lwe_ciphertext_u64", - funcType) - .failed()) { - return mlir::failure(); - } - } - // Insert forward declaration of the bootstrap function - { - auto funcType = mlir::FunctionType::get( - rewriter.getContext(), - { - errType, - mlir::zamalang::LowLFHE::LweBootstrapKeyType::get( - rewriter.getContext()), - lweOperandType, - lweOperandType, - mlir::zamalang::LowLFHE::GlweCiphertextType::get( - rewriter.getContext()), - }, - {}); - if (insertForwardDeclaration(op, rewriter, "bootstrap_lwe_u64", funcType) - .failed()) { - return mlir::failure(); - } - } - + auto resultType = op->getResultTypes().front(); + auto bstOutputSize = + resultType.cast().getSize(); auto errOp = rewriter.create(op.getLoc(), rewriter.getIndexAttr(0)); - // allocate the result lwe ciphertext - // TODO: use right value for output lwe size - // LweSize output_lwe_size = { (glwe_size._0 -1) * poly_size._0 + 1} + // allocate the result lwe ciphertext, should be of a generic type, to cast + // before return auto lweSizeOp = rewriter.create( - op.getLoc(), mlir::IntegerAttr::get( - mlir::IntegerType::get(rewriter.getContext(), 32), - op->getAttr("k").cast().getInt())); + op.getLoc(), + mlir::IntegerAttr::get( + mlir::IntegerType::get(rewriter.getContext(), 32), bstOutputSize)); mlir::SmallVector allocLweCtOperands{errOp, lweSizeOp}; - auto allocateLweCtOp = rewriter.replaceOpWithNewOp( - op, "allocate_lwe_ciphertext_u64", lweOperandType, allocLweCtOperands); - // allocate bsk - // auto decompLevelCountOp = rewriter.create( - // op.getLoc(), - // mlir::IntegerAttr::get( - // mlir::IntegerType::get(rewriter.getContext(), 32), - // op->getAttr("level").cast().getInt())); - // auto decompBaseLogOp = rewriter.create( - // op.getLoc(), - // mlir::IntegerAttr::get( - // mlir::IntegerType::get(rewriter.getContext(), 32), - // op->getAttr("baseLog").cast().getInt())); - // auto glweSizeOp = rewriter.create( - // op.getLoc(), - // mlir::IntegerAttr::get( - // mlir::IntegerType::get(rewriter.getContext(), 32), -1)); - // auto polySizeOp = rewriter.create( - // op.getLoc(), - // mlir::IntegerAttr::get( - // mlir::IntegerType::get(rewriter.getContext(), 32), - // op->getAttr("polynomialSize").cast().getInt())); - // mlir::SmallVector allocBskOperands{ - // errOp, decompLevelCountOp, decompBaseLogOp, - // glweSizeOp, lweSizeOp, polySizeOp}; - // auto allocateBskOp = rewriter.create( - // op.getLoc(), "allocate_lwe_bootstrap_key_u64", - // mlir::zamalang::LowLFHE::LweBootstrapKeyType::get( - // rewriter.getContext()), - // allocBskOperands); - + auto allocateGenericLweCtOp = rewriter.create( + op.getLoc(), "allocate_lwe_ciphertext_u64", + getGenericLweCiphertextType(rewriter.getContext()), allocLweCtOperands); // get bsk mlir::SmallVector getBskOperands{}; auto getBskOp = rewriter.create( op.getLoc(), "getGlobalBootstrapKey", - mlir::zamalang::LowLFHE::LweBootstrapKeyType::get( - rewriter.getContext()), - getBskOperands); + getGenericLweBootstrapKeyType(rewriter.getContext()), getBskOperands); // bootstrap + // cast input ciphertext to a generic type + mlir::Value lweToBootstrap = + rewriter + .create( + op.getLoc(), getGenericType(op.getOperand(0).getType()), + op.getOperand(0)) + .getResult(0); + // cast input accumulator to a generic type + mlir::Value accumulator = + rewriter + .create( + op.getLoc(), getGenericType(op.getOperand(1).getType()), + op.getOperand(1)) + .getResult(0); mlir::SmallVector bootstrapOperands{ - errOp, getBskOp.getResult(0), allocateLweCtOp.getResult(0), - op->getOperand(0), op->getOperand(1)}; + errOp, getBskOp.getResult(0), allocateGenericLweCtOp.getResult(0), + lweToBootstrap, accumulator}; rewriter.create(op.getLoc(), "bootstrap_lwe_u64", mlir::TypeRange({}), bootstrapOperands); + // Cast result to the appropriate type + rewriter.replaceOpWithNewOp( + op, resultType, allocateGenericLweCtOp.getResult(0)); return mlir::success(); }; }; +// Rewrite a KeySwitchLweOp with a series of ops: +// - allocate the result LWE ciphertext +// - get the global keyswitch key +// - use the key to keyswitch the input ciphertext struct LowLFHEKeySwitchLweOpPattern : public mlir::OpRewritePattern { LowLFHEKeySwitchLweOpPattern(mlir::MLIRContext *context, @@ -531,139 +646,41 @@ struct LowLFHEKeySwitchLweOpPattern mlir::LogicalResult matchAndRewrite(mlir::zamalang::LowLFHE::KeySwitchLweOp op, mlir::PatternRewriter &rewriter) const override { - auto errType = mlir::IndexType::get(rewriter.getContext()); - auto lweOperandType = op->getOperandTypes().front(); - // Insert forward declaration of the allocate_ksk_key function - // { - // auto funcType = mlir::FunctionType::get( - // rewriter.getContext(), - // { - // errType, - // // level - // mlir::IntegerType::get(rewriter.getContext(), 32), - // // baselog - // mlir::IntegerType::get(rewriter.getContext(), 32), - // // input lwe size - // mlir::IntegerType::get(rewriter.getContext(), 32), - // // output lwe size - // mlir::IntegerType::get(rewriter.getContext(), 32), - // }, - // {mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get( - // rewriter.getContext())}); - // if (insertForwardDeclaration(op, rewriter, - // "allocate_lwe_keyswitch_key_u64", - // funcType) - // .failed()) { - // return mlir::failure(); - // } - // } - - // Insert forward declaration of the getKsk function - { - auto funcType = mlir::FunctionType::get( - rewriter.getContext(), {}, - {mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get( - rewriter.getContext())}); - if (insertForwardDeclaration(op, rewriter, "getGlobalKeyswitchKey", - funcType) - .failed()) { - return mlir::failure(); - } - } - // Insert forward declaration of the allocate_lwe_ct function - { - auto funcType = mlir::FunctionType::get( - rewriter.getContext(), - { - errType, - mlir::IntegerType::get(rewriter.getContext(), 32), - }, - {lweOperandType}); - if (insertForwardDeclaration(op, rewriter, "allocate_lwe_ciphertext_u64", - funcType) - .failed()) { - return mlir::failure(); - } - } - // TODO: build the right type here - auto lweOutputType = lweOperandType; - // Insert forward declaration of the keyswitch function - { - auto funcType = mlir::FunctionType::get( - rewriter.getContext(), - { - errType, - // ksk - mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get( - rewriter.getContext()), - // output ct - lweOutputType, - // input ct - lweOperandType, - }, - {}); - if (insertForwardDeclaration(op, rewriter, "keyswitch_lwe_u64", funcType) - .failed()) { - return mlir::failure(); - } - } - auto errOp = rewriter.create(op.getLoc(), rewriter.getIndexAttr(0)); - // allocate the result lwe ciphertext + // allocate the result lwe ciphertext, should be of a generic type, to cast + // before return auto lweSizeOp = rewriter.create( op.getLoc(), mlir::IntegerAttr::get( mlir::IntegerType::get(rewriter.getContext(), 32), op->getAttr("outputLweSize").cast().getInt())); mlir::SmallVector allocLweCtOperands{errOp, lweSizeOp}; - auto allocateLweCtOp = rewriter.replaceOpWithNewOp( - op, "allocate_lwe_ciphertext_u64", lweOutputType, allocLweCtOperands); - // allocate ksk - // auto decompLevelCountOp = rewriter.create( - // op.getLoc(), - // mlir::IntegerAttr::get( - // mlir::IntegerType::get(rewriter.getContext(), 32), - // op->getAttr("level").cast().getInt())); - // auto decompBaseLogOp = rewriter.create( - // op.getLoc(), - // mlir::IntegerAttr::get( - // mlir::IntegerType::get(rewriter.getContext(), 32), - // op->getAttr("baseLog").cast().getInt())); - // auto inputLweSizeOp = rewriter.create( - // op.getLoc(), - // mlir::IntegerAttr::get( - // mlir::IntegerType::get(rewriter.getContext(), 32), - // op->getAttr("inputLweSize").cast().getInt())); - // auto outputLweSizeOp = rewriter.create( - // op.getLoc(), - // mlir::IntegerAttr::get( - // mlir::IntegerType::get(rewriter.getContext(), 32), - // op->getAttr("outputLweSize").cast().getInt())); - // mlir::SmallVector allockskOperands{ - // errOp, decompLevelCountOp, decompBaseLogOp, inputLweSizeOp, - // outputLweSizeOp}; - // auto allocateKskOp = rewriter.create( - // op.getLoc(), "allocate_lwe_keyswitch_key_u64", - // mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get( - // rewriter.getContext()), - // allockskOperands); - + auto allocateGenericLweCtOp = rewriter.create( + op.getLoc(), "allocate_lwe_ciphertext_u64", + getGenericLweCiphertextType(rewriter.getContext()), allocLweCtOperands); // get ksk mlir::SmallVector getkskOperands{}; auto getKskOp = rewriter.create( op.getLoc(), "getGlobalKeyswitchKey", - mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get( - rewriter.getContext()), - getkskOperands); - + getGenericLweKeySwitchKeyType(rewriter.getContext()), getkskOperands); // keyswitch + // cast input ciphertext to a generic type + mlir::Value lweToKeyswitch = + rewriter + .create( + op.getLoc(), getGenericType(op.getOperand().getType()), + op.getOperand()) + .getResult(0); mlir::SmallVector keyswitchOperands{ - errOp, getKskOp.getResult(0), allocateLweCtOp.getResult(0), - op->getOperand(0)}; + errOp, getKskOp.getResult(0), allocateGenericLweCtOp.getResult(0), + lweToKeyswitch}; rewriter.create(op.getLoc(), "keyswitch_lwe_u64", mlir::TypeRange({}), keyswitchOperands); - + // Cast result to the appropriate type + auto lweOutputType = op->getResultTypes().front(); + rewriter.replaceOpWithNewOp( + op, lweOutputType, allocateGenericLweCtOp.getResult(0)); return mlir::success(); }; }; @@ -713,8 +730,14 @@ void LowLFHEToConcreteCAPIPass::runOnOperation() { mlir::RewritePatternSet patterns(&getContext()); populateLowLFHEToConcreteCAPICall(patterns); - // Apply the conversion + // Insert forward declarations mlir::ModuleOp op = getOperation(); + mlir::IRRewriter rewriter(&getContext()); + if (insertForwardDeclarations(op, rewriter).failed()) { + this->signalPassFailure(); + } + + // Apply the conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) { this->signalPassFailure(); } diff --git a/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp b/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp index 34ed4138e..b79504d94 100644 --- a/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp +++ b/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp @@ -35,6 +35,7 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() { // legalize LLVM dialect. mlir::LLVMConversionTarget target(getContext()); target.addLegalOp(); + target.addIllegalOp(); // Setup the LLVMTypeConverter (that converts `std` types to `llvm` types) and // add our types conversion to `llvm` compatible type. diff --git a/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp b/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp index 87805d462..26fb4ab1c 100644 --- a/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp +++ b/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp @@ -107,7 +107,7 @@ bool verifyEncryptedIntegerInputsConsistency(::mlir::OpState &op, "`ct` argument."; return mlir::failure(); } - // Disable this check for the moment + // Disable this check for the moment: issue/111 // Check the witdh of the encrypted integer and the integer of the tabulated // lambda are equals // if (ct.getWidth() != l_cst.getElementType().cast().getWidth()) diff --git a/compiler/lib/Dialect/MidLFHE/IR/MidLFHEOps.cpp b/compiler/lib/Dialect/MidLFHE/IR/MidLFHEOps.cpp index 073749940..e8d8d9f2c 100644 --- a/compiler/lib/Dialect/MidLFHE/IR/MidLFHEOps.cpp +++ b/compiler/lib/Dialect/MidLFHE/IR/MidLFHEOps.cpp @@ -123,7 +123,7 @@ mlir::LogicalResult verifyApplyLookupTable(ApplyLookupTable &op) { "`ct` argument."; return mlir::failure(); } - // Disable this check for the moment + // Disable this check for the moment: issue/111 // Check the witdh of the encrypted integer and the integer of the tabulated // lambda are equals // if (result.getP() < l_cst.getElementType().cast().getWidth()) diff --git a/compiler/lib/Support/CompilerTools.cpp b/compiler/lib/Support/CompilerTools.cpp index 2bc9a03d9..6da29ac63 100644 --- a/compiler/lib/Support/CompilerTools.cpp +++ b/compiler/lib/Support/CompilerTools.cpp @@ -236,7 +236,7 @@ JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) { } // Setup runtime context with appropriate keys - keySet.generateRuntimeContext(); + keySet.initGlobalRuntimeContext(); } JITLambda::Argument::~Argument() { diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir index 0aad472ba..48b9acc0d 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir @@ -1,17 +1,30 @@ // RUN: zamacompiler --passes lowlfhe-to-concrete-c-api %s 2>&1| FileCheck %s // CHECK-LABEL: module -// CHECK-NEXT: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, i32) -> !LowLFHE.lwe_ciphertext<1024,4> +// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list) +// CHECK-NEXT: func private @fill_plaintext_list_with_expansion_u64(index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list) +// CHECK-NEXT: func private @allocate_plaintext_list_u64(index, i32) -> !LowLFHE.plaintext_list +// CHECK-NEXT: func private @allocate_glwe_ciphertext_u64(index, i32, i32) -> !LowLFHE.glwe_ciphertext +// CHECK-NEXT: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) +// CHECK-NEXT: func private @getGlobalKeyswitchKey() -> !LowLFHE.lwe_key_switch_key +// CHECK-NEXT: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext) // CHECK-NEXT: func private @getGlobalBootstrapKey() -> !LowLFHE.lwe_bootstrap_key +// CHECK-NEXT: func private @negate_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) +// CHECK-NEXT: func private @mul_cleartext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.cleartext<_>) +// CHECK-NEXT: func private @add_plaintext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.plaintext<_>) +// CHECK-NEXT: func private @add_lwe_ciphertexts_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) +// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, i32) -> !LowLFHE.lwe_ciphertext<_,_> // CHECK-LABEL: func @bootstrap_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4> func @bootstrap_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4> { // CHECK-NEXT: %[[ERR:.*]] = constant 0 : index - // CHECK-NEXT: %[[C0:.*]] = constant 1 : i32 - // CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[ERR]], %[[C0]]) : (index, i32) -> !LowLFHE.lwe_ciphertext<1024,4> + // CHECK-NEXT: %[[C0:.*]] = constant 1024 : i32 + // CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[ERR]], %[[C0]]) : (index, i32) -> !LowLFHE.lwe_ciphertext<_,_> // CHECK-NEXT: %[[V2:.*]] = call @getGlobalBootstrapKey() : () -> !LowLFHE.lwe_bootstrap_key - // CHECK-NEXT: call @bootstrap_lwe_u64(%[[ERR]], %[[V2]], %[[V1]], %arg0, %arg1) : (index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> () - // CHECK-NEXT: return %[[V1]] : !LowLFHE.lwe_ciphertext<1024,4> + // CHECK-NEXT: %[[V3:.*]] = unrealized_conversion_cast %arg0 : !LowLFHE.lwe_ciphertext<1024,4> to !LowLFHE.lwe_ciphertext<_,_> + // CHECK-NEXT: %[[V4:.*]] = unrealized_conversion_cast %arg1 : !LowLFHE.glwe_ciphertext to !LowLFHE.glwe_ciphertext + // CHECK-NEXT: call @bootstrap_lwe_u64(%[[ERR]], %[[V2]], %[[V1]], %[[V3]], %[[V4]]) : (index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext) -> () + // CHECK-NEXT: %[[RES:.*]] = unrealized_conversion_cast %[[V1]] : !LowLFHE.lwe_ciphertext<_,_> to !LowLFHE.lwe_ciphertext<1024,4> + // CHECK-NEXT: return %[[RES]] : !LowLFHE.lwe_ciphertext<1024,4> %1 = "LowLFHE.bootstrap_lwe"(%arg0, %arg1) {baseLog = 2 : i32, k = 1 : i32, level = 3 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4> return %1: !LowLFHE.lwe_ciphertext<1024,4> } \ No newline at end of file diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir index 6b9beaae8..89aff9491 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir @@ -1,13 +1,22 @@ // RUN: zamacompiler --passes lowlfhe-to-concrete-c-api %s 2>&1| FileCheck %s // CHECK-LABEL: module +// CHECK-NEXT: func private @runtime_foreign_plaintext_list_u64(index, tensor<16xi64>, i64, i32) -> !LowLFHE.foreign_plaintext_list // CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list) // CHECK-NEXT: func private @fill_plaintext_list_with_expansion_u64(index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list) -// CHECK-NEXT: func private @runtime_foreign_plaintext_list_u64(index, tensor<16xi4>, i64, i32) -> !LowLFHE.foreign_plaintext_list // CHECK-NEXT: func private @allocate_plaintext_list_u64(index, i32) -> !LowLFHE.plaintext_list // CHECK-NEXT: func private @allocate_glwe_ciphertext_u64(index, i32, i32) -> !LowLFHE.glwe_ciphertext -// CHECK-LABEL: func @glwe_from_table(%arg0: tensor<16xi4>) -> !LowLFHE.glwe_ciphertext -func @glwe_from_table(%arg0: tensor<16xi4>) -> !LowLFHE.glwe_ciphertext { +// CHECK-NEXT: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) +// CHECK-NEXT: func private @getGlobalKeyswitchKey() -> !LowLFHE.lwe_key_switch_key +// CHECK-NEXT: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext) +// CHECK-NEXT: func private @getGlobalBootstrapKey() -> !LowLFHE.lwe_bootstrap_key +// CHECK-NEXT: func private @negate_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) +// CHECK-NEXT: func private @mul_cleartext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.cleartext<_>) +// CHECK-NEXT: func private @add_plaintext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.plaintext<_>) +// CHECK-NEXT: func private @add_lwe_ciphertexts_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) +// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, i32) -> !LowLFHE.lwe_ciphertext<_,_> +// CHECK-LABEL: func @glwe_from_table(%arg0: tensor<16xi64>) -> !LowLFHE.glwe_ciphertext +func @glwe_from_table(%arg0: tensor<16xi64>) -> !LowLFHE.glwe_ciphertext { // CHECK-NEXT: %[[V0:.*]] = constant 0 : index // CHECK-NEXT: %[[C0:.*]] = constant 1 : i32 // CHECK-NEXT: %[[C1:.*]] = constant 1024 : i32 @@ -16,10 +25,10 @@ func @glwe_from_table(%arg0: tensor<16xi4>) -> !LowLFHE.glwe_ciphertext { // CHECK-NEXT: %[[V3:.*]] = call @allocate_plaintext_list_u64(%[[V0]], %[[C1]]) : (index, i32) -> !LowLFHE.plaintext_list // CHECK-NEXT: %[[C2:.*]] = constant 16 : i64 // CHECK-NEXT: %[[C3:.*]] = constant 4 : i32 - // CHECK-NEXT: %[[V4:.*]] = call @runtime_foreign_plaintext_list_u64(%[[V0]], %arg0, %[[C2]], %[[C3]]) : (index, tensor<16xi4>, i64, i32) -> !LowLFHE.foreign_plaintext_list + // CHECK-NEXT: %[[V4:.*]] = call @runtime_foreign_plaintext_list_u64(%[[V0]], %arg0, %[[C2]], %[[C3]]) : (index, tensor<16xi64>, i64, i32) -> !LowLFHE.foreign_plaintext_list // CHECK-NEXT: call @fill_plaintext_list_with_expansion_u64(%[[V0]], %[[V3]], %[[V4]]) : (index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list) -> () // CHECK-NEXT: call @add_plaintext_list_glwe_ciphertext_u64(%[[V0]], %[[V1]], %[[V2]], %[[V3]]) : (index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list) -> () // CHECK-NEXT: return %[[V1]] : !LowLFHE.glwe_ciphertext - %1 = "LowLFHE.glwe_from_table"(%arg0) {k = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi4>) -> !LowLFHE.glwe_ciphertext + %1 = "LowLFHE.glwe_from_table"(%arg0) {k = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !LowLFHE.glwe_ciphertext return %1: !LowLFHE.glwe_ciphertext } \ No newline at end of file diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir index a9e684177..90b44a10b 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir @@ -1,16 +1,29 @@ // RUN: zamacompiler --passes lowlfhe-to-concrete-c-api %s 2>&1| FileCheck %s // CHECK-LABEL: module -// CHECK-NEXT: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>) -// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, i32) -> !LowLFHE.lwe_ciphertext<1024,4> +// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list) +// CHECK-NEXT: func private @fill_plaintext_list_with_expansion_u64(index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list) +// CHECK-NEXT: func private @allocate_plaintext_list_u64(index, i32) -> !LowLFHE.plaintext_list +// CHECK-NEXT: func private @allocate_glwe_ciphertext_u64(index, i32, i32) -> !LowLFHE.glwe_ciphertext +// CHECK-NEXT: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) +// CHECK-NEXT: func private @getGlobalKeyswitchKey() -> !LowLFHE.lwe_key_switch_key +// CHECK-NEXT: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext) +// CHECK-NEXT: func private @getGlobalBootstrapKey() -> !LowLFHE.lwe_bootstrap_key +// CHECK-NEXT: func private @negate_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) +// CHECK-NEXT: func private @mul_cleartext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.cleartext<_>) +// CHECK-NEXT: func private @add_plaintext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.plaintext<_>) +// CHECK-NEXT: func private @add_lwe_ciphertexts_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) +// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, i32) -> !LowLFHE.lwe_ciphertext<_,_> // CHECK-LABEL: func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> { // CHECK-NEXT: %[[ERR:.*]] = constant 0 : index // CHECK-NEXT: %[[C0:.*]] = constant 1 : i32 - // CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[ERR]], %[[C0]]) : (index, i32) -> !LowLFHE.lwe_ciphertext<1024,4> + // CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[ERR]], %[[C0]]) : (index, i32) -> !LowLFHE.lwe_ciphertext<_,_> // CHECK-NEXT: %[[V2:.*]] = call @getGlobalKeyswitchKey() : () -> !LowLFHE.lwe_key_switch_key - // CHECK-NEXT: call @keyswitch_lwe_u64(%[[ERR]], %[[V2]], %[[V1]], %arg0) : (index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>) -> () - // CHECK-NEXT: return %[[V1]] : !LowLFHE.lwe_ciphertext<1024,4> + // CHECK-NEXT: %[[V3:.*]] = unrealized_conversion_cast %arg0 : !LowLFHE.lwe_ciphertext<1024,4> to !LowLFHE.lwe_ciphertext<_,_> + // CHECK-NEXT: call @keyswitch_lwe_u64(%[[ERR]], %[[V2]], %[[V1]], %[[V3]]) : (index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) -> () + // CHECK-NEXT: %[[RES:.*]] = unrealized_conversion_cast %[[V1]] : !LowLFHE.lwe_ciphertext<_,_> to !LowLFHE.lwe_ciphertext<1024,4> + // CHECK-NEXT: return %[[RES]] : !LowLFHE.lwe_ciphertext<1024,4> %1 = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 1 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> return %1: !LowLFHE.lwe_ciphertext<1024,4> } \ No newline at end of file diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir index 828a438f5..f3cda0d8f 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir @@ -3,8 +3,8 @@ // CHECK-LABEL: func @apply_lookup_table(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: tensor<16xi4>) -> !LowLFHE.lwe_ciphertext<1024,4> func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi4>) -> !MidLFHE.glwe<{1024,1,64}{4}> { // CHECK-NEXT: %[[V1:.*]] = "LowLFHE.glwe_from_table"(%arg1) {k = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi4>) -> !LowLFHE.glwe_ciphertext - // CHECK-NEXT: %[[V2:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 600 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> - // CHECK-NEXT: %[[V3:.*]] = "LowLFHE.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, k = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4> + // CHECK-NEXT: %[[V2:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 600 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<600,4> + // CHECK-NEXT: %[[V3:.*]] = "LowLFHE.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, k = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<600,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4> // CHECK-NEXT: return %[[V3]] : !LowLFHE.lwe_ciphertext<1024,4> %1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1){k=1:i32, polynomialSize=1024:i32, levelKS=3:i32, baseLogKS=2:i32, levelBS=5:i32, baseLogBS=4:i32, outputSizeKS=600:i32}: (!MidLFHE.glwe<{1024,1,64}{4}>, tensor<16xi4>) -> (!MidLFHE.glwe<{1024,1,64}{4}>) return %1: !MidLFHE.glwe<{1024,1,64}{4}> diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir index 06c678633..e202ae025 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir @@ -4,8 +4,8 @@ func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{2048,1,64}{4}>) -> !MidLFHE.glwe<{2048,1,64}{4}> { // CHECK-NEXT: %[[TABLE:.*]] = constant dense<[0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1]> : tensor<16xi4> // CHECK-NEXT: %[[V1:.*]] = "LowLFHE.glwe_from_table"(%[[TABLE]]) {k = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32} : (tensor<16xi4>) -> !LowLFHE.glwe_ciphertext - // CHECK-NEXT: %[[V2:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 600 : i32} : (!LowLFHE.lwe_ciphertext<2048,4>) -> !LowLFHE.lwe_ciphertext<2048,4> - // CHECK-NEXT: %[[V3:.*]] = "LowLFHE.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, k = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (!LowLFHE.lwe_ciphertext<2048,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<2048,4> + // CHECK-NEXT: %[[V2:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 600 : i32} : (!LowLFHE.lwe_ciphertext<2048,4>) -> !LowLFHE.lwe_ciphertext<600,4> + // CHECK-NEXT: %[[V3:.*]] = "LowLFHE.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, k = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (!LowLFHE.lwe_ciphertext<600,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<2048,4> // CHECK-NEXT: return %[[V3]] : !LowLFHE.lwe_ciphertext<2048,4> %tlu = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi4> %1 = "MidLFHE.apply_lookup_table"(%arg0, %tlu){k=1:i32, polynomialSize=2048:i32, levelKS=3:i32, baseLogKS=2:i32, levelBS=5:i32, baseLogBS=4:i32, outputSizeKS=600:i32}: (!MidLFHE.glwe<{2048,1,64}{4}>, tensor<16xi4>) -> (!MidLFHE.glwe<{2048,1,64}{4}>) diff --git a/compiler/tests/python/test_compiler_engine.py b/compiler/tests/python/test_compiler_engine.py index 49ae284ba..52ba1b074 100644 --- a/compiler/tests/python/test_compiler_engine.py +++ b/compiler/tests/python/test_compiler_engine.py @@ -29,7 +29,6 @@ def test_compile_and_run(mlir_input, args, expected_result): ( """ func @main(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { - // 0..128 shifted << 55 %tlu = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64> %1 = "HLFHE.apply_lookup_table"(%arg0, %tlu): (!HLFHE.eint<7>, tensor<128xi64>) -> (!HLFHE.eint<7>) return %1: !HLFHE.eint<7>