From 3a254bcb8725507a11538913d3d5e9657ac00043 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Wed, 8 Sep 2021 15:24:52 +0200 Subject: [PATCH] fix(compiler): Integrate the security estimator to compute variances --- compiler/include/zamalang/Support/V0Curves.h | 40 ++++++++++++++++++++ compiler/lib/Support/CMakeLists.txt | 1 + compiler/lib/Support/ClientParameters.cpp | 26 ++++++++----- compiler/lib/Support/V0Curves.cpp | 25 ++++++++++++ 4 files changed, 82 insertions(+), 10 deletions(-) create mode 100644 compiler/include/zamalang/Support/V0Curves.h create mode 100644 compiler/lib/Support/V0Curves.cpp diff --git a/compiler/include/zamalang/Support/V0Curves.h b/compiler/include/zamalang/Support/V0Curves.h new file mode 100644 index 000000000..51d904420 --- /dev/null +++ b/compiler/include/zamalang/Support/V0Curves.h @@ -0,0 +1,40 @@ +#ifndef ZAMALANG_SUPPORT_V0CURVES_H_ +#define ZAMALANG_SUPPORT_V0CURVES_H_ + +#include + +namespace mlir { +namespace zamalang { + +#define SECURITY_LEVEL_80 0 +#define SECURITY_LEVEL_128 1 +#define SECURITY_LEVEL_192 2 +#define SECURITY_LEVEL_256 3 +#define SECURITY_LEVEL_MAX 4 + +#define KEY_FORMAT_BINARY 0 +#define KEY_FORMAT_MAX 1 + +struct V0Curves { + int securityLevel; + double linearTerm1; + double linearTerm2; + int nAlpha; + int keyFormat; + V0Curves(int securityLevel, double linearTerm1, double linearTerm2, + int nAlpha, int keyFormat) + : securityLevel(securityLevel), linearTerm1(linearTerm1), + linearTerm2(linearTerm2), nAlpha(nAlpha), keyFormat(keyFormat) {} + + double getVariance(int k, int polynomialSize, int logQ) { + auto a = std::pow(2, (linearTerm1 * k * polynomialSize + linearTerm2) * 2); + auto b = std::pow(2, -2 * (logQ - 2)); + return a > b ? a : b; + } +}; + +V0Curves *getV0Curves(int securityLevel, int keyFormat); + +} // namespace zamalang +} // namespace mlir +#endif \ No newline at end of file diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index 0a611ab20..279efbaf2 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(ZamalangSupport CompilerTools.cpp CompilerEngine.cpp V0Parameters.cpp + V0Curves.cpp ClientParameters.cpp KeySet.cpp diff --git a/compiler/lib/Support/ClientParameters.cpp b/compiler/lib/Support/ClientParameters.cpp index 8eea22d7f..368050a0a 100644 --- a/compiler/lib/Support/ClientParameters.cpp +++ b/compiler/lib/Support/ClientParameters.cpp @@ -5,13 +5,19 @@ #include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h" #include "zamalang/Support/ClientParameters.h" +#include "zamalang/Support/V0Curves.h" namespace mlir { namespace zamalang { +const auto securityLevel = SECURITY_LEVEL_128; +const auto keyFormat = KEY_FORMAT_BINARY; +const auto v0Curve = getV0Curves(securityLevel, keyFormat); + // For the v0 the secretKeyID and precision are the same for all gates. llvm::Expected gateFromMLIRType(std::string secretKeyID, Precision precision, + Variance variance, mlir::Type type) { if (type.isIntOrIndex()) { // TODO - The index type is dependant of the target architecture, so @@ -36,8 +42,7 @@ llvm::Expected gateFromMLIRType(std::string secretKeyID, return CircuitGate{ .encryption = llvm::Optional({ .secretKeyID = secretKeyID, - // TODO - Compute variance, wait for security estimator - .variance = 0., + .variance = variance, .encoding = {.precision = precision}, }), .shape = {.width = precision, .size = 0}, @@ -45,8 +50,8 @@ llvm::Expected gateFromMLIRType(std::string secretKeyID, } auto tensor = type.dyn_cast_or_null(); if (tensor != nullptr) { - auto gate = - gateFromMLIRType(secretKeyID, precision, tensor.getElementType()); + auto gate = gateFromMLIRType(secretKeyID, precision, variance, + tensor.getElementType()); if (auto err = gate.takeError()) { return std::move(err); } @@ -61,6 +66,9 @@ llvm::Expected createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name, mlir::ModuleOp module) { auto v0Param = fheContext.parameter; + Variance encryptionVariance = + v0Curve->getVariance(1, 1 << v0Param.polynomialSize, 64); + Variance keyswitchVariance = v0Curve->getVariance(1, v0Param.nSmall, 64); // Static client parameters from global parameters for v0 ClientParameters c{ .secretKeys{ @@ -76,8 +84,7 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name, .level = v0Param.brLevel, .baseLog = v0Param.brLogBase, .k = v0Param.k, - // TODO - Compute variance, wait for security estimator - .variance = 0, + .variance = encryptionVariance, }, }, }, @@ -89,8 +96,7 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name, .outputSecretKeyID = "small", .level = v0Param.ksLevel, .baseLog = v0Param.ksLogBase, - // TODO - Compute variance, wait for security estimator - .variance = 0, + .variance = keyswitchVariance, }, }, }, @@ -112,14 +118,14 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name, // Create input and output circuit gate parameters auto funcType = (*funcOp).getType(); for (auto inType : funcType.getInputs()) { - auto gate = gateFromMLIRType("big", precision, inType); + auto gate = gateFromMLIRType("big", precision, encryptionVariance, inType); if (auto err = gate.takeError()) { return std::move(err); } c.inputs.push_back(gate.get()); } for (auto outType : funcType.getResults()) { - auto gate = gateFromMLIRType("big", precision, outType); + auto gate = gateFromMLIRType("big", precision, encryptionVariance, outType); if (auto err = gate.takeError()) { return std::move(err); } diff --git a/compiler/lib/Support/V0Curves.cpp b/compiler/lib/Support/V0Curves.cpp new file mode 100644 index 000000000..37263fecf --- /dev/null +++ b/compiler/lib/Support/V0Curves.cpp @@ -0,0 +1,25 @@ +#include + +#include "zamalang/Support/V0Curves.h" + +namespace mlir { +namespace zamalang { + +V0Curves curves[SECURITY_LEVEL_MAX][KEY_FORMAT_MAX] = { + {V0Curves(SECURITY_LEVEL_80, -0.04047677865612648, 1.1433465085639063, 160, + 1)}, + {V0Curves(SECURITY_LEVEL_128, -0.026374888765705498, 2.012143923330495, 256, + 1)}, + {V0Curves(SECURITY_LEVEL_192, -0.018504919354426233, 2.6634073426215843, + 381, 1)}, + {V0Curves(SECURITY_LEVEL_256, -0.014327640360322604, 2.899270827311091, 781, + 1)}}; + +V0Curves *getV0Curves(int securityLevel, int keyFormat) { + if (securityLevel >= SECURITY_LEVEL_MAX || keyFormat >= KEY_FORMAT_MAX) { + return nullptr; + } + return &curves[securityLevel][keyFormat]; +} +} // namespace zamalang +} // namespace mlir \ No newline at end of file