refactor: Integrate concrete-cpu and remove concrete-core

Co-authored-by: Mayeul@Zama <mayeul.debellabre@zama.ai>
This commit is contained in:
Quentin Bourgerie
2023-02-03 15:54:51 +01:00
parent 9c784a2243
commit a62b3b1d74
53 changed files with 1502 additions and 1868 deletions

View File

@@ -92,8 +92,7 @@ public:
OUTCOME_TRY(auto clientArguments,
EncryptedArguments::create(keySet, args...));
return clientArguments->exportPublicArguments(clientParameters,
keySet.runtimeContext());
return clientArguments->exportPublicArguments(clientParameters);
}
outcome::checked<Result, StringError> decryptResult(KeySet &keySet,

View File

@@ -35,10 +35,8 @@ namespace clientlib {
using concretelang::error::StringError;
const std::string SMALL_KEY = "small";
const std::string BIG_KEY = "big";
const std::string BOOTSTRAP_KEY = "bsk_v0";
const std::string KEYSWITCH_KEY = "ksk_v0";
const uint64_t SMALL_KEY = 1;
const uint64_t BIG_KEY = 0;
const std::string CLIENT_PARAMETERS_EXT = ".concrete.params.json";
@@ -52,7 +50,7 @@ typedef std::vector<int64_t> CRTDecomposition;
typedef uint64_t LweDimension;
typedef uint64_t GlweDimension;
typedef std::string LweSecretKeyID;
typedef uint64_t LweSecretKeyID;
struct LweSecretKeyParam {
LweDimension dimension;
@@ -66,7 +64,7 @@ static bool operator==(const LweSecretKeyParam &lhs,
return lhs.dimension == rhs.dimension;
}
typedef std::string BootstrapKeyID;
typedef uint64_t BootstrapKeyID;
struct BootstrapKeyParam {
LweSecretKeyID inputSecretKeyID;
LweSecretKeyID outputSecretKeyID;
@@ -74,6 +72,8 @@ struct BootstrapKeyParam {
DecompositionBaseLog baseLog;
GlweDimension glweDimension;
Variance variance;
PolynomialSize polynomialSize;
LweDimension inputLweDimension;
void hash(size_t &seed);
@@ -90,7 +90,7 @@ static inline bool operator==(const BootstrapKeyParam &lhs,
lhs.glweDimension == rhs.glweDimension && lhs.variance == rhs.variance;
}
typedef std::string KeyswitchKeyID;
typedef uint64_t KeyswitchKeyID;
struct KeyswitchKeyParam {
LweSecretKeyID inputSecretKeyID;
LweSecretKeyID outputSecretKeyID;
@@ -112,22 +112,28 @@ static inline bool operator==(const KeyswitchKeyParam &lhs,
lhs.variance == rhs.variance;
}
typedef std::string PackingKeySwitchID;
struct PackingKeySwitchParam {
typedef uint64_t PackingKeyswitchKeyID;
struct PackingKeyswitchKeyParam {
LweSecretKeyID inputSecretKeyID;
LweSecretKeyID outputSecretKeyID;
BootstrapKeyID bootstrapKeyID;
DecompositionLevelCount level;
DecompositionBaseLog baseLog;
GlweDimension glweDimension;
PolynomialSize polynomialSize;
LweDimension inputLweDimension;
Variance variance;
void hash(size_t &seed);
};
static inline bool operator==(const PackingKeySwitchParam &lhs,
const PackingKeySwitchParam &rhs) {
static inline bool operator==(const PackingKeyswitchKeyParam &lhs,
const PackingKeyswitchKeyParam &rhs) {
return lhs.inputSecretKeyID == rhs.inputSecretKeyID &&
lhs.outputSecretKeyID == rhs.outputSecretKeyID &&
lhs.level == rhs.level && lhs.baseLog == rhs.baseLog;
lhs.level == rhs.level && lhs.baseLog == rhs.baseLog &&
lhs.glweDimension == rhs.glweDimension &&
lhs.polynomialSize == rhs.polynomialSize &&
lhs.variance == lhs.variance &&
lhs.inputLweDimension == rhs.inputLweDimension;
}
struct Encoding {
@@ -185,13 +191,13 @@ struct CircuitGate {
bool isEncrypted() { return encryption.hasValue(); }
/// byteSize returns the size in bytes for this gate.
size_t byteSize(std::map<LweSecretKeyID, LweSecretKeyParam> secretKeys) {
size_t byteSize(std::vector<LweSecretKeyParam> secretKeys) {
auto width = shape.width;
auto numElts = shape.size == 0 ? 1 : shape.size;
if (isEncrypted()) {
auto skParam = secretKeys.find(encryption->secretKeyID);
assert(skParam != secretKeys.end());
return 8 * skParam->second.lweSize() * numElts;
assert(encryption->secretKeyID < secretKeys.size());
auto skParam = secretKeys[encryption->secretKeyID];
return 8 * skParam.lweSize() * numElts;
}
width = bitWidthAsWord(width) / 8;
return width * numElts;
@@ -203,10 +209,10 @@ static inline bool operator==(const CircuitGate &lhs, const CircuitGate &rhs) {
}
struct ClientParameters {
std::map<LweSecretKeyID, LweSecretKeyParam> secretKeys;
std::map<BootstrapKeyID, BootstrapKeyParam> bootstrapKeys;
std::map<KeyswitchKeyID, KeyswitchKeyParam> keyswitchKeys;
std::map<PackingKeySwitchID, PackingKeySwitchParam> packingKeys;
std::vector<LweSecretKeyParam> secretKeys;
std::vector<BootstrapKeyParam> bootstrapKeys;
std::vector<KeyswitchKeyParam> keyswitchKeys;
std::vector<PackingKeyswitchKeyParam> packingKeyswitchKeys;
std::vector<CircuitGate> inputs;
std::vector<CircuitGate> outputs;
std::string functionName;
@@ -237,12 +243,9 @@ struct ClientParameters {
if (!gate.encryption.hasValue()) {
return StringError("gate is not encrypted");
}
auto secretKey = secretKeys.find(gate.encryption->secretKeyID);
if (secretKey == secretKeys.end()) {
return StringError("cannot find ")
<< gate.encryption->secretKeyID << " in client parameters";
}
return secretKey->second;
assert(gate.encryption->secretKeyID < secretKeys.size());
auto secretKey = secretKeys[gate.encryption->secretKeyID];
return secretKey;
}
/// bufferSize returns the size of the whole buffer of a gate.
@@ -309,8 +312,8 @@ bool fromJSON(const llvm::json::Value, BootstrapKeyParam &, llvm::json::Path);
llvm::json::Value toJSON(const KeyswitchKeyParam &);
bool fromJSON(const llvm::json::Value, KeyswitchKeyParam &, llvm::json::Path);
llvm::json::Value toJSON(const PackingKeySwitchParam &);
bool fromJSON(const llvm::json::Value, PackingKeySwitchParam &,
llvm::json::Value toJSON(const PackingKeyswitchKeyParam &);
bool fromJSON(const llvm::json::Value, PackingKeyswitchKeyParam &,
llvm::json::Path);
llvm::json::Value toJSON(const Encoding &);

View File

@@ -61,8 +61,7 @@ public:
/// arguments, i.e. move all buffers to the PublicArguments and reset the
/// positional counter.
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
exportPublicArguments(ClientParameters clientParameters,
RuntimeContext runtimeContext);
exportPublicArguments(ClientParameters clientParameters);
/// Check that all arguments as been pushed.
// TODO: Remove public method here

View File

@@ -6,113 +6,143 @@
#ifndef CONCRETELANG_CLIENTLIB_EVALUATION_KEYS_H_
#define CONCRETELANG_CLIENTLIB_EVALUATION_KEYS_H_
#include <cassert>
#include <memory>
#include <vector>
#include "concrete-core-ffi.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/Common/Error.h"
typedef struct LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64
LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64;
int destroy_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_u64(
LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *);
struct Csprng;
struct CsprngVtable;
namespace concretelang {
namespace clientlib {
// =============================================
class CSPRNG {
public:
struct Csprng *ptr;
const struct CsprngVtable *vtable;
/// Wrapper for `LweKeyswitchKey64` so that it cleans up properly.
CSPRNG() = delete;
CSPRNG(CSPRNG &) = delete;
CSPRNG(CSPRNG &&other) : ptr(other.ptr), vtable(other.vtable) {
assert(ptr != nullptr);
other.ptr = nullptr;
};
CSPRNG(Csprng *ptr, const CsprngVtable *vtable) : ptr(ptr), vtable(vtable){};
};
class ConcreteCSPRNG : public CSPRNG {
public:
ConcreteCSPRNG(__uint128_t seed);
ConcreteCSPRNG() = delete;
ConcreteCSPRNG(ConcreteCSPRNG &) = delete;
ConcreteCSPRNG(ConcreteCSPRNG &&other);
~ConcreteCSPRNG();
};
/// @brief LweSecretKey implements tools for manipulating lwe secret key on
/// client.
class LweSecretKey {
std::shared_ptr<std::vector<uint64_t>> _buffer;
LweSecretKeyParam _parameters;
public:
LweSecretKey() = delete;
LweSecretKey(LweSecretKeyParam &parameters, CSPRNG &csprng);
LweSecretKey(std::shared_ptr<std::vector<uint64_t>> buffer,
LweSecretKeyParam parameters)
: _buffer(buffer), _parameters(parameters){};
/// @brief Encrypt the plaintext to the lwe ciphertext buffer.
void encrypt(uint64_t *ciphertext, uint64_t plaintext, double variance,
CSPRNG &csprng) const;
/// @brief Decrypt the ciphertext to the plaintext
void decrypt(const uint64_t *ciphertext, uint64_t &plaintext) const;
/// @brief Returns the buffer that hold the keyswitch key.
const uint64_t *buffer() const { return _buffer->data(); }
size_t size() const { return _buffer->size(); }
/// @brief Returns the parameters of the keyswicth key.
LweSecretKeyParam parameters() const { return this->_parameters; }
/// @brief Returns the lwe dimension of the secret key.
size_t dimension() const { return parameters().dimension; }
};
/// @brief LweKeyswitchKey implements tools for manipulating keyswitch key on
/// client.
class LweKeyswitchKey {
private:
LweKeyswitchKey64 *ksk;
protected:
friend std::ostream &operator<<(std::ostream &ostream,
const LweKeyswitchKey &wrappedKsk);
friend std::istream &operator>>(std::istream &istream,
LweKeyswitchKey &wrappedKsk);
std::shared_ptr<std::vector<uint64_t>> _buffer;
KeyswitchKeyParam _parameters;
public:
LweKeyswitchKey(LweKeyswitchKey64 *ksk) : ksk{ksk} {}
LweKeyswitchKey(LweKeyswitchKey &other) = delete;
LweKeyswitchKey(LweKeyswitchKey &&other) : ksk{other.ksk} {
other.ksk = nullptr;
}
~LweKeyswitchKey() {
if (this->ksk != nullptr) {
CAPI_ASSERT_ERROR(destroy_lwe_keyswitch_key_u64(this->ksk));
LweKeyswitchKey() = delete;
LweKeyswitchKey(KeyswitchKeyParam &parameters, LweSecretKey &inputKey,
LweSecretKey &outputKey, CSPRNG &csprng);
LweKeyswitchKey(std::shared_ptr<std::vector<uint64_t>> buffer,
KeyswitchKeyParam parameters)
: _buffer(buffer), _parameters(parameters){};
this->ksk = nullptr;
}
}
/// @brief Returns the buffer that hold the keyswitch key.
const uint64_t *buffer() const { return _buffer->data(); }
size_t size() const { return _buffer->size(); }
LweKeyswitchKey64 *get() { return this->ksk; }
/// @brief Returns the parameters of the keyswicth key.
KeyswitchKeyParam parameters() const { return this->_parameters; }
};
// =============================================
/// Wrapper for `LweBootstrapKey64` so that it cleans up properly.
/// @brief LweBootstrapKey implements tools for manipulating bootstrap key on
/// client.
class LweBootstrapKey {
private:
LweBootstrapKey64 *bsk;
protected:
friend std::ostream &operator<<(std::ostream &ostream,
const LweBootstrapKey &wrappedBsk);
friend std::istream &operator>>(std::istream &istream,
LweBootstrapKey &wrappedBsk);
std::shared_ptr<std::vector<uint64_t>> _buffer;
BootstrapKeyParam _parameters;
public:
LweBootstrapKey(LweBootstrapKey64 *bsk) : bsk{bsk} {}
LweBootstrapKey(LweBootstrapKey &other) = delete;
LweBootstrapKey(LweBootstrapKey &&other) : bsk{other.bsk} {
other.bsk = nullptr;
}
~LweBootstrapKey() {
if (this->bsk != nullptr) {
CAPI_ASSERT_ERROR(destroy_lwe_bootstrap_key_u64(this->bsk));
this->bsk = nullptr;
}
}
LweBootstrapKey() = delete;
LweBootstrapKey(std::shared_ptr<std::vector<uint64_t>> buffer,
BootstrapKeyParam &parameters)
: _buffer(buffer), _parameters(parameters){};
LweBootstrapKey(BootstrapKeyParam &parameters, LweSecretKey &inputKey,
LweSecretKey &outputKey, CSPRNG &csprng);
LweBootstrapKey64 *get() { return this->bsk; }
///// @brief Returns the buffer that hold the bootstrap key.
const uint64_t *buffer() const { return _buffer->data(); }
size_t size() const { return _buffer->size(); }
/// @brief Returns the parameters of the bootsrap key.
BootstrapKeyParam parameters() const { return this->_parameters; }
};
// =============================================
/// Wrapper for `LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64` so
/// that it cleans up properly.
/// @brief PackingKeyswitchKey implements tools for manipulating privat packing
/// keyswitch key on client.
class PackingKeyswitchKey {
private:
LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *key;
protected:
friend std::ostream &operator<<(std::ostream &ostream,
const PackingKeyswitchKey &wrappedFpksk);
friend std::istream &operator>>(std::istream &istream,
PackingKeyswitchKey &wrappedFpksk);
std::shared_ptr<std::vector<uint64_t>> _buffer;
PackingKeyswitchKeyParam _parameters;
public:
PackingKeyswitchKey(
LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *key)
: key{key} {}
PackingKeyswitchKey(PackingKeyswitchKey &other) = delete;
PackingKeyswitchKey(PackingKeyswitchKey &&other) : key{other.key} {
other.key = nullptr;
}
~PackingKeyswitchKey() {
if (this->key != nullptr) {
CAPI_ASSERT_ERROR(
destroy_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_u64(
this->key));
this->key = nullptr;
}
}
PackingKeyswitchKey() = delete;
PackingKeyswitchKey(PackingKeyswitchKeyParam &parameters,
LweSecretKey &inputKey, LweSecretKey &outputKey,
CSPRNG &csprng);
PackingKeyswitchKey(std::shared_ptr<std::vector<uint64_t>> buffer,
PackingKeyswitchKeyParam parameters)
: _buffer(buffer), _parameters(parameters){};
LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *get() {
return this->key;
}
/// @brief Returns the buffer that hold the keyswitch key.
const uint64_t *buffer() const { return _buffer->data(); }
size_t size() const { return _buffer->size(); }
/// @brief Returns the parameters of the keyswicth key.
PackingKeyswitchKeyParam parameters() const { return this->_parameters; }
};
// =============================================
@@ -120,31 +150,40 @@ public:
/// Evalution keys required for execution.
class EvaluationKeys {
private:
std::shared_ptr<LweKeyswitchKey> sharedKsk;
std::shared_ptr<LweBootstrapKey> sharedBsk;
std::shared_ptr<PackingKeyswitchKey> sharedFpksk;
protected:
friend std::ostream &operator<<(std::ostream &ostream,
const EvaluationKeys &evaluationKeys);
friend std::istream &operator>>(std::istream &istream,
EvaluationKeys &evaluationKeys);
std::vector<LweKeyswitchKey> keyswitchKeys;
std::vector<LweBootstrapKey> bootstrapKeys;
std::vector<PackingKeyswitchKey> packingKeyswitchKeys;
public:
EvaluationKeys()
: sharedKsk{std::shared_ptr<LweKeyswitchKey>(nullptr)},
sharedBsk{std::shared_ptr<LweBootstrapKey>(nullptr)} {}
EvaluationKeys() = delete;
EvaluationKeys(std::shared_ptr<LweKeyswitchKey> sharedKsk,
std::shared_ptr<LweBootstrapKey> sharedBsk,
std::shared_ptr<PackingKeyswitchKey> sharedFpksk)
: sharedKsk{sharedKsk}, sharedBsk{sharedBsk}, sharedFpksk{sharedFpksk} {}
EvaluationKeys(const std::vector<LweKeyswitchKey> keyswitchKeys,
const std::vector<LweBootstrapKey> bootstrapKeys,
const std::vector<PackingKeyswitchKey> packingKeyswitchKeys)
: keyswitchKeys(keyswitchKeys), bootstrapKeys(bootstrapKeys),
packingKeyswitchKeys(packingKeyswitchKeys) {}
LweKeyswitchKey64 *getKsk() { return this->sharedKsk->get(); }
LweBootstrapKey64 *getBsk() { return this->sharedBsk->get(); }
LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *getFpksk() {
return this->sharedFpksk->get();
const LweKeyswitchKey &getKeyswitchKey(size_t id) const {
return this->keyswitchKeys[id];
}
const std::vector<LweKeyswitchKey> getKeyswitchKeys() const {
return this->keyswitchKeys;
}
const LweBootstrapKey &getBootstrapKey(size_t id) const {
return bootstrapKeys[id];
}
const std::vector<LweBootstrapKey> getBootstrapKeys() const {
return this->bootstrapKeys;
}
const PackingKeyswitchKey &getPackingKeyswitchKey(size_t id) const {
return this->packingKeyswitchKeys[id];
};
const std::vector<PackingKeyswitchKey> getPackingKeyswitchKeys() const {
return this->packingKeyswitchKeys;
}
};
// =============================================

View File

@@ -10,30 +10,33 @@
#include "boost/outcome.h"
#include "concrete-core-ffi.h"
#include "concretelang/Runtime/DFRuntime.hpp"
#include "concretelang/Runtime/context.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Runtime/DFRuntime.hpp"
namespace concretelang {
namespace clientlib {
using concretelang::error::StringError;
using RuntimeContext = mlir::concretelang::RuntimeContext;
class KeySet {
public:
KeySet();
~KeySet();
KeySet(ClientParameters clientParameters, CSPRNG &&csprng)
: csprng(std::move(csprng)), _clientParameters(clientParameters){};
KeySet(KeySet &other) = delete;
/// allocate a KeySet according the ClientParameters.
/// Generate a KeySet from a ClientParameters specification.
static outcome::checked<std::unique_ptr<KeySet>, StringError>
generate(ClientParameters &params, uint64_t seed_msb, uint64_t seed_lsb);
generate(ClientParameters clientParameters, CSPRNG &&csprng);
/// Create a KeySet from a set of given keys
static outcome::checked<std::unique_ptr<KeySet>, StringError> fromKeys(
ClientParameters clientParameters, std::vector<LweSecretKey> secretKeys,
std::vector<LweBootstrapKey> bootstrapKeys,
std::vector<LweKeyswitchKey> keyswitchKeys,
std::vector<PackingKeyswitchKey> packingKeyswitchKeys, CSPRNG &&csprng);
/// Returns the ClientParameters associated with the KeySet.
ClientParameters clientParameters() { return _clientParameters; }
@@ -41,23 +44,6 @@ public:
// isInputEncrypted return true if the input at the given pos is encrypted.
bool isInputEncrypted(size_t pos);
/// getInputLweSecretKeyParam returns the parameters of the lwe secret key for
/// the input at the given `pos`.
/// The input must be encrupted
LweSecretKeyParam getInputLweSecretKeyParam(size_t pos) {
auto gate = inputGate(pos);
auto inputSk = this->secretKeys.find(gate.encryption->secretKeyID);
return inputSk->second.first;
}
/// getOutputLweSecretKeyParam returns the parameters of the lwe secret key
/// for the given output.
LweSecretKeyParam getOutputLweSecretKeyParam(size_t pos) {
auto gate = outputGate(pos);
auto outputSk = this->secretKeys.find(gate.encryption->secretKeyID);
return outputSk->second.first;
}
/// allocate a lwe ciphertext buffer for the argument at argPos, set the size
/// of the allocated buffer.
outcome::checked<void, StringError>
@@ -80,103 +66,58 @@ public:
CircuitGate inputGate(size_t pos) { return std::get<0>(inputs[pos]); }
CircuitGate outputGate(size_t pos) { return std::get<0>(outputs[pos]); }
RuntimeContext runtimeContext() {
RuntimeContext context;
context.evaluationKeys = this->evaluationKeys();
return context;
}
/// @brief evaluationKeys returns the evaluation keys associate to this client
/// keyset. Those evaluations keys can be safely shared publicly
EvaluationKeys evaluationKeys();
EvaluationKeys evaluationKeys() {
if (this->bootstrapKeys.empty() && this->keyswitchKeys.empty()) {
return EvaluationKeys();
}
auto kskIt = this->keyswitchKeys.find(clientlib::KEYSWITCH_KEY);
auto bskIt = this->bootstrapKeys.find(clientlib::BOOTSTRAP_KEY);
auto fpkskIt = this->packingKeys.find("fpksk_v0");
if (kskIt != this->keyswitchKeys.end() &&
bskIt != this->bootstrapKeys.end()) {
auto sharedKsk = std::get<1>(kskIt->second);
auto sharedBsk = std::get<1>(bskIt->second);
auto sharedFpksk = fpkskIt == this->packingKeys.end()
? std::make_shared<PackingKeyswitchKey>(nullptr)
: std::get<1>(fpkskIt->second);
return EvaluationKeys(sharedKsk, sharedBsk, sharedFpksk);
}
assert(!mlir::concretelang::dfr::_dfr_is_root_node() &&
"Evaluation keys missing in KeySet (on root node).");
return EvaluationKeys();
}
const std::vector<LweSecretKey> &getSecretKeys();
const std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey64 *>>
&getSecretKeys();
const std::vector<LweBootstrapKey> &getBootstrapKeys();
const std::map<LweSecretKeyID,
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>>
&getBootstrapKeys();
const std::vector<LweKeyswitchKey> &getKeyswitchKeys();
const std::map<LweSecretKeyID,
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>>
&getKeyswitchKeys();
const std::map<
LweSecretKeyID,
std::pair<PackingKeySwitchParam, std::shared_ptr<PackingKeyswitchKey>>> &
getPackingKeys();
const std::vector<PackingKeyswitchKey> &getPackingKeyswitchKeys();
protected:
outcome::checked<void, StringError>
generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param);
generateSecretKey(LweSecretKeyParam param);
outcome::checked<void, StringError>
generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param);
generateBootstrapKey(BootstrapKeyParam param);
outcome::checked<void, StringError>
generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param);
generateKeyswitchKey(KeyswitchKeyParam param);
outcome::checked<void, StringError>
generatePackingKey(PackingKeySwitchID id, PackingKeySwitchParam param);
generatePackingKeyswitchKey(PackingKeyswitchKeyParam param);
outcome::checked<void, StringError>
generateKeysFromParams(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb);
outcome::checked<void, StringError> generateKeysFromParams();
outcome::checked<void, StringError>
setupEncryptionMaterial(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb);
outcome::checked<void, StringError> setupEncryptionMaterial();
friend class KeySetCache;
private:
DefaultEngine *engine;
DefaultParallelEngine *par_engine;
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey64 *>>
secretKeys;
std::map<LweSecretKeyID,
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>>
bootstrapKeys;
std::map<LweSecretKeyID,
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>>
keyswitchKeys;
std::map<LweSecretKeyID, std::pair<PackingKeySwitchParam,
std::shared_ptr<PackingKeyswitchKey>>>
packingKeys;
std::vector<std::tuple<CircuitGate, LweSecretKeyParam, LweSecretKey64 *>>
inputs;
std::vector<std::tuple<CircuitGate, LweSecretKeyParam, LweSecretKey64 *>>
outputs;
CSPRNG csprng;
void setKeys(
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey64 *>>
secretKeys,
std::map<LweSecretKeyID,
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>>
bootstrapKeys,
std::map<LweSecretKeyID,
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>>
keyswitchKeys,
std::map<LweSecretKeyID, std::pair<PackingKeySwitchParam,
std::shared_ptr<PackingKeyswitchKey>>>
packingKeys);
///////////////////////////////////////////////
// Keys mappings
std::vector<LweSecretKey> secretKeys;
std::vector<LweBootstrapKey> bootstrapKeys;
std::vector<LweKeyswitchKey> keyswitchKeys;
std::vector<PackingKeyswitchKey> packingKeyswitchKeys;
outcome::checked<LweSecretKey, StringError> findLweSecretKey(LweSecretKeyID);
///////////////////////////////////////////////
// Convenient positional mapping between positional gate en secret key
typedef std::vector<std::pair<CircuitGate, llvm::Optional<LweSecretKey>>>
SecretKeyGateMapping;
outcome::checked<SecretKeyGateMapping, StringError>
mapCircuitGateLweSecretKey(std::vector<CircuitGate> gates);
SecretKeyGateMapping inputs;
SecretKeyGateMapping outputs;
clientlib::ClientParameters _clientParameters;
};

View File

@@ -14,7 +14,6 @@
#include "concretelang/ClientLib/EncryptedArguments.h"
#include "concretelang/ClientLib/Types.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Runtime/context.h"
namespace concretelang {
namespace serverlib {

View File

@@ -12,13 +12,10 @@
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/ClientLib/Types.h"
#include "concretelang/Runtime/context.h"
namespace concretelang {
namespace clientlib {
using RuntimeContext = mlir::concretelang::RuntimeContext;
// integers are not serialized as binary values even on a binary stream
// so we cannot rely on << operator directly
template <typename Word>
@@ -62,10 +59,6 @@ template <typename Stream> bool incorrectMode(Stream &stream) {
return !binary;
}
std::ostream &operator<<(std::ostream &ostream,
const RuntimeContext &runtimeContext);
std::istream &operator>>(std::istream &istream, RuntimeContext &runtimeContext);
std::ostream &serializeScalarData(const ScalarData &sd, std::ostream &ostream);
outcome::checked<ScalarData, StringError>
@@ -105,17 +98,24 @@ outcome::checked<ScalarOrTensorData, StringError>
unserializeScalarOrTensorData(const std::vector<int64_t> &expectedSizes,
std::istream &istream);
std::ostream &operator<<(std::ostream &ostream, const LweSecretKey &wrappedKsk);
LweSecretKey readLweSecretKey(std::istream &istream);
std::ostream &operator<<(std::ostream &ostream,
const LweKeyswitchKey &wrappedKsk);
std::istream &operator>>(std::istream &istream, LweKeyswitchKey &wrappedKsk);
LweKeyswitchKey readLweKeyswitchKey(std::istream &istream);
std::ostream &operator<<(std::ostream &ostream,
const LweBootstrapKey &wrappedBsk);
std::istream &operator>>(std::istream &istream, LweBootstrapKey &wrappedBsk);
LweBootstrapKey readLweBootstrapKey(std::istream &istream);
std::ostream &operator<<(std::ostream &ostream,
const PackingKeyswitchKey &wrappedKsk);
PackingKeyswitchKey readPackingKeyswitchKey(std::istream &istream);
std::ostream &operator<<(std::ostream &ostream,
const EvaluationKeys &evaluationKeys);
std::istream &operator>>(std::istream &istream, EvaluationKeys &evaluationKeys);
EvaluationKeys readEvaluationKeys(std::istream &istream);
} // namespace clientlib
} // namespace concretelang

View File

@@ -7,12 +7,6 @@
#include <string>
#define CAPI_ASSERT_ERROR(instr) \
{ \
int err = instr; \
assert(err == 0); \
}
namespace concretelang {
namespace error {

View File

@@ -8,6 +8,7 @@
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Casting.h"
#include <cstddef>
#include <list>
namespace mlir {
@@ -19,11 +20,9 @@ struct CrtLoweringParameters {
size_t nMods;
size_t modsProd;
size_t bitsTotal;
size_t polynomialSize;
size_t lutSize;
size_t singleLutSize;
CrtLoweringParameters(mlir::SmallVector<int64_t> mods, size_t polySize)
: mods(mods), polynomialSize(polySize) {
CrtLoweringParameters(mlir::SmallVector<int64_t> mods) : mods(mods) {
nMods = mods.size();
modsProd = 1;
bitsTotal = 0;
@@ -35,9 +34,7 @@ struct CrtLoweringParameters {
bits.push_back(nbits);
bitsTotal += nbits;
}
size_t lutCrtSize = size_t(1) << bitsTotal;
lutCrtSize = std::max(lutCrtSize, polynomialSize);
lutSize = mods.size() * lutCrtSize;
singleLutSize = size_t(1) << bitsTotal;
}
};

View File

@@ -15,12 +15,14 @@ include "concretelang/Dialect/RT/IR/RTTypes.td"
def Concrete_LweTensor : 1DTensorOf<[I64]>;
def Concrete_LutTensor : 1DTensorOf<[I64]>;
def Concrete_CrtLutsTensor : 2DTensorOf<[I64]>;
def Concrete_CrtPlaintextTensor : 1DTensorOf<[I64]>;
def Concrete_LweCRTTensor : 2DTensorOf<[I64]>;
def Concrete_BatchLweTensor : 2DTensorOf<[I64]>;
def Concrete_LweBuffer : MemRefRankOf<[I64], [1]>;
def Concrete_LutBuffer : MemRefRankOf<[I64], [1]>;
def Concrete_CrtLutsBuffer : MemRefRankOf<[I64], [2]>;
def Concrete_CrtPlaintextBuffer : MemRefRankOf<[I64], [1]>;
def Concrete_LweCRTBuffer : MemRefRankOf<[I64], [2]>;
def Concrete_BatchLweBuffer : MemRefRankOf<[I64], [2]>;
@@ -126,7 +128,7 @@ def Concrete_EncodeExpandLutForBootstrapBufferOp : Concrete_Op<"encode_expand_lu
);
}
def Concrete_EncodeExpandLutForWopPBSTensorOp : Concrete_Op<"encode_expand_lut_for_woppbs_tensor", [NoSideEffect]> {
def Concrete_EncodeLutForCrtWopPBSTensorOp : Concrete_Op<"encode_lut_for_crt_woppbs_tensor", [NoSideEffect]> {
let summary =
"Encode and expand a lookup table so that it can be used for a wop pbs";
@@ -134,24 +136,22 @@ def Concrete_EncodeExpandLutForWopPBSTensorOp : Concrete_Op<"encode_expand_lut_f
Concrete_LutTensor : $input_lookup_table,
I64ArrayAttr: $crtDecomposition,
I64ArrayAttr: $crtBits,
I32Attr : $polySize,
I32Attr : $modulusProduct,
BoolAttr: $isSigned
);
let results = (outs Concrete_LutTensor : $result);
let results = (outs Concrete_CrtLutsTensor : $result);
}
def Concrete_EncodeExpandLutForWopPBSBufferOp : Concrete_Op<"encode_expand_lut_for_woppbs_buffer"> {
def Concrete_EncodeLutForCrtWopPBSBufferOp : Concrete_Op<"encode_lut_for_crt_woppbs_buffer"> {
let summary =
"Encode and expand a lookup table so that it can be used for a wop pbs";
"Encode and expand a lookup table so that it can be used for a crt wop pbs";
let arguments = (ins
Concrete_LutBuffer : $result,
Concrete_CrtLutsBuffer : $result,
Concrete_LutBuffer : $input_lookup_table,
I64ArrayAttr: $crtDecomposition,
I64ArrayAttr: $crtBits,
I32Attr : $polySize,
I32Attr : $modulusProduct,
BoolAttr: $isSigned
);
@@ -347,7 +347,7 @@ def Concrete_BatchedKeySwitchLweBufferOp : Concrete_Op<"batched_keyswitch_lwe_bu
def Concrete_WopPBSCRTLweTensorOp : Concrete_Op<"wop_pbs_crt_lwe_tensor", [NoSideEffect]> {
let arguments = (ins
Concrete_LweCRTTensor:$ciphertext,
Concrete_LutTensor:$lookupTable,
Concrete_CrtLutsTensor:$lookupTable,
// Bootstrap parameters
I32Attr : $bootstrapLevel,
I32Attr : $bootstrapBaseLog,
@@ -370,7 +370,7 @@ def Concrete_WopPBSCRTLweBufferOp : Concrete_Op<"wop_pbs_crt_lwe_buffer"> {
let arguments = (ins
Concrete_LweCRTBuffer:$result,
Concrete_LweCRTBuffer:$ciphertext,
Concrete_LutBuffer:$lookup_table,
Concrete_CrtLutsBuffer:$lookup_table,
// Bootstrap parameters
I32Attr : $bootstrapLevel,
I32Attr : $bootstrapBaseLog,

View File

@@ -33,7 +33,7 @@ def TFHE_EncodeExpandLutForBootstrapOp : TFHE_Op<"encode_expand_lut_for_bootstra
let results = (outs 1DTensorOf<[I64]> : $result);
}
def TFHE_EncodeExpandLutForWopPBSOp : TFHE_Op<"encode_expand_lut_for_woppbs"> {
def TFHE_EncodeLutForCrtWopPBSOp : TFHE_Op<"encode_lut_for_crt_woppbs"> {
let summary =
"Encode and expand a lookup table so that it can be used for a wop pbs.";
@@ -41,12 +41,11 @@ def TFHE_EncodeExpandLutForWopPBSOp : TFHE_Op<"encode_expand_lut_for_woppbs"> {
1DTensorOf<[I64]> : $input_lookup_table,
I64ArrayAttr: $crtDecomposition,
I64ArrayAttr: $crtBits,
I32Attr : $polySize,
I32Attr : $modulusProduct,
BoolAttr: $isSigned
);
let results = (outs 1DTensorOf<[I64]> : $result);
let results = (outs 2DTensorOf<[I64]> : $result);
}
def TFHE_EncodePlaintextWithCrtOp : TFHE_Op<"encode_plaintext_with_crt"> {
@@ -158,7 +157,7 @@ def TFHE_WopPBSGLWEOp : TFHE_Op<"wop_pbs_glwe"> {
let arguments = (ins
Type<And<[TensorOf<[TFHE_GLWECipherTextType]>.predicate, HasStaticShapePred]>>: $ciphertexts,
1DTensorOf<[I64]> : $lookupTable,
2DTensorOf<[I64]> : $lookupTable,
// Bootstrap parameters
I32Attr : $bootstrapLevel,
I32Attr : $bootstrapBaseLog,

View File

@@ -12,11 +12,10 @@
#include <pthread.h>
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/Runtime/seeder.h"
#include "concrete-core-ffi.h"
#include "concretelang/Common/Error.h"
#include "concrete-cpu.h"
#ifdef CONCRETELANG_CUDA_SUPPORT
#include "bootstrap.h"
#include "device.h"
@@ -26,27 +25,23 @@
namespace mlir {
namespace concretelang {
typedef struct FFT {
FFT() = delete;
FFT(size_t polynomial_size);
FFT(FFT &other) = delete;
FFT(FFT &&other);
~FFT();
struct Fft *fft;
size_t polynomial_size;
} FFT;
typedef struct RuntimeContext {
RuntimeContext() {
CAPI_ASSERT_ERROR(new_default_engine(best_seeder, &default_engine));
#ifdef CONCRETELANG_CUDA_SUPPORT
bsk_gpu = nullptr;
ksk_gpu = nullptr;
#endif
}
/// Ensure that the engines map is not copied
RuntimeContext(const RuntimeContext &ctx){};
RuntimeContext() = delete;
RuntimeContext(::concretelang::clientlib::EvaluationKeys evaluationKeys);
~RuntimeContext() {
CAPI_ASSERT_ERROR(destroy_default_engine(default_engine));
for (const auto &key : fft_engines) {
CAPI_ASSERT_ERROR(destroy_fft_engine(key.second));
}
if (fbsk != nullptr) {
CAPI_ASSERT_ERROR(destroy_fft_fourier_lwe_bootstrap_key_u64(fbsk));
}
#ifdef CONCRETELANG_CUDA_SUPPORT
if (bsk_gpu != nullptr) {
cuda_drop(bsk_gpu, 0);
@@ -55,44 +50,33 @@ typedef struct RuntimeContext {
cuda_drop(ksk_gpu, 0);
}
#endif
};
const uint64_t *keyswitch_key_buffer(size_t keyId) {
return evaluationKeys.getKeyswitchKey(keyId).buffer();
}
FftEngine *get_fft_engine() {
pthread_t threadId = pthread_self();
std::lock_guard<std::mutex> guard(engines_map_guard);
auto engineIt = fft_engines.find(threadId);
if (engineIt == fft_engines.end()) {
FftEngine *fft_engine = nullptr;
CAPI_ASSERT_ERROR(new_fft_engine(&fft_engine));
engineIt =
fft_engines
.insert(std::pair<pthread_t, FftEngine *>(threadId, fft_engine))
.first;
}
assert(engineIt->second && "No engine available in context");
return engineIt->second;
const double *fourier_bootstrap_key_buffer(size_t keyId) {
return fourier_bootstrap_keys[keyId]->data();
}
DefaultEngine *get_default_engine() { return default_engine; }
FftFourierLweBootstrapKey64 *get_fft_fourier_bsk() {
if (fbsk != nullptr) {
return fbsk;
}
const std::lock_guard<std::mutex> guard(fbskMutex);
if (fbsk == nullptr) {
CAPI_ASSERT_ERROR(
fft_engine_convert_lwe_bootstrap_key_to_fft_fourier_lwe_bootstrap_key_u64(
get_fft_engine(), evaluationKeys.getBsk(), &fbsk));
}
return fbsk;
const uint64_t *fp_keyswitch_key_buffer(size_t keyId) {
return evaluationKeys.getPackingKeyswitchKey(keyId).buffer();
}
const struct Fft *fft(size_t keyId) { return ffts[keyId].fft; }
const ::concretelang::clientlib::EvaluationKeys getKeys() const {
return evaluationKeys;
}
private:
::concretelang::clientlib::EvaluationKeys evaluationKeys;
std::vector<std::shared_ptr<std::vector<double>>> fourier_bootstrap_keys;
std::vector<FFT> ffts;
#ifdef CONCRETELANG_CUDA_SUPPORT
public:
void *get_bsk_gpu(uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
uint32_t glwe_dim, uint32_t gpu_idx, void *stream) {
@@ -104,25 +88,21 @@ typedef struct RuntimeContext {
if (bsk_gpu != nullptr) {
return bsk_gpu;
}
LweBootstrapKey64 *bsk = get_bsk();
size_t bsk_buffer_len =
input_lwe_dim * (glwe_dim + 1) * (glwe_dim + 1) * poly_size * level;
size_t bsk_buffer_size = bsk_buffer_len * sizeof(uint64_t);
uint64_t *bsk_buffer =
(uint64_t *)aligned_alloc(U64_ALIGNMENT, bsk_buffer_size);
auto bsk = evaluationKeys.getBootstrapKey(0);
size_t bsk_buffer_len = bsk.size();
size_t bsk_gpu_buffer_size = bsk_buffer_len * sizeof(double);
void *bsk_gpu_tmp = cuda_malloc(bsk_gpu_buffer_size, gpu_idx);
CAPI_ASSERT_ERROR(
default_engine_discard_convert_lwe_bootstrap_key_to_lwe_bootstrap_key_mut_view_u64_raw_ptr_buffers(
default_engine, bsk, bsk_buffer));
cuda_initialize_twiddles(poly_size, gpu_idx);
cuda_convert_lwe_bootstrap_key_64(bsk_gpu_tmp, bsk_buffer, stream, gpu_idx,
input_lwe_dim, glwe_dim, level,
cuda_convert_lwe_bootstrap_key_64(bsk_gpu_tmp, (void *)bsk.buffer(), stream,
gpu_idx, input_lwe_dim, glwe_dim, level,
poly_size);
// This is currently not 100% async as we have to free CPU memory after
// This is currently not 100% async as
// we have to free CPU memory after
// conversion
cuda_synchronize_device(gpu_idx);
free(bsk_buffer);
bsk_gpu = bsk_gpu_tmp;
return bsk_gpu;
}
@@ -138,49 +118,23 @@ typedef struct RuntimeContext {
if (ksk_gpu != nullptr) {
return ksk_gpu;
}
LweKeyswitchKey64 *ksk = get_ksk();
size_t ksk_buffer_len = input_lwe_dim * (output_lwe_dim + 1) * level;
size_t ksk_buffer_size = sizeof(uint64_t) * ksk_buffer_len;
uint64_t *ksk_buffer =
(uint64_t *)aligned_alloc(U64_ALIGNMENT, ksk_buffer_size);
auto ksk = evaluationKeys.getKeyswitchKey(0);
size_t ksk_buffer_size = sizeof(uint64_t) * ksk.size();
void *ksk_gpu_tmp = cuda_malloc(ksk_buffer_size, gpu_idx);
CAPI_ASSERT_ERROR(
default_engine_discard_convert_lwe_keyswitch_key_to_lwe_keyswitch_key_mut_view_u64_raw_ptr_buffers(
default_engine, ksk, ksk_buffer));
cuda_memcpy_async_to_gpu(ksk_gpu_tmp, ksk_buffer, ksk_buffer_size, stream,
gpu_idx);
// This is currently not 100% async as we have to free CPU memory after
cuda_memcpy_async_to_gpu(ksk_gpu_tmp, (void *)ksk.buffer(), ksk_buffer_size,
stream, gpu_idx);
// This is currently not 100% async as
// we have to free CPU memory after
// conversion
cuda_synchronize_device(gpu_idx);
free(ksk_buffer);
ksk_gpu = ksk_gpu_tmp;
return ksk_gpu;
}
#endif
LweBootstrapKey64 *get_bsk() { return evaluationKeys.getBsk(); }
LweKeyswitchKey64 *get_ksk() { return evaluationKeys.getKsk(); }
LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *get_fpksk() {
return evaluationKeys.getFpksk();
}
RuntimeContext &operator=(const RuntimeContext &rhs) {
this->evaluationKeys = rhs.evaluationKeys;
return *this;
}
::concretelang::clientlib::EvaluationKeys evaluationKeys;
private:
std::mutex fbskMutex;
FftFourierLweBootstrapKey64 *fbsk = nullptr;
DefaultEngine *default_engine;
std::map<pthread_t, FftEngine *> fft_engines;
std::mutex engines_map_guard;
#ifdef CONCRETELANG_CUDA_SUPPORT
std::mutex bsk_gpu_mutex;
void *bsk_gpu;
std::mutex ksk_gpu_mutex;
@@ -192,18 +146,4 @@ private:
} // namespace concretelang
} // namespace mlir
extern "C" {
LweKeyswitchKey64 *
get_keyswitch_key_u64(mlir::concretelang::RuntimeContext *context);
FftFourierLweBootstrapKey64 *
get_fft_fourier_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context);
LweBootstrapKey64 *
get_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context);
DefaultEngine *get_engine(mlir::concretelang::RuntimeContext *context);
FftEngine *get_fft_engine(mlir::concretelang::RuntimeContext *context);
}
#endif

View File

@@ -14,17 +14,16 @@
#include <hpx/modules/collectives.hpp>
#include <hpx/modules/serialization.hpp>
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/Runtime/DFRuntime.hpp"
#include "concretelang/Runtime/context.h"
#include "concrete-core-ffi.h"
#include "concretelang/Common/Error.h"
namespace mlir {
namespace concretelang {
namespace dfr {
template <typename T> struct KeyManager;
struct RuntimeContextManager;
namespace {
static void *dl_handle;
@@ -32,103 +31,83 @@ static RuntimeContextManager *_dfr_node_level_runtime_context_manager;
} // namespace
template <typename LweKeyType> struct KeyWrapper {
LweKeyType *key;
Buffer buffer;
std::vector<LweKeyType> keys;
KeyWrapper() : key(nullptr) {}
KeyWrapper(KeyWrapper &&moved) noexcept
: key(moved.key), buffer(moved.buffer) {}
KeyWrapper(LweKeyType *key);
KeyWrapper(const KeyWrapper &kw) : key(kw.key), buffer(kw.buffer) {}
KeyWrapper() {}
KeyWrapper(KeyWrapper &&moved) noexcept : keys(moved.keys) {}
KeyWrapper(const KeyWrapper &kw) : keys(kw.keys) {}
KeyWrapper &operator=(const KeyWrapper &rhs) {
this->key = rhs.key;
this->buffer = rhs.buffer;
this->keys = rhs.keys;
return *this;
}
KeyWrapper(std::vector<LweKeyType> keyvec) : keys(keyvec) {}
friend class hpx::serialization::access;
// template <class Archive>
// void save(Archive &ar, const unsigned int version) const;
template <class Archive>
void save(Archive &ar, const unsigned int version) const;
template <class Archive> void load(Archive &ar, const unsigned int version);
HPX_SERIALIZATION_SPLIT_MEMBER()
void serialize(Archive &ar, const unsigned int version) const {}
// template <class Archive> void load(Archive &ar, const unsigned int
// version); HPX_SERIALIZATION_SPLIT_MEMBER()
};
template <>
KeyWrapper<LweKeyswitchKey64>::KeyWrapper(LweKeyswitchKey64 *key) : key(key) {
DefaultSerializationEngine *engine;
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
// No Freeing as it doesn't allocate anything.
CAPI_ASSERT_ERROR(
default_serialization_engine_serialize_lwe_keyswitch_key_u64(engine, key,
&buffer));
}
template <>
KeyWrapper<LweBootstrapKey64>::KeyWrapper(LweBootstrapKey64 *key) : key(key) {
DefaultSerializationEngine *engine;
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
// No Freeing as it doesn't allocate anything.
CAPI_ASSERT_ERROR(
default_serialization_engine_serialize_lwe_bootstrap_key_u64(engine, key,
&buffer));
}
template <typename LweKeyType>
bool operator==(const KeyWrapper<LweKeyType> &lhs,
const KeyWrapper<LweKeyType> &rhs) {
return lhs.key == rhs.key;
if (lhs.keys.size() != rhs.keys.size())
return false;
for (size_t i = 0; i < lhs.keys.size(); ++i)
if (lhs.keys[i].buffer() != rhs.keys[i].buffer())
return false;
return true;
}
template <>
template <class Archive>
void KeyWrapper<LweBootstrapKey64>::save(Archive &ar,
const unsigned int version) const {
ar << buffer.length;
ar << hpx::serialization::make_array(buffer.pointer, buffer.length);
}
template <>
template <class Archive>
void KeyWrapper<LweBootstrapKey64>::load(Archive &ar,
const unsigned int version) {
DefaultSerializationEngine *engine;
// template <>
// template <class Archive>
// void KeyWrapper<LweBootstrapKey>::save(Archive &ar,
// const unsigned int version) const {
// ar << buffer.length;
// ar << hpx::serialization::make_array(buffer.pointer, buffer.length);
// }
// template <>
// template <class Archive>
// void KeyWrapper<LweBootstrapKey>::load(Archive &ar,
// const unsigned int version) {
// DefaultSerializationEngine *engine;
// No Freeing as it doesn't allocate anything.
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
// // No Freeing as it doesn't allocate anything.
// CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
ar >> buffer.length;
buffer.pointer = new uint8_t[buffer.length];
ar >> hpx::serialization::make_array(buffer.pointer, buffer.length);
CAPI_ASSERT_ERROR(
default_serialization_engine_deserialize_lwe_bootstrap_key_u64(
engine, {buffer.pointer, buffer.length}, &key));
}
// ar >> buffer.length;
// buffer.pointer = new uint8_t[buffer.length];
// ar >> hpx::serialization::make_array(buffer.pointer, buffer.length);
// CAPI_ASSERT_ERROR(
// default_serialization_engine_deserialize_lwe_bootstrap_key_u64(
// engine, {buffer.pointer, buffer.length}, &key));
// }
template <>
template <class Archive>
void KeyWrapper<LweKeyswitchKey64>::save(Archive &ar,
const unsigned int version) const {
ar << buffer.length;
ar << hpx::serialization::make_array(buffer.pointer, buffer.length);
}
template <>
template <class Archive>
void KeyWrapper<LweKeyswitchKey64>::load(Archive &ar,
const unsigned int version) {
DefaultSerializationEngine *engine;
// template <>
// template <class Archive>
// void KeyWrapper<LweKeyswitchKey>::save(Archive &ar,
// const unsigned int version) const {
// ar << buffer.length;
// ar << hpx::serialization::make_array(buffer.pointer, buffer.length);
// }
// template <>
// template <class Archive>
// void KeyWrapper<LweKeyswitchKey>::load(Archive &ar,
// const unsigned int version) {
// DefaultSerializationEngine *engine;
// No Freeing as it doesn't allocate anything.
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
// // No Freeing as it doesn't allocate anything.
// CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
ar >> buffer.length;
buffer.pointer = new uint8_t[buffer.length];
ar >> hpx::serialization::make_array(buffer.pointer, buffer.length);
CAPI_ASSERT_ERROR(
default_serialization_engine_deserialize_lwe_keyswitch_key_u64(
engine, {buffer.pointer, buffer.length}, &key));
}
// ar >> buffer.length;
// buffer.pointer = new uint8_t[buffer.length];
// ar >> hpx::serialization::make_array(buffer.pointer, buffer.length);
// CAPI_ASSERT_ERROR(
// default_serialization_engine_deserialize_lwe_keyswitch_key_u64(
// engine, {buffer.pointer, buffer.length}, &key));
// }
/************************/
/* Context management. */
@@ -152,34 +131,36 @@ struct RuntimeContextManager {
// instantiates a local RuntimeContext.
if (_dfr_is_root_node()) {
RuntimeContext *context = (RuntimeContext *)ctx;
LweKeyswitchKey64 *ksk = get_keyswitch_key_u64(context);
LweBootstrapKey64 *bsk = get_bootstrap_key_u64(context);
KeyWrapper<LweKeyswitchKey64> kskw(ksk);
KeyWrapper<LweBootstrapKey64> bskw(bsk);
KeyWrapper<::concretelang::clientlib::LweKeyswitchKey> kskw(
context->getKeys().getKeyswitchKeys());
KeyWrapper<::concretelang::clientlib::LweBootstrapKey> bskw(
context->getKeys().getBootstrapKeys());
KeyWrapper<::concretelang::clientlib::PackingKeyswitchKey> pkskw(
context->getKeys().getPackingKeyswitchKeys());
hpx::collectives::broadcast_to("ksk_keystore", kskw);
hpx::collectives::broadcast_to("bsk_keystore", bskw);
hpx::collectives::broadcast_to("pksk_keystore", pkskw);
} else {
auto kskFut =
hpx::collectives::broadcast_from<KeyWrapper<LweKeyswitchKey64>>(
"ksk_keystore");
auto bskFut =
hpx::collectives::broadcast_from<KeyWrapper<LweBootstrapKey64>>(
"bsk_keystore");
auto kskFut = hpx::collectives::broadcast_from<
KeyWrapper<::concretelang::clientlib::LweKeyswitchKey>>(
"ksk_keystore");
auto bskFut = hpx::collectives::broadcast_from<
KeyWrapper<::concretelang::clientlib::LweBootstrapKey>>(
"bsk_keystore");
auto pkskFut = hpx::collectives::broadcast_from<
KeyWrapper<::concretelang::clientlib::PackingKeyswitchKey>>(
"pksk_keystore");
KeyWrapper<LweKeyswitchKey64> kskw = kskFut.get();
KeyWrapper<LweBootstrapKey64> bskw = bskFut.get();
context = new mlir::concretelang::RuntimeContext();
// TODO - packing keyswitch key for distributed
context->evaluationKeys = ::concretelang::clientlib::EvaluationKeys(
std::shared_ptr<::concretelang::clientlib::LweKeyswitchKey>(
new ::concretelang::clientlib::LweKeyswitchKey(kskw.key)),
std::shared_ptr<::concretelang::clientlib::LweBootstrapKey>(
new ::concretelang::clientlib::LweBootstrapKey(bskw.key)),
std::shared_ptr<::concretelang::clientlib::PackingKeyswitchKey>(
nullptr));
delete (kskw.buffer.pointer);
delete (bskw.buffer.pointer);
KeyWrapper<::concretelang::clientlib::LweKeyswitchKey> kskw =
kskFut.get();
KeyWrapper<::concretelang::clientlib::LweBootstrapKey> bskw =
bskFut.get();
KeyWrapper<::concretelang::clientlib::PackingKeyswitchKey> pkskw =
pkskFut.get();
context = new mlir::concretelang::RuntimeContext(
::concretelang::clientlib::EvaluationKeys(kskw.keys, bskw.keys,
pkskw.keys));
}
}

View File

@@ -1,13 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_RUNTIME_SEEDER_H
#define CONCRETELANG_RUNTIME_SEEDER_H
#include "concrete-core-ffi.h"
extern SeederBuilder *best_seeder;
#endif

View File

@@ -29,18 +29,19 @@ void memref_encode_expand_lut_for_bootstrap(
uint64_t input_lut_size, uint64_t input_lut_stride, uint32_t poly_size,
uint32_t out_MESSAGE_BITS, bool is_signed);
void memref_encode_expand_lut_for_woppbs(
void memref_encode_lut_for_crt_woppbs(
uint64_t *output_lut_allocated, uint64_t *output_lut_aligned,
uint64_t output_lut_offset, uint64_t output_lut_size,
uint64_t output_lut_stride, uint64_t *input_lut_allocated,
uint64_t output_lut_offset, uint64_t output_lut_size0,
uint64_t output_lut_size1, uint64_t output_lut_stride0,
uint64_t output_lut_stride1, uint64_t *input_lut_allocated,
uint64_t *input_lut_aligned, uint64_t input_lut_offset,
uint64_t input_lut_size, uint64_t input_lut_stride,
uint64_t *crt_decomposition_allocated, uint64_t *crt_decomposition_aligned,
uint64_t crt_decomposition_offset, uint64_t crt_decomposition_size,
uint64_t crt_decomposition_stride, uint64_t *crt_bits_allocated,
uint64_t *crt_bits_aligned, uint64_t crt_bits_offset,
uint64_t crt_bits_size, uint64_t crt_bits_stride, uint32_t poly_size,
uint32_t modulus_product, bool is_signed);
uint64_t crt_bits_size, uint64_t crt_bits_stride, uint32_t modulus_product,
bool is_signed);
void memref_encode_plaintext_with_crt(
uint64_t *output_allocated, uint64_t *output_aligned,
@@ -149,13 +150,16 @@ void memref_wop_pbs_crt_buffer(
uint64_t in_stride_1,
// clear text lut
uint64_t *lut_ct_allocated, uint64_t *lut_ct_aligned,
uint64_t lut_ct_offset, uint64_t lut_ct_size, uint64_t lut_ct_stride,
uint64_t lut_ct_offset, uint64_t lut_ct_size0, uint64_t lut_ct_size1,
uint64_t lut_ct_stride0, uint64_t lut_ct_stride1,
// CRT decomposition
uint64_t *crt_decomp_allocated, uint64_t *crt_decomp_aligned,
uint64_t crt_decomp_offset, uint64_t crt_decomp_size,
uint64_t crt_decomp_stride,
// Additional crypto parameters
uint32_t lwe_small_size, uint32_t cbs_level_count, uint32_t cbs_base_log,
uint32_t ksk_level_count, uint32_t ksk_base_log, uint32_t bsk_level_count,
uint32_t bsk_base_log, uint32_t fpksk_level_count, uint32_t fpksk_base_log,
uint32_t polynomial_size,
// runtime context that hold evluation keys
mlir::concretelang::RuntimeContext *context);

View File

@@ -337,8 +337,8 @@ public:
if (check.has_error()) {
return StreamStringError(check.error().mesg);
}
auto publicArguments = encryptedArgs->exportPublicArguments(
clientParameters, keySet.runtimeContext());
auto publicArguments =
encryptedArgs->exportPublicArguments(clientParameters);
if (publicArguments.has_error()) {
return StreamStringError(publicArguments.error().mesg);
}
@@ -485,8 +485,8 @@ public:
if (encryptedArgs.has_error()) {
return StreamStringError(encryptedArgs.error().mesg);
}
auto publicArguments = encryptedArgs.value()->exportPublicArguments(
clientParameters, keySet->runtimeContext());
auto publicArguments =
encryptedArgs.value()->exportPublicArguments(clientParameters);
if (!publicArguments.has_value()) {
return StreamStringError(publicArguments.error().mesg);
}
@@ -506,8 +506,8 @@ public:
if (encryptedArgs.has_error()) {
return StreamStringError(encryptedArgs.error().mesg);
}
auto publicArguments = encryptedArgs.value()->exportPublicArguments(
clientParameters, keySet->runtimeContext());
auto publicArguments =
encryptedArgs.value()->exportPublicArguments(clientParameters);
if (publicArguments.has_error()) {
return StreamStringError(publicArguments.error().mesg);
}

View File

@@ -73,8 +73,7 @@ public:
OUTCOME_TRY(auto encryptedArgs,
clientlib::EncryptedArguments::create(*keySet, args...));
OUTCOME_TRY(auto publicArgument,
encryptedArgs->exportPublicArguments(this->clientParameters,
keySet->runtimeContext()));
encryptedArgs->exportPublicArguments(this->clientParameters));
// client argument serialization
// publicArgument->serialize(clientOuput);
// message = clientOuput.str();