feat(frontend-python): separate arguments

This commit is contained in:
Umut
2023-06-09 11:09:47 +02:00
parent 71d511756a
commit c8cc8a811d
13 changed files with 283 additions and 150 deletions

View File

@@ -14,6 +14,7 @@ from .compilation import (
ClientSpecs,
Compiler,
Configuration,
Data,
DebugArtifacts,
EncryptionStatus,
Keys,

View File

@@ -12,6 +12,7 @@ from .configuration import (
Configuration,
ParameterSelectionStrategy,
)
from .data import Data
from .keys import Keys
from .server import Server
from .specs import ClientSpecs

View File

@@ -4,17 +4,15 @@ Declaration of `Circuit` class.
# pylint: disable=import-error,no-member,no-name-in-module
from typing import Any, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union
import numpy as np
# mypy: disable-error-code=attr-defined
from concrete.compiler import PublicArguments, PublicResult
from ..internal.utils import assert_that
from ..representation import Graph
from .client import Client
from .configuration import Configuration
from .data import Data
from .keys import Keys
from .server import Server
@@ -99,54 +97,60 @@ class Circuit:
self.client.keygen(force, seed)
def encrypt(self, *args: Union[int, np.ndarray]) -> PublicArguments:
def encrypt(
self,
*args: Optional[Union[int, np.ndarray, List]],
) -> Optional[Union[Data, Tuple[Optional[Data], ...]]]:
"""
Prepare inputs to be run on the circuit.
Encrypt argument(s) to for evaluation.
Args:
*args (Union[int, numpy.ndarray]):
inputs to the circuit
*args (Optional[Union[int, numpy.ndarray, List]]):
argument(s) for evaluation
Returns:
PublicArguments:
encrypted and plain arguments as well as public keys
Optional[Union[Data, Tuple[Optional[Data], ...]]]:
encrypted argument(s) for evaluation
"""
return self.client.encrypt(*args)
def run(self, args: PublicArguments) -> PublicResult:
def run(
self,
*args: Optional[Union[Data, Tuple[Optional[Data], ...]]],
) -> Union[Data, Tuple[Data, ...]]:
"""
Evaluate circuit using encrypted arguments.
Evaluate the circuit.
Args:
args (PublicArguments):
arguments to the circuit (can be obtained with `encrypt` method of `Circuit`)
*args (Data):
argument(s) for evaluation
Returns:
PublicResult:
encrypted result of homomorphic evaluaton
Union[Data, Tuple[Data, ...]]:
result(s) of evaluation
"""
self.keygen(force=False)
return self.server.run(args, self.client.evaluation_keys)
return self.server.run(*args, evaluation_keys=self.client.evaluation_keys)
def decrypt(
self,
result: PublicResult,
) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
*results: Union[Data, Tuple[Data, ...]],
) -> Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]:
"""
Decrypt result of homomorphic evaluaton.
Decrypt result(s) of evaluation.
Args:
result (PublicResult):
encrypted result of homomorphic evaluaton
*results (Union[Data, Tuple[Data, ...]]):
result(s) of evaluation
Returns:
Union[int, numpy.ndarray]:
clear result of homomorphic evaluaton
Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]:
decrypted result(s) of evaluation
"""
return self.client.decrypt(result)
return self.client.decrypt(*results)
def encrypt_run_decrypt(self, *args: Any) -> Any:
"""

View File

@@ -8,14 +8,15 @@ import json
import shutil
import tempfile
from pathlib import Path
from typing import Dict, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from concrete.compiler import ClientSupport, EvaluationKeys, PublicArguments, PublicResult
from concrete.compiler import EvaluationKeys, ValueDecrypter, ValueExporter
from ..dtypes.integer import SignedInteger, UnsignedInteger
from ..internal.utils import assert_that
from ..values.value import Value
from .data import Data
from .keys import Keys
from .specs import ClientSpecs
@@ -116,17 +117,20 @@ class Client:
self.keys.generate(force=force, seed=seed)
def encrypt(self, *args: Union[int, np.ndarray]) -> PublicArguments:
def encrypt(
self,
*args: Optional[Union[int, np.ndarray, List]],
) -> Optional[Union[Data, Tuple[Optional[Data], ...]]]:
"""
Prepare inputs to be run on the circuit.
Encrypt argument(s) to for evaluation.
Args:
*args (Union[int, numpy.ndarray]):
inputs to the circuit
*args (Optional[Union[int, np.ndarray, List]]):
argument(s) for evaluation
Returns:
PublicArguments:
encrypted and plain arguments as well as public keys
Optional[Union[Data, Tuple[Optional[Data], ...]]]:
encrypted argument(s) for evaluation
"""
client_parameters_json = json.loads(self.specs.client_parameters.serialize())
@@ -137,9 +141,12 @@ class Client:
message = f"Expected {len(input_specs)} inputs but got {len(args)}"
raise ValueError(message)
sanitized_args: Dict[int, Union[int, np.ndarray]] = {}
for index, spec in enumerate(input_specs):
arg = args[index]
sanitized_args: Dict[int, Optional[Union[int, np.ndarray]]] = {}
for index, (arg, spec) in enumerate(zip(args, input_specs)):
if arg is None:
sanitized_args[index] = None
continue
if isinstance(arg, list):
arg = np.array(arg)
@@ -183,35 +190,59 @@ class Client:
)
raise ValueError(message)
ordered_sanitized_args = [sanitized_args[i] for i in range(len(sanitized_args))]
self.keygen(force=False)
keyset = self.keys._keyset # pylint: disable=protected-access
return ClientSupport.encrypt_arguments(
self.specs.client_parameters,
keyset,
[sanitized_args[i] for i in range(len(sanitized_args))],
)
exporter = ValueExporter.create(keyset, self.specs.client_parameters)
exported = [
None
if arg is None
else Data(
exporter.export_tensor(position, arg.flatten().tolist(), list(arg.shape))
if isinstance(arg, np.ndarray) and arg.shape != ()
else exporter.export_scalar(position, int(arg))
)
for position, arg in enumerate(ordered_sanitized_args)
]
return tuple(exported) if len(exported) != 1 else exported[0]
def decrypt(
self,
result: PublicResult,
) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
*results: Union[Data, Tuple[Data, ...]],
) -> Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]:
"""
Decrypt result of homomorphic evaluation.
Decrypt result(s) of evaluation.
Args:
result (PublicResult):
encrypted result of homomorphic evaluation
*results (Union[Data, Tuple[Data, ...]]):
result(s) of evaluation
Returns:
Union[int, numpy.ndarray]:
clear result of homomorphic evaluation
Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]:
decrypted result(s) of evaluation
"""
flattened_results: List[Data] = []
for result in results:
if isinstance(result, tuple): # pragma: no cover
# this branch is impossible to cover without multiple outputs
flattened_results.extend(result)
else:
flattened_results.append(result)
self.keygen(force=False)
keyset = self.keys._keyset # pylint: disable=protected-access
outputs = ClientSupport.decrypt_result(self.specs.client_parameters, keyset, result)
return outputs
decrypter = ValueDecrypter.create(keyset, self.specs.client_parameters)
decrypted = tuple(
decrypter.decrypt(position, result.inner)
for position, result in enumerate(flattened_results)
)
return decrypted if len(decrypted) != 1 else decrypted[0]
@property
def evaluation_keys(self) -> EvaluationKeys:

View File

@@ -0,0 +1,47 @@
"""
Declaration of `Data` class.
"""
# pylint: disable=import-error,no-name-in-module
from concrete.compiler import Value as NativeData
# pylint: enable=import-error,no-name-in-module
class Data:
"""
Data class, to store scalar or tensor data which can be encrypted or clear.
"""
inner: NativeData
def __init__(self, inner: NativeData):
self.inner = inner
def serialize(self) -> bytes:
"""
Serialize data into bytes.
Returns:
bytes:
serialized data
"""
return self.inner.serialize()
@staticmethod
def deserialize(serialized_data: bytes) -> "Data":
"""
Deserialize data from bytes.
Args:
serialized_data (bytes):
previously serialized data
Returns:
Data:
deserialized data
"""
return Data(NativeData.deserialize(serialized_data))

View File

@@ -8,7 +8,7 @@ import json
import shutil
import tempfile
from pathlib import Path
from typing import Optional, Union
from typing import List, Optional, Tuple, Union
# mypy: disable-error-code=attr-defined
import concrete.compiler
@@ -23,7 +23,6 @@ from concrete.compiler import (
LibraryLambda,
LibrarySupport,
PublicArguments,
PublicResult,
)
from mlir._mlir_libs._concretelang._compiler import OptimizerStrategy
@@ -34,6 +33,7 @@ from .configuration import (
Configuration,
ParameterSelectionStrategy,
)
from .data import Data
from .specs import ClientSpecs
# pylint: enable=import-error,no-member,no-name-in-module
@@ -258,23 +258,50 @@ class Server:
return Server(client_specs, output_dir, support, compilation_result, server_lambda)
def run(self, args: PublicArguments, evaluation_keys: EvaluationKeys) -> PublicResult:
def run(
self,
*args: Optional[Union[Data, Tuple[Optional[Data], ...]]],
evaluation_keys: EvaluationKeys,
) -> Union[Data, Tuple[Data, ...]]:
"""
Evaluate using encrypted arguments.
Evaluate.
Args:
args (PublicArguments):
encrypted arguments of the computation
*args (Optional[Union[Data, Tuple[Optional[Data], ...]]]):
argument(s) for evaluation
evaluation_keys (EvaluationKeys):
evaluation keys for encrypted computation
evaluation keys
Returns:
PublicResult:
encrypted result of the computation
Union[Data, Tuple[Data, ...]]:
result(s) of evaluation
"""
return self._support.server_call(self._server_lambda, args, evaluation_keys)
flattened_args: List[Optional[Data]] = []
for arg in args:
if isinstance(arg, tuple):
flattened_args.extend(arg)
else:
flattened_args.append(arg)
buffers = []
for i, arg in enumerate(flattened_args):
if arg is None:
message = f"Expected argument {i} to be an fhe.Data but it's None"
raise ValueError(message)
if not isinstance(arg, Data):
message = f"Expected argument {i} to be an fhe.Data but it's {type(arg).__name__}"
raise ValueError(message)
buffers.append(arg.inner)
public_args = PublicArguments.create(self.client_specs.client_parameters, buffers)
public_result = self._support.server_call(self._server_lambda, public_args, evaluation_keys)
result = tuple(Data(public_result.get_value(i)) for i in range(public_result.n_values()))
return result if len(result) > 1 else result[0]
def cleanup(self):
"""

