mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: add virtual option to compile to simulate fhe without compiling
This commit is contained in:
@@ -3,7 +3,7 @@ Declaration of `Circuit` class.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
from typing import Any, List, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from concrete.compiler import (
|
||||
@@ -33,6 +33,7 @@ class Circuit:
|
||||
|
||||
graph: Graph
|
||||
mlir: str
|
||||
virtual: bool
|
||||
|
||||
_jit_support: JITSupport
|
||||
_compilation_result: JITCompilationResult
|
||||
@@ -44,10 +45,26 @@ class Circuit:
|
||||
|
||||
_server_lambda: JITLambda
|
||||
|
||||
def __init__(self, graph: Graph, mlir: str, configuration: CompilationConfiguration):
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
mlir: str,
|
||||
configuration: Optional[CompilationConfiguration] = None,
|
||||
virtual: bool = False,
|
||||
):
|
||||
configuration = configuration if configuration is not None else CompilationConfiguration()
|
||||
|
||||
self.graph = graph
|
||||
self.mlir = mlir
|
||||
|
||||
self.virtual = virtual
|
||||
if self.virtual:
|
||||
print(
|
||||
"Warning: You are using virtual compilation, "
|
||||
"which means the evaluation will not be homomorphic."
|
||||
)
|
||||
return
|
||||
|
||||
options = CompilationOptions.new("main")
|
||||
|
||||
options.set_loop_parallelize(configuration.loop_parallelize)
|
||||
@@ -111,6 +128,9 @@ class Circuit:
|
||||
whether to generate new keys even if keys are already generated
|
||||
"""
|
||||
|
||||
if self.virtual:
|
||||
raise RuntimeError("Virtual circuits cannot use `keygen` method")
|
||||
|
||||
if self._keyset is None or force:
|
||||
self._keyset = ClientSupport.key_set(self._client_parameters, self._keyset_cache)
|
||||
|
||||
@@ -127,6 +147,9 @@ class Circuit:
|
||||
encrypted and plain arguments as well as public keys
|
||||
"""
|
||||
|
||||
if self.virtual:
|
||||
raise RuntimeError("Virtual circuits cannot use `encrypt` method")
|
||||
|
||||
if len(args) != len(self.graph.input_nodes):
|
||||
raise ValueError(f"Expected {len(self.graph.input_nodes)} inputs but got {len(args)}")
|
||||
|
||||
@@ -186,6 +209,9 @@ class Circuit:
|
||||
encrypted result of homomorphic evaluaton
|
||||
"""
|
||||
|
||||
if self.virtual:
|
||||
raise RuntimeError("Virtual circuits cannot use `run` method")
|
||||
|
||||
return self._jit_support.server_call(self._server_lambda, args)
|
||||
|
||||
def decrypt(
|
||||
@@ -204,6 +230,9 @@ class Circuit:
|
||||
clear result of homomorphic evaluaton
|
||||
"""
|
||||
|
||||
if self.virtual:
|
||||
raise RuntimeError("Virtual circuits cannot use `decrypt` method")
|
||||
|
||||
results = ClientSupport.decrypt_result(self._keyset, result)
|
||||
if not isinstance(results, tuple):
|
||||
results = (results,)
|
||||
@@ -233,10 +262,7 @@ class Circuit:
|
||||
|
||||
return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results)
|
||||
|
||||
def encrypt_run_decrypt(
|
||||
self,
|
||||
*args: Union[int, np.ndarray],
|
||||
) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
|
||||
def encrypt_run_decrypt(self, *args: Any) -> Any:
|
||||
"""
|
||||
Encrypt inputs, run the circuit, and decrypt the outputs in one go.
|
||||
|
||||
@@ -249,4 +275,7 @@ class Circuit:
|
||||
clear result of homomorphic evaluation
|
||||
"""
|
||||
|
||||
if self.virtual:
|
||||
return self.graph(*args)
|
||||
|
||||
return self.decrypt(self.run(self.encrypt(*args)))
|
||||
|
||||
@@ -218,7 +218,7 @@ class Compiler:
|
||||
columns = min(longest_line, 80)
|
||||
else:
|
||||
columns = min(longest_line, columns)
|
||||
except OSError:
|
||||
except OSError: # pragma: no cover
|
||||
columns = min(longest_line, 80)
|
||||
|
||||
print()
|
||||
@@ -254,6 +254,7 @@ class Compiler:
|
||||
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
|
||||
show_graph: bool = False,
|
||||
show_mlir: bool = False,
|
||||
virtual: bool = False,
|
||||
) -> Circuit:
|
||||
"""
|
||||
Compile the function using an inputset.
|
||||
@@ -268,6 +269,9 @@ class Compiler:
|
||||
show_mlir (bool, default = False):
|
||||
whether to print the compiled mlir
|
||||
|
||||
virtual (bool, default = False):
|
||||
whether to simulate the computation to allow large bit-widths
|
||||
|
||||
Returns:
|
||||
Circuit:
|
||||
compiled circuit
|
||||
@@ -277,7 +281,7 @@ class Compiler:
|
||||
self._evaluate("Compiling", inputset)
|
||||
assert self.graph is not None
|
||||
|
||||
mlir = GraphConverter.convert(self.graph)
|
||||
mlir = GraphConverter.convert(self.graph, virtual=virtual)
|
||||
self.artifacts.add_mlir_to_compile(mlir)
|
||||
|
||||
if show_graph or show_mlir:
|
||||
@@ -298,7 +302,7 @@ class Compiler:
|
||||
columns = min(longest_line, 80)
|
||||
else:
|
||||
columns = min(longest_line, columns)
|
||||
except OSError:
|
||||
except OSError: # pragma: no cover
|
||||
columns = min(longest_line, 80)
|
||||
|
||||
if show_graph:
|
||||
@@ -321,7 +325,7 @@ class Compiler:
|
||||
|
||||
print()
|
||||
|
||||
return Circuit(self.graph, mlir, self.configuration)
|
||||
return Circuit(self.graph, mlir, self.configuration, virtual=virtual)
|
||||
|
||||
except Exception: # pragma: no cover
|
||||
|
||||
|
||||
@@ -79,6 +79,7 @@ def compiler(
|
||||
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
|
||||
show_graph: bool = False,
|
||||
show_mlir: bool = False,
|
||||
virtual: bool = False,
|
||||
) -> Circuit:
|
||||
"""
|
||||
Compile the function into a circuit.
|
||||
@@ -93,12 +94,15 @@ def compiler(
|
||||
show_mlir (bool, default = False):
|
||||
whether to print the compiled mlir
|
||||
|
||||
virtual (bool, default = False):
|
||||
whether to simulate the computation to allow large bit-widths
|
||||
|
||||
Returns:
|
||||
Circuit:
|
||||
compiled circuit
|
||||
"""
|
||||
|
||||
return self.compiler.compile(inputset, show_graph, show_mlir)
|
||||
return self.compiler.compile(inputset, show_graph, show_mlir, virtual)
|
||||
|
||||
return Compilable(function)
|
||||
|
||||
|
||||
@@ -277,7 +277,7 @@ class GraphConverter:
|
||||
nx_graph.add_edge(add_offset, node, input_idx=variable_input_index)
|
||||
|
||||
@staticmethod
|
||||
def convert(graph: Graph) -> str:
|
||||
def convert(graph: Graph, virtual: bool = False) -> str:
|
||||
"""
|
||||
Convert a computation graph to its corresponding MLIR representation.
|
||||
|
||||
@@ -285,6 +285,9 @@ class GraphConverter:
|
||||
graph (Graph):
|
||||
computation graph to be converted
|
||||
|
||||
virtual (bool, default = False):
|
||||
whether to circuit will be virtual
|
||||
|
||||
Returns:
|
||||
str:
|
||||
textual MLIR representation corresponding to `graph`
|
||||
@@ -293,6 +296,9 @@ class GraphConverter:
|
||||
graph = deepcopy(graph)
|
||||
|
||||
GraphConverter._check_graph_convertibility(graph)
|
||||
if virtual:
|
||||
return "Virtual circuits doesn't have MLIR."
|
||||
|
||||
GraphConverter._update_bit_widths(graph)
|
||||
GraphConverter._offset_negative_lookup_table_inputs(graph)
|
||||
|
||||
|
||||
@@ -129,3 +129,38 @@ def test_circuit_bad_run(helpers):
|
||||
assert str(excinfo.value) == (
|
||||
"Expected argument 1 to be EncryptedScalar<uint5> but it's EncryptedScalar<uint7>"
|
||||
)
|
||||
|
||||
|
||||
def test_circuit_virtual_explicit_api(helpers):
|
||||
"""
|
||||
Test `keygen`, `encrypt`, `run`, and `decrypt` methods of `Circuit` class with virtual circuit.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@compiler({"x": "encrypted", "y": "encrypted"}, configuration=configuration)
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
inputset = [(np.random.randint(0, 2 ** 4), np.random.randint(0, 2 ** 5)) for _ in range(100)]
|
||||
circuit = f.compile(inputset, virtual=True)
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
circuit.keygen()
|
||||
|
||||
assert str(excinfo.value) == "Virtual circuits cannot use `keygen` method"
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
circuit.encrypt(1, 2)
|
||||
|
||||
assert str(excinfo.value) == "Virtual circuits cannot use `encrypt` method"
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
circuit.run(None)
|
||||
|
||||
assert str(excinfo.value) == "Virtual circuits cannot use `run` method"
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
circuit.decrypt(None)
|
||||
|
||||
assert str(excinfo.value) == "Virtual circuits cannot use `decrypt` method"
|
||||
|
||||
@@ -120,3 +120,25 @@ def test_compiler_bad_compile(helpers):
|
||||
compiler.compile()
|
||||
|
||||
assert str(excinfo.value) == "Compiling function 'f' without an inputset is not supported"
|
||||
|
||||
|
||||
def test_compiler_virtual_compile(helpers, capsys):
|
||||
"""
|
||||
Test `compile` method of `Compiler` class with virtual=True.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
def f(x):
|
||||
return x + 400
|
||||
|
||||
compiler = Compiler(f, {"x": "encrypted"}, configuration=configuration)
|
||||
circuit = compiler.compile(inputset=range(400), virtual=True)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out.strip() == (
|
||||
"Warning: You are using virtual compilation, "
|
||||
"which means the evaluation will not be homomorphic."
|
||||
)
|
||||
|
||||
assert circuit.encrypt_run_decrypt(200) == 600
|
||||
|
||||
Reference in New Issue
Block a user