feat: keep std bsk and conv to fourier when needed

This commit is contained in:
youben11
2022-09-02 08:47:52 +01:00
committed by Ayoub Benaissa
parent 942b41d07c
commit 661d33c2b6
9 changed files with 134 additions and 74 deletions

View File

@@ -18,7 +18,7 @@ HPX_TARBALL=$(shell pwd)/hpx-$(HPX_VERSION).tar.gz
HPX_LOCAL_DIR=$(shell pwd)/hpx-$(HPX_VERSION)
HPX_INSTALL_DIR?=$(HPX_LOCAL_DIR)
CONCRETE_CORE_FFI_VERSION?=0.2.0-rc.2
CONCRETE_CORE_FFI_VERSION?=0.2.0-rc.3
ifeq ($(shell uname), Linux)
CONCRETE_CORE_FFI_TARBALL=concrete-core-ffi_$(CONCRETE_CORE_FFI_VERSION)_linux_amd64.tar.gz
else

View File

@@ -46,10 +46,10 @@ public:
// =============================================
/// Wrapper for `FftwFourierLweBootstrapKey64` so that it cleans up properly.
/// Wrapper for `LweBootstrapKey64` so that it cleans up properly.
class LweBootstrapKey {
private:
FftwFourierLweBootstrapKey64 *bsk;
LweBootstrapKey64 *bsk;
protected:
friend std::ostream &operator<<(std::ostream &ostream,
@@ -58,19 +58,19 @@ protected:
LweBootstrapKey &wrappedBsk);
public:
LweBootstrapKey(FftwFourierLweBootstrapKey64 *bsk) : bsk{bsk} {}
LweBootstrapKey(LweBootstrapKey64 *bsk) : bsk{bsk} {}
LweBootstrapKey(LweBootstrapKey &other) = delete;
LweBootstrapKey(LweBootstrapKey &&other) : bsk{other.bsk} {
other.bsk = nullptr;
}
~LweBootstrapKey() {
if (this->bsk != nullptr) {
CAPI_ASSERT_ERROR(destroy_fftw_fourier_lwe_bootstrap_key_u64(this->bsk));
CAPI_ASSERT_ERROR(destroy_lwe_bootstrap_key_u64(this->bsk));
this->bsk = nullptr;
}
}
FftwFourierLweBootstrapKey64 *get() { return this->bsk; }
LweBootstrapKey64 *get() { return this->bsk; }
};
// =============================================
@@ -97,7 +97,7 @@ public:
: sharedKsk{sharedKsk}, sharedBsk{sharedBsk} {}
LweKeyswitchKey64 *getKsk() { return this->sharedKsk->get(); }
FftwFourierLweBootstrapKey64 *getBsk() { return this->sharedBsk->get(); }
LweBootstrapKey64 *getBsk() { return this->sharedBsk->get(); }
};
// =============================================

View File

