refactor: remove virtual option

This commit is contained in:
Umut
2023-02-23 09:41:06 +01:00
parent ddba09fa32
commit 3bbb0c2aa3
11 changed files with 63 additions and 239 deletions

View File

@@ -2,7 +2,6 @@
Declaration of `Circuit` class.
"""
from copy import deepcopy
from typing import Any, Optional, Tuple, Union, cast
import numpy as np
@@ -10,7 +9,6 @@ from concrete.compiler import PublicArguments, PublicResult
from ..dtypes import Integer
from ..internal.utils import assert_that
from ..mlir import GraphConverter
from ..representation import Graph
from .client import Client
from .configuration import Configuration
@@ -36,9 +34,6 @@ class Circuit:
self.graph = graph
self.mlir = mlir
if self.configuration.virtual:
return
self._initialize_client_and_server()
def _initialize_client_and_server(self):
@@ -82,35 +77,7 @@ class Circuit:
result of the simulation
"""
p_error = self.p_error if not self.configuration.virtual else self.configuration.p_error
return self.graph(*args, p_error=p_error)
def enable_fhe(self):
"""
Enable fully homomorphic encryption features.
When called on a virtual circuit, it'll enable access to the following methods:
- encrypt
- run
- decrypt
- encrypt_run_decrypt
When called on a normal circuit, it'll do nothing.
Raises:
RuntimeError:
if the circuit is not supported in fhe
"""
if not self.configuration.virtual:
return
new_configuration = deepcopy(self.configuration)
new_configuration.virtual = False
self.configuration = new_configuration
self.mlir = GraphConverter.convert(self.graph)
self._initialize_client_and_server()
return self.graph(*args, p_error=self.p_error)
def keygen(self, force: bool = False):
"""
@@ -121,10 +88,6 @@ class Circuit:
whether to generate new keys even if keys are already generated
"""
if self.configuration.virtual:
message = "Virtual circuits cannot use `keygen` method"
raise RuntimeError(message)
self.client.keygen(force)
def encrypt(self, *args: Union[int, np.ndarray]) -> PublicArguments:
@@ -140,10 +103,6 @@ class Circuit:
encrypted and plain arguments as well as public keys
"""
if self.configuration.virtual:
message = "Virtual circuits cannot use `encrypt` method"
raise RuntimeError(message)
return self.client.encrypt(*args)
def run(self, args: PublicArguments) -> PublicResult:
@@ -159,10 +118,6 @@ class Circuit:
encrypted result of homomorphic evaluaton
"""
if self.configuration.virtual:
message = "Virtual circuits cannot use `run` method"
raise RuntimeError(message)
self.keygen(force=False)
return self.server.run(args, self.client.evaluation_keys)
@@ -182,10 +137,6 @@ class Circuit:
clear result of homomorphic evaluaton
"""
if self.configuration.virtual:
message = "Virtual circuits cannot use `decrypt` method"
raise RuntimeError(message)
return self.client.decrypt(result)
def encrypt_run_decrypt(self, *args: Any) -> Any:

View File

@@ -439,7 +439,7 @@ class Compiler:
self._evaluate("Compiling", inputset)
assert self.graph is not None
mlir = GraphConverter.convert(self.graph, virtual=self.configuration.virtual)
mlir = GraphConverter.convert(self.graph)
if self.artifacts is not None:
self.artifacts.add_mlir_to_compile(mlir)
@@ -513,17 +513,12 @@ class Compiler:
print("-" * columns)
circuit = Circuit(self.graph, mlir, self.configuration)
if not self.configuration.virtual:
assert circuit.client.specs.client_parameters is not None
if self.artifacts is not None:
self.artifacts.add_client_parameters(
circuit.client.specs.client_parameters.serialize()
)
client_parameters = circuit.client.specs.client_parameters
if self.artifacts is not None:
self.artifacts.add_client_parameters(client_parameters.serialize())
if show_optimizer:
if self.configuration.virtual:
print("Virtual circuits don't have optimizer output.")
print("-" * columns)
print()

View File

@@ -23,7 +23,6 @@ class Configuration:
show_optimizer: Optional[bool]
dump_artifacts_on_unexpected_failures: bool
enable_unsafe_features: bool
virtual: bool
use_insecure_key_cache: bool
loop_parallelize: bool
dataflow_parallelize: bool
@@ -61,7 +60,6 @@ class Configuration:
show_optimizer: Optional[bool] = None,
dump_artifacts_on_unexpected_failures: bool = True,
enable_unsafe_features: bool = False,
virtual: bool = False,
use_insecure_key_cache: bool = False,
insecure_key_cache_location: Optional[Union[Path, str]] = None,
loop_parallelize: bool = True,
@@ -78,7 +76,6 @@ class Configuration:
self.show_optimizer = show_optimizer
self.dump_artifacts_on_unexpected_failures = dump_artifacts_on_unexpected_failures
self.enable_unsafe_features = enable_unsafe_features
self.virtual = virtual
self.use_insecure_key_cache = use_insecure_key_cache
self.insecure_key_cache_location = (
str(insecure_key_cache_location) if insecure_key_cache_location is not None else None

View File

@@ -40,7 +40,7 @@ class GraphConverter:
"""
@staticmethod
def _check_node_convertibility(graph: Graph, node: Node, virtual: bool) -> Optional[str]:
def _check_node_convertibility(graph: Graph, node: Node) -> Optional[str]:
"""
Check node convertibility to MLIR.
@@ -51,9 +51,6 @@ class GraphConverter:
node (Node):
node to be checked
virtual (bool):
whether the circuit will be virtual
Returns:
Optional[str]:
None if node is convertible to MLIR, the reason for inconvertibility otherwise
@@ -153,7 +150,7 @@ class GraphConverter:
elif name == "multiply":
assert_that(len(inputs) == 2)
if not virtual and inputs[0].is_encrypted and inputs[1].is_encrypted:
if inputs[0].is_encrypted and inputs[1].is_encrypted:
return "only multiplication between encrypted and clear is supported"
elif name == "negative":
@@ -205,7 +202,7 @@ class GraphConverter:
# pylint: enable=too-many-branches,too-many-return-statements,too-many-statements
@staticmethod
def _check_graph_convertibility(graph: Graph, virtual: bool):
def _check_graph_convertibility(graph: Graph):
"""
Check graph convertibility to MLIR.
@@ -213,9 +210,6 @@ class GraphConverter:
graph (Graph):
computation graph to be checked
virtual (bool):
whether the circuit will be virtual
Raises:
RuntimeError:
if `graph` is not convertible to MLIR
@@ -233,7 +227,7 @@ class GraphConverter:
if len(offending_nodes) == 0:
for node in graph.graph.nodes:
reason = GraphConverter._check_node_convertibility(graph, node, virtual)
reason = GraphConverter._check_node_convertibility(graph, node)
if reason is not None:
offending_nodes[node] = [reason, node.location]
@@ -665,7 +659,7 @@ class GraphConverter:
return sanitized_args
@staticmethod
def convert(graph: Graph, virtual: bool = False) -> str:
def convert(graph: Graph) -> str:
"""
Convert a computation graph to its corresponding MLIR representation.
@@ -673,9 +667,6 @@ class GraphConverter:
graph (Graph):
computation graph to be converted
virtual (bool, default = False):
whether the circuit will be virtual
Returns:
str:
textual MLIR representation corresponding to `graph`
@@ -683,10 +674,7 @@ class GraphConverter:
graph = deepcopy(graph)
GraphConverter._check_graph_convertibility(graph, virtual)
if virtual:
return "Virtual circuits don't have MLIR."
GraphConverter._check_graph_convertibility(graph)
GraphConverter._update_bit_widths(graph)
GraphConverter._offset_negative_lookup_table_inputs(graph)
GraphConverter._broadcast_assignments(graph)

View File

@@ -85,7 +85,7 @@ class Graph:
nodes and their values during computation
"""
# pylint: disable=no-member,too-many-nested-blocks
# pylint: disable=no-member,too-many-nested-blocks,too-many-branches,too-many-statements
if p_error is None:
p_error = 0.0
@@ -153,19 +153,26 @@ class Graph:
error_sign = np.random.rand(*pred_results[index].shape)
error_sign = np.where(error_sign < 0.5, 1, -1).astype(np.int64)
new_results = pred_results[index] + (error * error_sign)
new_result = pred_results[index] + (error * error_sign)
underflow_indices = np.where(new_results < dtype.min())
new_results[underflow_indices] = (
dtype.max() - (dtype.min() - new_results[underflow_indices]) + 1
)
if new_result.shape == (): # pragma: no cover
if new_result < dtype.min():
new_result = dtype.max() - (dtype.min() - new_result) + 1
elif new_result > dtype.max():
new_result = dtype.min() - (new_result - dtype.max()) - 1
overflow_indices = np.where(new_results > dtype.max())
new_results[overflow_indices] = (
dtype.min() + (new_results[overflow_indices] - dtype.max()) - 1
)
else:
underflow_indices = np.where(new_result < dtype.min())
new_result[underflow_indices] = (
dtype.max() - (dtype.min() - new_result[underflow_indices]) + 1
)
pred_results[index] = new_results
overflow_indices = np.where(new_result > dtype.max())
new_result[overflow_indices] = (
dtype.min() + (new_result[overflow_indices] - dtype.max()) - 1
)
pred_results[index] = new_result
try:
node_results[node] = node(*pred_results)