View File

@@ -7,7 +7,7 @@ Declaration of `ClientSpecs` class.
from typing import Any
# mypy: disable-error-code=attr-defined
from concrete.compiler import ClientParameters, PublicArguments, PublicResult
from concrete.compiler import ClientParameters
# pylint: enable=import-error,no-member,no-name-in-module
@@ -55,63 +55,3 @@ class ClientSpecs:
client_parameters = ClientParameters.deserialize(serialized_client_specs)
return ClientSpecs(client_parameters)
def serialize_public_args(self, args: PublicArguments) -> bytes:
"""
Serialize public arguments to bytes.
Args:
args (PublicArguments):
public arguments to serialize
Returns:
bytes:
serialized public arguments
"""
return args.serialize()
def deserialize_public_args(self, serialized_args: bytes) -> PublicArguments:
"""
Deserialize public arguments from bytes.
Args:
serialized_args (bytes):
serialized public arguments
Returns:
PublicArguments:
deserialized public arguments
"""
return PublicArguments.deserialize(self.client_parameters, serialized_args)
def serialize_public_result(self, result: PublicResult) -> bytes:
"""
Serialize public result to bytes.
Args:
result (PublicResult):
public result to serialize
Returns:
bytes:
serialized public result
"""
return result.serialize()
def deserialize_public_result(self, serialized_result: bytes) -> PublicResult:
"""
Deserialize public result from bytes.
Args:
serialized_result (bytes):
serialized public result
Returns:
PublicResult:
deserialized public result
"""
return PublicResult.deserialize(self.client_parameters, serialized_result)

