diff --git a/concrete/numpy/compilation/circuit.py b/concrete/numpy/compilation/circuit.py index 685ecfc26..83f48df3e 100644 --- a/concrete/numpy/compilation/circuit.py +++ b/concrete/numpy/compilation/circuit.py @@ -2,6 +2,7 @@ Declaration of `Circuit` class. """ +from copy import deepcopy from typing import Any, Optional, Tuple, Union, cast import numpy as np @@ -9,6 +10,7 @@ from concrete.compiler import PublicArguments, PublicResult from ..dtypes import Integer from ..internal.utils import assert_that +from ..mlir import GraphConverter from ..representation import Graph from .client import Client from .configuration import Configuration @@ -35,24 +37,26 @@ class Circuit: self.mlir = mlir if self.configuration.virtual: - assert_that(self.configuration.enable_unsafe_features) return + self._initialize_client_and_server() + + def _initialize_client_and_server(self): input_signs = [] - for i in range(len(graph.input_nodes)): # pylint: disable=consider-using-enumerate - input_value = graph.input_nodes[i].output + for i in range(len(self.graph.input_nodes)): # pylint: disable=consider-using-enumerate + input_value = self.graph.input_nodes[i].output assert_that(isinstance(input_value.dtype, Integer)) input_dtype = cast(Integer, input_value.dtype) input_signs.append(input_dtype.is_signed) output_signs = [] - for i in range(len(graph.output_nodes)): # pylint: disable=consider-using-enumerate - output_value = graph.output_nodes[i].output + for i in range(len(self.graph.output_nodes)): # pylint: disable=consider-using-enumerate + output_value = self.graph.output_nodes[i].output assert_that(isinstance(output_value.dtype, Integer)) output_dtype = cast(Integer, output_value.dtype) output_signs.append(output_dtype.is_signed) - self.server = Server.create(mlir, input_signs, output_signs, self.configuration) + self.server = Server.create(self.mlir, input_signs, output_signs, self.configuration) keyset_cache_directory = None if self.configuration.use_insecure_key_cache: @@ -65,6 +69,49 @@ class Circuit: def __str__(self): return self.graph.format() + def simulate(self, *args: Any) -> Any: + """ + Simulate execution of the circuit. + + Args: + *args (Any): + inputs to the circuit + + Returns: + Any: + result of the simulation + """ + + p_error = self.p_error if not self.configuration.virtual else self.configuration.p_error + return self.graph(*args, p_error=p_error) + + def enable_fhe(self): + """ + Enable fully homomorphic encryption features. + + When called on a virtual circuit, it'll enable access to the following methods: + - encrypt + - run + - decrypt + - encrypt_run_decrypt + + When called on a normal circuit, it'll do nothing. + + Raises: + RuntimeError: + if the circuit is not supported in fhe + """ + + if not self.configuration.virtual: + return + + new_configuration = deepcopy(self.configuration) + new_configuration.virtual = False + self.configuration = new_configuration + + self.mlir = GraphConverter.convert(self.graph) + self._initialize_client_and_server() + def keygen(self, force: bool = False): """ Generate keys required for homomorphic evaluation. @@ -154,9 +201,6 @@ class Circuit: clear result of homomorphic evaluation """ - if self.configuration.virtual: - return self.graph(*args, p_error=self.configuration.p_error) - return self.decrypt(self.run(self.encrypt(*args))) def cleanup(self): diff --git a/concrete/numpy/compilation/configuration.py b/concrete/numpy/compilation/configuration.py index 6b9fa4182..7e10017ca 100644 --- a/concrete/numpy/compilation/configuration.py +++ b/concrete/numpy/compilation/configuration.py @@ -47,10 +47,6 @@ class Configuration: message = "Insecure key cache cannot be used without enabling unsafe features" raise RuntimeError(message) - if self.virtual: - message = "Virtual compilation is not allowed without enabling unsafe features" - raise RuntimeError(message) - if self.use_insecure_key_cache and self.insecure_key_cache_location is None: message = "Insecure key cache cannot be enabled without specifying its location" raise RuntimeError(message) diff --git a/tests/compilation/test_circuit.py b/tests/compilation/test_circuit.py index 641e406e3..c3371955e 100644 --- a/tests/compilation/test_circuit.py +++ b/tests/compilation/test_circuit.py @@ -321,7 +321,7 @@ def test_virtual_p_error(p_error, bit_width, sample_size, tolerance, helpers): circuit = function.compile(inputset, configuration=configuration, virtual=True, p_error=p_error) sample = np.random.randint(0, 2**bit_width, size=(sample_size,)) - output = circuit.encrypt_run_decrypt(sample) + output = circuit.simulate(sample) errors = 0 for i in range(sample_size): @@ -358,3 +358,30 @@ def test_circuit_run_with_unused_arg(helpers): assert circuit.encrypt_run_decrypt(10, 0) == 20 assert circuit.encrypt_run_decrypt(10, 10) == 20 assert circuit.encrypt_run_decrypt(10, 20) == 20 + + +def test_circuit_virtual_then_fhe(helpers): + """ + Test compiling to virtual and then fhe. + """ + + configuration = helpers.configuration() + + @compiler({"x": "encrypted", "y": "encrypted"}) + def f(x, y): + return x + y + + inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(100)] + circuit = f.compile(inputset, configuration, virtual=True) + + assert circuit.simulate(3, 5) == 8 + + circuit.enable_fhe() + + assert circuit.simulate(3, 5) == 8 + assert circuit.encrypt_run_decrypt(3, 5) == 8 + + circuit.enable_fhe() + + assert circuit.simulate(3, 5) == 8 + assert circuit.encrypt_run_decrypt(3, 5) == 8 diff --git a/tests/compilation/test_compiler.py b/tests/compilation/test_compiler.py index f8ebbfec5..4e205755e 100644 --- a/tests/compilation/test_compiler.py +++ b/tests/compilation/test_compiler.py @@ -209,21 +209,6 @@ def test_compiler_bad_compile(helpers): "(expected a tuple of 3 values got a tuple of 2 values)" ) - # with bad configuration - # ---------------------- - - with pytest.raises(RuntimeError) as excinfo: - compiler = Compiler(lambda x: x, {"x": "encrypted"}) - compiler.compile( - range(10), - configuration.fork(enable_unsafe_features=False, use_insecure_key_cache=False), - virtual=True, - ) - - assert str(excinfo.value) == ( - "Virtual compilation is not allowed without enabling unsafe features" - ) - def test_compiler_virtual_compile(helpers): """ @@ -240,7 +225,7 @@ def test_compiler_virtual_compile(helpers): inputset = [(100_000, 1_000_000)] circuit = compiler.compile(inputset, configuration=configuration, virtual=True) - assert circuit.encrypt_run_decrypt(100_000, 1_000_000) == 100_000_000_000 + assert circuit.simulate(100_000, 1_000_000) == 100_000_000_000 def test_compiler_compile_bad_inputset(helpers): diff --git a/tests/conftest.py b/tests/conftest.py index 37456403c..baaa0d65f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -261,7 +261,11 @@ class Helpers: for i in range(retries): expected = sanitize(function(*sample)) - actual = sanitize(circuit.encrypt_run_decrypt(*sample)) + actual = sanitize( + circuit.simulate(*sample) + if circuit.configuration.virtual + else circuit.encrypt_run_decrypt(*sample) + ) if all(np.array_equal(e, a) for e, a in zip(expected, actual)): break