feat: add virtual option to compile to simulate fhe without compiling

This commit is contained in:
Umut
2022-04-22 11:51:00 +02:00
parent 64234ee761
commit 85cbd38021
6 changed files with 112 additions and 12 deletions

View File

@@ -3,7 +3,7 @@ Declaration of `Circuit` class.
"""
from pathlib import Path
from typing import List, Optional, Tuple, Union, cast
from typing import Any, List, Optional, Tuple, Union, cast
import numpy as np
from concrete.compiler import (
@@ -33,6 +33,7 @@ class Circuit:
graph: Graph
mlir: str
virtual: bool
_jit_support: JITSupport
_compilation_result: JITCompilationResult
@@ -44,10 +45,26 @@ class Circuit:
_server_lambda: JITLambda
def __init__(self, graph: Graph, mlir: str, configuration: CompilationConfiguration):
def __init__(
self,
graph: Graph,
mlir: str,
configuration: Optional[CompilationConfiguration] = None,
virtual: bool = False,
):
configuration = configuration if configuration is not None else CompilationConfiguration()
self.graph = graph
self.mlir = mlir
self.virtual = virtual
if self.virtual:
print(
"Warning: You are using virtual compilation, "
"which means the evaluation will not be homomorphic."
)
return
options = CompilationOptions.new("main")
options.set_loop_parallelize(configuration.loop_parallelize)
@@ -111,6 +128,9 @@ class Circuit:
whether to generate new keys even if keys are already generated
"""
if self.virtual:
raise RuntimeError("Virtual circuits cannot use `keygen` method")
if self._keyset is None or force:
self._keyset = ClientSupport.key_set(self._client_parameters, self._keyset_cache)
@@ -127,6 +147,9 @@ class Circuit:
encrypted and plain arguments as well as public keys
"""
if self.virtual:
raise RuntimeError("Virtual circuits cannot use `encrypt` method")
if len(args) != len(self.graph.input_nodes):
raise ValueError(f"Expected {len(self.graph.input_nodes)} inputs but got {len(args)}")
@@ -186,6 +209,9 @@ class Circuit:
encrypted result of homomorphic evaluaton
"""
if self.virtual:
raise RuntimeError("Virtual circuits cannot use `run` method")
return self._jit_support.server_call(self._server_lambda, args)
def decrypt(
@@ -204,6 +230,9 @@ class Circuit:
clear result of homomorphic evaluaton
"""
if self.virtual:
raise RuntimeError("Virtual circuits cannot use `decrypt` method")
results = ClientSupport.decrypt_result(self._keyset, result)
if not isinstance(results, tuple):
results = (results,)
@@ -233,10 +262,7 @@ class Circuit:
return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results)
def encrypt_run_decrypt(
self,
*args: Union[int, np.ndarray],
) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
def encrypt_run_decrypt(self, *args: Any) -> Any:
"""
Encrypt inputs, run the circuit, and decrypt the outputs in one go.
@@ -249,4 +275,7 @@ class Circuit:
clear result of homomorphic evaluation
"""
if self.virtual:
return self.graph(*args)
return self.decrypt(self.run(self.encrypt(*args)))

View File

@@ -218,7 +218,7 @@ class Compiler:
columns = min(longest_line, 80)
else:
columns = min(longest_line, columns)
except OSError:
except OSError: # pragma: no cover
columns = min(longest_line, 80)
print()
@@ -254,6 +254,7 @@ class Compiler:
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
show_graph: bool = False,
show_mlir: bool = False,
virtual: bool = False,
) -> Circuit:
"""
Compile the function using an inputset.
@@ -268,6 +269,9 @@ class Compiler:
show_mlir (bool, default = False):
whether to print the compiled mlir
virtual (bool, default = False):
whether to simulate the computation to allow large bit-widths
Returns:
Circuit:
compiled circuit
@@ -277,7 +281,7 @@ class Compiler:
self._evaluate("Compiling", inputset)
assert self.graph is not None
mlir = GraphConverter.convert(self.graph)
mlir = GraphConverter.convert(self.graph, virtual=virtual)
self.artifacts.add_mlir_to_compile(mlir)
if show_graph or show_mlir:
@@ -298,7 +302,7 @@ class Compiler:
columns = min(longest_line, 80)
else:
columns = min(longest_line, columns)
except OSError:
except OSError: # pragma: no cover
columns = min(longest_line, 80)
if show_graph:
@@ -321,7 +325,7 @@ class Compiler:
print()
return Circuit(self.graph, mlir, self.configuration)
return Circuit(self.graph, mlir, self.configuration, virtual=virtual)
except Exception: # pragma: no cover

View File

@@ -79,6 +79,7 @@ def compiler(
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
show_graph: bool = False,
show_mlir: bool = False,
virtual: bool = False,
) -> Circuit:
"""
Compile the function into a circuit.
@@ -93,12 +94,15 @@ def compiler(
show_mlir (bool, default = False):
whether to print the compiled mlir
virtual (bool, default = False):
whether to simulate the computation to allow large bit-widths
Returns:
Circuit:
compiled circuit
"""
return self.compiler.compile(inputset, show_graph, show_mlir)
return self.compiler.compile(inputset, show_graph, show_mlir, virtual)
return Compilable(function)

View File

@@ -277,7 +277,7 @@ class GraphConverter:
nx_graph.add_edge(add_offset, node, input_idx=variable_input_index)
@staticmethod
def convert(graph: Graph) -> str:
def convert(graph: Graph, virtual: bool = False) -> str:
"""
Convert a computation graph to its corresponding MLIR representation.
@@ -285,6 +285,9 @@ class GraphConverter:
graph (Graph):
computation graph to be converted
virtual (bool, default = False):
whether to circuit will be virtual
Returns:
str:
textual MLIR representation corresponding to `graph`
@@ -293,6 +296,9 @@ class GraphConverter:
graph = deepcopy(graph)
GraphConverter._check_graph_convertibility(graph)
if virtual:
return "Virtual circuits doesn't have MLIR."
GraphConverter._update_bit_widths(graph)
GraphConverter._offset_negative_lookup_table_inputs(graph)

View File

@@ -129,3 +129,38 @@ def test_circuit_bad_run(helpers):
assert str(excinfo.value) == (
"Expected argument 1 to be EncryptedScalar<uint5> but it's EncryptedScalar<uint7>"
)
def test_circuit_virtual_explicit_api(helpers):
"""
Test `keygen`, `encrypt`, `run`, and `decrypt` methods of `Circuit` class with virtual circuit.
"""
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"}, configuration=configuration)
def f(x, y):
return x + y
inputset = [(np.random.randint(0, 2 ** 4), np.random.randint(0, 2 ** 5)) for _ in range(100)]
circuit = f.compile(inputset, virtual=True)
with pytest.raises(RuntimeError) as excinfo:
circuit.keygen()
assert str(excinfo.value) == "Virtual circuits cannot use `keygen` method"
with pytest.raises(RuntimeError) as excinfo:
circuit.encrypt(1, 2)
assert str(excinfo.value) == "Virtual circuits cannot use `encrypt` method"
with pytest.raises(RuntimeError) as excinfo:
circuit.run(None)
assert str(excinfo.value) == "Virtual circuits cannot use `run` method"
with pytest.raises(RuntimeError) as excinfo:
circuit.decrypt(None)
assert str(excinfo.value) == "Virtual circuits cannot use `decrypt` method"

View File

@@ -120,3 +120,25 @@ def test_compiler_bad_compile(helpers):
compiler.compile()
assert str(excinfo.value) == "Compiling function 'f' without an inputset is not supported"
def test_compiler_virtual_compile(helpers, capsys):
"""
Test `compile` method of `Compiler` class with virtual=True.
"""
configuration = helpers.configuration()
def f(x):
return x + 400
compiler = Compiler(f, {"x": "encrypted"}, configuration=configuration)
circuit = compiler.compile(inputset=range(400), virtual=True)
captured = capsys.readouterr()
assert captured.out.strip() == (
"Warning: You are using virtual compilation, "
"which means the evaluation will not be homomorphic."
)
assert circuit.encrypt_run_decrypt(200) == 600