// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #ifndef CONCRETELANG_CLIENTLIB_ENCRYPTED_ARGS_H #define CONCRETELANG_CLIENTLIB_ENCRYPTED_ARGS_H #include #include "boost/outcome.h" #include "../Common/Error.h" #include "concretelang/ClientLib/ClientParameters.h" #include "concretelang/ClientLib/KeySet.h" #include "concretelang/ClientLib/Types.h" #include "concretelang/Common/BitsSize.h" namespace concretelang { namespace clientlib { using concretelang::error::StringError; class PublicArguments; /// Temporary object used to hold and encrypt parameters before calling a /// ClientLambda. Use preferably TypeClientLambda and serializeCall(Args...). /// Otherwise convert it to a PublicArguments and use /// serializeCall(PublicArguments, KeySet). class EncryptedArguments { public: EncryptedArguments() : currentPos(0) {} /// Encrypts args thanks the given KeySet and pack the encrypted arguments to /// an EncryptedArguments template static outcome::checked, StringError> create(KeySet &keySet, Args... args) { auto encryptedArgs = std::make_unique(); OUTCOME_TRYV(encryptedArgs->pushArgs(keySet, args...)); return std::move(encryptedArgs); } template static outcome::checked, StringError> create(KeySet &keySet, const llvm::ArrayRef args) { auto encryptedArgs = EncryptedArguments::empty(); for (size_t i = 0; i < args.size(); i++) { OUTCOME_TRYV(encryptedArgs->pushArg(args[i], keySet)); } OUTCOME_TRYV(encryptedArgs->checkAllArgs(keySet)); return std::move(encryptedArgs); } static std::unique_ptr empty() { return std::make_unique(); } /// Export encrypted arguments as public arguments, reset the encrypted /// arguments, i.e. move all buffers to the PublicArguments and reset the /// positional counter. outcome::checked, StringError> exportPublicArguments(ClientParameters clientParameters, RuntimeContext runtimeContext); /// Check that all arguments as been pushed. // TODO: Remove public method here outcome::checked checkAllArgs(KeySet &keySet); public: /// Add a uint64_t scalar argument. outcome::checked pushArg(uint64_t arg, KeySet &keySet); /// Add a vector-tensor argument. outcome::checked pushArg(std::vector arg, KeySet &keySet) { return pushArg((uint8_t *)arg.data(), llvm::ArrayRef{(int64_t)arg.size()}, keySet); } /// Add a 1D tensor argument with data and size of the dimension. template outcome::checked pushArg(const T *data, int64_t dim1, KeySet &keySet) { return pushArg(std::vector(data, data + dim1), keySet); } /// Add a 1D tensor argument. template outcome::checked pushArg(std::array arg, KeySet &keySet) { return pushArg((uint8_t *)arg.data(), llvm::ArrayRef{size}, keySet); } /// Add a 2D tensor argument. template outcome::checked pushArg(std::array, size0> arg, KeySet &keySet) { return pushArg((uint8_t *)arg.data(), llvm::ArrayRef{size0, size1}, keySet); } /// Add a 3D tensor argument. template outcome::checked pushArg(std::array, size1>, size0> arg, KeySet &keySet) { return pushArg((uint8_t *)arg.data(), llvm::ArrayRef{size0, size1, size2}, keySet); } // Generalize by computing shape by template recursion /// Set a argument at the given pos as a 1D tensor of T. template outcome::checked pushArg(T *data, int64_t dim1, KeySet &keySet) { return pushArg(data, llvm::ArrayRef(&dim1, 1), keySet); } /// Set a argument at the given pos as a tensor of T. template outcome::checked pushArg(T *data, llvm::ArrayRef shape, KeySet &keySet) { return pushArg(static_cast(data), shape, keySet); } template outcome::checked pushArg(const T *data, llvm::ArrayRef shape, KeySet &keySet) { OUTCOME_TRYV(checkPushTooManyArgs(keySet)); auto pos = currentPos; CircuitGate input = keySet.inputGate(pos); // Check the width of data if (input.shape.width > 64) { return StringError("argument #") << pos << " width > 64 bits is not supported"; } // Check the shape of tensor if (input.shape.dimensions.empty()) { return StringError("argument #") << pos << "is not a tensor"; } if (shape.size() != input.shape.dimensions.size()) { return StringError("argument #") << pos << "has not the expected number of dimension, got " << shape.size() << " expected " << input.shape.dimensions.size(); } // Check shape for (size_t i = 0; i < shape.size(); i++) { if (shape[i] != input.shape.dimensions[i]) { return StringError("argument #") << pos << " has not the expected dimension #" << i << " , got " << shape[i] << " expected " << input.shape.dimensions[i]; } } // Set sizes std::vector sizes = keySet.clientParameters().bufferShape(input); if (input.encryption.hasValue()) { TensorData td(sizes, EncryptedScalarElementType, EncryptedScalarElementWidth); auto lweSize = keySet.clientParameters().lweBufferSize(input); for (size_t i = 0, offset = 0; i < input.shape.size; i++, offset += lweSize) { OUTCOME_TRYV(keySet.encrypt_lwe( pos, td.getElementPointer(offset), data[i])); } ciphertextBuffers.push_back(std::move(td)); } else { auto bitsPerValue = bitWidthAsWord(input.shape.width); TensorData td(sizes, bitsPerValue, input.shape.sign); llvm::ArrayRef values(data, TensorData::getNumElements(sizes)); td.bulkAssign(values); ciphertextBuffers.push_back(std::move(td)); } TensorData &td = ciphertextBuffers.back().getTensor(); // allocated preparedArgs.push_back(nullptr); // aligned preparedArgs.push_back(td.getValuesAsOpaquePointer()); // offset preparedArgs.push_back((void *)0); // sizes for (size_t size : td.getDimensions()) { preparedArgs.push_back((void *)size); } // Set the stride for each dimension, equal to the product of the // following dimensions. int64_t stride = td.getNumElements(); for (size_t size : td.getDimensions()) { stride = (size == 0 ? 0 : (stride / size)); preparedArgs.push_back((void *)stride); } currentPos++; return outcome::success(); } /// Recursive case for scalars: extract first scalar argument from /// parameter pack and forward rest template outcome::checked pushArgs(KeySet &keySet, Arg0 arg0, OtherArgs... others) { OUTCOME_TRYV(pushArg(arg0, keySet)); return pushArgs(keySet, others...); } /// Recursive case for tensors: extract pointer and size from /// parameter pack and forward rest template outcome::checked pushArgs(KeySet &keySet, Arg0 *arg0, size_t size, OtherArgs... others) { OUTCOME_TRYV(pushArg(arg0, size, keySet)); return pushArgs(keySet, others...); } /// Terminal case of pushArgs outcome::checked pushArgs(KeySet &keySet) { return checkAllArgs(keySet); } private: outcome::checked checkPushTooManyArgs(KeySet &keySet); private: /// Position of the next pushed argument size_t currentPos; std::vector preparedArgs; /// Store buffers of ciphertexts std::vector ciphertextBuffers; }; } // namespace clientlib } // namespace concretelang #endif