View File

@@ -699,6 +699,6 @@ def friendly_type_format(type_: type) -> str:
pass
else:
if arg1 == None.__class__:
return f"Optional[{friendly_type_format(arg0)}]"
return f"Optional[{friendly_type_format(arg0)}]" # pragma: no cover
return result

View File

@@ -8,7 +8,7 @@ from pathlib import Path
import numpy as np
import pytest
from concrete.fhe import Client, ClientSpecs, EvaluationKeys, LookupTable, Server, compiler
from concrete.fhe import Client, ClientSpecs, Data, EvaluationKeys, LookupTable, Server, compiler
def test_circuit_str(helpers):
@@ -128,6 +128,55 @@ def test_circuit_bad_run(helpers):
"Expected argument 1 to be EncryptedScalar<uint6> but it's EncryptedScalar<uint7>"
)
# with None
# ---------
with pytest.raises(ValueError) as excinfo:
circuit.encrypt_run_decrypt(None, 10)
assert str(excinfo.value) == "Expected argument 0 to be an fhe.Data but it's None"
# with non Data
# -------------
with pytest.raises(ValueError) as excinfo:
_, b = circuit.encrypt(None, 10)
circuit.run({"yes": "no"}, b)
assert str(excinfo.value) == "Expected argument 0 to be an fhe.Data but it's dict"
def test_circuit_separate_args(helpers):
"""
Test running circuit with separately encrypted args.
"""
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"})
def function(x, y):
return x + y
inputset = [
(
np.random.randint(0, 10, size=()),
np.random.randint(0, 10, size=(3,)),
)
for _ in range(10)
]
circuit = function.compile(inputset, configuration)
x = 4
y = [1, 2, 3]
x_encrypted, _ = circuit.encrypt(x, None)
_, y_encrypted = circuit.encrypt(None, y)
x_plus_y_encrypted = circuit.run(x_encrypted, y_encrypted)
x_plus_y = circuit.decrypt(x_plus_y_encrypted)
assert np.array_equal(x_plus_y, x + np.array(y))
def test_client_server_api(helpers):
"""
@@ -168,18 +217,18 @@ def test_client_server_api(helpers):
]
for client in clients:
args = client.encrypt([3, 8, 1])
arg = client.encrypt([3, 8, 1])
serialized_args = client.specs.serialize_public_args(args)
serialized_arg = arg.serialize()
serialized_evaluation_keys = client.evaluation_keys.serialize()
deserialized_args = server.client_specs.deserialize_public_args(serialized_args)
deserialized_arg = Data.deserialize(serialized_arg)
deserialized_evaluation_keys = EvaluationKeys.deserialize(serialized_evaluation_keys)
result = server.run(deserialized_args, deserialized_evaluation_keys)
serialized_result = server.client_specs.serialize_public_result(result)
result = server.run(deserialized_arg, evaluation_keys=deserialized_evaluation_keys)
serialized_result = result.serialize()
deserialized_result = client.specs.deserialize_public_result(serialized_result)
deserialized_result = Data.deserialize(serialized_result)
output = client.decrypt(deserialized_result)
assert np.array_equal(output, [45, 50, 43])
@@ -226,18 +275,18 @@ def test_client_server_api_crt(helpers):
]
for client in clients:
args = client.encrypt([100, 150, 10])
arg = client.encrypt([100, 150, 10])
serialized_args = client.specs.serialize_public_args(args)
serialized_arg = arg.serialize()
serialized_evaluation_keys = client.evaluation_keys.serialize()
deserialized_args = server.client_specs.deserialize_public_args(serialized_args)
deserialized_arg = Data.deserialize(serialized_arg)
deserialized_evaluation_keys = EvaluationKeys.deserialize(serialized_evaluation_keys)
result = server.run(deserialized_args, deserialized_evaluation_keys)
serialized_result = server.client_specs.serialize_public_result(result)
result = server.run(deserialized_arg, evaluation_keys=deserialized_evaluation_keys)
serialized_result = result.serialize()
deserialized_result = client.specs.deserialize_public_result(serialized_result)
deserialized_result = Data.deserialize(serialized_result)
output = client.decrypt(deserialized_result)
assert np.array_equal(output, [100**2, 150**2, 10**2])
@@ -279,18 +328,18 @@ def test_client_server_api_via_mlir(helpers):
]
for client in clients:
args = client.encrypt([3, 8, 1])
arg = client.encrypt([3, 8, 1])
serialized_args = client.specs.serialize_public_args(args)
serialized_arg = arg.serialize()
serialized_evaluation_keys = client.evaluation_keys.serialize()
deserialized_args = server.client_specs.deserialize_public_args(serialized_args)
deserialized_arg = Data.deserialize(serialized_arg)
deserialized_evaluation_keys = EvaluationKeys.deserialize(serialized_evaluation_keys)
result = server.run(deserialized_args, deserialized_evaluation_keys)
serialized_result = server.client_specs.serialize_public_result(result)
result = server.run(deserialized_arg, evaluation_keys=deserialized_evaluation_keys)
serialized_result = result.serialize()
deserialized_result = client.specs.deserialize_public_result(serialized_result)
deserialized_result = Data.deserialize(serialized_result)
output = client.decrypt(deserialized_result)
assert np.array_equal(output, [45, 50, 43])

View File

@@ -116,7 +116,7 @@ def test_keys_serialize_deserialize(helpers):
client1.keys.generate()
sample = client1.encrypt(5)
evaluation = server.run(sample, client1.evaluation_keys)
evaluation = server.run(sample, evaluation_keys=client1.evaluation_keys)
client2 = fhe.Client(server.client_specs)
client2.keys = fhe.Keys.deserialize(client1.keys.serialize())