fix(compiler): Integrate the security estimator to compute variances

This commit is contained in:
Quentin Bourgerie
2021-09-08 15:24:52 +02:00
parent 967fda07a0
commit 3a254bcb87
4 changed files with 82 additions and 10 deletions

View File

@@ -0,0 +1,40 @@
#ifndef ZAMALANG_SUPPORT_V0CURVES_H_
#define ZAMALANG_SUPPORT_V0CURVES_H_
#include <cstddef>
namespace mlir {
namespace zamalang {
#define SECURITY_LEVEL_80 0
#define SECURITY_LEVEL_128 1
#define SECURITY_LEVEL_192 2
#define SECURITY_LEVEL_256 3
#define SECURITY_LEVEL_MAX 4
#define KEY_FORMAT_BINARY 0
#define KEY_FORMAT_MAX 1
struct V0Curves {
int securityLevel;
double linearTerm1;
double linearTerm2;
int nAlpha;
int keyFormat;
V0Curves(int securityLevel, double linearTerm1, double linearTerm2,
int nAlpha, int keyFormat)
: securityLevel(securityLevel), linearTerm1(linearTerm1),
linearTerm2(linearTerm2), nAlpha(nAlpha), keyFormat(keyFormat) {}
double getVariance(int k, int polynomialSize, int logQ) {
auto a = std::pow(2, (linearTerm1 * k * polynomialSize + linearTerm2) * 2);
auto b = std::pow(2, -2 * (logQ - 2));
return a > b ? a : b;
}
};
V0Curves *getV0Curves(int securityLevel, int keyFormat);
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -2,6 +2,7 @@ add_mlir_library(ZamalangSupport
CompilerTools.cpp
CompilerEngine.cpp
V0Parameters.cpp
V0Curves.cpp
ClientParameters.cpp
KeySet.cpp

View File

@@ -5,13 +5,19 @@
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
#include "zamalang/Support/ClientParameters.h"
#include "zamalang/Support/V0Curves.h"
namespace mlir {
namespace zamalang {
const auto securityLevel = SECURITY_LEVEL_128;
const auto keyFormat = KEY_FORMAT_BINARY;
const auto v0Curve = getV0Curves(securityLevel, keyFormat);
// For the v0 the secretKeyID and precision are the same for all gates.
llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
Precision precision,
Variance variance,
mlir::Type type) {
if (type.isIntOrIndex()) {
// TODO - The index type is dependant of the target architecture, so
@@ -36,8 +42,7 @@ llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
return CircuitGate{
.encryption = llvm::Optional<EncryptionGate>({
.secretKeyID = secretKeyID,
// TODO - Compute variance, wait for security estimator
.variance = 0.,
.variance = variance,
.encoding = {.precision = precision},
}),
.shape = {.width = precision, .size = 0},
@@ -45,8 +50,8 @@ llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
}
auto tensor = type.dyn_cast_or_null<mlir::RankedTensorType>();
if (tensor != nullptr) {
auto gate =
gateFromMLIRType(secretKeyID, precision, tensor.getElementType());
auto gate = gateFromMLIRType(secretKeyID, precision, variance,
tensor.getElementType());
if (auto err = gate.takeError()) {
return std::move(err);
}
@@ -61,6 +66,9 @@ llvm::Expected<ClientParameters>
createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name,
mlir::ModuleOp module) {
auto v0Param = fheContext.parameter;
Variance encryptionVariance =
v0Curve->getVariance(1, 1 << v0Param.polynomialSize, 64);
Variance keyswitchVariance = v0Curve->getVariance(1, v0Param.nSmall, 64);
// Static client parameters from global parameters for v0
ClientParameters c{
.secretKeys{
@@ -76,8 +84,7 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name,
.level = v0Param.brLevel,
.baseLog = v0Param.brLogBase,
.k = v0Param.k,
// TODO - Compute variance, wait for security estimator
.variance = 0,
.variance = encryptionVariance,
},
},
},
@@ -89,8 +96,7 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name,
.outputSecretKeyID = "small",
.level = v0Param.ksLevel,
.baseLog = v0Param.ksLogBase,
// TODO - Compute variance, wait for security estimator
.variance = 0,
.variance = keyswitchVariance,
},
},
},
@@ -112,14 +118,14 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name,
// Create input and output circuit gate parameters
auto funcType = (*funcOp).getType();
for (auto inType : funcType.getInputs()) {
auto gate = gateFromMLIRType("big", precision, inType);
auto gate = gateFromMLIRType("big", precision, encryptionVariance, inType);
if (auto err = gate.takeError()) {
return std::move(err);
}
c.inputs.push_back(gate.get());
}
for (auto outType : funcType.getResults()) {
auto gate = gateFromMLIRType("big", precision, outType);
auto gate = gateFromMLIRType("big", precision, encryptionVariance, outType);
if (auto err = gate.takeError()) {
return std::move(err);
}

View File

@@ -0,0 +1,25 @@
#include <cmath>
#include "zamalang/Support/V0Curves.h"
namespace mlir {
namespace zamalang {
V0Curves curves[SECURITY_LEVEL_MAX][KEY_FORMAT_MAX] = {
{V0Curves(SECURITY_LEVEL_80, -0.04047677865612648, 1.1433465085639063, 160,
1)},
{V0Curves(SECURITY_LEVEL_128, -0.026374888765705498, 2.012143923330495, 256,
1)},
{V0Curves(SECURITY_LEVEL_192, -0.018504919354426233, 2.6634073426215843,
381, 1)},
{V0Curves(SECURITY_LEVEL_256, -0.014327640360322604, 2.899270827311091, 781,
1)}};
V0Curves *getV0Curves(int securityLevel, int keyFormat) {
if (securityLevel >= SECURITY_LEVEL_MAX || keyFormat >= KEY_FORMAT_MAX) {
return nullptr;
}
return &curves[securityLevel][keyFormat];
}
} // namespace zamalang
} // namespace mlir