View File

@@ -8,7 +8,7 @@ from pathlib import Path
import numpy as np
import pytest
from concrete.numpy import Client, ClientSpecs, EvaluationKeys, Server, compiler
from concrete.numpy import Client, ClientSpecs, EvaluationKeys, LookupTable, Server, compiler
def test_circuit_str(helpers):
@@ -129,41 +129,6 @@ def test_circuit_bad_run(helpers):
)
def test_circuit_virtual_explicit_api(helpers):
"""
Test `keygen`, `encrypt`, `run`, and `decrypt` methods of `Circuit` class with virtual circuit.
"""
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return x + y
inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(100)]
circuit = f.compile(inputset, configuration, virtual=True)
with pytest.raises(RuntimeError) as excinfo:
circuit.keygen()
assert str(excinfo.value) == "Virtual circuits cannot use `keygen` method"
with pytest.raises(RuntimeError) as excinfo:
circuit.encrypt(1, 2)
assert str(excinfo.value) == "Virtual circuits cannot use `encrypt` method"
with pytest.raises(RuntimeError) as excinfo:
circuit.run(None)
assert str(excinfo.value) == "Virtual circuits cannot use `run` method"
with pytest.raises(RuntimeError) as excinfo:
circuit.decrypt(None)
assert str(excinfo.value) == "Virtual circuits cannot use `decrypt` method"
def test_client_server_api(helpers):
"""
Test client/server API.
@@ -303,35 +268,38 @@ def test_bad_server_save(helpers):
@pytest.mark.parametrize("p_error", [0.75, 0.5, 0.4, 0.25, 0.2, 0.1, 0.01, 0.001])
@pytest.mark.parametrize("bit_width", [10])
@pytest.mark.parametrize("bit_width", [5])
@pytest.mark.parametrize("sample_size", [1_000_000])
@pytest.mark.parametrize("tolerance", [0.075])
def test_virtual_p_error(p_error, bit_width, sample_size, tolerance, helpers):
def test_p_error_simulation(p_error, bit_width, sample_size, tolerance, helpers):
"""
Test virtual circuits with p_error.
Test p_error simulation.
"""
configuration = helpers.configuration()
configuration = helpers.configuration().fork(global_p_error=None)
table = LookupTable([0] + [x - 1 for x in range(1, 2**bit_width)])
@compiler({"x": "encrypted"})
def function(x):
return (-x) ** 2
return table[x + 1]
inputset = [np.random.randint(0, 2**bit_width, size=(sample_size,)) for _ in range(100)]
circuit = function.compile(inputset, configuration=configuration, virtual=True, p_error=p_error)
inputset = [np.random.randint(0, (2**bit_width) - 1, size=(sample_size,)) for _ in range(100)]
circuit = function.compile(inputset, configuration=configuration, p_error=p_error)
sample = np.random.randint(0, 2**bit_width, size=(sample_size,))
assert circuit.p_error < p_error
sample = np.random.randint(0, (2**bit_width) - 1, size=(sample_size,))
output = circuit.simulate(sample)
errors = 0
for i in range(sample_size):
if output[i] != (-sample[i]) ** 2:
errors += 1
errors = np.sum(output != sample)
expected_number_of_errors_on_average = sample_size * circuit.p_error
tolerated_difference = expected_number_of_errors_on_average * tolerance
expected_number_of_errors_on_average = sample_size * p_error
acceptable_number_of_errors = [
expected_number_of_errors_on_average - (expected_number_of_errors_on_average * tolerance),
expected_number_of_errors_on_average + (expected_number_of_errors_on_average * tolerance),
round(expected_number_of_errors_on_average - tolerated_difference),
round(expected_number_of_errors_on_average + tolerated_difference),
]
assert acceptable_number_of_errors[0] <= errors <= acceptable_number_of_errors[1]
@@ -358,30 +326,3 @@ def test_circuit_run_with_unused_arg(helpers):
assert circuit.encrypt_run_decrypt(10, 0) == 20
assert circuit.encrypt_run_decrypt(10, 10) == 20
assert circuit.encrypt_run_decrypt(10, 20) == 20
def test_circuit_virtual_then_fhe(helpers):
"""
Test compiling to virtual and then fhe.
"""
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return x + y
inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(100)]
circuit = f.compile(inputset, configuration, virtual=True)
assert circuit.simulate(3, 5) == 8
circuit.enable_fhe()
assert circuit.simulate(3, 5) == 8
assert circuit.encrypt_run_decrypt(3, 5) == 8
circuit.enable_fhe()
assert circuit.simulate(3, 5) == 8
assert circuit.encrypt_run_decrypt(3, 5) == 8

View File

@@ -210,24 +210,6 @@ def test_compiler_bad_compile(helpers):
)
def test_compiler_virtual_compile(helpers):
"""
Test `compile` method of `Compiler` class with virtual=True.
"""
configuration = helpers.configuration()
def f(x, y):
return x * y
compiler = Compiler(f, {"x": "encrypted", "y": "encrypted"})
inputset = [(100_000, 1_000_000)]
circuit = compiler.compile(inputset, configuration=configuration, virtual=True)
assert circuit.simulate(100_000, 1_000_000) == 100_000_000_000
def test_compiler_compile_bad_inputset(helpers):
"""
Test `compile` method of `Compiler` class with bad inputset.

