mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
fix(compiler): Integrate the security estimator to compute variances
This commit is contained in:
40
compiler/include/zamalang/Support/V0Curves.h
Normal file
40
compiler/include/zamalang/Support/V0Curves.h
Normal file
@@ -0,0 +1,40 @@
|
||||
#ifndef ZAMALANG_SUPPORT_V0CURVES_H_
|
||||
#define ZAMALANG_SUPPORT_V0CURVES_H_
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
#define SECURITY_LEVEL_80 0
|
||||
#define SECURITY_LEVEL_128 1
|
||||
#define SECURITY_LEVEL_192 2
|
||||
#define SECURITY_LEVEL_256 3
|
||||
#define SECURITY_LEVEL_MAX 4
|
||||
|
||||
#define KEY_FORMAT_BINARY 0
|
||||
#define KEY_FORMAT_MAX 1
|
||||
|
||||
struct V0Curves {
|
||||
int securityLevel;
|
||||
double linearTerm1;
|
||||
double linearTerm2;
|
||||
int nAlpha;
|
||||
int keyFormat;
|
||||
V0Curves(int securityLevel, double linearTerm1, double linearTerm2,
|
||||
int nAlpha, int keyFormat)
|
||||
: securityLevel(securityLevel), linearTerm1(linearTerm1),
|
||||
linearTerm2(linearTerm2), nAlpha(nAlpha), keyFormat(keyFormat) {}
|
||||
|
||||
double getVariance(int k, int polynomialSize, int logQ) {
|
||||
auto a = std::pow(2, (linearTerm1 * k * polynomialSize + linearTerm2) * 2);
|
||||
auto b = std::pow(2, -2 * (logQ - 2));
|
||||
return a > b ? a : b;
|
||||
}
|
||||
};
|
||||
|
||||
V0Curves *getV0Curves(int securityLevel, int keyFormat);
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
#endif
|
||||
@@ -2,6 +2,7 @@ add_mlir_library(ZamalangSupport
|
||||
CompilerTools.cpp
|
||||
CompilerEngine.cpp
|
||||
V0Parameters.cpp
|
||||
V0Curves.cpp
|
||||
ClientParameters.cpp
|
||||
KeySet.cpp
|
||||
|
||||
|
||||
@@ -5,13 +5,19 @@
|
||||
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
|
||||
#include "zamalang/Support/ClientParameters.h"
|
||||
#include "zamalang/Support/V0Curves.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
const auto securityLevel = SECURITY_LEVEL_128;
|
||||
const auto keyFormat = KEY_FORMAT_BINARY;
|
||||
const auto v0Curve = getV0Curves(securityLevel, keyFormat);
|
||||
|
||||
// For the v0 the secretKeyID and precision are the same for all gates.
|
||||
llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
|
||||
Precision precision,
|
||||
Variance variance,
|
||||
mlir::Type type) {
|
||||
if (type.isIntOrIndex()) {
|
||||
// TODO - The index type is dependant of the target architecture, so
|
||||
@@ -36,8 +42,7 @@ llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
|
||||
return CircuitGate{
|
||||
.encryption = llvm::Optional<EncryptionGate>({
|
||||
.secretKeyID = secretKeyID,
|
||||
// TODO - Compute variance, wait for security estimator
|
||||
.variance = 0.,
|
||||
.variance = variance,
|
||||
.encoding = {.precision = precision},
|
||||
}),
|
||||
.shape = {.width = precision, .size = 0},
|
||||
@@ -45,8 +50,8 @@ llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
|
||||
}
|
||||
auto tensor = type.dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
if (tensor != nullptr) {
|
||||
auto gate =
|
||||
gateFromMLIRType(secretKeyID, precision, tensor.getElementType());
|
||||
auto gate = gateFromMLIRType(secretKeyID, precision, variance,
|
||||
tensor.getElementType());
|
||||
if (auto err = gate.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
@@ -61,6 +66,9 @@ llvm::Expected<ClientParameters>
|
||||
createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name,
|
||||
mlir::ModuleOp module) {
|
||||
auto v0Param = fheContext.parameter;
|
||||
Variance encryptionVariance =
|
||||
v0Curve->getVariance(1, 1 << v0Param.polynomialSize, 64);
|
||||
Variance keyswitchVariance = v0Curve->getVariance(1, v0Param.nSmall, 64);
|
||||
// Static client parameters from global parameters for v0
|
||||
ClientParameters c{
|
||||
.secretKeys{
|
||||
@@ -76,8 +84,7 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name,
|
||||
.level = v0Param.brLevel,
|
||||
.baseLog = v0Param.brLogBase,
|
||||
.k = v0Param.k,
|
||||
// TODO - Compute variance, wait for security estimator
|
||||
.variance = 0,
|
||||
.variance = encryptionVariance,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -89,8 +96,7 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name,
|
||||
.outputSecretKeyID = "small",
|
||||
.level = v0Param.ksLevel,
|
||||
.baseLog = v0Param.ksLogBase,
|
||||
// TODO - Compute variance, wait for security estimator
|
||||
.variance = 0,
|
||||
.variance = keyswitchVariance,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -112,14 +118,14 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name,
|
||||
// Create input and output circuit gate parameters
|
||||
auto funcType = (*funcOp).getType();
|
||||
for (auto inType : funcType.getInputs()) {
|
||||
auto gate = gateFromMLIRType("big", precision, inType);
|
||||
auto gate = gateFromMLIRType("big", precision, encryptionVariance, inType);
|
||||
if (auto err = gate.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
c.inputs.push_back(gate.get());
|
||||
}
|
||||
for (auto outType : funcType.getResults()) {
|
||||
auto gate = gateFromMLIRType("big", precision, outType);
|
||||
auto gate = gateFromMLIRType("big", precision, encryptionVariance, outType);
|
||||
if (auto err = gate.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
25
compiler/lib/Support/V0Curves.cpp
Normal file
25
compiler/lib/Support/V0Curves.cpp
Normal file
@@ -0,0 +1,25 @@
|
||||
#include <cmath>
|
||||
|
||||
#include "zamalang/Support/V0Curves.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
V0Curves curves[SECURITY_LEVEL_MAX][KEY_FORMAT_MAX] = {
|
||||
{V0Curves(SECURITY_LEVEL_80, -0.04047677865612648, 1.1433465085639063, 160,
|
||||
1)},
|
||||
{V0Curves(SECURITY_LEVEL_128, -0.026374888765705498, 2.012143923330495, 256,
|
||||
1)},
|
||||
{V0Curves(SECURITY_LEVEL_192, -0.018504919354426233, 2.6634073426215843,
|
||||
381, 1)},
|
||||
{V0Curves(SECURITY_LEVEL_256, -0.014327640360322604, 2.899270827311091, 781,
|
||||
1)}};
|
||||
|
||||
V0Curves *getV0Curves(int securityLevel, int keyFormat) {
|
||||
if (securityLevel >= SECURITY_LEVEL_MAX || keyFormat >= KEY_FORMAT_MAX) {
|
||||
return nullptr;
|
||||
}
|
||||
return &curves[securityLevel][keyFormat];
|
||||
}
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
Reference in New Issue
Block a user