From e9345b385983ebe6127cff7bc26e1bf8af9e78da Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 21 Apr 2022 14:00:17 +0100 Subject: [PATCH] feat: serialize public result w/ python bindings --- .../concretelang-c/Support/CompilerEngine.h | 7 ++++ .../concretelang/ClientLib/PublicArguments.h | 10 ++++- .../lib/Bindings/Python/CompilerAPIModule.cpp | 10 ++++- .../concrete/compiler/public_arguments.py | 18 +++++---- .../Python/concrete/compiler/public_result.py | 38 +++++++++++++++++++ compiler/lib/CAPI/Support/CompilerEngine.cpp | 22 +++++++++++ compiler/lib/ClientLib/PublicArguments.cpp | 32 ++++++++++++++++ compiler/tests/python/test_serialization.py | 5 ++- 8 files changed, 130 insertions(+), 12 deletions(-) diff --git a/compiler/include/concretelang-c/Support/CompilerEngine.h b/compiler/include/concretelang-c/Support/CompilerEngine.h index f433d0d05..b9d245aa4 100644 --- a/compiler/include/concretelang-c/Support/CompilerEngine.h +++ b/compiler/include/concretelang-c/Support/CompilerEngine.h @@ -107,6 +107,13 @@ publicArgumentsUnserialize( MLIR_CAPI_EXPORTED std::string publicArgumentsSerialize( concretelang::clientlib::PublicArguments &publicArguments); +MLIR_CAPI_EXPORTED std::unique_ptr +publicResultUnserialize(mlir::concretelang::ClientParameters &clientParameters, + const std::string &buffer); + +MLIR_CAPI_EXPORTED std::string +publicResultSerialize(concretelang::clientlib::PublicResult &publicResult); + // Parse then print a textual representation of an MLIR module MLIR_CAPI_EXPORTED std::string roundTrip(const char *module); diff --git a/compiler/include/concretelang/ClientLib/PublicArguments.h b/compiler/include/concretelang/ClientLib/PublicArguments.h index 855d416ea..3d79ec63a 100644 --- a/compiler/include/concretelang/ClientLib/PublicArguments.h +++ b/compiler/include/concretelang/ClientLib/PublicArguments.h @@ -82,9 +82,15 @@ struct PublicResult { return std::make_unique(clientParameters, buffers); } - /// Unserialize from a input stream. + /// Unserialize from an input stream inplace. outcome::checked unserialize(std::istream &istream); - + /// Unserialize from an input stream returning a new PublicResult. + static outcome::checked, StringError> + unserialize(ClientParameters &expectedParams, std::istream &istream) { + auto publicResult = std::make_unique(expectedParams); + OUTCOME_TRYV(publicResult->unserialize(istream)); + return std::move(publicResult); + } /// Serialize into an output stream. outcome::checked serialize(std::ostream &ostream); diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 49bad68ab..43d82daa4 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -166,7 +166,15 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( .def("serialize", [](clientlib::PublicArguments &publicArgument) { return pybind11::bytes(publicArgumentsSerialize(publicArgument)); }); - pybind11::class_(m, "PublicResult"); + pybind11::class_(m, "PublicResult") + .def_static("unserialize", + [](mlir::concretelang::ClientParameters &clientParameters, + const pybind11::bytes &buffer) { + return publicResultUnserialize(clientParameters, buffer); + }) + .def("serialize", [](clientlib::PublicResult &publicResult) { + return pybind11::bytes(publicResultSerialize(publicResult)); + }); pybind11::class_(m, "LambdaArgument") .def_static("from_tensor", diff --git a/compiler/lib/Bindings/Python/concrete/compiler/public_arguments.py b/compiler/lib/Bindings/Python/concrete/compiler/public_arguments.py index 9c7db0e5a..a4e888c16 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler/public_arguments.py +++ b/compiler/lib/Bindings/Python/concrete/compiler/public_arguments.py @@ -36,7 +36,7 @@ class PublicArguments(WrapperCpp): super().__init__(public_arguments) def serialize(self) -> bytes: - """Serialize the PublicArguments into a buffer. + """Serialize the PublicArguments. Returns: bytes: serialized object @@ -45,17 +45,17 @@ class PublicArguments(WrapperCpp): @staticmethod def unserialize( - client_parameters: ClientParameters, buffer: bytes + client_parameters: ClientParameters, serialized_args: bytes ) -> "PublicArguments": - """Unserialize PublicArguments from a buffer. + """Unserialize PublicArguments from bytes of serialized_args. Args: client_parameters (ClientParameters): client parameters of the compiled circuit - buffer (bytes): previously serialized PublicArguments + serialized_args (bytes): previously serialized PublicArguments Raises: TypeError: if client_parameters is not of type ClientParameters - TypeError: if buffer is not of type bytes + TypeError: if serialized_args is not of type bytes Returns: PublicArguments: unserialized object @@ -64,8 +64,10 @@ class PublicArguments(WrapperCpp): raise TypeError( f"client_parameters must be of type ClientParameters, not {type(client_parameters)}" ) - if not isinstance(buffer, bytes): - raise TypeError(f"buffer must be of type bytes, not {type(buffer)}") + if not isinstance(serialized_args, bytes): + raise TypeError( + f"serialized_args must be of type bytes, not {type(serialized_args)}" + ) return PublicArguments.wrap( - _PublicArguments.unserialize(client_parameters.cpp(), buffer) + _PublicArguments.unserialize(client_parameters.cpp(), serialized_args) ) diff --git a/compiler/lib/Bindings/Python/concrete/compiler/public_result.py b/compiler/lib/Bindings/Python/concrete/compiler/public_result.py index 36997275b..02fc8f1fa 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler/public_result.py +++ b/compiler/lib/Bindings/Python/concrete/compiler/public_result.py @@ -7,6 +7,7 @@ from mlir._mlir_libs._concretelang._compiler import ( PublicResult as _PublicResult, ) +from .client_parameters import ClientParameters # pylint: enable=no-name-in-module,import-error from .wrapper import WrapperCpp @@ -29,3 +30,40 @@ class PublicResult(WrapperCpp): f"public_result must be of type _PublicResult, not {type(public_result)}" ) super().__init__(public_result) + + def serialize(self) -> bytes: + """Serialize the PublicResult. + + Returns: + bytes: serialized object + """ + return self.cpp().serialize() + + @staticmethod + def unserialize( + client_parameters: ClientParameters, serialized_result: bytes + ) -> "PublicResult": + """Unserialize PublicResult from bytes of serialized_result. + + Args: + client_parameters (ClientParameters): client parameters of the compiled circuit + serialized_result (bytes): previously serialized PublicResult + + Raises: + TypeError: if client_parameters is not of type ClientParameters + TypeError: if serialized_result is not of type bytes + + Returns: + PublicResult: unserialized object + """ + if not isinstance(client_parameters, ClientParameters): + raise TypeError( + f"client_parameters must be of type ClientParameters, not {type(client_parameters)}" + ) + if not isinstance(serialized_result, bytes): + raise TypeError( + f"serialized_result must be of type bytes, not {type(serialized_result)}" + ) + return PublicResult.wrap( + _PublicResult.unserialize(client_parameters.cpp(), serialized_result) + ) diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index 31e33ec26..d0e8adde7 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -157,6 +157,28 @@ MLIR_CAPI_EXPORTED std::string publicArgumentsSerialize( return buffer.str(); } +MLIR_CAPI_EXPORTED std::unique_ptr +publicResultUnserialize(mlir::concretelang::ClientParameters &clientParameters, + const std::string &buffer) { + std::stringstream istream(buffer); + auto publicResultOrError = concretelang::clientlib::PublicResult::unserialize( + clientParameters, istream); + if (!publicResultOrError) { + throw std::runtime_error(publicResultOrError.error().mesg); + } + return std::move(publicResultOrError.value()); +} + +MLIR_CAPI_EXPORTED std::string +publicResultSerialize(concretelang::clientlib::PublicResult &publicResult) { + std::ostringstream buffer(std::ios::binary); + auto voidOrError = publicResult.serialize(buffer); + if (!voidOrError) { + throw std::runtime_error(voidOrError.error().mesg); + } + return buffer.str(); +} + void terminateParallelization() { #ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED _dfr_terminate(); diff --git a/compiler/lib/ClientLib/PublicArguments.cpp b/compiler/lib/ClientLib/PublicArguments.cpp index 768fa318b..ad334d136 100644 --- a/compiler/lib/ClientLib/PublicArguments.cpp +++ b/compiler/lib/ClientLib/PublicArguments.cpp @@ -136,6 +136,38 @@ PublicArguments::unserialize(ClientParameters &clientParameters, return std::move(sArguments); } +outcome::checked +PublicResult::unserialize(std::istream &istream) { + for (auto gate : clientParameters.outputs) { + if (!gate.encryption.hasValue()) { + return StringError("Clear values are not handled"); + } + auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize(); + std::vector sizes = gate.shape.dimensions; + sizes.push_back(lweSize); + buffers.push_back(unserializeTensorData(sizes, istream)); + if (istream.fail()) { + return StringError("Cannot read tensor data"); + } + } + return outcome::success(); +} + +outcome::checked +PublicResult::serialize(std::ostream &ostream) { + if (incorrectMode(ostream)) { + return StringError( + "PublicResult::serialize: ostream should be in binary mode"); + } + for (auto tensorData : buffers) { + serializeTensorData(tensorData, ostream); + if (ostream.fail()) { + return StringError("Cannot write tensor data"); + } + } + return outcome::success(); +} + void next_coord_index(size_t index[], size_t sizes[], size_t rank) { // increase multi dim index for (int r = rank - 1; r >= 0; r--) { diff --git a/compiler/tests/python/test_serialization.py b/compiler/tests/python/test_serialization.py index ad66f32ee..6019ff57d 100644 --- a/compiler/tests/python/test_serialization.py +++ b/compiler/tests/python/test_serialization.py @@ -7,6 +7,7 @@ from concrete.compiler import ( CompilationOptions, PublicArguments, ) +from concrete.compiler.public_result import PublicResult def assert_result(result, expected_result): @@ -21,7 +22,6 @@ def assert_result(result, expected_result): assert np.all(result == expected_result) -# TODO(#541): add result serialization def run_with_serialization( engine, args, @@ -43,7 +43,10 @@ def run_with_serialization( del public_arguments_buffer server_lambda = engine.load_server_lambda(compilation_result) public_result = engine.server_call(server_lambda, public_arguments) + public_result_buffer = public_result.serialize() # Client + public_result = PublicResult.unserialize(client_parameters, public_result_buffer) + del public_result_buffer result = ClientSupport.decrypt_result(key_set, public_result) return result