diff --git a/concrete/numpy/representation/graph.py b/concrete/numpy/representation/graph.py index a67c2a9a8..c576f9720 100644 --- a/concrete/numpy/representation/graph.py +++ b/concrete/numpy/representation/graph.py @@ -102,22 +102,34 @@ class Graph: ] for index in variable_input_indices: - dtype = node.inputs[index].dtype - if isinstance(dtype, Integer): - # this is fine as we only call the function in the loop, and it's tested + pred_node = self.ordered_preds_of(node)[index] + if pred_node.operation != Operation.Input: + dtype = node.inputs[index].dtype + if isinstance(dtype, Integer): + # this is not the real behavior of FHE + # it's a simplified model, and it will be replaced at one point - # pylint: disable=cell-var-from-loop + error = np.random.rand(*pred_results[index].shape) + error = np.where(error < p_error**3, 3, error) + error = np.where(error < p_error**2, 2, error) + error = np.where(error < p_error, 1, np.where(error > 1, error, 0)) - def introduce_error(value): - if np.random.rand() < p_error: - value += 1 if np.random.rand() < 0.5 else -1 - value = value if value >= dtype.min() else dtype.max() - value = value if value <= dtype.max() else dtype.min() - return value + error_sign = np.random.rand(*pred_results[index].shape) + error_sign = np.where(error < 0.5, 1, -1) - # pylint: enable=cell-var-from-loop + new_results = pred_results[index] + (error * error_sign) - pred_results[index] = np.vectorize(introduce_error)(pred_results[index]) + underflow_indices = np.where(new_results < dtype.min()) + new_results[underflow_indices] = ( + dtype.max() - (dtype.min() - new_results[underflow_indices]) + 1 + ) + + overflow_indices = np.where(new_results > dtype.max()) + new_results[overflow_indices] = ( + dtype.min() + (new_results[overflow_indices] - dtype.max()) - 1 + ) + + pred_results[index] = new_results try: node_results[node] = node(*pred_results) diff --git a/tests/compilation/test_circuit.py b/tests/compilation/test_circuit.py index 1f6e384e6..224e0adb9 100644 --- a/tests/compilation/test_circuit.py +++ b/tests/compilation/test_circuit.py @@ -307,7 +307,7 @@ def test_virtual_p_error(p_error, bit_width, sample_size, tolerance, helpers): @compiler({"x": "encrypted"}) def function(x): - return x**2 + return (-x) ** 2 inputset = [np.random.randint(0, 2**bit_width, size=(sample_size,)) for _ in range(100)] circuit = function.compile(inputset, configuration=configuration, virtual=True, p_error=p_error) @@ -317,12 +317,7 @@ def test_virtual_p_error(p_error, bit_width, sample_size, tolerance, helpers): errors = 0 for i in range(sample_size): - if output[i] != sample[i] ** 2: - possible_inputs = [ - (sample[i] + 1) if sample[i] != (2**bit_width) - 1 else 0, - (sample[i] - 1) if sample[i] != 0 else (2**bit_width) - 1, - ] - assert output[i] in [possible_inputs[0] ** 2, possible_inputs[1] ** 2] + if output[i] != (-sample[i]) ** 2: errors += 1 expected_number_of_errors_on_average = sample_size * p_error