diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Bindings/Python/CompilerEngine.h b/compilers/concrete-compiler/compiler/include/concretelang/Bindings/Python/CompilerEngine.h index a65a9712b..c0fcb758d 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Bindings/Python/CompilerEngine.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Bindings/Python/CompilerEngine.h @@ -151,6 +151,12 @@ keySetUnserialize(const std::string &buffer); MLIR_CAPI_EXPORTED std::string keySetSerialize(concretelang::clientlib::KeySet &keySet); +MLIR_CAPI_EXPORTED concretelang::clientlib::SharedScalarOrTensorData +valueUnserialize(const std::string &buffer); + +MLIR_CAPI_EXPORTED std::string +valueSerialize(const concretelang::clientlib::SharedScalarOrTensorData &value); + /// Parse then print a textual representation of an MLIR module MLIR_CAPI_EXPORTED std::string roundTrip(const char *module); diff --git a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/PublicArguments.h b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/PublicArguments.h index a8376b148..bbc7625e2 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/PublicArguments.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/PublicArguments.h @@ -112,17 +112,15 @@ private: class PublicArguments { public: PublicArguments(const ClientParameters &clientParameters, - std::vector &&ciphertextBuffers); + std::vector &buffers); ~PublicArguments(); - PublicArguments(PublicArguments &other) = delete; - PublicArguments(PublicArguments &&other) = delete; static outcome::checked, StringError> - unserialize(ClientParameters &expectedParams, std::istream &istream); + unserialize(const ClientParameters &expectedParams, std::istream &istream); outcome::checked serialize(std::ostream &ostream); - std::vector &getArguments() { return arguments; } + std::vector &getArguments() { return arguments; } ClientParameters &getClientParameters() { return clientParameters; } friend class ::concretelang::serverlib::ServerLambda; @@ -133,7 +131,7 @@ private: ClientParameters clientParameters; /// Store buffers of ciphertexts - std::vector arguments; + std::vector arguments; }; /// PublicResult is a result of a ServerLambda call which contains encrypted @@ -141,7 +139,7 @@ private: struct PublicResult { PublicResult(const ClientParameters &clientParameters, - std::vector &&buffers = {}) + std::vector &&buffers = {}) : clientParameters(clientParameters), buffers(std::move(buffers)){}; PublicResult(PublicResult &) = delete; @@ -150,17 +148,18 @@ struct PublicResult { /// @param argPos The position of the value in the PublicResult /// @return Either the value or an error if there are no value at this /// position - outcome::checked getValue(size_t argPos) { + outcome::checked + getValue(size_t argPos) { if (argPos >= buffers.size()) { return StringError("result #") << argPos << " does not exists"; } - return std::move(buffers[argPos]); + return buffers[argPos]; } /// Create a public result from buffers. static std::unique_ptr fromBuffers(const ClientParameters &clientParameters, - std::vector &&buffers) { + std::vector &&buffers) { return std::make_unique(clientParameters, std::move(buffers)); } @@ -182,7 +181,7 @@ struct PublicResult { outcome::checked asClearTextScalar(KeySet &keySet, size_t pos) { ValueDecrypter decrypter(keySet, clientParameters); - auto &data = buffers[pos]; + auto &data = buffers[pos].get(); return decrypter.template decrypt(data, pos); } @@ -192,7 +191,7 @@ struct PublicResult { outcome::checked, StringError> asClearTextVector(KeySet &keySet, size_t pos) { ValueDecrypter decrypter(keySet, clientParameters); - return decrypter.template decryptTensor(buffers[pos], pos); + return decrypter.template decryptTensor(buffers[pos].get(), pos); } /// Return the shape of the clear tensor of a result. @@ -205,7 +204,7 @@ struct PublicResult { // private: TODO tmp friend class ::concretelang::serverlib::ServerLambda; ClientParameters clientParameters; - std::vector buffers; + std::vector buffers; }; /// Helper function to convert from MemRefDescriptor to diff --git a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/Serializers.h b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/Serializers.h index cea56e4f5..ee09f71f5 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/Serializers.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/Serializers.h @@ -96,10 +96,9 @@ std::ostream &serializeScalarOrTensorData(const ScalarOrTensorData &sotd, outcome::checked unserializeScalarOrTensorData(std::istream &istream); -std::ostream & -serializeVectorOfScalarOrTensorData(const std::vector &sotd, - std::ostream &ostream); -outcome::checked, StringError> +std::ostream &serializeVectorOfScalarOrTensorData( + const std::vector &sotd, std::ostream &ostream); +outcome::checked, StringError> unserializeVectorOfScalarOrTensorData(std::istream &istream); std::ostream &operator<<(std::ostream &ostream, const LweSecretKey &wrappedKsk); diff --git a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/Types.h b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/Types.h index bd76135e5..6ae80049d 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/Types.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/Types.h @@ -689,7 +689,7 @@ public: // Returns a void pointer to the first element of a flat // representation of the tensor - void *getValuesAsOpaquePointer() { + void *getValuesAsOpaquePointer() const { switch (this->elementType) { case ElementType::u64: return static_cast(values.u64->data()); @@ -879,6 +879,19 @@ public: return *tensor; } }; + +struct SharedScalarOrTensorData { + std::shared_ptr inner; + + SharedScalarOrTensorData(std::shared_ptr inner) + : inner{inner} {} + + SharedScalarOrTensorData(ScalarOrTensorData &&inner) + : inner{std::make_shared(std::move(inner))} {} + + ScalarOrTensorData &get() const { return *this->inner; } +}; + } // namespace clientlib } // namespace concretelang diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/Utils.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/Utils.h index 003ec51cd..c1e68c141 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/Utils.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/Utils.h @@ -75,17 +75,21 @@ invokeRawOnLambda(Lambda *lambda, clientlib::ClientParameters clientParameters, } // Store the result to the PublicResult - std::vector buffers; + std::vector buffers; { size_t outputOffset = 0; for (auto &output : clientParameters.outputs) { auto shape = clientParameters.bufferShape(output); if (shape.size() == 0) { // scalar scalar - buffers.push_back(concretelang::clientlib::ScalarOrTensorData( + auto value = concretelang::clientlib::ScalarOrTensorData( concretelang::clientlib::ScalarData(outputs[outputOffset++], output.shape.sign, - output.shape.width))); + output.shape.width)); + auto sharedValue = + clientlib::SharedScalarOrTensorData(std::move(value)); + + buffers.push_back(sharedValue); } else { // buffer gate auto rank = shape.size(); @@ -102,14 +106,18 @@ invokeRawOnLambda(Lambda *lambda, clientlib::ClientParameters clientParameters, : output.shape.width; bool sign = (output.isEncrypted()) ? false : output.shape.sign; - concretelang::clientlib::TensorData td = + + auto value = concretelang::clientlib::ScalarOrTensorData( clientlib::tensorDataFromMemRef(rank, elementWidth, sign, allocated, - aligned, offset, sizes, strides); - buffers.push_back( - concretelang::clientlib::ScalarOrTensorData(std::move(td))); + aligned, offset, sizes, strides)); + auto sharedValue = + clientlib::SharedScalarOrTensorData(std::move(value)); + + buffers.push_back(sharedValue); } } } + return clientlib::PublicResult::fromBuffers(clientParameters, std::move(buffers)); } @@ -120,7 +128,8 @@ invokeRawOnLambda(Lambda *lambda, clientlib::PublicArguments &arguments, clientlib::EvaluationKeys &evaluationKeys) { // Prepare arguments with the right calling convention std::vector preparedArgs; - for (auto &arg : arguments.getArguments()) { + for (auto &sharedArg : arguments.getArguments()) { + clientlib::ScalarOrTensorData &arg = sharedArg.get(); if (arg.isScalar()) { auto scalar = arg.getScalar().getValueAsU64(); preparedArgs.push_back((void *)scalar); diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CMakeLists.txt index 1c08dbf54..a866982b4 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CMakeLists.txt @@ -60,6 +60,9 @@ declare_mlir_python_sources( concrete/compiler/public_result.py concrete/compiler/evaluation_keys.py concrete/compiler/utils.py + concrete/compiler/value.py + concrete/compiler/value_decrypter.py + concrete/compiler/value_exporter.py concrete/compiler/wrapper.py concrete/__init__.py concrete/lang/__init__.py diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 8cbd61800..3d6769444 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -285,9 +285,98 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( .def("get_evaluation_keys", [](clientlib::KeySet &keySet) { return keySet.evaluationKeys(); }); + pybind11::class_(m, "Value") + .def_static("deserialize", + [](const pybind11::bytes &buffer) { + return valueUnserialize(buffer); + }) + .def("serialize", [](const clientlib::SharedScalarOrTensorData &value) { + return pybind11::bytes(valueSerialize(value)); + }); + + pybind11::class_(m, "ValueExporter") + .def_static("create", + [](clientlib::KeySet &keySet, + mlir::concretelang::ClientParameters &clientParameters) { + return clientlib::ValueExporter(keySet, clientParameters); + }) + .def("export_scalar", + [](clientlib::ValueExporter &exporter, size_t position, + int64_t value) { + outcome::checked + result = exporter.exportValue(value, position); + + if (result.has_error()) { + throw std::runtime_error(result.error().mesg); + } + + return clientlib::SharedScalarOrTensorData( + std::move(result.value())); + }) + .def("export_tensor", [](clientlib::ValueExporter &exporter, + size_t position, std::vector values, + std::vector shape) { + outcome::checked result = + exporter.exportValue(values.data(), shape, position); + + if (result.has_error()) { + throw std::runtime_error(result.error().mesg); + } + + return clientlib::SharedScalarOrTensorData(std::move(result.value())); + }); + + pybind11::class_(m, "ValueDecrypter") + .def_static("create", + [](clientlib::KeySet &keySet, + mlir::concretelang::ClientParameters &clientParameters) { + return clientlib::ValueDecrypter(keySet, clientParameters); + }) + .def("get_shape", + [](clientlib::ValueDecrypter &decrypter, size_t position) { + outcome::checked, StringError> result = + decrypter.getShape(position); + + if (result.has_error()) { + throw std::runtime_error(result.error().mesg); + } + + return result.value(); + }) + .def("decrypt_scalar", + [](clientlib::ValueDecrypter &decrypter, size_t position, + clientlib::SharedScalarOrTensorData &value) { + outcome::checked result = + decrypter.decrypt(value.get(), position); + + if (result.has_error()) { + throw std::runtime_error(result.error().mesg); + } + + return result.value(); + }) + .def("decrypt_tensor", + [](clientlib::ValueDecrypter &decrypter, size_t position, + clientlib::SharedScalarOrTensorData &value) { + outcome::checked, StringError> result = + decrypter.decryptTensor(value.get(), position); + + if (result.has_error()) { + throw std::runtime_error(result.error().mesg); + } + + return result.value(); + }); + pybind11::class_>( m, "PublicArguments") + .def_static( + "create", + [](const mlir::concretelang::ClientParameters &clientParameters, + std::vector &buffers) { + return clientlib::PublicArguments(clientParameters, buffers); + }) .def_static("deserialize", [](mlir::concretelang::ClientParameters &clientParameters, const pybind11::bytes &buffer) { @@ -302,9 +391,25 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( const pybind11::bytes &buffer) { return publicResultUnserialize(clientParameters, buffer); }) - .def("serialize", [](clientlib::PublicResult &publicResult) { - return pybind11::bytes(publicResultSerialize(publicResult)); - }); + .def("serialize", + [](clientlib::PublicResult &publicResult) { + return pybind11::bytes(publicResultSerialize(publicResult)); + }) + .def("n_values", + [](const clientlib::PublicResult &publicResult) { + return publicResult.buffers.size(); + }) + .def("get_value", + [](clientlib::PublicResult &publicResult, size_t position) { + outcome::checked + result = publicResult.getValue(position); + + if (result.has_error()) { + throw std::runtime_error(result.error().mesg); + } + + return result.value(); + }); pybind11::class_(m, "EvaluationKeys") .def_static("deserialize", diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerEngine.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerEngine.cpp index ec949810e..560fbb399 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerEngine.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerEngine.cpp @@ -257,6 +257,26 @@ keySetSerialize(concretelang::clientlib::KeySet &keySet) { return buffer.str(); } +MLIR_CAPI_EXPORTED concretelang::clientlib::SharedScalarOrTensorData +valueUnserialize(const std::string &buffer) { + std::stringstream istream(buffer); + + auto value = concretelang::clientlib::unserializeScalarOrTensorData(istream); + if (istream.fail() || value.has_error()) { + throw std::runtime_error("Cannot read data"); + } + + return concretelang::clientlib::SharedScalarOrTensorData( + std::move(value.value())); +} + +MLIR_CAPI_EXPORTED std::string +valueSerialize(const concretelang::clientlib::SharedScalarOrTensorData &value) { + std::ostringstream buffer(std::ios::binary); + serializeScalarOrTensorData(value.get(), buffer); + return buffer.str(); +} + MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters clientParametersUnserialize(const std::string &json) { GET_OR_THROW_LLVM_EXPECTED( diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py index 48d9cc72a..14710e8f5 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py @@ -29,6 +29,9 @@ from .client_support import ClientSupport from .jit_support import JITSupport from .library_support import LibrarySupport from .evaluation_keys import EvaluationKeys +from .value import Value +from .value_decrypter import ValueDecrypter +from .value_exporter import ValueExporter def init_dfr(): diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/public_arguments.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/public_arguments.py index acbbf59a3..20a34c6e8 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/public_arguments.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/public_arguments.py @@ -3,6 +3,8 @@ """PublicArguments.""" +from typing import List + # pylint: disable=no-name-in-module,import-error from mlir._mlir_libs._concretelang._compiler import ( PublicArguments as _PublicArguments, @@ -10,6 +12,7 @@ from mlir._mlir_libs._concretelang._compiler import ( # pylint: enable=no-name-in-module,import-error from .client_parameters import ClientParameters +from .value import Value from .wrapper import WrapperCpp @@ -35,6 +38,20 @@ class PublicArguments(WrapperCpp): ) super().__init__(public_arguments) + @staticmethod + def create( + client_parameters: ClientParameters, values: List[Value] + ) -> "PublicArguments": + """ + Create public arguments from individual values. + """ + return PublicArguments( + _PublicArguments.create( + client_parameters.cpp(), + [value.cpp() for value in values], + ) + ) + def serialize(self) -> bytes: """Serialize the PublicArguments. diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/public_result.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/public_result.py index 33d15c27e..cf404c53d 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/public_result.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/public_result.py @@ -10,6 +10,7 @@ from mlir._mlir_libs._concretelang._compiler import ( from .client_parameters import ClientParameters # pylint: enable=no-name-in-module,import-error +from .value import Value from .wrapper import WrapperCpp @@ -31,6 +32,18 @@ class PublicResult(WrapperCpp): ) super().__init__(public_result) + def n_values(self) -> int: + """ + Get number of values in the result. + """ + return self.cpp().n_values() + + def get_value(self, position: int) -> Value: + """ + Get a specific value in the result. + """ + return Value(self.cpp().get_value(position)) + def serialize(self) -> bytes: """Serialize the PublicResult. diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value.py new file mode 100644 index 000000000..5a7e60867 --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value.py @@ -0,0 +1,68 @@ +"""Value.""" + +# pylint: disable=no-name-in-module,import-error + +from mlir._mlir_libs._concretelang._compiler import ( + Value as _Value, +) + +from .wrapper import WrapperCpp + +# pylint: enable=no-name-in-module,import-error + + +class Value(WrapperCpp): + """An encrypted/clear value which can be scalar/tensor.""" + + def __init__(self, value: _Value): + """ + Wrap the native C++ object. + + Args: + value (_Value): + object to wrap + + Raises: + TypeError: + if `value` is not of type `_Value` + """ + + if not isinstance(value, _Value): + raise TypeError(f"value must be of type _Value, not {type(value)}") + + super().__init__(value) + + def serialize(self) -> bytes: + """ + Serialize value into bytes. + + Returns: + bytes: serialized value + """ + + return self.cpp().serialize() + + @staticmethod + def deserialize(serialized_value: bytes) -> "Value": + """ + Deserialize value from bytes. + + Args: + serialized_value (bytes): + previously serialized value + + Returns: + Value: + deserialized value + + Raises: + TypeError: + if `serialized_value` is not of type `bytes` + """ + + if not isinstance(serialized_value, bytes): + raise TypeError( + f"serialized_value must be of type bytes, not {type(serialized_value)}" + ) + + return Value.wrap(_Value.deserialize(serialized_value)) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value_decrypter.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value_decrypter.py new file mode 100644 index 000000000..39d998dd6 --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value_decrypter.py @@ -0,0 +1,111 @@ +"""ValueDecrypter.""" + +# pylint: disable=no-name-in-module,import-error + +from typing import List, Union + +import numpy as np +from mlir._mlir_libs._concretelang._compiler import ( + ValueDecrypter as _ValueDecrypter, +) + +from .client_parameters import ClientParameters +from .key_set import KeySet +from .value import Value +from .wrapper import WrapperCpp + +# pylint: enable=no-name-in-module,import-error + + +class ValueDecrypter(WrapperCpp): + """A helper class to decrypt `Value`s.""" + + def __init__(self, value_decrypter: _ValueDecrypter): + """ + Wrap the native C++ object. + + Args: + value_decrypter (_ValueDecrypter): + object to wrap + + Raises: + TypeError: + if `value_decrypter` is not of type `_ValueDecrypter` + """ + + if not isinstance(value_decrypter, _ValueDecrypter): + raise TypeError( + f"value_decrypter must be of type _ValueDecrypter, not {type(value_decrypter)}" + ) + + super().__init__(value_decrypter) + + @staticmethod + def create(keyset: KeySet, client_parameters: ClientParameters): + """ + Create a value decrypter. + """ + return ValueDecrypter( + _ValueDecrypter.create(keyset.cpp(), client_parameters.cpp()) + ) + + def decrypt(self, position: int, value: Value) -> Union[int, np.ndarray]: + """ + Decrypt value. + + Args: + position (int): + position of the argument within the circuit + + value (Value): + value to decrypt + + Returns: + Union[int, np.ndarray]: + decrypted value + """ + + shape = tuple(self.cpp().get_shape(position)) + + if len(shape) == 0: + return self.decrypt_scalar(position, value) + + return np.array(self.decrypt_tensor(position, value), dtype=np.int64).reshape( + shape + ) + + def decrypt_scalar(self, position: int, value: Value) -> int: + """ + Decrypt scalar. + + Args: + position (int): + position of the argument within the circuit + + value (Value): + scalar value to decrypt + + Returns: + int: + decrypted scalar + """ + + return self.cpp().decrypt_scalar(position, value.cpp()) + + def decrypt_tensor(self, position: int, value: Value) -> List[int]: + """ + Decrypt tensor. + + Args: + position (int): + position of the argument within the circuit + + value (Value): + tensor value to decrypt + + Returns: + List[int]: + decrypted tensor + """ + + return self.cpp().decrypt_tensor(position, value.cpp()) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value_exporter.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value_exporter.py new file mode 100644 index 000000000..d6692c25a --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/value_exporter.py @@ -0,0 +1,90 @@ +"""ValueExporter.""" + +# pylint: disable=no-name-in-module,import-error + +from typing import List + +from mlir._mlir_libs._concretelang._compiler import ( + ValueExporter as _ValueExporter, +) + +from .client_parameters import ClientParameters +from .key_set import KeySet +from .value import Value +from .wrapper import WrapperCpp + +# pylint: enable=no-name-in-module,import-error + + +class ValueExporter(WrapperCpp): + """A helper class to create `Value`s.""" + + def __init__(self, value_exporter: _ValueExporter): + """ + Wrap the native C++ object. + + Args: + value_exporter (_ValueExporter): + object to wrap + + Raises: + TypeError: + if `value_exporter` is not of type `_ValueExporter` + """ + + if not isinstance(value_exporter, _ValueExporter): + raise TypeError( + f"value_exporter must be of type _ValueExporter, not {type(value_exporter)}" + ) + + super().__init__(value_exporter) + + @staticmethod + def create(keyset: KeySet, client_parameters: ClientParameters) -> "ValueExporter": + """ + Create a value exporter. + """ + return ValueExporter( + _ValueExporter.create(keyset.cpp(), client_parameters.cpp()) + ) + + def export_scalar(self, position: int, value: int) -> Value: + """ + Export scalar. + + Args: + position (int): + position of the argument within the circuit + + value (int): + scalar to export + + Returns: + Value: + exported scalar + """ + + return Value(self.cpp().export_scalar(position, value)) + + def export_tensor( + self, position: int, values: List[int], shape: List[int] + ) -> Value: + """ + Export tensor. + + Args: + position (int): + position of the argument within the circuit + + values (List[int]): + tensor elements to export + + shape (List[int]): + tensor shape to export + + Returns: + Value: + exported tensor + """ + + return Value(self.cpp().export_tensor(position, values, shape)) diff --git a/compilers/concrete-compiler/compiler/lib/ClientLib/EncryptedArguments.cpp b/compilers/concrete-compiler/compiler/lib/ClientLib/EncryptedArguments.cpp index dc512cf89..2fea24a56 100644 --- a/compilers/concrete-compiler/compiler/lib/ClientLib/EncryptedArguments.cpp +++ b/compilers/concrete-compiler/compiler/lib/ClientLib/EncryptedArguments.cpp @@ -13,7 +13,14 @@ using StringError = concretelang::error::StringError; outcome::checked, StringError> EncryptedArguments::exportPublicArguments(ClientParameters clientParameters) { - return std::make_unique(clientParameters, std::move(values)); + auto sharedValues = std::vector(); + sharedValues.reserve(this->values.size()); + + for (auto &&value : this->values) { + sharedValues.push_back(SharedScalarOrTensorData(std::move(value))); + } + + return std::make_unique(clientParameters, sharedValues); } /// Split the input integer into `size` chunks of `chunkWidth` bits each diff --git a/compilers/concrete-compiler/compiler/lib/ClientLib/PublicArguments.cpp b/compilers/concrete-compiler/compiler/lib/ClientLib/PublicArguments.cpp index 48c0b80b7..a69ce4622 100644 --- a/compilers/concrete-compiler/compiler/lib/ClientLib/PublicArguments.cpp +++ b/compilers/concrete-compiler/compiler/lib/ClientLib/PublicArguments.cpp @@ -15,10 +15,11 @@ namespace clientlib { using concretelang::error::StringError; // TODO: optimize the move -PublicArguments::PublicArguments(const ClientParameters &clientParameters, - std::vector &&arguments_) +PublicArguments::PublicArguments( + const ClientParameters &clientParameters, + std::vector &buffers) : clientParameters(clientParameters) { - arguments = std::move(arguments_); + arguments = buffers; } PublicArguments::~PublicArguments() {} @@ -44,11 +45,11 @@ PublicArguments::unserializeArgs(std::istream &istream) { } outcome::checked, StringError> -PublicArguments::unserialize(ClientParameters &clientParameters, +PublicArguments::unserialize(const ClientParameters &expectedParams, std::istream &istream) { - std::vector emptyBuffers; - auto sArguments = std::make_unique(clientParameters, - std::move(emptyBuffers)); + std::vector emptyBuffers; + auto sArguments = + std::make_unique(expectedParams, emptyBuffers); OUTCOME_TRYV(sArguments->unserializeArgs(istream)); return std::move(sArguments); } diff --git a/compilers/concrete-compiler/compiler/lib/ClientLib/Serializers.cpp b/compilers/concrete-compiler/compiler/lib/ClientLib/Serializers.cpp index 62a691873..4cf5a5031 100644 --- a/compilers/concrete-compiler/compiler/lib/ClientLib/Serializers.cpp +++ b/compilers/concrete-compiler/compiler/lib/ClientLib/Serializers.cpp @@ -398,8 +398,11 @@ unserializeScalarData(std::istream &istream) { template static std::istream &unserializeTensorDataElements(TensorData &values_and_sizes, std::istream &istream) { - readWords(istream, values_and_sizes.getElementPointer(0), - values_and_sizes.getNumElements()); + // getElementPointer is not valid if the tensor contains no data + if (values_and_sizes.getNumElements() > 0) { + readWords(istream, values_and_sizes.getElementPointer(0), + values_and_sizes.getNumElements()); + } return istream; } @@ -555,26 +558,25 @@ unserializeScalarOrTensorData(std::istream &istream) { } } -std::ostream & -serializeVectorOfScalarOrTensorData(const std::vector &v, - std::ostream &ostream) { +std::ostream &serializeVectorOfScalarOrTensorData( + const std::vector &v, std::ostream &ostream) { writeSize(ostream, v.size()); for (auto &sotd : v) { - serializeScalarOrTensorData(sotd, ostream); + serializeScalarOrTensorData(sotd.get(), ostream); if (!ostream.good()) { return ostream; } } return ostream; } -outcome::checked, StringError> +outcome::checked, StringError> unserializeVectorOfScalarOrTensorData(std::istream &istream) { uint64_t nbElt; readSize(istream, nbElt); - std::vector v; + std::vector v; for (uint64_t i = 0; i < nbElt; i++) { OUTCOME_TRY(auto elt, unserializeScalarOrTensorData(istream)); - v.push_back(std::move(elt)); + v.push_back(SharedScalarOrTensorData(std::move(elt))); } return v; } diff --git a/compilers/concrete-compiler/compiler/lib/Support/Jit.cpp b/compilers/concrete-compiler/compiler/lib/Support/Jit.cpp index 2561fc249..85e9d5f01 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/Jit.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/Jit.cpp @@ -95,7 +95,7 @@ JITLambda::call(clientlib::PublicArguments &args, if (auto err = invokeRaw(rawArgs)) { return std::move(err); } - std::vector buffers; + std::vector buffers; return clientlib::PublicResult::fromBuffers(args.clientParameters, std::move(buffers)); }