diff --git a/compiler/Makefile b/compiler/Makefile index 09859e466..f93395992 100644 --- a/compiler/Makefile +++ b/compiler/Makefile @@ -18,7 +18,7 @@ HPX_TARBALL=$(shell pwd)/hpx-$(HPX_VERSION).tar.gz HPX_LOCAL_DIR=$(shell pwd)/hpx-$(HPX_VERSION) HPX_INSTALL_DIR?=$(HPX_LOCAL_DIR) -CONCRETE_CORE_FFI_VERSION?=0.2.0-rc.2 +CONCRETE_CORE_FFI_VERSION?=0.2.0-rc.3 ifeq ($(shell uname), Linux) CONCRETE_CORE_FFI_TARBALL=concrete-core-ffi_$(CONCRETE_CORE_FFI_VERSION)_linux_amd64.tar.gz else diff --git a/compiler/include/concretelang/ClientLib/EvaluationKeys.h b/compiler/include/concretelang/ClientLib/EvaluationKeys.h index 2f29ba468..8ad52b37a 100644 --- a/compiler/include/concretelang/ClientLib/EvaluationKeys.h +++ b/compiler/include/concretelang/ClientLib/EvaluationKeys.h @@ -46,10 +46,10 @@ public: // ============================================= -/// Wrapper for `FftwFourierLweBootstrapKey64` so that it cleans up properly. +/// Wrapper for `LweBootstrapKey64` so that it cleans up properly. class LweBootstrapKey { private: - FftwFourierLweBootstrapKey64 *bsk; + LweBootstrapKey64 *bsk; protected: friend std::ostream &operator<<(std::ostream &ostream, @@ -58,19 +58,19 @@ protected: LweBootstrapKey &wrappedBsk); public: - LweBootstrapKey(FftwFourierLweBootstrapKey64 *bsk) : bsk{bsk} {} + LweBootstrapKey(LweBootstrapKey64 *bsk) : bsk{bsk} {} LweBootstrapKey(LweBootstrapKey &other) = delete; LweBootstrapKey(LweBootstrapKey &&other) : bsk{other.bsk} { other.bsk = nullptr; } ~LweBootstrapKey() { if (this->bsk != nullptr) { - CAPI_ASSERT_ERROR(destroy_fftw_fourier_lwe_bootstrap_key_u64(this->bsk)); + CAPI_ASSERT_ERROR(destroy_lwe_bootstrap_key_u64(this->bsk)); this->bsk = nullptr; } } - FftwFourierLweBootstrapKey64 *get() { return this->bsk; } + LweBootstrapKey64 *get() { return this->bsk; } }; // ============================================= @@ -97,7 +97,7 @@ public: : sharedKsk{sharedKsk}, sharedBsk{sharedBsk} {} LweKeyswitchKey64 *getKsk() { return this->sharedKsk->get(); } - FftwFourierLweBootstrapKey64 *getBsk() { return this->sharedBsk->get(); } + LweBootstrapKey64 *getBsk() { return this->sharedBsk->get(); } }; // ============================================= diff --git a/compiler/include/concretelang/Runtime/context.h b/compiler/include/concretelang/Runtime/context.h index 551104bb7..e025f7349 100644 --- a/compiler/include/concretelang/Runtime/context.h +++ b/compiler/include/concretelang/Runtime/context.h @@ -21,10 +21,6 @@ namespace mlir { namespace concretelang { typedef struct RuntimeContext { - ::concretelang::clientlib::EvaluationKeys evaluationKeys; - DefaultEngine *default_engine; - std::map fftw_engines; - std::mutex engines_map_guard; RuntimeContext() { CAPI_ASSERT_ERROR(new_default_engine(best_seeder, &default_engine)); @@ -44,12 +40,64 @@ typedef struct RuntimeContext { for (const auto &key : fftw_engines) { CAPI_ASSERT_ERROR(destroy_fftw_engine(key.second)); } + if (fbsk != nullptr) { + CAPI_ASSERT_ERROR(destroy_fftw_fourier_lwe_bootstrap_key_u64(fbsk)); + } } + FftwEngine *get_fftw_engine() { + pthread_t threadId = pthread_self(); + std::lock_guard guard(engines_map_guard); + auto engineIt = fftw_engines.find(threadId); + if (engineIt == fftw_engines.end()) { + FftwEngine *fftw_engine = nullptr; + + CAPI_ASSERT_ERROR(new_fftw_engine(&fftw_engine)); + + engineIt = + fftw_engines + .insert(std::pair(threadId, fftw_engine)) + .first; + } + assert(engineIt->second && "No engine available in context"); + return engineIt->second; + } + + DefaultEngine *get_default_engine() { return default_engine; } + + FftwFourierLweBootstrapKey64 *get_fftw_fourier_bsk() { + + if (fbsk != nullptr) { + return fbsk; + } + + const std::lock_guard guard(fbskMutex); + if (fbsk == nullptr) { + CAPI_ASSERT_ERROR( + fftw_engine_convert_lwe_bootstrap_key_to_fftw_fourier_lwe_bootstrap_key_u64( + get_fftw_engine(), evaluationKeys.getBsk(), &fbsk)); + } + return fbsk; + } + + LweBootstrapKey64 *get_bsk() { return evaluationKeys.getBsk(); } + + LweKeyswitchKey64 *get_ksk() { return evaluationKeys.getKsk(); } + RuntimeContext &operator=(const RuntimeContext &rhs) { this->evaluationKeys = rhs.evaluationKeys; return *this; } + + ::concretelang::clientlib::EvaluationKeys evaluationKeys; + +private: + std::mutex fbskMutex; + FftwFourierLweBootstrapKey64 *fbsk = nullptr; + DefaultEngine *default_engine; + std::map fftw_engines; + std::mutex engines_map_guard; + } RuntimeContext; } // namespace concretelang @@ -60,6 +108,9 @@ LweKeyswitchKey64 * get_keyswitch_key_u64(mlir::concretelang::RuntimeContext *context); FftwFourierLweBootstrapKey64 * +get_fftw_fourier_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context); + +LweBootstrapKey64 * get_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context); DefaultEngine *get_engine(mlir::concretelang::RuntimeContext *context); diff --git a/compiler/include/concretelang/Runtime/key_manager.hpp b/compiler/include/concretelang/Runtime/key_manager.hpp index a3e648280..e80dc1773 100644 --- a/compiler/include/concretelang/Runtime/key_manager.hpp +++ b/compiler/include/concretelang/Runtime/key_manager.hpp @@ -64,18 +64,16 @@ KeyWrapper::KeyWrapper(LweKeyswitchKey64 *key) : key(key) { &buffer)); } template <> -KeyWrapper::KeyWrapper( - FftwFourierLweBootstrapKey64 *key) - : key(key) { +KeyWrapper::KeyWrapper(LweBootstrapKey64 *key) : key(key) { - FftwSerializationEngine *engine; + DefaultSerializationEngine *engine; - CAPI_ASSERT_ERROR(new_fftw_serialization_engine(&engine)); + CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); // No Freeing as it doesn't allocate anything. CAPI_ASSERT_ERROR( - fftw_serialization_engine_serialize_fftw_fourier_lwe_bootstrap_key_u64( - engine, key, &buffer)); + default_serialization_engine_serialize_lwe_bootstrap_key_u64(engine, key, + &buffer)); } template @@ -86,25 +84,25 @@ bool operator==(const KeyWrapper &lhs, template <> template -void KeyWrapper::save( - Archive &ar, const unsigned int version) const { +void KeyWrapper::save(Archive &ar, + const unsigned int version) const { ar << buffer.length; ar << hpx::serialization::make_array(buffer.pointer, buffer.length); } template <> template -void KeyWrapper::load( - Archive &ar, const unsigned int version) { - FftwSerializationEngine *engine; +void KeyWrapper::load(Archive &ar, + const unsigned int version) { + DefaultSerializationEngine *engine; // No Freeing as it doesn't allocate anything. - CAPI_ASSERT_ERROR(new_fftw_serialization_engine(&engine)); + CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); ar >> buffer.length; buffer.pointer = new uint8_t[buffer.length]; ar >> hpx::serialization::make_array(buffer.pointer, buffer.length); CAPI_ASSERT_ERROR( - fftw_serialization_engine_deserialize_fftw_fourier_lwe_bootstrap_key_u64( + default_serialization_engine_deserialize_lwe_bootstrap_key_u64( engine, {buffer.pointer, buffer.length}, &key)); } @@ -155,21 +153,22 @@ struct RuntimeContextManager { if (_dfr_is_root_node()) { RuntimeContext *context = (RuntimeContext *)ctx; LweKeyswitchKey64 *ksk = get_keyswitch_key_u64(context); - FftwFourierLweBootstrapKey64 *bsk = get_bootstrap_key_u64(context); + LweBootstrapKey64 *bsk = get_bootstrap_key_u64(context); KeyWrapper kskw(ksk); - KeyWrapper bskw(bsk); + KeyWrapper bskw(bsk); hpx::collectives::broadcast_to("ksk_keystore", kskw); hpx::collectives::broadcast_to("bsk_keystore", bskw); } else { auto kskFut = hpx::collectives::broadcast_from>( "ksk_keystore"); - auto bskFut = hpx::collectives::broadcast_from< - KeyWrapper>("bsk_keystore"); + auto bskFut = + hpx::collectives::broadcast_from>( + "bsk_keystore"); KeyWrapper kskw = kskFut.get(); - KeyWrapper bskw = bskFut.get(); + KeyWrapper bskw = bskFut.get(); context = new mlir::concretelang::RuntimeContext(); context->evaluationKeys = ::concretelang::clientlib::EvaluationKeys( std::shared_ptr<::concretelang::clientlib::LweKeyswitchKey>( diff --git a/compiler/lib/ClientLib/KeySet.cpp b/compiler/lib/ClientLib/KeySet.cpp index 22f79d4bd..e02c7a946 100644 --- a/compiler/lib/ClientLib/KeySet.cpp +++ b/compiler/lib/ClientLib/KeySet.cpp @@ -197,18 +197,10 @@ KeySet::generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param) { par_engine, inputSk->second.second, output_glwe_sk, param.baseLog, param.level, param.variance, &bsk)); - FftwFourierLweBootstrapKey64 *fbsk; - - CAPI_ASSERT_ERROR( - fftw_engine_convert_lwe_bootstrap_key_to_fftw_fourier_lwe_bootstrap_key_u64( - fftw_engine, bsk, &fbsk)); - - CAPI_ASSERT_ERROR(destroy_lwe_bootstrap_key_u64(bsk)); - CAPI_ASSERT_ERROR(destroy_glwe_secret_key_u64(output_glwe_sk)); // Store the bootstrap key - bootstrapKeys[id] = {param, std::make_shared(fbsk)}; + bootstrapKeys[id] = {param, std::make_shared(bsk)}; return outcome::success(); } diff --git a/compiler/lib/ClientLib/KeySetCache.cpp b/compiler/lib/ClientLib/KeySetCache.cpp index 4d7cd1123..82adf1481 100644 --- a/compiler/lib/ClientLib/KeySetCache.cpp +++ b/compiler/lib/ClientLib/KeySetCache.cpp @@ -75,17 +75,16 @@ loadKeyswitchKey(llvm::SmallString<0> &path) { engine); } -outcome::checked +outcome::checked loadBootstrapKey(llvm::SmallString<0> &path) { - FftwSerializationEngine *engine; + DefaultSerializationEngine *engine; - CAPI_ASSERT_ERROR(new_fftw_serialization_engine(&engine)); + CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); - return load( - path, - fftw_serialization_engine_deserialize_fftw_fourier_lwe_bootstrap_key_u64, - engine); + return load(path, + default_serialization_engine_deserialize_lwe_bootstrap_key_u64, + engine); } void saveSecretKey(llvm::SmallString<0> &path, LweSecretKey64 *key) { @@ -103,18 +102,16 @@ void saveSecretKey(llvm::SmallString<0> &path, LweSecretKey64 *key) { free(buffer.pointer); } -void saveBootstrapKey(llvm::SmallString<0> &path, - FftwFourierLweBootstrapKey64 *key) { - FftwSerializationEngine *engine; +void saveBootstrapKey(llvm::SmallString<0> &path, LweBootstrapKey64 *key) { + DefaultSerializationEngine *engine; - CAPI_ASSERT_ERROR(new_fftw_serialization_engine(&engine)); + CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); Buffer buffer; CAPI_ASSERT_ERROR( - fftw_serialization_engine_serialize_fftw_fourier_lwe_bootstrap_key_u64( - engine, key, &buffer)); - + default_serialization_engine_serialize_lwe_bootstrap_key_u64(engine, key, + &buffer)); writeFile(path, buffer); free(buffer.pointer); } @@ -166,7 +163,7 @@ KeySetCache::loadKeys(ClientParameters ¶ms, uint64_t seed_msb, auto param = bootstrapKeyParam.second; llvm::SmallString<0> path(folderPath); llvm::sys::path::append(path, "pbsKey_" + id); - OUTCOME_TRY(FftwFourierLweBootstrapKey64 * bsk, loadBootstrapKey(path)); + OUTCOME_TRY(LweBootstrapKey64 * bsk, loadBootstrapKey(path)); bootstrapKeys[id] = {param, std::make_shared(bsk)}; } // Load keyswitch keys diff --git a/compiler/lib/ClientLib/Serializers.cpp b/compiler/lib/ClientLib/Serializers.cpp index dafe977ea..a2746f858 100644 --- a/compiler/lib/ClientLib/Serializers.cpp +++ b/compiler/lib/ClientLib/Serializers.cpp @@ -60,6 +60,24 @@ std::ostream &operator<<(std::ostream &ostream, const LweKeyswitchKey64 *key) { return ostream; } +std::ostream &operator<<(std::ostream &ostream, const LweBootstrapKey64 *key) { + DefaultSerializationEngine *engine; + + // No Freeing as it doesn't allocate anything. + CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); + + Buffer b; + + CAPI_ASSERT_ERROR( + default_serialization_engine_serialize_lwe_bootstrap_key_u64(engine, key, + &b)) + + writeBufferLike(ostream, b); + free((void *)b.pointer); + b.pointer = nullptr; + return ostream; +} + std::ostream &operator<<(std::ostream &ostream, const FftwFourierLweBootstrapKey64 *key) { FftwSerializationEngine *engine; @@ -91,6 +109,18 @@ std::istream &operator>>(std::istream &istream, LweKeyswitchKey64 *&key) { return istream; } +std::istream &operator>>(std::istream &istream, LweBootstrapKey64 *&key) { + DefaultSerializationEngine *engine; + + // No Freeing as it doesn't allocate anything. + CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); + + key = read_deser( + istream, default_serialization_engine_deserialize_lwe_bootstrap_key_u64, + engine); + return istream; +} + std::istream &operator>>(std::istream &istream, FftwFourierLweBootstrapKey64 *&key) { FftwSerializationEngine *engine; diff --git a/compiler/lib/Runtime/context.cpp b/compiler/lib/Runtime/context.cpp index 1e9af538e..2377d59c7 100644 --- a/compiler/lib/Runtime/context.cpp +++ b/compiler/lib/Runtime/context.cpp @@ -11,32 +11,23 @@ LweKeyswitchKey64 * get_keyswitch_key_u64(mlir::concretelang::RuntimeContext *context) { - return context->evaluationKeys.getKsk(); + return context->get_ksk(); } -FftwFourierLweBootstrapKey64 * +LweBootstrapKey64 * get_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context) { - return context->evaluationKeys.getBsk(); + return context->get_bsk(); +} + +FftwFourierLweBootstrapKey64 *get_fftw_fourier_bootstrap_key_u64( + mlir::concretelang::RuntimeContext *context) { + return context->get_fftw_fourier_bsk(); } DefaultEngine *get_engine(mlir::concretelang::RuntimeContext *context) { - return context->default_engine; + return context->get_default_engine(); } FftwEngine *get_fftw_engine(mlir::concretelang::RuntimeContext *context) { - pthread_t threadId = pthread_self(); - std::lock_guard guard(context->engines_map_guard); - auto engineIt = context->fftw_engines.find(threadId); - if (engineIt == context->fftw_engines.end()) { - FftwEngine *fftw_engine = nullptr; - - CAPI_ASSERT_ERROR(new_fftw_engine(&fftw_engine)); - - engineIt = - context->fftw_engines - .insert(std::pair(threadId, fftw_engine)) - .first; - } - assert(engineIt->second && "No engine available in context"); - return engineIt->second; + return context->get_fftw_engine(); } diff --git a/compiler/lib/Runtime/wrappers.cpp b/compiler/lib/Runtime/wrappers.cpp index ae2795ff1..002ad5651 100644 --- a/compiler/lib/Runtime/wrappers.cpp +++ b/compiler/lib/Runtime/wrappers.cpp @@ -159,7 +159,7 @@ void memref_bootstrap_lwe_u64( CAPI_ASSERT_ERROR( fftw_engine_lwe_ciphertext_discarding_bootstrap_u64_raw_ptr_buffers( get_fftw_engine(context), get_engine(context), - get_bootstrap_key_u64(context), out_aligned + out_offset, + get_fftw_fourier_bootstrap_key_u64(context), out_aligned + out_offset, ct0_aligned + ct0_offset, glwe_ct_aligned + glwe_ct_offset)); }