From fb58dcc59dcf361fec55ab73b5088cf3b91c477a Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Fri, 26 Nov 2021 18:02:12 +0100 Subject: [PATCH] enhance(compiler/lowlfhe): Give the runtime context as function argument instead of a global variable (close #195) --- .../Dialect/LowLFHE/IR/LowLFHETypes.td | 16 ++ compiler/include/zamalang/Runtime/context.h | 19 +- compiler/include/zamalang/Support/Jit.h | 1 + compiler/include/zamalang/Support/KeySet.h | 7 +- .../LowLFHEToConcreteCAPI.cpp | 167 +++++++++++++++--- .../MLIRLowerableDialectsToLLVM.cpp | 1 + .../lib/Dialect/LowLFHE/IR/LowLFHEDialect.cpp | 24 +-- compiler/lib/Runtime/context.c | 45 +---- compiler/lib/Support/ClientParameters.cpp | 7 +- compiler/lib/Support/Jit.cpp | 10 +- compiler/lib/Support/Pipeline.cpp | 3 +- .../LowLFHEToConcreteCAPI/bootstrap.mlir | 22 +-- .../glwe_from_table.mlir | 21 +-- .../LowLFHEToConcreteCAPI/keyswitch_lwe.mlir | 22 +-- 14 files changed, 201 insertions(+), 164 deletions(-) diff --git a/compiler/include/zamalang/Dialect/LowLFHE/IR/LowLFHETypes.td b/compiler/include/zamalang/Dialect/LowLFHE/IR/LowLFHETypes.td index 58dc8a085..10831eda3 100644 --- a/compiler/include/zamalang/Dialect/LowLFHE/IR/LowLFHETypes.td +++ b/compiler/include/zamalang/Dialect/LowLFHE/IR/LowLFHETypes.td @@ -210,7 +210,23 @@ def LweBootstrapKeyType : LowLFHE_Type<"LweBootstrapKey"> { }]; } +def Context : LowLFHE_Type<"Context"> { + let mnemonic = "context"; + let summary = "Runtime context"; + + let description = [{ + An abstract runtime context to pass contextual value, like public keys, ... + }]; + + let printer = [{ + $_printer << "context"; + }]; + + let parser = [{ + return get($_ctxt); + }]; +} diff --git a/compiler/include/zamalang/Runtime/context.h b/compiler/include/zamalang/Runtime/context.h index e0f883176..5f62418d7 100644 --- a/compiler/include/zamalang/Runtime/context.h +++ b/compiler/include/zamalang/Runtime/context.h @@ -8,23 +8,8 @@ typedef struct RuntimeContext { struct LweBootstrapKey_u64 *bsk; } RuntimeContext; -extern RuntimeContext *globalRuntimeContext; +LweKeyswitchKey_u64 *get_keyswitch_key(RuntimeContext *context); -RuntimeContext *createRuntimeContext(LweKeyswitchKey_u64 *ksk, - LweBootstrapKey_u64 *bsk); - -void setGlobalRuntimeContext(RuntimeContext *context); - -RuntimeContext *getGlobalRuntimeContext(); - -LweKeyswitchKey_u64 *getGlobalKeyswitchKey(); - -LweBootstrapKey_u64 *getGlobalBootstrapKey(); - -LweKeyswitchKey_u64 *getKeyswitckKeyFromContext(RuntimeContext *context); - -LweBootstrapKey_u64 *getBootstrapKeyFromContext(RuntimeContext *context); - -bool checkError(int *err); +LweBootstrapKey_u64 *get_bootstrap_key(RuntimeContext *context); #endif diff --git a/compiler/include/zamalang/Support/Jit.h b/compiler/include/zamalang/Support/Jit.h index 61bfff18e..d959f1460 100644 --- a/compiler/include/zamalang/Support/Jit.h +++ b/compiler/include/zamalang/Support/Jit.h @@ -96,6 +96,7 @@ public: std::vector ciphertextBuffers; KeySet &keySet; + RuntimeContext context; }; JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name) : type(type), name(name){}; diff --git a/compiler/include/zamalang/Support/KeySet.h b/compiler/include/zamalang/Support/KeySet.h index 0b4f50db3..aab217caf 100644 --- a/compiler/include/zamalang/Support/KeySet.h +++ b/compiler/include/zamalang/Support/KeySet.h @@ -41,10 +41,9 @@ public: CircuitGate inputGate(size_t pos) { return std::get<0>(inputs[pos]); } CircuitGate outputGate(size_t pos) { return std::get<0>(outputs[pos]); } - void initGlobalRuntimeContext() { - auto ksk = std::get<1>(this->keyswitchKeys["ksk_v0"]); - auto bsk = std::get<1>(this->bootstrapKeys["bsk_v0"]); - setGlobalRuntimeContext(createRuntimeContext(ksk, bsk)); + void setRuntimeContext(RuntimeContext &context) { + context.ksk = std::get<1>(this->keyswitchKeys["ksk_v0"]); + context.bsk = std::get<1>(this->bootstrapKeys["bsk_v0"]); } protected: diff --git a/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp b/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp index f97dc62fb..3fe5b3ac5 100644 --- a/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp +++ b/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp @@ -133,6 +133,9 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, auto genericCleartextType = getGenericCleartextType(rewriter.getContext()); auto genericBSKType = getGenericLweBootstrapKeyType(rewriter.getContext()); auto genericKSKType = getGenericLweKeySwitchKeyType(rewriter.getContext()); + auto contextType = + mlir::zamalang::LowLFHE::ContextType::get(rewriter.getContext()); + auto errType = mlir::IndexType::get(rewriter.getContext()); // Insert forward declaration of allocate lwe ciphertext @@ -211,10 +214,9 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, } // Insert forward declaration of the getBsk function { - auto funcType = - mlir::FunctionType::get(rewriter.getContext(), {}, {genericBSKType}); - if (insertForwardDeclaration(op, rewriter, "getGlobalBootstrapKey", - funcType) + auto funcType = mlir::FunctionType::get(rewriter.getContext(), + {contextType}, {genericBSKType}); + if (insertForwardDeclaration(op, rewriter, "get_bootstrap_key", funcType) .failed()) { return mlir::failure(); } @@ -237,10 +239,9 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, } // Insert forward declaration of the getKsk function { - auto funcType = - mlir::FunctionType::get(rewriter.getContext(), {}, {genericKSKType}); - if (insertForwardDeclaration(op, rewriter, "getGlobalKeyswitchKey", - funcType) + auto funcType = mlir::FunctionType::get(rewriter.getContext(), + {contextType}, {genericKSKType}); + if (insertForwardDeclaration(op, rewriter, "get_keyswitch_key", funcType) .failed()) { return mlir::failure(); } @@ -563,6 +564,25 @@ struct GlweFromTableOpPattern }; }; +mlir::Value getContextArgument(mlir::Operation *op) { + mlir::Block *block = op->getBlock(); + while (block != nullptr) { + if (llvm::isa(block->getParentOp())) { + + mlir::Value context = block->getArguments().back(); + + assert(context.getType().isa() && + "the LowLFHE.context should be the last argument of the enclosing " + "function of the op"); + + return context; + } + block = block->getParentOp()->getBlock(); + } + assert("can't find a function that enclose the op"); + return nullptr; +} + // Rewrite a BootstrapLweOp with a series of ops: // - allocate the result LWE ciphertext // - get the global bootstrapping key @@ -592,10 +612,10 @@ struct LowLFHEBootstrapLweOpPattern op.getLoc(), "allocate_lwe_ciphertext_u64", getGenericLweCiphertextType(rewriter.getContext()), allocLweCtOperands); // get bsk - mlir::SmallVector getBskOperands{}; auto getBskOp = rewriter.create( - op.getLoc(), "getGlobalBootstrapKey", - getGenericLweBootstrapKeyType(rewriter.getContext()), getBskOperands); + op.getLoc(), "get_bootstrap_key", + getGenericLweBootstrapKeyType(rewriter.getContext()), + mlir::SmallVector{getContextArgument(op)}); // bootstrap // cast input ciphertext to a generic type mlir::Value lweToBootstrap = @@ -651,10 +671,10 @@ struct LowLFHEKeySwitchLweOpPattern op.getLoc(), "allocate_lwe_ciphertext_u64", getGenericLweCiphertextType(rewriter.getContext()), allocLweCtOperands); // get ksk - mlir::SmallVector getkskOperands{}; auto getKskOp = rewriter.create( - op.getLoc(), "getGlobalKeyswitchKey", - getGenericLweKeySwitchKeyType(rewriter.getContext()), getkskOperands); + op.getLoc(), "get_keyswitch_key", + getGenericLweKeySwitchKeyType(rewriter.getContext()), + mlir::SmallVector{getContextArgument(op)}); // keyswitch // cast input ciphertext to a generic type mlir::Value lweToKeyswitch = @@ -703,6 +723,73 @@ void populateLowLFHEToConcreteCAPICall(mlir::RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } +struct AddRuntimeContextToFuncOpPattern + : public mlir::OpRewritePattern { + AddRuntimeContextToFuncOpPattern(mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : mlir::OpRewritePattern(context, benefit) {} + + mlir::LogicalResult + matchAndRewrite(mlir::FuncOp oldFuncOp, + mlir::PatternRewriter &rewriter) const override { + mlir::OpBuilder::InsertionGuard guard(rewriter); + mlir::FunctionType oldFuncType = oldFuncOp.getType(); + + // Add a LowLFHE.context to the function signature + mlir::SmallVector newInputs(oldFuncType.getInputs().begin(), + oldFuncType.getInputs().end()); + newInputs.push_back( + rewriter.getType()); + mlir::FunctionType newFuncTy = rewriter.getType( + newInputs, oldFuncType.getResults()); + // Create the new func + mlir::FuncOp newFuncOp = rewriter.create( + oldFuncOp.getLoc(), oldFuncOp.getName(), newFuncTy); + + // Create the arguments of the new func + mlir::Region &newFuncBody = newFuncOp.body(); + mlir::Block *newFuncEntryBlock = new mlir::Block(); + newFuncEntryBlock->addArguments(newFuncTy.getInputs()); + newFuncBody.push_back(newFuncEntryBlock); + + // Clone the old body to the new one + mlir::BlockAndValueMapping map; + for (auto arg : llvm::enumerate(oldFuncOp.getArguments())) { + map.map(arg.value(), newFuncEntryBlock->getArgument(arg.index())); + } + for (auto &op : oldFuncOp.body().front()) { + newFuncEntryBlock->push_back(op.clone(map)); + } + rewriter.eraseOp(oldFuncOp); + return mlir::success(); + } + + // Legal function are one that are private or has a LowLFHE.context as last + // arguments. + static bool isLegal(mlir::FuncOp funcOp) { + if (!funcOp.isPublic()) { + return true; + } + // TODO : Don't need to add a runtime context for function that doesn't + // manipulates lowlfhe types. + // + // if (!llvm::any_of(funcOp.getType().getInputs(), [](mlir::Type t) { + // if (auto tensorTy = t.dyn_cast_or_null()) { + // t = tensorTy.getElementType(); + // } + // return llvm::isa( + // t.getDialect()); + // })) { + // return true; + // } + return funcOp.getType().getNumInputs() >= 1 && + funcOp.getType() + .getInputs() + .back() + .isa(); + } +}; + namespace { struct LowLFHEToConcreteCAPIPass : public LowLFHEToConcreteCAPIBase { @@ -711,27 +798,49 @@ struct LowLFHEToConcreteCAPIPass } // namespace void LowLFHEToConcreteCAPIPass::runOnOperation() { - // Setup the conversion target. - mlir::ConversionTarget target(getContext()); - target.addIllegalDialect(); - target.addLegalDialect(); - - // Setup rewrite patterns - mlir::RewritePatternSet patterns(&getContext()); - populateLowLFHEToConcreteCAPICall(patterns); - - // Insert forward declarations mlir::ModuleOp op = getOperation(); + + // First of all add the LowLFHE.context to the block arguments of function + // that manipulates ciphertexts. + { + mlir::ConversionTarget target(getContext()); + mlir::RewritePatternSet patterns(&getContext()); + + target.addDynamicallyLegalOp([&](mlir::FuncOp funcOp) { + return AddRuntimeContextToFuncOpPattern::isLegal(funcOp); + }); + + patterns.add(patterns.getContext()); + + // Apply the conversion + if (mlir::applyPartialConversion(op, target, std::move(patterns)) + .failed()) { + this->signalPassFailure(); + return; + } + } + + // Insert forward declaration mlir::IRRewriter rewriter(&getContext()); if (insertForwardDeclarations(op, rewriter).failed()) { this->signalPassFailure(); } + // Rewrite LowLFHE ops to CallOp to the Concrete C API + { + mlir::ConversionTarget target(getContext()); + mlir::RewritePatternSet patterns(&getContext()); - // Apply the conversion - if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) { - this->signalPassFailure(); + target.addIllegalDialect(); + target.addLegalDialect(); + + populateLowLFHEToConcreteCAPICall(patterns); + + if (mlir::applyPartialConversion(op, target, std::move(patterns)) + .failed()) { + this->signalPassFailure(); + } } } diff --git a/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp b/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp index e147b7d38..0fba4a7bb 100644 --- a/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp +++ b/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp @@ -70,6 +70,7 @@ MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) { type.isa() || type.isa() || type.isa() || + type.isa() || type.isa() || type.isa()) { return mlir::LLVM::LLVMPointerType::get( diff --git a/compiler/lib/Dialect/LowLFHE/IR/LowLFHEDialect.cpp b/compiler/lib/Dialect/LowLFHE/IR/LowLFHEDialect.cpp index 60e193e07..e57103060 100644 --- a/compiler/lib/Dialect/LowLFHE/IR/LowLFHEDialect.cpp +++ b/compiler/lib/Dialect/LowLFHE/IR/LowLFHEDialect.cpp @@ -25,24 +25,9 @@ void LowLFHEDialect::initialize() { mlir::Type type; std::string types_str[] = { - "enc_rand_gen", - "secret_rand_gen", - "plaintext", - "plaintext_list", - "foreign_plaintext_list", - "lwe_ciphertext", - "lwe_key_switch_key", - "lwe_bootstrap_key", - "lwe_secret_key", - "lwe_size", - "glwe_ciphertext", - "glwe_secret_key", - "glwe_size", - "polynomial_size", - "decomp_level_count", - "decomp_base_log", - "variance", - "cleartext", + "plaintext", "plaintext_list", "foreign_plaintext_list", + "lwe_ciphertext", "lwe_key_switch_key", "lwe_bootstrap_key", + "glwe_ciphertext", "cleartext", "context", }; for (const std::string &type_str : types_str) { @@ -53,8 +38,7 @@ void LowLFHEDialect::initialize() { } parser.emitError(parser.getCurrentLocation(), "Unknown LowLFHE type"); - // call default parser - parser.parseType(type); + return type; } diff --git a/compiler/lib/Runtime/context.c b/compiler/lib/Runtime/context.c index 6074858a7..8da6d7e4f 100644 --- a/compiler/lib/Runtime/context.c +++ b/compiler/lib/Runtime/context.c @@ -1,51 +1,10 @@ #include "zamalang/Runtime/context.h" #include -RuntimeContext *globalRuntimeContext; - -RuntimeContext *createRuntimeContext(LweKeyswitchKey_u64 *ksk, - LweBootstrapKey_u64 *bsk) { - RuntimeContext *context = (RuntimeContext *)malloc(sizeof(RuntimeContext)); - context->ksk = ksk; - context->bsk = bsk; - return context; -} - -void setGlobalRuntimeContext(RuntimeContext *context) { - globalRuntimeContext = context; -} - -RuntimeContext *getGlobalRuntimeContext() { return globalRuntimeContext; } - -LweKeyswitchKey_u64 *getGlobalKeyswitchKey() { - return globalRuntimeContext->ksk; -} - -LweBootstrapKey_u64 *getGlobalBootstrapKey() { - return globalRuntimeContext->bsk; -} - -LweKeyswitchKey_u64 *getKeyswitckKeyFromContext(RuntimeContext *context) { +LweKeyswitchKey_u64 *get_keyswitch_key(RuntimeContext *context) { return context->ksk; } -LweBootstrapKey_u64 *getBootstrapKeyFromContext(RuntimeContext *context) { +LweBootstrapKey_u64 *get_bootstrap_key(RuntimeContext *context) { return context->bsk; } - -bool checkError(int *err) { - switch (*err) { - case ERR_INDEX_OUT_OF_BOUND: - fprintf(stderr, "Runtime: index out of bound"); - break; - case ERR_NULL_POINTER: - fprintf(stderr, "Runtime: null pointer"); - break; - case ERR_SIZE_MISMATCH: - fprintf(stderr, "Runtime: size mismatch"); - break; - default: - return false; - } - return true; -} \ No newline at end of file diff --git a/compiler/lib/Support/ClientParameters.cpp b/compiler/lib/Support/ClientParameters.cpp index 335e490cc..0bd9ee519 100644 --- a/compiler/lib/Support/ClientParameters.cpp +++ b/compiler/lib/Support/ClientParameters.cpp @@ -125,8 +125,11 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name, // Create input and output circuit gate parameters auto funcType = (*funcOp).getType(); - for (auto inType : funcType.getInputs()) { - auto gate = gateFromMLIRType("big", precision, encryptionVariance, inType); + bool hasContext = + funcType.getInputs().back().isa(); + for (auto inType = funcType.getInputs().begin(); + inType < funcType.getInputs().end() - hasContext; inType++) { + auto gate = gateFromMLIRType("big", precision, encryptionVariance, *inType); if (auto err = gate.takeError()) { return std::move(err); } diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 181ba654b..18badbddd 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -78,8 +78,8 @@ llvm::Error JITLambda::invoke(Argument &args) { JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) { // Setting the inputs + auto numInputs = 0; { - auto numInputs = 0; for (size_t i = 0; i < keySet.numInputs(); i++) { auto offset = numInputs; auto gate = keySet.inputGate(i); @@ -95,6 +95,8 @@ JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) { // dimension of the tensor. numInputs = numInputs + 2 * keySet.inputGate(i).shape.dimensions.size(); } + // Reserve for the context argument + numInputs = numInputs + 1; inputs = std::vector(numInputs); } @@ -128,8 +130,10 @@ JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) { rawArg[i] = &outputs[i - inputs.size()]; } - // Setup runtime context with appropriate keys - keySet.initGlobalRuntimeContext(); + // Set the context argument + keySet.setRuntimeContext(context); + inputs[numInputs - 1] = &context; + rawArg[numInputs - 1] = &inputs[numInputs - 1]; } JITLambda::Argument::~Argument() { diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 310247e25..295096a3d 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -140,7 +140,8 @@ lowerLowLFHEToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass) { mlir::PassManager pm(&context); pipelinePrinting("LowLFHEToStd", pm, context); - pm.addPass(mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass()); + addPotentiallyNestedPass( + pm, mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass(), enablePass); return pm.run(module.getOperation()); } diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir index 57158e1f5..7f22ad340 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir @@ -1,25 +1,17 @@ // RUN: zamacompiler --passes lowlfhe-to-concrete-c-api --action=dump-std %s 2>&1| FileCheck %s // CHECK-LABEL: module -// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list) -// CHECK-NEXT: func private @fill_plaintext_list_with_expansion_u64(index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list) -// CHECK-NEXT: func private @allocate_plaintext_list_u64(index, i32) -> !LowLFHE.plaintext_list -// CHECK-NEXT: func private @allocate_glwe_ciphertext_u64(index, i32, i32) -> !LowLFHE.glwe_ciphertext -// CHECK-NEXT: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) -// CHECK-NEXT: func private @getGlobalKeyswitchKey() -> !LowLFHE.lwe_key_switch_key -// CHECK-NEXT: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext) -// CHECK-NEXT: func private @getGlobalBootstrapKey() -> !LowLFHE.lwe_bootstrap_key -// CHECK-NEXT: func private @negate_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) -// CHECK-NEXT: func private @mul_cleartext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.cleartext<_>) -// CHECK-NEXT: func private @add_plaintext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.plaintext<_>) -// CHECK-NEXT: func private @add_lwe_ciphertexts_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) -// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, index) -> !LowLFHE.lwe_ciphertext<_,_> -// CHECK-LABEL: func @bootstrap_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4> +// CHECK: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) +// CHECK: func private @get_keyswitch_key(!LowLFHE.context) -> !LowLFHE.lwe_key_switch_key +// CHECK: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext) +// CHECK: func private @get_bootstrap_key(!LowLFHE.context) -> !LowLFHE.lwe_bootstrap_key +// CHECK: func private @allocate_lwe_ciphertext_u64(index, index) -> !LowLFHE.lwe_ciphertext<_,_> +// CHECK-LABEL: func @bootstrap_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.glwe_ciphertext, %arg2: !LowLFHE.context) -> !LowLFHE.lwe_ciphertext<1024,4> func @bootstrap_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4> { // CHECK-NEXT: %[[ERR:.*]] = arith.constant 0 : index // CHECK-NEXT: %[[C0:.*]] = arith.constant 1024 : index // CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[ERR]], %[[C0]]) : (index, index) -> !LowLFHE.lwe_ciphertext<_,_> - // CHECK-NEXT: %[[V2:.*]] = call @getGlobalBootstrapKey() : () -> !LowLFHE.lwe_bootstrap_key + // CHECK-NEXT: %[[V2:.*]] = call @get_bootstrap_key(%arg2) : (!LowLFHE.context) -> !LowLFHE.lwe_bootstrap_key // CHECK-NEXT: %[[V3:.*]] = builtin.unrealized_conversion_cast %arg0 : !LowLFHE.lwe_ciphertext<1024,4> to !LowLFHE.lwe_ciphertext<_,_> // CHECK-NEXT: %[[V4:.*]] = builtin.unrealized_conversion_cast %arg1 : !LowLFHE.glwe_ciphertext to !LowLFHE.glwe_ciphertext // CHECK-NEXT: call @bootstrap_lwe_u64(%[[ERR]], %[[V2]], %[[V1]], %[[V3]], %[[V4]]) : (index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext) -> () diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir index 815366d04..3edf320b6 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir @@ -1,21 +1,12 @@ // RUN: zamacompiler --passes lowlfhe-to-concrete-c-api --action=dump-std %s 2>&1| FileCheck %s // CHECK-LABEL: module -// CHECK-NEXT: func private @runtime_foreign_plaintext_list_u64(index, tensor<16xi64>, i64, i32) -> !LowLFHE.foreign_plaintext_list -// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list) -// CHECK-NEXT: func private @fill_plaintext_list_with_expansion_u64(index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list) -// CHECK-NEXT: func private @allocate_plaintext_list_u64(index, i32) -> !LowLFHE.plaintext_list -// CHECK-NEXT: func private @allocate_glwe_ciphertext_u64(index, i32, i32) -> !LowLFHE.glwe_ciphertext -// CHECK-NEXT: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) -// CHECK-NEXT: func private @getGlobalKeyswitchKey() -> !LowLFHE.lwe_key_switch_key -// CHECK-NEXT: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext) -// CHECK-NEXT: func private @getGlobalBootstrapKey() -> !LowLFHE.lwe_bootstrap_key -// CHECK-NEXT: func private @negate_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) -// CHECK-NEXT: func private @mul_cleartext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.cleartext<_>) -// CHECK-NEXT: func private @add_plaintext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.plaintext<_>) -// CHECK-NEXT: func private @add_lwe_ciphertexts_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) -// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, index) -> !LowLFHE.lwe_ciphertext<_,_> -// CHECK-LABEL: func @glwe_from_table(%arg0: tensor<16xi64>) -> !LowLFHE.glwe_ciphertext +// CHECK: func private @runtime_foreign_plaintext_list_u64(index, tensor<16xi64>, i64, i32) -> !LowLFHE.foreign_plaintext_list +// CHECK: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list) +// CHECK: func private @fill_plaintext_list_with_expansion_u64(index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list) +// CHECK: func private @allocate_plaintext_list_u64(index, i32) -> !LowLFHE.plaintext_list +// CHECK: func private @allocate_glwe_ciphertext_u64(index, i32, i32) -> !LowLFHE.glwe_ciphertext +// CHECK-LABEL: func @glwe_from_table(%arg0: tensor<16xi64>, %arg1: !LowLFHE.context) -> !LowLFHE.glwe_ciphertext func @glwe_from_table(%arg0: tensor<16xi64>) -> !LowLFHE.glwe_ciphertext { // CHECK-NEXT: %[[V0:.*]] = arith.constant 0 : index // CHECK-NEXT: %[[C0:.*]] = arith.constant 1 : i32 diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir index 5077fd23d..179f2052f 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir @@ -1,25 +1,17 @@ // RUN: zamacompiler --passes lowlfhe-to-concrete-c-api --action=dump-std %s 2>&1| FileCheck %s // CHECK-LABEL: module -// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list) -// CHECK-NEXT: func private @fill_plaintext_list_with_expansion_u64(index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list) -// CHECK-NEXT: func private @allocate_plaintext_list_u64(index, i32) -> !LowLFHE.plaintext_list -// CHECK-NEXT: func private @allocate_glwe_ciphertext_u64(index, i32, i32) -> !LowLFHE.glwe_ciphertext -// CHECK-NEXT: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) -// CHECK-NEXT: func private @getGlobalKeyswitchKey() -> !LowLFHE.lwe_key_switch_key -// CHECK-NEXT: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext) -// CHECK-NEXT: func private @getGlobalBootstrapKey() -> !LowLFHE.lwe_bootstrap_key -// CHECK-NEXT: func private @negate_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) -// CHECK-NEXT: func private @mul_cleartext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.cleartext<_>) -// CHECK-NEXT: func private @add_plaintext_lwe_ciphertext_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.plaintext<_>) -// CHECK-NEXT: func private @add_lwe_ciphertexts_u64(index, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) -// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, index) -> !LowLFHE.lwe_ciphertext<_,_> -// CHECK-LABEL: func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> +// CHECK: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) +// CHECK: func private @get_keyswitch_key(!LowLFHE.context) -> !LowLFHE.lwe_key_switch_key +// CHECK: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.glwe_ciphertext) +// CHECK: func private @get_bootstrap_key(!LowLFHE.context) -> !LowLFHE.lwe_bootstrap_key +// CHECK: func private @allocate_lwe_ciphertext_u64(index, index) -> !LowLFHE.lwe_ciphertext<_,_> +// CHECK-LABEL: func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.context) -> !LowLFHE.lwe_ciphertext<1024,4> func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> { // CHECK-NEXT: %[[ERR:.*]] = arith.constant 0 : index // CHECK-NEXT: %[[C0:.*]] = arith.constant 1 : index // CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[ERR]], %[[C0]]) : (index, index) -> !LowLFHE.lwe_ciphertext<_,_> - // CHECK-NEXT: %[[V2:.*]] = call @getGlobalKeyswitchKey() : () -> !LowLFHE.lwe_key_switch_key + // CHECK-NEXT: %[[V2:.*]] = call @get_keyswitch_key(%arg1) : (!LowLFHE.context) -> !LowLFHE.lwe_key_switch_key // CHECK-NEXT: %[[V3:.*]] = builtin.unrealized_conversion_cast %arg0 : !LowLFHE.lwe_ciphertext<1024,4> to !LowLFHE.lwe_ciphertext<_,_> // CHECK-NEXT: call @keyswitch_lwe_u64(%[[ERR]], %[[V2]], %[[V1]], %[[V3]]) : (index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<_,_>, !LowLFHE.lwe_ciphertext<_,_>) -> () // CHECK-NEXT: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[V1]] : !LowLFHE.lwe_ciphertext<_,_> to !LowLFHE.lwe_ciphertext<1024,4>