From cb2c9ef6bf286d90d63123e7e1c24e1a657144e7 Mon Sep 17 00:00:00 2001 From: rudy Date: Mon, 19 Sep 2022 15:01:55 +0200 Subject: [PATCH] feat: accept no evaluation keys --- .github/workflows/continuous-integration.yml | 2 + compiler/concrete-optimizer | 2 +- .../concretelang/ClientLib/ClientParameters.h | 2 + .../include/concretelang/ClientLib/KeySet.h | 7 +- compiler/lib/ClientLib/Serializers.cpp | 38 +++++++--- compiler/lib/Support/V0ClientParameters.cpp | 74 +++++++++++-------- .../ClientLib/ClientParameters.cpp | 4 +- 7 files changed, 82 insertions(+), 47 deletions(-) diff --git a/.github/workflows/continuous-integration.yml b/.github/workflows/continuous-integration.yml index 6c1002a4d..6b9c4d2ff 100644 --- a/.github/workflows/continuous-integration.yml +++ b/.github/workflows/continuous-integration.yml @@ -132,6 +132,8 @@ jobs: then echo Cleaning $TO_CLEAN rm -rf $TO_CLEAN + echo New cache size is + du -sh KeySetCache else echo Nothing to clean fi diff --git a/compiler/concrete-optimizer b/compiler/concrete-optimizer index bf6bdcfec..06c8bdfce 160000 --- a/compiler/concrete-optimizer +++ b/compiler/concrete-optimizer @@ -1 +1 @@ -Subproject commit bf6bdcfec247d4bba8f425f0565efffe3e961d97 +Subproject commit 06c8bdfce065b1d07e2be13ee96c68fe4dbabff1 diff --git a/compiler/include/concretelang/ClientLib/ClientParameters.h b/compiler/include/concretelang/ClientLib/ClientParameters.h index cad26c1ed..f0206594e 100644 --- a/compiler/include/concretelang/ClientLib/ClientParameters.h +++ b/compiler/include/concretelang/ClientLib/ClientParameters.h @@ -37,6 +37,8 @@ using concretelang::error::StringError; const std::string SMALL_KEY = "small"; const std::string BIG_KEY = "big"; +const std::string BOOTSTRAP_KEY = "bsk_v0"; +const std::string KEYSWITCH_KEY = "ksk_v0"; const std::string CLIENT_PARAMETERS_EXT = ".concrete.params.json"; diff --git a/compiler/include/concretelang/ClientLib/KeySet.h b/compiler/include/concretelang/ClientLib/KeySet.h index d3b189e3b..4b5e582f5 100644 --- a/compiler/include/concretelang/ClientLib/KeySet.h +++ b/compiler/include/concretelang/ClientLib/KeySet.h @@ -87,8 +87,11 @@ public: } EvaluationKeys evaluationKeys() { - auto kskIt = this->keyswitchKeys.find("ksk_v0"); - auto bskIt = this->bootstrapKeys.find("bsk_v0"); + if (this->bootstrapKeys.empty() && this->keyswitchKeys.empty()) { + return EvaluationKeys(); + } + auto kskIt = this->keyswitchKeys.find(clientlib::KEYSWITCH_KEY); + auto bskIt = this->bootstrapKeys.find(clientlib::BOOTSTRAP_KEY); if (kskIt != this->keyswitchKeys.end() && bskIt != this->bootstrapKeys.end()) { auto sharedKsk = std::get<1>(kskIt->second); diff --git a/compiler/lib/ClientLib/Serializers.cpp b/compiler/lib/ClientLib/Serializers.cpp index a2746f858..f57b03266 100644 --- a/compiler/lib/ClientLib/Serializers.cpp +++ b/compiler/lib/ClientLib/Serializers.cpp @@ -241,24 +241,40 @@ std::istream &operator>>(std::istream &istream, LweBootstrapKey &wrappedBsk) { std::ostream &operator<<(std::ostream &ostream, const EvaluationKeys &evaluationKeys) { - ostream << *evaluationKeys.sharedKsk; - ostream << *evaluationKeys.sharedBsk; + bool has_ksk = (bool)evaluationKeys.sharedKsk; + writeWord(ostream, has_ksk); + if (has_ksk) { + ostream << *evaluationKeys.sharedKsk; + } + + bool has_bsk = (bool)evaluationKeys.sharedBsk; + writeWord(ostream, has_bsk); + if (has_bsk) { + ostream << *evaluationKeys.sharedBsk; + } assert(ostream.good()); return ostream; } std::istream &operator>>(std::istream &istream, EvaluationKeys &evaluationKeys) { - auto sharedKsk = LweKeyswitchKey(nullptr); - auto sharedBsk = LweBootstrapKey(nullptr); + bool has_ksk; + readWord(istream, has_ksk); + if (has_ksk) { + auto sharedKsk = LweKeyswitchKey(nullptr); + istream >> sharedKsk; + evaluationKeys.sharedKsk = + std::make_shared(std::move(sharedKsk)); + } - istream >> sharedKsk; - istream >> sharedBsk; - - evaluationKeys.sharedKsk = - std::make_shared(std::move(sharedKsk)); - evaluationKeys.sharedBsk = - std::make_shared(std::move(sharedBsk)); + bool has_bsk; + readWord(istream, has_bsk); + if (has_bsk) { + auto sharedBsk = LweBootstrapKey(nullptr); + istream >> sharedBsk; + evaluationKeys.sharedBsk = + std::make_shared(std::move(sharedBsk)); + } assert(istream.good()); return istream; diff --git a/compiler/lib/Support/V0ClientParameters.cpp b/compiler/lib/Support/V0ClientParameters.cpp index d98072a5c..728f27b07 100644 --- a/compiler/lib/Support/V0ClientParameters.cpp +++ b/compiler/lib/Support/V0ClientParameters.cpp @@ -2,6 +2,7 @@ // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. +#include #include #include @@ -19,14 +20,13 @@ namespace mlir { namespace concretelang { -using ::concretelang::clientlib::BIG_KEY; +namespace clientlib = ::concretelang::clientlib; using ::concretelang::clientlib::CircuitGate; using ::concretelang::clientlib::ClientParameters; using ::concretelang::clientlib::Encoding; using ::concretelang::clientlib::EncryptionGate; using ::concretelang::clientlib::LweSecretKeyID; using ::concretelang::clientlib::Precision; -using ::concretelang::clientlib::SMALL_KEY; using ::concretelang::clientlib::Variance; const auto securityLevel = SECURITY_LEVEL_128; @@ -105,34 +105,43 @@ createClientParametersForV0(V0FHEContext fheContext, // Static client parameters from global parameters for v0 ClientParameters c; c.secretKeys = { - {SMALL_KEY, {/*.size = */ v0Param.nSmall}}, - {BIG_KEY, {/*.size = */ v0Param.getNBigGlweDimension()}}, - }; - c.bootstrapKeys = { - { - "bsk_v0", - { - /*.inputSecretKeyID = */ SMALL_KEY, - /*.outputSecretKeyID = */ BIG_KEY, - /*.level = */ v0Param.brLevel, - /*.baseLog = */ v0Param.brLogBase, - /*.glweDimension = */ v0Param.glweDimension, - /*.variance = */ encryptionVariance, - }, - }, - }; - c.keyswitchKeys = { - { - "ksk_v0", - { - /*.inputSecretKeyID = */ BIG_KEY, - /*.outputSecretKeyID = */ SMALL_KEY, - /*.level = */ v0Param.ksLevel, - /*.baseLog = */ v0Param.ksLogBase, - /*.variance = */ keyswitchVariance, - }, - }, + {clientlib::BIG_KEY, {/*.size = */ v0Param.getNBigGlweDimension()}}, }; + bool has_small_key = v0Param.nSmall != 0; + bool has_bootstrap = v0Param.brLevel != 0; + if (has_small_key) { + c.secretKeys.insert({clientlib::SMALL_KEY, {/*.size = */ v0Param.nSmall}}); + } + if (has_bootstrap) { + auto inputKey = (has_small_key) ? clientlib::SMALL_KEY : clientlib::BIG_KEY; + c.bootstrapKeys = { + { + clientlib::BOOTSTRAP_KEY, + { + /*.inputSecretKeyID = */ inputKey, + /*.outputSecretKeyID = */ clientlib::BIG_KEY, + /*.level = */ v0Param.brLevel, + /*.baseLog = */ v0Param.brLogBase, + /*.glweDimension = */ v0Param.glweDimension, + /*.variance = */ encryptionVariance, + }, + }, + }; + } + if (has_small_key) { + c.keyswitchKeys = { + { + clientlib::KEYSWITCH_KEY, + { + /*.inputSecretKeyID = */ clientlib::BIG_KEY, + /*.outputSecretKeyID = */ clientlib::SMALL_KEY, + /*.level = */ v0Param.ksLevel, + /*.baseLog = */ v0Param.ksLogBase, + /*.variance = */ keyswitchVariance, + }, + }, + }; + } c.functionName = (std::string)functionName; // Find the input function auto rangeOps = module.getOps(); @@ -155,16 +164,19 @@ createClientParametersForV0(V0FHEContext fheContext, ? false : inputs.back().isa(); + auto gateFromType = [&](mlir::Type ty) { + return gateFromMLIRType(clientlib::BIG_KEY, encryptionVariance, ty); + }; for (auto inType = funcType.getInputs().begin(); inType < funcType.getInputs().end() - hasContext; inType++) { - auto gate = gateFromMLIRType(BIG_KEY, encryptionVariance, *inType); + auto gate = gateFromType(*inType); if (auto err = gate.takeError()) { return std::move(err); } c.inputs.push_back(gate.get()); } for (auto outType : funcType.getResults()) { - auto gate = gateFromMLIRType(BIG_KEY, encryptionVariance, outType); + auto gate = gateFromType(outType); if (auto err = gate.takeError()) { return std::move(err); } diff --git a/compiler/tests/unit_tests/concretelang/ClientLib/ClientParameters.cpp b/compiler/tests/unit_tests/concretelang/ClientLib/ClientParameters.cpp index 2ef688186..b960d1222 100644 --- a/compiler/tests/unit_tests/concretelang/ClientLib/ClientParameters.cpp +++ b/compiler/tests/unit_tests/concretelang/ClientLib/ClientParameters.cpp @@ -13,7 +13,7 @@ TEST(Support, client_parameters_json_serde) { {clientlib::BIG_KEY, {/*.size = */ 14}}, }; params0.bootstrapKeys = { - {"bsk_v0", + {clientlib::BOOTSTRAP_KEY, {/*.inputSecretKeyID = */ clientlib::SMALL_KEY, /*.outputSecretKeyID = */ clientlib::BIG_KEY, /*.level = */ 1, @@ -30,7 +30,7 @@ TEST(Support, client_parameters_json_serde) { /*.variance = */ 0.0001, }}, }; - params0.keyswitchKeys = {{"ksk_v0", + params0.keyswitchKeys = {{clientlib::KEYSWITCH_KEY, { /*.inputSecretKeyID = */ clientlib::BIG_KEY,