From 85cbd3802168a2287f8b2d79ac3817b5f7ed93a1 Mon Sep 17 00:00:00 2001 From: Umut Date: Fri, 22 Apr 2022 11:51:00 +0200 Subject: [PATCH] feat: add virtual option to compile to simulate fhe without compiling --- concrete/numpy/compilation/circuit.py | 41 +++++++++++++++++++++---- concrete/numpy/compilation/compiler.py | 12 +++++--- concrete/numpy/compilation/decorator.py | 6 +++- concrete/numpy/mlir/graph_converter.py | 8 ++++- tests/compilation/test_circuit.py | 35 +++++++++++++++++++++ tests/compilation/test_compiler.py | 22 +++++++++++++ 6 files changed, 112 insertions(+), 12 deletions(-) diff --git a/concrete/numpy/compilation/circuit.py b/concrete/numpy/compilation/circuit.py index 6bcedbf64..5683ec7db 100644 --- a/concrete/numpy/compilation/circuit.py +++ b/concrete/numpy/compilation/circuit.py @@ -3,7 +3,7 @@ Declaration of `Circuit` class. """ from pathlib import Path -from typing import List, Optional, Tuple, Union, cast +from typing import Any, List, Optional, Tuple, Union, cast import numpy as np from concrete.compiler import ( @@ -33,6 +33,7 @@ class Circuit: graph: Graph mlir: str + virtual: bool _jit_support: JITSupport _compilation_result: JITCompilationResult @@ -44,10 +45,26 @@ class Circuit: _server_lambda: JITLambda - def __init__(self, graph: Graph, mlir: str, configuration: CompilationConfiguration): + def __init__( + self, + graph: Graph, + mlir: str, + configuration: Optional[CompilationConfiguration] = None, + virtual: bool = False, + ): + configuration = configuration if configuration is not None else CompilationConfiguration() + self.graph = graph self.mlir = mlir + self.virtual = virtual + if self.virtual: + print( + "Warning: You are using virtual compilation, " + "which means the evaluation will not be homomorphic." + ) + return + options = CompilationOptions.new("main") options.set_loop_parallelize(configuration.loop_parallelize) @@ -111,6 +128,9 @@ class Circuit: whether to generate new keys even if keys are already generated """ + if self.virtual: + raise RuntimeError("Virtual circuits cannot use `keygen` method") + if self._keyset is None or force: self._keyset = ClientSupport.key_set(self._client_parameters, self._keyset_cache) @@ -127,6 +147,9 @@ class Circuit: encrypted and plain arguments as well as public keys """ + if self.virtual: + raise RuntimeError("Virtual circuits cannot use `encrypt` method") + if len(args) != len(self.graph.input_nodes): raise ValueError(f"Expected {len(self.graph.input_nodes)} inputs but got {len(args)}") @@ -186,6 +209,9 @@ class Circuit: encrypted result of homomorphic evaluaton """ + if self.virtual: + raise RuntimeError("Virtual circuits cannot use `run` method") + return self._jit_support.server_call(self._server_lambda, args) def decrypt( @@ -204,6 +230,9 @@ class Circuit: clear result of homomorphic evaluaton """ + if self.virtual: + raise RuntimeError("Virtual circuits cannot use `decrypt` method") + results = ClientSupport.decrypt_result(self._keyset, result) if not isinstance(results, tuple): results = (results,) @@ -233,10 +262,7 @@ class Circuit: return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results) - def encrypt_run_decrypt( - self, - *args: Union[int, np.ndarray], - ) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]: + def encrypt_run_decrypt(self, *args: Any) -> Any: """ Encrypt inputs, run the circuit, and decrypt the outputs in one go. @@ -249,4 +275,7 @@ class Circuit: clear result of homomorphic evaluation """ + if self.virtual: + return self.graph(*args) + return self.decrypt(self.run(self.encrypt(*args))) diff --git a/concrete/numpy/compilation/compiler.py b/concrete/numpy/compilation/compiler.py index a1ba47287..85df89d5b 100644 --- a/concrete/numpy/compilation/compiler.py +++ b/concrete/numpy/compilation/compiler.py @@ -218,7 +218,7 @@ class Compiler: columns = min(longest_line, 80) else: columns = min(longest_line, columns) - except OSError: + except OSError: # pragma: no cover columns = min(longest_line, 80) print() @@ -254,6 +254,7 @@ class Compiler: inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None, show_graph: bool = False, show_mlir: bool = False, + virtual: bool = False, ) -> Circuit: """ Compile the function using an inputset. @@ -268,6 +269,9 @@ class Compiler: show_mlir (bool, default = False): whether to print the compiled mlir + virtual (bool, default = False): + whether to simulate the computation to allow large bit-widths + Returns: Circuit: compiled circuit @@ -277,7 +281,7 @@ class Compiler: self._evaluate("Compiling", inputset) assert self.graph is not None - mlir = GraphConverter.convert(self.graph) + mlir = GraphConverter.convert(self.graph, virtual=virtual) self.artifacts.add_mlir_to_compile(mlir) if show_graph or show_mlir: @@ -298,7 +302,7 @@ class Compiler: columns = min(longest_line, 80) else: columns = min(longest_line, columns) - except OSError: + except OSError: # pragma: no cover columns = min(longest_line, 80) if show_graph: @@ -321,7 +325,7 @@ class Compiler: print() - return Circuit(self.graph, mlir, self.configuration) + return Circuit(self.graph, mlir, self.configuration, virtual=virtual) except Exception: # pragma: no cover diff --git a/concrete/numpy/compilation/decorator.py b/concrete/numpy/compilation/decorator.py index 0951406b0..da201d0b6 100644 --- a/concrete/numpy/compilation/decorator.py +++ b/concrete/numpy/compilation/decorator.py @@ -79,6 +79,7 @@ def compiler( inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None, show_graph: bool = False, show_mlir: bool = False, + virtual: bool = False, ) -> Circuit: """ Compile the function into a circuit. @@ -93,12 +94,15 @@ def compiler( show_mlir (bool, default = False): whether to print the compiled mlir + virtual (bool, default = False): + whether to simulate the computation to allow large bit-widths + Returns: Circuit: compiled circuit """ - return self.compiler.compile(inputset, show_graph, show_mlir) + return self.compiler.compile(inputset, show_graph, show_mlir, virtual) return Compilable(function) diff --git a/concrete/numpy/mlir/graph_converter.py b/concrete/numpy/mlir/graph_converter.py index 82d020b2c..e515147d9 100644 --- a/concrete/numpy/mlir/graph_converter.py +++ b/concrete/numpy/mlir/graph_converter.py @@ -277,7 +277,7 @@ class GraphConverter: nx_graph.add_edge(add_offset, node, input_idx=variable_input_index) @staticmethod - def convert(graph: Graph) -> str: + def convert(graph: Graph, virtual: bool = False) -> str: """ Convert a computation graph to its corresponding MLIR representation. @@ -285,6 +285,9 @@ class GraphConverter: graph (Graph): computation graph to be converted + virtual (bool, default = False): + whether to circuit will be virtual + Returns: str: textual MLIR representation corresponding to `graph` @@ -293,6 +296,9 @@ class GraphConverter: graph = deepcopy(graph) GraphConverter._check_graph_convertibility(graph) + if virtual: + return "Virtual circuits doesn't have MLIR." + GraphConverter._update_bit_widths(graph) GraphConverter._offset_negative_lookup_table_inputs(graph) diff --git a/tests/compilation/test_circuit.py b/tests/compilation/test_circuit.py index 6c36456e0..59396e9e2 100644 --- a/tests/compilation/test_circuit.py +++ b/tests/compilation/test_circuit.py @@ -129,3 +129,38 @@ def test_circuit_bad_run(helpers): assert str(excinfo.value) == ( "Expected argument 1 to be EncryptedScalar but it's EncryptedScalar" ) + + +def test_circuit_virtual_explicit_api(helpers): + """ + Test `keygen`, `encrypt`, `run`, and `decrypt` methods of `Circuit` class with virtual circuit. + """ + + configuration = helpers.configuration() + + @compiler({"x": "encrypted", "y": "encrypted"}, configuration=configuration) + def f(x, y): + return x + y + + inputset = [(np.random.randint(0, 2 ** 4), np.random.randint(0, 2 ** 5)) for _ in range(100)] + circuit = f.compile(inputset, virtual=True) + + with pytest.raises(RuntimeError) as excinfo: + circuit.keygen() + + assert str(excinfo.value) == "Virtual circuits cannot use `keygen` method" + + with pytest.raises(RuntimeError) as excinfo: + circuit.encrypt(1, 2) + + assert str(excinfo.value) == "Virtual circuits cannot use `encrypt` method" + + with pytest.raises(RuntimeError) as excinfo: + circuit.run(None) + + assert str(excinfo.value) == "Virtual circuits cannot use `run` method" + + with pytest.raises(RuntimeError) as excinfo: + circuit.decrypt(None) + + assert str(excinfo.value) == "Virtual circuits cannot use `decrypt` method" diff --git a/tests/compilation/test_compiler.py b/tests/compilation/test_compiler.py index 069793340..7c6d88ca9 100644 --- a/tests/compilation/test_compiler.py +++ b/tests/compilation/test_compiler.py @@ -120,3 +120,25 @@ def test_compiler_bad_compile(helpers): compiler.compile() assert str(excinfo.value) == "Compiling function 'f' without an inputset is not supported" + + +def test_compiler_virtual_compile(helpers, capsys): + """ + Test `compile` method of `Compiler` class with virtual=True. + """ + + configuration = helpers.configuration() + + def f(x): + return x + 400 + + compiler = Compiler(f, {"x": "encrypted"}, configuration=configuration) + circuit = compiler.compile(inputset=range(400), virtual=True) + + captured = capsys.readouterr() + assert captured.out.strip() == ( + "Warning: You are using virtual compilation, " + "which means the evaluation will not be homomorphic." + ) + + assert circuit.encrypt_run_decrypt(200) == 600