diff --git a/compiler/include/concretelang-c/Support/CompilerEngine.h b/compiler/include/concretelang-c/Support/CompilerEngine.h index b9d245aa4..eb447abc8 100644 --- a/compiler/include/concretelang-c/Support/CompilerEngine.h +++ b/compiler/include/concretelang-c/Support/CompilerEngine.h @@ -99,6 +99,12 @@ decrypt_result(concretelang::clientlib::KeySet &keySet, // Serialization //////////////////////////////////////////////////////////// +MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters +clientParametersUnserialize(const std::string &json); + +MLIR_CAPI_EXPORTED std::string +clientParametersSerialize(mlir::concretelang::ClientParameters ¶ms); + MLIR_CAPI_EXPORTED std::unique_ptr publicArgumentsUnserialize( mlir::concretelang::ClientParameters &clientParameters, diff --git a/compiler/include/concretelang/ClientLib/ClientParameters.h b/compiler/include/concretelang/ClientLib/ClientParameters.h index f76195a02..a679f8805 100644 --- a/compiler/include/concretelang/ClientLib/ClientParameters.h +++ b/compiler/include/concretelang/ClientLib/ClientParameters.h @@ -207,6 +207,11 @@ static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, return OS << llvm::formatv("{0:2}", toJSON(cp)); } +static inline llvm::raw_ostream &operator<<(llvm::raw_string_ostream &OS, + ClientParameters cp) { + return OS << llvm::formatv("{0:2}", toJSON(cp)); +} + } // namespace clientlib } // namespace concretelang diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 43d82daa4..c5a2fa489 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -152,7 +152,16 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( pybind11::class_(m, "KeySetCache") .def(pybind11::init()); - pybind11::class_(m, "ClientParameters"); + pybind11::class_(m, "ClientParameters") + .def_static("unserialize", + [](const pybind11::bytes &buffer) { + return clientParametersUnserialize(buffer); + }) + .def("serialize", + [](mlir::concretelang::ClientParameters &clientParameters) { + return pybind11::bytes( + clientParametersSerialize(clientParameters)); + }); pybind11::class_(m, "KeySet"); pybind11::class_ bytes: + """Serialize the ClientParameters. + + Returns: + bytes: serialized object + """ + return self.cpp().serialize() + + @staticmethod + def unserialize(serialized_params: bytes) -> "ClientParameters": + """Unserialize ClientParameters from bytes of serialized_params. + + Args: + serialized_params (bytes): previously serialized ClientParameters + + Raises: + TypeError: if serialized_params is not of type bytes + + Returns: + ClientParameters: unserialized object + """ + if not isinstance(serialized_params, bytes): + raise TypeError( + f"serialized_params must be of type bytes, not {type(serialized_params)}" + ) + return ClientParameters.wrap(_ClientParameters.unserialize(serialized_params)) diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index d0e8adde7..9c3868a26 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -179,6 +179,23 @@ publicResultSerialize(concretelang::clientlib::PublicResult &publicResult) { return buffer.str(); } +MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters +clientParametersUnserialize(const std::string &json) { + GET_OR_THROW_LLVM_EXPECTED( + clientParams, + llvm::json::parse(json)); + return clientParams.get(); +} + +MLIR_CAPI_EXPORTED std::string +clientParametersSerialize(mlir::concretelang::ClientParameters ¶ms) { + llvm::json::Value value(params); + std::string jsonParams; + llvm::raw_string_ostream buffer(jsonParams); + buffer << value; + return jsonParams; +} + void terminateParallelization() { #ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED _dfr_terminate(); diff --git a/compiler/tests/python/test_serialization.py b/compiler/tests/python/test_serialization.py index 6019ff57d..cc543e149 100644 --- a/compiler/tests/python/test_serialization.py +++ b/compiler/tests/python/test_serialization.py @@ -7,6 +7,7 @@ from concrete.compiler import ( CompilationOptions, PublicArguments, ) +from concrete.compiler.client_parameters import ClientParameters from concrete.compiler.public_result import PublicResult @@ -33,6 +34,8 @@ def run_with_serialization( Perform required loading, encryption, execution, and decryption.""" # Client client_parameters = engine.load_client_parameters(compilation_result) + serialized_client_parameters = client_parameters.serialize() + client_parameters = ClientParameters.unserialize(serialized_client_parameters) key_set = ClientSupport.key_set(client_parameters, keyset_cache) public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args) public_arguments_buffer = public_arguments.serialize()