feat(compiler/client-lib): Implement ValueExporter to allows partial encryption

This commit is contained in:
Bourgerie Quentin
2023-05-30 14:30:46 +02:00
committed by Umut
parent 17f1107231
commit f7f94a1663
5 changed files with 242 additions and 174 deletions

View File

@@ -295,9 +295,7 @@ generate-cpu-tests: \
SECURITY_TO_TEST=128
OPTIMIZATION_STRATEGY_TO_TEST=dag-mono dag-multi
PARALLEL_END_2_END_TESTS= end_to_end_jit_test \
end_to_end_jit_lambda \
end_to_end_jit_lambda
PARALLEL_END_2_END_TESTS= end_to_end_jit_test end_to_end_jit_lambda
run-end-to-end-tests: $(GTEST_PARALLEL_PY) build-end-to-end-tests generate-cpu-tests
$(foreach TEST,$(PARALLEL_END_2_END_TESTS), \
$(GTEST_PARALLEL_CMD) $(BUILD_DIR)/tools/concretelang/tests/end_to_end_tests/$(TEST);)

View File

@@ -23,6 +23,139 @@ using concretelang::error::StringError;
class PublicArguments;
/// @brief The ArgumentsExporter allows to transform clear
/// arguments to the one expected by a server lambda.
class ValueExporter {
public:
/// @brief
/// @param keySet
/// @param clientParameters
// TODO: Get rid of the reference here could make troubles (see for KeySet
// copy constructor or shared pointers)
ValueExporter(KeySet &keySet, ClientParameters clientParameters)
: _keySet(keySet), _clientParameters(clientParameters) {}
/// @brief Export a scalar 64 bits integer to a concreteprocol::Value
/// @param arg An 64 bits integer
/// @param argPos The position of the argument to export
/// @return Either the exported value ready to be sent to the server or an
/// error if the gate doesn't match the expected argument.
outcome::checked<ScalarOrTensorData, StringError> exportValue(uint64_t arg,
size_t argPos) {
OUTCOME_TRY(auto gate, _clientParameters.input(argPos));
if (gate.shape.size != 0) {
return StringError("argument #") << argPos << " is not a scalar";
}
if (gate.encryption.has_value()) {
return exportEncryptValue(arg, gate, argPos);
}
return exportClearValue(arg);
}
/// @brief Export a tensor like buffer of values to a serializable value
/// @tparam T The type of values hold by the buffer
/// @param arg A pointer to a memory area where the values are stored
/// @param shape The shape of the tensor
/// @param argPos The position of the argument to export
/// @return Either the exported value ready to be sent to the server or an
/// error if the gate doesn't match the expected argument.
template <typename T>
outcome::checked<ScalarOrTensorData, StringError>
exportValue(const T *arg, llvm::ArrayRef<int64_t> shape, size_t argPos) {
OUTCOME_TRY(auto gate, _clientParameters.input(argPos));
OUTCOME_TRYV(checkShape(shape, gate.shape, argPos));
if (gate.encryption.has_value()) {
return exportEncryptTensor(arg, shape, gate, argPos);
}
return exportClearTensor(arg, shape, gate);
}
private:
/// Export a 64bits integer to a serializable value
outcome::checked<ScalarOrTensorData, StringError>
exportClearValue(uint64_t arg) {
return ScalarData(arg);
}
/// Encrypt and export a 64bits integer to a serializale value
outcome::checked<ScalarOrTensorData, StringError>
exportEncryptValue(uint64_t arg, CircuitGate &gate, size_t argPos) {
std::vector<int64_t> shape = _clientParameters.bufferShape(gate);
// Create and allocate the TensorData that will holds encrypted value
TensorData td(shape, clientlib::EncryptedScalarElementType,
clientlib::EncryptedScalarElementWidth);
// Encrypt the value
OUTCOME_TRYV(
_keySet.encrypt_lwe(argPos, td.getElementPointer<uint64_t>(0), arg));
return std::move(td);
}
/// Export a tensor like buffer to a serializable value
template <typename T>
outcome::checked<ScalarOrTensorData, StringError>
exportClearTensor(const T *arg, llvm::ArrayRef<int64_t> shape,
CircuitGate &gate) {
auto bitsPerValue = bitWidthAsWord(gate.shape.width);
auto sizes = _clientParameters.bufferShape(gate);
TensorData td(sizes, bitsPerValue, gate.shape.sign);
llvm::ArrayRef<T> values(arg, TensorData::getNumElements(sizes));
td.bulkAssign(values);
return std::move(td);
}
/// Export and encrypt a tensor like buffer to a serializable value
template <typename T>
outcome::checked<ScalarOrTensorData, StringError>
exportEncryptTensor(const T *arg, llvm::ArrayRef<int64_t> shape,
CircuitGate &gate, size_t argPos) {
// Create and allocate the TensorData that will holds encrypted values
auto sizes = _clientParameters.bufferShape(gate);
TensorData td(sizes, EncryptedScalarElementType,
EncryptedScalarElementWidth);
// Iterate over values and encrypt at the right place the value
auto lweSize = _clientParameters.lweBufferSize(gate);
for (size_t i = 0, offset = 0; i < gate.shape.size;
i++, offset += lweSize) {
OUTCOME_TRYV(_keySet.encrypt_lwe(
argPos, td.getElementPointer<uint64_t>(offset), arg[i]));
}
return std::move(td);
}
static outcome::checked<void, StringError>
checkShape(llvm::ArrayRef<int64_t> shape, CircuitGateShape expected,
size_t argPos) {
// Check the shape of tensor
if (expected.dimensions.empty()) {
return StringError("argument #") << argPos << "is not a tensor";
}
if (shape.size() != expected.dimensions.size()) {
return StringError("argument #")
<< argPos << "has not the expected number of dimension, got "
<< shape.size() << " expected " << expected.dimensions.size();
}
// Check shape
for (size_t i = 0; i < shape.size(); i++) {
if (shape[i] != expected.dimensions[i]) {
return StringError("argument #")
<< argPos << " has not the expected dimension #" << i
<< " , got " << shape[i] << " expected "
<< expected.dimensions[i];
}
}
return outcome::success();
}
private:
KeySet &_keySet;
ClientParameters _clientParameters;
};
/// Temporary object used to hold and encrypt parameters before calling a
/// ClientLambda. Use preferably TypeClientLambda and serializeCall(Args...).
/// Otherwise convert it to a PublicArguments and use
@@ -30,10 +163,10 @@ class PublicArguments;
class EncryptedArguments {
public:
EncryptedArguments() : currentPos(0) {}
EncryptedArguments() {}
/// Encrypts args thanks the given KeySet and pack the encrypted arguments to
/// an EncryptedArguments
/// Encrypts args thanks the given KeySet and pack the encrypted arguments
/// to an EncryptedArguments
template <typename... Args>
static outcome::checked<std::unique_ptr<EncryptedArguments>, StringError>
create(KeySet &keySet, Args... args) {
@@ -69,7 +202,12 @@ public:
public:
/// Add a uint64_t scalar argument.
outcome::checked<void, StringError> pushArg(uint64_t arg, KeySet &keySet);
outcome::checked<void, StringError> pushArg(uint64_t arg, KeySet &keySet) {
ValueExporter exporter(keySet, keySet.clientParameters());
OUTCOME_TRY(auto value, exporter.exportValue(arg, values.size()));
values.push_back(std::move(value));
return outcome::success();
}
/// Add a vector-tensor argument.
outcome::checked<void, StringError> pushArg(std::vector<uint8_t> arg,
@@ -129,58 +267,9 @@ public:
template <typename T>
outcome::checked<void, StringError>
pushArg(const T *data, llvm::ArrayRef<int64_t> shape, KeySet &keySet) {
OUTCOME_TRYV(checkPushTooManyArgs(keySet));
auto pos = currentPos;
CircuitGate input = keySet.inputGate(pos);
// Check the width of data
if (input.shape.width > 64) {
return StringError("argument #")
<< pos << " width > 64 bits is not supported";
}
// Check the shape of tensor
if (input.shape.dimensions.empty()) {
return StringError("argument #") << pos << "is not a tensor";
}
if (shape.size() != input.shape.dimensions.size()) {
return StringError("argument #")
<< pos << "has not the expected number of dimension, got "
<< shape.size() << " expected " << input.shape.dimensions.size();
}
// Check shape
for (size_t i = 0; i < shape.size(); i++) {
if (shape[i] != input.shape.dimensions[i]) {
return StringError("argument #")
<< pos << " has not the expected dimension #" << i << " , got "
<< shape[i] << " expected " << input.shape.dimensions[i];
}
}
// Set sizes
std::vector<int64_t> sizes = keySet.clientParameters().bufferShape(input);
if (input.encryption.has_value()) {
TensorData td(sizes, EncryptedScalarElementType,
EncryptedScalarElementWidth);
auto lweSize = keySet.clientParameters().lweBufferSize(input);
for (size_t i = 0, offset = 0; i < input.shape.size;
i++, offset += lweSize) {
OUTCOME_TRYV(keySet.encrypt_lwe(
pos, td.getElementPointer<uint64_t>(offset), data[i]));
}
ciphertextBuffers.push_back(std::move(td));
} else {
auto bitsPerValue = bitWidthAsWord(input.shape.width);
TensorData td(sizes, bitsPerValue, input.shape.sign);
llvm::ArrayRef<T> values(data, TensorData::getNumElements(sizes));
td.bulkAssign(values);
ciphertextBuffers.push_back(std::move(td));
}
currentPos++;
ValueExporter exporter(keySet, keySet.clientParameters());
OUTCOME_TRY(auto value, exporter.exportValue(data, shape, values.size()));
values.push_back(std::move(value));
return outcome::success();
}
@@ -208,14 +297,8 @@ public:
}
private:
outcome::checked<void, StringError> checkPushTooManyArgs(KeySet &keySet);
private:
/// Position of the next pushed argument
size_t currentPos;
/// Store buffers of ciphertexts
std::vector<ScalarOrTensorData> ciphertextBuffers;
std::vector<ScalarOrTensorData> values;
};
} // namespace clientlib

