mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: consider p_error in virtual circuits
This commit is contained in:
@@ -151,7 +151,7 @@ class Circuit:
|
||||
"""
|
||||
|
||||
if self.configuration.virtual:
|
||||
return self.graph(*args)
|
||||
return self.graph(*args, p_error=self.configuration.p_error)
|
||||
|
||||
return self.decrypt(self.run(self.encrypt(*args)))
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ class Graph:
|
||||
def __call__(
|
||||
self,
|
||||
*args: Any,
|
||||
p_error: Optional[float] = None,
|
||||
) -> Union[
|
||||
np.bool_,
|
||||
np.integer,
|
||||
@@ -56,13 +57,14 @@ class Graph:
|
||||
np.ndarray,
|
||||
Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray], ...],
|
||||
]:
|
||||
evaluation = self.evaluate(*args)
|
||||
evaluation = self.evaluate(*args, p_error=p_error)
|
||||
result = tuple(evaluation[node] for node in self.ordered_outputs())
|
||||
return result if len(result) > 1 else result[0]
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
*args: Any,
|
||||
p_error: Optional[float] = None,
|
||||
) -> Dict[Node, Union[np.bool_, np.integer, np.floating, np.ndarray]]:
|
||||
r"""
|
||||
Perform the computation `Graph` represents and get resulting values for all nodes.
|
||||
@@ -71,11 +73,19 @@ class Graph:
|
||||
*args (List[Any]):
|
||||
inputs to the computation
|
||||
|
||||
p_error (Optional[float]):
|
||||
probability of error for table lookups
|
||||
|
||||
Returns:
|
||||
Dict[Node, Union[np.bool\_, np.integer, np.floating, np.ndarray]]:
|
||||
nodes and their values during computation
|
||||
"""
|
||||
|
||||
if p_error is None:
|
||||
p_error = 0.0
|
||||
|
||||
assert isinstance(p_error, float)
|
||||
|
||||
node_results: Dict[Node, Union[np.bool_, np.integer, np.floating, np.ndarray]] = {}
|
||||
for node in nx.topological_sort(self.graph):
|
||||
if node.operation == Operation.Input:
|
||||
@@ -83,6 +93,32 @@ class Graph:
|
||||
continue
|
||||
|
||||
pred_results = [node_results[pred] for pred in self.ordered_preds_of(node)]
|
||||
|
||||
if p_error > 0.0 and node.converted_to_table_lookup:
|
||||
variable_input_indices = [
|
||||
idx
|
||||
for idx, pred in enumerate(self.ordered_preds_of(node))
|
||||
if not pred.operation == Operation.Constant
|
||||
]
|
||||
|
||||
for index in variable_input_indices:
|
||||
dtype = node.inputs[index].dtype
|
||||
if isinstance(dtype, Integer):
|
||||
# this is fine as we only call the function in the loop, and it's tested
|
||||
|
||||
# pylint: disable=cell-var-from-loop
|
||||
|
||||
def introduce_error(value):
|
||||
if np.random.rand() < p_error:
|
||||
value += 1 if np.random.rand() < 0.5 else -1
|
||||
value = value if value >= dtype.min() else dtype.max()
|
||||
value = value if value <= dtype.max() else dtype.min()
|
||||
return value
|
||||
|
||||
# pylint: enable=cell-var-from-loop
|
||||
|
||||
pred_results[index] = np.vectorize(introduce_error)(pred_results[index])
|
||||
|
||||
try:
|
||||
node_results[node] = node(*pred_results)
|
||||
except Exception as error:
|
||||
|
||||
@@ -292,3 +292,42 @@ def test_bad_server_save(helpers):
|
||||
circuit.server.save("test.zip")
|
||||
|
||||
assert str(excinfo.value) == "Just-in-Time compilation cannot be saved"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("p_error", [0.5, 0.1, 0.01])
|
||||
@pytest.mark.parametrize("bit_width", [10])
|
||||
@pytest.mark.parametrize("sample_size", [100_000])
|
||||
@pytest.mark.parametrize("tolerance", [0.05])
|
||||
def test_virtual_p_error(p_error, bit_width, sample_size, tolerance, helpers):
|
||||
"""
|
||||
Test virtual circuits with p_error.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return x**2
|
||||
|
||||
inputset = [np.random.randint(0, 2**bit_width, size=(sample_size,)) for _ in range(100)]
|
||||
circuit = function.compile(inputset, configuration=configuration, virtual=True, p_error=p_error)
|
||||
|
||||
sample = np.random.randint(0, 2**bit_width, size=(sample_size,))
|
||||
output = circuit.encrypt_run_decrypt(sample)
|
||||
|
||||
errors = 0
|
||||
for i in range(sample_size):
|
||||
if output[i] != sample[i] ** 2:
|
||||
possible_inputs = [
|
||||
(sample[i] + 1) if sample[i] != (2**bit_width) - 1 else 0,
|
||||
(sample[i] - 1) if sample[i] != 0 else (2**bit_width) - 1,
|
||||
]
|
||||
assert output[i] in [possible_inputs[0] ** 2, possible_inputs[1] ** 2]
|
||||
errors += 1
|
||||
|
||||
expected_number_of_errors_on_average = sample_size * p_error
|
||||
acceptable_number_of_errors = [
|
||||
expected_number_of_errors_on_average - (expected_number_of_errors_on_average * tolerance),
|
||||
expected_number_of_errors_on_average + (expected_number_of_errors_on_average * tolerance),
|
||||
]
|
||||
assert acceptable_number_of_errors[0] < errors < acceptable_number_of_errors[1]
|
||||
|
||||
Reference in New Issue
Block a user