mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-10 04:35:03 -05:00
feat: keep std bsk and conv to fourier when needed
This commit is contained in:
@@ -18,7 +18,7 @@ HPX_TARBALL=$(shell pwd)/hpx-$(HPX_VERSION).tar.gz
|
||||
HPX_LOCAL_DIR=$(shell pwd)/hpx-$(HPX_VERSION)
|
||||
HPX_INSTALL_DIR?=$(HPX_LOCAL_DIR)
|
||||
|
||||
CONCRETE_CORE_FFI_VERSION?=0.2.0-rc.2
|
||||
CONCRETE_CORE_FFI_VERSION?=0.2.0-rc.3
|
||||
ifeq ($(shell uname), Linux)
|
||||
CONCRETE_CORE_FFI_TARBALL=concrete-core-ffi_$(CONCRETE_CORE_FFI_VERSION)_linux_amd64.tar.gz
|
||||
else
|
||||
|
||||
@@ -46,10 +46,10 @@ public:
|
||||
|
||||
// =============================================
|
||||
|
||||
/// Wrapper for `FftwFourierLweBootstrapKey64` so that it cleans up properly.
|
||||
/// Wrapper for `LweBootstrapKey64` so that it cleans up properly.
|
||||
class LweBootstrapKey {
|
||||
private:
|
||||
FftwFourierLweBootstrapKey64 *bsk;
|
||||
LweBootstrapKey64 *bsk;
|
||||
|
||||
protected:
|
||||
friend std::ostream &operator<<(std::ostream &ostream,
|
||||
@@ -58,19 +58,19 @@ protected:
|
||||
LweBootstrapKey &wrappedBsk);
|
||||
|
||||
public:
|
||||
LweBootstrapKey(FftwFourierLweBootstrapKey64 *bsk) : bsk{bsk} {}
|
||||
LweBootstrapKey(LweBootstrapKey64 *bsk) : bsk{bsk} {}
|
||||
LweBootstrapKey(LweBootstrapKey &other) = delete;
|
||||
LweBootstrapKey(LweBootstrapKey &&other) : bsk{other.bsk} {
|
||||
other.bsk = nullptr;
|
||||
}
|
||||
~LweBootstrapKey() {
|
||||
if (this->bsk != nullptr) {
|
||||
CAPI_ASSERT_ERROR(destroy_fftw_fourier_lwe_bootstrap_key_u64(this->bsk));
|
||||
CAPI_ASSERT_ERROR(destroy_lwe_bootstrap_key_u64(this->bsk));
|
||||
this->bsk = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
FftwFourierLweBootstrapKey64 *get() { return this->bsk; }
|
||||
LweBootstrapKey64 *get() { return this->bsk; }
|
||||
};
|
||||
|
||||
// =============================================
|
||||
@@ -97,7 +97,7 @@ public:
|
||||
: sharedKsk{sharedKsk}, sharedBsk{sharedBsk} {}
|
||||
|
||||
LweKeyswitchKey64 *getKsk() { return this->sharedKsk->get(); }
|
||||
FftwFourierLweBootstrapKey64 *getBsk() { return this->sharedBsk->get(); }
|
||||
LweBootstrapKey64 *getBsk() { return this->sharedBsk->get(); }
|
||||
};
|
||||
|
||||
// =============================================
|
||||
|
||||
@@ -21,10 +21,6 @@ namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
typedef struct RuntimeContext {
|
||||
::concretelang::clientlib::EvaluationKeys evaluationKeys;
|
||||
DefaultEngine *default_engine;
|
||||
std::map<pthread_t, FftwEngine *> fftw_engines;
|
||||
std::mutex engines_map_guard;
|
||||
|
||||
RuntimeContext() {
|
||||
CAPI_ASSERT_ERROR(new_default_engine(best_seeder, &default_engine));
|
||||
@@ -44,12 +40,64 @@ typedef struct RuntimeContext {
|
||||
for (const auto &key : fftw_engines) {
|
||||
CAPI_ASSERT_ERROR(destroy_fftw_engine(key.second));
|
||||
}
|
||||
if (fbsk != nullptr) {
|
||||
CAPI_ASSERT_ERROR(destroy_fftw_fourier_lwe_bootstrap_key_u64(fbsk));
|
||||
}
|
||||
}
|
||||
|
||||
FftwEngine *get_fftw_engine() {
|
||||
pthread_t threadId = pthread_self();
|
||||
std::lock_guard<std::mutex> guard(engines_map_guard);
|
||||
auto engineIt = fftw_engines.find(threadId);
|
||||
if (engineIt == fftw_engines.end()) {
|
||||
FftwEngine *fftw_engine = nullptr;
|
||||
|
||||
CAPI_ASSERT_ERROR(new_fftw_engine(&fftw_engine));
|
||||
|
||||
engineIt =
|
||||
fftw_engines
|
||||
.insert(std::pair<pthread_t, FftwEngine *>(threadId, fftw_engine))
|
||||
.first;
|
||||
}
|
||||
assert(engineIt->second && "No engine available in context");
|
||||
return engineIt->second;
|
||||
}
|
||||
|
||||
DefaultEngine *get_default_engine() { return default_engine; }
|
||||
|
||||
FftwFourierLweBootstrapKey64 *get_fftw_fourier_bsk() {
|
||||
|
||||
if (fbsk != nullptr) {
|
||||
return fbsk;
|
||||
}
|
||||
|
||||
const std::lock_guard<std::mutex> guard(fbskMutex);
|
||||
if (fbsk == nullptr) {
|
||||
CAPI_ASSERT_ERROR(
|
||||
fftw_engine_convert_lwe_bootstrap_key_to_fftw_fourier_lwe_bootstrap_key_u64(
|
||||
get_fftw_engine(), evaluationKeys.getBsk(), &fbsk));
|
||||
}
|
||||
return fbsk;
|
||||
}
|
||||
|
||||
LweBootstrapKey64 *get_bsk() { return evaluationKeys.getBsk(); }
|
||||
|
||||
LweKeyswitchKey64 *get_ksk() { return evaluationKeys.getKsk(); }
|
||||
|
||||
RuntimeContext &operator=(const RuntimeContext &rhs) {
|
||||
this->evaluationKeys = rhs.evaluationKeys;
|
||||
return *this;
|
||||
}
|
||||
|
||||
::concretelang::clientlib::EvaluationKeys evaluationKeys;
|
||||
|
||||
private:
|
||||
std::mutex fbskMutex;
|
||||
FftwFourierLweBootstrapKey64 *fbsk = nullptr;
|
||||
DefaultEngine *default_engine;
|
||||
std::map<pthread_t, FftwEngine *> fftw_engines;
|
||||
std::mutex engines_map_guard;
|
||||
|
||||
} RuntimeContext;
|
||||
|
||||
} // namespace concretelang
|
||||
@@ -60,6 +108,9 @@ LweKeyswitchKey64 *
|
||||
get_keyswitch_key_u64(mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
FftwFourierLweBootstrapKey64 *
|
||||
get_fftw_fourier_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
LweBootstrapKey64 *
|
||||
get_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
DefaultEngine *get_engine(mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
@@ -64,18 +64,16 @@ KeyWrapper<LweKeyswitchKey64>::KeyWrapper(LweKeyswitchKey64 *key) : key(key) {
|
||||
&buffer));
|
||||
}
|
||||
template <>
|
||||
KeyWrapper<FftwFourierLweBootstrapKey64>::KeyWrapper(
|
||||
FftwFourierLweBootstrapKey64 *key)
|
||||
: key(key) {
|
||||
KeyWrapper<LweBootstrapKey64>::KeyWrapper(LweBootstrapKey64 *key) : key(key) {
|
||||
|
||||
FftwSerializationEngine *engine;
|
||||
DefaultSerializationEngine *engine;
|
||||
|
||||
CAPI_ASSERT_ERROR(new_fftw_serialization_engine(&engine));
|
||||
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
|
||||
|
||||
// No Freeing as it doesn't allocate anything.
|
||||
CAPI_ASSERT_ERROR(
|
||||
fftw_serialization_engine_serialize_fftw_fourier_lwe_bootstrap_key_u64(
|
||||
engine, key, &buffer));
|
||||
default_serialization_engine_serialize_lwe_bootstrap_key_u64(engine, key,
|
||||
&buffer));
|
||||
}
|
||||
|
||||
template <typename LweKeyType>
|
||||
@@ -86,25 +84,25 @@ bool operator==(const KeyWrapper<LweKeyType> &lhs,
|
||||
|
||||
template <>
|
||||
template <class Archive>
|
||||
void KeyWrapper<FftwFourierLweBootstrapKey64>::save(
|
||||
Archive &ar, const unsigned int version) const {
|
||||
void KeyWrapper<LweBootstrapKey64>::save(Archive &ar,
|
||||
const unsigned int version) const {
|
||||
ar << buffer.length;
|
||||
ar << hpx::serialization::make_array(buffer.pointer, buffer.length);
|
||||
}
|
||||
template <>
|
||||
template <class Archive>
|
||||
void KeyWrapper<FftwFourierLweBootstrapKey64>::load(
|
||||
Archive &ar, const unsigned int version) {
|
||||
FftwSerializationEngine *engine;
|
||||
void KeyWrapper<LweBootstrapKey64>::load(Archive &ar,
|
||||
const unsigned int version) {
|
||||
DefaultSerializationEngine *engine;
|
||||
|
||||
// No Freeing as it doesn't allocate anything.
|
||||
CAPI_ASSERT_ERROR(new_fftw_serialization_engine(&engine));
|
||||
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
|
||||
|
||||
ar >> buffer.length;
|
||||
buffer.pointer = new uint8_t[buffer.length];
|
||||
ar >> hpx::serialization::make_array(buffer.pointer, buffer.length);
|
||||
CAPI_ASSERT_ERROR(
|
||||
fftw_serialization_engine_deserialize_fftw_fourier_lwe_bootstrap_key_u64(
|
||||
default_serialization_engine_deserialize_lwe_bootstrap_key_u64(
|
||||
engine, {buffer.pointer, buffer.length}, &key));
|
||||
}
|
||||
|
||||
@@ -155,21 +153,22 @@ struct RuntimeContextManager {
|
||||
if (_dfr_is_root_node()) {
|
||||
RuntimeContext *context = (RuntimeContext *)ctx;
|
||||
LweKeyswitchKey64 *ksk = get_keyswitch_key_u64(context);
|
||||
FftwFourierLweBootstrapKey64 *bsk = get_bootstrap_key_u64(context);
|
||||
LweBootstrapKey64 *bsk = get_bootstrap_key_u64(context);
|
||||
|
||||
KeyWrapper<LweKeyswitchKey64> kskw(ksk);
|
||||
KeyWrapper<FftwFourierLweBootstrapKey64> bskw(bsk);
|
||||
KeyWrapper<LweBootstrapKey64> bskw(bsk);
|
||||
hpx::collectives::broadcast_to("ksk_keystore", kskw);
|
||||
hpx::collectives::broadcast_to("bsk_keystore", bskw);
|
||||
} else {
|
||||
auto kskFut =
|
||||
hpx::collectives::broadcast_from<KeyWrapper<LweKeyswitchKey64>>(
|
||||
"ksk_keystore");
|
||||
auto bskFut = hpx::collectives::broadcast_from<
|
||||
KeyWrapper<FftwFourierLweBootstrapKey64>>("bsk_keystore");
|
||||
auto bskFut =
|
||||
hpx::collectives::broadcast_from<KeyWrapper<LweBootstrapKey64>>(
|
||||
"bsk_keystore");
|
||||
|
||||
KeyWrapper<LweKeyswitchKey64> kskw = kskFut.get();
|
||||
KeyWrapper<FftwFourierLweBootstrapKey64> bskw = bskFut.get();
|
||||
KeyWrapper<LweBootstrapKey64> bskw = bskFut.get();
|
||||
context = new mlir::concretelang::RuntimeContext();
|
||||
context->evaluationKeys = ::concretelang::clientlib::EvaluationKeys(
|
||||
std::shared_ptr<::concretelang::clientlib::LweKeyswitchKey>(
|
||||
|
||||
@@ -197,18 +197,10 @@ KeySet::generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param) {
|
||||
par_engine, inputSk->second.second, output_glwe_sk, param.baseLog,
|
||||
param.level, param.variance, &bsk));
|
||||
|
||||
FftwFourierLweBootstrapKey64 *fbsk;
|
||||
|
||||
CAPI_ASSERT_ERROR(
|
||||
fftw_engine_convert_lwe_bootstrap_key_to_fftw_fourier_lwe_bootstrap_key_u64(
|
||||
fftw_engine, bsk, &fbsk));
|
||||
|
||||
CAPI_ASSERT_ERROR(destroy_lwe_bootstrap_key_u64(bsk));
|
||||
|
||||
CAPI_ASSERT_ERROR(destroy_glwe_secret_key_u64(output_glwe_sk));
|
||||
|
||||
// Store the bootstrap key
|
||||
bootstrapKeys[id] = {param, std::make_shared<LweBootstrapKey>(fbsk)};
|
||||
bootstrapKeys[id] = {param, std::make_shared<LweBootstrapKey>(bsk)};
|
||||
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
@@ -75,17 +75,16 @@ loadKeyswitchKey(llvm::SmallString<0> &path) {
|
||||
engine);
|
||||
}
|
||||
|
||||
outcome::checked<FftwFourierLweBootstrapKey64 *, StringError>
|
||||
outcome::checked<LweBootstrapKey64 *, StringError>
|
||||
loadBootstrapKey(llvm::SmallString<0> &path) {
|
||||
|
||||
FftwSerializationEngine *engine;
|
||||
DefaultSerializationEngine *engine;
|
||||
|
||||
CAPI_ASSERT_ERROR(new_fftw_serialization_engine(&engine));
|
||||
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
|
||||
|
||||
return load(
|
||||
path,
|
||||
fftw_serialization_engine_deserialize_fftw_fourier_lwe_bootstrap_key_u64,
|
||||
engine);
|
||||
return load(path,
|
||||
default_serialization_engine_deserialize_lwe_bootstrap_key_u64,
|
||||
engine);
|
||||
}
|
||||
|
||||
void saveSecretKey(llvm::SmallString<0> &path, LweSecretKey64 *key) {
|
||||
@@ -103,18 +102,16 @@ void saveSecretKey(llvm::SmallString<0> &path, LweSecretKey64 *key) {
|
||||
free(buffer.pointer);
|
||||
}
|
||||
|
||||
void saveBootstrapKey(llvm::SmallString<0> &path,
|
||||
FftwFourierLweBootstrapKey64 *key) {
|
||||
FftwSerializationEngine *engine;
|
||||
void saveBootstrapKey(llvm::SmallString<0> &path, LweBootstrapKey64 *key) {
|
||||
DefaultSerializationEngine *engine;
|
||||
|
||||
CAPI_ASSERT_ERROR(new_fftw_serialization_engine(&engine));
|
||||
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
|
||||
|
||||
Buffer buffer;
|
||||
|
||||
CAPI_ASSERT_ERROR(
|
||||
fftw_serialization_engine_serialize_fftw_fourier_lwe_bootstrap_key_u64(
|
||||
engine, key, &buffer));
|
||||
|
||||
default_serialization_engine_serialize_lwe_bootstrap_key_u64(engine, key,
|
||||
&buffer));
|
||||
writeFile(path, buffer);
|
||||
free(buffer.pointer);
|
||||
}
|
||||
@@ -166,7 +163,7 @@ KeySetCache::loadKeys(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
auto param = bootstrapKeyParam.second;
|
||||
llvm::SmallString<0> path(folderPath);
|
||||
llvm::sys::path::append(path, "pbsKey_" + id);
|
||||
OUTCOME_TRY(FftwFourierLweBootstrapKey64 * bsk, loadBootstrapKey(path));
|
||||
OUTCOME_TRY(LweBootstrapKey64 * bsk, loadBootstrapKey(path));
|
||||
bootstrapKeys[id] = {param, std::make_shared<LweBootstrapKey>(bsk)};
|
||||
}
|
||||
// Load keyswitch keys
|
||||
|
||||
@@ -60,6 +60,24 @@ std::ostream &operator<<(std::ostream &ostream, const LweKeyswitchKey64 *key) {
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream, const LweBootstrapKey64 *key) {
|
||||
DefaultSerializationEngine *engine;
|
||||
|
||||
// No Freeing as it doesn't allocate anything.
|
||||
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
|
||||
|
||||
Buffer b;
|
||||
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_serialization_engine_serialize_lwe_bootstrap_key_u64(engine, key,
|
||||
&b))
|
||||
|
||||
writeBufferLike(ostream, b);
|
||||
free((void *)b.pointer);
|
||||
b.pointer = nullptr;
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const FftwFourierLweBootstrapKey64 *key) {
|
||||
FftwSerializationEngine *engine;
|
||||
@@ -91,6 +109,18 @@ std::istream &operator>>(std::istream &istream, LweKeyswitchKey64 *&key) {
|
||||
return istream;
|
||||
}
|
||||
|
||||
std::istream &operator>>(std::istream &istream, LweBootstrapKey64 *&key) {
|
||||
DefaultSerializationEngine *engine;
|
||||
|
||||
// No Freeing as it doesn't allocate anything.
|
||||
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
|
||||
|
||||
key = read_deser(
|
||||
istream, default_serialization_engine_deserialize_lwe_bootstrap_key_u64,
|
||||
engine);
|
||||
return istream;
|
||||
}
|
||||
|
||||
std::istream &operator>>(std::istream &istream,
|
||||
FftwFourierLweBootstrapKey64 *&key) {
|
||||
FftwSerializationEngine *engine;
|
||||
|
||||
@@ -11,32 +11,23 @@
|
||||
|
||||
LweKeyswitchKey64 *
|
||||
get_keyswitch_key_u64(mlir::concretelang::RuntimeContext *context) {
|
||||
return context->evaluationKeys.getKsk();
|
||||
return context->get_ksk();
|
||||
}
|
||||
|
||||
FftwFourierLweBootstrapKey64 *
|
||||
LweBootstrapKey64 *
|
||||
get_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context) {
|
||||
return context->evaluationKeys.getBsk();
|
||||
return context->get_bsk();
|
||||
}
|
||||
|
||||
FftwFourierLweBootstrapKey64 *get_fftw_fourier_bootstrap_key_u64(
|
||||
mlir::concretelang::RuntimeContext *context) {
|
||||
return context->get_fftw_fourier_bsk();
|
||||
}
|
||||
|
||||
DefaultEngine *get_engine(mlir::concretelang::RuntimeContext *context) {
|
||||
return context->default_engine;
|
||||
return context->get_default_engine();
|
||||
}
|
||||
|
||||
FftwEngine *get_fftw_engine(mlir::concretelang::RuntimeContext *context) {
|
||||
pthread_t threadId = pthread_self();
|
||||
std::lock_guard<std::mutex> guard(context->engines_map_guard);
|
||||
auto engineIt = context->fftw_engines.find(threadId);
|
||||
if (engineIt == context->fftw_engines.end()) {
|
||||
FftwEngine *fftw_engine = nullptr;
|
||||
|
||||
CAPI_ASSERT_ERROR(new_fftw_engine(&fftw_engine));
|
||||
|
||||
engineIt =
|
||||
context->fftw_engines
|
||||
.insert(std::pair<pthread_t, FftwEngine *>(threadId, fftw_engine))
|
||||
.first;
|
||||
}
|
||||
assert(engineIt->second && "No engine available in context");
|
||||
return engineIt->second;
|
||||
return context->get_fftw_engine();
|
||||
}
|
||||
|
||||
@@ -159,7 +159,7 @@ void memref_bootstrap_lwe_u64(
|
||||
CAPI_ASSERT_ERROR(
|
||||
fftw_engine_lwe_ciphertext_discarding_bootstrap_u64_raw_ptr_buffers(
|
||||
get_fftw_engine(context), get_engine(context),
|
||||
get_bootstrap_key_u64(context), out_aligned + out_offset,
|
||||
get_fftw_fourier_bootstrap_key_u64(context), out_aligned + out_offset,
|
||||
ct0_aligned + ct0_offset, glwe_ct_aligned + glwe_ct_offset));
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user