@@ -21,10 +21,6 @@ namespace mlir {
namespace concretelang {
typedef struct RuntimeContext {
::concretelang::clientlib::EvaluationKeys evaluationKeys;
DefaultEngine *default_engine;
std::map<pthread_t, FftwEngine *> fftw_engines;
std::mutex engines_map_guard;
RuntimeContext() {
CAPI_ASSERT_ERROR(new_default_engine(best_seeder, &default_engine));
@@ -44,12 +40,64 @@ typedef struct RuntimeContext {
for (const auto &key : fftw_engines) {
CAPI_ASSERT_ERROR(destroy_fftw_engine(key.second));
}
if (fbsk != nullptr) {
CAPI_ASSERT_ERROR(destroy_fftw_fourier_lwe_bootstrap_key_u64(fbsk));
}
}
FftwEngine *get_fftw_engine() {
pthread_t threadId = pthread_self();
std::lock_guard<std::mutex> guard(engines_map_guard);
auto engineIt = fftw_engines.find(threadId);
if (engineIt == fftw_engines.end()) {
FftwEngine *fftw_engine = nullptr;
CAPI_ASSERT_ERROR(new_fftw_engine(&fftw_engine));
engineIt =
fftw_engines
.insert(std::pair<pthread_t, FftwEngine *>(threadId, fftw_engine))
.first;
}
assert(engineIt->second && "No engine available in context");
return engineIt->second;
}
DefaultEngine *get_default_engine() { return default_engine; }
FftwFourierLweBootstrapKey64 *get_fftw_fourier_bsk() {
if (fbsk != nullptr) {
return fbsk;
}
const std::lock_guard<std::mutex> guard(fbskMutex);
if (fbsk == nullptr) {
CAPI_ASSERT_ERROR(
fftw_engine_convert_lwe_bootstrap_key_to_fftw_fourier_lwe_bootstrap_key_u64(
get_fftw_engine(), evaluationKeys.getBsk(), &fbsk));
}
return fbsk;
}
LweBootstrapKey64 *get_bsk() { return evaluationKeys.getBsk(); }
LweKeyswitchKey64 *get_ksk() { return evaluationKeys.getKsk(); }
RuntimeContext &operator=(const RuntimeContext &rhs) {
this->evaluationKeys = rhs.evaluationKeys;
return *this;
}
::concretelang::clientlib::EvaluationKeys evaluationKeys;
private:
std::mutex fbskMutex;
FftwFourierLweBootstrapKey64 *fbsk = nullptr;
DefaultEngine *default_engine;
std::map<pthread_t, FftwEngine *> fftw_engines;
std::mutex engines_map_guard;
} RuntimeContext;
} // namespace concretelang
@@ -60,6 +108,9 @@ LweKeyswitchKey64 *
get_keyswitch_key_u64(mlir::concretelang::RuntimeContext *context);
FftwFourierLweBootstrapKey64 *
get_fftw_fourier_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context);
LweBootstrapKey64 *
get_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context);
DefaultEngine *get_engine(mlir::concretelang::RuntimeContext *context);

View File

