feat: simplify (un)serialization of public args/result

This commit is contained in:
Umut
2022-05-16 13:32:58 +02:00
parent d94812b234
commit c9bb05df82
5 changed files with 103 additions and 16 deletions

View File

@@ -2,6 +2,8 @@
Export everything that users might need.
"""
from concrete.compiler import PublicArguments, PublicResult
from .compilation import (
Circuit,
Client,

View File

@@ -208,3 +208,22 @@ class Client:
)
return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results)
def unserialize_and_decrypt(
self,
result: bytes,
) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
"""
Decrypt serialized result of homomorphic evaluaton.
Args:
result (bytes):
serialized encrypted result of homomorphic evaluaton
Returns:
Union[int, numpy.ndarray]:
clear result of homomorphic evaluaton
"""
unserialized_result = self.specs.unserialize_public_result(result)
return self.decrypt(unserialized_result)

View File

@@ -173,6 +173,22 @@ class Server:
return self._support.server_call(self._server_lambda, args)
def unserialize_and_run(self, args: bytes) -> PublicResult:
"""
Evaluate using serialized encrypted arguments.
Args:
args (bytes):
serialized encrypted arguments of the computation
Returns:
PublicResult:
encrypted result of the computation
"""
unserialized_args = self.client_specs.unserialize_public_args(args)
return self.run(unserialized_args)
def cleanup(self):
"""
Cleanup the temporary library output directory.

View File

@@ -5,7 +5,7 @@ Declaration of `ClientSpecs` class.
import json
from typing import List
from concrete.compiler import ClientParameters
from concrete.compiler import ClientParameters, PublicArguments, PublicResult
class ClientSpecs:
@@ -57,3 +57,63 @@ class ClientSpecs:
client_parameters = ClientParameters.unserialize(client_parameters_bytes)
return ClientSpecs(client_parameters, raw_specs["output_signs"])
def serialize_public_args(self, args: PublicArguments) -> bytes: # pylint: disable=no-self-use
"""
Serialize public arguments to bytes.
Args:
args (PublicArguments):
public arguments to serialize
Returns:
bytes:
serialized public arguments
"""
return args.serialize()
def unserialize_public_args(self, serialized_args: bytes) -> PublicArguments:
"""
Unserialize public arguments from bytes.
Args:
serialized_args (bytes):
serialized public arguments
Returns:
PublicArguments:
unserialized public arguments
"""
return PublicArguments.unserialize(self.client_parameters, serialized_args)
def serialize_public_result(self, result: PublicResult) -> bytes: # pylint: disable=no-self-use
"""
Serialize public result to bytes.
Args:
result (PublicResult):
public result to serialize
Returns:
bytes:
serialized public result
"""
return result.serialize()
def unserialize_public_result(self, serialized_result: bytes) -> PublicResult:
"""
Unserialize public result from bytes.
Args:
serialized_result (bytes):
serialized public result
Returns:
PublicResult:
unserialized public result
"""
return PublicResult.unserialize(self.client_parameters, serialized_result)

View File

@@ -7,7 +7,6 @@ from pathlib import Path
import numpy as np
import pytest
from concrete.compiler import PublicArguments, PublicResult
from concrete.numpy import Client, ClientSpecs, Configuration, Server
from concrete.numpy.compilation import compiler
@@ -207,22 +206,13 @@ def test_client_server_api(helpers):
]
for client in clients:
raw_input = client.encrypt(4)
serialized_input = raw_input.serialize()
args = client.encrypt(4)
serialized_args = client.specs.serialize_public_args(args)
unserialized_input = PublicArguments.unserialize(
server.client_specs.client_parameters,
serialized_input,
)
evaluation = server.run(unserialized_input)
serialized_evaluation = evaluation.serialize()
unserialized_evaluation = PublicResult.unserialize(
client.specs.client_parameters,
serialized_evaluation,
)
output = client.decrypt(unserialized_evaluation)
result = server.unserialize_and_run(serialized_args)
serialized_result = server.client_specs.serialize_public_result(result)
output = client.unserialize_and_decrypt(serialized_result)
assert output == 46
server.cleanup()