mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: simplify (un)serialization of public args/result
This commit is contained in:
@@ -2,6 +2,8 @@
|
||||
Export everything that users might need.
|
||||
"""
|
||||
|
||||
from concrete.compiler import PublicArguments, PublicResult
|
||||
|
||||
from .compilation import (
|
||||
Circuit,
|
||||
Client,
|
||||
|
||||
@@ -208,3 +208,22 @@ class Client:
|
||||
)
|
||||
|
||||
return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results)
|
||||
|
||||
def unserialize_and_decrypt(
|
||||
self,
|
||||
result: bytes,
|
||||
) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
|
||||
"""
|
||||
Decrypt serialized result of homomorphic evaluaton.
|
||||
|
||||
Args:
|
||||
result (bytes):
|
||||
serialized encrypted result of homomorphic evaluaton
|
||||
|
||||
Returns:
|
||||
Union[int, numpy.ndarray]:
|
||||
clear result of homomorphic evaluaton
|
||||
"""
|
||||
|
||||
unserialized_result = self.specs.unserialize_public_result(result)
|
||||
return self.decrypt(unserialized_result)
|
||||
|
||||
@@ -173,6 +173,22 @@ class Server:
|
||||
|
||||
return self._support.server_call(self._server_lambda, args)
|
||||
|
||||
def unserialize_and_run(self, args: bytes) -> PublicResult:
|
||||
"""
|
||||
Evaluate using serialized encrypted arguments.
|
||||
|
||||
Args:
|
||||
args (bytes):
|
||||
serialized encrypted arguments of the computation
|
||||
|
||||
Returns:
|
||||
PublicResult:
|
||||
encrypted result of the computation
|
||||
"""
|
||||
|
||||
unserialized_args = self.client_specs.unserialize_public_args(args)
|
||||
return self.run(unserialized_args)
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
Cleanup the temporary library output directory.
|
||||
|
||||
@@ -5,7 +5,7 @@ Declaration of `ClientSpecs` class.
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from concrete.compiler import ClientParameters
|
||||
from concrete.compiler import ClientParameters, PublicArguments, PublicResult
|
||||
|
||||
|
||||
class ClientSpecs:
|
||||
@@ -57,3 +57,63 @@ class ClientSpecs:
|
||||
client_parameters = ClientParameters.unserialize(client_parameters_bytes)
|
||||
|
||||
return ClientSpecs(client_parameters, raw_specs["output_signs"])
|
||||
|
||||
def serialize_public_args(self, args: PublicArguments) -> bytes: # pylint: disable=no-self-use
|
||||
"""
|
||||
Serialize public arguments to bytes.
|
||||
|
||||
Args:
|
||||
args (PublicArguments):
|
||||
public arguments to serialize
|
||||
|
||||
Returns:
|
||||
bytes:
|
||||
serialized public arguments
|
||||
"""
|
||||
|
||||
return args.serialize()
|
||||
|
||||
def unserialize_public_args(self, serialized_args: bytes) -> PublicArguments:
|
||||
"""
|
||||
Unserialize public arguments from bytes.
|
||||
|
||||
Args:
|
||||
serialized_args (bytes):
|
||||
serialized public arguments
|
||||
|
||||
Returns:
|
||||
PublicArguments:
|
||||
unserialized public arguments
|
||||
"""
|
||||
|
||||
return PublicArguments.unserialize(self.client_parameters, serialized_args)
|
||||
|
||||
def serialize_public_result(self, result: PublicResult) -> bytes: # pylint: disable=no-self-use
|
||||
"""
|
||||
Serialize public result to bytes.
|
||||
|
||||
Args:
|
||||
result (PublicResult):
|
||||
public result to serialize
|
||||
|
||||
Returns:
|
||||
bytes:
|
||||
serialized public result
|
||||
"""
|
||||
|
||||
return result.serialize()
|
||||
|
||||
def unserialize_public_result(self, serialized_result: bytes) -> PublicResult:
|
||||
"""
|
||||
Unserialize public result from bytes.
|
||||
|
||||
Args:
|
||||
serialized_result (bytes):
|
||||
serialized public result
|
||||
|
||||
Returns:
|
||||
PublicResult:
|
||||
unserialized public result
|
||||
"""
|
||||
|
||||
return PublicResult.unserialize(self.client_parameters, serialized_result)
|
||||
|
||||
@@ -7,7 +7,6 @@ from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from concrete.compiler import PublicArguments, PublicResult
|
||||
|
||||
from concrete.numpy import Client, ClientSpecs, Configuration, Server
|
||||
from concrete.numpy.compilation import compiler
|
||||
@@ -207,22 +206,13 @@ def test_client_server_api(helpers):
|
||||
]
|
||||
|
||||
for client in clients:
|
||||
raw_input = client.encrypt(4)
|
||||
serialized_input = raw_input.serialize()
|
||||
args = client.encrypt(4)
|
||||
serialized_args = client.specs.serialize_public_args(args)
|
||||
|
||||
unserialized_input = PublicArguments.unserialize(
|
||||
server.client_specs.client_parameters,
|
||||
serialized_input,
|
||||
)
|
||||
evaluation = server.run(unserialized_input)
|
||||
serialized_evaluation = evaluation.serialize()
|
||||
|
||||
unserialized_evaluation = PublicResult.unserialize(
|
||||
client.specs.client_parameters,
|
||||
serialized_evaluation,
|
||||
)
|
||||
output = client.decrypt(unserialized_evaluation)
|
||||
result = server.unserialize_and_run(serialized_args)
|
||||
serialized_result = server.client_specs.serialize_public_result(result)
|
||||
|
||||
output = client.unserialize_and_decrypt(serialized_result)
|
||||
assert output == 46
|
||||
|
||||
server.cleanup()
|
||||
|
||||
Reference in New Issue
Block a user