View File

@@ -32,6 +32,81 @@ using concretelang::error::StringError;
class EncryptedArguments;
/// @brief allows to transform a serializable value into a clear value
class ValueDecrypter {
public:
ValueDecrypter(KeySet &keySet, ClientParameters clientParameters)
: _keySet(keySet), _clientParameters(clientParameters) {}
/// @brief Transforms a FHE value into a clear scalar value
/// @tparam T The type of the clear scalar value
/// @param value The value to decrypt
/// @param pos The position of the argument
/// @return Either the decrypted value or an error if the gate doesn't match
/// the expected result.
template <typename T>
outcome::checked<T, StringError> decrypt(ScalarOrTensorData &value,
size_t pos) {
OUTCOME_TRY(auto gate, _clientParameters.ouput(pos));
if (!gate.isEncrypted())
return value.getScalar().getValue<T>();
auto &buffer = value.getTensor();
auto ciphertext = buffer.getOpaqueElementPointer(0);
uint64_t decrypted;
// Convert to uint64_t* as required by `KeySet::decrypt_lwe`
// FIXME: this may break alignment restrictions on some
// architectures
auto ciphertextu64 = reinterpret_cast<uint64_t *>(ciphertext);
OUTCOME_TRYV(_keySet.decrypt_lwe(0, ciphertextu64, decrypted));
return (T)decrypted;
}
/// @brief Transforms a FHE value into a vector of clear value
/// @tparam T The type of the clear scalar value
/// @param value The value to decrypt
/// @param pos The position of the argument
/// @return Either the decrypted value or an error if the gate doesn't match
/// the expected result.
template <typename T>
outcome::checked<std::vector<T>, StringError>
decryptTensor(ScalarOrTensorData &value, size_t pos) {
OUTCOME_TRY(auto gate, _clientParameters.ouput(pos));
if (!gate.isEncrypted())
return value.getTensor().asFlatVector<T>();
auto &buffer = value.getTensor();
auto lweSize = _clientParameters.lweBufferSize(gate);
std::vector<T> decryptedValues(buffer.length() / lweSize);
for (size_t i = 0; i < decryptedValues.size(); i++) {
auto ciphertext = buffer.getOpaqueElementPointer(i * lweSize);
uint64_t decrypted;
// Convert to uint64_t* as required by `KeySet::decrypt_lwe`
// FIXME: this may break alignment restrictions on some
// architectures
auto ciphertextu64 = reinterpret_cast<uint64_t *>(ciphertext);
OUTCOME_TRYV(_keySet.decrypt_lwe(0, ciphertextu64, decrypted));
decryptedValues[i] = decrypted;
}
return decryptedValues;
}
/// Return the shape of the clear tensor of a result.
outcome::checked<std::vector<int64_t>, StringError> getShape(size_t pos) {
OUTCOME_TRY(auto gate, _clientParameters.ouput(pos));
return gate.shape.dimensions;
}
private:
KeySet &_keySet;
ClientParameters _clientParameters;
};
/// PublicArguments will be sended to the server. It includes encrypted
/// arguments and public keys.
class PublicArguments {
@@ -71,6 +146,17 @@ struct PublicResult {
PublicResult(PublicResult &) = delete;
/// @brief Return a value from the PublicResult
/// @param argPos The position of the value in the PublicResult
/// @return Either the value or an error if there are no value at this
/// position
outcome::checked<ScalarOrTensorData, StringError> getValue(size_t argPos) {
if (argPos >= buffers.size()) {
return StringError("result #") << argPos << " does not exists";
}
return std::move(buffers[argPos]);
}
/// Create a public result from buffers.
static std::unique_ptr<PublicResult>
fromBuffers(const ClientParameters &clientParameters,
@@ -90,49 +176,14 @@ struct PublicResult {
/// Serialize into an output stream.
outcome::checked<void, StringError> serialize(std::ostream &ostream);
/// Get the original integer that was decomposed into chunks of `chunkWidth`
/// bits each
uint64_t fromChunks(std::vector<uint64_t> chunks, unsigned int chunkWidth) {
uint64_t value = 0;
uint64_t mask = (1 << chunkWidth) - 1;
for (size_t i = 0; i < chunks.size(); i++) {
auto chunk = chunks[i] & mask;
value += chunk << (chunkWidth * i);
}
return value;
}
/// Get the result at `pos` as a scalar. Decryption happens if the
/// result is encrypted.
template <typename T>
outcome::checked<T, StringError> asClearTextScalar(KeySet &keySet,
size_t pos) {
OUTCOME_TRY(auto gate, clientParameters.ouput(pos));
if (!gate.isEncrypted())
return buffers[pos].getScalar().getValue<T>();
// Chunked integers are represented as tensors at a lower level, so we need
// to deal with them as tensors, then build the resulting scalar out of the
// tensor values
if (gate.chunkInfo.has_value()) {
OUTCOME_TRY(std::vector<uint64_t> decryptedChunks,
this->asClearTextVector<uint64_t>(keySet, pos));
uint64_t decrypted = fromChunks(decryptedChunks, gate.chunkInfo->width);
return (T)decrypted;
}
auto &buffer = buffers[pos].getTensor();
auto ciphertext = buffer.getOpaqueElementPointer(0);
uint64_t decrypted;
// Convert to uint64_t* as required by `KeySet::decrypt_lwe`
// FIXME: this may break alignment restrictions on some
// architectures
auto ciphertextu64 = reinterpret_cast<uint64_t *>(ciphertext);
OUTCOME_TRYV(keySet.decrypt_lwe(0, ciphertextu64, decrypted));
return (T)decrypted;
ValueDecrypter decrypter(keySet, clientParameters);
auto &data = buffers[pos];
return decrypter.template decrypt<T>(data, pos);
}
/// Get the result at `pos` as a vector. Decryption happens if the
@@ -140,26 +191,8 @@ struct PublicResult {
template <typename T>
outcome::checked<std::vector<T>, StringError>
asClearTextVector(KeySet &keySet, size_t pos) {
OUTCOME_TRY(auto gate, clientParameters.ouput(pos));
if (!gate.isEncrypted())
return buffers[pos].getTensor().asFlatVector<T>();
auto &buffer = buffers[pos].getTensor();
auto lweSize = clientParameters.lweBufferSize(gate);
std::vector<T> decryptedValues(buffer.length() / lweSize);
for (size_t i = 0; i < decryptedValues.size(); i++) {
auto ciphertext = buffer.getOpaqueElementPointer(i * lweSize);
uint64_t decrypted;
// Convert to uint64_t* as required by `KeySet::decrypt_lwe`
// FIXME: this may break alignment restrictions on some
// architectures
auto ciphertextu64 = reinterpret_cast<uint64_t *>(ciphertext);
OUTCOME_TRYV(keySet.decrypt_lwe(0, ciphertextu64, decrypted));
decryptedValues[i] = decrypted;
}
return decryptedValues;
ValueDecrypter decrypter(keySet, clientParameters);
return decrypter.template decryptTensor<T>(buffers[pos], pos);
}
/// Return the shape of the clear tensor of a result.

View File

@@ -842,6 +842,7 @@ protected:
std::unique_ptr<TensorData> tensor;
public:
ScalarOrTensorData(const ScalarOrTensorData &td) = delete;
ScalarOrTensorData(ScalarOrTensorData &&td)
: scalar(std::move(td.scalar)), tensor(std::move(td.tensor)) {}

View File

@@ -13,8 +13,7 @@ using StringError = concretelang::error::StringError;
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
EncryptedArguments::exportPublicArguments(ClientParameters clientParameters) {
return std::make_unique<PublicArguments>(clientParameters,
std::move(ciphertextBuffers));
return std::make_unique<PublicArguments>(clientParameters, std::move(values));
}
/// Split the input integer into `size` chunks of `chunkWidth` bits each
@@ -31,60 +30,14 @@ std::vector<uint64_t> chunkInput(uint64_t value, size_t size,
return chunks;
}
outcome::checked<void, StringError>
EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) {
OUTCOME_TRYV(checkPushTooManyArgs(keySet));
OUTCOME_TRY(CircuitGate input, keySet.clientParameters().input(currentPos));
// a chunked input is represented as a tensor in lower levels, and need to to
// splitted into chunks and encrypted as such
if (input.chunkInfo.has_value()) {
std::vector<uint64_t> chunks =
chunkInput(arg, input.shape.size, input.chunkInfo.value().width);
return this->pushArg(chunks.data(), input.shape.size, keySet);
}
// we only increment if we don't forward the call to another pushArg method
auto pos = currentPos++;
if (input.shape.size != 0) {
return StringError("argument #") << pos << " is not a scalar";
}
if (!input.encryption.has_value()) {
// clear scalar: just push the argument
ciphertextBuffers.push_back(ScalarData(arg));
return outcome::success();
}
std::vector<int64_t> shape = keySet.clientParameters().bufferShape(input);
// Allocate empty
ciphertextBuffers.emplace_back(
TensorData(shape, clientlib::EncryptedScalarElementType,
clientlib::EncryptedScalarElementWidth));
TensorData &values_and_sizes = ciphertextBuffers.back().getTensor();
OUTCOME_TRYV(keySet.encrypt_lwe(
pos, values_and_sizes.getElementPointer<decrypted_scalar_t>(0), arg));
return outcome::success();
}
outcome::checked<void, StringError>
EncryptedArguments::checkPushTooManyArgs(KeySet &keySet) {
size_t arity = keySet.numInputs();
if (currentPos < arity) {
return outcome::success();
}
return StringError("function has arity ")
<< arity << " but is applied to too many arguments";
}
outcome::checked<void, StringError>
EncryptedArguments::checkAllArgs(KeySet &keySet) {
size_t arity = keySet.numInputs();
if (currentPos == arity) {
if (values.size() == arity) {
return outcome::success();
}
return StringError("function expects ")
<< arity << " arguments but has been called with " << currentPos
<< arity << " arguments but has been called with " << values.size()
<< " arguments";
}