feat: accept no evaluation keys

This commit is contained in:
rudy
2022-09-19 15:01:55 +02:00
committed by rudy-6-4
parent 08ed2fc49b
commit cb2c9ef6bf
7 changed files with 82 additions and 47 deletions

View File

@@ -132,6 +132,8 @@ jobs:
then
echo Cleaning $TO_CLEAN
rm -rf $TO_CLEAN
echo New cache size is
du -sh KeySetCache
else
echo Nothing to clean
fi

View File

@@ -37,6 +37,8 @@ using concretelang::error::StringError;
const std::string SMALL_KEY = "small";
const std::string BIG_KEY = "big";
const std::string BOOTSTRAP_KEY = "bsk_v0";
const std::string KEYSWITCH_KEY = "ksk_v0";
const std::string CLIENT_PARAMETERS_EXT = ".concrete.params.json";

View File

@@ -87,8 +87,11 @@ public:
}
EvaluationKeys evaluationKeys() {
auto kskIt = this->keyswitchKeys.find("ksk_v0");
auto bskIt = this->bootstrapKeys.find("bsk_v0");
if (this->bootstrapKeys.empty() && this->keyswitchKeys.empty()) {
return EvaluationKeys();
}
auto kskIt = this->keyswitchKeys.find(clientlib::KEYSWITCH_KEY);
auto bskIt = this->bootstrapKeys.find(clientlib::BOOTSTRAP_KEY);
if (kskIt != this->keyswitchKeys.end() &&
bskIt != this->bootstrapKeys.end()) {
auto sharedKsk = std::get<1>(kskIt->second);

View File

@@ -241,24 +241,40 @@ std::istream &operator>>(std::istream &istream, LweBootstrapKey &wrappedBsk) {
std::ostream &operator<<(std::ostream &ostream,
const EvaluationKeys &evaluationKeys) {
ostream << *evaluationKeys.sharedKsk;
ostream << *evaluationKeys.sharedBsk;
bool has_ksk = (bool)evaluationKeys.sharedKsk;
writeWord(ostream, has_ksk);
if (has_ksk) {
ostream << *evaluationKeys.sharedKsk;
}
bool has_bsk = (bool)evaluationKeys.sharedBsk;
writeWord(ostream, has_bsk);
if (has_bsk) {
ostream << *evaluationKeys.sharedBsk;
}
assert(ostream.good());
return ostream;
}
std::istream &operator>>(std::istream &istream,
EvaluationKeys &evaluationKeys) {
auto sharedKsk = LweKeyswitchKey(nullptr);
auto sharedBsk = LweBootstrapKey(nullptr);
bool has_ksk;
readWord(istream, has_ksk);
if (has_ksk) {
auto sharedKsk = LweKeyswitchKey(nullptr);
istream >> sharedKsk;
evaluationKeys.sharedKsk =
std::make_shared<LweKeyswitchKey>(std::move(sharedKsk));
}
istream >> sharedKsk;
istream >> sharedBsk;
evaluationKeys.sharedKsk =
std::make_shared<LweKeyswitchKey>(std::move(sharedKsk));
evaluationKeys.sharedBsk =
std::make_shared<LweBootstrapKey>(std::move(sharedBsk));
bool has_bsk;
readWord(istream, has_bsk);
if (has_bsk) {
auto sharedBsk = LweBootstrapKey(nullptr);
istream >> sharedBsk;
evaluationKeys.sharedBsk =
std::make_shared<LweBootstrapKey>(std::move(sharedBsk));
}
assert(istream.good());
return istream;

View File

@@ -2,6 +2,7 @@
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <map>
#include <llvm/ADT/Optional.h>
#include <llvm/ADT/STLExtras.h>
@@ -19,14 +20,13 @@
namespace mlir {
namespace concretelang {
using ::concretelang::clientlib::BIG_KEY;
namespace clientlib = ::concretelang::clientlib;
using ::concretelang::clientlib::CircuitGate;
using ::concretelang::clientlib::ClientParameters;
using ::concretelang::clientlib::Encoding;
using ::concretelang::clientlib::EncryptionGate;
using ::concretelang::clientlib::LweSecretKeyID;
using ::concretelang::clientlib::Precision;
using ::concretelang::clientlib::SMALL_KEY;
using ::concretelang::clientlib::Variance;
const auto securityLevel = SECURITY_LEVEL_128;
@@ -105,34 +105,43 @@ createClientParametersForV0(V0FHEContext fheContext,
// Static client parameters from global parameters for v0
ClientParameters c;
c.secretKeys = {
{SMALL_KEY, {/*.size = */ v0Param.nSmall}},
{BIG_KEY, {/*.size = */ v0Param.getNBigGlweDimension()}},
};
c.bootstrapKeys = {
{
"bsk_v0",
{
/*.inputSecretKeyID = */ SMALL_KEY,
/*.outputSecretKeyID = */ BIG_KEY,
/*.level = */ v0Param.brLevel,
/*.baseLog = */ v0Param.brLogBase,
/*.glweDimension = */ v0Param.glweDimension,
/*.variance = */ encryptionVariance,
},
},
};
c.keyswitchKeys = {
{
"ksk_v0",
{
/*.inputSecretKeyID = */ BIG_KEY,
/*.outputSecretKeyID = */ SMALL_KEY,
/*.level = */ v0Param.ksLevel,
/*.baseLog = */ v0Param.ksLogBase,
/*.variance = */ keyswitchVariance,
},
},
{clientlib::BIG_KEY, {/*.size = */ v0Param.getNBigGlweDimension()}},
};
bool has_small_key = v0Param.nSmall != 0;
bool has_bootstrap = v0Param.brLevel != 0;
if (has_small_key) {
c.secretKeys.insert({clientlib::SMALL_KEY, {/*.size = */ v0Param.nSmall}});
}
if (has_bootstrap) {
auto inputKey = (has_small_key) ? clientlib::SMALL_KEY : clientlib::BIG_KEY;
c.bootstrapKeys = {
{
clientlib::BOOTSTRAP_KEY,
{
/*.inputSecretKeyID = */ inputKey,
/*.outputSecretKeyID = */ clientlib::BIG_KEY,
/*.level = */ v0Param.brLevel,
/*.baseLog = */ v0Param.brLogBase,
/*.glweDimension = */ v0Param.glweDimension,
/*.variance = */ encryptionVariance,
},
},
};
}
if (has_small_key) {
c.keyswitchKeys = {
{
clientlib::KEYSWITCH_KEY,
{
/*.inputSecretKeyID = */ clientlib::BIG_KEY,
/*.outputSecretKeyID = */ clientlib::SMALL_KEY,
/*.level = */ v0Param.ksLevel,
/*.baseLog = */ v0Param.ksLogBase,
/*.variance = */ keyswitchVariance,
},
},
};
}
c.functionName = (std::string)functionName;
// Find the input function
auto rangeOps = module.getOps<mlir::func::FuncOp>();
@@ -155,16 +164,19 @@ createClientParametersForV0(V0FHEContext fheContext,
? false
: inputs.back().isa<mlir::concretelang::Concrete::ContextType>();
auto gateFromType = [&](mlir::Type ty) {
return gateFromMLIRType(clientlib::BIG_KEY, encryptionVariance, ty);
};
for (auto inType = funcType.getInputs().begin();
inType < funcType.getInputs().end() - hasContext; inType++) {
auto gate = gateFromMLIRType(BIG_KEY, encryptionVariance, *inType);
auto gate = gateFromType(*inType);
if (auto err = gate.takeError()) {
return std::move(err);
}
c.inputs.push_back(gate.get());
}
for (auto outType : funcType.getResults()) {
auto gate = gateFromMLIRType(BIG_KEY, encryptionVariance, outType);
auto gate = gateFromType(outType);
if (auto err = gate.takeError()) {
return std::move(err);
}

View File

@@ -13,7 +13,7 @@ TEST(Support, client_parameters_json_serde) {
{clientlib::BIG_KEY, {/*.size = */ 14}},
};
params0.bootstrapKeys = {
{"bsk_v0",
{clientlib::BOOTSTRAP_KEY,
{/*.inputSecretKeyID = */ clientlib::SMALL_KEY,
/*.outputSecretKeyID = */ clientlib::BIG_KEY,
/*.level = */ 1,
@@ -30,7 +30,7 @@ TEST(Support, client_parameters_json_serde) {
/*.variance = */ 0.0001,
}},
};
params0.keyswitchKeys = {{"ksk_v0",
params0.keyswitchKeys = {{clientlib::KEYSWITCH_KEY,
{
/*.inputSecretKeyID = */
clientlib::BIG_KEY,