From d0877536ed8659cf8f03acc44ca14a88c8bb9119 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Wed, 4 Aug 2021 15:12:48 +0200 Subject: [PATCH] feat(compiler): First draft of client parameters generation, runtime support for encrypting and decrypting circuit gates, integration of fhe parameters for the v0 (#65, #66, #56) --- compiler/CMakeLists.txt | 7 +- compiler/Makefile | 5 +- .../zamalang/Support/ClientParameters.h | 81 +++++ .../include/zamalang/Support/CompilerTools.h | 55 +++- compiler/include/zamalang/Support/KeySet.h | 65 +++++ .../include/zamalang/Support/V0Parameters.h | 29 ++ .../MLIRLowerableDialectsToLLVM.cpp | 19 +- compiler/lib/Support/CMakeLists.txt | 8 +- compiler/lib/Support/ClientParameters.cpp | 117 ++++++++ compiler/lib/Support/CompilerTools.cpp | 97 +++++- compiler/lib/Support/KeySet.cpp | 276 ++++++++++++++++++ compiler/lib/Support/V0Parameters.cpp | 134 +++++++++ compiler/src/main.cpp | 129 ++++++-- compiler/tests/RunJit/lowlfhe_identity.mlir | 6 + 14 files changed, 984 insertions(+), 44 deletions(-) create mode 100644 compiler/include/zamalang/Support/ClientParameters.h create mode 100644 compiler/include/zamalang/Support/KeySet.h create mode 100644 compiler/include/zamalang/Support/V0Parameters.h create mode 100644 compiler/lib/Support/ClientParameters.cpp create mode 100644 compiler/lib/Support/KeySet.cpp create mode 100644 compiler/lib/Support/V0Parameters.cpp create mode 100644 compiler/tests/RunJit/lowlfhe_identity.mlir diff --git a/compiler/CMakeLists.txt b/compiler/CMakeLists.txt index e818b19b0..69a56fc4f 100644 --- a/compiler/CMakeLists.txt +++ b/compiler/CMakeLists.txt @@ -6,7 +6,6 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) set(CMAKE_CXX_STANDARD 14) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-rtti") find_package(MLIR REQUIRED CONFIG) message(STATUS "Using MLIR cmake file from: ${MLIR_DIR}") @@ -27,6 +26,12 @@ include_directories(${PROJECT_BINARY_DIR}/include) link_directories(${LLVM_BUILD_LIBRARY_DIR}) add_definitions(${LLVM_DEFINITIONS}) +#------------------------------------------------------------------------------- +# Concrete FFI Configuration +#------------------------------------------------------------------------------- +include_directories(${CONCRETE_FFI_RELEASE}) +add_library(Concrete SHARED IMPORTED) +set_target_properties(Concrete PROPERTIES IMPORTED_LOCATION ${CONCRETE_FFI_RELEASE}/libconcrete_ffi.so ) #------------------------------------------------------------------------------- # Python Configuration diff --git a/compiler/Makefile b/compiler/Makefile index db107e5c5..dc3fbe0fd 100644 --- a/compiler/Makefile +++ b/compiler/Makefile @@ -1,5 +1,8 @@ build: - cmake -B build . -DLLVM_DIR=${LLVM_PROJECT}/build/lib/cmake/llvm -DMLIR_DIR=${LLVM_PROJECT}/build/lib/cmake/mlir + cmake -B build . \ + -DLLVM_DIR=${LLVM_PROJECT}/build/lib/cmake/llvm \ + -DMLIR_DIR=${LLVM_PROJECT}/build/lib/cmake/mlir \ + -DCONCRETE_FFI_RELEASE=${CONCRETE_PROJECT}/target/release zamacompiler: make -C build/ zamacompiler diff --git a/compiler/include/zamalang/Support/ClientParameters.h b/compiler/include/zamalang/Support/ClientParameters.h new file mode 100644 index 000000000..d2fbac27b --- /dev/null +++ b/compiler/include/zamalang/Support/ClientParameters.h @@ -0,0 +1,81 @@ +#ifndef ZAMALANG_SUPPORT_CLIENTPARAMETERS_H_ +#define ZAMALANG_SUPPORT_CLIENTPARAMETERS_H_ +#include +#include +#include + +#include +#include + +#include "zamalang/Support/V0Parameters.h" + +namespace mlir { +namespace zamalang { + +typedef size_t DecompositionLevelCount; +typedef size_t DecompositionBaseLog; +typedef size_t PolynomialSize; +typedef size_t Precision; +typedef double Variance; + +typedef uint64_t LweSize; +typedef uint64_t GLWESize; + +typedef std::string LweSecretKeyID; +struct LweSecretKeyParam { + LweSize size; +}; + +typedef std::string BootstrapKeyID; +struct BootstrapKeyParam { + LweSecretKeyID inputSecretKeyID; + LweSecretKeyID outputSecretKeyID; + DecompositionLevelCount level; + DecompositionBaseLog baseLog; + GLWESize k; + Variance variance; +}; + +typedef std::string KeyswitchKeyID; +struct KeyswitchKeyParam { + LweSecretKeyID inputSecretKeyID; + LweSecretKeyID outputSecretKeyID; + DecompositionLevelCount level; + DecompositionBaseLog baseLog; + Variance variance; +}; + +struct Encoding { + Precision precision; +}; + +struct EncryptionGate { + LweSecretKeyID secretKeyID; + Variance variance; + Encoding encoding; +}; + +struct CircuitGateShape { + uint64_t size; +}; + +struct CircuitGate { + llvm::Optional encryption; + CircuitGateShape shape; +}; + +struct ClientParameters { + std::map secretKeys; + std::map bootstrapKeys; + std::map keyswitchKeys; + std::vector inputs; + std::vector outputs; +}; + +llvm::Expected +createClientParametersForV0(V0Parameter *v0Param, Precision precision, + llvm::StringRef name, mlir::ModuleOp module); +} // namespace zamalang +} // namespace mlir + +#endif \ No newline at end of file diff --git a/compiler/include/zamalang/Support/CompilerTools.h b/compiler/include/zamalang/Support/CompilerTools.h index f95ceea4c..3e9b67b75 100644 --- a/compiler/include/zamalang/Support/CompilerTools.h +++ b/compiler/include/zamalang/Support/CompilerTools.h @@ -5,21 +5,43 @@ #include #include +#include "zamalang/Support/ClientParameters.h" +#include "zamalang/Support/KeySet.h" +#include "zamalang/Support/V0Parameters.h" + namespace mlir { namespace zamalang { +/// For the v0 we compute a global constraint, this is defined here as the +/// high-level verification pass is not yet implemented. +struct FHECircuitConstraint { + size_t norm2; + size_t p; +}; + class CompilerTools { public: /// lowerHLFHEToMlirLLVMDialect run all passes to lower FHE dialects to mlir - /// LLVM dialect. - static mlir::LogicalResult lowerHLFHEToMlirLLVMDialect( + /// lowerable to llvm dialect. + /// The given module MLIR operation would be modified and the constraints set. + static mlir::LogicalResult lowerHLFHEToMlirStdsDialect( + mlir::MLIRContext &context, mlir::Operation *module, + FHECircuitConstraint &constraint, + llvm::function_ref enablePass = [](std::string pass) { + return true; + }); + + /// lowerMlirStdsDialectToMlirLLVMDialect run all passes to lower MLIR + /// dialects to MLIR LLVM dialect. The given module MLIR operation would be + /// modified. + static mlir::LogicalResult lowerMlirStdsDialectToMlirLLVMDialect( mlir::MLIRContext &context, mlir::Operation *module, llvm::function_ref enablePass = [](std::string pass) { return true; }); static llvm::Expected> - toLLVMModule(llvm::LLVMContext &context, mlir::ModuleOp &module, + toLLVMModule(llvm::LLVMContext &llvmContext, mlir::ModuleOp &module, llvm::function_ref optPipeline); }; @@ -27,6 +49,28 @@ public: /// of the module. class JITLambda { public: + class Argument { + public: + Argument(KeySet &keySet); + ~Argument(); + + // Create lambda Argument that use the given KeySet to perform encryption + // and decryption operations. + static llvm::Expected> create(KeySet &keySet); + + // Set the argument at the given pos as a uint64_t. + llvm::Error setArg(size_t pos, uint64_t arg); + + // Get the result at the given pos as an uint64_t. + llvm::Error getResult(size_t pos, uint64_t &res); + + private: + friend JITLambda; + std::vector rawArg; + std::vector inputs; + std::vector results; + KeySet &keySet; + }; JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name) : type(type), name(name){}; @@ -35,7 +79,7 @@ public: create(llvm::StringRef name, mlir::ModuleOp &module, llvm::function_ref optPipeline); - /// invokeRaw execute the jit lambda with a lits of arguments, the last one is + /// invokeRaw execute the jit lambda with a list of Argument, the last one is /// used to store the result of the computation. /// Example: /// uin64_t arg0 = 1; @@ -44,6 +88,9 @@ public: /// lambda.invokeRaw(args); llvm::Error invokeRaw(llvm::MutableArrayRef args); + /// invoke the jit lambda with the Argument. + llvm::Error invoke(Argument &args); + private: mlir::LLVM::LLVMFunctionType type; llvm::StringRef name; diff --git a/compiler/include/zamalang/Support/KeySet.h b/compiler/include/zamalang/Support/KeySet.h new file mode 100644 index 000000000..dcb290d50 --- /dev/null +++ b/compiler/include/zamalang/Support/KeySet.h @@ -0,0 +1,65 @@ +#ifndef ZAMALANG_SUPPORT_KEYSET_H_ +#define ZAMALANG_SUPPORT_KEYSET_H_ + +#include "llvm/Support/Error.h" +#include + +extern "C" { +#include "concrete-ffi.h" +} + +#include "zamalang/Support/ClientParameters.h" + +namespace mlir { +namespace zamalang { + +class KeySet { +public: + ~KeySet(); + // allocate a KeySet according the ClientParameters. + static llvm::Expected> + generate(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb); + + // isInputEncrypted return true if the input at the given pos is encrypted. + bool isInputEncrypted(size_t pos); + // allocate a lwe ciphertext for the argument at argPos. + llvm::Error allocate_lwe(size_t argPos, LweCiphertext_u64 **ciphertext); + // encrypt the input to the ciphertext for the argument at argPos. + llvm::Error encrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext, + uint64_t input); + + // isOuputEncrypted return true if the output at the given pos is encrypted. + bool isOutputEncrypted(size_t pos); + // decrypt the ciphertext to the output for the argument at argPos. + llvm::Error decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext, + uint64_t &output); + + size_t numInputs() { return inputs.size(); } + size_t numOutputs() { return outputs.size(); } + +protected: + llvm::Error generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param, + SecretRandomGenerator *generator); + llvm::Error generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param, + EncryptionRandomGenerator *generator); + llvm::Error generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param, + EncryptionRandomGenerator *generator); + +private: + EncryptionRandomGenerator *encryptionRandomGenerator; + std::map> + secretKeys; + std::map> + bootstrapKeys; + std::map> + keyswitchKeys; + std::vector> + inputs; + std::vector> + outputs; +}; + +} // namespace zamalang +} // namespace mlir + +#endif \ No newline at end of file diff --git a/compiler/include/zamalang/Support/V0Parameters.h b/compiler/include/zamalang/Support/V0Parameters.h new file mode 100644 index 000000000..45b887d18 --- /dev/null +++ b/compiler/include/zamalang/Support/V0Parameters.h @@ -0,0 +1,29 @@ +#ifndef ZAMALANG_SUPPORT_V0Parameter_H_ +#define ZAMALANG_SUPPORT_V0Parameter_H_ + +#include + +namespace mlir { +namespace zamalang { + +typedef struct V0Parameter { + size_t k; + size_t polynomialSize; + size_t nSmall; + size_t brLevel; + size_t brLogBase; + size_t ksLevel; + size_t ksLogBase; + + V0Parameter(size_t k, size_t polynomialSize, size_t nSmall, size_t brLevel, + size_t brLogBase, size_t ksLevel, size_t ksLogBase) + : k(k), polynomialSize(polynomialSize), nSmall(nSmall), brLevel(brLevel), + brLogBase(brLogBase), ksLevel(ksLevel), ksLogBase(ksLogBase) {} + +} V0Parameter; + +V0Parameter *getV0Parameter(size_t norm, size_t p); + +} // namespace zamalang +} // namespace mlir +#endif \ No newline at end of file diff --git a/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp b/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp index 6e56ba541..9ec6f2909 100644 --- a/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp +++ b/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp @@ -3,9 +3,6 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" -#include "zamalang/Conversion/Passes.h" -#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h" - #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" @@ -19,10 +16,16 @@ #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/Sequence.h" +#include "zamalang/Conversion/Passes.h" +#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h" + namespace { struct MLIRLowerableDialectsToLLVMPass : public MLIRLowerableDialectsToLLVMBase { void runOnOperation() final; + + /// Convert types to the LLVM dialect-compatible type + static llvm::Optional convertTypes(mlir::Type type); }; } // namespace @@ -35,6 +38,7 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() { // Setup the LLVMTypeConverter (that converts `std` types to `llvm` types) and // add our types conversion to `llvm` compatible type. mlir::LLVMTypeConverter typeConverter(&getContext()); + typeConverter.addConversion(convertTypes); // Setup the set of the patterns rewriter. At this point we want to // convert the `scf` operations to `std` and `std` operations to `llvm`. @@ -49,6 +53,15 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() { } } +llvm::Optional +MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) { + if (type.isa()) { + return mlir::LLVM::LLVMPointerType::get( + mlir::IntegerType::get(type.getContext(), 8)); + } + return llvm::None; +} + namespace mlir { namespace zamalang { /// Create a pass for lowering operations the remaining mlir dialects diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index 4c89f456c..0f677b3dc 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -1,5 +1,8 @@ add_mlir_library(ZamalangSupport CompilerTools.cpp + V0Parameters.cpp + ClientParameters.cpp + KeySet.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/zamalang/Support @@ -13,4 +16,7 @@ add_mlir_library(ZamalangSupport MLIRLowerableDialectsToLLVM MLIRExecutionEngine - ${LLVM_PTHREAD_LIB}) \ No newline at end of file + ${LLVM_PTHREAD_LIB} + + Concrete +) diff --git a/compiler/lib/Support/ClientParameters.cpp b/compiler/lib/Support/ClientParameters.cpp new file mode 100644 index 000000000..423646550 --- /dev/null +++ b/compiler/lib/Support/ClientParameters.cpp @@ -0,0 +1,117 @@ +#include +#include + +#include + +#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h" +#include "zamalang/Support/ClientParameters.h" + +namespace mlir { +namespace zamalang { + +// For the v0 the secretKeyID and precision are the same for all gates. +llvm::Expected gateFromMLIRType(std::string secretKeyID, + Precision precision, + mlir::Type type) { + if (type.isInteger(64)) { + return CircuitGate{ + .encryption = llvm::None, + .shape = {.size = 0}, + }; + } + if (type.isa()) { + return CircuitGate{ + .encryption = llvm::Optional({ + .secretKeyID = secretKeyID, + // TODO - Compute variance, wait for security estimator + .variance = 0., + .encoding = {.precision = precision}, + }), + .shape = {.size = 0}, + }; + } + auto memref = type.dyn_cast_or_null(); + if (memref != nullptr) { + auto gate = + gateFromMLIRType(secretKeyID, precision, memref.getElementType()); + if (auto err = gate.takeError()) { + return std::move(err); + } + gate->shape.size = memref.getDimSize(0); + return gate; + } + return llvm::make_error( + "cannot convert MLIR type to shape", llvm::inconvertibleErrorCode()); +} + +llvm::Expected +createClientParametersForV0(V0Parameter *v0Param, Precision precision, + llvm::StringRef name, mlir::ModuleOp module) { + // Static client parameters from global parameters for v0 + ClientParameters c{ + .secretKeys{ + {"small", {.size = v0Param->nSmall}}, + {"big", {.size = v0Param->k * (1 << v0Param->polynomialSize)}}, + }, + .bootstrapKeys{ + { + "bsk_v0", + { + .inputSecretKeyID = "small", + .outputSecretKeyID = "big", + .level = v0Param->brLevel, + .baseLog = v0Param->brLogBase, + .k = v0Param->k, + // TODO - Compute variance, wait for security estimator + .variance = 0, + }, + }, + }, + .keyswitchKeys{ + { + "ksk_v0", + { + .inputSecretKeyID = "big", + .outputSecretKeyID = "small", + .level = v0Param->ksLevel, + .baseLog = v0Param->ksLogBase, + // TODO - Compute variance, wait for security estimator + .variance = 0, + }, + }, + }, + }; + // Find the input function + auto rangeOps = module.getOps(); + auto funcOp = llvm::find_if( + rangeOps, [&](mlir::FuncOp op) { return op.getName() == name; }); + if (funcOp == rangeOps.end()) { + return llvm::make_error( + "cannot find the function for generate client parameters", + llvm::inconvertibleErrorCode()); + } + + // For the v0 the precision is global + Encoding encoding = {.precision = precision}; + + // Create input and output circuit gate parameters + auto funcType = (*funcOp).getType(); + for (auto inType : funcType.getInputs()) { + auto gate = gateFromMLIRType("big", precision, inType); + if (auto err = gate.takeError()) { + return std::move(err); + } + c.inputs.push_back(gate.get()); + } + for (auto outType : funcType.getResults()) { + auto gate = gateFromMLIRType("big", precision, outType); + if (auto err = gate.takeError()) { + return std::move(err); + } + c.outputs.push_back(gate.get()); + } + return c; +} + +} // namespace zamalang +} // namespace mlir \ No newline at end of file diff --git a/compiler/lib/Support/CompilerTools.cpp b/compiler/lib/Support/CompilerTools.cpp index 598d9ee65..c03c7682a 100644 --- a/compiler/lib/Support/CompilerTools.cpp +++ b/compiler/lib/Support/CompilerTools.cpp @@ -8,6 +8,9 @@ namespace mlir { namespace zamalang { +// This is temporary while we doesn't yet have the high-level verification pass +FHECircuitConstraint defaultGlobalFHECircuitConstraint{.norm2 = 20, .p = 6}; + void initLLVMNativeTarget() { // Initialize LLVM targets. llvm::InitializeNativeTarget(); @@ -27,8 +30,9 @@ void addFilteredPassToPassManager( } }; -mlir::LogicalResult CompilerTools::lowerHLFHEToMlirLLVMDialect( +mlir::LogicalResult CompilerTools::lowerHLFHEToMlirStdsDialect( mlir::MLIRContext &context, mlir::Operation *module, + FHECircuitConstraint &constraint, llvm::function_ref enablePass) { mlir::PassManager pm(&context); @@ -37,11 +41,7 @@ mlir::LogicalResult CompilerTools::lowerHLFHEToMlirLLVMDialect( pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg(), enablePass); addFilteredPassToPassManager( pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass); - addFilteredPassToPassManager( - pm, mlir::zamalang::createConvertMidLFHEToLowLFHEPass(), enablePass); - addFilteredPassToPassManager( - pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass(), - enablePass); + constraint = defaultGlobalFHECircuitConstraint; // Run the passes if (pm.run(module).failed()) { @@ -51,14 +51,31 @@ mlir::LogicalResult CompilerTools::lowerHLFHEToMlirLLVMDialect( return mlir::success(); } +mlir::LogicalResult CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect( + mlir::MLIRContext &context, mlir::Operation *module, + llvm::function_ref enablePass) { + + mlir::PassManager pm(&context); + addFilteredPassToPassManager( + pm, mlir::zamalang::createConvertMidLFHEToLowLFHEPass(), enablePass); + addFilteredPassToPassManager( + pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass(), + enablePass); + + if (pm.run(module).failed()) { + return mlir::failure(); + } + return mlir::success(); +} + llvm::Expected> CompilerTools::toLLVMModule( - llvm::LLVMContext &context, mlir::ModuleOp &module, + llvm::LLVMContext &llvmContext, mlir::ModuleOp &module, llvm::function_ref optPipeline) { initLLVMNativeTarget(); mlir::registerLLVMDialectTranslation(*module->getContext()); - auto llvmModule = mlir::translateModuleToLLVMIR(module, context); + auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext); if (!llvmModule) { return llvm::make_error( "failed to translate MLIR to LLVM IR", llvm::inconvertibleErrorCode()); @@ -113,5 +130,69 @@ llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef args) { llvm::inconvertibleErrorCode()); } +llvm::Error JITLambda::invoke(Argument &args) { return invokeRaw(args.rawArg); } + +JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) { + inputs = std::vector(keySet.numInputs()); + results = std::vector(keySet.numOutputs()); + // The raw argument contains pointers to inputs and pointers to store the + // results + rawArg = + std::vector(keySet.numInputs() + keySet.numOutputs(), nullptr); + // Set the results pointer on the rawArg + for (auto i = keySet.numInputs(); i < rawArg.size(); i++) { + rawArg[i] = &results[i - keySet.numInputs()]; + } +} + +JITLambda::Argument::~Argument() { + int err; + for (auto i = 0; i < keySet.numInputs(); i++) { + if (keySet.isInputEncrypted(i)) { + free_lwe_ciphertext_u64(&err, (LweCiphertext_u64 *)(inputs[i])); + } + } +} + +llvm::Expected> +JITLambda::Argument::create(KeySet &keySet) { + auto args = std::make_unique(keySet); + return std::move(args); +} + +llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) { + // If argument is not encrypted, just save. + if (!keySet.isInputEncrypted(pos)) { + inputs[pos] = (void *)arg; + rawArg[pos] = &inputs[pos]; + return llvm::Error::success(); + } + // Else if is encryted, allocate ciphertext. + LweCiphertext_u64 *ctArg; + if (auto err = this->keySet.allocate_lwe(pos, &ctArg)) { + return std::move(err); + } + if (auto err = this->keySet.encrypt_lwe(pos, ctArg, arg)) { + return std::move(err); + } + inputs[pos] = ctArg; + rawArg[pos] = &inputs[pos]; + return llvm::Error::success(); +} + +llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t &res) { + // If result is not encrypted, just set the result + if (!keySet.isOutputEncrypted(pos)) { + res = (uint64_t)(results[pos]); + return llvm::Error::success(); + } + // Else if is encryted, decrypt + LweCiphertext_u64 *ct = (LweCiphertext_u64 *)(results[pos]); + if (auto err = this->keySet.decrypt_lwe(pos, ct, res)) { + return std::move(err); + } + return llvm::Error::success(); +} + } // namespace zamalang } // namespace mlir \ No newline at end of file diff --git a/compiler/lib/Support/KeySet.cpp b/compiler/lib/Support/KeySet.cpp new file mode 100644 index 000000000..739083c8d --- /dev/null +++ b/compiler/lib/Support/KeySet.cpp @@ -0,0 +1,276 @@ +#include "zamalang/Support/KeySet.h" + +#define CAPI_ERR_TO_LLVM_ERROR(s, msg) \ + { \ + int err; \ + s; \ + if (err != 0) { \ + return llvm::make_error( \ + msg, llvm::inconvertibleErrorCode()); \ + } \ + } + +namespace mlir { +namespace zamalang { + +KeySet::~KeySet() { + int err; + for (auto it : secretKeys) { + free_lwe_secret_key_u64(&err, it.second.second); + } + for (auto it : bootstrapKeys) { + free_lwe_bootstrap_key_u64(&err, it.second.second); + } + for (auto it : keyswitchKeys) { + free_lwe_keyswitch_key_u64(&err, it.second.second); + } + free_encryption_generator(&err, encryptionRandomGenerator); +} + +llvm::Expected> +KeySet::generate(ClientParameters ¶ms, uint64_t seed_msb, + uint64_t seed_lsb) { + auto keySet = std::make_unique(); + + { + // Generate LWE secret keys + SecretRandomGenerator *generator; + CAPI_ERR_TO_LLVM_ERROR( + generator = allocate_secret_generator(&err, seed_msb, seed_lsb), + "cannot allocate random generator"); + for (auto secretKeyParam : params.secretKeys) { + auto e = keySet->generateSecretKey(secretKeyParam.first, + secretKeyParam.second, generator); + if (e) { + return e; + } + } + CAPI_ERR_TO_LLVM_ERROR(free_secret_generator(&err, generator), + "cannot free random generator"); + } + // Allocate the encryption random generator + CAPI_ERR_TO_LLVM_ERROR( + keySet->encryptionRandomGenerator = + allocate_encryption_generator(&err, seed_msb, seed_lsb), + "cannot allocate encryption generator"); + // Generate bootstrap and keyswitch keys + { + for (auto bootstrapKeyParam : params.bootstrapKeys) { + auto e = keySet->generateBootstrapKey(bootstrapKeyParam.first, + bootstrapKeyParam.second, + keySet->encryptionRandomGenerator); + if (e) { + return e; + } + } + for (auto keyswitchParam : params.keyswitchKeys) { + auto e = keySet->generateKeyswitchKey(keyswitchParam.first, + keyswitchParam.second, + keySet->encryptionRandomGenerator); + if (e) { + return e; + } + } + } + // Set inputs and outputs LWE secret keys + { + for (auto param : params.inputs) { + std::tuple input = { + param, nullptr, nullptr}; + if (param.encryption.hasValue()) { + auto inputSk = keySet->secretKeys.find(param.encryption->secretKeyID); + if (inputSk == keySet->secretKeys.end()) { + return llvm::make_error( + "cannot find input key to generate bootstrap key", + llvm::inconvertibleErrorCode()); + } + std::get<1>(input) = &inputSk->second.first; + std::get<2>(input) = inputSk->second.second; + } + keySet->inputs.push_back(input); + } + for (auto param : params.outputs) { + std::tuple output = + {param, nullptr, nullptr}; + if (param.encryption.hasValue()) { + auto outputSk = keySet->secretKeys.find(param.encryption->secretKeyID); + if (outputSk == keySet->secretKeys.end()) { + return llvm::make_error( + "cannot find output key to generate bootstrap key", + llvm::inconvertibleErrorCode()); + } + std::get<1>(output) = &outputSk->second.first; + std::get<2>(output) = outputSk->second.second; + } + keySet->outputs.push_back(output); + } + } + return std::move(keySet); +} + +llvm::Error KeySet::generateSecretKey(LweSecretKeyID id, + LweSecretKeyParam param, + SecretRandomGenerator *generator) { + LweSecretKey_u64 *sk; + CAPI_ERR_TO_LLVM_ERROR( + sk = allocate_lwe_secret_key_u64(&err, {_0 : param.size}), + "cannot allocate secret key"); + CAPI_ERR_TO_LLVM_ERROR(fill_lwe_secret_key_u64(&err, sk, generator), + "cannot fill secret key with random generator") + secretKeys[id] = {param, sk}; + return llvm::Error::success(); +} + +llvm::Error KeySet::generateBootstrapKey(BootstrapKeyID id, + BootstrapKeyParam param, + EncryptionRandomGenerator *generator) { + // Finding input and output secretKeys + auto inputSk = secretKeys.find(param.inputSecretKeyID); + if (inputSk == secretKeys.end()) { + return llvm::make_error( + "cannot find input key to generate bootstrap key", + llvm::inconvertibleErrorCode()); + } + auto outputSk = secretKeys.find(param.outputSecretKeyID); + if (outputSk == secretKeys.end()) { + return llvm::make_error( + "cannot find input key to generate bootstrap key", + llvm::inconvertibleErrorCode()); + } + // Allocate the bootstrap key + LweBootstrapKey_u64 *bsk; + CAPI_ERR_TO_LLVM_ERROR( + bsk = allocate_lwe_bootstrap_key_u64( + &err, {param.level}, {param.baseLog}, {param.k}, + {inputSk->second.first.size}, + {outputSk->second.first.size /*TODO: size / k ?*/}), + "cannot allocate bootstrap key"); + // Store the bootstrap key + bootstrapKeys[id] = {param, bsk}; + // Convert the output lwe key to glwe key + GlweSecretKey_u64 *glwe_sk; + CAPI_ERR_TO_LLVM_ERROR( + glwe_sk = allocate_glwe_secret_key_u64(&err, {param.k}, + {outputSk->second.first.size}), + "cannot allocate glwe key for initiliazation of bootstrap key"); + // Initialize the bootstrap key + CAPI_ERR_TO_LLVM_ERROR( + fill_lwe_bootstrap_key_u64(&err, bsk, inputSk->second.second, glwe_sk, + generator, {param.variance}), + "cannot fill bootstrap key"); + CAPI_ERR_TO_LLVM_ERROR( + free_glwe_secret_key_u64(&err, glwe_sk), + "cannot free glwe key for initiliazation of bootstrap key") + return llvm::Error::success(); +} + +llvm::Error KeySet::generateKeyswitchKey(KeyswitchKeyID id, + KeyswitchKeyParam param, + EncryptionRandomGenerator *generator) { + // Finding input and output secretKeys + auto inputSk = secretKeys.find(param.inputSecretKeyID); + if (inputSk == secretKeys.end()) { + return llvm::make_error( + "cannot find input key to generate keyswitch key", + llvm::inconvertibleErrorCode()); + } + auto outputSk = secretKeys.find(param.outputSecretKeyID); + if (outputSk == secretKeys.end()) { + return llvm::make_error( + "cannot find input key to generate keyswitch key", + llvm::inconvertibleErrorCode()); + } + // Allocate the keyswitch key + LweKeyswitchKey_u64 *ksk; + CAPI_ERR_TO_LLVM_ERROR( + ksk = allocate_lwe_keyswitch_key_u64(&err, {param.level}, {param.baseLog}, + {inputSk->second.first.size}, + {outputSk->second.first.size}), + "cannot allocate keyswitch key"); + // Store the keyswitch key + keyswitchKeys[id] = {param, ksk}; + // Initialize the keyswitch key + CAPI_ERR_TO_LLVM_ERROR( + fill_lwe_keyswitch_key_u64(&err, ksk, inputSk->second.second, + outputSk->second.second, generator, + {param.variance}), + "cannot fill bootsrap key"); + return llvm::Error::success(); +} + +llvm::Error KeySet::allocate_lwe(size_t argPos, + LweCiphertext_u64 **ciphertext) { + if (argPos >= inputs.size()) { + return llvm::make_error( + "allocate_lwe position of argument is too high", + llvm::inconvertibleErrorCode()); + } + auto inputSk = inputs[argPos]; + CAPI_ERR_TO_LLVM_ERROR(*ciphertext = allocate_lwe_ciphertext_u64( + &err, {std::get<1>(inputSk)->size}), + "cannot allocate ciphertext"); + return llvm::Error::success(); +} + +bool KeySet::isInputEncrypted(size_t argPos) { + return argPos < inputs.size() && + std::get<0>(inputs[argPos]).encryption.hasValue(); +} + +bool KeySet::isOutputEncrypted(size_t argPos) { + return argPos < outputs.size() && + std::get<0>(outputs[argPos]).encryption.hasValue(); +} + +llvm::Error KeySet::encrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext, + uint64_t input) { + if (argPos >= inputs.size()) { + return llvm::make_error( + "encrypt_lwe position of argument is too high", + llvm::inconvertibleErrorCode()); + } + auto inputSk = inputs[argPos]; + if (!std::get<0>(inputSk).encryption.hasValue()) { + return llvm::make_error( + "encrypt_lwe the positional argument is not encrypted", + llvm::inconvertibleErrorCode()); + } + // Encode - TODO we could check if the input value is in the right range + Plaintext_u64 plaintext = { + input << (64 - + (std::get<0>(inputSk).encryption->encoding.precision + 1))}; + // Encrypt + CAPI_ERR_TO_LLVM_ERROR( + encrypt_lwe_u64(&err, std::get<2>(inputSk), ciphertext, plaintext, + encryptionRandomGenerator, + {std::get<0>(inputSk).encryption->variance}), + "cannot encrypt"); + return llvm::Error::success(); +} + +llvm::Error KeySet::decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext, + uint64_t &output) { + if (argPos >= outputs.size()) { + return llvm::make_error( + "decrypt_lwe: position of argument is too high", + llvm::inconvertibleErrorCode()); + } + auto outputSk = outputs[argPos]; + if (!std::get<0>(outputSk).encryption.hasValue()) { + return llvm::make_error( + "decrypt_lwe: the positional argument is not encrypted", + llvm::inconvertibleErrorCode()); + } + // Decrypt + Plaintext_u64 plaintext; + CAPI_ERR_TO_LLVM_ERROR( + decrypt_lwe_u64(&err, std::get<2>(outputSk), ciphertext, &plaintext), + "cannot decrypt"); + // Decode + output = plaintext._0 >> + (64 - (std::get<0>(outputSk).encryption->encoding.precision + 1)); + return llvm::Error::success(); +} + +} // namespace zamalang +} // namespace mlir \ No newline at end of file diff --git a/compiler/lib/Support/V0Parameters.cpp b/compiler/lib/Support/V0Parameters.cpp new file mode 100644 index 000000000..d0e1ce089 --- /dev/null +++ b/compiler/lib/Support/V0Parameters.cpp @@ -0,0 +1,134 @@ +/// DO NOT MANUALLY EDIT THIS FILE. +/// This file was generated thanks the "parameters optimizer". +/// We should include this in our build system, but for moment it is just a cc +/// from the optimizer output. + +#include "zamalang/Support/V0Parameters.h" + +namespace mlir { +namespace zamalang { +using namespace std; + +const int NORM2_MAX = 25; +const int P_MAX = 7; + +V0Parameter parameters[NORM2_MAX][P_MAX] = { + {V0Parameter(1, 10, 481, 2, 8, 4, 2), V0Parameter(1, 10, 483, 2, 8, 4, 2), + V0Parameter(1, 10, 491, 2, 8, 4, 2), V0Parameter(1, 10, 484, 3, 6, 4, 2), + V0Parameter(1, 10, 497, 3, 6, 4, 2), V0Parameter(1, 11, 506, 1, 24, 4, 2), + V0Parameter(1, 11, 506, 1, 24, 4, 2)}, + {V0Parameter(1, 11, 506, 1, 23, 4, 2), V0Parameter(1, 11, 508, 1, 23, 4, 2), + V0Parameter(1, 11, 533, 1, 23, 3, 3), V0Parameter(1, 11, 506, 2, 13, 4, 2), + V0Parameter(1, 11, 506, 2, 15, 4, 2), V0Parameter(1, 11, 506, 2, 17, 4, 2), + V0Parameter(1, 11, 506, 2, 15, 4, 2)}, + {V0Parameter(1, 11, 506, 2, 15, 4, 2), V0Parameter(1, 11, 508, 2, 15, 4, 2), + V0Parameter(1, 11, 513, 2, 15, 4, 2), V0Parameter(1, 11, 530, 2, 15, 5, 2), + V0Parameter(1, 11, 507, 3, 12, 4, 2), V0Parameter(1, 11, 511, 3, 11, 4, 2), + V0Parameter(1, 11, 517, 3, 12, 5, 2)}, + {V0Parameter(1, 11, 508, 4, 9, 4, 2), V0Parameter(1, 11, 509, 4, 9, 5, 2), + V0Parameter(1, 11, 507, 5, 8, 5, 2), V0Parameter(1, 11, 541, 5, 8, 5, 2), + V0Parameter(1, 10, 549, 2, 8, 3, 3), V0Parameter(1, 10, 530, 2, 8, 5, 2), + V0Parameter(1, 10, 524, 3, 6, 5, 2)}, + {V0Parameter(1, 10, 536, 3, 6, 5, 2), V0Parameter(1, 11, 571, 1, 22, 3, 3), + V0Parameter(1, 11, 572, 1, 22, 3, 3), V0Parameter(1, 11, 572, 1, 22, 3, 3), + V0Parameter(1, 11, 574, 1, 23, 3, 3), V0Parameter(1, 11, 585, 1, 23, 3, 3), + V0Parameter(1, 11, 541, 2, 14, 5, 2)}, + {V0Parameter(1, 11, 541, 2, 14, 5, 2), V0Parameter(1, 11, 541, 2, 17, 5, 2), + V0Parameter(1, 11, 541, 2, 15, 5, 2), V0Parameter(1, 11, 541, 2, 15, 5, 2), + V0Parameter(1, 11, 542, 2, 15, 5, 2), V0Parameter(1, 11, 547, 2, 15, 5, 2), + V0Parameter(1, 11, 580, 2, 15, 5, 2)}, + {V0Parameter(1, 11, 542, 3, 11, 5, 2), V0Parameter(1, 11, 545, 3, 12, 5, 2), + V0Parameter(1, 11, 561, 3, 12, 5, 2), V0Parameter(1, 11, 543, 4, 9, 5, 2), + V0Parameter(1, 11, 549, 4, 9, 5, 2), V0Parameter(1, 11, 547, 5, 8, 5, 2), + V0Parameter(1, 11, 591, 5, 8, 6, 2)}, + {V0Parameter(1, 11, 550, 7, 6, 5, 2), V0Parameter(1, 10, 576, 2, 8, 5, 2), + V0Parameter(1, 10, 567, 3, 6, 5, 2), V0Parameter(1, 10, 584, 3, 6, 5, 2), + V0Parameter(1, 11, 607, 1, 20, 4, 3), V0Parameter(1, 11, 607, 1, 22, 4, 3), + V0Parameter(1, 11, 607, 1, 22, 4, 3)}, + {V0Parameter(1, 11, 609, 1, 23, 4, 3), V0Parameter(1, 11, 616, 1, 23, 4, 3), + V0Parameter(1, 11, 588, 2, 13, 5, 2), V0Parameter(1, 11, 588, 2, 17, 5, 2), + V0Parameter(1, 11, 588, 2, 14, 5, 2), V0Parameter(1, 11, 588, 2, 16, 5, 2), + V0Parameter(1, 11, 588, 2, 15, 5, 2)}, + {V0Parameter(1, 11, 590, 2, 15, 5, 2), V0Parameter(1, 11, 596, 2, 15, 5, 2), + V0Parameter(1, 11, 622, 2, 15, 6, 2), V0Parameter(1, 11, 589, 3, 11, 5, 2), + V0Parameter(1, 11, 594, 3, 12, 5, 2), V0Parameter(1, 11, 602, 3, 12, 6, 2), + V0Parameter(1, 11, 591, 4, 9, 5, 2)}, + {V0Parameter(1, 11, 591, 4, 9, 6, 2), V0Parameter(1, 11, 589, 5, 8, 6, 2), + V0Parameter(1, 11, 655, 5, 8, 6, 2), V0Parameter(1, 11, 592, 7, 6, 6, 2), + V0Parameter(1, 11, 598, 8, 5, 6, 2), V0Parameter(1, 10, 608, 3, 6, 6, 2), + V0Parameter(1, 10, 625, 3, 6, 6, 2)}, + {V0Parameter(1, 11, 647, 1, 24, 4, 3), V0Parameter(1, 11, 647, 1, 22, 4, 3), + V0Parameter(1, 11, 647, 1, 23, 4, 3), V0Parameter(1, 11, 649, 1, 23, 4, 3), + V0Parameter(1, 11, 658, 1, 23, 4, 3), V0Parameter(1, 11, 647, 2, 19, 4, 3), + V0Parameter(1, 11, 647, 2, 14, 4, 3)}, + {V0Parameter(1, 11, 647, 2, 17, 4, 3), V0Parameter(1, 11, 647, 2, 15, 4, 3), + V0Parameter(1, 11, 647, 2, 15, 4, 3), V0Parameter(1, 11, 648, 2, 15, 4, 3), + V0Parameter(1, 11, 654, 2, 15, 4, 3), V0Parameter(1, 11, 674, 2, 15, 7, 2), + V0Parameter(1, 11, 624, 3, 12, 6, 2)}, + {V0Parameter(1, 11, 627, 3, 11, 6, 2), V0Parameter(1, 11, 648, 3, 12, 6, 2), + V0Parameter(1, 11, 625, 4, 9, 6, 2), V0Parameter(1, 11, 633, 4, 9, 6, 2), + V0Parameter(1, 11, 630, 5, 8, 6, 2), V0Parameter(1, 11, 632, 6, 7, 6, 2), + V0Parameter(1, 11, 634, 7, 6, 6, 2)}, + {V0Parameter(1, 11, 642, 8, 5, 6, 2), V0Parameter(1, 11, 644, 11, 4, 6, 2), + V0Parameter(1, 11, 698, 1, 20, 4, 3), V0Parameter(1, 11, 698, 1, 21, 4, 3), + V0Parameter(1, 11, 698, 1, 22, 4, 3), V0Parameter(1, 11, 699, 1, 22, 4, 3), + V0Parameter(1, 11, 702, 1, 23, 4, 3)}, + {V0Parameter(1, 11, 721, 1, 23, 4, 3), V0Parameter(1, 11, 698, 2, 13, 4, 3), + V0Parameter(1, 11, 698, 2, 17, 4, 3), V0Parameter(1, 11, 698, 2, 14, 4, 3), + V0Parameter(1, 11, 698, 2, 16, 4, 3), V0Parameter(1, 11, 698, 2, 15, 4, 3), + V0Parameter(1, 11, 700, 2, 15, 4, 3)}, + {V0Parameter(1, 11, 711, 2, 15, 4, 3), V0Parameter(1, 11, 674, 3, 12, 6, 2), + V0Parameter(1, 11, 675, 3, 11, 6, 2), V0Parameter(1, 11, 681, 3, 12, 6, 2), + V0Parameter(1, 11, 695, 3, 12, 7, 2), V0Parameter(1, 11, 677, 4, 9, 6, 2), + V0Parameter(1, 11, 677, 4, 9, 7, 2)}, + {V0Parameter(1, 11, 674, 5, 8, 7, 2), V0Parameter(1, 11, 676, 6, 7, 7, 2), + V0Parameter(1, 11, 678, 7, 6, 7, 2), V0Parameter(1, 11, 688, 8, 5, 7, 2), + V0Parameter(1, 11, 690, 11, 4, 7, 2), + V0Parameter(1, 11, 731, 14, 3, 15, 1), + V0Parameter(1, 11, 745, 1, 21, 5, 3)}, + {V0Parameter(1, 11, 745, 1, 22, 5, 3), V0Parameter(1, 11, 746, 1, 23, 5, 3), + V0Parameter(1, 11, 750, 1, 23, 5, 3), V0Parameter(1, 11, 781, 1, 23, 5, 3), + V0Parameter(1, 11, 745, 2, 19, 5, 3), V0Parameter(1, 11, 745, 2, 14, 5, 3), + V0Parameter(1, 11, 745, 2, 17, 5, 3)}, + {V0Parameter(1, 11, 745, 2, 15, 5, 3), V0Parameter(1, 11, 746, 2, 15, 5, 3), + V0Parameter(1, 11, 748, 2, 15, 5, 3), V0Parameter(1, 11, 763, 2, 15, 5, 3), + V0Parameter(1, 11, 745, 3, 12, 5, 3), V0Parameter(1, 11, 747, 3, 11, 5, 3), + V0Parameter(1, 11, 756, 3, 11, 5, 3)}, + {V0Parameter(1, 11, 722, 4, 9, 7, 2), V0Parameter(1, 11, 727, 4, 9, 7, 2), + V0Parameter(1, 11, 750, 4, 9, 8, 2), V0Parameter(1, 11, 747, 5, 8, 7, 2), + V0Parameter(1, 11, 747, 6, 7, 8, 2), V0Parameter(1, 11, 757, 7, 6, 8, 2), + V0Parameter(1, 11, 739, 10, 4, 8, 2)}, + {V0Parameter(1, 11, 735, 14, 3, 8, 2), V0Parameter(1, 11, 767, 21, 2, 8, 2), + V0Parameter(0, 0, 0, 0, 0, 0, 0), V0Parameter(1, 12, 839, 1, 22, 4, 4), + V0Parameter(1, 12, 844, 1, 23, 4, 4), V0Parameter(1, 12, 847, 1, 23, 6, 3), + V0Parameter(1, 12, 815, 2, 13, 5, 3)}, + {V0Parameter(1, 12, 815, 2, 17, 5, 3), V0Parameter(1, 12, 815, 2, 14, 5, 3), + V0Parameter(1, 12, 815, 2, 15, 5, 3), V0Parameter(1, 12, 816, 2, 15, 5, 3), + V0Parameter(1, 12, 820, 2, 15, 5, 3), V0Parameter(1, 12, 858, 2, 15, 4, 4), + V0Parameter(1, 12, 815, 3, 11, 5, 3)}, + {V0Parameter(1, 12, 818, 3, 11, 5, 3), V0Parameter(1, 12, 790, 3, 11, 8, 2), + V0Parameter(1, 12, 783, 4, 9, 8, 2), V0Parameter(1, 12, 787, 4, 9, 8, 2), + V0Parameter(1, 12, 823, 4, 9, 8, 2), V0Parameter(1, 12, 822, 5, 8, 8, 2), + V0Parameter(1, 12, 843, 6, 7, 9, 2)}, + {V0Parameter(1, 12, 792, 8, 5, 8, 2), V0Parameter(1, 12, 802, 10, 4, 8, 2), + V0Parameter(1, 12, 804, 14, 3, 8, 2), V0Parameter(1, 12, 825, 22, 2, 9, 2), + V0Parameter(0, 0, 0, 0, 0, 0, 0), V0Parameter(0, 0, 0, 0, 0, 0, 0), + V0Parameter(0, 0, 0, 0, 0, 0, 0)}}; + +V0Parameter *getV0Parameter(size_t norm, size_t p) { + if (norm > NORM2_MAX) { + return nullptr; + } + if (p >= P_MAX) { + return nullptr; + } + // - 1 is an offset as norm and p are in [1, ...] and not [0, ...] + auto param = ¶meters[norm - 1][p - 1]; + if (param->k == 0) { + return nullptr; + } + return param; +} + +} // namespace zamalang +} // namespace mlir \ No newline at end of file diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index b732a71d3..9bfc47455 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -32,6 +32,9 @@ llvm::cl::opt output("o", llvm::cl::value_desc("filename"), llvm::cl::init("-")); +llvm::cl::opt verbose("verbose", llvm::cl::desc("verbose logs"), + llvm::cl::init(false)); + llvm::cl::list passes( "passes", llvm::cl::desc("Specify the passes to run (use only for compiler tests)"), @@ -53,8 +56,19 @@ llvm::cl::opt splitInputFile( "chunk independently"), llvm::cl::init(false)); +llvm::cl::opt generateKeySet( + "generate-keyset", + llvm::cl::desc("[tmp] Generate a key set for the compiled fhe circuit"), + llvm::cl::init(false)); + llvm::cl::opt runJit("run-jit", llvm::cl::desc("JIT the code and run it"), llvm::cl::init(false)); + +llvm::cl::opt jitFuncname( + "jit-funcname", + llvm::cl::desc("Name of the function to execute, default 'main'"), + llvm::cl::init("main")); + llvm::cl::list jitArgs("jit-args", llvm::cl::desc("Value of arguments to pass to the main func"), @@ -64,6 +78,12 @@ llvm::cl::opt toLLVM("to-llvm", llvm::cl::desc("Compile to llvm and "), llvm::cl::init(false)); }; // namespace cmdline +#define LOG_VERBOSE(expr) \ + if (cmdline::verbose) \ + llvm::errs() << expr; + +#define LOG_ERROR(expr) llvm::errs() << expr; + auto defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr); mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) { @@ -77,32 +97,42 @@ mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) { return mlir::success(); } -mlir::LogicalResult runJit(mlir::ModuleOp module, llvm::raw_ostream &os) { +mlir::LogicalResult runJit(mlir::ModuleOp module, + mlir::zamalang::KeySet &keySet, + llvm::raw_ostream &os) { // Create the JIT lambda - auto maybeLambda = - mlir::zamalang::JITLambda::create("main", module, defaultOptPipeline); + auto maybeLambda = mlir::zamalang::JITLambda::create( + cmdline::jitFuncname, module, defaultOptPipeline); if (!maybeLambda) { return mlir::failure(); } - auto lambda = maybeLambda.get().get(); + auto lambda = std::move(maybeLambda.get()); - // Create buffer to copy argument - std::vector dummy(cmdline::jitArgs.size()); - llvm::SmallVector llvmArgs; - for (auto i = 0; i < cmdline::jitArgs.size(); i++) { - dummy[i] = cmdline::jitArgs[i]; - llvmArgs.push_back(&dummy[i]); - } - // Add the result pointer - uint64_t res = 0; - llvmArgs.push_back(&res); + // Create the arguments of the JIT lambda + auto maybeArguments = mlir::zamalang::JITLambda::Argument::create(keySet); + if (auto err = maybeArguments.takeError()) { - // Invoke the lambda - if (lambda->invokeRaw(llvmArgs)) { + LOG_ERROR("Cannot create lambda arguments: " << err << "\n"); return mlir::failure(); } - - std::cerr << res << "\n"; + // Set the arguments + auto arguments = std::move(maybeArguments.get()); + for (auto i = 0; i < cmdline::jitArgs.size(); i++) { + if (auto err = arguments->setArg(i, cmdline::jitArgs[i])) { + LOG_ERROR("Cannot push argument " << i << ": " << err << "\n"); + return mlir::failure(); + } + } + // Invoke the lambda + if (lambda->invoke(*arguments)) { + return mlir::failure(); + } + uint64_t res = 0; + if (auto err = arguments->getResult(0, res)) { + LOG_ERROR("Cannot get result : " << err << "\n"); + return mlir::failure(); + } + llvm::errs() << res << "\n"; return mlir::success(); } @@ -137,20 +167,67 @@ processInputBuffer(mlir::MLIRContext &context, return mlir::success(); } - if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirLLVMDialect( - context, *module, - [](std::string passName) { - return cmdline::passes.size() == 0 || - std::any_of( - cmdline::passes.begin(), cmdline::passes.end(), + auto enablePass = [](std::string passName) { + return cmdline::passes.size() == 0 || + std::any_of(cmdline::passes.begin(), cmdline::passes.end(), [&](const std::string &p) { return passName == p; }); - }) + }; + + // Lower to MLIR Stds Dialects and compute the constraint on the FHE Circuit. + mlir::zamalang::FHECircuitConstraint constraint; + LOG_VERBOSE("### Lower from HLFHE to MLIR standards \n"); + if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect( + context, *module, constraint, enablePass) + .failed()) { + return mlir::failure(); + } + LOG_VERBOSE("### Global FHE constraint: {norm2:" << constraint.norm2 << ", p:" + << constraint.p << "}\n"); + + // Retreive the parameters for the v0 approach + mlir::zamalang::V0Parameter *fheParameter = + mlir::zamalang::getV0Parameter(constraint.norm2, constraint.p); + LOG_VERBOSE("### FHE parameters for the atomic pattern: {k: " + << fheParameter->k + << ", polynomialSize: " << fheParameter->polynomialSize + << ", nSmall: " << fheParameter->nSmall + << ", brLevel: " << fheParameter->brLevel + << ", brLogBase: " << fheParameter->brLogBase + << ", ksLevel: " << fheParameter->ksLevel + << ", polynomialSize: " << fheParameter->ksLogBase << "}\n"); + + // Generate the keySet + std::unique_ptr keySet; + if (cmdline::generateKeySet || cmdline::runJit) { + // Create the client parameters + auto clientParameter = mlir::zamalang::createClientParametersForV0( + fheParameter, constraint.p, cmdline::jitFuncname, *module); + if (auto err = clientParameter.takeError()) { + LOG_ERROR("cannot generate client parameters: " << err << "\n"); + return mlir::failure(); + } + LOG_VERBOSE("### Generate the key set\n"); + auto maybeKeySet = + mlir::zamalang::KeySet::generate(clientParameter.get(), 0, + 0); // TODO: seed + if (auto err = maybeKeySet.takeError()) { + llvm::errs() << err; + return mlir::failure(); + } + keySet = std::move(maybeKeySet.get()); + } + + // Lower to MLIR LLVM Dialect + LOG_VERBOSE("### Lower from MLIR standards to LLVM\n"); + if (mlir::zamalang::CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect( + context, *module, enablePass) .failed()) { return mlir::failure(); } if (cmdline::runJit) { - return runJit(module.get(), os); + LOG_VERBOSE("### JIT compile & running\n"); + return runJit(module.get(), *keySet, os); } if (cmdline::toLLVM) { return dumpLLVMIR(module.get(), os); diff --git a/compiler/tests/RunJit/lowlfhe_identity.mlir b/compiler/tests/RunJit/lowlfhe_identity.mlir new file mode 100644 index 000000000..9cd0e8f91 --- /dev/null +++ b/compiler/tests/RunJit/lowlfhe_identity.mlir @@ -0,0 +1,6 @@ +// RUN: zamacompiler %s --run-jit --jit-args 42 2>&1| FileCheck %s + +// CHECK-LABEL: 42 +func @main(%arg0: !LowLFHE.lwe_ciphertext) -> !LowLFHE.lwe_ciphertext { + return %arg0 : !LowLFHE.lwe_ciphertext +} \ No newline at end of file