mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(KeySetCache): detect corrupted keysetcache
This commit is contained in:
@@ -24,11 +24,25 @@ namespace clientlib {
|
||||
|
||||
using StringError = concretelang::error::StringError;
|
||||
|
||||
static std::string readFile(llvm::SmallString<0> &path) {
|
||||
template <class Key>
|
||||
outcome::checked<Key *, StringError> load(llvm::SmallString<0> &path,
|
||||
Key *(*deser)(BufferView buffer)) {
|
||||
std::ifstream in((std::string)path, std::ofstream::binary);
|
||||
if (in.fail()) {
|
||||
return StringError("Cannot access " + (std::string)path);
|
||||
}
|
||||
std::stringstream sbuffer;
|
||||
sbuffer << in.rdbuf();
|
||||
return sbuffer.str();
|
||||
if (in.fail()) {
|
||||
return StringError("Cannot read " + (std::string)path);
|
||||
}
|
||||
auto content = sbuffer.str();
|
||||
BufferView buffer = {(const uint8_t *)content.c_str(), content.length()};
|
||||
auto result = deser(buffer);
|
||||
if (result == nullptr) {
|
||||
return StringError("Cannot deserialize " + (std::string)path);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static void writeFile(llvm::SmallString<0> &path, Buffer content) {
|
||||
@@ -37,22 +51,19 @@ static void writeFile(llvm::SmallString<0> &path, Buffer content) {
|
||||
out.close();
|
||||
}
|
||||
|
||||
LweSecretKey_u64 *loadSecretKey(llvm::SmallString<0> &path) {
|
||||
std::string content = readFile(path);
|
||||
BufferView buffer = {(const uint8_t *)content.c_str(), content.length()};
|
||||
return deserialize_lwe_secret_key_u64(buffer);
|
||||
outcome::checked<LweSecretKey_u64 *, StringError>
|
||||
loadSecretKey(llvm::SmallString<0> &path) {
|
||||
return load(path, deserialize_lwe_secret_key_u64);
|
||||
}
|
||||
|
||||
LweKeyswitchKey_u64 *loadKeyswitchKey(llvm::SmallString<0> &path) {
|
||||
std::string content = readFile(path);
|
||||
BufferView buffer = {(const uint8_t *)content.c_str(), content.length()};
|
||||
return deserialize_lwe_keyswitching_key_u64(buffer);
|
||||
outcome::checked<LweKeyswitchKey_u64 *, StringError>
|
||||
loadKeyswitchKey(llvm::SmallString<0> &path) {
|
||||
return load(path, deserialize_lwe_keyswitching_key_u64);
|
||||
}
|
||||
|
||||
LweBootstrapKey_u64 *loadBootstrapKey(llvm::SmallString<0> &path) {
|
||||
std::string content = readFile(path);
|
||||
BufferView buffer = {(const uint8_t *)content.c_str(), content.length()};
|
||||
return deserialize_lwe_bootstrap_key_u64(buffer);
|
||||
outcome::checked<LweBootstrapKey_u64 *, StringError>
|
||||
loadBootstrapKey(llvm::SmallString<0> &path) {
|
||||
return load(path, deserialize_lwe_bootstrap_key_u64);
|
||||
}
|
||||
|
||||
void saveSecretKey(llvm::SmallString<0> &path, LweSecretKey_u64 *key) {
|
||||
@@ -96,7 +107,7 @@ KeySetCache::loadKeys(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
auto param = secretKeyParam.second;
|
||||
llvm::SmallString<0> path(folderPath);
|
||||
llvm::sys::path::append(path, "secretKey_" + id);
|
||||
LweSecretKey_u64 *sk = loadSecretKey(path);
|
||||
OUTCOME_TRY(LweSecretKey_u64 * sk, loadSecretKey(path));
|
||||
secretKeys[id] = {param, sk};
|
||||
}
|
||||
// Load bootstrap keys
|
||||
@@ -105,7 +116,7 @@ KeySetCache::loadKeys(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
auto param = bootstrapKeyParam.second;
|
||||
llvm::SmallString<0> path(folderPath);
|
||||
llvm::sys::path::append(path, "pbsKey_" + id);
|
||||
LweBootstrapKey_u64 *bsk = loadBootstrapKey(path);
|
||||
OUTCOME_TRY(LweBootstrapKey_u64 * bsk, loadBootstrapKey(path));
|
||||
bootstrapKeys[id] = {param, bsk};
|
||||
}
|
||||
// Load keyswitch keys
|
||||
@@ -114,7 +125,7 @@ KeySetCache::loadKeys(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
auto param = keyswitchParam.second;
|
||||
llvm::SmallString<0> path(folderPath);
|
||||
llvm::sys::path::append(path, "ksKey_" + id);
|
||||
LweKeyswitchKey_u64 *ksk = loadKeyswitchKey(path);
|
||||
OUTCOME_TRY(LweKeyswitchKey_u64 * ksk, loadKeyswitchKey(path));
|
||||
keyswitchKeys[id] = {param, ksk};
|
||||
}
|
||||
|
||||
@@ -187,7 +198,15 @@ KeySetCache::loadOrGenerateSave(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
std::to_string(seed_lsb));
|
||||
|
||||
if (llvm::sys::fs::exists(folderPath)) {
|
||||
return loadKeys(params, seed_msb, seed_lsb, std::string(folderPath));
|
||||
auto keys = loadKeys(params, seed_msb, seed_lsb, std::string(folderPath));
|
||||
if (keys.has_value()) {
|
||||
return keys;
|
||||
} else {
|
||||
std::cerr << std::string(keys.error().mesg) << "\n";
|
||||
std::cerr << "Regenerating KeySetCache entry " << std::string(folderPath)
|
||||
<< "\n";
|
||||
llvm::sys::fs::remove_directories(folderPath);
|
||||
}
|
||||
}
|
||||
|
||||
// Creating a lock for concurrent generation
|
||||
|
||||
Reference in New Issue
Block a user