feat(python): seriliaze client parameters

This commit is contained in:
youben11
2022-04-28 09:43:26 +01:00
committed by Quentin Bourgerie
parent 923a1b58e1
commit 843dd0eb5b
6 changed files with 68 additions and 1 deletions

View File

@@ -99,6 +99,12 @@ decrypt_result(concretelang::clientlib::KeySet &keySet,
// Serialization ////////////////////////////////////////////////////////////
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
clientParametersUnserialize(const std::string &json);
MLIR_CAPI_EXPORTED std::string
clientParametersSerialize(mlir::concretelang::ClientParameters &params);
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicArguments>
publicArgumentsUnserialize(
mlir::concretelang::ClientParameters &clientParameters,

View File

@@ -207,6 +207,11 @@ static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
return OS << llvm::formatv("{0:2}", toJSON(cp));
}
static inline llvm::raw_ostream &operator<<(llvm::raw_string_ostream &OS,
ClientParameters cp) {
return OS << llvm::formatv("{0:2}", toJSON(cp));
}
} // namespace clientlib
} // namespace concretelang

View File

@@ -152,7 +152,16 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
pybind11::class_<clientlib::KeySetCache>(m, "KeySetCache")
.def(pybind11::init<std::string &>());
pybind11::class_<mlir::concretelang::ClientParameters>(m, "ClientParameters");
pybind11::class_<mlir::concretelang::ClientParameters>(m, "ClientParameters")
.def_static("unserialize",
[](const pybind11::bytes &buffer) {
return clientParametersUnserialize(buffer);
})
.def("serialize",
[](mlir::concretelang::ClientParameters &clientParameters) {
return pybind11::bytes(
clientParametersSerialize(clientParameters));
});
pybind11::class_<clientlib::KeySet>(m, "KeySet");
pybind11::class_<clientlib::PublicArguments,

View File

@@ -34,3 +34,30 @@ class ClientParameters(WrapperCpp):
f"client_parameters must be of type _ClientParameters, not {type(client_parameters)}"
)
super().__init__(client_parameters)
def serialize(self) -> bytes:
"""Serialize the ClientParameters.
Returns:
bytes: serialized object
"""
return self.cpp().serialize()
@staticmethod
def unserialize(serialized_params: bytes) -> "ClientParameters":
"""Unserialize ClientParameters from bytes of serialized_params.
Args:
serialized_params (bytes): previously serialized ClientParameters
Raises:
TypeError: if serialized_params is not of type bytes
Returns:
ClientParameters: unserialized object
"""
if not isinstance(serialized_params, bytes):
raise TypeError(
f"serialized_params must be of type bytes, not {type(serialized_params)}"
)
return ClientParameters.wrap(_ClientParameters.unserialize(serialized_params))

View File

@@ -179,6 +179,23 @@ publicResultSerialize(concretelang::clientlib::PublicResult &publicResult) {
return buffer.str();
}
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
clientParametersUnserialize(const std::string &json) {
GET_OR_THROW_LLVM_EXPECTED(
clientParams,
llvm::json::parse<mlir::concretelang::ClientParameters>(json));
return clientParams.get();
}
MLIR_CAPI_EXPORTED std::string
clientParametersSerialize(mlir::concretelang::ClientParameters &params) {
llvm::json::Value value(params);
std::string jsonParams;
llvm::raw_string_ostream buffer(jsonParams);
buffer << value;
return jsonParams;
}
void terminateParallelization() {
#ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED
_dfr_terminate();

View File

@@ -7,6 +7,7 @@ from concrete.compiler import (
CompilationOptions,
PublicArguments,
)
from concrete.compiler.client_parameters import ClientParameters
from concrete.compiler.public_result import PublicResult
@@ -33,6 +34,8 @@ def run_with_serialization(
Perform required loading, encryption, execution, and decryption."""
# Client
client_parameters = engine.load_client_parameters(compilation_result)
serialized_client_parameters = client_parameters.serialize()
client_parameters = ClientParameters.unserialize(serialized_client_parameters)
key_set = ClientSupport.key_set(client_parameters, keyset_cache)
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
public_arguments_buffer = public_arguments.serialize()