diff --git a/compiler/include/concretelang/ClientLib/ClientLambda.h b/compiler/include/concretelang/ClientLib/ClientLambda.h index b0a5a0280..e83a4acc0 100644 --- a/compiler/include/concretelang/ClientLib/ClientLambda.h +++ b/compiler/include/concretelang/ClientLib/ClientLambda.h @@ -9,7 +9,7 @@ #include #include "concretelang/ClientLib/ClientParameters.h" -#include "concretelang/ClientLib/EncryptedArgs.h" +#include "concretelang/ClientLib/EncryptedArguments.h" #include "concretelang/ClientLib/KeySet.h" #include "concretelang/ClientLib/KeySetCache.h" #include "concretelang/ClientLib/PublicArguments.h" @@ -33,15 +33,10 @@ class ClientLambda { /// Low-level class to create the client side view of a FHE function. public: virtual ~ClientLambda() = default; - static outcome::checked - /// Construct a ClientLambda from a ClientParameter file. - load(std::string funcName, std::string jsonPath); - /// Emit a call to the given ostream, no meta-date are include, so it's the - /// responsability of the the caller/callee to verify the add/verify the - /// function to be called. - outcome::checked - untypedSerializeCall(PublicArguments &publicArguments, std::ostream &ostream); + /// Construct a ClientLambda from a ClientParameter file. + static outcome::checked load(std::string funcName, + std::string jsonPath); /// Generate or get from cache a KeySet suitable for this ClientLambda outcome::checked, StringError> @@ -49,19 +44,19 @@ public: uint64_t seed_lsb); outcome::checked, StringError> - decryptReturnedValues(KeySet &keySet, std::istream &istream); + decryptReturnedValues(KeySet &keySet, PublicResult &result); outcome::checked - decryptReturnedScalar(KeySet &keySet, std::istream &istream); + decryptReturnedScalar(KeySet &keySet, PublicResult &result); outcome::checked - decryptReturnedTensor1(KeySet &keySet, std::istream &istream); + decryptReturnedTensor1(KeySet &keySet, PublicResult &result); outcome::checked - decryptReturnedTensor2(KeySet &keySet, std::istream &istream); + decryptReturnedTensor2(KeySet &keySet, PublicResult &result); outcome::checked - decryptReturnedTensor3(KeySet &keySet, std::istream &istream); + decryptReturnedTensor3(KeySet &keySet, PublicResult &result); public: ClientParameters clientParameters; @@ -70,7 +65,7 @@ public: template outcome::checked topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet, - std::istream &istream); + PublicResult &result); template class TypedClientLambda : public ClientLambda { @@ -90,19 +85,21 @@ public: serializeCall(Args... args, std::shared_ptr keySet, std::ostream &ostream) { OUTCOME_TRY(auto publicArguments, publicArguments(args..., keySet)); - return ClientLambda::untypedSerializeCall(publicArguments, ostream); + return publicArguments->serialize(ostream); } - outcome::checked + outcome::checked, StringError> publicArguments(Args... args, std::shared_ptr keySet) { - OUTCOME_TRY(auto clientArguments, EncryptedArgs::create(keySet, args...)); - return clientArguments->asPublicArguments(clientParameters, - keySet->runtimeContext()); + OUTCOME_TRY(auto clientArguments, + EncryptedArguments::create(keySet, args...)); + + return clientArguments->exportPublicArguments(clientParameters, + keySet->runtimeContext()); } - outcome::checked decryptReturned(KeySet &keySet, - std::istream &istream) { - return topLevelDecryptResult((*this), keySet, istream); + outcome::checked decryptResult(KeySet &keySet, + PublicResult &result) { + return topLevelDecryptResult((*this), keySet, result); } TypedClientLambda(ClientLambda &lambda) : ClientLambda(lambda) { @@ -115,25 +112,25 @@ protected: template friend outcome::checked topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet, - std::istream &istream); + PublicResult &result); }; template <> outcome::checked topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet, - std::istream &istream); + PublicResult &result); template <> outcome::checked topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet, - std::istream &istream); + PublicResult &result); template <> outcome::checked topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet, - std::istream &istream); + PublicResult &result); } // namespace clientlib } // namespace concretelang diff --git a/compiler/include/concretelang/ClientLib/EncryptedArgs.h b/compiler/include/concretelang/ClientLib/EncryptedArguments.h similarity index 80% rename from compiler/include/concretelang/ClientLib/EncryptedArgs.h rename to compiler/include/concretelang/ClientLib/EncryptedArguments.h index 6d3e93a68..f6edc3487 100644 --- a/compiler/include/concretelang/ClientLib/EncryptedArgs.h +++ b/compiler/include/concretelang/ClientLib/EncryptedArguments.h @@ -23,39 +23,52 @@ using concretelang::error::StringError; class PublicArguments; -class EncryptedArgs { +class EncryptedArguments { /// Temporary object used to hold and encrypt parameters before calling a /// ClientLambda. Use preferably TypeClientLambda and serializeCall(Args...). /// Otherwise convert it to a PublicArguments and use /// serializeCall(PublicArguments, KeySet). public: - // Create EncryptedArgument that use the given KeySet to perform - // encryption/decryption operations. + EncryptedArguments() : currentPos(0) {} + + /// Encrypts args thanks the given KeySet and pack the encrypted arguments to + /// an EncryptedArguments template - static outcome::checked, StringError> + static outcome::checked, StringError> create(std::shared_ptr keySet, Args... args) { - auto arguments = std::make_shared(); + auto arguments = std::make_unique(); OUTCOME_TRYV(arguments->pushArgs(keySet, args...)); return arguments; } - /** Low level interface */ + /// Export encrypted arguments as public arguments, reset the encrypted + /// arguments, i.e. move all buffers to the PublicArguments and reset the + /// positional counter. + outcome::checked, StringError> + exportPublicArguments(ClientParameters clientParameters, + RuntimeContext runtimeContext); + public: - // Add a scalar argument. + /// Add a uint8_t scalar argument. outcome::checked pushArg(uint8_t arg, std::shared_ptr keySet); - // Add a vector-tensor argument. + // Add a uint64_t scalar argument. + outcome::checked pushArg(uint64_t arg, + std::shared_ptr keySet); + + /// Add a vector-tensor argument. outcome::checked pushArg(std::vector arg, std::shared_ptr keySet); + /// Add a 1D tensor argument. template outcome::checked pushArg(std::array arg, std::shared_ptr keySet) { return pushArg(8, (void *)arg.data(), {size}, keySet); } - // Add a matrix-tensor argument. + /// Add a 2D tensor argument. template outcome::checked pushArg(std::array, size0> arg, @@ -63,7 +76,7 @@ public: return pushArg(8, (void *)arg.data(), {size0, size1}, keySet); } - // Add a rank3 tensor. + /// Add a 3D tensor argument. template outcome::checked pushArg(std::array, size1>, size0> arg, @@ -92,6 +105,7 @@ public: llvm::ArrayRef shape, std::shared_ptr keySet); + /// Push a variadic list of arguments. template outcome::checked pushArgs(std::shared_ptr keySet, Arg0 arg0, OtherArgs... others) { @@ -99,26 +113,18 @@ public: return pushArgs(keySet, others...); } + // Terminal case of pushArgs outcome::checked pushArgs(std::shared_ptr keySet) { return checkAllArgs(keySet); } - outcome::checked - asPublicArguments(ClientParameters clientParameters, - RuntimeContext runtimeContext); - - EncryptedArgs(); - ~EncryptedArgs(); - private: outcome::checked - checkPushTooManyArgs(std::shared_ptr keySetPtr); + checkPushTooManyArgs(std::shared_ptr keySet); outcome::checked checkAllArgs(std::shared_ptr keySet); - // Add a scalar argument. - outcome::checked pushArg(uint64_t arg, - std::shared_ptr keySet); +private: // Position of the next pushed argument size_t currentPos; std::vector preparedArgs; diff --git a/compiler/include/concretelang/ClientLib/KeySet.h b/compiler/include/concretelang/ClientLib/KeySet.h index 2d00242ef..adee94de0 100644 --- a/compiler/include/concretelang/ClientLib/KeySet.h +++ b/compiler/include/concretelang/ClientLib/KeySet.h @@ -29,6 +29,7 @@ class KeySet { public: KeySet(); ~KeySet(); + KeySet(KeySet &other) = delete; // allocate a KeySet according the ClientParameters. static outcome::checked, StringError> diff --git a/compiler/include/concretelang/ClientLib/PublicArguments.h b/compiler/include/concretelang/ClientLib/PublicArguments.h index 4b2faac37..ef53d4ff4 100644 --- a/compiler/include/concretelang/ClientLib/PublicArguments.h +++ b/compiler/include/concretelang/ClientLib/PublicArguments.h @@ -11,7 +11,7 @@ #include "boost/outcome.h" #include "concretelang/ClientLib/ClientParameters.h" -#include "concretelang/ClientLib/EncryptedArgs.h" +#include "concretelang/ClientLib/EncryptedArguments.h" #include "concretelang/ClientLib/Types.h" #include "concretelang/Common/Error.h" #include "concretelang/Runtime/context.h" @@ -26,7 +26,7 @@ namespace clientlib { using concretelang::error::StringError; -class EncryptedArgs; +class EncryptedArguments; class PublicArguments { /// PublicArguments will be sended to the server. It includes encrypted /// arguments and public keys. @@ -35,12 +35,9 @@ public: const ClientParameters &clientParameters, RuntimeContext runtimeContext, bool clearRuntimeContext, std::vector &&preparedArgs, std::vector &&ciphertextBuffers); - PublicArguments(PublicArguments &other) = delete; - // to have proper owership transfer (outcome and local object) - PublicArguments(PublicArguments &&other); ~PublicArguments(); - - void freeIfNotOwned(std::vector res); + PublicArguments(PublicArguments &other) = delete; + PublicArguments(PublicArguments &&other) = delete; static outcome::checked, StringError> unserialize(ClientParameters &expectedParams, std::istream &istream); @@ -57,9 +54,44 @@ private: std::vector preparedArgs; // Store buffers of ciphertexts std::vector ciphertextBuffers; + + // Indicates if this public argument own the runtime keys. bool clearRuntimeContext; }; +struct PublicResult { + /// PublicResult is a result of a ServerLambda call which contains encrypted + /// results. + + PublicResult(const ClientParameters &clientParameters, + std::vector buffers = {}) + : clientParameters(clientParameters), buffers(buffers){}; + + PublicResult(PublicResult &) = delete; + + /// Create a public result from buffers. + static std::unique_ptr + fromBuffers(const ClientParameters &clientParameters, + std::vector buffers) { + return std::make_unique(clientParameters, buffers); + } + + /// Unserialize from a input stream. + outcome::checked unserialize(std::istream &istream); + + /// Serialize into an output stream. + outcome::checked serialize(std::ostream &ostream); + + /// Decrypt the result at `pos` as a vector. + outcome::checked, StringError> + decryptVector(KeySet &keySet, size_t pos); + +private: + friend class ::concretelang::serverlib::ServerLambda; + ClientParameters clientParameters; + std::vector buffers; +}; + } // namespace clientlib } // namespace concretelang diff --git a/compiler/include/concretelang/ServerLib/ServerLambda.h b/compiler/include/concretelang/ServerLib/ServerLambda.h index 1bafbdcfe..a809b252a 100644 --- a/compiler/include/concretelang/ServerLib/ServerLambda.h +++ b/compiler/include/concretelang/ServerLib/ServerLambda.h @@ -27,17 +27,23 @@ encrypted_scalars_and_sizes_t encrypted_scalars_and_sizes_t_from_MemRef( size_t rank, encrypted_scalars_t allocated, encrypted_scalars_t aligned, size_t offset, size_t *sizes, size_t *strides); +/// ServerLambda is a utility class that allows to call a function of a +/// compilation result. class ServerLambda { public: + /// Load the symbol `funcName` of the compilation result located at the path + /// `outputLib`. static outcome::checked load(std::string funcName, std::string outputLib); + /// Load the symbol `funcName` of the dynamic loaded library static outcome::checked loadFromModule(std::shared_ptr module, std::string funcName); - outcome::checked - read_call_write(std::istream &istream, std::ostream &ostream); + /// Call the ServerLambda with public arguments. + std::unique_ptr + call(clientlib::PublicArguments &args); protected: ClientParameters clientParameters; diff --git a/compiler/include/concretelang/TestLib/TestTypedLambda.h b/compiler/include/concretelang/TestLib/TestTypedLambda.h index 235b65c83..2a07b5598 100644 --- a/compiler/include/concretelang/TestLib/TestTypedLambda.h +++ b/compiler/include/concretelang/TestLib/TestTypedLambda.h @@ -63,41 +63,17 @@ public: keySet(keySet) {} outcome::checked call(Args... args) { - // client - auto BINARY = std::ios::binary; - std::string message; - { - // client - std::ostringstream clientOuput(BINARY); - OUTCOME_TRYV(this->serializeCall(args..., keySet, clientOuput)); - if (clientOuput.fail()) { - return StringError("Error on clientOuput"); - } - message = clientOuput.str(); - } - { - // server - std::istringstream serverInput(message, BINARY); - freeStringMemory(message); - assert(serverInput.tellg() == 0); - std::ostringstream serverOutput(BINARY); - OUTCOME_TRYV(serverLambda.read_call_write(serverInput, serverOutput)); - if (serverInput.fail()) { - return StringError("Error on serverInput"); - } - if (serverOutput.fail()) { - return StringError("Error on serverOutput"); - } - message = serverOutput.str(); - } - { - // client - std::istringstream clientInput(message, BINARY); - freeStringMemory(message); - OUTCOME_TRY(auto result, this->decryptReturned(*keySet, clientInput)); - assert(clientInput.good()); - return result; - } + // client argument encryption + OUTCOME_TRY(auto encryptedArgs, + clientlib::EncryptedArguments::create(keySet, args...)); + OUTCOME_TRY(auto publicArgument, + encryptedArgs->exportPublicArguments(this->clientParameters, + keySet->runtimeContext())); + // server function call + auto publicResult = serverLambda.call(*publicArgument); + + // client result decryption + return this->decryptResult(*keySet, *publicResult); } private: diff --git a/compiler/lib/ClientLib/CMakeLists.txt b/compiler/lib/ClientLib/CMakeLists.txt index 99cd9b781..89d074187 100644 --- a/compiler/lib/ClientLib/CMakeLists.txt +++ b/compiler/lib/ClientLib/CMakeLists.txt @@ -14,7 +14,7 @@ add_mlir_library( ConcretelangClientLib ClientLambda.cpp ClientParameters.cpp - EncryptedArgs.cpp + EncryptedArguments.cpp KeySet.cpp KeySetCache.cpp PublicArguments.cpp diff --git a/compiler/lib/ClientLib/ClientLambda.cpp b/compiler/lib/ClientLib/ClientLambda.cpp index 203ce6798..5d4ace6d9 100644 --- a/compiler/lib/ClientLib/ClientLambda.cpp +++ b/compiler/lib/ClientLib/ClientLambda.cpp @@ -46,35 +46,15 @@ ClientLambda::keySet(std::shared_ptr optionalCache, seed_lsb); } -outcome::checked -ClientLambda::untypedSerializeCall(PublicArguments &serverArguments, - std::ostream &ostream) { - return serverArguments.serialize(ostream); -} - outcome::checked -ClientLambda::decryptReturnedScalar(KeySet &keySet, std::istream &istream) { - OUTCOME_TRY(auto v, decryptReturnedValues(keySet, istream)); +ClientLambda::decryptReturnedScalar(KeySet &keySet, PublicResult &result) { + OUTCOME_TRY(auto v, decryptReturnedValues(keySet, result)); return v[0]; } outcome::checked, StringError> -ClientLambda::decryptReturnedValues(KeySet &keySet, std::istream &istream) { - auto lweSize = - clientParameters.lweSecretKeyParam(clientParameters.outputs[0]).lweSize(); - std::vector sizes = clientParameters.outputs[0].shape.dimensions; - sizes.push_back(lweSize); - auto encryptedValues = unserializeEncryptedValues(sizes, istream); - if (istream.fail()) { - return StringError("Encrypted scalars has not the right size"); - } - auto len = encryptedValues.length(); - decrypted_tensor_1_t decryptedValues(len / lweSize); - for (size_t i = 0; i < decryptedValues.size(); i++) { - auto buffer = (uint64_t *)(&encryptedValues.values[i * lweSize]); - OUTCOME_TRYV(keySet.decrypt_lwe(0, buffer, decryptedValues[i])); - } - return decryptedValues; +ClientLambda::decryptReturnedValues(KeySet &keySet, PublicResult &result) { + return result.decryptVector(keySet, 0); } outcome::checked errorResultRank(size_t expected, @@ -128,7 +108,7 @@ decrypted_tensor_3_t flatToTensor(decrypted_tensor_1_t &values, size_t *sizes) { template outcome::checked -decryptReturnedTensor(std::istream &istream, ClientLambda &lambda, +decryptReturnedTensor(PublicResult &result, ClientLambda &lambda, ClientParameters ¶ms, size_t expectedRank, KeySet &keySet) { auto shape = params.outputs[0].shape; @@ -137,7 +117,7 @@ decryptReturnedTensor(std::istream &istream, ClientLambda &lambda, return StringError("Function returns a tensor of rank ") << expectedRank << " which cannot be decrypted to rank " << rank; } - OUTCOME_TRY(auto values, lambda.decryptReturnedValues(keySet, istream)); + OUTCOME_TRY(auto values, lambda.decryptReturnedValues(keySet, result)); llvm::SmallVector sizes; for (size_t dim = 0; dim < rank; dim++) { sizes.push_back(shape.dimensions[dim]); @@ -146,27 +126,27 @@ decryptReturnedTensor(std::istream &istream, ClientLambda &lambda, } outcome::checked -ClientLambda::decryptReturnedTensor1(KeySet &keySet, std::istream &istream) { +ClientLambda::decryptReturnedTensor1(KeySet &keySet, PublicResult &result) { return decryptReturnedTensor( - istream, *this, this->clientParameters, 1, keySet); + result, *this, this->clientParameters, 1, keySet); } outcome::checked -ClientLambda::decryptReturnedTensor2(KeySet &keySet, std::istream &istream) { +ClientLambda::decryptReturnedTensor2(KeySet &keySet, PublicResult &result) { return decryptReturnedTensor( - istream, *this, this->clientParameters, 2, keySet); + result, *this, this->clientParameters, 2, keySet); } outcome::checked -ClientLambda::decryptReturnedTensor3(KeySet &keySet, std::istream &istream) { +ClientLambda::decryptReturnedTensor3(KeySet &keySet, PublicResult &result) { return decryptReturnedTensor( - istream, *this, this->clientParameters, 3, keySet); + result, *this, this->clientParameters, 3, keySet); } template outcome::checked topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet, - std::istream &istream) { + PublicResult &result) { // compile time error if used using COMPATIBLE_RESULT_TYPE = void; return (Result)(COMPATIBLE_RESULT_TYPE)0; @@ -175,32 +155,32 @@ topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet, template <> outcome::checked topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet, - std::istream &istream) { - return lambda.decryptReturnedScalar(keySet, istream); + PublicResult &result) { + return lambda.decryptReturnedScalar(keySet, result); } template <> outcome::checked topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet, - std::istream &istream) { - return lambda.decryptReturnedTensor1(keySet, istream); + PublicResult &result) { + return lambda.decryptReturnedTensor1(keySet, result); } template <> outcome::checked topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet, - std::istream &istream) { - return lambda.decryptReturnedTensor2(keySet, istream); + PublicResult &result) { + return lambda.decryptReturnedTensor2(keySet, result); } template <> outcome::checked topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet, - std::istream &istream) { - return lambda.decryptReturnedTensor3(keySet, istream); + PublicResult &result) { + return lambda.decryptReturnedTensor3(keySet, result); } } // namespace clientlib diff --git a/compiler/lib/ClientLib/EncryptedArgs.cpp b/compiler/lib/ClientLib/EncryptedArguments.cpp similarity index 83% rename from compiler/lib/ClientLib/EncryptedArgs.cpp rename to compiler/lib/ClientLib/EncryptedArguments.cpp index 1c37ffdbc..ae9d5ab99 100644 --- a/compiler/lib/ClientLib/EncryptedArgs.cpp +++ b/compiler/lib/ClientLib/EncryptedArguments.cpp @@ -3,7 +3,7 @@ // https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt // for license information. -#include "concretelang/ClientLib/EncryptedArgs.h" +#include "concretelang/ClientLib/EncryptedArguments.h" #include "concretelang/ClientLib/PublicArguments.h" namespace concretelang { @@ -11,20 +11,23 @@ namespace clientlib { using StringError = concretelang::error::StringError; -EncryptedArgs::~EncryptedArgs() { - // There is no explicit allocation - // All buffers are owned by ciphertextBuffers +outcome::checked, StringError> +EncryptedArguments::exportPublicArguments(ClientParameters clientParameters, + RuntimeContext runtimeContext) { + // On client side the runtimeContext is hold by the KeySet + bool clearContext = false; + return std::make_unique( + clientParameters, runtimeContext, clearContext, std::move(preparedArgs), + std::move(ciphertextBuffers)); } -EncryptedArgs::EncryptedArgs() : currentPos(0) {} - outcome::checked -EncryptedArgs::pushArg(uint8_t arg, std::shared_ptr keySet) { +EncryptedArguments::pushArg(uint8_t arg, std::shared_ptr keySet) { return pushArg((uint64_t)arg, keySet); } outcome::checked -EncryptedArgs::pushArg(uint64_t arg, std::shared_ptr keySet) { +EncryptedArguments::pushArg(uint64_t arg, std::shared_ptr keySet) { // TODO: NON ENCRYPTED OUTCOME_TRYV(checkPushTooManyArgs(keySet)); auto pos = currentPos; @@ -65,14 +68,15 @@ EncryptedArgs::pushArg(uint64_t arg, std::shared_ptr keySet) { } outcome::checked -EncryptedArgs::pushArg(std::vector arg, - std::shared_ptr keySet) { +EncryptedArguments::pushArg(std::vector arg, + std::shared_ptr keySet) { return pushArg(8, (void *)arg.data(), {(int64_t)arg.size()}, keySet); } outcome::checked -EncryptedArgs::pushArg(size_t width, void *data, llvm::ArrayRef shape, - std::shared_ptr keySet) { +EncryptedArguments::pushArg(size_t width, void *data, + llvm::ArrayRef shape, + std::shared_ptr keySet) { OUTCOME_TRYV(checkPushTooManyArgs(keySet)); auto pos = currentPos; CircuitGate input = keySet->inputGate(pos); @@ -148,7 +152,7 @@ EncryptedArgs::pushArg(size_t width, void *data, llvm::ArrayRef shape, } outcome::checked -EncryptedArgs::checkPushTooManyArgs(std::shared_ptr keySet) { +EncryptedArguments::checkPushTooManyArgs(std::shared_ptr keySet) { size_t arity = keySet->numInputs(); if (currentPos < arity) { return outcome::success(); @@ -158,7 +162,7 @@ EncryptedArgs::checkPushTooManyArgs(std::shared_ptr keySet) { } outcome::checked -EncryptedArgs::checkAllArgs(std::shared_ptr keySet) { +EncryptedArguments::checkAllArgs(std::shared_ptr keySet) { size_t arity = keySet->numInputs(); if (currentPos == arity) { return outcome::success(); @@ -168,14 +172,5 @@ EncryptedArgs::checkAllArgs(std::shared_ptr keySet) { << " arguments"; } -outcome::checked -EncryptedArgs::asPublicArguments(ClientParameters clientParameters, - RuntimeContext runtimeContext) { - // On client side the runtimeContext is hold by the KeySet - bool clearContext = false; - return PublicArguments(clientParameters, runtimeContext, clearContext, - std::move(preparedArgs), std::move(ciphertextBuffers)); -} - } // namespace clientlib } // namespace concretelang diff --git a/compiler/lib/ClientLib/KeySet.cpp b/compiler/lib/ClientLib/KeySet.cpp index c1e53fb6e..2bfbbbc72 100644 --- a/compiler/lib/ClientLib/KeySet.cpp +++ b/compiler/lib/ClientLib/KeySet.cpp @@ -230,7 +230,6 @@ KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) { outcome::checked KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) { - if (argPos >= outputs.size()) { return StringError("decrypt_lwe: position of argument is too high"); } diff --git a/compiler/lib/ClientLib/PublicArguments.cpp b/compiler/lib/ClientLib/PublicArguments.cpp index fde5fdf89..c0023a3ec 100644 --- a/compiler/lib/ClientLib/PublicArguments.cpp +++ b/compiler/lib/ClientLib/PublicArguments.cpp @@ -29,18 +29,6 @@ PublicArguments::PublicArguments( ciphertextBuffers = std::move(ciphertextBuffers_); } -PublicArguments::PublicArguments(PublicArguments &&other) { - clientParameters = other.clientParameters; - runtimeContext = other.runtimeContext; - runtimeContext.bsk = std::move(other.runtimeContext.bsk); - clearRuntimeContext = other.clearRuntimeContext; - preparedArgs = std::move(other.preparedArgs); - ciphertextBuffers = std::move(other.ciphertextBuffers); - // transfer ownership - other.clearRuntimeContext = false; - other.runtimeContext.ksk = nullptr; -} - PublicArguments::~PublicArguments() { if (!clearRuntimeContext) { return; @@ -68,7 +56,7 @@ PublicArguments::serialize(std::ostream &ostream) { size_t rank = gate.shape.dimensions.size(); if (!gate.encryption.hasValue()) { return StringError("PublicArguments::serialize: Clear arguments " - "are not supported. Argument ") + "are not yet supported. Argument ") << iGate; } /*auto allocated = */ preparedArgs[iPreparedArgs++]; @@ -140,14 +128,27 @@ PublicArguments::unserialize(ClientParameters &clientParameters, } std::vector empty; std::vector emptyBuffers; - // On server side the PublicArguments is responsible for the context - auto clearRuntimeContext = true; auto sArguments = std::make_shared( - clientParameters, runtimeContext, clearRuntimeContext, std::move(empty), + clientParameters, runtimeContext, true, std::move(empty), std::move(emptyBuffers)); OUTCOME_TRYV(sArguments->unserializeArgs(istream)); return sArguments; } +outcome::checked, StringError> +PublicResult::decryptVector(KeySet &keySet, size_t pos) { + auto lweSize = + clientParameters.lweSecretKeyParam(clientParameters.outputs[pos]) + .lweSize(); + + auto buffer = buffers[pos]; + decrypted_tensor_1_t decryptedValues(buffer.length() / lweSize); + for (size_t i = 0; i < decryptedValues.size(); i++) { + auto ciphertext = &buffer.values[i * lweSize]; + OUTCOME_TRYV(keySet.decrypt_lwe(0, ciphertext, decryptedValues[i])); + } + return decryptedValues; +} + } // namespace clientlib } // namespace concretelang diff --git a/compiler/lib/ServerLib/ServerLambda.cpp b/compiler/lib/ServerLib/ServerLambda.cpp index 94f2c7834..b0e791904 100644 --- a/compiler/lib/ServerLib/ServerLambda.cpp +++ b/compiler/lib/ServerLib/ServerLambda.cpp @@ -123,37 +123,20 @@ ServerLambda::load(std::string funcName, std::string outputLib) { encrypted_scalars_and_sizes_t dynamicCall(void *(*func)(void *...), std::vector &preparedArgs, - CircuitGate &output, - std::ostream &ostream) { + CircuitGate &output) { size_t rank = output.shape.dimensions.size(); return multi_arity_call_dynamic_rank(func, preparedArgs, rank); } -outcome::checked -ServerLambda::read_call_write(std::istream &istream, std::ostream &ostream) { - OUTCOME_TRY(auto argumentsPtr, - PublicArguments::unserialize(clientParameters, istream)); - assert(istream.good()); - - PublicArguments &arguments = *argumentsPtr; - // The runtime context is always the last argument list - arguments.preparedArgs.push_back((void *)&arguments.runtimeContext); - auto values_and_sizes = dynamicCall(this->func, arguments.preparedArgs, - clientParameters.outputs[0], ostream); - auto shape = clientParameters.outputs[0].shape; - size_t rank = shape.dimensions.size(); - for (size_t dim = 0; dim < rank; dim++) { - if (values_and_sizes.sizes[dim] != (size_t)shape.dimensions[dim]) { - return StringError("Dimension mismatch on dim ") - << dim << " actual: " << values_and_sizes.sizes[dim] - << " vs expected: " << shape.dimensions[dim] << "\n"; - } - } - serializeEncryptedValues(values_and_sizes, ostream); - if (ostream.fail()) { - return StringError("Cannot write result"); - } - return outcome::success(); +std::unique_ptr +ServerLambda::call(PublicArguments &args) { + std::vector preparedArgs(args.preparedArgs.begin(), + args.preparedArgs.end()); + preparedArgs.push_back((void *)&args.runtimeContext); + return clientlib::PublicResult::fromBuffers( + clientParameters, + {dynamicCall(this->func, preparedArgs, clientParameters.outputs[0])}); + ; } } // namespace serverlib diff --git a/compiler/tests/Support/support_unit_test.cpp b/compiler/tests/Support/support_unit_test.cpp index 52ab9eac9..dd9fcb0ef 100644 --- a/compiler/tests/Support/support_unit_test.cpp +++ b/compiler/tests/Support/support_unit_test.cpp @@ -2,7 +2,7 @@ #include "../unittest/end_to_end_jit_test.h" #include "concretelang/ClientLib/ClientParameters.h" -#include "concretelang/ClientLib/EncryptedArgs.h" +#include "concretelang/ClientLib/EncryptedArguments.h" namespace clientlib = concretelang::clientlib; diff --git a/compiler/tests/TestLib/testlib_unit_test.cpp b/compiler/tests/TestLib/testlib_unit_test.cpp index 6ad13e80c..270786bbd 100644 --- a/compiler/tests/TestLib/testlib_unit_test.cpp +++ b/compiler/tests/TestLib/testlib_unit_test.cpp @@ -73,10 +73,11 @@ func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { ASSERT_TRUE(maybeKeySet.has_value()); std::shared_ptr keySet = std::move(maybeKeySet.value()); auto maybePublicArguments = lambda.publicArguments(1, keySet); + ASSERT_TRUE(maybePublicArguments.has_value()); auto publicArguments = std::move(maybePublicArguments.value()); std::ostringstream osstream(std::ios::binary); - EXPECT_TRUE(lambda.untypedSerializeCall(publicArguments, osstream)); + ASSERT_TRUE(publicArguments->serialize(osstream).has_value()); EXPECT_TRUE(osstream.good()); // Direct call without intermediate EXPECT_TRUE(lambda.serializeCall(1, keySet, osstream)); @@ -119,6 +120,7 @@ func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> { } TEST(CompiledModule, call_2s_1s) { + std::string source = R"( func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> { %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)