From c9bb05df8221242cd4cd415bd323e1fc099f606b Mon Sep 17 00:00:00 2001 From: Umut Date: Mon, 16 May 2022 13:32:58 +0200 Subject: [PATCH] feat: simplify (un)serialization of public args/result --- concrete/numpy/__init__.py | 2 + concrete/numpy/compilation/client.py | 19 +++++++++ concrete/numpy/compilation/server.py | 16 +++++++ concrete/numpy/compilation/specs.py | 62 +++++++++++++++++++++++++++- tests/compilation/test_circuit.py | 20 +++------ 5 files changed, 103 insertions(+), 16 deletions(-) diff --git a/concrete/numpy/__init__.py b/concrete/numpy/__init__.py index 5b1bae7e1..b37898cb5 100644 --- a/concrete/numpy/__init__.py +++ b/concrete/numpy/__init__.py @@ -2,6 +2,8 @@ Export everything that users might need. """ +from concrete.compiler import PublicArguments, PublicResult + from .compilation import ( Circuit, Client, diff --git a/concrete/numpy/compilation/client.py b/concrete/numpy/compilation/client.py index 299590805..8d2505076 100644 --- a/concrete/numpy/compilation/client.py +++ b/concrete/numpy/compilation/client.py @@ -208,3 +208,22 @@ class Client: ) return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results) + + def unserialize_and_decrypt( + self, + result: bytes, + ) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]: + """ + Decrypt serialized result of homomorphic evaluaton. + + Args: + result (bytes): + serialized encrypted result of homomorphic evaluaton + + Returns: + Union[int, numpy.ndarray]: + clear result of homomorphic evaluaton + """ + + unserialized_result = self.specs.unserialize_public_result(result) + return self.decrypt(unserialized_result) diff --git a/concrete/numpy/compilation/server.py b/concrete/numpy/compilation/server.py index d0d5afeeb..e69b52c32 100644 --- a/concrete/numpy/compilation/server.py +++ b/concrete/numpy/compilation/server.py @@ -173,6 +173,22 @@ class Server: return self._support.server_call(self._server_lambda, args) + def unserialize_and_run(self, args: bytes) -> PublicResult: + """ + Evaluate using serialized encrypted arguments. + + Args: + args (bytes): + serialized encrypted arguments of the computation + + Returns: + PublicResult: + encrypted result of the computation + """ + + unserialized_args = self.client_specs.unserialize_public_args(args) + return self.run(unserialized_args) + def cleanup(self): """ Cleanup the temporary library output directory. diff --git a/concrete/numpy/compilation/specs.py b/concrete/numpy/compilation/specs.py index b5e70df5c..12c26a0bf 100644 --- a/concrete/numpy/compilation/specs.py +++ b/concrete/numpy/compilation/specs.py @@ -5,7 +5,7 @@ Declaration of `ClientSpecs` class. import json from typing import List -from concrete.compiler import ClientParameters +from concrete.compiler import ClientParameters, PublicArguments, PublicResult class ClientSpecs: @@ -57,3 +57,63 @@ class ClientSpecs: client_parameters = ClientParameters.unserialize(client_parameters_bytes) return ClientSpecs(client_parameters, raw_specs["output_signs"]) + + def serialize_public_args(self, args: PublicArguments) -> bytes: # pylint: disable=no-self-use + """ + Serialize public arguments to bytes. + + Args: + args (PublicArguments): + public arguments to serialize + + Returns: + bytes: + serialized public arguments + """ + + return args.serialize() + + def unserialize_public_args(self, serialized_args: bytes) -> PublicArguments: + """ + Unserialize public arguments from bytes. + + Args: + serialized_args (bytes): + serialized public arguments + + Returns: + PublicArguments: + unserialized public arguments + """ + + return PublicArguments.unserialize(self.client_parameters, serialized_args) + + def serialize_public_result(self, result: PublicResult) -> bytes: # pylint: disable=no-self-use + """ + Serialize public result to bytes. + + Args: + result (PublicResult): + public result to serialize + + Returns: + bytes: + serialized public result + """ + + return result.serialize() + + def unserialize_public_result(self, serialized_result: bytes) -> PublicResult: + """ + Unserialize public result from bytes. + + Args: + serialized_result (bytes): + serialized public result + + Returns: + PublicResult: + unserialized public result + """ + + return PublicResult.unserialize(self.client_parameters, serialized_result) diff --git a/tests/compilation/test_circuit.py b/tests/compilation/test_circuit.py index 100ede9e8..e95df7fe2 100644 --- a/tests/compilation/test_circuit.py +++ b/tests/compilation/test_circuit.py @@ -7,7 +7,6 @@ from pathlib import Path import numpy as np import pytest -from concrete.compiler import PublicArguments, PublicResult from concrete.numpy import Client, ClientSpecs, Configuration, Server from concrete.numpy.compilation import compiler @@ -207,22 +206,13 @@ def test_client_server_api(helpers): ] for client in clients: - raw_input = client.encrypt(4) - serialized_input = raw_input.serialize() + args = client.encrypt(4) + serialized_args = client.specs.serialize_public_args(args) - unserialized_input = PublicArguments.unserialize( - server.client_specs.client_parameters, - serialized_input, - ) - evaluation = server.run(unserialized_input) - serialized_evaluation = evaluation.serialize() - - unserialized_evaluation = PublicResult.unserialize( - client.specs.client_parameters, - serialized_evaluation, - ) - output = client.decrypt(unserialized_evaluation) + result = server.unserialize_and_run(serialized_args) + serialized_result = server.client_specs.serialize_public_result(result) + output = client.unserialize_and_decrypt(serialized_result) assert output == 46 server.cleanup()