diff --git a/concrete/numpy/compilation/circuit.py b/concrete/numpy/compilation/circuit.py index 6adb526e1..c5c633326 100644 --- a/concrete/numpy/compilation/circuit.py +++ b/concrete/numpy/compilation/circuit.py @@ -151,7 +151,7 @@ class Circuit: """ if self.configuration.virtual: - return self.graph(*args) + return self.graph(*args, p_error=self.configuration.p_error) return self.decrypt(self.run(self.encrypt(*args))) diff --git a/concrete/numpy/representation/graph.py b/concrete/numpy/representation/graph.py index 0c21b741a..a67c2a9a8 100644 --- a/concrete/numpy/representation/graph.py +++ b/concrete/numpy/representation/graph.py @@ -49,6 +49,7 @@ class Graph: def __call__( self, *args: Any, + p_error: Optional[float] = None, ) -> Union[ np.bool_, np.integer, @@ -56,13 +57,14 @@ class Graph: np.ndarray, Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray], ...], ]: - evaluation = self.evaluate(*args) + evaluation = self.evaluate(*args, p_error=p_error) result = tuple(evaluation[node] for node in self.ordered_outputs()) return result if len(result) > 1 else result[0] def evaluate( self, *args: Any, + p_error: Optional[float] = None, ) -> Dict[Node, Union[np.bool_, np.integer, np.floating, np.ndarray]]: r""" Perform the computation `Graph` represents and get resulting values for all nodes. @@ -71,11 +73,19 @@ class Graph: *args (List[Any]): inputs to the computation + p_error (Optional[float]): + probability of error for table lookups + Returns: Dict[Node, Union[np.bool\_, np.integer, np.floating, np.ndarray]]: nodes and their values during computation """ + if p_error is None: + p_error = 0.0 + + assert isinstance(p_error, float) + node_results: Dict[Node, Union[np.bool_, np.integer, np.floating, np.ndarray]] = {} for node in nx.topological_sort(self.graph): if node.operation == Operation.Input: @@ -83,6 +93,32 @@ class Graph: continue pred_results = [node_results[pred] for pred in self.ordered_preds_of(node)] + + if p_error > 0.0 and node.converted_to_table_lookup: + variable_input_indices = [ + idx + for idx, pred in enumerate(self.ordered_preds_of(node)) + if not pred.operation == Operation.Constant + ] + + for index in variable_input_indices: + dtype = node.inputs[index].dtype + if isinstance(dtype, Integer): + # this is fine as we only call the function in the loop, and it's tested + + # pylint: disable=cell-var-from-loop + + def introduce_error(value): + if np.random.rand() < p_error: + value += 1 if np.random.rand() < 0.5 else -1 + value = value if value >= dtype.min() else dtype.max() + value = value if value <= dtype.max() else dtype.min() + return value + + # pylint: enable=cell-var-from-loop + + pred_results[index] = np.vectorize(introduce_error)(pred_results[index]) + try: node_results[node] = node(*pred_results) except Exception as error: diff --git a/tests/compilation/test_circuit.py b/tests/compilation/test_circuit.py index 9574add59..7587472a5 100644 --- a/tests/compilation/test_circuit.py +++ b/tests/compilation/test_circuit.py @@ -292,3 +292,42 @@ def test_bad_server_save(helpers): circuit.server.save("test.zip") assert str(excinfo.value) == "Just-in-Time compilation cannot be saved" + + +@pytest.mark.parametrize("p_error", [0.5, 0.1, 0.01]) +@pytest.mark.parametrize("bit_width", [10]) +@pytest.mark.parametrize("sample_size", [100_000]) +@pytest.mark.parametrize("tolerance", [0.05]) +def test_virtual_p_error(p_error, bit_width, sample_size, tolerance, helpers): + """ + Test virtual circuits with p_error. + """ + + configuration = helpers.configuration() + + @compiler({"x": "encrypted"}) + def function(x): + return x**2 + + inputset = [np.random.randint(0, 2**bit_width, size=(sample_size,)) for _ in range(100)] + circuit = function.compile(inputset, configuration=configuration, virtual=True, p_error=p_error) + + sample = np.random.randint(0, 2**bit_width, size=(sample_size,)) + output = circuit.encrypt_run_decrypt(sample) + + errors = 0 + for i in range(sample_size): + if output[i] != sample[i] ** 2: + possible_inputs = [ + (sample[i] + 1) if sample[i] != (2**bit_width) - 1 else 0, + (sample[i] - 1) if sample[i] != 0 else (2**bit_width) - 1, + ] + assert output[i] in [possible_inputs[0] ** 2, possible_inputs[1] ** 2] + errors += 1 + + expected_number_of_errors_on_average = sample_size * p_error + acceptable_number_of_errors = [ + expected_number_of_errors_on_average - (expected_number_of_errors_on_average * tolerance), + expected_number_of_errors_on_average + (expected_number_of_errors_on_average * tolerance), + ] + assert acceptable_number_of_errors[0] < errors < acceptable_number_of_errors[1]