// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt // for license information. #ifndef CONCRETELANG_CLIENTLIB_PUBLIC_ARGUMENTS_H #define CONCRETELANG_CLIENTLIB_PUBLIC_ARGUMENTS_H #include #include "boost/outcome.h" #include "concretelang/ClientLib/ClientParameters.h" #include "concretelang/ClientLib/EncryptedArguments.h" #include "concretelang/ClientLib/Types.h" #include "concretelang/Common/Error.h" #include "concretelang/Runtime/context.h" namespace concretelang { namespace serverlib { class ServerLambda; } } // namespace concretelang namespace mlir { namespace concretelang { class JITLambda; } } // namespace mlir namespace concretelang { namespace clientlib { using concretelang::error::StringError; class EncryptedArguments; class PublicArguments { /// PublicArguments will be sended to the server. It includes encrypted /// arguments and public keys. public: PublicArguments(const ClientParameters &clientParameters, RuntimeContext runtimeContext, bool clearRuntimeContext, std::vector &&preparedArgs, std::vector &&ciphertextBuffers); ~PublicArguments(); PublicArguments(PublicArguments &other) = delete; PublicArguments(PublicArguments &&other) = delete; static outcome::checked, StringError> unserialize(ClientParameters &expectedParams, std::istream &istream); outcome::checked serialize(std::ostream &ostream); private: friend class ::concretelang::serverlib::ServerLambda; friend class ::mlir::concretelang::JITLambda; outcome::checked unserializeArgs(std::istream &istream); ClientParameters clientParameters; RuntimeContext runtimeContext; std::vector preparedArgs; // Store buffers of ciphertexts std::vector ciphertextBuffers; // Indicates if this public argument own the runtime keys. bool clearRuntimeContext; }; struct PublicResult { /// PublicResult is a result of a ServerLambda call which contains encrypted /// results. PublicResult(const ClientParameters &clientParameters, std::vector buffers = {}) : clientParameters(clientParameters), buffers(buffers){}; PublicResult(PublicResult &) = delete; /// Create a public result from buffers. static std::unique_ptr fromBuffers(const ClientParameters &clientParameters, std::vector buffers) { return std::make_unique(clientParameters, buffers); } /// Unserialize from a input stream. outcome::checked unserialize(std::istream &istream); /// Serialize into an output stream. outcome::checked serialize(std::ostream &ostream); /// Get the result at `pos` as a vector, if the result is a scalar returns a /// vector of size 1. Decryption happens if the result is encrypted. // outcome::checked, StringError> // asClearTextVector(KeySet &keySet, size_t pos); template outcome::checked, StringError> asClearTextVector(KeySet &keySet, size_t pos) { OUTCOME_TRY(auto gate, clientParameters.ouput(pos)); if (!gate.isEncrypted()) { std::vector result; result.reserve(buffers[pos].values.size()); std::copy(buffers[pos].values.begin(), buffers[pos].values.end(), std::back_inserter(result)); return result; } auto buffer = buffers[pos]; auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize(); std::vector decryptedValues(buffer.length() / lweSize); for (size_t i = 0; i < decryptedValues.size(); i++) { auto ciphertext = &buffer.values[i * lweSize]; uint64_t decrypted; OUTCOME_TRYV(keySet.decrypt_lwe(0, ciphertext, decrypted)); decryptedValues[i] = decrypted; } return decryptedValues; } // private: TODO tmp friend class ::concretelang::serverlib::ServerLambda; ClientParameters clientParameters; std::vector buffers; }; /// Helper function to convert from a scalar to TensorData TensorData tensorDataFromScalar(uint64_t value); /// Helper function to convert from MemRefDescriptor to /// TensorData TensorData tensorDataFromMemRef(size_t memref_rank, encrypted_scalars_t allocated, encrypted_scalars_t aligned, size_t offset, size_t *sizes, size_t *strides); } // namespace clientlib } // namespace concretelang #endif