From ee5fc138abf9fff06f18912df40340f5f61ebcbe Mon Sep 17 00:00:00 2001 From: Umut Date: Fri, 30 Jun 2023 14:29:21 +0200 Subject: [PATCH] feat(frontend-python): lazily enable simulation and execution when needed --- .../concrete/fhe/compilation/circuit.py | 88 +++++++++++++++---- .../tests/compilation/test_circuit.py | 17 +--- 2 files changed, 75 insertions(+), 30 deletions(-) diff --git a/frontends/concrete-python/concrete/fhe/compilation/circuit.py b/frontends/concrete-python/concrete/fhe/compilation/circuit.py index c5dda446a..38ff63bd8 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/circuit.py +++ b/frontends/concrete-python/concrete/fhe/compilation/circuit.py @@ -41,25 +41,28 @@ class Circuit: self.graph = graph self.mlir = mlir - self._initialize_circuit() - - def _initialize_circuit(self): - if self.configuration.fhe_execution: - self.enable_fhe_execution() - if self.configuration.fhe_simulation: self.enable_fhe_simulation() + if self.configuration.fhe_execution: + self.enable_fhe_execution() + def __str__(self): return self.graph.format() def enable_fhe_simulation(self): - """Enable fhe simulation mode.""" + """ + Enable FHE simulation. + """ + if not hasattr(self, "simulator"): self.simulator = Server.create(self.mlir, self.configuration, is_simulated=True) def enable_fhe_execution(self): - """Enable fhe execution mode.""" + """ + Enable FHE execution. + """ + if not hasattr(self, "server"): self.server = Server.create(self.mlir, self.configuration) @@ -83,9 +86,9 @@ class Circuit: Any: result of the simulation """ - if not hasattr(self, "simulator"): - message = "Simulation isn't enabled. You can call enable_fhe_simulation() to enable it" - raise RuntimeError(message) + + if not hasattr(self, "simulator"): # pragma: no cover + self.enable_fhe_simulation() ordered_validated_args = validate_input_args(self.simulator.client_specs, *args) @@ -100,13 +103,16 @@ class Circuit: ) for position, arg in enumerate(ordered_validated_args) ] + results = self.simulator.run(*exported) if not isinstance(results, tuple): results = (results,) + decrypter = SimulatedValueDecrypter.new(self.simulator.client_specs.client_parameters) decrypted = tuple( decrypter.decrypt(position, result.inner) for position, result in enumerate(results) ) + return decrypted if len(decrypted) != 1 else decrypted[0] @property @@ -114,6 +120,10 @@ class Circuit: """ Get the keys of the circuit. """ + + if not hasattr(self, "client"): # pragma: no cover + self.enable_fhe_execution() + return self.client.keys @keys.setter @@ -121,6 +131,10 @@ class Circuit: """ Set the keys of the circuit. """ + + if not hasattr(self, "client"): # pragma: no cover + self.enable_fhe_execution() + self.client.keys = new_keys def keygen(self, force: bool = False, seed: Optional[int] = None): @@ -135,6 +149,9 @@ class Circuit: seed for randomness """ + if not hasattr(self, "client"): # pragma: no cover + self.enable_fhe_execution() + self.client.keygen(force, seed) def encrypt( @@ -153,6 +170,9 @@ class Circuit: encrypted argument(s) for evaluation """ + if not hasattr(self, "client"): # pragma: no cover + self.enable_fhe_execution() + return self.client.encrypt(*args) def run( @@ -170,11 +190,9 @@ class Circuit: Union[Value, Tuple[Value, ...]]: result(s) of evaluation """ - if not hasattr(self, "server"): - message = ( - "FHE execution isn't enabled. You can call enable_fhe_execution() to enable it" - ) - raise RuntimeError(message) + + if not hasattr(self, "server"): # pragma: no cover + self.enable_fhe_execution() self.keygen(force=False) return self.server.run(*args, evaluation_keys=self.client.evaluation_keys) @@ -195,6 +213,9 @@ class Circuit: decrypted result(s) of evaluation """ + if not hasattr(self, "client"): # pragma: no cover + self.enable_fhe_execution() + return self.client.decrypt(*results) def encrypt_run_decrypt(self, *args: Any) -> Any: @@ -217,13 +238,18 @@ class Circuit: Cleanup the temporary library output directory. """ - self.server.cleanup() + if hasattr(self, "server"): # pragma: no cover + self.server.cleanup() @property def complexity(self) -> float: """ Get complexity of the circuit. """ + + if not hasattr(self, "server"): # pragma: no cover + self.enable_fhe_execution() + return self.server.complexity @property @@ -231,6 +257,10 @@ class Circuit: """ Get size of the secret keys of the circuit. """ + + if not hasattr(self, "server"): # pragma: no cover + self.enable_fhe_execution() + return self.server.size_of_secret_keys @property @@ -238,6 +268,10 @@ class Circuit: """ Get size of the bootstrap keys of the circuit. """ + + if not hasattr(self, "server"): # pragma: no cover + self.enable_fhe_execution() + return self.server.size_of_bootstrap_keys @property @@ -245,6 +279,10 @@ class Circuit: """ Get size of the key switch keys of the circuit. """ + + if not hasattr(self, "server"): # pragma: no cover + self.enable_fhe_execution() + return self.server.size_of_keyswitch_keys @property @@ -252,6 +290,10 @@ class Circuit: """ Get size of the inputs of the circuit. """ + + if not hasattr(self, "server"): # pragma: no cover + self.enable_fhe_execution() + return self.server.size_of_inputs @property @@ -259,6 +301,10 @@ class Circuit: """ Get size of the outputs of the circuit. """ + + if not hasattr(self, "server"): # pragma: no cover + self.enable_fhe_execution() + return self.server.size_of_outputs @property @@ -266,6 +312,10 @@ class Circuit: """ Get probability of error for each simple TLU (on a scalar). """ + + if not hasattr(self, "server"): # pragma: no cover + self.enable_fhe_execution() + return self.server.p_error @property @@ -273,4 +323,8 @@ class Circuit: """ Get the probability of having at least one simple TLU error during the entire execution. """ + + if not hasattr(self, "server"): # pragma: no cover + self.enable_fhe_execution() + return self.server.global_p_error diff --git a/frontends/concrete-python/tests/compilation/test_circuit.py b/frontends/concrete-python/tests/compilation/test_circuit.py index 419c1321b..34e96ac49 100644 --- a/frontends/concrete-python/tests/compilation/test_circuit.py +++ b/frontends/concrete-python/tests/compilation/test_circuit.py @@ -422,12 +422,8 @@ def test_circuit_sim_disabled(helpers): inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(2)] circuit = f.compile(inputset, configuration) - with pytest.raises(RuntimeError) as excinfo: - circuit.simulate(*inputset[0]) - assert ( - str(excinfo.value) - == "Simulation isn't enabled. You can call enable_fhe_simulation() to enable it" - ) + + assert circuit.simulate(*inputset[0]) == f(*inputset[0]) def test_circuit_fhe_exec_disabled(helpers): @@ -443,13 +439,8 @@ def test_circuit_fhe_exec_disabled(helpers): inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(2)] circuit = f.compile(inputset, configuration.fork(fhe_execution=False)) - with pytest.raises(RuntimeError) as excinfo: - # as we can't encrypt, we just pass plain inputs, and it should lead to the expected error - circuit.run(*inputset[0], None) - assert ( - str(excinfo.value) - == "FHE execution isn't enabled. You can call enable_fhe_execution() to enable it" - ) + + assert circuit.encrypt_run_decrypt(*inputset[0]) == f(*inputset[0]) def test_circuit_fhe_exec_no_eval_keys(helpers):