diff --git a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/EncryptedArguments.h b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/EncryptedArguments.h index ac7caaf21..ab3e0cfcd 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/EncryptedArguments.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/EncryptedArguments.h @@ -179,26 +179,7 @@ public: td.bulkAssign(values); ciphertextBuffers.push_back(std::move(td)); } - TensorData &td = ciphertextBuffers.back().getTensor(); - // allocated - preparedArgs.push_back(nullptr); - // aligned - preparedArgs.push_back(td.getValuesAsOpaquePointer()); - // offset - preparedArgs.push_back((void *)0); - // sizes - for (size_t size : td.getDimensions()) { - preparedArgs.push_back((void *)size); - } - - // Set the stride for each dimension, equal to the product of the - // following dimensions. - int64_t stride = td.getNumElements(); - for (size_t size : td.getDimensions()) { - stride = (size == 0 ? 0 : (stride / size)); - preparedArgs.push_back((void *)stride); - } currentPos++; return outcome::success(); } @@ -232,7 +213,6 @@ private: private: /// Position of the next pushed argument size_t currentPos; - std::vector preparedArgs; /// Store buffers of ciphertexts std::vector ciphertextBuffers; diff --git a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/PublicArguments.h b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/PublicArguments.h index b8480b251..a1ae46a2a 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/PublicArguments.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/PublicArguments.h @@ -37,7 +37,6 @@ class EncryptedArguments; class PublicArguments { public: PublicArguments(const ClientParameters &clientParameters, - std::vector &&preparedArgs, std::vector &&ciphertextBuffers); ~PublicArguments(); PublicArguments(PublicArguments &other) = delete; @@ -48,16 +47,18 @@ public: outcome::checked serialize(std::ostream &ostream); -private: + std::vector &getArguments() { return arguments; } + ClientParameters &getClientParameters() { return clientParameters; } + friend class ::concretelang::serverlib::ServerLambda; friend class ::mlir::concretelang::JITLambda; +private: outcome::checked unserializeArgs(std::istream &istream); ClientParameters clientParameters; - std::vector preparedArgs; /// Store buffers of ciphertexts - std::vector ciphertextBuffers; + std::vector arguments; }; /// PublicResult is a result of a ServerLambda call which contains encrypted diff --git a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/Serializers.h b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/Serializers.h index 22682d2c5..cea56e4f5 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/Serializers.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/Serializers.h @@ -87,17 +87,20 @@ std::ostream &serializeTensorDataRaw(const llvm::ArrayRef &dimensions, return ostream; } -outcome::checked unserializeTensorData( - std::vector &expectedSizes, // includes unsigned to - // accomodate non static sizes - std::istream &istream); +outcome::checked +unserializeTensorData(std::istream &istream); std::ostream &serializeScalarOrTensorData(const ScalarOrTensorData &sotd, std::ostream &ostream); outcome::checked -unserializeScalarOrTensorData(const std::vector &expectedSizes, - std::istream &istream); +unserializeScalarOrTensorData(std::istream &istream); + +std::ostream & +serializeVectorOfScalarOrTensorData(const std::vector &sotd, + std::ostream &ostream); +outcome::checked, StringError> +unserializeVectorOfScalarOrTensorData(std::istream &istream); std::ostream &operator<<(std::ostream &ostream, const LweSecretKey &wrappedKsk); LweSecretKey readLweSecretKey(std::istream &istream); diff --git a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/Types.h b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/Types.h index 2a7f4cad2..84874c772 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/Types.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/Types.h @@ -7,6 +7,7 @@ #define CONCRETELANG_CLIENTLIB_TYPES_H_ #include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/raw_ostream.h" #include #include @@ -818,6 +819,8 @@ public: // Retrieves the value as a generic `uint64_t` uint64_t getValueAsU64() const { size_t width = getElementTypeWidth(type); + if (width == 64) + return value.u64; uint64_t mask = ((uint64_t)1 << width) - 1; uint64_t val = value.u64 & mask; return val; diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/Utils.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/Utils.h index f2920e39f..003ec51cd 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/Utils.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/Utils.h @@ -114,6 +114,42 @@ invokeRawOnLambda(Lambda *lambda, clientlib::ClientParameters clientParameters, std::move(buffers)); } +template +llvm::Expected> +invokeRawOnLambda(Lambda *lambda, clientlib::PublicArguments &arguments, + clientlib::EvaluationKeys &evaluationKeys) { + // Prepare arguments with the right calling convention + std::vector preparedArgs; + for (auto &arg : arguments.getArguments()) { + if (arg.isScalar()) { + auto scalar = arg.getScalar().getValueAsU64(); + preparedArgs.push_back((void *)scalar); + } else { + clientlib::TensorData &td = arg.getTensor(); + // allocated + preparedArgs.push_back(nullptr); + // aligned + preparedArgs.push_back(td.getValuesAsOpaquePointer()); + // offset + preparedArgs.push_back((void *)0); + // sizes + for (size_t size : td.getDimensions()) { + preparedArgs.push_back((void *)size); + } + + // Set the stride for each dimension, equal to the product of the + // following dimensions. + int64_t stride = td.getNumElements(); + for (size_t size : td.getDimensions()) { + stride = (size == 0 ? 0 : (stride / size)); + preparedArgs.push_back((void *)stride); + } + } + } + return invokeRawOnLambda(lambda, arguments.getClientParameters(), + preparedArgs, evaluationKeys); +} + template llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const llvm::SmallVector vect) { diff --git a/compilers/concrete-compiler/compiler/lib/ClientLib/EncryptedArguments.cpp b/compilers/concrete-compiler/compiler/lib/ClientLib/EncryptedArguments.cpp index 4fe022296..f1677fb20 100644 --- a/compilers/concrete-compiler/compiler/lib/ClientLib/EncryptedArguments.cpp +++ b/compilers/concrete-compiler/compiler/lib/ClientLib/EncryptedArguments.cpp @@ -13,8 +13,8 @@ using StringError = concretelang::error::StringError; outcome::checked, StringError> EncryptedArguments::exportPublicArguments(ClientParameters clientParameters) { - return std::make_unique( - clientParameters, std::move(preparedArgs), std::move(ciphertextBuffers)); + return std::make_unique(clientParameters, + std::move(ciphertextBuffers)); } /// Split the input integer into `size` chunks of `chunkWidth` bits each @@ -49,7 +49,7 @@ EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) { } if (!input.encryption.has_value()) { // clear scalar: just push the argument - preparedArgs.push_back((void *)arg); + ciphertextBuffers.push_back(ScalarData(arg)); return outcome::success(); } @@ -63,24 +63,6 @@ EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) { OUTCOME_TRYV(keySet.encrypt_lwe( pos, values_and_sizes.getElementPointer(0), arg)); - // Note: Since we bufferized lwe ciphertext take care of memref calling - // convention - // allocated - preparedArgs.push_back(nullptr); - // aligned - preparedArgs.push_back((void *)values_and_sizes.getValuesAsOpaquePointer()); - // offset - preparedArgs.push_back((void *)0); - // sizes - for (auto size : values_and_sizes.getDimensions()) { - preparedArgs.push_back((void *)size); - } - // strides - int64_t stride = TensorData::getNumElements(shape); - for (size_t size : values_and_sizes.getDimensions()) { - stride = (size == 0 ? 0 : (stride / size)); - preparedArgs.push_back((void *)stride); - } return outcome::success(); } diff --git a/compilers/concrete-compiler/compiler/lib/ClientLib/PublicArguments.cpp b/compilers/concrete-compiler/compiler/lib/ClientLib/PublicArguments.cpp index 379614c6d..48c0b80b7 100644 --- a/compilers/concrete-compiler/compiler/lib/ClientLib/PublicArguments.cpp +++ b/compilers/concrete-compiler/compiler/lib/ClientLib/PublicArguments.cpp @@ -15,13 +15,10 @@ namespace clientlib { using concretelang::error::StringError; // TODO: optimize the move -PublicArguments::PublicArguments( - const ClientParameters &clientParameters, - std::vector &&preparedArgs_, - std::vector &&ciphertextBuffers_) +PublicArguments::PublicArguments(const ClientParameters &clientParameters, + std::vector &&arguments_) : clientParameters(clientParameters) { - preparedArgs = std::move(preparedArgs_); - ciphertextBuffers = std::move(ciphertextBuffers_); + arguments = std::move(arguments_); } PublicArguments::~PublicArguments() {} @@ -32,146 +29,41 @@ PublicArguments::serialize(std::ostream &ostream) { return StringError( "PublicArguments::serialize: ostream should be in binary mode"); } - size_t iPreparedArgs = 0; - int iGate = -1; - for (auto gate : clientParameters.inputs) { - iGate++; - size_t rank = gate.shape.dimensions.size(); - if (!gate.encryption.has_value()) { - return StringError("PublicArguments::serialize: Clear arguments " - "are not yet supported. Argument ") - << iGate; - } - - /*auto allocated = */ iPreparedArgs++; - auto aligned = (encrypted_scalars_t)preparedArgs[iPreparedArgs++]; - assert(aligned != nullptr); - auto offset = (size_t)preparedArgs[iPreparedArgs++]; - std::vector sizes; // includes lweSize as last dim - sizes.resize(rank + (gate.encryption->encoding.crt.empty() ? 1 : 2)); - for (auto dim = 0u; dim < sizes.size(); dim++) { - // sizes are part of the client parameters signature - // it's static now but some day it could be dynamic so we serialize - // them. - sizes[dim] = (size_t)preparedArgs[iPreparedArgs++]; - } - std::vector strides(rank + 1); - /* strides should be zero here and are not serialized */ - for (auto dim = 0u; dim < strides.size(); dim++) { - strides[dim] = (size_t)preparedArgs[iPreparedArgs++]; - } - // TODO: STRIDES - auto values = aligned + offset; - - writeWord(ostream, 1); - serializeTensorDataRaw(sizes, - llvm::ArrayRef{ - values, TensorData::getNumElements(sizes)}, - ostream); + serializeVectorOfScalarOrTensorData(arguments, ostream); + if (ostream.bad()) { + return StringError( + "PublicArguments::serialize: cannot serialize public arguments"); } - return outcome::success(); } outcome::checked PublicArguments::unserializeArgs(std::istream &istream) { - int iGate = -1; - for (auto gate : clientParameters.inputs) { - iGate++; - if (!gate.encryption.has_value()) { - return StringError("Clear values are not handled"); - } - - std::vector sizes = gate.shape.dimensions; - if (gate.encryption.has_value() && !gate.encryption->encoding.crt.empty()) { - sizes.push_back(gate.encryption->encoding.crt.size()); - } - auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize(); - sizes.push_back(lweSize); - - auto sotdOrErr = unserializeScalarOrTensorData(sizes, istream); - - if (sotdOrErr.has_error()) - return sotdOrErr.error(); - - ciphertextBuffers.push_back(std::move(sotdOrErr.value())); - auto &buffer = ciphertextBuffers.back(); - - if (istream.fail()) { - return StringError( - "PublicArguments::unserializeArgs: Failed to read argument ") - << iGate; - } - - if (buffer.isTensor()) { - TensorData &td = buffer.getTensor(); - preparedArgs.push_back(/*allocated*/ nullptr); - preparedArgs.push_back(td.getValuesAsOpaquePointer()); - preparedArgs.push_back(/*offset*/ 0); - // sizes - for (auto size : td.getDimensions()) { - preparedArgs.push_back((void *)size); - } - // strides has been removed by serialization - auto stride = td.length(); - for (auto size : sizes) { - stride /= size; - preparedArgs.push_back((void *)stride); - } - } else { - ScalarData &sd = buffer.getScalar(); - preparedArgs.push_back((void *)sd.getValueAsU64()); - } - } + OUTCOME_TRY(arguments, unserializeVectorOfScalarOrTensorData(istream)); return outcome::success(); } outcome::checked, StringError> PublicArguments::unserialize(ClientParameters &clientParameters, std::istream &istream) { - std::vector empty; std::vector emptyBuffers; - auto sArguments = std::make_unique( - clientParameters, std::move(empty), std::move(emptyBuffers)); + auto sArguments = std::make_unique(clientParameters, + std::move(emptyBuffers)); OUTCOME_TRYV(sArguments->unserializeArgs(istream)); return std::move(sArguments); } outcome::checked PublicResult::unserialize(std::istream &istream) { - for (auto gate : clientParameters.outputs) { - if (!gate.encryption.has_value()) { - return StringError("Clear values are not handled"); - } - - std::vector sizes = gate.shape.dimensions; - if (gate.encryption.has_value() && !gate.encryption->encoding.crt.empty()) { - sizes.push_back(gate.encryption->encoding.crt.size()); - } - auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize(); - sizes.push_back(lweSize); - - auto sotd = unserializeScalarOrTensorData(sizes, istream); - - if (sotd.has_error()) - return sotd.error(); - - buffers.push_back(std::move(sotd.value())); - } + OUTCOME_TRY(buffers, unserializeVectorOfScalarOrTensorData(istream)); return outcome::success(); } outcome::checked PublicResult::serialize(std::ostream &ostream) { - if (incorrectMode(ostream)) { - return StringError( - "PublicResult::serialize: ostream should be in binary mode"); - } - for (const ScalarOrTensorData &sotd : buffers) { - serializeScalarOrTensorData(sotd, ostream); - if (ostream.fail()) { - return StringError("Cannot write data"); - } + serializeVectorOfScalarOrTensorData(buffers, ostream); + if (ostream.bad()) { + return StringError("PublicResult::serialize: cannot serialize"); } return outcome::success(); } diff --git a/compilers/concrete-compiler/compiler/lib/ClientLib/Serializers.cpp b/compilers/concrete-compiler/compiler/lib/ClientLib/Serializers.cpp index ea7d81165..62a691873 100644 --- a/compilers/concrete-compiler/compiler/lib/ClientLib/Serializers.cpp +++ b/compilers/concrete-compiler/compiler/lib/ClientLib/Serializers.cpp @@ -444,10 +444,8 @@ std::ostream &serializeTensorData(const TensorData &values_and_sizes, assert(false && "Unhandled element type"); } -outcome::checked unserializeTensorData( - const std::vector &expectedSizes, // includes lweSize, unsigned to - // accomodate non static sizes - std::istream &istream) { +outcome::checked +unserializeTensorData(std::istream &istream) { if (incorrectMode(istream)) { return StringError("Stream is in incorrect mode"); @@ -461,13 +459,6 @@ outcome::checked unserializeTensorData( for (uint64_t i = 0; i < numDimensions; i++) { int64_t dimSize; readWord(istream, dimSize); - - if (dimSize != expectedSizes[i]) { - istream.setstate(std::ios::badbit); - return StringError("Number of dimensions did not match the number of " - "expected dimensions"); - } - dims.push_back(dimSize); } @@ -537,8 +528,7 @@ std::ostream &serializeScalarOrTensorData(const ScalarOrTensorData &sotd, } outcome::checked -unserializeScalarOrTensorData(const std::vector &expectedSizes, - std::istream &istream) { +unserializeScalarOrTensorData(std::istream &istream) { uint8_t isTensor; readWord(istream, isTensor); @@ -549,7 +539,7 @@ unserializeScalarOrTensorData(const std::vector &expectedSizes, } if (isTensor) { - auto tdOrErr = unserializeTensorData(expectedSizes, istream); + auto tdOrErr = unserializeTensorData(istream); if (tdOrErr.has_error()) return std::move(tdOrErr.error()); @@ -565,5 +555,29 @@ unserializeScalarOrTensorData(const std::vector &expectedSizes, } } +std::ostream & +serializeVectorOfScalarOrTensorData(const std::vector &v, + std::ostream &ostream) { + writeSize(ostream, v.size()); + for (auto &sotd : v) { + serializeScalarOrTensorData(sotd, ostream); + if (!ostream.good()) { + return ostream; + } + } + return ostream; +} +outcome::checked, StringError> +unserializeVectorOfScalarOrTensorData(std::istream &istream) { + uint64_t nbElt; + readSize(istream, nbElt); + std::vector v; + for (uint64_t i = 0; i < nbElt; i++) { + OUTCOME_TRY(auto elt, unserializeScalarOrTensorData(istream)); + v.push_back(std::move(elt)); + } + return v; +} + } // namespace clientlib } // namespace concretelang diff --git a/compilers/concrete-compiler/compiler/lib/ServerLib/ServerLambda.cpp b/compilers/concrete-compiler/compiler/lib/ServerLib/ServerLambda.cpp index a4274ea0d..a24e7dbd3 100644 --- a/compilers/concrete-compiler/compiler/lib/ServerLib/ServerLambda.cpp +++ b/compilers/concrete-compiler/compiler/lib/ServerLib/ServerLambda.cpp @@ -79,8 +79,7 @@ llvm::Error ServerLambda::invokeRaw(llvm::MutableArrayRef args) { llvm::Expected> ServerLambda::call(PublicArguments &args, EvaluationKeys &evaluationKeys) { - return invokeRawOnLambda(this, args.clientParameters, args.preparedArgs, - evaluationKeys); + return invokeRawOnLambda(this, args, evaluationKeys); } } // namespace serverlib diff --git a/compilers/concrete-compiler/compiler/lib/Support/Jit.cpp b/compilers/concrete-compiler/compiler/lib/Support/Jit.cpp index 614f70a1b..2561fc249 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/Jit.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/Jit.cpp @@ -101,8 +101,7 @@ JITLambda::call(clientlib::PublicArguments &args, } #endif - return ::concretelang::invokeRawOnLambda(this, args.clientParameters, - args.preparedArgs, evaluationKeys); + return ::concretelang::invokeRawOnLambda(this, args, evaluationKeys); } } // namespace concretelang diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_test.cc b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_test.cc index f40ce562c..51948c067 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_test.cc +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_test.cc @@ -79,12 +79,34 @@ public: stream << evaluationKeys; stream.seekg(0, std::ios::beg); evaluationKeys = concretelang::clientlib::readEvaluationKeys(stream); + stream.str(""); + stream.clear(); + + /* Serialize and unserialize public arguments */ + auto serializeRes = publicArguments->serialize(stream); + ASSERT_FALSE(serializeRes.has_error()); + stream.seekg(0, std::ios::beg); + auto unserializedArgs = + concretelang::clientlib::PublicArguments::unserialize(clientParameters, + stream); + stream.str(""); + stream.clear(); + ASSERT_FALSE(unserializedArgs.has_error()); /* Call the server lambda */ - auto publicResult = - support.serverCall(serverLambda, *publicArguments, evaluationKeys); + auto publicResult = support.serverCall( + serverLambda, *unserializedArgs.value(), evaluationKeys); ASSERT_EXPECTED_SUCCESS(publicResult); + /* Serialize and unserialize public result */ + serializeRes = (*publicResult)->serialize(stream); + ASSERT_FALSE(serializeRes.has_error()); + + auto unserializedResult = + concretelang::clientlib::PublicResult::unserialize(clientParameters, + stream); + ASSERT_FALSE(unserializedResult.has_error()); + /* Decrypt the public result */ auto result = mlir::concretelang::typedResult< std::unique_ptr>(*keySet,