diff --git a/concrete/numpy/compilation/circuit.py b/concrete/numpy/compilation/circuit.py index 76538e385..f5adede82 100644 --- a/concrete/numpy/compilation/circuit.py +++ b/concrete/numpy/compilation/circuit.py @@ -50,14 +50,13 @@ class Circuit: graph: Graph, mlir: str, configuration: Optional[Configuration] = None, - virtual: bool = False, ): configuration = configuration if configuration is not None else Configuration() self.graph = graph self.mlir = mlir - self.virtual = virtual + self.virtual = configuration.virtual if self.virtual: return diff --git a/concrete/numpy/compilation/compiler.py b/concrete/numpy/compilation/compiler.py index 5b78432c4..b7fc41e64 100644 --- a/concrete/numpy/compilation/compiler.py +++ b/concrete/numpy/compilation/compiler.py @@ -5,6 +5,7 @@ Declaration of `Compiler` class. import inspect import os import traceback +from copy import deepcopy from enum import Enum, unique from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union @@ -48,8 +49,6 @@ class Compiler: self, function: Callable, parameter_encryption_statuses: Dict[str, Union[str, EncryptionStatus]], - configuration: Optional[Configuration] = None, - artifacts: Optional[DebugArtifacts] = None, ): signature = inspect.signature(function) @@ -82,16 +81,12 @@ class Compiler: for param, status in parameter_encryption_statuses.items() } - self.configuration = configuration if configuration is not None else Configuration() - self.artifacts = artifacts if artifacts is not None else DebugArtifacts() + self.configuration = Configuration() + self.artifacts = DebugArtifacts() self.inputset = [] self.graph = None - self.artifacts.add_source_code(function) - for param, encryption_status in parameter_encryption_statuses.items(): - self.artifacts.add_parameter_encryption_status(param, encryption_status) - def __call__( self, *args: Any, @@ -126,6 +121,10 @@ class Compiler: sample to use for tracing """ + self.artifacts.add_source_code(self.function) + for param, encryption_status in self.parameter_encryption_statuses.items(): + self.artifacts.add_parameter_encryption_status(param, encryption_status) + parameters = { param: Value.of(arg, is_encrypted=(status == EncryptionStatus.ENCRYPTED)) for arg, (param, status) in zip( @@ -180,7 +179,9 @@ class Compiler: def trace( self, inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None, - show_graph: bool = False, + configuration: Optional[Configuration] = None, + artifacts: Optional[DebugArtifacts] = None, + **kwargs, ) -> Graph: """ Trace the function using an inputset. @@ -189,19 +190,37 @@ class Compiler: inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]): optional inputset to extend accumulated inputset before bounds measurement - show_graph (bool, default = False): - whether to print the computation graph + configuration(Optional[Configuration], default = None): + configuration to use + + artifacts (Optional[DebugArtifacts], default = None): + artifacts to store information about the process + + kwargs (Dict[str, Any]): + configuration options to overwrite Returns: Graph: computation graph representing the function prior to MLIR conversion """ + old_configuration = deepcopy(self.configuration) + old_artifacts = deepcopy(self.artifacts) + + if configuration is not None: + self.configuration = configuration + if artifacts is not None: + self.artifacts = artifacts + + if len(kwargs) != 0: + self.configuration = self.configuration.fork(**kwargs) + try: + self._evaluate("Tracing", inputset) assert self.graph is not None - if show_graph: + if self.configuration.verbose or self.configuration.show_graph: graph = self.graph.format() longest_line = max([len(line) for line in graph.split("\n")]) @@ -247,12 +266,19 @@ class Compiler: raise + finally: + + self.configuration = old_configuration + self.artifacts = old_artifacts + + # pylint: disable=too-many-branches,too-many-statements + def compile( self, inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None, - show_graph: bool = False, - show_mlir: bool = False, - virtual: bool = False, + configuration: Optional[Configuration] = None, + artifacts: Optional[DebugArtifacts] = None, + **kwargs, ) -> Circuit: """ Compile the function using an inputset. @@ -261,34 +287,50 @@ class Compiler: inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]): optional inputset to extend accumulated inputset before bounds measurement - show_graph (bool, default = False): - whether to print the computation graph + configuration(Optional[Configuration], default = None): + configuration to use - show_mlir (bool, default = False): - whether to print the compiled mlir + artifacts (Optional[DebugArtifacts], default = None): + artifacts to store information about the process - virtual (bool, default = False): - whether to simulate the computation to allow large bit-widths + kwargs (Dict[str, Any]): + configuration options to overwrite Returns: Circuit: compiled circuit """ + old_configuration = deepcopy(self.configuration) + old_artifacts = deepcopy(self.artifacts) + + if configuration is not None: + self.configuration = configuration + if artifacts is not None: + self.artifacts = artifacts + + if len(kwargs) != 0: + self.configuration = self.configuration.fork(**kwargs) + try: - if virtual and not self.configuration.enable_unsafe_features: - raise RuntimeError( - "Virtual compilation is not allowed without enabling unsafe features" - ) self._evaluate("Compiling", inputset) assert self.graph is not None - mlir = GraphConverter.convert(self.graph, virtual=virtual) + mlir = GraphConverter.convert(self.graph, virtual=self.configuration.virtual) self.artifacts.add_mlir_to_compile(mlir) - if show_graph or show_mlir: - graph = self.graph.format() if show_graph else "" + if ( + self.configuration.verbose + or self.configuration.show_graph + or self.configuration.show_mlir + ): + + graph = ( + self.graph.format() + if self.configuration.verbose or self.configuration.show_graph + else "" + ) longest_graph_line = max([len(line) for line in graph.split("\n")]) longest_mlir_line = max([len(line) for line in mlir.split("\n")]) @@ -308,7 +350,7 @@ class Compiler: except OSError: # pragma: no cover columns = min(longest_line, 80) - if show_graph: + if self.configuration.verbose or self.configuration.show_graph: print() print("Computation Graph") @@ -318,8 +360,13 @@ class Compiler: print() - if show_mlir: - print("\n" if not show_graph else "", end="") + if self.configuration.verbose or self.configuration.show_mlir: + print( + "\n" + if not (self.configuration.verbose or self.configuration.show_graph) + else "", + end="", + ) print("MLIR") print("-" * columns) @@ -328,7 +375,7 @@ class Compiler: print() - return Circuit(self.graph, mlir, self.configuration, virtual=virtual) + return Circuit(self.graph, mlir, self.configuration) except Exception: # pragma: no cover @@ -346,3 +393,10 @@ class Compiler: f.write(traceback.format_exc()) raise + + finally: + + self.configuration = old_configuration + self.artifacts = old_artifacts + + # pylint: enable=too-many-branches,too-many-statements diff --git a/concrete/numpy/compilation/configuration.py b/concrete/numpy/compilation/configuration.py index fdd87d55c..01efcfdb1 100644 --- a/concrete/numpy/compilation/configuration.py +++ b/concrete/numpy/compilation/configuration.py @@ -13,31 +13,63 @@ class Configuration: Configuration class, to allow the compilation process to be customized. """ + verbose: bool + show_graph: bool + show_mlir: bool dump_artifacts_on_unexpected_failures: bool enable_unsafe_features: bool + virtual: bool use_insecure_key_cache: bool loop_parallelize: bool dataflow_parallelize: bool auto_parallelize: bool + def _validate(self): + """ + Validate configuration. + """ + + if not self.enable_unsafe_features: + + if self.use_insecure_key_cache: + raise RuntimeError( + "Insecure key cache cannot be used without enabling unsafe features" + ) + + if self.virtual: + raise RuntimeError( + "Virtual compilation is not allowed without enabling unsafe features" + ) + + # pylint: disable=too-many-arguments + def __init__( self, + verbose: bool = False, + show_graph: bool = False, + show_mlir: bool = False, dump_artifacts_on_unexpected_failures: bool = True, enable_unsafe_features: bool = False, + virtual: bool = False, use_insecure_key_cache: bool = False, loop_parallelize: bool = True, dataflow_parallelize: bool = False, auto_parallelize: bool = False, ): + self.verbose = verbose + self.show_graph = show_graph + self.show_mlir = show_mlir self.dump_artifacts_on_unexpected_failures = dump_artifacts_on_unexpected_failures self.enable_unsafe_features = enable_unsafe_features + self.virtual = virtual self.use_insecure_key_cache = use_insecure_key_cache self.loop_parallelize = loop_parallelize self.dataflow_parallelize = dataflow_parallelize self.auto_parallelize = auto_parallelize - if not enable_unsafe_features and use_insecure_key_cache: - raise RuntimeError("Insecure key cache cannot be used without enabling unsafe features") + self._validate() + + # pylint: enable=too-many-arguments @staticmethod def insecure_key_cache_location() -> Optional[str]: @@ -80,4 +112,8 @@ class Configuration: setattr(result, name, value) + # pylint: disable=protected-access + result._validate() + # pylint: enable=protected-access + return result diff --git a/concrete/numpy/compilation/decorator.py b/concrete/numpy/compilation/decorator.py index d2153eb00..8bdbffdc0 100644 --- a/concrete/numpy/compilation/decorator.py +++ b/concrete/numpy/compilation/decorator.py @@ -11,23 +11,13 @@ from .compiler import Compiler, EncryptionStatus from .configuration import Configuration -def compiler( - parameters: Mapping[str, EncryptionStatus], - configuration: Optional[Configuration] = None, - artifacts: Optional[DebugArtifacts] = None, -): +def compiler(parameters: Mapping[str, EncryptionStatus]): """ Provide an easy interface for compilation. Args: parameters (Dict[str, EncryptionStatus]): encryption statuses of the parameters of the function to compile - - configuration(Optional[Configuration], default = None): - configuration to use for compilation - - artifacts (Optional[DebugArtifacts], default = None): - artifacts to store information about compilation """ def decoration(function: Callable): @@ -41,12 +31,7 @@ def compiler( def __init__(self, function: Callable): self.function = function # type: ignore - self.compiler = Compiler( - self.function, - dict(parameters), - configuration, - artifacts, - ) + self.compiler = Compiler(self.function, dict(parameters)) def __call__(self, *args, **kwargs) -> Any: self.compiler(*args, **kwargs) @@ -55,7 +40,9 @@ def compiler( def trace( self, inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None, - show_graph: bool = False, + configuration: Optional[Configuration] = None, + artifacts: Optional[DebugArtifacts] = None, + **kwargs, ) -> Graph: """ Trace the function into computation graph. @@ -64,22 +51,28 @@ def compiler( inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]): optional inputset to extend accumulated inputset before bounds measurement - show_graph (bool, default = False): - whether to print the computation graph + configuration(Optional[Configuration], default = None): + configuration to use + + artifacts (Optional[DebugArtifacts], default = None): + artifacts to store information about the process + + kwargs (Dict[str, Any]): + configuration options to overwrite Returns: Graph: computation graph representing the function prior to MLIR conversion """ - return self.compiler.trace(inputset, show_graph) + return self.compiler.trace(inputset, configuration, artifacts, **kwargs) def compile( self, inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None, - show_graph: bool = False, - show_mlir: bool = False, - virtual: bool = False, + configuration: Optional[Configuration] = None, + artifacts: Optional[DebugArtifacts] = None, + **kwargs, ) -> Circuit: """ Compile the function into a circuit. @@ -88,21 +81,21 @@ def compiler( inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]): optional inputset to extend accumulated inputset before bounds measurement - show_graph (bool, default = False): - whether to print the computation graph + configuration(Optional[Configuration], default = None): + configuration to use - show_mlir (bool, default = False): - whether to print the compiled mlir + artifacts (Optional[DebugArtifacts], default = None): + artifacts to store information about the process - virtual (bool, default = False): - whether to simulate the computation to allow large bit-widths + kwargs (Dict[str, Any]): + configuration options to overwrite Returns: Circuit: compiled circuit """ - return self.compiler.compile(inputset, show_graph, show_mlir, virtual) + return self.compiler.compile(inputset, configuration, artifacts, **kwargs) return Compilable(function) diff --git a/docs/user/tutorial/compilation_artifacts.md b/docs/user/tutorial/compilation_artifacts.md index f6187fd2d..a977aac09 100644 --- a/docs/user/tutorial/compilation_artifacts.md +++ b/docs/user/tutorial/compilation_artifacts.md @@ -121,11 +121,11 @@ import pathlib artifacts = cnp.DebugArtifacts("/tmp/custom/export/path") -@cnp.compiler({"x": "encrypted"}, artifacts=artifacts) +@cnp.compiler({"x": "encrypted"}) def f(x): return 127 - (50 * (np.sin(x) + 1)).astype(np.int64) -f.compile(range(2 ** 3)) +f.compile(range(2 ** 3), artifacts=artifacts) artifacts.export() ``` diff --git a/tests/compilation/test_artifacts.py b/tests/compilation/test_artifacts.py index e1840bcad..35f6d035a 100644 --- a/tests/compilation/test_artifacts.py +++ b/tests/compilation/test_artifacts.py @@ -19,12 +19,12 @@ def test_artifacts_export(helpers): configuration = helpers.configuration() artifacts = DebugArtifacts(tmpdir) - @compiler({"x": "encrypted"}, configuration=configuration, artifacts=artifacts) + @compiler({"x": "encrypted"}) def f(x): return x + 10 inputset = range(100) - f.compile(inputset) + f.compile(inputset, configuration, artifacts) artifacts.export() diff --git a/tests/compilation/test_circuit.py b/tests/compilation/test_circuit.py index 59396e9e2..0f2aa1504 100644 --- a/tests/compilation/test_circuit.py +++ b/tests/compilation/test_circuit.py @@ -18,12 +18,12 @@ def test_circuit_str(helpers): configuration = helpers.configuration() - @compiler({"x": "encrypted", "y": "encrypted"}, configuration=configuration) + @compiler({"x": "encrypted", "y": "encrypted"}) def f(x, y): return x + y inputset = [(np.random.randint(0, 2 ** 4), np.random.randint(0, 2 ** 5)) for _ in range(100)] - circuit = f.compile(inputset) + circuit = f.compile(inputset, configuration) assert str(circuit) == ( """ @@ -44,12 +44,12 @@ def test_circuit_draw(helpers): configuration = helpers.configuration() - @compiler({"x": "encrypted", "y": "encrypted"}, configuration=configuration) + @compiler({"x": "encrypted", "y": "encrypted"}) def f(x, y): return x + y inputset = [(np.random.randint(0, 2 ** 4), np.random.randint(0, 2 ** 5)) for _ in range(100)] - circuit = f.compile(inputset) + circuit = f.compile(inputset, configuration) with tempfile.TemporaryDirectory() as path: tmpdir = Path(path) @@ -67,12 +67,12 @@ def test_circuit_bad_run(helpers): configuration = helpers.configuration() - @compiler({"x": "encrypted", "y": "encrypted"}, configuration=configuration) + @compiler({"x": "encrypted", "y": "encrypted"}) def f(x, y): return x + y inputset = [(np.random.randint(0, 2 ** 4), np.random.randint(0, 2 ** 5)) for _ in range(100)] - circuit = f.compile(inputset) + circuit = f.compile(inputset, configuration) # with 1 argument # --------------- @@ -138,12 +138,12 @@ def test_circuit_virtual_explicit_api(helpers): configuration = helpers.configuration() - @compiler({"x": "encrypted", "y": "encrypted"}, configuration=configuration) + @compiler({"x": "encrypted", "y": "encrypted"}) def f(x, y): return x + y inputset = [(np.random.randint(0, 2 ** 4), np.random.randint(0, 2 ** 5)) for _ in range(100)] - circuit = f.compile(inputset, virtual=True) + circuit = f.compile(inputset, configuration, virtual=True) with pytest.raises(RuntimeError) as excinfo: circuit.keygen() diff --git a/tests/compilation/test_compiler.py b/tests/compilation/test_compiler.py index 07d57b446..486f99e38 100644 --- a/tests/compilation/test_compiler.py +++ b/tests/compilation/test_compiler.py @@ -7,13 +7,11 @@ import pytest from concrete.numpy.compilation import Compiler -def test_compiler_bad_init(helpers): +def test_compiler_bad_init(): """ Test `__init__` method of `Compiler` class with bad parameters. """ - configuration = helpers.configuration() - def f(x, y, z): return x + y + z @@ -21,7 +19,7 @@ def test_compiler_bad_init(helpers): # ----------- with pytest.raises(ValueError) as excinfo: - Compiler(f, {}, configuration=configuration) + Compiler(f, {}) assert str(excinfo.value) == ( "Encryption statuses of parameters 'x', 'y' and 'z' of function 'f' are not provided" @@ -31,7 +29,7 @@ def test_compiler_bad_init(helpers): # --------------- with pytest.raises(ValueError) as excinfo: - Compiler(f, {"z": "clear"}, configuration=configuration) + Compiler(f, {"z": "clear"}) assert str(excinfo.value) == ( "Encryption statuses of parameters 'x' and 'y' of function 'f' are not provided" @@ -41,7 +39,7 @@ def test_compiler_bad_init(helpers): # --------- with pytest.raises(ValueError) as excinfo: - Compiler(f, {"y": "encrypted", "z": "clear"}, configuration=configuration) + Compiler(f, {"y": "encrypted", "z": "clear"}) assert str(excinfo.value) == ( "Encryption status of parameter 'x' of function 'f' is not provided" @@ -52,29 +50,19 @@ def test_compiler_bad_init(helpers): # this is fine and `p` is just ignored - Compiler( - f, - {"x": "encrypted", "y": "encrypted", "z": "clear", "p": "clear"}, - configuration=configuration, - ) + Compiler(f, {"x": "encrypted", "y": "encrypted", "z": "clear", "p": "clear"}) -def test_compiler_bad_call(helpers): +def test_compiler_bad_call(): """ Test `__call__` method of `Compiler` class with bad parameters. """ - configuration = helpers.configuration() - def f(x, y, z): return x + y + z with pytest.raises(RuntimeError) as excinfo: - compiler = Compiler( - f, - {"x": "encrypted", "y": "encrypted", "z": "clear"}, - configuration=configuration, - ) + compiler = Compiler(f, {"x": "encrypted", "y": "encrypted", "z": "clear"}) compiler(1, 2, 3, invalid=4) assert str(excinfo.value) == "Calling function 'f' with kwargs is not supported" @@ -94,9 +82,8 @@ def test_compiler_bad_trace(helpers): compiler = Compiler( f, {"x": "encrypted", "y": "encrypted", "z": "clear"}, - configuration=configuration, ) - compiler.trace() + compiler.trace(configuration=configuration) assert str(excinfo.value) == "Tracing function 'f' without an inputset is not supported" @@ -115,17 +102,18 @@ def test_compiler_bad_compile(helpers): compiler = Compiler( f, {"x": "encrypted", "y": "encrypted", "z": "clear"}, - configuration=configuration, ) - compiler.compile() + compiler.compile(configuration=configuration) assert str(excinfo.value) == "Compiling function 'f' without an inputset is not supported" - configuration.enable_unsafe_features = False - with pytest.raises(RuntimeError) as excinfo: - compiler = Compiler(lambda x: x, {"x": "encrypted"}, configuration=configuration) - compiler.compile(virtual=True) + compiler = Compiler(lambda x: x, {"x": "encrypted"}) + compiler.compile( + range(10), + configuration.fork(enable_unsafe_features=False, use_insecure_key_cache=False), + virtual=True, + ) assert str(excinfo.value) == ( "Virtual compilation is not allowed without enabling unsafe features" @@ -142,7 +130,7 @@ def test_compiler_virtual_compile(helpers): def f(x): return x + 400 - compiler = Compiler(f, {"x": "encrypted"}, configuration=configuration) - circuit = compiler.compile(inputset=range(400), virtual=True) + compiler = Compiler(f, {"x": "encrypted"}) + circuit = compiler.compile(inputset=range(400), configuration=configuration, virtual=True) assert circuit.encrypt_run_decrypt(200) == 600 diff --git a/tests/compilation/test_decorator.py b/tests/compilation/test_decorator.py index cd76ae44c..dae690a85 100644 --- a/tests/compilation/test_decorator.py +++ b/tests/compilation/test_decorator.py @@ -12,14 +12,14 @@ def test_call_compile(helpers): configuration = helpers.configuration() - @compiler({"x": "encrypted"}, configuration=configuration) + @compiler({"x": "encrypted"}) def function(x): return x + 42 for i in range(10): function(i) - circuit = function.compile() + circuit = function.compile(configuration=configuration) sample = 5 helpers.check_execution(circuit, function, sample) @@ -33,12 +33,12 @@ def test_compiler_verbose_trace(helpers, capsys): configuration = helpers.configuration() artifacts = DebugArtifacts() - @compiler({"x": "encrypted"}, configuration=configuration, artifacts=artifacts) + @compiler({"x": "encrypted"}) def function(x): return x + 42 inputset = range(10) - function.trace(inputset, show_graph=True) + function.trace(inputset, configuration, artifacts, show_graph=True) captured = capsys.readouterr() assert captured.out.strip() == ( @@ -61,12 +61,12 @@ def test_compiler_verbose_compile(helpers, capsys): configuration = helpers.configuration() artifacts = DebugArtifacts() - @compiler({"x": "encrypted"}, configuration=configuration, artifacts=artifacts) + @compiler({"x": "encrypted"}) def function(x): return x + 42 inputset = range(10) - function.compile(inputset, show_graph=True, show_mlir=True) + function.compile(inputset, configuration, artifacts, show_graph=True, show_mlir=True) captured = capsys.readouterr() assert captured.out.strip() == ( diff --git a/tests/execution/test_add.py b/tests/execution/test_add.py index 7a0abef00..f25379fe4 100644 --- a/tests/execution/test_add.py +++ b/tests/execution/test_add.py @@ -59,10 +59,10 @@ def test_constant_add(function, parameters, helpers): parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters) configuration = helpers.configuration() - compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration) + compiler = cnp.Compiler(function, parameter_encryption_statuses) inputset = helpers.generate_inputset(parameters) - circuit = compiler.compile(inputset) + circuit = compiler.compile(inputset, configuration) sample = helpers.generate_sample(parameters) helpers.check_execution(circuit, function, sample) @@ -150,10 +150,10 @@ def test_add(function, parameters, helpers): parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters) configuration = helpers.configuration() - compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration) + compiler = cnp.Compiler(function, parameter_encryption_statuses) inputset = helpers.generate_inputset(parameters) - circuit = compiler.compile(inputset) + circuit = compiler.compile(inputset, configuration) sample = helpers.generate_sample(parameters) helpers.check_execution(circuit, function, sample) diff --git a/tests/execution/test_concat.py b/tests/execution/test_concat.py index 6a42f33fa..0eab4ce45 100644 --- a/tests/execution/test_concat.py +++ b/tests/execution/test_concat.py @@ -167,10 +167,10 @@ def test_concatenate(function, parameters, helpers): parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters) configuration = helpers.configuration() - compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration) + compiler = cnp.Compiler(function, parameter_encryption_statuses) inputset = helpers.generate_inputset(parameters) - circuit = compiler.compile(inputset) + circuit = compiler.compile(inputset, configuration) sample = helpers.generate_sample(parameters) helpers.check_execution(circuit, function, sample) diff --git a/tests/execution/test_convolution.py b/tests/execution/test_convolution.py index e1a186d3d..2364c920a 100644 --- a/tests/execution/test_convolution.py +++ b/tests/execution/test_convolution.py @@ -57,12 +57,12 @@ def test_conv2d(input_shape, weight_shape, strides, dilations, has_bias, helpers else: bias = None - @cnp.compiler({"x": "encrypted"}, configuration=configuration) + @cnp.compiler({"x": "encrypted"}) def function(x): return connx.conv(x, weight, bias, strides=strides, dilations=dilations) inputset = [np.random.randint(0, 4, size=input_shape) for i in range(100)] - circuit = function.compile(inputset) + circuit = function.compile(inputset, configuration) sample = np.random.randint(0, 4, size=input_shape, dtype=np.uint8) helpers.check_execution(circuit, function, sample) @@ -373,7 +373,7 @@ def test_bad_conv_compilation( else: bias = None - @cnp.compiler({"x": "encrypted"}, configuration=configuration) + @cnp.compiler({"x": "encrypted"}) def function(x): return connx.conv( x, @@ -389,7 +389,7 @@ def test_bad_conv_compilation( inputset = [np.random.randint(0, 4, size=input_shape) for i in range(100)] with pytest.raises(expected_error) as excinfo: - function.compile(inputset) + function.compile(inputset, configuration) assert str(excinfo.value) == expected_message diff --git a/tests/execution/test_direct_table_lookup.py b/tests/execution/test_direct_table_lookup.py index 60f866446..11d0edda1 100644 --- a/tests/execution/test_direct_table_lookup.py +++ b/tests/execution/test_direct_table_lookup.py @@ -166,10 +166,10 @@ def test_direct_table_lookup(bits, function, helpers): # scalar # ------ - compiler = cnp.Compiler(function, {"x": "encrypted"}, configuration) + compiler = cnp.Compiler(function, {"x": "encrypted"}) inputset = range(2 ** bits) - circuit = compiler.compile(inputset) + circuit = compiler.compile(inputset, configuration) sample = int(np.random.randint(0, 2 ** bits)) helpers.check_execution(circuit, function, sample, retries=10) @@ -177,10 +177,10 @@ def test_direct_table_lookup(bits, function, helpers): # tensor # ------ - compiler = cnp.Compiler(function, {"x": "encrypted"}, configuration) + compiler = cnp.Compiler(function, {"x": "encrypted"}) inputset = [np.random.randint(0, 2 ** bits, size=(3, 2), dtype=np.uint8) for _ in range(100)] - circuit = compiler.compile(inputset) + circuit = compiler.compile(inputset, configuration) sample = np.random.randint(0, 2 ** bits, size=(3, 2), dtype=np.uint8) helpers.check_execution(circuit, function, sample, retries=10) @@ -207,10 +207,10 @@ def test_direct_multi_table_lookup(helpers): def function(x): return table[x] - compiler = cnp.Compiler(function, {"x": "encrypted"}, configuration) + compiler = cnp.Compiler(function, {"x": "encrypted"}) inputset = [np.random.randint(0, 2 ** 2, size=(3, 2), dtype=np.uint8) for _ in range(100)] - circuit = compiler.compile(inputset) + circuit = compiler.compile(inputset, configuration) sample = np.random.randint(0, 2 ** 2, size=(3, 2), dtype=np.uint8) helpers.check_execution(circuit, function, sample, retries=10) @@ -277,22 +277,22 @@ def test_bad_direct_table_lookup(helpers): # compilation with float value # ---------------------------- - compiler = cnp.Compiler(random_table_lookup_3b, {"x": "encrypted"}, configuration) + compiler = cnp.Compiler(random_table_lookup_3b, {"x": "encrypted"}) inputset = [1.5] with pytest.raises(ValueError) as excinfo: - compiler.compile(inputset) + compiler.compile(inputset, configuration) assert str(excinfo.value) == "LookupTable cannot be looked up with EncryptedScalar" # compilation with invalid shape # ------------------------------ - compiler = cnp.Compiler(lambda x: table[x], {"x": "encrypted"}, configuration) + compiler = cnp.Compiler(lambda x: table[x], {"x": "encrypted"}) inputset = [10, 5, 6, 2] with pytest.raises(ValueError) as excinfo: - compiler.compile(inputset) + compiler.compile(inputset, configuration) assert str(excinfo.value) == ( "LookupTable of shape (3, 2) cannot be looked up with EncryptedScalar" diff --git a/tests/execution/test_dot.py b/tests/execution/test_dot.py index d9c1b05c1..0dfb8caf8 100644 --- a/tests/execution/test_dot.py +++ b/tests/execution/test_dot.py @@ -22,23 +22,23 @@ def test_dot(size, helpers): bound = int(np.floor(np.sqrt(127 / size))) cst = np.random.randint(0, bound, size=(size,)) - @cnp.compiler({"x": "encrypted"}, configuration=configuration) + @cnp.compiler({"x": "encrypted"}) def left_function(x): return np.dot(x, cst) - @cnp.compiler({"x": "encrypted"}, configuration=configuration) + @cnp.compiler({"x": "encrypted"}) def right_function(x): return np.dot(cst, x) - @cnp.compiler({"x": "encrypted"}, configuration=configuration) + @cnp.compiler({"x": "encrypted"}) def method(x): return x.dot(cst) inputset = [np.random.randint(0, bound, size=(size,)) for i in range(100)] - left_function_circuit = left_function.compile(inputset) - right_function_circuit = right_function.compile(inputset) - method_circuit = method.compile(inputset) + left_function_circuit = left_function.compile(inputset, configuration) + right_function_circuit = right_function.compile(inputset, configuration) + method_circuit = method.compile(inputset, configuration) sample = np.random.randint(0, bound, size=(size,), dtype=np.uint8) diff --git a/tests/execution/test_matmul.py b/tests/execution/test_matmul.py index b035a061b..c921254f1 100644 --- a/tests/execution/test_matmul.py +++ b/tests/execution/test_matmul.py @@ -125,29 +125,29 @@ def test_matmul(lhs_shape, rhs_shape, bounds, helpers): lhs_cst = list(np.random.randint(minimum, maximum, size=lhs_shape)) rhs_cst = list(np.random.randint(minimum, maximum, size=rhs_shape)) - @cnp.compiler({"x": "encrypted"}, configuration=configuration) + @cnp.compiler({"x": "encrypted"}) def lhs_operator(x): return x @ rhs_cst - @cnp.compiler({"x": "encrypted"}, configuration=configuration) + @cnp.compiler({"x": "encrypted"}) def rhs_operator(x): return lhs_cst @ x - @cnp.compiler({"x": "encrypted"}, configuration=configuration) + @cnp.compiler({"x": "encrypted"}) def lhs_function(x): return np.matmul(x, rhs_cst) - @cnp.compiler({"x": "encrypted"}, configuration=configuration) + @cnp.compiler({"x": "encrypted"}) def rhs_function(x): return np.matmul(lhs_cst, x) lhs_inputset = [np.random.randint(minimum, maximum, size=lhs_shape) for i in range(100)] rhs_inputset = [np.random.randint(minimum, maximum, size=rhs_shape) for i in range(100)] - lhs_operator_circuit = lhs_operator.compile(lhs_inputset) - rhs_operator_circuit = rhs_operator.compile(rhs_inputset) - lhs_function_circuit = lhs_function.compile(lhs_inputset) - rhs_function_circuit = rhs_function.compile(rhs_inputset) + lhs_operator_circuit = lhs_operator.compile(lhs_inputset, configuration) + rhs_operator_circuit = rhs_operator.compile(rhs_inputset, configuration) + lhs_function_circuit = lhs_function.compile(lhs_inputset, configuration) + rhs_function_circuit = rhs_function.compile(rhs_inputset, configuration) lhs_sample = np.random.randint(minimum, maximum, size=lhs_shape, dtype=np.uint8) rhs_sample = np.random.randint(minimum, maximum, size=rhs_shape, dtype=np.uint8) diff --git a/tests/execution/test_mul.py b/tests/execution/test_mul.py index 77af5fd33..925236967 100644 --- a/tests/execution/test_mul.py +++ b/tests/execution/test_mul.py @@ -67,10 +67,10 @@ def test_constant_mul(function, parameters, helpers): parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters) configuration = helpers.configuration() - compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration) + compiler = cnp.Compiler(function, parameter_encryption_statuses) inputset = helpers.generate_inputset(parameters) - circuit = compiler.compile(inputset) + circuit = compiler.compile(inputset, configuration) sample = helpers.generate_sample(parameters) helpers.check_execution(circuit, function, sample) diff --git a/tests/execution/test_neg.py b/tests/execution/test_neg.py index 81af6fa3b..0bf0d5a8e 100644 --- a/tests/execution/test_neg.py +++ b/tests/execution/test_neg.py @@ -27,18 +27,18 @@ def test_neg(parameters, helpers): parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters) configuration = helpers.configuration() - @cnp.compiler(parameter_encryption_statuses, configuration=configuration) + @cnp.compiler(parameter_encryption_statuses) def operator(x): return -x - @cnp.compiler(parameter_encryption_statuses, configuration=configuration) + @cnp.compiler(parameter_encryption_statuses) def function(x): return np.negative(x) inputset = helpers.generate_inputset(parameters) - operator_circuit = operator.compile(inputset) - function_circuit = function.compile(inputset) + operator_circuit = operator.compile(inputset, configuration) + function_circuit = function.compile(inputset, configuration) sample = helpers.generate_sample(parameters) diff --git a/tests/execution/test_others.py b/tests/execution/test_others.py index 726525d4c..a31223bd0 100644 --- a/tests/execution/test_others.py +++ b/tests/execution/test_others.py @@ -447,10 +447,10 @@ def test_others(function, parameters, helpers): parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters) configuration = helpers.configuration() - compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration) + compiler = cnp.Compiler(function, parameter_encryption_statuses) inputset = helpers.generate_inputset(parameters) - circuit = compiler.compile(inputset) + circuit = compiler.compile(inputset, configuration) sample = helpers.generate_sample(parameters) helpers.check_execution(circuit, function, sample, retries=10) @@ -464,10 +464,10 @@ def test_others(function, parameters, helpers): parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters) configuration = helpers.configuration() - compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration) + compiler = cnp.Compiler(function, parameter_encryption_statuses) inputset = helpers.generate_inputset(parameters) - circuit = compiler.compile(inputset) + circuit = compiler.compile(inputset, configuration) sample = helpers.generate_sample(parameters) helpers.check_execution(circuit, function, sample, retries=10) @@ -483,13 +483,13 @@ def test_others_bad_fusing(helpers): # two variable inputs # ------------------- - @cnp.compiler({"x": "encrypted", "y": "clear"}, configuration=configuration) + @cnp.compiler({"x": "encrypted", "y": "clear"}) def function1(x, y): return (10 * (np.sin(x) ** 2) + 10 * (np.cos(y) ** 2)).astype(np.int64) with pytest.raises(RuntimeError) as excinfo: inputset = [(i, i) for i in range(100)] - function1.compile(inputset) + function1.compile(inputset, configuration) helpers.check_str( # pylint: disable=line-too-long @@ -528,13 +528,13 @@ return %13 # big intermediate constants # -------------------------- - @cnp.compiler({"x": "encrypted"}, configuration=configuration) + @cnp.compiler({"x": "encrypted"}) def function2(x): return (np.sin(x) * [[1, 2], [3, 4]]).astype(np.int64) with pytest.raises(RuntimeError) as excinfo: inputset = range(100) - function2.compile(inputset) + function2.compile(inputset, configuration) helpers.check_str( # pylint: disable=line-too-long @@ -559,13 +559,13 @@ return %4 # intermediates with different shape # ---------------------------------- - @cnp.compiler({"x": "encrypted"}, configuration=configuration) + @cnp.compiler({"x": "encrypted"}) def function3(x): return np.abs(np.sin(x)).reshape((2, 3)).astype(np.int64) with pytest.raises(RuntimeError) as excinfo: inputset = [np.random.randint(0, 2 ** 7, size=(3, 2)) for _ in range(100)] - function3.compile(inputset) + function3.compile(inputset, configuration) helpers.check_str( # pylint: disable=line-too-long diff --git a/tests/execution/test_reshape.py b/tests/execution/test_reshape.py index ac4b176d6..41ebdeb32 100644 --- a/tests/execution/test_reshape.py +++ b/tests/execution/test_reshape.py @@ -116,18 +116,18 @@ def test_reshape(shape, newshape, helpers): configuration = helpers.configuration() - @cnp.compiler({"x": "encrypted"}, configuration=configuration) + @cnp.compiler({"x": "encrypted"}) def function(x): return np.reshape(x, newshape) - @cnp.compiler({"x": "encrypted"}, configuration=configuration) + @cnp.compiler({"x": "encrypted"}) def method(x): return x.reshape(newshape) inputset = [np.random.randint(0, 2 ** 5, size=shape) for i in range(100)] - function_circuit = function.compile(inputset) - method_circuit = method.compile(inputset) + function_circuit = function.compile(inputset, configuration) + method_circuit = method.compile(inputset, configuration) sample = np.random.randint(0, 2 ** 5, size=shape, dtype=np.uint8) @@ -159,12 +159,12 @@ def test_flatten(shape, helpers): configuration = helpers.configuration() - @cnp.compiler({"x": "encrypted"}, configuration=configuration) + @cnp.compiler({"x": "encrypted"}) def function(x): return x.flatten() inputset = [np.random.randint(0, 2 ** 5, size=shape) for i in range(100)] - circuit = function.compile(inputset) + circuit = function.compile(inputset, configuration) sample = np.random.randint(0, 2 ** 5, size=shape, dtype=np.uint8) helpers.check_execution(circuit, function, sample) diff --git a/tests/execution/test_static_indexing.py b/tests/execution/test_static_indexing.py index 032251bc6..fd24360c9 100644 --- a/tests/execution/test_static_indexing.py +++ b/tests/execution/test_static_indexing.py @@ -154,10 +154,10 @@ def test_static_indexing(shape, function, helpers): """ configuration = helpers.configuration() - compiler = cnp.Compiler(function, {"x": "encrypted"}, configuration) + compiler = cnp.Compiler(function, {"x": "encrypted"}) inputset = [np.random.randint(0, 2 ** 5, size=shape) for _ in range(100)] - circuit = compiler.compile(inputset) + circuit = compiler.compile(inputset, configuration) sample = np.random.randint(0, 2 ** 5, size=shape, dtype=np.uint8) helpers.check_execution(circuit, function, sample) @@ -173,21 +173,21 @@ def test_bad_static_indexing(helpers): # with float # ---------- - compiler = cnp.Compiler(lambda x: x[1.5], {"x": "encrypted"}, configuration) + compiler = cnp.Compiler(lambda x: x[1.5], {"x": "encrypted"}) inputset = [np.random.randint(0, 2 ** 3, size=(3,)) for _ in range(100)] with pytest.raises(ValueError) as excinfo: - compiler.compile(inputset) + compiler.compile(inputset, configuration) assert str(excinfo.value) == "Indexing with '1.5' is not supported" # with bad slice # -------------- - compiler = cnp.Compiler(lambda x: x[slice(1.5, 2.5, None)], {"x": "encrypted"}, configuration) + compiler = cnp.Compiler(lambda x: x[slice(1.5, 2.5, None)], {"x": "encrypted"}) inputset = [np.random.randint(0, 2 ** 3, size=(3,)) for _ in range(100)] with pytest.raises(ValueError) as excinfo: - compiler.compile(inputset) + compiler.compile(inputset, configuration) assert str(excinfo.value) == "Indexing with '1.5:2.5' is not supported" diff --git a/tests/execution/test_sub.py b/tests/execution/test_sub.py index d94e21bb1..e132153bd 100644 --- a/tests/execution/test_sub.py +++ b/tests/execution/test_sub.py @@ -47,10 +47,10 @@ def test_constant_sub(function, parameters, helpers): parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters) configuration = helpers.configuration() - compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration) + compiler = cnp.Compiler(function, parameter_encryption_statuses) inputset = helpers.generate_inputset(parameters) - circuit = compiler.compile(inputset) + circuit = compiler.compile(inputset, configuration) sample = helpers.generate_sample(parameters) helpers.check_execution(circuit, function, sample) diff --git a/tests/execution/test_sum.py b/tests/execution/test_sum.py index 2a9747b8d..ad2322a8f 100644 --- a/tests/execution/test_sum.py +++ b/tests/execution/test_sum.py @@ -105,10 +105,10 @@ def test_sum(function, parameters, helpers): parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters) configuration = helpers.configuration() - compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration) + compiler = cnp.Compiler(function, parameter_encryption_statuses) inputset = helpers.generate_inputset(parameters) - circuit = compiler.compile(inputset) + circuit = compiler.compile(inputset, configuration) sample = helpers.generate_sample(parameters) helpers.check_execution(circuit, function, sample) diff --git a/tests/execution/test_transpose.py b/tests/execution/test_transpose.py index d39ed9d17..4402ddae3 100644 --- a/tests/execution/test_transpose.py +++ b/tests/execution/test_transpose.py @@ -39,10 +39,10 @@ def test_transpose(function, parameters, helpers): parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters) configuration = helpers.configuration() - compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration) + compiler = cnp.Compiler(function, parameter_encryption_statuses) inputset = helpers.generate_inputset(parameters) - circuit = compiler.compile(inputset) + circuit = compiler.compile(inputset, configuration) sample = helpers.generate_sample(parameters) helpers.check_execution(circuit, function, sample) diff --git a/tests/mlir/test_graph_converter.py b/tests/mlir/test_graph_converter.py index 6fcabc9ec..9b99441d7 100644 --- a/tests/mlir/test_graph_converter.py +++ b/tests/mlir/test_graph_converter.py @@ -389,10 +389,10 @@ def test_graph_converter_bad_convert( """ configuration = helpers.configuration() - compiler = cnp.Compiler(function, encryption_statuses, configuration) + compiler = cnp.Compiler(function, encryption_statuses) with pytest.raises(expected_error) as excinfo: - compiler.compile(inputset) + compiler.compile(inputset, configuration) helpers.check_str(expected_message, str(excinfo.value)) diff --git a/tests/representation/test_graph.py b/tests/representation/test_graph.py index b8e80ead0..79aaf85e4 100644 --- a/tests/representation/test_graph.py +++ b/tests/representation/test_graph.py @@ -39,8 +39,8 @@ def test_graph_maximum_integer_bit_width(function, inputset, expected_result, he configuration = helpers.configuration() - compiler = cnp.Compiler(function, {"x": "encrypted"}, configuration=configuration) - graph = compiler.trace(inputset) + compiler = cnp.Compiler(function, {"x": "encrypted"}) + graph = compiler.trace(inputset, configuration) print(graph.format())