@@ -64,18 +64,16 @@ KeyWrapper<LweKeyswitchKey64>::KeyWrapper(LweKeyswitchKey64 *key) : key(key) {
&buffer));
}
template <>
KeyWrapper<FftwFourierLweBootstrapKey64>::KeyWrapper(
FftwFourierLweBootstrapKey64 *key)
: key(key) {
KeyWrapper<LweBootstrapKey64>::KeyWrapper(LweBootstrapKey64 *key) : key(key) {
FftwSerializationEngine *engine;
DefaultSerializationEngine *engine;
CAPI_ASSERT_ERROR(new_fftw_serialization_engine(&engine));
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
// No Freeing as it doesn't allocate anything.
CAPI_ASSERT_ERROR(
fftw_serialization_engine_serialize_fftw_fourier_lwe_bootstrap_key_u64(
engine, key, &buffer));
default_serialization_engine_serialize_lwe_bootstrap_key_u64(engine, key,
&buffer));
}
template <typename LweKeyType>
@@ -86,25 +84,25 @@ bool operator==(const KeyWrapper<LweKeyType> &lhs,
template <>
template <class Archive>
void KeyWrapper<FftwFourierLweBootstrapKey64>::save(
Archive &ar, const unsigned int version) const {
void KeyWrapper<LweBootstrapKey64>::save(Archive &ar,
const unsigned int version) const {
ar << buffer.length;
ar << hpx::serialization::make_array(buffer.pointer, buffer.length);
}
template <>
template <class Archive>
void KeyWrapper<FftwFourierLweBootstrapKey64>::load(
Archive &ar, const unsigned int version) {
FftwSerializationEngine *engine;
void KeyWrapper<LweBootstrapKey64>::load(Archive &ar,
const unsigned int version) {
DefaultSerializationEngine *engine;
// No Freeing as it doesn't allocate anything.
CAPI_ASSERT_ERROR(new_fftw_serialization_engine(&engine));
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
ar >> buffer.length;
buffer.pointer = new uint8_t[buffer.length];
ar >> hpx::serialization::make_array(buffer.pointer, buffer.length);
CAPI_ASSERT_ERROR(
fftw_serialization_engine_deserialize_fftw_fourier_lwe_bootstrap_key_u64(
default_serialization_engine_deserialize_lwe_bootstrap_key_u64(
engine, {buffer.pointer, buffer.length}, &key));
}
@@ -155,21 +153,22 @@ struct RuntimeContextManager {
if (_dfr_is_root_node()) {
RuntimeContext *context = (RuntimeContext *)ctx;
LweKeyswitchKey64 *ksk = get_keyswitch_key_u64(context);
FftwFourierLweBootstrapKey64 *bsk = get_bootstrap_key_u64(context);
LweBootstrapKey64 *bsk = get_bootstrap_key_u64(context);
KeyWrapper<LweKeyswitchKey64> kskw(ksk);
KeyWrapper<FftwFourierLweBootstrapKey64> bskw(bsk);
KeyWrapper<LweBootstrapKey64> bskw(bsk);
hpx::collectives::broadcast_to("ksk_keystore", kskw);
hpx::collectives::broadcast_to("bsk_keystore", bskw);
} else {
auto kskFut =
hpx::collectives::broadcast_from<KeyWrapper<LweKeyswitchKey64>>(
"ksk_keystore");
auto bskFut = hpx::collectives::broadcast_from<
KeyWrapper<FftwFourierLweBootstrapKey64>>("bsk_keystore");
auto bskFut =
hpx::collectives::broadcast_from<KeyWrapper<LweBootstrapKey64>>(
"bsk_keystore");
KeyWrapper<LweKeyswitchKey64> kskw = kskFut.get();
KeyWrapper<FftwFourierLweBootstrapKey64> bskw = bskFut.get();
KeyWrapper<LweBootstrapKey64> bskw = bskFut.get();
context = new mlir::concretelang::RuntimeContext();
context->evaluationKeys = ::concretelang::clientlib::EvaluationKeys(
std::shared_ptr<::concretelang::clientlib::LweKeyswitchKey>(

View File

@@ -197,18 +197,10 @@ KeySet::generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param) {
par_engine, inputSk->second.second, output_glwe_sk, param.baseLog,
param.level, param.variance, &bsk));
FftwFourierLweBootstrapKey64 *fbsk;
CAPI_ASSERT_ERROR(
fftw_engine_convert_lwe_bootstrap_key_to_fftw_fourier_lwe_bootstrap_key_u64(
fftw_engine, bsk, &fbsk));
CAPI_ASSERT_ERROR(destroy_lwe_bootstrap_key_u64(bsk));
CAPI_ASSERT_ERROR(destroy_glwe_secret_key_u64(output_glwe_sk));
// Store the bootstrap key
bootstrapKeys[id] = {param, std::make_shared<LweBootstrapKey>(fbsk)};
bootstrapKeys[id] = {param, std::make_shared<LweBootstrapKey>(bsk)};
return outcome::success();
}

View File

@@ -75,17 +75,16 @@ loadKeyswitchKey(llvm::SmallString<0> &path) {
engine);
}
outcome::checked<FftwFourierLweBootstrapKey64 *, StringError>
outcome::checked<LweBootstrapKey64 *, StringError>
loadBootstrapKey(llvm::SmallString<0> &path) {
FftwSerializationEngine *engine;
DefaultSerializationEngine *engine;
CAPI_ASSERT_ERROR(new_fftw_serialization_engine(&engine));
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
return load(
path,
fftw_serialization_engine_deserialize_fftw_fourier_lwe_bootstrap_key_u64,
engine);
return load(path,
default_serialization_engine_deserialize_lwe_bootstrap_key_u64,
engine);
}
void saveSecretKey(llvm::SmallString<0> &path, LweSecretKey64 *key) {
@@ -103,18 +102,16 @@ void saveSecretKey(llvm::SmallString<0> &path, LweSecretKey64 *key) {
free(buffer.pointer);
}
void saveBootstrapKey(llvm::SmallString<0> &path,
FftwFourierLweBootstrapKey64 *key) {
FftwSerializationEngine *engine;
void saveBootstrapKey(llvm::SmallString<0> &path, LweBootstrapKey64 *key) {
DefaultSerializationEngine *engine;
CAPI_ASSERT_ERROR(new_fftw_serialization_engine(&engine));
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
Buffer buffer;
CAPI_ASSERT_ERROR(
fftw_serialization_engine_serialize_fftw_fourier_lwe_bootstrap_key_u64(
engine, key, &buffer));
default_serialization_engine_serialize_lwe_bootstrap_key_u64(engine, key,
&buffer));
writeFile(path, buffer);
free(buffer.pointer);
}
@@ -166,7 +163,7 @@ KeySetCache::loadKeys(ClientParameters &params, uint64_t seed_msb,
auto param = bootstrapKeyParam.second;
llvm::SmallString<0> path(folderPath);
llvm::sys::path::append(path, "pbsKey_" + id);
OUTCOME_TRY(FftwFourierLweBootstrapKey64 * bsk, loadBootstrapKey(path));
OUTCOME_TRY(LweBootstrapKey64 * bsk, loadBootstrapKey(path));
bootstrapKeys[id] = {param, std::make_shared<LweBootstrapKey>(bsk)};
}
// Load keyswitch keys

View File

@@ -60,6 +60,24 @@ std::ostream &operator<<(std::ostream &ostream, const LweKeyswitchKey64 *key) {
return ostream;
}
std::ostream &operator<<(std::ostream &ostream, const LweBootstrapKey64 *key) {
DefaultSerializationEngine *engine;
// No Freeing as it doesn't allocate anything.
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
Buffer b;
CAPI_ASSERT_ERROR(
default_serialization_engine_serialize_lwe_bootstrap_key_u64(engine, key,
&b))
writeBufferLike(ostream, b);
free((void *)b.pointer);
b.pointer = nullptr;
return ostream;
}
std::ostream &operator<<(std::ostream &ostream,
const FftwFourierLweBootstrapKey64 *key) {
FftwSerializationEngine *engine;
@@ -91,6 +109,18 @@ std::istream &operator>>(std::istream &istream, LweKeyswitchKey64 *&key) {
return istream;
}
std::istream &operator>>(std::istream &istream, LweBootstrapKey64 *&key) {
DefaultSerializationEngine *engine;
// No Freeing as it doesn't allocate anything.
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
key = read_deser(
istream, default_serialization_engine_deserialize_lwe_bootstrap_key_u64,
engine);
return istream;
}
std::istream &operator>>(std::istream &istream,
FftwFourierLweBootstrapKey64 *&key) {
FftwSerializationEngine *engine;

View File

@@ -11,32 +11,23 @@
LweKeyswitchKey64 *
get_keyswitch_key_u64(mlir::concretelang::RuntimeContext *context) {
return context->evaluationKeys.getKsk();
return context->get_ksk();
}
FftwFourierLweBootstrapKey64 *
LweBootstrapKey64 *
get_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context) {
return context->evaluationKeys.getBsk();
return context->get_bsk();
}
FftwFourierLweBootstrapKey64 *get_fftw_fourier_bootstrap_key_u64(
mlir::concretelang::RuntimeContext *context) {
return context->get_fftw_fourier_bsk();
}
DefaultEngine *get_engine(mlir::concretelang::RuntimeContext *context) {
return context->default_engine;
return context->get_default_engine();
}
FftwEngine *get_fftw_engine(mlir::concretelang::RuntimeContext *context) {
pthread_t threadId = pthread_self();
std::lock_guard<std::mutex> guard(context->engines_map_guard);
auto engineIt = context->fftw_engines.find(threadId);
if (engineIt == context->fftw_engines.end()) {
FftwEngine *fftw_engine = nullptr;
CAPI_ASSERT_ERROR(new_fftw_engine(&fftw_engine));
engineIt =
context->fftw_engines
.insert(std::pair<pthread_t, FftwEngine *>(threadId, fftw_engine))
.first;
}
assert(engineIt->second && "No engine available in context");
return engineIt->second;
return context->get_fftw_engine();
}

View File

@@ -159,7 +159,7 @@ void memref_bootstrap_lwe_u64(
CAPI_ASSERT_ERROR(
fftw_engine_lwe_ciphertext_discarding_bootstrap_u64_raw_ptr_buffers(
get_fftw_engine(context), get_engine(context),
get_bootstrap_key_u64(context), out_aligned + out_offset,
get_fftw_fourier_bootstrap_key_u64(context), out_aligned + out_offset,
ct0_aligned + ct0_offset, glwe_ct_aligned + glwe_ct_offset));
}