mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(frontend-python): separate arguments
This commit is contained in:
@@ -14,6 +14,7 @@ from .compilation import (
|
||||
ClientSpecs,
|
||||
Compiler,
|
||||
Configuration,
|
||||
Data,
|
||||
DebugArtifacts,
|
||||
EncryptionStatus,
|
||||
Keys,
|
||||
|
||||
@@ -12,6 +12,7 @@ from .configuration import (
|
||||
Configuration,
|
||||
ParameterSelectionStrategy,
|
||||
)
|
||||
from .data import Data
|
||||
from .keys import Keys
|
||||
from .server import Server
|
||||
from .specs import ClientSpecs
|
||||
|
||||
@@ -4,17 +4,15 @@ Declaration of `Circuit` class.
|
||||
|
||||
# pylint: disable=import-error,no-member,no-name-in-module
|
||||
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
# mypy: disable-error-code=attr-defined
|
||||
from concrete.compiler import PublicArguments, PublicResult
|
||||
|
||||
from ..internal.utils import assert_that
|
||||
from ..representation import Graph
|
||||
from .client import Client
|
||||
from .configuration import Configuration
|
||||
from .data import Data
|
||||
from .keys import Keys
|
||||
from .server import Server
|
||||
|
||||
@@ -99,54 +97,60 @@ class Circuit:
|
||||
|
||||
self.client.keygen(force, seed)
|
||||
|
||||
def encrypt(self, *args: Union[int, np.ndarray]) -> PublicArguments:
|
||||
def encrypt(
|
||||
self,
|
||||
*args: Optional[Union[int, np.ndarray, List]],
|
||||
) -> Optional[Union[Data, Tuple[Optional[Data], ...]]]:
|
||||
"""
|
||||
Prepare inputs to be run on the circuit.
|
||||
Encrypt argument(s) to for evaluation.
|
||||
|
||||
Args:
|
||||
*args (Union[int, numpy.ndarray]):
|
||||
inputs to the circuit
|
||||
*args (Optional[Union[int, numpy.ndarray, List]]):
|
||||
argument(s) for evaluation
|
||||
|
||||
Returns:
|
||||
PublicArguments:
|
||||
encrypted and plain arguments as well as public keys
|
||||
Optional[Union[Data, Tuple[Optional[Data], ...]]]:
|
||||
encrypted argument(s) for evaluation
|
||||
"""
|
||||
|
||||
return self.client.encrypt(*args)
|
||||
|
||||
def run(self, args: PublicArguments) -> PublicResult:
|
||||
def run(
|
||||
self,
|
||||
*args: Optional[Union[Data, Tuple[Optional[Data], ...]]],
|
||||
) -> Union[Data, Tuple[Data, ...]]:
|
||||
"""
|
||||
Evaluate circuit using encrypted arguments.
|
||||
Evaluate the circuit.
|
||||
|
||||
Args:
|
||||
args (PublicArguments):
|
||||
arguments to the circuit (can be obtained with `encrypt` method of `Circuit`)
|
||||
*args (Data):
|
||||
argument(s) for evaluation
|
||||
|
||||
Returns:
|
||||
PublicResult:
|
||||
encrypted result of homomorphic evaluaton
|
||||
Union[Data, Tuple[Data, ...]]:
|
||||
result(s) of evaluation
|
||||
"""
|
||||
|
||||
self.keygen(force=False)
|
||||
return self.server.run(args, self.client.evaluation_keys)
|
||||
return self.server.run(*args, evaluation_keys=self.client.evaluation_keys)
|
||||
|
||||
def decrypt(
|
||||
self,
|
||||
result: PublicResult,
|
||||
) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
|
||||
*results: Union[Data, Tuple[Data, ...]],
|
||||
) -> Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]:
|
||||
"""
|
||||
Decrypt result of homomorphic evaluaton.
|
||||
Decrypt result(s) of evaluation.
|
||||
|
||||
Args:
|
||||
result (PublicResult):
|
||||
encrypted result of homomorphic evaluaton
|
||||
*results (Union[Data, Tuple[Data, ...]]):
|
||||
result(s) of evaluation
|
||||
|
||||
Returns:
|
||||
Union[int, numpy.ndarray]:
|
||||
clear result of homomorphic evaluaton
|
||||
Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]:
|
||||
decrypted result(s) of evaluation
|
||||
"""
|
||||
|
||||
return self.client.decrypt(result)
|
||||
return self.client.decrypt(*results)
|
||||
|
||||
def encrypt_run_decrypt(self, *args: Any) -> Any:
|
||||
"""
|
||||
|
||||
@@ -8,14 +8,15 @@ import json
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from concrete.compiler import ClientSupport, EvaluationKeys, PublicArguments, PublicResult
|
||||
from concrete.compiler import EvaluationKeys, ValueDecrypter, ValueExporter
|
||||
|
||||
from ..dtypes.integer import SignedInteger, UnsignedInteger
|
||||
from ..internal.utils import assert_that
|
||||
from ..values.value import Value
|
||||
from .data import Data
|
||||
from .keys import Keys
|
||||
from .specs import ClientSpecs
|
||||
|
||||
@@ -116,17 +117,20 @@ class Client:
|
||||
|
||||
self.keys.generate(force=force, seed=seed)
|
||||
|
||||
def encrypt(self, *args: Union[int, np.ndarray]) -> PublicArguments:
|
||||
def encrypt(
|
||||
self,
|
||||
*args: Optional[Union[int, np.ndarray, List]],
|
||||
) -> Optional[Union[Data, Tuple[Optional[Data], ...]]]:
|
||||
"""
|
||||
Prepare inputs to be run on the circuit.
|
||||
Encrypt argument(s) to for evaluation.
|
||||
|
||||
Args:
|
||||
*args (Union[int, numpy.ndarray]):
|
||||
inputs to the circuit
|
||||
*args (Optional[Union[int, np.ndarray, List]]):
|
||||
argument(s) for evaluation
|
||||
|
||||
Returns:
|
||||
PublicArguments:
|
||||
encrypted and plain arguments as well as public keys
|
||||
Optional[Union[Data, Tuple[Optional[Data], ...]]]:
|
||||
encrypted argument(s) for evaluation
|
||||
"""
|
||||
|
||||
client_parameters_json = json.loads(self.specs.client_parameters.serialize())
|
||||
@@ -137,9 +141,12 @@ class Client:
|
||||
message = f"Expected {len(input_specs)} inputs but got {len(args)}"
|
||||
raise ValueError(message)
|
||||
|
||||
sanitized_args: Dict[int, Union[int, np.ndarray]] = {}
|
||||
for index, spec in enumerate(input_specs):
|
||||
arg = args[index]
|
||||
sanitized_args: Dict[int, Optional[Union[int, np.ndarray]]] = {}
|
||||
for index, (arg, spec) in enumerate(zip(args, input_specs)):
|
||||
if arg is None:
|
||||
sanitized_args[index] = None
|
||||
continue
|
||||
|
||||
if isinstance(arg, list):
|
||||
arg = np.array(arg)
|
||||
|
||||
@@ -183,35 +190,59 @@ class Client:
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
ordered_sanitized_args = [sanitized_args[i] for i in range(len(sanitized_args))]
|
||||
|
||||
self.keygen(force=False)
|
||||
keyset = self.keys._keyset # pylint: disable=protected-access
|
||||
|
||||
return ClientSupport.encrypt_arguments(
|
||||
self.specs.client_parameters,
|
||||
keyset,
|
||||
[sanitized_args[i] for i in range(len(sanitized_args))],
|
||||
)
|
||||
exporter = ValueExporter.create(keyset, self.specs.client_parameters)
|
||||
exported = [
|
||||
None
|
||||
if arg is None
|
||||
else Data(
|
||||
exporter.export_tensor(position, arg.flatten().tolist(), list(arg.shape))
|
||||
if isinstance(arg, np.ndarray) and arg.shape != ()
|
||||
else exporter.export_scalar(position, int(arg))
|
||||
)
|
||||
for position, arg in enumerate(ordered_sanitized_args)
|
||||
]
|
||||
|
||||
return tuple(exported) if len(exported) != 1 else exported[0]
|
||||
|
||||
def decrypt(
|
||||
self,
|
||||
result: PublicResult,
|
||||
) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
|
||||
*results: Union[Data, Tuple[Data, ...]],
|
||||
) -> Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]:
|
||||
"""
|
||||
Decrypt result of homomorphic evaluation.
|
||||
Decrypt result(s) of evaluation.
|
||||
|
||||
Args:
|
||||
result (PublicResult):
|
||||
encrypted result of homomorphic evaluation
|
||||
*results (Union[Data, Tuple[Data, ...]]):
|
||||
result(s) of evaluation
|
||||
|
||||
Returns:
|
||||
Union[int, numpy.ndarray]:
|
||||
clear result of homomorphic evaluation
|
||||
Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]:
|
||||
decrypted result(s) of evaluation
|
||||
"""
|
||||
|
||||
flattened_results: List[Data] = []
|
||||
for result in results:
|
||||
if isinstance(result, tuple): # pragma: no cover
|
||||
# this branch is impossible to cover without multiple outputs
|
||||
flattened_results.extend(result)
|
||||
else:
|
||||
flattened_results.append(result)
|
||||
|
||||
self.keygen(force=False)
|
||||
keyset = self.keys._keyset # pylint: disable=protected-access
|
||||
outputs = ClientSupport.decrypt_result(self.specs.client_parameters, keyset, result)
|
||||
return outputs
|
||||
|
||||
decrypter = ValueDecrypter.create(keyset, self.specs.client_parameters)
|
||||
decrypted = tuple(
|
||||
decrypter.decrypt(position, result.inner)
|
||||
for position, result in enumerate(flattened_results)
|
||||
)
|
||||
|
||||
return decrypted if len(decrypted) != 1 else decrypted[0]
|
||||
|
||||
@property
|
||||
def evaluation_keys(self) -> EvaluationKeys:
|
||||
|
||||
47
frontends/concrete-python/concrete/fhe/compilation/data.py
Normal file
47
frontends/concrete-python/concrete/fhe/compilation/data.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
Declaration of `Data` class.
|
||||
"""
|
||||
|
||||
# pylint: disable=import-error,no-name-in-module
|
||||
|
||||
from concrete.compiler import Value as NativeData
|
||||
|
||||
# pylint: enable=import-error,no-name-in-module
|
||||
|
||||
|
||||
class Data:
|
||||
"""
|
||||
Data class, to store scalar or tensor data which can be encrypted or clear.
|
||||
"""
|
||||
|
||||
inner: NativeData
|
||||
|
||||
def __init__(self, inner: NativeData):
|
||||
self.inner = inner
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""
|
||||
Serialize data into bytes.
|
||||
|
||||
Returns:
|
||||
bytes:
|
||||
serialized data
|
||||
"""
|
||||
|
||||
return self.inner.serialize()
|
||||
|
||||
@staticmethod
|
||||
def deserialize(serialized_data: bytes) -> "Data":
|
||||
"""
|
||||
Deserialize data from bytes.
|
||||
|
||||
Args:
|
||||
serialized_data (bytes):
|
||||
previously serialized data
|
||||
|
||||
Returns:
|
||||
Data:
|
||||
deserialized data
|
||||
"""
|
||||
|
||||
return Data(NativeData.deserialize(serialized_data))
|
||||
@@ -8,7 +8,7 @@ import json
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
# mypy: disable-error-code=attr-defined
|
||||
import concrete.compiler
|
||||
@@ -23,7 +23,6 @@ from concrete.compiler import (
|
||||
LibraryLambda,
|
||||
LibrarySupport,
|
||||
PublicArguments,
|
||||
PublicResult,
|
||||
)
|
||||
from mlir._mlir_libs._concretelang._compiler import OptimizerStrategy
|
||||
|
||||
@@ -34,6 +33,7 @@ from .configuration import (
|
||||
Configuration,
|
||||
ParameterSelectionStrategy,
|
||||
)
|
||||
from .data import Data
|
||||
from .specs import ClientSpecs
|
||||
|
||||
# pylint: enable=import-error,no-member,no-name-in-module
|
||||
@@ -258,23 +258,50 @@ class Server:
|
||||
|
||||
return Server(client_specs, output_dir, support, compilation_result, server_lambda)
|
||||
|
||||
def run(self, args: PublicArguments, evaluation_keys: EvaluationKeys) -> PublicResult:
|
||||
def run(
|
||||
self,
|
||||
*args: Optional[Union[Data, Tuple[Optional[Data], ...]]],
|
||||
evaluation_keys: EvaluationKeys,
|
||||
) -> Union[Data, Tuple[Data, ...]]:
|
||||
"""
|
||||
Evaluate using encrypted arguments.
|
||||
Evaluate.
|
||||
|
||||
Args:
|
||||
args (PublicArguments):
|
||||
encrypted arguments of the computation
|
||||
*args (Optional[Union[Data, Tuple[Optional[Data], ...]]]):
|
||||
argument(s) for evaluation
|
||||
|
||||
evaluation_keys (EvaluationKeys):
|
||||
evaluation keys for encrypted computation
|
||||
evaluation keys
|
||||
|
||||
Returns:
|
||||
PublicResult:
|
||||
encrypted result of the computation
|
||||
Union[Data, Tuple[Data, ...]]:
|
||||
result(s) of evaluation
|
||||
"""
|
||||
|
||||
return self._support.server_call(self._server_lambda, args, evaluation_keys)
|
||||
flattened_args: List[Optional[Data]] = []
|
||||
for arg in args:
|
||||
if isinstance(arg, tuple):
|
||||
flattened_args.extend(arg)
|
||||
else:
|
||||
flattened_args.append(arg)
|
||||
|
||||
buffers = []
|
||||
for i, arg in enumerate(flattened_args):
|
||||
if arg is None:
|
||||
message = f"Expected argument {i} to be an fhe.Data but it's None"
|
||||
raise ValueError(message)
|
||||
|
||||
if not isinstance(arg, Data):
|
||||
message = f"Expected argument {i} to be an fhe.Data but it's {type(arg).__name__}"
|
||||
raise ValueError(message)
|
||||
|
||||
buffers.append(arg.inner)
|
||||
|
||||
public_args = PublicArguments.create(self.client_specs.client_parameters, buffers)
|
||||
public_result = self._support.server_call(self._server_lambda, public_args, evaluation_keys)
|
||||
|
||||
result = tuple(Data(public_result.get_value(i)) for i in range(public_result.n_values()))
|
||||
return result if len(result) > 1 else result[0]
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
|
||||
@@ -7,7 +7,7 @@ Declaration of `ClientSpecs` class.
|
||||
from typing import Any
|
||||
|
||||
# mypy: disable-error-code=attr-defined
|
||||
from concrete.compiler import ClientParameters, PublicArguments, PublicResult
|
||||
from concrete.compiler import ClientParameters
|
||||
|
||||
# pylint: enable=import-error,no-member,no-name-in-module
|
||||
|
||||
@@ -55,63 +55,3 @@ class ClientSpecs:
|
||||
|
||||
client_parameters = ClientParameters.deserialize(serialized_client_specs)
|
||||
return ClientSpecs(client_parameters)
|
||||
|
||||
def serialize_public_args(self, args: PublicArguments) -> bytes:
|
||||
"""
|
||||
Serialize public arguments to bytes.
|
||||
|
||||
Args:
|
||||
args (PublicArguments):
|
||||
public arguments to serialize
|
||||
|
||||
Returns:
|
||||
bytes:
|
||||
serialized public arguments
|
||||
"""
|
||||
|
||||
return args.serialize()
|
||||
|
||||
def deserialize_public_args(self, serialized_args: bytes) -> PublicArguments:
|
||||
"""
|
||||
Deserialize public arguments from bytes.
|
||||
|
||||
Args:
|
||||
serialized_args (bytes):
|
||||
serialized public arguments
|
||||
|
||||
Returns:
|
||||
PublicArguments:
|
||||
deserialized public arguments
|
||||
"""
|
||||
|
||||
return PublicArguments.deserialize(self.client_parameters, serialized_args)
|
||||
|
||||
def serialize_public_result(self, result: PublicResult) -> bytes:
|
||||
"""
|
||||
Serialize public result to bytes.
|
||||
|
||||
Args:
|
||||
result (PublicResult):
|
||||
public result to serialize
|
||||
|
||||
Returns:
|
||||
bytes:
|
||||
serialized public result
|
||||
"""
|
||||
|
||||
return result.serialize()
|
||||
|
||||
def deserialize_public_result(self, serialized_result: bytes) -> PublicResult:
|
||||
"""
|
||||
Deserialize public result from bytes.
|
||||
|
||||
Args:
|
||||
serialized_result (bytes):
|
||||
serialized public result
|
||||
|
||||
Returns:
|
||||
PublicResult:
|
||||
deserialized public result
|
||||
"""
|
||||
|
||||
return PublicResult.deserialize(self.client_parameters, serialized_result)
|
||||
|
||||
@@ -699,6 +699,6 @@ def friendly_type_format(type_: type) -> str:
|
||||
pass
|
||||
else:
|
||||
if arg1 == None.__class__:
|
||||
return f"Optional[{friendly_type_format(arg0)}]"
|
||||
return f"Optional[{friendly_type_format(arg0)}]" # pragma: no cover
|
||||
|
||||
return result
|
||||
|
||||
@@ -8,7 +8,7 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from concrete.fhe import Client, ClientSpecs, EvaluationKeys, LookupTable, Server, compiler
|
||||
from concrete.fhe import Client, ClientSpecs, Data, EvaluationKeys, LookupTable, Server, compiler
|
||||
|
||||
|
||||
def test_circuit_str(helpers):
|
||||
@@ -128,6 +128,55 @@ def test_circuit_bad_run(helpers):
|
||||
"Expected argument 1 to be EncryptedScalar<uint6> but it's EncryptedScalar<uint7>"
|
||||
)
|
||||
|
||||
# with None
|
||||
# ---------
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
circuit.encrypt_run_decrypt(None, 10)
|
||||
|
||||
assert str(excinfo.value) == "Expected argument 0 to be an fhe.Data but it's None"
|
||||
|
||||
# with non Data
|
||||
# -------------
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
_, b = circuit.encrypt(None, 10)
|
||||
circuit.run({"yes": "no"}, b)
|
||||
|
||||
assert str(excinfo.value) == "Expected argument 0 to be an fhe.Data but it's dict"
|
||||
|
||||
|
||||
def test_circuit_separate_args(helpers):
|
||||
"""
|
||||
Test running circuit with separately encrypted args.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@compiler({"x": "encrypted", "y": "encrypted"})
|
||||
def function(x, y):
|
||||
return x + y
|
||||
|
||||
inputset = [
|
||||
(
|
||||
np.random.randint(0, 10, size=()),
|
||||
np.random.randint(0, 10, size=(3,)),
|
||||
)
|
||||
for _ in range(10)
|
||||
]
|
||||
circuit = function.compile(inputset, configuration)
|
||||
|
||||
x = 4
|
||||
y = [1, 2, 3]
|
||||
|
||||
x_encrypted, _ = circuit.encrypt(x, None)
|
||||
_, y_encrypted = circuit.encrypt(None, y)
|
||||
|
||||
x_plus_y_encrypted = circuit.run(x_encrypted, y_encrypted)
|
||||
x_plus_y = circuit.decrypt(x_plus_y_encrypted)
|
||||
|
||||
assert np.array_equal(x_plus_y, x + np.array(y))
|
||||
|
||||
|
||||
def test_client_server_api(helpers):
|
||||
"""
|
||||
@@ -168,18 +217,18 @@ def test_client_server_api(helpers):
|
||||
]
|
||||
|
||||
for client in clients:
|
||||
args = client.encrypt([3, 8, 1])
|
||||
arg = client.encrypt([3, 8, 1])
|
||||
|
||||
serialized_args = client.specs.serialize_public_args(args)
|
||||
serialized_arg = arg.serialize()
|
||||
serialized_evaluation_keys = client.evaluation_keys.serialize()
|
||||
|
||||
deserialized_args = server.client_specs.deserialize_public_args(serialized_args)
|
||||
deserialized_arg = Data.deserialize(serialized_arg)
|
||||
deserialized_evaluation_keys = EvaluationKeys.deserialize(serialized_evaluation_keys)
|
||||
|
||||
result = server.run(deserialized_args, deserialized_evaluation_keys)
|
||||
serialized_result = server.client_specs.serialize_public_result(result)
|
||||
result = server.run(deserialized_arg, evaluation_keys=deserialized_evaluation_keys)
|
||||
serialized_result = result.serialize()
|
||||
|
||||
deserialized_result = client.specs.deserialize_public_result(serialized_result)
|
||||
deserialized_result = Data.deserialize(serialized_result)
|
||||
output = client.decrypt(deserialized_result)
|
||||
|
||||
assert np.array_equal(output, [45, 50, 43])
|
||||
@@ -226,18 +275,18 @@ def test_client_server_api_crt(helpers):
|
||||
]
|
||||
|
||||
for client in clients:
|
||||
args = client.encrypt([100, 150, 10])
|
||||
arg = client.encrypt([100, 150, 10])
|
||||
|
||||
serialized_args = client.specs.serialize_public_args(args)
|
||||
serialized_arg = arg.serialize()
|
||||
serialized_evaluation_keys = client.evaluation_keys.serialize()
|
||||
|
||||
deserialized_args = server.client_specs.deserialize_public_args(serialized_args)
|
||||
deserialized_arg = Data.deserialize(serialized_arg)
|
||||
deserialized_evaluation_keys = EvaluationKeys.deserialize(serialized_evaluation_keys)
|
||||
|
||||
result = server.run(deserialized_args, deserialized_evaluation_keys)
|
||||
serialized_result = server.client_specs.serialize_public_result(result)
|
||||
result = server.run(deserialized_arg, evaluation_keys=deserialized_evaluation_keys)
|
||||
serialized_result = result.serialize()
|
||||
|
||||
deserialized_result = client.specs.deserialize_public_result(serialized_result)
|
||||
deserialized_result = Data.deserialize(serialized_result)
|
||||
output = client.decrypt(deserialized_result)
|
||||
|
||||
assert np.array_equal(output, [100**2, 150**2, 10**2])
|
||||
@@ -279,18 +328,18 @@ def test_client_server_api_via_mlir(helpers):
|
||||
]
|
||||
|
||||
for client in clients:
|
||||
args = client.encrypt([3, 8, 1])
|
||||
arg = client.encrypt([3, 8, 1])
|
||||
|
||||
serialized_args = client.specs.serialize_public_args(args)
|
||||
serialized_arg = arg.serialize()
|
||||
serialized_evaluation_keys = client.evaluation_keys.serialize()
|
||||
|
||||
deserialized_args = server.client_specs.deserialize_public_args(serialized_args)
|
||||
deserialized_arg = Data.deserialize(serialized_arg)
|
||||
deserialized_evaluation_keys = EvaluationKeys.deserialize(serialized_evaluation_keys)
|
||||
|
||||
result = server.run(deserialized_args, deserialized_evaluation_keys)
|
||||
serialized_result = server.client_specs.serialize_public_result(result)
|
||||
result = server.run(deserialized_arg, evaluation_keys=deserialized_evaluation_keys)
|
||||
serialized_result = result.serialize()
|
||||
|
||||
deserialized_result = client.specs.deserialize_public_result(serialized_result)
|
||||
deserialized_result = Data.deserialize(serialized_result)
|
||||
output = client.decrypt(deserialized_result)
|
||||
|
||||
assert np.array_equal(output, [45, 50, 43])
|
||||
|
||||
@@ -116,7 +116,7 @@ def test_keys_serialize_deserialize(helpers):
|
||||
client1.keys.generate()
|
||||
|
||||
sample = client1.encrypt(5)
|
||||
evaluation = server.run(sample, client1.evaluation_keys)
|
||||
evaluation = server.run(sample, evaluation_keys=client1.evaluation_keys)
|
||||
|
||||
client2 = fhe.Client(server.client_specs)
|
||||
client2.keys = fhe.Keys.deserialize(client1.keys.serialize())
|
||||
|
||||
Reference in New Issue
Block a user