From f7f94a166336bb62d0ffb0068a696d87bdea543c Mon Sep 17 00:00:00 2001 From: Bourgerie Quentin Date: Tue, 30 May 2023 14:30:46 +0200 Subject: [PATCH] feat(compiler/client-lib): Implement ValueExporter to allows partial encryption --- compilers/concrete-compiler/compiler/Makefile | 4 +- .../ClientLib/EncryptedArguments.h | 209 ++++++++++++------ .../concretelang/ClientLib/PublicArguments.h | 149 ++++++++----- .../include/concretelang/ClientLib/Types.h | 1 + .../lib/ClientLib/EncryptedArguments.cpp | 53 +---- 5 files changed, 242 insertions(+), 174 deletions(-) diff --git a/compilers/concrete-compiler/compiler/Makefile b/compilers/concrete-compiler/compiler/Makefile index a015e493d..c7c4b5473 100644 --- a/compilers/concrete-compiler/compiler/Makefile +++ b/compilers/concrete-compiler/compiler/Makefile @@ -295,9 +295,7 @@ generate-cpu-tests: \ SECURITY_TO_TEST=128 OPTIMIZATION_STRATEGY_TO_TEST=dag-mono dag-multi -PARALLEL_END_2_END_TESTS= end_to_end_jit_test \ - end_to_end_jit_lambda \ - end_to_end_jit_lambda +PARALLEL_END_2_END_TESTS= end_to_end_jit_test end_to_end_jit_lambda run-end-to-end-tests: $(GTEST_PARALLEL_PY) build-end-to-end-tests generate-cpu-tests $(foreach TEST,$(PARALLEL_END_2_END_TESTS), \ $(GTEST_PARALLEL_CMD) $(BUILD_DIR)/tools/concretelang/tests/end_to_end_tests/$(TEST);) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/EncryptedArguments.h b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/EncryptedArguments.h index ab3e0cfcd..6a88d9392 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/EncryptedArguments.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/EncryptedArguments.h @@ -23,6 +23,139 @@ using concretelang::error::StringError; class PublicArguments; +/// @brief The ArgumentsExporter allows to transform clear +/// arguments to the one expected by a server lambda. +class ValueExporter { + +public: + /// @brief + /// @param keySet + /// @param clientParameters + // TODO: Get rid of the reference here could make troubles (see for KeySet + // copy constructor or shared pointers) + ValueExporter(KeySet &keySet, ClientParameters clientParameters) + : _keySet(keySet), _clientParameters(clientParameters) {} + + /// @brief Export a scalar 64 bits integer to a concreteprocol::Value + /// @param arg An 64 bits integer + /// @param argPos The position of the argument to export + /// @return Either the exported value ready to be sent to the server or an + /// error if the gate doesn't match the expected argument. + outcome::checked exportValue(uint64_t arg, + size_t argPos) { + OUTCOME_TRY(auto gate, _clientParameters.input(argPos)); + if (gate.shape.size != 0) { + return StringError("argument #") << argPos << " is not a scalar"; + } + if (gate.encryption.has_value()) { + return exportEncryptValue(arg, gate, argPos); + } + return exportClearValue(arg); + } + + /// @brief Export a tensor like buffer of values to a serializable value + /// @tparam T The type of values hold by the buffer + /// @param arg A pointer to a memory area where the values are stored + /// @param shape The shape of the tensor + /// @param argPos The position of the argument to export + /// @return Either the exported value ready to be sent to the server or an + /// error if the gate doesn't match the expected argument. + template + outcome::checked + exportValue(const T *arg, llvm::ArrayRef shape, size_t argPos) { + OUTCOME_TRY(auto gate, _clientParameters.input(argPos)); + OUTCOME_TRYV(checkShape(shape, gate.shape, argPos)); + if (gate.encryption.has_value()) { + return exportEncryptTensor(arg, shape, gate, argPos); + } + return exportClearTensor(arg, shape, gate); + } + +private: + /// Export a 64bits integer to a serializable value + outcome::checked + exportClearValue(uint64_t arg) { + return ScalarData(arg); + } + + /// Encrypt and export a 64bits integer to a serializale value + outcome::checked + exportEncryptValue(uint64_t arg, CircuitGate &gate, size_t argPos) { + std::vector shape = _clientParameters.bufferShape(gate); + + // Create and allocate the TensorData that will holds encrypted value + TensorData td(shape, clientlib::EncryptedScalarElementType, + clientlib::EncryptedScalarElementWidth); + + // Encrypt the value + OUTCOME_TRYV( + _keySet.encrypt_lwe(argPos, td.getElementPointer(0), arg)); + return std::move(td); + } + + /// Export a tensor like buffer to a serializable value + template + outcome::checked + exportClearTensor(const T *arg, llvm::ArrayRef shape, + CircuitGate &gate) { + auto bitsPerValue = bitWidthAsWord(gate.shape.width); + auto sizes = _clientParameters.bufferShape(gate); + TensorData td(sizes, bitsPerValue, gate.shape.sign); + llvm::ArrayRef values(arg, TensorData::getNumElements(sizes)); + td.bulkAssign(values); + return std::move(td); + } + + /// Export and encrypt a tensor like buffer to a serializable value + template + outcome::checked + exportEncryptTensor(const T *arg, llvm::ArrayRef shape, + CircuitGate &gate, size_t argPos) { + // Create and allocate the TensorData that will holds encrypted values + auto sizes = _clientParameters.bufferShape(gate); + TensorData td(sizes, EncryptedScalarElementType, + EncryptedScalarElementWidth); + + // Iterate over values and encrypt at the right place the value + auto lweSize = _clientParameters.lweBufferSize(gate); + for (size_t i = 0, offset = 0; i < gate.shape.size; + i++, offset += lweSize) { + OUTCOME_TRYV(_keySet.encrypt_lwe( + argPos, td.getElementPointer(offset), arg[i])); + } + return std::move(td); + } + + static outcome::checked + checkShape(llvm::ArrayRef shape, CircuitGateShape expected, + size_t argPos) { + // Check the shape of tensor + if (expected.dimensions.empty()) { + return StringError("argument #") << argPos << "is not a tensor"; + } + if (shape.size() != expected.dimensions.size()) { + return StringError("argument #") + << argPos << "has not the expected number of dimension, got " + << shape.size() << " expected " << expected.dimensions.size(); + } + + // Check shape + for (size_t i = 0; i < shape.size(); i++) { + if (shape[i] != expected.dimensions[i]) { + return StringError("argument #") + << argPos << " has not the expected dimension #" << i + << " , got " << shape[i] << " expected " + << expected.dimensions[i]; + } + } + return outcome::success(); + } + +private: + KeySet &_keySet; + ClientParameters _clientParameters; +}; + /// Temporary object used to hold and encrypt parameters before calling a /// ClientLambda. Use preferably TypeClientLambda and serializeCall(Args...). /// Otherwise convert it to a PublicArguments and use @@ -30,10 +163,10 @@ class PublicArguments; class EncryptedArguments { public: - EncryptedArguments() : currentPos(0) {} + EncryptedArguments() {} - /// Encrypts args thanks the given KeySet and pack the encrypted arguments to - /// an EncryptedArguments + /// Encrypts args thanks the given KeySet and pack the encrypted arguments + /// to an EncryptedArguments template static outcome::checked, StringError> create(KeySet &keySet, Args... args) { @@ -69,7 +202,12 @@ public: public: /// Add a uint64_t scalar argument. - outcome::checked pushArg(uint64_t arg, KeySet &keySet); + outcome::checked pushArg(uint64_t arg, KeySet &keySet) { + ValueExporter exporter(keySet, keySet.clientParameters()); + OUTCOME_TRY(auto value, exporter.exportValue(arg, values.size())); + values.push_back(std::move(value)); + return outcome::success(); + } /// Add a vector-tensor argument. outcome::checked pushArg(std::vector arg, @@ -129,58 +267,9 @@ public: template outcome::checked pushArg(const T *data, llvm::ArrayRef shape, KeySet &keySet) { - OUTCOME_TRYV(checkPushTooManyArgs(keySet)); - auto pos = currentPos; - CircuitGate input = keySet.inputGate(pos); - // Check the width of data - if (input.shape.width > 64) { - return StringError("argument #") - << pos << " width > 64 bits is not supported"; - } - // Check the shape of tensor - if (input.shape.dimensions.empty()) { - return StringError("argument #") << pos << "is not a tensor"; - } - if (shape.size() != input.shape.dimensions.size()) { - return StringError("argument #") - << pos << "has not the expected number of dimension, got " - << shape.size() << " expected " << input.shape.dimensions.size(); - } - - // Check shape - for (size_t i = 0; i < shape.size(); i++) { - if (shape[i] != input.shape.dimensions[i]) { - return StringError("argument #") - << pos << " has not the expected dimension #" << i << " , got " - << shape[i] << " expected " << input.shape.dimensions[i]; - } - } - - // Set sizes - std::vector sizes = keySet.clientParameters().bufferShape(input); - - if (input.encryption.has_value()) { - TensorData td(sizes, EncryptedScalarElementType, - EncryptedScalarElementWidth); - - auto lweSize = keySet.clientParameters().lweBufferSize(input); - - for (size_t i = 0, offset = 0; i < input.shape.size; - i++, offset += lweSize) { - OUTCOME_TRYV(keySet.encrypt_lwe( - pos, td.getElementPointer(offset), data[i])); - } - ciphertextBuffers.push_back(std::move(td)); - } else { - auto bitsPerValue = bitWidthAsWord(input.shape.width); - - TensorData td(sizes, bitsPerValue, input.shape.sign); - llvm::ArrayRef values(data, TensorData::getNumElements(sizes)); - td.bulkAssign(values); - ciphertextBuffers.push_back(std::move(td)); - } - - currentPos++; + ValueExporter exporter(keySet, keySet.clientParameters()); + OUTCOME_TRY(auto value, exporter.exportValue(data, shape, values.size())); + values.push_back(std::move(value)); return outcome::success(); } @@ -208,14 +297,8 @@ public: } private: - outcome::checked checkPushTooManyArgs(KeySet &keySet); - -private: - /// Position of the next pushed argument - size_t currentPos; - /// Store buffers of ciphertexts - std::vector ciphertextBuffers; + std::vector values; }; } // namespace clientlib diff --git a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/PublicArguments.h b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/PublicArguments.h index a1ae46a2a..a8376b148 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/PublicArguments.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/PublicArguments.h @@ -32,6 +32,81 @@ using concretelang::error::StringError; class EncryptedArguments; +/// @brief allows to transform a serializable value into a clear value +class ValueDecrypter { +public: + ValueDecrypter(KeySet &keySet, ClientParameters clientParameters) + : _keySet(keySet), _clientParameters(clientParameters) {} + + /// @brief Transforms a FHE value into a clear scalar value + /// @tparam T The type of the clear scalar value + /// @param value The value to decrypt + /// @param pos The position of the argument + /// @return Either the decrypted value or an error if the gate doesn't match + /// the expected result. + template + outcome::checked decrypt(ScalarOrTensorData &value, + size_t pos) { + OUTCOME_TRY(auto gate, _clientParameters.ouput(pos)); + if (!gate.isEncrypted()) + return value.getScalar().getValue(); + + auto &buffer = value.getTensor(); + + auto ciphertext = buffer.getOpaqueElementPointer(0); + uint64_t decrypted; + + // Convert to uint64_t* as required by `KeySet::decrypt_lwe` + // FIXME: this may break alignment restrictions on some + // architectures + auto ciphertextu64 = reinterpret_cast(ciphertext); + OUTCOME_TRYV(_keySet.decrypt_lwe(0, ciphertextu64, decrypted)); + + return (T)decrypted; + } + + /// @brief Transforms a FHE value into a vector of clear value + /// @tparam T The type of the clear scalar value + /// @param value The value to decrypt + /// @param pos The position of the argument + /// @return Either the decrypted value or an error if the gate doesn't match + /// the expected result. + template + outcome::checked, StringError> + decryptTensor(ScalarOrTensorData &value, size_t pos) { + OUTCOME_TRY(auto gate, _clientParameters.ouput(pos)); + if (!gate.isEncrypted()) + return value.getTensor().asFlatVector(); + + auto &buffer = value.getTensor(); + auto lweSize = _clientParameters.lweBufferSize(gate); + + std::vector decryptedValues(buffer.length() / lweSize); + for (size_t i = 0; i < decryptedValues.size(); i++) { + auto ciphertext = buffer.getOpaqueElementPointer(i * lweSize); + uint64_t decrypted; + + // Convert to uint64_t* as required by `KeySet::decrypt_lwe` + // FIXME: this may break alignment restrictions on some + // architectures + auto ciphertextu64 = reinterpret_cast(ciphertext); + OUTCOME_TRYV(_keySet.decrypt_lwe(0, ciphertextu64, decrypted)); + decryptedValues[i] = decrypted; + } + return decryptedValues; + } + + /// Return the shape of the clear tensor of a result. + outcome::checked, StringError> getShape(size_t pos) { + OUTCOME_TRY(auto gate, _clientParameters.ouput(pos)); + return gate.shape.dimensions; + } + +private: + KeySet &_keySet; + ClientParameters _clientParameters; +}; + /// PublicArguments will be sended to the server. It includes encrypted /// arguments and public keys. class PublicArguments { @@ -71,6 +146,17 @@ struct PublicResult { PublicResult(PublicResult &) = delete; + /// @brief Return a value from the PublicResult + /// @param argPos The position of the value in the PublicResult + /// @return Either the value or an error if there are no value at this + /// position + outcome::checked getValue(size_t argPos) { + if (argPos >= buffers.size()) { + return StringError("result #") << argPos << " does not exists"; + } + return std::move(buffers[argPos]); + } + /// Create a public result from buffers. static std::unique_ptr fromBuffers(const ClientParameters &clientParameters, @@ -90,49 +176,14 @@ struct PublicResult { /// Serialize into an output stream. outcome::checked serialize(std::ostream &ostream); - /// Get the original integer that was decomposed into chunks of `chunkWidth` - /// bits each - uint64_t fromChunks(std::vector chunks, unsigned int chunkWidth) { - uint64_t value = 0; - uint64_t mask = (1 << chunkWidth) - 1; - for (size_t i = 0; i < chunks.size(); i++) { - auto chunk = chunks[i] & mask; - value += chunk << (chunkWidth * i); - } - return value; - } - /// Get the result at `pos` as a scalar. Decryption happens if the /// result is encrypted. template outcome::checked asClearTextScalar(KeySet &keySet, size_t pos) { - OUTCOME_TRY(auto gate, clientParameters.ouput(pos)); - if (!gate.isEncrypted()) - return buffers[pos].getScalar().getValue(); - - // Chunked integers are represented as tensors at a lower level, so we need - // to deal with them as tensors, then build the resulting scalar out of the - // tensor values - if (gate.chunkInfo.has_value()) { - OUTCOME_TRY(std::vector decryptedChunks, - this->asClearTextVector(keySet, pos)); - uint64_t decrypted = fromChunks(decryptedChunks, gate.chunkInfo->width); - return (T)decrypted; - } - - auto &buffer = buffers[pos].getTensor(); - - auto ciphertext = buffer.getOpaqueElementPointer(0); - uint64_t decrypted; - - // Convert to uint64_t* as required by `KeySet::decrypt_lwe` - // FIXME: this may break alignment restrictions on some - // architectures - auto ciphertextu64 = reinterpret_cast(ciphertext); - OUTCOME_TRYV(keySet.decrypt_lwe(0, ciphertextu64, decrypted)); - - return (T)decrypted; + ValueDecrypter decrypter(keySet, clientParameters); + auto &data = buffers[pos]; + return decrypter.template decrypt(data, pos); } /// Get the result at `pos` as a vector. Decryption happens if the @@ -140,26 +191,8 @@ struct PublicResult { template outcome::checked, StringError> asClearTextVector(KeySet &keySet, size_t pos) { - OUTCOME_TRY(auto gate, clientParameters.ouput(pos)); - if (!gate.isEncrypted()) - return buffers[pos].getTensor().asFlatVector(); - - auto &buffer = buffers[pos].getTensor(); - auto lweSize = clientParameters.lweBufferSize(gate); - - std::vector decryptedValues(buffer.length() / lweSize); - for (size_t i = 0; i < decryptedValues.size(); i++) { - auto ciphertext = buffer.getOpaqueElementPointer(i * lweSize); - uint64_t decrypted; - - // Convert to uint64_t* as required by `KeySet::decrypt_lwe` - // FIXME: this may break alignment restrictions on some - // architectures - auto ciphertextu64 = reinterpret_cast(ciphertext); - OUTCOME_TRYV(keySet.decrypt_lwe(0, ciphertextu64, decrypted)); - decryptedValues[i] = decrypted; - } - return decryptedValues; + ValueDecrypter decrypter(keySet, clientParameters); + return decrypter.template decryptTensor(buffers[pos], pos); } /// Return the shape of the clear tensor of a result. diff --git a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/Types.h b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/Types.h index 84874c772..bd76135e5 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/Types.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/Types.h @@ -842,6 +842,7 @@ protected: std::unique_ptr tensor; public: + ScalarOrTensorData(const ScalarOrTensorData &td) = delete; ScalarOrTensorData(ScalarOrTensorData &&td) : scalar(std::move(td.scalar)), tensor(std::move(td.tensor)) {} diff --git a/compilers/concrete-compiler/compiler/lib/ClientLib/EncryptedArguments.cpp b/compilers/concrete-compiler/compiler/lib/ClientLib/EncryptedArguments.cpp index f1677fb20..dc512cf89 100644 --- a/compilers/concrete-compiler/compiler/lib/ClientLib/EncryptedArguments.cpp +++ b/compilers/concrete-compiler/compiler/lib/ClientLib/EncryptedArguments.cpp @@ -13,8 +13,7 @@ using StringError = concretelang::error::StringError; outcome::checked, StringError> EncryptedArguments::exportPublicArguments(ClientParameters clientParameters) { - return std::make_unique(clientParameters, - std::move(ciphertextBuffers)); + return std::make_unique(clientParameters, std::move(values)); } /// Split the input integer into `size` chunks of `chunkWidth` bits each @@ -31,60 +30,14 @@ std::vector chunkInput(uint64_t value, size_t size, return chunks; } -outcome::checked -EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) { - OUTCOME_TRYV(checkPushTooManyArgs(keySet)); - OUTCOME_TRY(CircuitGate input, keySet.clientParameters().input(currentPos)); - // a chunked input is represented as a tensor in lower levels, and need to to - // splitted into chunks and encrypted as such - if (input.chunkInfo.has_value()) { - std::vector chunks = - chunkInput(arg, input.shape.size, input.chunkInfo.value().width); - return this->pushArg(chunks.data(), input.shape.size, keySet); - } - // we only increment if we don't forward the call to another pushArg method - auto pos = currentPos++; - if (input.shape.size != 0) { - return StringError("argument #") << pos << " is not a scalar"; - } - if (!input.encryption.has_value()) { - // clear scalar: just push the argument - ciphertextBuffers.push_back(ScalarData(arg)); - return outcome::success(); - } - - std::vector shape = keySet.clientParameters().bufferShape(input); - - // Allocate empty - ciphertextBuffers.emplace_back( - TensorData(shape, clientlib::EncryptedScalarElementType, - clientlib::EncryptedScalarElementWidth)); - TensorData &values_and_sizes = ciphertextBuffers.back().getTensor(); - - OUTCOME_TRYV(keySet.encrypt_lwe( - pos, values_and_sizes.getElementPointer(0), arg)); - - return outcome::success(); -} - -outcome::checked -EncryptedArguments::checkPushTooManyArgs(KeySet &keySet) { - size_t arity = keySet.numInputs(); - if (currentPos < arity) { - return outcome::success(); - } - return StringError("function has arity ") - << arity << " but is applied to too many arguments"; -} - outcome::checked EncryptedArguments::checkAllArgs(KeySet &keySet) { size_t arity = keySet.numInputs(); - if (currentPos == arity) { + if (values.size() == arity) { return outcome::success(); } return StringError("function expects ") - << arity << " arguments but has been called with " << currentPos + << arity << " arguments but has been called with " << values.size() << " arguments"; }