feat(KeySetCache): detect corrupted keysetcache

This commit is contained in:
rudy
2022-04-14 11:23:44 +02:00
committed by rudy-6-4
parent 1b34388d6e
commit 1fca8f4b91

View File

@@ -24,11 +24,25 @@ namespace clientlib {
using StringError = concretelang::error::StringError;
static std::string readFile(llvm::SmallString<0> &path) {
template <class Key>
outcome::checked<Key *, StringError> load(llvm::SmallString<0> &path,
Key *(*deser)(BufferView buffer)) {
std::ifstream in((std::string)path, std::ofstream::binary);
if (in.fail()) {
return StringError("Cannot access " + (std::string)path);
}
std::stringstream sbuffer;
sbuffer << in.rdbuf();
return sbuffer.str();
if (in.fail()) {
return StringError("Cannot read " + (std::string)path);
}
auto content = sbuffer.str();
BufferView buffer = {(const uint8_t *)content.c_str(), content.length()};
auto result = deser(buffer);
if (result == nullptr) {
return StringError("Cannot deserialize " + (std::string)path);
}
return result;
}
static void writeFile(llvm::SmallString<0> &path, Buffer content) {
@@ -37,22 +51,19 @@ static void writeFile(llvm::SmallString<0> &path, Buffer content) {
out.close();
}
LweSecretKey_u64 *loadSecretKey(llvm::SmallString<0> &path) {
std::string content = readFile(path);
BufferView buffer = {(const uint8_t *)content.c_str(), content.length()};
return deserialize_lwe_secret_key_u64(buffer);
outcome::checked<LweSecretKey_u64 *, StringError>
loadSecretKey(llvm::SmallString<0> &path) {
return load(path, deserialize_lwe_secret_key_u64);
}
LweKeyswitchKey_u64 *loadKeyswitchKey(llvm::SmallString<0> &path) {
std::string content = readFile(path);
BufferView buffer = {(const uint8_t *)content.c_str(), content.length()};
return deserialize_lwe_keyswitching_key_u64(buffer);
outcome::checked<LweKeyswitchKey_u64 *, StringError>
loadKeyswitchKey(llvm::SmallString<0> &path) {
return load(path, deserialize_lwe_keyswitching_key_u64);
}
LweBootstrapKey_u64 *loadBootstrapKey(llvm::SmallString<0> &path) {
std::string content = readFile(path);
BufferView buffer = {(const uint8_t *)content.c_str(), content.length()};
return deserialize_lwe_bootstrap_key_u64(buffer);
outcome::checked<LweBootstrapKey_u64 *, StringError>
loadBootstrapKey(llvm::SmallString<0> &path) {
return load(path, deserialize_lwe_bootstrap_key_u64);
}
void saveSecretKey(llvm::SmallString<0> &path, LweSecretKey_u64 *key) {
@@ -96,7 +107,7 @@ KeySetCache::loadKeys(ClientParameters &params, uint64_t seed_msb,
auto param = secretKeyParam.second;
llvm::SmallString<0> path(folderPath);
llvm::sys::path::append(path, "secretKey_" + id);
LweSecretKey_u64 *sk = loadSecretKey(path);
OUTCOME_TRY(LweSecretKey_u64 * sk, loadSecretKey(path));
secretKeys[id] = {param, sk};
}
// Load bootstrap keys
@@ -105,7 +116,7 @@ KeySetCache::loadKeys(ClientParameters &params, uint64_t seed_msb,
auto param = bootstrapKeyParam.second;
llvm::SmallString<0> path(folderPath);
llvm::sys::path::append(path, "pbsKey_" + id);
LweBootstrapKey_u64 *bsk = loadBootstrapKey(path);
OUTCOME_TRY(LweBootstrapKey_u64 * bsk, loadBootstrapKey(path));
bootstrapKeys[id] = {param, bsk};
}
// Load keyswitch keys
@@ -114,7 +125,7 @@ KeySetCache::loadKeys(ClientParameters &params, uint64_t seed_msb,
auto param = keyswitchParam.second;
llvm::SmallString<0> path(folderPath);
llvm::sys::path::append(path, "ksKey_" + id);
LweKeyswitchKey_u64 *ksk = loadKeyswitchKey(path);
OUTCOME_TRY(LweKeyswitchKey_u64 * ksk, loadKeyswitchKey(path));
keyswitchKeys[id] = {param, ksk};
}
@@ -187,7 +198,15 @@ KeySetCache::loadOrGenerateSave(ClientParameters &params, uint64_t seed_msb,
std::to_string(seed_lsb));
if (llvm::sys::fs::exists(folderPath)) {
return loadKeys(params, seed_msb, seed_lsb, std::string(folderPath));
auto keys = loadKeys(params, seed_msb, seed_lsb, std::string(folderPath));
if (keys.has_value()) {
return keys;
} else {
std::cerr << std::string(keys.error().mesg) << "\n";
std::cerr << "Regenerating KeySetCache entry " << std::string(folderPath)
<< "\n";
llvm::sys::fs::remove_directories(folderPath);
}
}
// Creating a lock for concurrent generation