// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #ifndef CONCRETELANG_CLIENTLIB_ENCRYPTED_ARGS_H #define CONCRETELANG_CLIENTLIB_ENCRYPTED_ARGS_H #include #include "boost/outcome.h" #include "../Common/Error.h" #include "concretelang/ClientLib/ClientParameters.h" #include "concretelang/ClientLib/KeySet.h" #include "concretelang/ClientLib/Types.h" #include "concretelang/Common/BitsSize.h" namespace concretelang { namespace clientlib { using concretelang::error::StringError; class PublicArguments; inline size_t bitWidthAsWord(size_t exactBitWidth) { if (exactBitWidth <= 8) return 8; if (exactBitWidth <= 16) return 16; if (exactBitWidth <= 32) return 32; if (exactBitWidth <= 64) return 64; assert(false && "Bit witdh > 64 not supported"); } /// Temporary object used to hold and encrypt parameters before calling a /// ClientLambda. Use preferably TypeClientLambda and serializeCall(Args...). /// Otherwise convert it to a PublicArguments and use /// serializeCall(PublicArguments, KeySet). class EncryptedArguments { public: EncryptedArguments() : currentPos(0) {} /// Encrypts args thanks the given KeySet and pack the encrypted arguments to /// an EncryptedArguments template static outcome::checked, StringError> create(KeySet &keySet, Args... args) { auto encryptedArgs = std::make_unique(); OUTCOME_TRYV(encryptedArgs->pushArgs(keySet, args...)); return std::move(encryptedArgs); } template static outcome::checked, StringError> create(KeySet &keySet, const llvm::ArrayRef args) { auto encryptedArgs = EncryptedArguments::empty(); for (size_t i = 0; i < args.size(); i++) { OUTCOME_TRYV(encryptedArgs->pushArg(args[i], keySet)); } OUTCOME_TRYV(encryptedArgs->checkAllArgs(keySet)); return std::move(encryptedArgs); } static std::unique_ptr empty() { return std::make_unique(); } /// Export encrypted arguments as public arguments, reset the encrypted /// arguments, i.e. move all buffers to the PublicArguments and reset the /// positional counter. outcome::checked, StringError> exportPublicArguments(ClientParameters clientParameters, RuntimeContext runtimeContext); /// Check that all arguments as been pushed. // TODO: Remove public method here outcome::checked checkAllArgs(KeySet &keySet); public: /// Add a uint64_t scalar argument. outcome::checked pushArg(uint64_t arg, KeySet &keySet); /// Add a vector-tensor argument. outcome::checked pushArg(std::vector arg, KeySet &keySet) { return pushArg((uint8_t *)arg.data(), llvm::ArrayRef{(int64_t)arg.size()}, keySet); } /// Add a 1D tensor argument with data and size of the dimension. template outcome::checked pushArg(const T *data, int64_t dim1, KeySet &keySet) { return pushArg(std::vector(data, data + dim1), keySet); } /// Add a 1D tensor argument. template outcome::checked pushArg(std::array arg, KeySet &keySet) { return pushArg((uint8_t *)arg.data(), llvm::ArrayRef{size}, keySet); } /// Add a 2D tensor argument. template outcome::checked pushArg(std::array, size0> arg, KeySet &keySet) { return pushArg((uint8_t *)arg.data(), llvm::ArrayRef{size0, size1}, keySet); } /// Add a 3D tensor argument. template outcome::checked pushArg(std::array, size1>, size0> arg, KeySet &keySet) { return pushArg((uint8_t *)arg.data(), llvm::ArrayRef{size0, size1, size2}, keySet); } // Generalize by computing shape by template recursion /// Set a argument at the given pos as a 1D tensor of T. template outcome::checked pushArg(T *data, int64_t dim1, KeySet &keySet) { return pushArg(data, llvm::ArrayRef(&dim1, 1), keySet); } /// Set a argument at the given pos as a tensor of T. template outcome::checked pushArg(T *data, llvm::ArrayRef shape, KeySet &keySet) { return pushArg(static_cast(data), shape, keySet); } template outcome::checked pushArg(const T *data, llvm::ArrayRef shape, KeySet &keySet) { OUTCOME_TRYV(checkPushTooManyArgs(keySet)); auto pos = currentPos; CircuitGate input = keySet.inputGate(pos); // Check the width of data if (input.shape.width > 64) { return StringError("argument #") << pos << " width > 64 bits is not supported"; } // Check the shape of tensor if (input.shape.dimensions.empty()) { return StringError("argument #") << pos << "is not a tensor"; } if (shape.size() != input.shape.dimensions.size()) { return StringError("argument #") << pos << "has not the expected number of dimension, got " << shape.size() << " expected " << input.shape.dimensions.size(); } // Allocate empty ciphertextBuffers.resize(ciphertextBuffers.size() + 1); TensorData &values_and_sizes = ciphertextBuffers.back(); // Check shape for (size_t i = 0; i < shape.size(); i++) { if (shape[i] != input.shape.dimensions[i]) { return StringError("argument #") << pos << " has not the expected dimension #" << i << " , got " << shape[i] << " expected " << input.shape.dimensions[i]; } } // Set sizes values_and_sizes.sizes = keySet.clientParameters().bufferShape(input); if (input.encryption.hasValue()) { // Allocate values values_and_sizes.values.resize( keySet.clientParameters().bufferSize(input)); auto lweSize = keySet.clientParameters().lweBufferSize(input); auto &values = values_and_sizes.values; for (size_t i = 0, offset = 0; i < input.shape.size; i++, offset += lweSize) { OUTCOME_TRYV(keySet.encrypt_lwe(pos, values.data() + offset, data[i])); } } else { // Allocate values take care of gate bitwidth auto bitsPerValue = bitWidthAsWord(input.shape.width); auto bytesPerValue = bitsPerValue / 8; auto nbWordPerValue = 8 / bytesPerValue; // ceil division auto size = (input.shape.size / nbWordPerValue) + (input.shape.size % nbWordPerValue != 0); size = size == 0 ? 1 : size; values_and_sizes.values.resize(size); auto v = (uint8_t *)values_and_sizes.values.data(); for (size_t i = 0; i < input.shape.size; i++) { auto dst = v + i * bytesPerValue; auto src = (const uint8_t *)&data[i]; for (size_t j = 0; j < bytesPerValue; j++) { dst[j] = src[j]; } } } // allocated preparedArgs.push_back(nullptr); // aligned preparedArgs.push_back((void *)values_and_sizes.values.data()); // offset preparedArgs.push_back((void *)0); // sizes for (size_t size : values_and_sizes.sizes) { preparedArgs.push_back((void *)size); } // Set the stride for each dimension, equal to the product of the // following dimensions. int64_t stride = values_and_sizes.length(); for (size_t size : values_and_sizes.sizes) { stride = (size == 0 ? 0 : (stride / size)); preparedArgs.push_back((void *)stride); } currentPos++; return outcome::success(); } /// Recursive case for scalars: extract first scalar argument from /// parameter pack and forward rest template outcome::checked pushArgs(KeySet &keySet, Arg0 arg0, OtherArgs... others) { OUTCOME_TRYV(pushArg(arg0, keySet)); return pushArgs(keySet, others...); } /// Recursive case for tensors: extract pointer and size from /// parameter pack and forward rest template outcome::checked pushArgs(KeySet &keySet, Arg0 *arg0, size_t size, OtherArgs... others) { OUTCOME_TRYV(pushArg(arg0, size, keySet)); return pushArgs(keySet, others...); } /// Terminal case of pushArgs outcome::checked pushArgs(KeySet &keySet) { return checkAllArgs(keySet); } private: outcome::checked checkPushTooManyArgs(KeySet &keySet); private: /// Position of the next pushed argument size_t currentPos; std::vector preparedArgs; /// Store buffers of ciphertexts std::vector ciphertextBuffers; }; } // namespace clientlib } // namespace concretelang #endif