feat: make p_error simulation in virtual circuits faster and a bit more realistic

This commit is contained in:
Umut
2022-12-21 10:36:37 +01:00
parent 1472c8f020
commit 39958cf02a
2 changed files with 26 additions and 19 deletions

View File

@@ -102,22 +102,34 @@ class Graph:
]
for index in variable_input_indices:
dtype = node.inputs[index].dtype
if isinstance(dtype, Integer):
# this is fine as we only call the function in the loop, and it's tested
pred_node = self.ordered_preds_of(node)[index]
if pred_node.operation != Operation.Input:
dtype = node.inputs[index].dtype
if isinstance(dtype, Integer):
# this is not the real behavior of FHE
# it's a simplified model, and it will be replaced at one point
# pylint: disable=cell-var-from-loop
error = np.random.rand(*pred_results[index].shape)
error = np.where(error < p_error**3, 3, error)
error = np.where(error < p_error**2, 2, error)
error = np.where(error < p_error, 1, np.where(error > 1, error, 0))
def introduce_error(value):
if np.random.rand() < p_error:
value += 1 if np.random.rand() < 0.5 else -1
value = value if value >= dtype.min() else dtype.max()
value = value if value <= dtype.max() else dtype.min()
return value
error_sign = np.random.rand(*pred_results[index].shape)
error_sign = np.where(error < 0.5, 1, -1)
# pylint: enable=cell-var-from-loop
new_results = pred_results[index] + (error * error_sign)
pred_results[index] = np.vectorize(introduce_error)(pred_results[index])
underflow_indices = np.where(new_results < dtype.min())
new_results[underflow_indices] = (
dtype.max() - (dtype.min() - new_results[underflow_indices]) + 1
)
overflow_indices = np.where(new_results > dtype.max())
new_results[overflow_indices] = (
dtype.min() + (new_results[overflow_indices] - dtype.max()) - 1
)
pred_results[index] = new_results
try:
node_results[node] = node(*pred_results)

View File

@@ -307,7 +307,7 @@ def test_virtual_p_error(p_error, bit_width, sample_size, tolerance, helpers):
@compiler({"x": "encrypted"})
def function(x):
return x**2
return (-x) ** 2
inputset = [np.random.randint(0, 2**bit_width, size=(sample_size,)) for _ in range(100)]
circuit = function.compile(inputset, configuration=configuration, virtual=True, p_error=p_error)
@@ -317,12 +317,7 @@ def test_virtual_p_error(p_error, bit_width, sample_size, tolerance, helpers):
errors = 0
for i in range(sample_size):
if output[i] != sample[i] ** 2:
possible_inputs = [
(sample[i] + 1) if sample[i] != (2**bit_width) - 1 else 0,
(sample[i] - 1) if sample[i] != 0 else (2**bit_width) - 1,
]
assert output[i] in [possible_inputs[0] ** 2, possible_inputs[1] ** 2]
if output[i] != (-sample[i]) ** 2:
errors += 1
expected_number_of_errors_on_average = sample_size * p_error