From 05282285a34c1a81bef9423d302b79b521cdeb42 Mon Sep 17 00:00:00 2001 From: Umut Date: Tue, 23 Aug 2022 16:20:38 +0200 Subject: [PATCH] feat: allow saving servers via MLIR --- concrete/numpy/compilation/configuration.py | 14 ++++- concrete/numpy/compilation/server.py | 70 +++++++++++++++++++-- tests/compilation/test_circuit.py | 60 ++++++++++++++++++ tests/compilation/test_configuration.py | 6 ++ 4 files changed, 142 insertions(+), 8 deletions(-) diff --git a/concrete/numpy/compilation/configuration.py b/concrete/numpy/compilation/configuration.py index b14f41530..a5c168df0 100644 --- a/concrete/numpy/compilation/configuration.py +++ b/concrete/numpy/compilation/configuration.py @@ -114,10 +114,20 @@ class Configuration: raise TypeError(f"Unexpected keyword argument '{name}'") hint = hints[name] - if not isinstance(value, hint): # type: ignore + is_correctly_typed = True + + if name == "insecure_key_cache_location": + if not (value is None or isinstance(value, str)): + is_correctly_typed = False + + elif not isinstance(value, hint): # type: ignore + is_correctly_typed = False + + if not is_correctly_typed: + expected = hint.__name__ if hasattr(hint, "__name__") else str(hint) raise TypeError( f"Unexpected type for keyword argument '{name}' " - f"(expected '{hint.__name__}', got '{type(value).__name__}')" + f"(expected '{expected}', got '{type(value).__name__}')" ) setattr(result, name, value) diff --git a/concrete/numpy/compilation/server.py b/concrete/numpy/compilation/server.py index b4adff672..fdb282727 100644 --- a/concrete/numpy/compilation/server.py +++ b/concrete/numpy/compilation/server.py @@ -2,6 +2,7 @@ Declaration of `Server` class. """ +import json import shutil import tempfile from pathlib import Path @@ -37,6 +38,9 @@ class Server: _compilation_result: Union[JITCompilationResult, LibraryCompilationResult] _server_lambda: Union[JITLambda, LibraryLambda] + _mlir: Optional[str] + _configuration: Optional[Configuration] + def __init__( self, client_specs: ClientSpecs, @@ -51,6 +55,7 @@ class Server: self._support = support self._compilation_result = compilation_result self._server_lambda = server_lambda + self._mlir = None assert_that( support.load_client_parameters(compilation_result).serialize() @@ -112,27 +117,61 @@ class Server: client_parameters = support.load_client_parameters(compilation_result) client_specs = ClientSpecs(input_signs, client_parameters, output_signs) - return Server(client_specs, output_dir, support, compilation_result, server_lambda) - def save(self, path: Union[str, Path]): + result = Server(client_specs, output_dir, support, compilation_result, server_lambda) + + # pylint: disable=protected-access + result._mlir = mlir + result._configuration = configuration + # pylint: enable=protected-access + + return result + + def save(self, path: Union[str, Path], via_mlir: bool = False): """ Save the server into the given path in zip format. Args: path (Union[str, Path]): path to save the server + + via_mlir (bool, default = False) + export using the MLIR code of the program, + this will make the export cross-platform """ + path = str(path) + if path.endswith(".zip"): + path = path[: len(path) - 4] + + if via_mlir: + if self._mlir is None or self._configuration is None: + raise RuntimeError("Loaded server objects cannot be saved again via MLIR") + + with tempfile.TemporaryDirectory() as tmp: + + with open(Path(tmp) / "circuit.mlir", "w", encoding="utf-8") as f: + f.write(self._mlir) + + with open(Path(tmp) / "input_signs.json", "w", encoding="utf-8") as f: + f.write(json.dumps(self.client_specs.input_signs)) + + with open(Path(tmp) / "output_signs.json", "w", encoding="utf-8") as f: + f.write(json.dumps(self.client_specs.output_signs)) + + with open(Path(tmp) / "configuration.json", "w", encoding="utf-8") as f: + f.write(json.dumps(self._configuration.__dict__)) + + shutil.make_archive(path, "zip", tmp) + + return + if self._output_dir is None: raise RuntimeError("Just-in-Time compilation cannot be saved") with open(Path(self._output_dir.name) / "client.specs.json", "w", encoding="utf-8") as f: f.write(self.client_specs.serialize()) - path = str(path) - if path.endswith(".zip"): - path = path[: len(path) - 4] - shutil.make_archive(path, "zip", self._output_dir.name) @staticmethod @@ -156,6 +195,25 @@ class Server: shutil.unpack_archive(path, str(output_dir_path), "zip") + if (output_dir_path / "circuit.mlir").exists(): + with open(output_dir_path / "circuit.mlir", "r", encoding="utf-8") as f: + mlir = f.read() + + with open(output_dir_path / "input_signs.json", "r", encoding="utf-8") as f: + input_signs = json.load(f) + assert_that(isinstance(input_signs, list)) + assert_that(all(isinstance(sign, bool) for sign in input_signs)) + + with open(output_dir_path / "output_signs.json", "r", encoding="utf-8") as f: + output_signs = json.load(f) + assert_that(isinstance(output_signs, list)) + assert_that(all(isinstance(sign, bool) for sign in output_signs)) + + with open(output_dir_path / "configuration.json", "r", encoding="utf-8") as f: + configuration = Configuration().fork(**json.load(f)) + + return Server.create(mlir, input_signs, output_signs, configuration) + with open(output_dir_path / "client.specs.json", "r", encoding="utf-8") as f: client_specs = ClientSpecs.unserialize(f.read()) diff --git a/tests/compilation/test_circuit.py b/tests/compilation/test_circuit.py index 2bfcb5b3a..7f3766a9d 100644 --- a/tests/compilation/test_circuit.py +++ b/tests/compilation/test_circuit.py @@ -182,6 +182,66 @@ def test_client_server_api(helpers): Client.load(client_path, configuration.insecure_key_cache_location), ] + for client in clients: + args = client.encrypt([3, 8, 1]) + + serialized_args = client.specs.serialize_public_args(args) + serialized_evaluation_keys = client.evaluation_keys.serialize() + + unserialized_args = server.client_specs.unserialize_public_args(serialized_args) + unserialized_evaluation_keys = EvaluationKeys.unserialize(serialized_evaluation_keys) + + result = server.run(unserialized_args, unserialized_evaluation_keys) + serialized_result = server.client_specs.serialize_public_result(result) + + unserialized_result = client.specs.unserialize_public_result(serialized_result) + output = client.decrypt(unserialized_result) + + assert np.array_equal(output, [45, 50, 43]) + + with pytest.raises(RuntimeError) as excinfo: + server.save("UNUSED", via_mlir=True) + + assert str(excinfo.value) == "Loaded server objects cannot be saved again via MLIR" + + server.cleanup() + + +def test_client_server_api_via_mlir(helpers): + """ + Test client/server API. + """ + + configuration = helpers.configuration() + + @compiler({"x": "encrypted"}) + def function(x): + return x + 42 + + inputset = [np.random.randint(0, 10, size=(3,)) for _ in range(10)] + circuit = function.compile(inputset, configuration.fork(jit=False)) + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_dir_path = Path(tmp_dir) + + server_path = tmp_dir_path / "server.zip" + circuit.server.save(server_path, via_mlir=True) + + client_path = tmp_dir_path / "client.zip" + circuit.client.save(client_path) + + circuit.cleanup() + + server = Server.load(server_path) + + serialized_client_specs = server.client_specs.serialize() + client_specs = ClientSpecs.unserialize(serialized_client_specs) + + clients = [ + Client(client_specs, configuration.insecure_key_cache_location), + Client.load(client_path, configuration.insecure_key_cache_location), + ] + for client in clients: args = client.encrypt([3, 8, 1]) diff --git a/tests/compilation/test_configuration.py b/tests/compilation/test_configuration.py index cf1bf1242..28b7ab329 100644 --- a/tests/compilation/test_configuration.py +++ b/tests/compilation/test_configuration.py @@ -68,6 +68,12 @@ def test_configuration_fork(): "Unexpected type for keyword argument 'dump_artifacts_on_unexpected_failures' " "(expected 'bool', got 'str')", ), + pytest.param( + {"insecure_key_cache_location": 3}, + TypeError, + "Unexpected type for keyword argument 'insecure_key_cache_location' " + "(expected 'typing.Union[str, NoneType]', got 'int')", + ), ], ) def test_configuration_bad_fork(kwargs, expected_error, expected_message):