From 90c95e380ceb41852af36900a5e7f02ba5d9323e Mon Sep 17 00:00:00 2001 From: Umut Date: Thu, 12 May 2022 09:53:14 +0200 Subject: [PATCH] feat: implement client server architecture --- concrete/numpy/__init__.py | 3 + concrete/numpy/compilation/__init__.py | 3 + concrete/numpy/compilation/circuit.py | 334 ++----------------------- concrete/numpy/compilation/client.py | 210 ++++++++++++++++ concrete/numpy/compilation/compiler.py | 8 +- concrete/numpy/compilation/server.py | 182 ++++++++++++++ concrete/numpy/compilation/specs.py | 59 +++++ tests/compilation/test_circuit.py | 142 ++++++----- 8 files changed, 559 insertions(+), 382 deletions(-) create mode 100644 concrete/numpy/compilation/client.py create mode 100644 concrete/numpy/compilation/server.py create mode 100644 concrete/numpy/compilation/specs.py diff --git a/concrete/numpy/__init__.py b/concrete/numpy/__init__.py index 3f8d73b3d..5b1bae7e1 100644 --- a/concrete/numpy/__init__.py +++ b/concrete/numpy/__init__.py @@ -4,10 +4,13 @@ Export everything that users might need. from .compilation import ( Circuit, + Client, + ClientSpecs, Compiler, Configuration, DebugArtifacts, EncryptionStatus, + Server, compiler, ) from .extensions import LookupTable diff --git a/concrete/numpy/compilation/__init__.py b/concrete/numpy/compilation/__init__.py index 79a725374..147137419 100644 --- a/concrete/numpy/compilation/__init__.py +++ b/concrete/numpy/compilation/__init__.py @@ -4,6 +4,9 @@ Glue the compilation process together. from .artifacts import DebugArtifacts from .circuit import Circuit +from .client import Client from .compiler import Compiler, EncryptionStatus from .configuration import Configuration from .decorator import compiler +from .server import Server +from .specs import ClientSpecs diff --git a/concrete/numpy/compilation/circuit.py b/concrete/numpy/compilation/circuit.py index 1d506c3cb..cc37de1c1 100644 --- a/concrete/numpy/compilation/circuit.py +++ b/concrete/numpy/compilation/circuit.py @@ -2,267 +2,60 @@ Declaration of `Circuit` class. """ -import pickle -import shutil -import tempfile from pathlib import Path -from typing import Any, List, Optional, Tuple, Union, cast +from typing import Any, Optional, Tuple, Union, cast import numpy as np -from concrete.compiler import ( - ClientParameters, - ClientSupport, - CompilationOptions, - JITCompilationResult, - JITLambda, - JITSupport, - KeySet, - KeySetCache, - LibraryCompilationResult, - LibraryLambda, - LibrarySupport, - PublicArguments, - PublicResult, -) +from concrete.compiler import PublicArguments, PublicResult from ..dtypes import Integer from ..internal.utils import assert_that from ..representation import Graph -from ..values import Value +from .client import Client from .configuration import Configuration +from .server import Server class Circuit: """ - Circuit class, to combine computation graph and compiler engine into a single object. + Circuit class, to combine computation graph, mlir, client and server into a single object. """ - # pylint: disable=too-many-instance-attributes - configuration: Configuration graph: Graph mlir: str - client_parameters: Optional[ClientParameters] - _support: Union[JITSupport, LibrarySupport] - _compilation_result: Union[JITCompilationResult, LibraryCompilationResult] - _server_lambda: Union[JITLambda, LibraryLambda] + client: Client + server: Server - _output_dir: Optional[tempfile.TemporaryDirectory] - - _keyset: KeySet - _keyset_cache: KeySetCache - - # pylint: enable=too-many-instance-attributes - - def __init__( - self, - configuration: Configuration, - graph: Graph, - mlir: str, - support: Optional[Union[JITSupport, LibrarySupport]] = None, - compilation_result: Optional[Union[JITCompilationResult, LibraryCompilationResult]] = None, - server_lambda: Optional[Union[JITLambda, LibraryLambda]] = None, - output_dir: Optional[tempfile.TemporaryDirectory] = None, - ): - self.configuration = configuration - self._output_dir = output_dir + def __init__(self, graph: Graph, mlir: str, configuration: Optional[Configuration] = None): + self.configuration = configuration if configuration is not None else Configuration() self.graph = graph self.mlir = mlir - self.client_parameters = None - if configuration.virtual: - assert_that(configuration.enable_unsafe_features) + if self.configuration.virtual: + assert_that(self.configuration.enable_unsafe_features) return - assert support is not None - assert compilation_result is not None - assert server_lambda is not None + output_signs = [] + for i in range(len(graph.output_nodes)): # pylint: disable=consider-using-enumerate + output_value = graph.output_nodes[i].output + assert_that(isinstance(output_value.dtype, Integer)) + output_dtype = cast(Integer, output_value.dtype) + output_signs.append(output_dtype.is_signed) - assert_that( - ( - isinstance(support, JITSupport) - and isinstance(compilation_result, JITCompilationResult) - and isinstance(server_lambda, JITLambda) - ) - or ( - isinstance(support, LibrarySupport) - and isinstance(compilation_result, LibraryCompilationResult) - and isinstance(server_lambda, LibraryLambda) - ) - ) + self.server = Server.create(mlir, output_signs, self.configuration) - self._support = support - self._compilation_result = compilation_result - self._server_lambda = server_lambda - - self._output_dir = output_dir - if isinstance(support, LibrarySupport): - assert output_dir is not None - assert_that(support.output_dir_path == str(output_dir.name)) - - self.client_parameters = support.load_client_parameters(compilation_result) - keyset = None keyset_cache = None - - if configuration.use_insecure_key_cache: - assert_that(configuration.enable_unsafe_features) + if self.configuration.use_insecure_key_cache: + assert_that(self.configuration.enable_unsafe_features) location = Configuration.insecure_key_cache_location() if location is not None: - keyset_cache = KeySetCache.new(str(location)) + keyset_cache = str(location) - self._keyset = keyset - self._keyset_cache = keyset_cache - - @staticmethod - def create(graph: Graph, mlir: str, configuration: Optional[Configuration] = None) -> "Circuit": - """ - Create a circuit from a graph and its MLIR. - - Args: - graph (Graph): - graph of the circuit - - mlir (str): - mlir of the circuit - - configuration (Optional[Configuration], default = None): - configuration to use - - Returns: - Circuit: - circuit of graph - """ - - configuration = configuration if configuration is not None else Configuration() - if configuration.virtual: - return Circuit(configuration, graph, mlir) - - options = CompilationOptions.new("main") - - options.set_loop_parallelize(configuration.loop_parallelize) - options.set_dataflow_parallelize(configuration.dataflow_parallelize) - options.set_auto_parallelize(configuration.auto_parallelize) - options.set_p_error(configuration.p_error) - - if configuration.jit: - - output_dir = None - - support = JITSupport.new() - compilation_result = support.compile(mlir, options) - server_lambda = support.load_server_lambda(compilation_result) - - else: - - # pylint: disable=consider-using-with - output_dir = tempfile.TemporaryDirectory() - output_dir_path = Path(output_dir.name) - # pylint: enable=consider-using-with - - support = LibrarySupport.new( - str(output_dir_path), generateCppHeader=False, generateStaticLib=False - ) - compilation_result = support.compile(mlir, options) - server_lambda = support.load_server_lambda(compilation_result) - - return Circuit( - configuration, - graph, - mlir, - support, - compilation_result, - server_lambda, - output_dir, - ) - - def save(self, path: Union[str, Path]): - """ - Save the circuit into the given path in zip format. - - Args: - path (Union[str, Path]): - path to save the circuit - """ - - if not self.configuration.virtual and self.configuration.jit: - raise RuntimeError("JIT Circuits cannot be saved") - - if self.configuration.virtual: - # pylint: disable=consider-using-with - self._output_dir = tempfile.TemporaryDirectory() - # pylint: enable=consider-using-with - - assert self._output_dir is not None - output_dir_path = Path(self._output_dir.name) - - with open(output_dir_path / "out.pickle", "wb") as f: - attributes = { - "configuration": self.configuration, - "graph": self.graph, - "mlir": self.mlir, - } - pickle.dump(attributes, f) - - path = str(path) - if path.endswith(".zip"): - path = path[: len(path) - 4] - - shutil.make_archive(path, "zip", str(output_dir_path)) - - if self.configuration.virtual: - self.cleanup() - self._output_dir = None - - @staticmethod - def load(path: Union[str, Path]) -> "Circuit": - """ - Load the circuit from the given path in zip format. - - Args: - path (Union[str, Path]): - path to load the circuit from - - Returns: - Circuit: - circuit loaded from the filesystem - """ - - # pylint: disable=consider-using-with - output_dir = tempfile.TemporaryDirectory() - output_dir_path = Path(output_dir.name) - # pylint: enable=consider-using-with - - shutil.unpack_archive(path, str(output_dir_path), "zip") - - with open(output_dir_path / "out.pickle", "rb") as f: - attributes = pickle.load(f) - - configuration = attributes["configuration"] - graph = attributes["graph"] - mlir = attributes["mlir"] - - if configuration.virtual: - output_dir.cleanup() - return Circuit(configuration, graph, mlir) - - support = LibrarySupport.new( - str(output_dir_path), generateCppHeader=False, generateStaticLib=False - ) - compilation_result = support.reload("main") - server_lambda = support.load_server_lambda(compilation_result) - - return Circuit( - configuration, - graph, - mlir, - support, - compilation_result, - server_lambda, - output_dir, - ) + self.client = Client(self.server.client_specs, keyset_cache) def __str__(self): return self.graph.format() @@ -310,8 +103,7 @@ class Circuit: if self.configuration.virtual: raise RuntimeError("Virtual circuits cannot use `keygen` method") - if self._keyset is None or force: - self._keyset = ClientSupport.key_set(self.client_parameters, self._keyset_cache) + self.client.keygen(force) def encrypt(self, *args: Union[int, np.ndarray]) -> PublicArguments: """ @@ -329,51 +121,7 @@ class Circuit: if self.configuration.virtual: raise RuntimeError("Virtual circuits cannot use `encrypt` method") - if len(args) != len(self.graph.input_nodes): - raise ValueError(f"Expected {len(self.graph.input_nodes)} inputs but got {len(args)}") - - sanitized_args = {} - for index, node in self.graph.input_nodes.items(): - arg = args[index] - is_valid = isinstance(arg, (int, np.integer)) or ( - isinstance(arg, np.ndarray) and np.issubdtype(arg.dtype, np.integer) - ) - - expected_value = node.output - - assert_that(isinstance(expected_value.dtype, Integer)) - expected_dtype = cast(Integer, expected_value.dtype) - - if is_valid: - expected_min = expected_dtype.min() - expected_max = expected_dtype.max() - expected_shape = expected_value.shape - - actual_min = arg if isinstance(arg, int) else arg.min() - actual_max = arg if isinstance(arg, int) else arg.max() - actual_shape = () if isinstance(arg, int) else arg.shape - - is_valid = ( - actual_min >= expected_min - and actual_max <= expected_max - and actual_shape == expected_shape - ) - - if is_valid: - sanitized_args[index] = arg if isinstance(arg, int) else arg.astype(np.uint8) - - if not is_valid: - actual_value = Value.of(arg, is_encrypted=expected_value.is_encrypted) - raise ValueError( - f"Expected argument {index} to be {expected_value} but it's {actual_value}" - ) - - self.keygen(force=False) - return ClientSupport.encrypt_arguments( - self.client_parameters, - self._keyset, - [sanitized_args[i] for i in range(len(sanitized_args))], - ) + return self.client.encrypt(*args) def run(self, args: PublicArguments) -> PublicResult: """ @@ -391,7 +139,7 @@ class Circuit: if self.configuration.virtual: raise RuntimeError("Virtual circuits cannot use `run` method") - return self._support.server_call(self._server_lambda, args) + return self.server.run(args) def decrypt( self, @@ -412,34 +160,7 @@ class Circuit: if self.configuration.virtual: raise RuntimeError("Virtual circuits cannot use `decrypt` method") - results = ClientSupport.decrypt_result(self._keyset, result) - if not isinstance(results, tuple): - results = (results,) - - sanitized_results: List[Union[int, np.ndarray]] = [] - - for index, node in self.graph.output_nodes.items(): - expected_value = node.output - assert_that(isinstance(expected_value.dtype, Integer)) - - expected_dtype = cast(Integer, expected_value.dtype) - n = expected_dtype.bit_width - - result = results[index] % (2**n) - if expected_dtype.is_signed: - if isinstance(result, int): - sanititzed_result = result if result < (2 ** (n - 1)) else result - (2**n) - sanitized_results.append(sanititzed_result) - else: - result = result.astype(np.longlong) # to prevent overflows in numpy - sanititzed_result = np.where(result < (2 ** (n - 1)), result, result - (2**n)) - sanitized_results.append(sanititzed_result.astype(np.int8)) - else: - sanitized_results.append( - result if isinstance(result, int) else result.astype(np.uint8) - ) - - return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results) + return self.client.decrypt(result) def encrypt_run_decrypt(self, *args: Any) -> Any: """ @@ -464,5 +185,4 @@ class Circuit: Cleanup the temporary library output directory. """ - if self._output_dir is not None: - self._output_dir.cleanup() + self.server.cleanup() diff --git a/concrete/numpy/compilation/client.py b/concrete/numpy/compilation/client.py new file mode 100644 index 000000000..299590805 --- /dev/null +++ b/concrete/numpy/compilation/client.py @@ -0,0 +1,210 @@ +""" +Declaration of `Client` class. +""" + +import json +import shutil +import tempfile +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import numpy as np +from concrete.compiler import ClientSupport, KeySet, KeySetCache, PublicArguments, PublicResult + +from ..dtypes.integer import SignedInteger, UnsignedInteger +from ..internal.utils import assert_that +from ..values.value import Value +from .specs import ClientSpecs + + +class Client: + """ + Client class, which can be used to manage keys, encrypt arguments and decrypt results. + """ + + specs: ClientSpecs + + _keyset: Optional[KeySet] + _keyset_cache: Optional[KeySetCache] + + def __init__( + self, + client_specs: ClientSpecs, + keyset_cache_directory: Optional[Union[str, Path]] = None, + ): + self.specs = client_specs + + self._keyset = None + self._keyset_cache = None + + if keyset_cache_directory is not None: + self._keyset_cache = KeySetCache.new(str(keyset_cache_directory)) + + def save(self, path: Union[str, Path]): + """ + Save the client into the given path in zip format. + + Args: + path (Union[str, Path]): + path to save the client + """ + + with tempfile.TemporaryDirectory() as tmp_dir: + with open(Path(tmp_dir) / "client.specs.json", "w", encoding="utf-8") as f: + f.write(self.specs.serialize()) + + path = str(path) + if path.endswith(".zip"): + path = path[: len(path) - 4] + + shutil.make_archive(path, "zip", tmp_dir) + + @staticmethod + def load( + path: Union[str, Path], + keyset_cache_directory: Optional[Union[str, Path]] = None, + ) -> "Client": + """ + Load the client from the given path in zip format. + + Args: + path (Union[str, Path]): + path to load the client from + + keyset_cache_directory (Optional[Union[str, Path]], default = None): + keyset cache directory to use + + Returns: + Client: + client loaded from the filesystem + """ + + with tempfile.TemporaryDirectory() as tmp_dir: + shutil.unpack_archive(path, tmp_dir, "zip") + with open(Path(tmp_dir) / "client.specs.json", "r", encoding="utf-8") as f: + client_specs = ClientSpecs.unserialize(f.read()) + + return Client(client_specs, keyset_cache_directory) + + def keygen(self, force: bool = False): + """ + Generate keys required for homomorphic evaluation. + + Args: + force (bool, default = False): + whether to generate new keys even if keys are already generated + """ + + if self._keyset is None or force: + self._keyset = ClientSupport.key_set(self.specs.client_parameters, self._keyset_cache) + + def encrypt(self, *args: Union[int, np.ndarray]) -> PublicArguments: + """ + Prepare inputs to be run on the circuit. + + Args: + *args (Union[int, numpy.ndarray]): + inputs to the circuit + + Returns: + PublicArguments: + encrypted and plain arguments as well as public keys + """ + + client_parameters_json = json.loads(self.specs.client_parameters.serialize()) + assert_that("inputs" in client_parameters_json) + input_specs = client_parameters_json["inputs"] + + if len(args) != len(input_specs): + raise ValueError(f"Expected {len(input_specs)} inputs but got {len(args)}") + + sanitized_args = {} + for index, spec in enumerate(input_specs): + arg = args[index] + is_valid = isinstance(arg, (int, np.integer)) or ( + isinstance(arg, np.ndarray) and np.issubdtype(arg.dtype, np.integer) + ) + + width = spec["shape"]["width"] + shape = tuple(spec["shape"]["dimensions"]) + is_encrypted = spec["encryption"] is not None + + expected_dtype = UnsignedInteger(width) + expected_value = Value(expected_dtype, shape, is_encrypted) + if is_valid: + expected_min = expected_dtype.min() + expected_max = expected_dtype.max() + + actual_min = arg if isinstance(arg, int) else arg.min() + actual_max = arg if isinstance(arg, int) else arg.max() + actual_shape = () if isinstance(arg, int) else arg.shape + + is_valid = ( + actual_min >= expected_min + and actual_max <= expected_max + and actual_shape == expected_value.shape + ) + + if is_valid: + sanitized_args[index] = arg if isinstance(arg, int) else arg.astype(np.uint8) + + if not is_valid: + actual_value = Value.of(arg, is_encrypted=is_encrypted) + raise ValueError( + f"Expected argument {index} to be {expected_value} but it's {actual_value}" + ) + + self.keygen(force=False) + return ClientSupport.encrypt_arguments( + self.specs.client_parameters, + self._keyset, + [sanitized_args[i] for i in range(len(sanitized_args))], + ) + + def decrypt( + self, + result: PublicResult, + ) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]: + """ + Decrypt result of homomorphic evaluaton. + + Args: + result (PublicResult): + encrypted result of homomorphic evaluaton + + Returns: + Union[int, numpy.ndarray]: + clear result of homomorphic evaluaton + """ + + results = ClientSupport.decrypt_result(self._keyset, result) + if not isinstance(results, tuple): + results = (results,) + + sanitized_results: List[Union[int, np.ndarray]] = [] + + client_parameters_json = json.loads(self.specs.client_parameters.serialize()) + assert_that("outputs" in client_parameters_json) + output_specs = client_parameters_json["outputs"] + + for index, spec in enumerate(output_specs): + n = spec["shape"]["width"] + expected_dtype = ( + SignedInteger(n) if self.specs.output_signs[index] else UnsignedInteger(n) + ) + + result = results[index] % (2**n) + if expected_dtype.is_signed: + if isinstance(result, int): + sanititzed_result = result if result < (2 ** (n - 1)) else result - (2**n) + sanitized_results.append(sanititzed_result) + else: + result = result.astype(np.longlong) # to prevent overflows in numpy + sanititzed_result = np.where(result < (2 ** (n - 1)), result, result - (2**n)) + sanitized_results.append(sanititzed_result.astype(np.int8)) + else: + sanitized_results.append( + result if isinstance(result, int) else result.astype(np.uint8) + ) + + return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results) diff --git a/concrete/numpy/compilation/compiler.py b/concrete/numpy/compilation/compiler.py index 7978686ee..6189d473e 100644 --- a/concrete/numpy/compilation/compiler.py +++ b/concrete/numpy/compilation/compiler.py @@ -375,10 +375,12 @@ class Compiler: print() - circuit = Circuit.create(self.graph, mlir, self.configuration) + circuit = Circuit(self.graph, mlir, self.configuration) if not self.configuration.virtual: - assert circuit.client_parameters is not None - self.artifacts.add_client_parameters(circuit.client_parameters.serialize()) + assert circuit.client.specs.client_parameters is not None + self.artifacts.add_client_parameters( + circuit.client.specs.client_parameters.serialize() + ) return circuit except Exception: # pragma: no cover diff --git a/concrete/numpy/compilation/server.py b/concrete/numpy/compilation/server.py new file mode 100644 index 000000000..d0d5afeeb --- /dev/null +++ b/concrete/numpy/compilation/server.py @@ -0,0 +1,182 @@ +""" +Declaration of `Server` class. +""" + +import shutil +import tempfile +from pathlib import Path +from typing import List, Optional, Union + +from concrete.compiler import ( + CompilationOptions, + JITCompilationResult, + JITLambda, + JITSupport, + LibraryCompilationResult, + LibraryLambda, + LibrarySupport, + PublicArguments, + PublicResult, +) + +from ..internal.utils import assert_that +from .configuration import Configuration +from .specs import ClientSpecs + + +class Server: + """ + Client class, which can be used to perform homomorphic computation. + """ + + client_specs: ClientSpecs + + _output_dir: Optional[tempfile.TemporaryDirectory] + _support: Union[JITSupport, LibrarySupport] + _compilation_result: Union[JITCompilationResult, LibraryCompilationResult] + _server_lambda: Union[JITLambda, LibraryLambda] + + def __init__( + self, + client_specs: ClientSpecs, + output_dir: Optional[tempfile.TemporaryDirectory], + support: Union[JITSupport, LibrarySupport], + compilation_result: Union[JITCompilationResult, LibraryCompilationResult], + server_lambda: Union[JITLambda, LibraryLambda], + ): + self.client_specs = client_specs + + self._output_dir = output_dir + self._support = support + self._compilation_result = compilation_result + self._server_lambda = server_lambda + + assert_that( + support.load_client_parameters(compilation_result).serialize() + == client_specs.client_parameters.serialize() + ) + + @staticmethod + def create(mlir: str, output_signs: List[bool], configuration: Configuration) -> "Server": + """ + Create a server using MLIR and output sign information. + + Args: + mlir (str): + mlir to compile + + output_signs (List[bool]): + sign status of the outputs + + configuration (Optional[Configuration], default = None): + configuration to use + """ + + options = CompilationOptions.new("main") + + options.set_loop_parallelize(configuration.loop_parallelize) + options.set_dataflow_parallelize(configuration.dataflow_parallelize) + options.set_auto_parallelize(configuration.auto_parallelize) + options.set_p_error(configuration.p_error) + + if configuration.jit: + + output_dir = None + + support = JITSupport.new() + compilation_result = support.compile(mlir, options) + server_lambda = support.load_server_lambda(compilation_result) + + else: + + # pylint: disable=consider-using-with + output_dir = tempfile.TemporaryDirectory() + output_dir_path = Path(output_dir.name) + # pylint: enable=consider-using-with + + support = LibrarySupport.new( + str(output_dir_path), generateCppHeader=False, generateStaticLib=False + ) + compilation_result = support.compile(mlir, options) + server_lambda = support.load_server_lambda(compilation_result) + + client_specs = ClientSpecs(support.load_client_parameters(compilation_result), output_signs) + return Server(client_specs, output_dir, support, compilation_result, server_lambda) + + def save(self, path: Union[str, Path]): + """ + Save the server into the given path in zip format. + + Args: + path (Union[str, Path]): + path to save the server + """ + + if self._output_dir is None: + raise RuntimeError("Just-in-Time compilation cannot be saved") + + with open(Path(self._output_dir.name) / "client.specs.json", "w", encoding="utf-8") as f: + f.write(self.client_specs.serialize()) + + path = str(path) + if path.endswith(".zip"): + path = path[: len(path) - 4] + + shutil.make_archive(path, "zip", self._output_dir.name) + + @staticmethod + def load(path: Union[str, Path]) -> "Server": + """ + Load the server from the given path in zip format. + + Args: + path (Union[str, Path]): + path to load the server from + + Returns: + Server: + server loaded from the filesystem + """ + + # pylint: disable=consider-using-with + output_dir = tempfile.TemporaryDirectory() + output_dir_path = Path(output_dir.name) + # pylint: enable=consider-using-with + + shutil.unpack_archive(path, str(output_dir_path), "zip") + + with open(output_dir_path / "client.specs.json", "r", encoding="utf-8") as f: + client_specs = ClientSpecs.unserialize(f.read()) + + support = LibrarySupport.new( + str(output_dir_path), + generateCppHeader=False, + generateStaticLib=False, + ) + compilation_result = support.reload("main") + server_lambda = support.load_server_lambda(compilation_result) + + return Server(client_specs, output_dir, support, compilation_result, server_lambda) + + def run(self, args: PublicArguments) -> PublicResult: + """ + Evaluate using encrypted arguments. + + Args: + args (PublicArguments): + encrypted arguments of the computation + + Returns: + PublicResult: + encrypted result of the computation + """ + + return self._support.server_call(self._server_lambda, args) + + def cleanup(self): + """ + Cleanup the temporary library output directory. + """ + + if self._output_dir is not None: + self._output_dir.cleanup() diff --git a/concrete/numpy/compilation/specs.py b/concrete/numpy/compilation/specs.py new file mode 100644 index 000000000..b5e70df5c --- /dev/null +++ b/concrete/numpy/compilation/specs.py @@ -0,0 +1,59 @@ +""" +Declaration of `ClientSpecs` class. +""" + +import json +from typing import List + +from concrete.compiler import ClientParameters + + +class ClientSpecs: + """ + ClientSpecs class, to create Client objects. + """ + + client_parameters: ClientParameters + output_signs: List[bool] + + def __init__(self, client_parameters: ClientParameters, output_signs: List[bool]): + self.client_parameters = client_parameters + self.output_signs = output_signs + + def serialize(self) -> str: + """ + Serialize client specs into a string representation. + + Returns: + str: + string representation of the client specs + """ + + client_parameters_json = json.loads(self.client_parameters.serialize()) + return json.dumps( + { + "client_parameters": client_parameters_json, + "output_signs": self.output_signs, + } + ) + + @staticmethod + def unserialize(serialized_client_specs: str) -> "ClientSpecs": + """ + Create client specs from its string representation. + + Args: + serialized_client_specs (str): + client specs to unserialize + + Returns: + ClientSpecs: + unserialized client specs + """ + + raw_specs = json.loads(serialized_client_specs) + + client_parameters_bytes = json.dumps(raw_specs["client_parameters"]).encode("utf-8") + client_parameters = ClientParameters.unserialize(client_parameters_bytes) + + return ClientSpecs(client_parameters, raw_specs["output_signs"]) diff --git a/tests/compilation/test_circuit.py b/tests/compilation/test_circuit.py index cdd4e507f..100ede9e8 100644 --- a/tests/compilation/test_circuit.py +++ b/tests/compilation/test_circuit.py @@ -7,8 +7,9 @@ from pathlib import Path import numpy as np import pytest +from concrete.compiler import PublicArguments, PublicResult -from concrete.numpy import Circuit +from concrete.numpy import Client, ClientSpecs, Configuration, Server from concrete.numpy.compilation import compiler @@ -98,7 +99,7 @@ def test_circuit_bad_run(helpers): circuit.encrypt_run_decrypt(-1, 11) assert str(excinfo.value) == ( - "Expected argument 0 to be EncryptedScalar but it's EncryptedScalar" + "Expected argument 0 to be EncryptedScalar but it's EncryptedScalar" ) # with negative argument 1 @@ -108,7 +109,7 @@ def test_circuit_bad_run(helpers): circuit.encrypt_run_decrypt(1, -11) assert str(excinfo.value) == ( - "Expected argument 1 to be EncryptedScalar but it's EncryptedScalar" + "Expected argument 1 to be EncryptedScalar but it's EncryptedScalar" ) # with large argument 0 @@ -118,7 +119,7 @@ def test_circuit_bad_run(helpers): circuit.encrypt_run_decrypt(100, 10) assert str(excinfo.value) == ( - "Expected argument 0 to be EncryptedScalar but it's EncryptedScalar" + "Expected argument 0 to be EncryptedScalar but it's EncryptedScalar" ) # with large argument 1 @@ -128,7 +129,7 @@ def test_circuit_bad_run(helpers): circuit.encrypt_run_decrypt(1, 100) assert str(excinfo.value) == ( - "Expected argument 1 to be EncryptedScalar but it's EncryptedScalar" + "Expected argument 1 to be EncryptedScalar but it's EncryptedScalar" ) @@ -167,9 +168,69 @@ def test_circuit_virtual_explicit_api(helpers): assert str(excinfo.value) == "Virtual circuits cannot use `decrypt` method" -def test_circuit_bad_save(helpers): +def test_client_server_api(helpers): """ - Test `save` method of `Circuit` class with bad parameters. + Test client/server API. + """ + + configuration = helpers.configuration() + + @compiler({"x": "encrypted"}) + def function(x): + return x + 42 + + inputset = range(10) + circuit = function.compile(inputset, configuration.fork(jit=False)) + + # for coverage + circuit.keygen() + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_dir_path = Path(tmp_dir) + + server_path = tmp_dir_path / "server.zip" + circuit.server.save(server_path) + + client_path = tmp_dir_path / "client.zip" + circuit.client.save(client_path) + + circuit.cleanup() + + server = Server.load(server_path) + + serialized_client_specs = server.client_specs.serialize() + client_specs = ClientSpecs.unserialize(serialized_client_specs) + + clients = [ + Client(client_specs, Configuration.insecure_key_cache_location()), + Client.load(client_path, Configuration.insecure_key_cache_location()), + ] + + for client in clients: + raw_input = client.encrypt(4) + serialized_input = raw_input.serialize() + + unserialized_input = PublicArguments.unserialize( + server.client_specs.client_parameters, + serialized_input, + ) + evaluation = server.run(unserialized_input) + serialized_evaluation = evaluation.serialize() + + unserialized_evaluation = PublicResult.unserialize( + client.specs.client_parameters, + serialized_evaluation, + ) + output = client.decrypt(unserialized_evaluation) + + assert output == 46 + + server.cleanup() + + +def test_bad_server_save(helpers): + """ + Test `save` method of `Server` class with bad parameters. """ configuration = helpers.configuration() @@ -182,69 +243,6 @@ def test_circuit_bad_save(helpers): circuit = function.compile(inputset, configuration) with pytest.raises(RuntimeError) as excinfo: - circuit.save("circuit.zip") + circuit.server.save("test.zip") - assert str(excinfo.value) == "JIT Circuits cannot be saved" - - -@pytest.mark.parametrize( - "virtual", - [False, True], -) -def test_circuit_save_load(virtual, helpers): - """ - Test `save`, `load`, and `cleanup` methods of `Circuit` class. - """ - - configuration = helpers.configuration().fork(jit=False, virtual=virtual) - - def save(base): - @compiler({"x": "encrypted"}) - def function(x): - return x + 42 - - inputset = range(10) - circuit = function.compile(inputset, configuration) - - circuit.save(base / "circuit.zip") - circuit.cleanup() - - def load(base): - circuit = Circuit.load(base / "circuit.zip") - - helpers.check_str( - """ - -%0 = x # EncryptedScalar -%1 = 42 # ClearScalar -%2 = add(%0, %1) # EncryptedScalar -return %2 - - """, - str(circuit), - ) - if virtual: - helpers.check_str("Virtual circuits doesn't have MLIR.", circuit.mlir) - else: - helpers.check_str( - """ - -module { - func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> { - %c42_i7 = arith.constant 42 : i7 - %0 = "FHE.add_eint_int"(%arg0, %c42_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6> - return %0 : !FHE.eint<6> - } -} - - """, - circuit.mlir, - ) - helpers.check_execution(circuit, lambda x: x + 42, 4) - - circuit.cleanup() - - with tempfile.TemporaryDirectory() as tmp: - path = Path(tmp) - save(path) - load(path) + assert str(excinfo.value) == "Just-in-Time compilation cannot be saved"