// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt // for license information. #ifndef CONCRETELANG_CLIENTLIB_ENCRYPTED_ARGS_H #define CONCRETELANG_CLIENTLIB_ENCRYPTED_ARGS_H #include #include "boost/outcome.h" #include "../Common/Error.h" #include "concretelang/ClientLib/ClientParameters.h" #include "concretelang/ClientLib/KeySet.h" #include "concretelang/ClientLib/Types.h" #include "concretelang/Common/BitsSize.h" namespace concretelang { namespace clientlib { using concretelang::error::StringError; class PublicArguments; class EncryptedArguments { /// Temporary object used to hold and encrypt parameters before calling a /// ClientLambda. Use preferably TypeClientLambda and serializeCall(Args...). /// Otherwise convert it to a PublicArguments and use /// serializeCall(PublicArguments, KeySet). public: EncryptedArguments() : currentPos(0) {} /// Encrypts args thanks the given KeySet and pack the encrypted arguments to /// an EncryptedArguments template static outcome::checked, StringError> create(KeySet &keySet, Args... args) { auto arguments = std::make_unique(); OUTCOME_TRYV(arguments->pushArgs(keySet, args...)); return arguments; } static std::unique_ptr empty() { return std::make_unique(); } /// Export encrypted arguments as public arguments, reset the encrypted /// arguments, i.e. move all buffers to the PublicArguments and reset the /// positional counter. outcome::checked, StringError> exportPublicArguments(ClientParameters clientParameters, RuntimeContext runtimeContext); /// Check that all arguments as been pushed. /// TODO: Remove public method here outcome::checked checkAllArgs(KeySet &keySet); public: // Add a uint64_t scalar argument. outcome::checked pushArg(uint64_t arg, KeySet &keySet); /// Add a vector-tensor argument. outcome::checked pushArg(std::vector arg, KeySet &keySet); // Add a 1D tensor argument with data and size of the dimension. template outcome::checked pushArg(const T *data, int64_t dim1, KeySet &keySet) { return pushArg(std::vector(data, data + dim1), keySet); } // Add a tensor argument. template outcome::checked pushArg(const T *data, llvm::ArrayRef shape, KeySet &keySet) { return pushArg(8 * sizeof(T), static_cast(data), shape, keySet); } /// Add a 1D tensor argument. template outcome::checked pushArg(std::array arg, KeySet &keySet) { return pushArg(8, (void *)arg.data(), {size}, keySet); } /// Add a 2D tensor argument. template outcome::checked pushArg(std::array, size0> arg, KeySet &keySet) { return pushArg(8, (void *)arg.data(), {size0, size1}, keySet); } /// Add a 3D tensor argument. template outcome::checked pushArg(std::array, size1>, size0> arg, KeySet &keySet) { return pushArg(8, (void *)arg.data(), {size0, size1, size2}, keySet); } // Generalize by computing shape by template recursion // Set a argument at the given pos as a 1D tensor of T. template outcome::checked pushArg(T *data, int64_t dim1, KeySet &keySet) { return pushArg(data, llvm::ArrayRef(&dim1, 1), keySet); } // Set a argument at the given pos as a tensor of T. template outcome::checked pushArg(T *data, llvm::ArrayRef shape, KeySet &keySet) { return pushArg(8 * sizeof(T), static_cast(data), shape, keySet); } outcome::checked pushArg(size_t width, const void *data, llvm::ArrayRef shape, KeySet &keySet); // Recursive case for scalars: extract first scalar argument from // parameter pack and forward rest template outcome::checked pushArgs(KeySet &keySet, Arg0 arg0, OtherArgs... others) { OUTCOME_TRYV(pushArg(arg0, keySet)); return pushArgs(keySet, others...); } // Recursive case for tensors: extract pointer and size from // parameter pack and forward rest template outcome::checked pushArgs(KeySet &keySet, Arg0 *arg0, size_t size, OtherArgs... others) { OUTCOME_TRYV(pushArg(arg0, size, keySet)); return pushArgs(keySet, others...); } // Terminal case of pushArgs outcome::checked pushArgs(KeySet &keySet) { return checkAllArgs(keySet); } private: outcome::checked checkPushTooManyArgs(KeySet &keySet); private: // Position of the next pushed argument size_t currentPos; std::vector preparedArgs; // Store buffers of ciphertexts std::vector ciphertextBuffers; }; } // namespace clientlib } // namespace concretelang #endif