diff --git a/concrete/numpy/compilation/circuit.py b/concrete/numpy/compilation/circuit.py index f5adede82..4de1c2d8b 100644 --- a/concrete/numpy/compilation/circuit.py +++ b/concrete/numpy/compilation/circuit.py @@ -2,6 +2,9 @@ Declaration of `Circuit` class. """ +import pickle +import shutil +import tempfile from pathlib import Path from typing import Any, List, Optional, Tuple, Union, cast @@ -15,6 +18,9 @@ from concrete.compiler import ( JITSupport, KeySet, KeySetCache, + LibraryCompilationResult, + LibraryLambda, + LibrarySupport, PublicArguments, PublicResult, ) @@ -31,55 +37,227 @@ class Circuit: Circuit class, to combine computation graph and compiler engine into a single object. """ + # pylint: disable=too-many-instance-attributes + + configuration: Configuration + graph: Graph mlir: str - virtual: bool - _jit_support: JITSupport - _compilation_result: JITCompilationResult + _support: Union[JITSupport, LibrarySupport] + _compilation_result: Union[JITCompilationResult, LibraryCompilationResult] + _server_lambda: Union[JITLambda, LibraryLambda] + + _output_dir: Optional[tempfile.TemporaryDirectory] _client_parameters: ClientParameters - - _keyset_cache: KeySetCache _keyset: KeySet + _keyset_cache: KeySetCache - _server_lambda: JITLambda + # pylint: enable=too-many-instance-attributes def __init__( self, + configuration: Configuration, graph: Graph, mlir: str, - configuration: Optional[Configuration] = None, + support: Optional[Union[JITSupport, LibrarySupport]] = None, + compilation_result: Optional[Union[JITCompilationResult, LibraryCompilationResult]] = None, + server_lambda: Optional[Union[JITLambda, LibraryLambda]] = None, + output_dir: Optional[tempfile.TemporaryDirectory] = None, ): - configuration = configuration if configuration is not None else Configuration() + self.configuration = configuration + self._output_dir = output_dir self.graph = graph self.mlir = mlir - self.virtual = configuration.virtual - if self.virtual: + if configuration.virtual: + assert_that(configuration.enable_unsafe_features) return + assert support is not None + assert compilation_result is not None + assert server_lambda is not None + + assert_that( + ( + isinstance(support, JITSupport) + and isinstance(compilation_result, JITCompilationResult) + and isinstance(server_lambda, JITLambda) + ) + or ( + isinstance(support, LibrarySupport) + and isinstance(compilation_result, LibraryCompilationResult) + and isinstance(server_lambda, LibraryLambda) + ) + ) + + self._support = support + self._compilation_result = compilation_result + self._server_lambda = server_lambda + + self._output_dir = output_dir + if isinstance(support, LibrarySupport): + assert output_dir is not None + assert_that(support.library_path == str(output_dir.name) + "/out") + + client_parameters = support.load_client_parameters(compilation_result) + keyset = None + keyset_cache = None + + if configuration.use_insecure_key_cache: + assert_that(configuration.enable_unsafe_features) + location = Configuration.insecure_key_cache_location() + if location is not None: + keyset_cache = KeySetCache.new(str(location)) + + self._client_parameters = client_parameters + self._keyset = keyset + self._keyset_cache = keyset_cache + + @staticmethod + def create(graph: Graph, mlir: str, configuration: Optional[Configuration] = None) -> "Circuit": + """ + Create a circuit from a graph and its MLIR. + + Args: + graph (Graph): + graph of the circuit + + mlir (str): + mlir of the circuit + + configuration (Optional[Configuration], default = None): + configuration to use + + Returns: + Circuit: + circuit of graph + """ + + configuration = configuration if configuration is not None else Configuration() + if configuration.virtual: + return Circuit(configuration, graph, mlir) + options = CompilationOptions.new("main") options.set_loop_parallelize(configuration.loop_parallelize) options.set_dataflow_parallelize(configuration.dataflow_parallelize) options.set_auto_parallelize(configuration.auto_parallelize) - self._jit_support = JITSupport.new() - self._compilation_result = self._jit_support.compile(mlir, options) + if configuration.jit: - self._client_parameters = self._jit_support.load_client_parameters(self._compilation_result) + output_dir = None - self._keyset_cache = None - if configuration.use_insecure_key_cache: - assert_that(configuration.enable_unsafe_features) - location = Configuration.insecure_key_cache_location() - if location is not None: - self._keyset_cache = KeySetCache.new(str(location)) - self._keyset = None + support = JITSupport.new() + compilation_result = support.compile(mlir, options) + server_lambda = support.load_server_lambda(compilation_result) - self._server_lambda = self._jit_support.load_server_lambda(self._compilation_result) + else: + + # pylint: disable=consider-using-with + output_dir = tempfile.TemporaryDirectory() + output_dir_path = Path(output_dir.name) + # pylint: enable=consider-using-with + + support = LibrarySupport.new(str(output_dir_path / "out")) + compilation_result = support.compile(mlir, options) + server_lambda = support.load_server_lambda(compilation_result) + + return Circuit( + configuration, + graph, + mlir, + support, + compilation_result, + server_lambda, + output_dir, + ) + + def save(self, path: Union[str, Path]): + """ + Save the circuit into the given path in zip format. + + Args: + path (Union[str, Path]): + path to save the circuit + """ + + if not self.configuration.virtual and self.configuration.jit: + raise RuntimeError("JIT Circuits cannot be saved") + + if self.configuration.virtual: + # pylint: disable=consider-using-with + self._output_dir = tempfile.TemporaryDirectory() + # pylint: enable=consider-using-with + + assert self._output_dir is not None + output_dir_path = Path(self._output_dir.name) + + with open(output_dir_path / "out.pickle", "wb") as f: + attributes = { + "configuration": self.configuration, + "graph": self.graph, + "mlir": self.mlir, + } + pickle.dump(attributes, f) + + path = str(path) + if path.endswith(".zip"): + path = path[: len(path) - 4] + + shutil.make_archive(path, "zip", str(output_dir_path)) + + if self.configuration.virtual: + self.cleanup() + self._output_dir = None + + @staticmethod + def load(path: Union[str, Path]) -> "Circuit": + """ + Load the circuit from the given path in zip format. + + Args: + path (Union[str, Path]): + path to load the circuit from + + Returns: + Circuit: + circuit loaded from the filesystem + """ + + # pylint: disable=consider-using-with + output_dir = tempfile.TemporaryDirectory() + output_dir_path = Path(output_dir.name) + # pylint: enable=consider-using-with + + shutil.unpack_archive(path, str(output_dir_path), "zip") + + with open(output_dir_path / "out.pickle", "rb") as f: + attributes = pickle.load(f) + + configuration = attributes["configuration"] + graph = attributes["graph"] + mlir = attributes["mlir"] + + if configuration.virtual: + output_dir.cleanup() + return Circuit(configuration, graph, mlir) + + support = LibrarySupport.new(str(output_dir_path / "out")) + compilation_result = support.reload("main") + server_lambda = support.load_server_lambda(compilation_result) + + return Circuit( + configuration, + graph, + mlir, + support, + compilation_result, + server_lambda, + output_dir, + ) def __str__(self): return self.graph.format() @@ -124,7 +302,7 @@ class Circuit: whether to generate new keys even if keys are already generated """ - if self.virtual: + if self.configuration.virtual: raise RuntimeError("Virtual circuits cannot use `keygen` method") if self._keyset is None or force: @@ -143,7 +321,7 @@ class Circuit: encrypted and plain arguments as well as public keys """ - if self.virtual: + if self.configuration.virtual: raise RuntimeError("Virtual circuits cannot use `encrypt` method") if len(args) != len(self.graph.input_nodes): @@ -205,10 +383,10 @@ class Circuit: encrypted result of homomorphic evaluaton """ - if self.virtual: + if self.configuration.virtual: raise RuntimeError("Virtual circuits cannot use `run` method") - return self._jit_support.server_call(self._server_lambda, args) + return self._support.server_call(self._server_lambda, args) def decrypt( self, @@ -226,7 +404,7 @@ class Circuit: clear result of homomorphic evaluaton """ - if self.virtual: + if self.configuration.virtual: raise RuntimeError("Virtual circuits cannot use `decrypt` method") results = ClientSupport.decrypt_result(self._keyset, result) @@ -271,7 +449,15 @@ class Circuit: clear result of homomorphic evaluation """ - if self.virtual: + if self.configuration.virtual: return self.graph(*args) return self.decrypt(self.run(self.encrypt(*args))) + + def cleanup(self): + """ + Cleanup the temporary library output directory. + """ + + if self._output_dir is not None: + self._output_dir.cleanup() diff --git a/concrete/numpy/compilation/compiler.py b/concrete/numpy/compilation/compiler.py index b7fc41e64..6bac7d54e 100644 --- a/concrete/numpy/compilation/compiler.py +++ b/concrete/numpy/compilation/compiler.py @@ -375,7 +375,7 @@ class Compiler: print() - return Circuit(self.graph, mlir, self.configuration) + return Circuit.create(self.graph, mlir, self.configuration) except Exception: # pragma: no cover diff --git a/concrete/numpy/compilation/configuration.py b/concrete/numpy/compilation/configuration.py index 01efcfdb1..d69cbe2cb 100644 --- a/concrete/numpy/compilation/configuration.py +++ b/concrete/numpy/compilation/configuration.py @@ -13,6 +13,8 @@ class Configuration: Configuration class, to allow the compilation process to be customized. """ + # pylint: disable=too-many-instance-attributes + verbose: bool show_graph: bool show_mlir: bool @@ -23,6 +25,9 @@ class Configuration: loop_parallelize: bool dataflow_parallelize: bool auto_parallelize: bool + jit: bool + + # pylint: enable=too-many-instance-attributes def _validate(self): """ @@ -55,6 +60,7 @@ class Configuration: loop_parallelize: bool = True, dataflow_parallelize: bool = False, auto_parallelize: bool = False, + jit: bool = False, ): self.verbose = verbose self.show_graph = show_graph @@ -66,6 +72,7 @@ class Configuration: self.loop_parallelize = loop_parallelize self.dataflow_parallelize = dataflow_parallelize self.auto_parallelize = auto_parallelize + self.jit = jit self._validate() diff --git a/tests/compilation/test_circuit.py b/tests/compilation/test_circuit.py index 0f2aa1504..550ce3b2e 100644 --- a/tests/compilation/test_circuit.py +++ b/tests/compilation/test_circuit.py @@ -8,6 +8,7 @@ from pathlib import Path import numpy as np import pytest +from concrete.numpy import Circuit from concrete.numpy.compilation import compiler @@ -164,3 +165,86 @@ def test_circuit_virtual_explicit_api(helpers): circuit.decrypt(None) assert str(excinfo.value) == "Virtual circuits cannot use `decrypt` method" + + +def test_circuit_bad_save(helpers): + """ + Test `save` method of `Circuit` class with bad parameters. + """ + + configuration = helpers.configuration() + + @compiler({"x": "encrypted"}) + def function(x): + return x + 42 + + inputset = range(10) + circuit = function.compile(inputset, configuration) + + with pytest.raises(RuntimeError) as excinfo: + circuit.save("circuit.zip") + + assert str(excinfo.value) == "JIT Circuits cannot be saved" + + +@pytest.mark.parametrize( + "virtual", + [False, True], +) +def test_circuit_save_load(virtual, helpers): + """ + Test `save`, `load`, and `cleanup` methods of `Circuit` class. + """ + + configuration = helpers.configuration().fork(jit=False, virtual=virtual) + + def save(base): + @compiler({"x": "encrypted"}) + def function(x): + return x + 42 + + inputset = range(10) + circuit = function.compile(inputset, configuration) + + circuit.save(base / "circuit.zip") + circuit.cleanup() + + def load(base): + circuit = Circuit.load(base / "circuit.zip") + + helpers.check_str( + """ + +%0 = x # EncryptedScalar +%1 = 42 # ClearScalar +%2 = add(%0, %1) # EncryptedScalar +return %2 + + """, + str(circuit), + ) + if virtual: + helpers.check_str("Virtual circuits doesn't have MLIR.", circuit.mlir) + else: + helpers.check_str( + """ + +module { + func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> { + %c42_i7 = arith.constant 42 : i7 + %0 = "FHE.add_eint_int"(%arg0, %c42_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6> + return %0 : !FHE.eint<6> + } +} + + """, + circuit.mlir, + ) + helpers.check_execution(circuit, lambda x: x + 42, 4) + + circuit.cleanup() + + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) + save(path) + load(path) diff --git a/tests/conftest.py b/tests/conftest.py index 42c1bd8da..b78545c4a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -107,6 +107,7 @@ class Helpers: loop_parallelize=True, dataflow_parallelize=False, auto_parallelize=False, + jit=True, ) @staticmethod