mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: make p_error simulation in virtual circuits faster and a bit more realistic
This commit is contained in:
@@ -102,22 +102,34 @@ class Graph:
|
||||
]
|
||||
|
||||
for index in variable_input_indices:
|
||||
dtype = node.inputs[index].dtype
|
||||
if isinstance(dtype, Integer):
|
||||
# this is fine as we only call the function in the loop, and it's tested
|
||||
pred_node = self.ordered_preds_of(node)[index]
|
||||
if pred_node.operation != Operation.Input:
|
||||
dtype = node.inputs[index].dtype
|
||||
if isinstance(dtype, Integer):
|
||||
# this is not the real behavior of FHE
|
||||
# it's a simplified model, and it will be replaced at one point
|
||||
|
||||
# pylint: disable=cell-var-from-loop
|
||||
error = np.random.rand(*pred_results[index].shape)
|
||||
error = np.where(error < p_error**3, 3, error)
|
||||
error = np.where(error < p_error**2, 2, error)
|
||||
error = np.where(error < p_error, 1, np.where(error > 1, error, 0))
|
||||
|
||||
def introduce_error(value):
|
||||
if np.random.rand() < p_error:
|
||||
value += 1 if np.random.rand() < 0.5 else -1
|
||||
value = value if value >= dtype.min() else dtype.max()
|
||||
value = value if value <= dtype.max() else dtype.min()
|
||||
return value
|
||||
error_sign = np.random.rand(*pred_results[index].shape)
|
||||
error_sign = np.where(error < 0.5, 1, -1)
|
||||
|
||||
# pylint: enable=cell-var-from-loop
|
||||
new_results = pred_results[index] + (error * error_sign)
|
||||
|
||||
pred_results[index] = np.vectorize(introduce_error)(pred_results[index])
|
||||
underflow_indices = np.where(new_results < dtype.min())
|
||||
new_results[underflow_indices] = (
|
||||
dtype.max() - (dtype.min() - new_results[underflow_indices]) + 1
|
||||
)
|
||||
|
||||
overflow_indices = np.where(new_results > dtype.max())
|
||||
new_results[overflow_indices] = (
|
||||
dtype.min() + (new_results[overflow_indices] - dtype.max()) - 1
|
||||
)
|
||||
|
||||
pred_results[index] = new_results
|
||||
|
||||
try:
|
||||
node_results[node] = node(*pred_results)
|
||||
|
||||
@@ -307,7 +307,7 @@ def test_virtual_p_error(p_error, bit_width, sample_size, tolerance, helpers):
|
||||
|
||||
@compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return x**2
|
||||
return (-x) ** 2
|
||||
|
||||
inputset = [np.random.randint(0, 2**bit_width, size=(sample_size,)) for _ in range(100)]
|
||||
circuit = function.compile(inputset, configuration=configuration, virtual=True, p_error=p_error)
|
||||
@@ -317,12 +317,7 @@ def test_virtual_p_error(p_error, bit_width, sample_size, tolerance, helpers):
|
||||
|
||||
errors = 0
|
||||
for i in range(sample_size):
|
||||
if output[i] != sample[i] ** 2:
|
||||
possible_inputs = [
|
||||
(sample[i] + 1) if sample[i] != (2**bit_width) - 1 else 0,
|
||||
(sample[i] - 1) if sample[i] != 0 else (2**bit_width) - 1,
|
||||
]
|
||||
assert output[i] in [possible_inputs[0] ** 2, possible_inputs[1] ** 2]
|
||||
if output[i] != (-sample[i]) ** 2:
|
||||
errors += 1
|
||||
|
||||
expected_number_of_errors_on_average = sample_size * p_error
|
||||
|
||||
Reference in New Issue
Block a user