feat: consider p_error in virtual circuits

This commit is contained in:
Umut
2022-11-17 13:20:33 +01:00
parent 25b9a59e21
commit 702375f929
3 changed files with 77 additions and 2 deletions

View File

@@ -151,7 +151,7 @@ class Circuit:
"""
if self.configuration.virtual:
return self.graph(*args)
return self.graph(*args, p_error=self.configuration.p_error)
return self.decrypt(self.run(self.encrypt(*args)))

View File

@@ -49,6 +49,7 @@ class Graph:
def __call__(
self,
*args: Any,
p_error: Optional[float] = None,
) -> Union[
np.bool_,
np.integer,
@@ -56,13 +57,14 @@ class Graph:
np.ndarray,
Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray], ...],
]:
evaluation = self.evaluate(*args)
evaluation = self.evaluate(*args, p_error=p_error)
result = tuple(evaluation[node] for node in self.ordered_outputs())
return result if len(result) > 1 else result[0]
def evaluate(
self,
*args: Any,
p_error: Optional[float] = None,
) -> Dict[Node, Union[np.bool_, np.integer, np.floating, np.ndarray]]:
r"""
Perform the computation `Graph` represents and get resulting values for all nodes.
@@ -71,11 +73,19 @@ class Graph:
*args (List[Any]):
inputs to the computation
p_error (Optional[float]):
probability of error for table lookups
Returns:
Dict[Node, Union[np.bool\_, np.integer, np.floating, np.ndarray]]:
nodes and their values during computation
"""
if p_error is None:
p_error = 0.0
assert isinstance(p_error, float)
node_results: Dict[Node, Union[np.bool_, np.integer, np.floating, np.ndarray]] = {}
for node in nx.topological_sort(self.graph):
if node.operation == Operation.Input:
@@ -83,6 +93,32 @@ class Graph:
continue
pred_results = [node_results[pred] for pred in self.ordered_preds_of(node)]
if p_error > 0.0 and node.converted_to_table_lookup:
variable_input_indices = [
idx
for idx, pred in enumerate(self.ordered_preds_of(node))
if not pred.operation == Operation.Constant
]
for index in variable_input_indices:
dtype = node.inputs[index].dtype
if isinstance(dtype, Integer):
# this is fine as we only call the function in the loop, and it's tested
# pylint: disable=cell-var-from-loop
def introduce_error(value):
if np.random.rand() < p_error:
value += 1 if np.random.rand() < 0.5 else -1
value = value if value >= dtype.min() else dtype.max()
value = value if value <= dtype.max() else dtype.min()
return value
# pylint: enable=cell-var-from-loop
pred_results[index] = np.vectorize(introduce_error)(pred_results[index])
try:
node_results[node] = node(*pred_results)
except Exception as error:

View File

@@ -292,3 +292,42 @@ def test_bad_server_save(helpers):
circuit.server.save("test.zip")
assert str(excinfo.value) == "Just-in-Time compilation cannot be saved"
@pytest.mark.parametrize("p_error", [0.5, 0.1, 0.01])
@pytest.mark.parametrize("bit_width", [10])
@pytest.mark.parametrize("sample_size", [100_000])
@pytest.mark.parametrize("tolerance", [0.05])
def test_virtual_p_error(p_error, bit_width, sample_size, tolerance, helpers):
"""
Test virtual circuits with p_error.
"""
configuration = helpers.configuration()
@compiler({"x": "encrypted"})
def function(x):
return x**2
inputset = [np.random.randint(0, 2**bit_width, size=(sample_size,)) for _ in range(100)]
circuit = function.compile(inputset, configuration=configuration, virtual=True, p_error=p_error)
sample = np.random.randint(0, 2**bit_width, size=(sample_size,))
output = circuit.encrypt_run_decrypt(sample)
errors = 0
for i in range(sample_size):
if output[i] != sample[i] ** 2:
possible_inputs = [
(sample[i] + 1) if sample[i] != (2**bit_width) - 1 else 0,
(sample[i] - 1) if sample[i] != 0 else (2**bit_width) - 1,
]
assert output[i] in [possible_inputs[0] ** 2, possible_inputs[1] ** 2]
errors += 1
expected_number_of_errors_on_average = sample_size * p_error
acceptable_number_of_errors = [
expected_number_of_errors_on_average - (expected_number_of_errors_on_average * tolerance),
expected_number_of_errors_on_average + (expected_number_of_errors_on_average * tolerance),
]
assert acceptable_number_of_errors[0] < errors < acceptable_number_of_errors[1]