View File

@@ -92,44 +92,6 @@ Optimizer
)
def test_compiler_verbose_virtual_compile(helpers, capsys):
"""
Test `compile` method of `compiler` decorator with verbose flag.
"""
configuration = helpers.configuration()
artifacts = cnp.DebugArtifacts()
@cnp.compiler({"x": "encrypted"})
def function(x):
return x + 42
inputset = range(10)
function.compile(inputset, configuration, artifacts, verbose=True, virtual=True)
captured = capsys.readouterr()
assert captured.out.strip() == (
f"""
Computation Graph
------------------------------------------------------------------
{list(artifacts.textual_representations_of_graphs.values())[-1][-1]}
------------------------------------------------------------------
MLIR
------------------------------------------------------------------
Virtual circuits don't have MLIR.
------------------------------------------------------------------
Optimizer
------------------------------------------------------------------
Virtual circuits don't have optimizer output.
------------------------------------------------------------------
""".strip()
)
def test_circuit(helpers):
"""
Test circuit decorator.

View File

@@ -223,6 +223,7 @@ class Helpers:
function: Callable,
sample: Union[Any, List[Any]],
retries: int = 1,
simulate: bool = False,
):
"""
Assert that `circuit` is behaves the same as `function` on `sample`.
@@ -237,8 +238,11 @@ class Helpers:
sample (List[Any]):
inputs
retries (int):
retries (int, default = 1):
number of times to retry (for probabilistic execution)
simulate (bool, default = False):
whether to simulate instead of fhe execution
"""
if not isinstance(sample, list):
@@ -262,9 +266,7 @@ class Helpers:
for i in range(retries):
expected = sanitize(function(*sample))
actual = sanitize(
circuit.simulate(*sample)
if circuit.configuration.virtual
else circuit.encrypt_run_decrypt(*sample)
circuit.simulate(*sample) if simulate else circuit.encrypt_run_decrypt(*sample)
)
if all(np.array_equal(e, a) for e, a in zip(expected, actual)):

View File

@@ -75,8 +75,8 @@ def test_maxpool(
def function(x):
return connx.maxpool(x, **operation)
circuit = function.compile([sample_input], helpers.configuration(), virtual=True)
helpers.check_execution(circuit, function, sample_input)
graph = function.trace([sample_input], helpers.configuration())
assert np.array_equal(graph(sample_input), expected_output)
@pytest.mark.parametrize(
@@ -318,16 +318,16 @@ def test_bad_maxpool_special(helpers):
Test maxpool with bad parameters for special cases.
"""
# without virtual
# ---------------
# compile
# -------
@cnp.compiler({"x": "encrypted"})
def without_virtual(x):
def not_compilable(x):
return connx.maxpool(x, kernel_shape=(4, 3))
inputset = [np.random.randint(0, 10, size=(1, 1, 10, 10)) for i in range(100)]
with pytest.raises(NotImplementedError) as excinfo:
without_virtual.compile(inputset, helpers.configuration())
not_compilable.compile(inputset, helpers.configuration())
helpers.check_str("MaxPool operation cannot be compiled yet", str(excinfo.value))

View File

@@ -96,8 +96,8 @@ def test_round_bit_pattern(input_bits, lsbs_to_remove, helpers):
x_rounded = cnp.round_bit_pattern(x, lsbs_to_remove=lsbs_to_remove)
return np.abs(50 * np.sin(x_rounded)).astype(np.int64)
circuit = function.compile([(2**input_bits) - 1], helpers.configuration(), virtual=True)
helpers.check_execution(circuit, function, np.random.randint(0, 2**input_bits))
circuit = function.compile([(2**input_bits) - 1], helpers.configuration())
helpers.check_execution(circuit, function, np.random.randint(0, 2**input_bits), simulate=True)
def test_auto_rounding(helpers):
@@ -181,7 +181,6 @@ def test_auto_rounding(helpers):
inputset3,
helpers.configuration(),
auto_adjust_rounders=True,
virtual=True,
)
assert rounder3.lsbs_to_remove == 3