From 1fca8f4b918db9a60068cb032c198d767dc0832c Mon Sep 17 00:00:00 2001 From: rudy Date: Thu, 14 Apr 2022 11:23:44 +0200 Subject: [PATCH] feat(KeySetCache): detect corrupted keysetcache --- compiler/lib/ClientLib/KeySetCache.cpp | 55 +++++++++++++++++--------- 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/compiler/lib/ClientLib/KeySetCache.cpp b/compiler/lib/ClientLib/KeySetCache.cpp index 4cf2d8629..9338a45eb 100644 --- a/compiler/lib/ClientLib/KeySetCache.cpp +++ b/compiler/lib/ClientLib/KeySetCache.cpp @@ -24,11 +24,25 @@ namespace clientlib { using StringError = concretelang::error::StringError; -static std::string readFile(llvm::SmallString<0> &path) { +template +outcome::checked load(llvm::SmallString<0> &path, + Key *(*deser)(BufferView buffer)) { std::ifstream in((std::string)path, std::ofstream::binary); + if (in.fail()) { + return StringError("Cannot access " + (std::string)path); + } std::stringstream sbuffer; sbuffer << in.rdbuf(); - return sbuffer.str(); + if (in.fail()) { + return StringError("Cannot read " + (std::string)path); + } + auto content = sbuffer.str(); + BufferView buffer = {(const uint8_t *)content.c_str(), content.length()}; + auto result = deser(buffer); + if (result == nullptr) { + return StringError("Cannot deserialize " + (std::string)path); + } + return result; } static void writeFile(llvm::SmallString<0> &path, Buffer content) { @@ -37,22 +51,19 @@ static void writeFile(llvm::SmallString<0> &path, Buffer content) { out.close(); } -LweSecretKey_u64 *loadSecretKey(llvm::SmallString<0> &path) { - std::string content = readFile(path); - BufferView buffer = {(const uint8_t *)content.c_str(), content.length()}; - return deserialize_lwe_secret_key_u64(buffer); +outcome::checked +loadSecretKey(llvm::SmallString<0> &path) { + return load(path, deserialize_lwe_secret_key_u64); } -LweKeyswitchKey_u64 *loadKeyswitchKey(llvm::SmallString<0> &path) { - std::string content = readFile(path); - BufferView buffer = {(const uint8_t *)content.c_str(), content.length()}; - return deserialize_lwe_keyswitching_key_u64(buffer); +outcome::checked +loadKeyswitchKey(llvm::SmallString<0> &path) { + return load(path, deserialize_lwe_keyswitching_key_u64); } -LweBootstrapKey_u64 *loadBootstrapKey(llvm::SmallString<0> &path) { - std::string content = readFile(path); - BufferView buffer = {(const uint8_t *)content.c_str(), content.length()}; - return deserialize_lwe_bootstrap_key_u64(buffer); +outcome::checked +loadBootstrapKey(llvm::SmallString<0> &path) { + return load(path, deserialize_lwe_bootstrap_key_u64); } void saveSecretKey(llvm::SmallString<0> &path, LweSecretKey_u64 *key) { @@ -96,7 +107,7 @@ KeySetCache::loadKeys(ClientParameters ¶ms, uint64_t seed_msb, auto param = secretKeyParam.second; llvm::SmallString<0> path(folderPath); llvm::sys::path::append(path, "secretKey_" + id); - LweSecretKey_u64 *sk = loadSecretKey(path); + OUTCOME_TRY(LweSecretKey_u64 * sk, loadSecretKey(path)); secretKeys[id] = {param, sk}; } // Load bootstrap keys @@ -105,7 +116,7 @@ KeySetCache::loadKeys(ClientParameters ¶ms, uint64_t seed_msb, auto param = bootstrapKeyParam.second; llvm::SmallString<0> path(folderPath); llvm::sys::path::append(path, "pbsKey_" + id); - LweBootstrapKey_u64 *bsk = loadBootstrapKey(path); + OUTCOME_TRY(LweBootstrapKey_u64 * bsk, loadBootstrapKey(path)); bootstrapKeys[id] = {param, bsk}; } // Load keyswitch keys @@ -114,7 +125,7 @@ KeySetCache::loadKeys(ClientParameters ¶ms, uint64_t seed_msb, auto param = keyswitchParam.second; llvm::SmallString<0> path(folderPath); llvm::sys::path::append(path, "ksKey_" + id); - LweKeyswitchKey_u64 *ksk = loadKeyswitchKey(path); + OUTCOME_TRY(LweKeyswitchKey_u64 * ksk, loadKeyswitchKey(path)); keyswitchKeys[id] = {param, ksk}; } @@ -187,7 +198,15 @@ KeySetCache::loadOrGenerateSave(ClientParameters ¶ms, uint64_t seed_msb, std::to_string(seed_lsb)); if (llvm::sys::fs::exists(folderPath)) { - return loadKeys(params, seed_msb, seed_lsb, std::string(folderPath)); + auto keys = loadKeys(params, seed_msb, seed_lsb, std::string(folderPath)); + if (keys.has_value()) { + return keys; + } else { + std::cerr << std::string(keys.error().mesg) << "\n"; + std::cerr << "Regenerating KeySetCache entry " << std::string(folderPath) + << "\n"; + llvm::sys::fs::remove_directories(folderPath); + } } // Creating a lock for concurrent generation