mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(frontend): simulate execution using the compiler
This commit is contained in:
@@ -7,6 +7,7 @@ Declaration of `Circuit` class.
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from concrete.compiler import SimulatedValueDecrypter, SimulatedValueExporter
|
||||
|
||||
from ..internal.utils import assert_that
|
||||
from ..representation import Graph
|
||||
@@ -14,6 +15,7 @@ from .client import Client
|
||||
from .configuration import Configuration
|
||||
from .keys import Keys
|
||||
from .server import Server
|
||||
from .utils import validate_input_args
|
||||
from .value import Value
|
||||
|
||||
# pylint: enable=import-error,no-member,no-name-in-module
|
||||
@@ -31,6 +33,7 @@ class Circuit:
|
||||
|
||||
client: Client
|
||||
server: Server
|
||||
simulator: Server
|
||||
|
||||
def __init__(self, graph: Graph, mlir: str, configuration: Optional[Configuration] = None):
|
||||
self.configuration = configuration if configuration is not None else Configuration()
|
||||
@@ -38,22 +41,36 @@ class Circuit:
|
||||
self.graph = graph
|
||||
self.mlir = mlir
|
||||
|
||||
self._initialize_client_and_server()
|
||||
self._initialize_circuit()
|
||||
|
||||
def _initialize_client_and_server(self):
|
||||
self.server = Server.create(self.mlir, self.configuration)
|
||||
def _initialize_circuit(self):
|
||||
if self.configuration.fhe_execution:
|
||||
self.enable_fhe_execution()
|
||||
|
||||
keyset_cache_directory = None
|
||||
if self.configuration.use_insecure_key_cache:
|
||||
assert_that(self.configuration.enable_unsafe_features)
|
||||
assert_that(self.configuration.insecure_key_cache_location is not None)
|
||||
keyset_cache_directory = self.configuration.insecure_key_cache_location
|
||||
|
||||
self.client = Client(self.server.client_specs, keyset_cache_directory)
|
||||
if self.configuration.fhe_simulation:
|
||||
self.enable_fhe_simulation()
|
||||
|
||||
def __str__(self):
|
||||
return self.graph.format()
|
||||
|
||||
def enable_fhe_simulation(self):
|
||||
"""Enable fhe simulation mode."""
|
||||
if not hasattr(self, "simulator"):
|
||||
self.simulator = Server.create(self.mlir, self.configuration, is_simulated=True)
|
||||
|
||||
def enable_fhe_execution(self):
|
||||
"""Enable fhe execution mode."""
|
||||
if not hasattr(self, "server"):
|
||||
self.server = Server.create(self.mlir, self.configuration)
|
||||
|
||||
keyset_cache_directory = None
|
||||
if self.configuration.use_insecure_key_cache:
|
||||
assert_that(self.configuration.enable_unsafe_features)
|
||||
assert_that(self.configuration.insecure_key_cache_location is not None)
|
||||
keyset_cache_directory = self.configuration.insecure_key_cache_location
|
||||
|
||||
self.client = Client(self.server.client_specs, keyset_cache_directory)
|
||||
|
||||
def simulate(self, *args: Any) -> Any:
|
||||
"""
|
||||
Simulate execution of the circuit.
|
||||
@@ -66,8 +83,31 @@ class Circuit:
|
||||
Any:
|
||||
result of the simulation
|
||||
"""
|
||||
if not hasattr(self, "simulator"):
|
||||
message = "Simulation isn't enabled. You can call enable_fhe_simulation() to enable it"
|
||||
raise RuntimeError(message)
|
||||
|
||||
return self.graph(*args, p_error=self.p_error)
|
||||
ordered_validated_args = validate_input_args(self.simulator.client_specs, *args)
|
||||
|
||||
exporter = SimulatedValueExporter.new(self.simulator.client_specs.client_parameters)
|
||||
exported = [
|
||||
None
|
||||
if arg is None
|
||||
else Value(
|
||||
exporter.export_tensor(position, arg.flatten().tolist(), list(arg.shape))
|
||||
if isinstance(arg, np.ndarray) and arg.shape != ()
|
||||
else exporter.export_scalar(position, int(arg))
|
||||
)
|
||||
for position, arg in enumerate(ordered_validated_args)
|
||||
]
|
||||
results = self.simulator.run(*exported)
|
||||
if not isinstance(results, tuple):
|
||||
results = (results,)
|
||||
decrypter = SimulatedValueDecrypter.new(self.simulator.client_specs.client_parameters)
|
||||
decrypted = tuple(
|
||||
decrypter.decrypt(position, result.inner) for position, result in enumerate(results)
|
||||
)
|
||||
return decrypted if len(decrypted) != 1 else decrypted[0]
|
||||
|
||||
@property
|
||||
def keys(self) -> Keys:
|
||||
@@ -130,6 +170,11 @@ class Circuit:
|
||||
Union[Value, Tuple[Value, ...]]:
|
||||
result(s) of evaluation
|
||||
"""
|
||||
if not hasattr(self, "server"):
|
||||
message = (
|
||||
"FHE execution isn't enabled. You can call enable_fhe_execution() to enable it"
|
||||
)
|
||||
raise RuntimeError(message)
|
||||
|
||||
self.keygen(force=False)
|
||||
return self.server.run(*args, evaluation_keys=self.client.evaluation_keys)
|
||||
|
||||
@@ -4,20 +4,17 @@ Declaration of `Client` class.
|
||||
|
||||
# pylint: disable=import-error,no-member,no-name-in-module
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from concrete.compiler import EvaluationKeys, ValueDecrypter, ValueExporter
|
||||
|
||||
from ..dtypes.integer import SignedInteger, UnsignedInteger
|
||||
from ..internal.utils import assert_that
|
||||
from ..values import ValueDescription
|
||||
from .keys import Keys
|
||||
from .specs import ClientSpecs
|
||||
from .utils import validate_input_args
|
||||
from .value import Value
|
||||
|
||||
# pylint: enable=import-error,no-member,no-name-in-module
|
||||
@@ -133,64 +130,7 @@ class Client:
|
||||
encrypted argument(s) for evaluation
|
||||
"""
|
||||
|
||||
client_parameters_json = json.loads(self.specs.client_parameters.serialize())
|
||||
assert_that("inputs" in client_parameters_json)
|
||||
input_specs = client_parameters_json["inputs"]
|
||||
|
||||
if len(args) != len(input_specs):
|
||||
message = f"Expected {len(input_specs)} inputs but got {len(args)}"
|
||||
raise ValueError(message)
|
||||
|
||||
sanitized_args: Dict[int, Optional[Union[int, np.ndarray]]] = {}
|
||||
for index, (arg, spec) in enumerate(zip(args, input_specs)):
|
||||
if arg is None:
|
||||
sanitized_args[index] = None
|
||||
continue
|
||||
|
||||
if isinstance(arg, list):
|
||||
arg = np.array(arg)
|
||||
|
||||
is_valid = isinstance(arg, (int, np.integer)) or (
|
||||
isinstance(arg, np.ndarray) and np.issubdtype(arg.dtype, np.integer)
|
||||
)
|
||||
|
||||
width = spec["shape"]["width"]
|
||||
is_signed = spec["shape"]["sign"]
|
||||
shape = tuple(spec["shape"]["dimensions"])
|
||||
is_encrypted = spec["encryption"] is not None
|
||||
|
||||
expected_dtype = SignedInteger(width) if is_signed else UnsignedInteger(width)
|
||||
expected_value = ValueDescription(expected_dtype, shape, is_encrypted)
|
||||
if is_valid:
|
||||
expected_min = expected_dtype.min()
|
||||
expected_max = expected_dtype.max()
|
||||
|
||||
if not is_encrypted:
|
||||
# clear integers are signless
|
||||
# (e.g., 8-bit clear integer can be in range -128, 255)
|
||||
expected_min = -(expected_max // 2) - 1
|
||||
|
||||
actual_min = arg if isinstance(arg, int) else arg.min()
|
||||
actual_max = arg if isinstance(arg, int) else arg.max()
|
||||
actual_shape = () if isinstance(arg, int) else arg.shape
|
||||
|
||||
is_valid = (
|
||||
actual_min >= expected_min
|
||||
and actual_max <= expected_max
|
||||
and actual_shape == expected_value.shape
|
||||
)
|
||||
|
||||
if is_valid:
|
||||
sanitized_args[index] = arg
|
||||
|
||||
if not is_valid:
|
||||
actual_value = ValueDescription.of(arg, is_encrypted=is_encrypted)
|
||||
message = (
|
||||
f"Expected argument {index} to be {expected_value} but it's {actual_value}"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
ordered_sanitized_args = [sanitized_args[i] for i in range(len(sanitized_args))]
|
||||
ordered_sanitized_args = validate_input_args(self.specs, *args)
|
||||
|
||||
self.keygen(force=False)
|
||||
keyset = self.keys._keyset # pylint: disable=protected-access
|
||||
|
||||
@@ -434,6 +434,14 @@ class Compiler:
|
||||
self._evaluate("Compiling", inputset)
|
||||
assert self.graph is not None
|
||||
|
||||
if len(self.graph.output_nodes) > 1:
|
||||
fmtd_graph = self.graph.format(
|
||||
highlighted_result=["multiple outputs are not supported"],
|
||||
show_bounds=False,
|
||||
)
|
||||
message = "Function you are trying to compile cannot be compiled\n\n" + fmtd_graph
|
||||
raise RuntimeError(message)
|
||||
|
||||
mlir = GraphConverter().convert(self.graph, self.configuration)
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_mlir_to_compile(mlir)
|
||||
@@ -507,9 +515,10 @@ class Compiler:
|
||||
|
||||
circuit = Circuit(self.graph, mlir, self.configuration)
|
||||
|
||||
client_parameters = circuit.client.specs.client_parameters
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_client_parameters(client_parameters.serialize())
|
||||
if hasattr(circuit, "client"):
|
||||
client_parameters = circuit.client.specs.client_parameters
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_client_parameters(client_parameters.serialize())
|
||||
|
||||
if show_optimizer:
|
||||
print("-" * columns)
|
||||
@@ -536,14 +545,6 @@ class Compiler:
|
||||
self.configuration = old_configuration
|
||||
self.artifacts = old_artifacts
|
||||
|
||||
if self.graph and len(self.graph.output_nodes) > 1:
|
||||
graph = self.graph.format(
|
||||
highlighted_result=["multiple outputs are not supported"],
|
||||
show_bounds=False,
|
||||
)
|
||||
message = "Function you are trying to compile cannot be compiled\n\n" + graph
|
||||
raise RuntimeError(message)
|
||||
|
||||
return circuit
|
||||
|
||||
# pylint: enable=too-many-branches,too-many-statements
|
||||
|
||||
@@ -66,6 +66,8 @@ class Configuration:
|
||||
show_progress: bool
|
||||
progress_title: str
|
||||
progress_tag: Union[bool, int]
|
||||
fhe_simulation: bool
|
||||
fhe_execution: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -91,6 +93,8 @@ class Configuration:
|
||||
show_progress: bool = False,
|
||||
progress_title: str = "",
|
||||
progress_tag: Union[bool, int] = False,
|
||||
fhe_simulation: bool = False,
|
||||
fhe_execution: bool = True,
|
||||
): # pylint: disable=too-many-arguments
|
||||
self.verbose = verbose
|
||||
self.show_graph = show_graph
|
||||
@@ -114,6 +118,8 @@ class Configuration:
|
||||
self.show_progress = show_progress
|
||||
self.progress_title = progress_title
|
||||
self.progress_tag = progress_tag
|
||||
self.fhe_simulation = fhe_simulation
|
||||
self.fhe_execution = fhe_execution
|
||||
|
||||
self._validate()
|
||||
|
||||
@@ -146,6 +152,8 @@ class Configuration:
|
||||
show_progress: Union[Keep, bool] = KEEP,
|
||||
progress_title: Union[Keep, str] = KEEP,
|
||||
progress_tag: Union[Keep, Union[bool, int]] = KEEP,
|
||||
fhe_simulation: Union[Keep, bool] = KEEP,
|
||||
fhe_execution: Union[Keep, bool] = KEEP,
|
||||
) -> "Configuration":
|
||||
"""
|
||||
Get a new configuration from another one specified changes.
|
||||
|
||||
@@ -45,6 +45,7 @@ class Server:
|
||||
"""
|
||||
|
||||
client_specs: ClientSpecs
|
||||
is_simulated: bool
|
||||
|
||||
_output_dir: Optional[tempfile.TemporaryDirectory]
|
||||
_support: Union[JITSupport, LibrarySupport]
|
||||
@@ -62,8 +63,10 @@ class Server:
|
||||
support: Union[JITSupport, LibrarySupport],
|
||||
compilation_result: Union[JITCompilationResult, LibraryCompilationResult],
|
||||
server_lambda: Union[JITLambda, LibraryLambda],
|
||||
is_simulated: bool,
|
||||
):
|
||||
self.client_specs = client_specs
|
||||
self.is_simulated = is_simulated
|
||||
|
||||
self._output_dir = output_dir
|
||||
self._support = support
|
||||
@@ -78,7 +81,7 @@ class Server:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create(mlir: str, configuration: Configuration) -> "Server":
|
||||
def create(mlir: str, configuration: Configuration, is_simulated: bool = False) -> "Server":
|
||||
"""
|
||||
Create a server using MLIR and output sign information.
|
||||
|
||||
@@ -86,11 +89,15 @@ class Server:
|
||||
mlir (str):
|
||||
mlir to compile
|
||||
|
||||
configuration (Optional[Configuration], default = None):
|
||||
configuration (Configuration):
|
||||
configuration to use
|
||||
|
||||
is_simulated (bool, default = False):
|
||||
whether to compile in simulation mode or not
|
||||
"""
|
||||
|
||||
options = CompilationOptions.new("main")
|
||||
options.simulation(is_simulated)
|
||||
|
||||
options.set_loop_parallelize(configuration.loop_parallelize)
|
||||
options.set_dataflow_parallelize(configuration.dataflow_parallelize)
|
||||
@@ -142,7 +149,8 @@ class Server:
|
||||
elif parameter_selection_strategy == ParameterSelectionStrategy.MULTI: # pragma: no cover
|
||||
options.set_optimizer_strategy(OptimizerStrategy.DAG_MULTI)
|
||||
|
||||
if configuration.jit:
|
||||
if configuration.jit: # pragma: no cover
|
||||
# JIT to be dropped soon
|
||||
output_dir = None
|
||||
|
||||
support = JITSupport.new()
|
||||
@@ -164,7 +172,14 @@ class Server:
|
||||
client_parameters = support.load_client_parameters(compilation_result)
|
||||
client_specs = ClientSpecs(client_parameters)
|
||||
|
||||
result = Server(client_specs, output_dir, support, compilation_result, server_lambda)
|
||||
result = Server(
|
||||
client_specs,
|
||||
output_dir,
|
||||
support,
|
||||
compilation_result,
|
||||
server_lambda,
|
||||
is_simulated,
|
||||
)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
result._mlir = mlir
|
||||
@@ -199,6 +214,9 @@ class Server:
|
||||
with open(Path(tmp) / "circuit.mlir", "w", encoding="utf-8") as f:
|
||||
f.write(self._mlir)
|
||||
|
||||
with open(Path(tmp) / "is_simulated", "w", encoding="utf-8") as f:
|
||||
f.write("1" if self.is_simulated else "0")
|
||||
|
||||
with open(Path(tmp) / "configuration.json", "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(self._configuration.__dict__))
|
||||
|
||||
@@ -206,13 +224,17 @@ class Server:
|
||||
|
||||
return
|
||||
|
||||
if self._output_dir is None:
|
||||
if self._output_dir is None: # pragma: no cover
|
||||
# JIT to be dropped soon
|
||||
message = "Just-in-Time compilation cannot be saved"
|
||||
raise RuntimeError(message)
|
||||
|
||||
with open(Path(self._output_dir.name) / "client.specs.json", "wb") as f:
|
||||
f.write(self.client_specs.serialize())
|
||||
|
||||
with open(Path(self._output_dir.name) / "is_simulated", "w", encoding="utf-8") as f:
|
||||
f.write("1" if self.is_simulated else "0")
|
||||
|
||||
shutil.make_archive(path, "zip", self._output_dir.name)
|
||||
|
||||
@staticmethod
|
||||
@@ -236,6 +258,9 @@ class Server:
|
||||
|
||||
shutil.unpack_archive(path, str(output_dir_path), "zip")
|
||||
|
||||
with open(output_dir_path / "is_simulated", "r", encoding="utf-8") as f:
|
||||
is_simulated = f.read() == "1"
|
||||
|
||||
if (output_dir_path / "circuit.mlir").exists():
|
||||
with open(output_dir_path / "circuit.mlir", "r", encoding="utf-8") as f:
|
||||
mlir = f.read()
|
||||
@@ -243,7 +268,7 @@ class Server:
|
||||
with open(output_dir_path / "configuration.json", "r", encoding="utf-8") as f:
|
||||
configuration = Configuration().fork(**json.load(f))
|
||||
|
||||
return Server.create(mlir, configuration)
|
||||
return Server.create(mlir, configuration, is_simulated)
|
||||
|
||||
with open(output_dir_path / "client.specs.json", "rb") as f:
|
||||
client_specs = ClientSpecs.deserialize(f.read())
|
||||
@@ -256,12 +281,14 @@ class Server:
|
||||
compilation_result = support.reload("main")
|
||||
server_lambda = support.load_server_lambda(compilation_result)
|
||||
|
||||
return Server(client_specs, output_dir, support, compilation_result, server_lambda)
|
||||
return Server(
|
||||
client_specs, output_dir, support, compilation_result, server_lambda, is_simulated
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
*args: Optional[Union[Value, Tuple[Optional[Value], ...]]],
|
||||
evaluation_keys: EvaluationKeys,
|
||||
evaluation_keys: Optional[EvaluationKeys] = None,
|
||||
) -> Union[Value, Tuple[Value, ...]]:
|
||||
"""
|
||||
Evaluate.
|
||||
@@ -270,14 +297,18 @@ class Server:
|
||||
*args (Optional[Union[Value, Tuple[Optional[Value], ...]]]):
|
||||
argument(s) for evaluation
|
||||
|
||||
evaluation_keys (EvaluationKeys):
|
||||
evaluation keys
|
||||
evaluation_keys (Optional[EvaluationKeys], default = None):
|
||||
evaluation keys required for fhe execution
|
||||
|
||||
Returns:
|
||||
Union[Value, Tuple[Value, ...]]:
|
||||
result(s) of evaluation
|
||||
"""
|
||||
|
||||
if evaluation_keys is None and not self.is_simulated:
|
||||
message = "Expected evaluation keys to be provided when not in simulation mode"
|
||||
raise RuntimeError(message)
|
||||
|
||||
flattened_args: List[Optional[Value]] = []
|
||||
for arg in args:
|
||||
if isinstance(arg, tuple):
|
||||
@@ -298,7 +329,17 @@ class Server:
|
||||
buffers.append(arg.inner)
|
||||
|
||||
public_args = PublicArguments.new(self.client_specs.client_parameters, buffers)
|
||||
public_result = self._support.server_call(self._server_lambda, public_args, evaluation_keys)
|
||||
|
||||
if self.is_simulated:
|
||||
if isinstance(self._support, JITSupport): # pragma: no cover
|
||||
# JIT to be dropped soon
|
||||
message = "Can't run simulation while using JIT"
|
||||
raise RuntimeError(message)
|
||||
public_result = self._support.simulate(self._server_lambda, public_args)
|
||||
else:
|
||||
public_result = self._support.server_call(
|
||||
self._server_lambda, public_args, evaluation_keys
|
||||
)
|
||||
|
||||
result = tuple(Value(public_result.get_value(i)) for i in range(public_result.n_values()))
|
||||
return result if len(result) > 1 else result[0]
|
||||
|
||||
@@ -2,19 +2,96 @@
|
||||
Declaration of various functions and constants related to compilation.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
from ..dtypes import Float, Integer
|
||||
from ..dtypes import Float, Integer, SignedInteger, UnsignedInteger
|
||||
from ..representation import Graph, Node, Operation
|
||||
from ..values import ValueDescription
|
||||
from .artifacts import DebugArtifacts
|
||||
from .specs import ClientSpecs
|
||||
|
||||
# ruff: noqa: ERA001
|
||||
|
||||
|
||||
def validate_input_args(
|
||||
client_specs: ClientSpecs,
|
||||
*args: Optional[Union[int, np.ndarray, List]],
|
||||
) -> List[Optional[Union[int, np.ndarray]]]:
|
||||
"""Validate input arguments.
|
||||
|
||||
Args:
|
||||
client_specs (ClientSpecs):
|
||||
client specification
|
||||
*args (Optional[Union[int, np.ndarray, List]]):
|
||||
argument(s) for evaluation
|
||||
|
||||
Returns:
|
||||
List[Optional[Union[int, np.ndarray]]]: ordered validated args
|
||||
"""
|
||||
client_parameters_json = json.loads(client_specs.client_parameters.serialize())
|
||||
assert "inputs" in client_parameters_json
|
||||
input_specs = client_parameters_json["inputs"]
|
||||
if len(args) != len(input_specs):
|
||||
message = f"Expected {len(input_specs)} inputs but got {len(args)}"
|
||||
raise ValueError(message)
|
||||
|
||||
sanitized_args: Dict[int, Optional[Union[int, np.ndarray]]] = {}
|
||||
for index, (arg, spec) in enumerate(zip(args, input_specs)):
|
||||
if arg is None:
|
||||
sanitized_args[index] = None
|
||||
continue
|
||||
|
||||
if isinstance(arg, list):
|
||||
arg = np.array(arg)
|
||||
|
||||
is_valid = isinstance(arg, (int, np.integer)) or (
|
||||
isinstance(arg, np.ndarray) and np.issubdtype(arg.dtype, np.integer)
|
||||
)
|
||||
|
||||
width = spec["shape"]["width"]
|
||||
is_signed = spec["shape"]["sign"]
|
||||
shape = tuple(spec["shape"]["dimensions"])
|
||||
is_encrypted = spec["encryption"] is not None
|
||||
|
||||
expected_dtype = SignedInteger(width) if is_signed else UnsignedInteger(width)
|
||||
expected_value = ValueDescription(expected_dtype, shape, is_encrypted)
|
||||
if is_valid:
|
||||
expected_min = expected_dtype.min()
|
||||
expected_max = expected_dtype.max()
|
||||
|
||||
if not is_encrypted:
|
||||
# clear integers are signless
|
||||
# (e.g., 8-bit clear integer can be in range -128, 255)
|
||||
expected_min = -(expected_max // 2) - 1
|
||||
|
||||
actual_min = arg if isinstance(arg, int) else arg.min()
|
||||
actual_max = arg if isinstance(arg, int) else arg.max()
|
||||
actual_shape = () if isinstance(arg, int) else arg.shape
|
||||
|
||||
is_valid = (
|
||||
actual_min >= expected_min
|
||||
and actual_max <= expected_max
|
||||
and actual_shape == expected_value.shape
|
||||
)
|
||||
|
||||
if is_valid:
|
||||
sanitized_args[index] = arg
|
||||
|
||||
if not is_valid:
|
||||
actual_value = ValueDescription.of(arg, is_encrypted=is_encrypted)
|
||||
message = f"Expected argument {index} to be {expected_value} but it's {actual_value}"
|
||||
raise ValueError(message)
|
||||
|
||||
ordered_sanitized_args = [sanitized_args[i] for i in range(len(sanitized_args))]
|
||||
return ordered_sanitized_args
|
||||
|
||||
|
||||
def fuse(graph: Graph, artifacts: Optional[DebugArtifacts] = None):
|
||||
"""
|
||||
Fuse appropriate subgraphs in a graph to a single Operation.Generic node.
|
||||
|
||||
@@ -352,7 +352,7 @@ def test_bad_server_save(helpers):
|
||||
Test `save` method of `Server` class with bad parameters.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
configuration = helpers.configuration().fork(jit=True)
|
||||
|
||||
@compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
@@ -367,43 +367,6 @@ def test_bad_server_save(helpers):
|
||||
assert str(excinfo.value) == "Just-in-Time compilation cannot be saved"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("p_error", [0.75, 0.5, 0.4, 0.25, 0.2, 0.1, 0.01, 0.001])
|
||||
@pytest.mark.parametrize("bit_width", [5])
|
||||
@pytest.mark.parametrize("sample_size", [1_000_000])
|
||||
@pytest.mark.parametrize("tolerance", [0.1])
|
||||
def test_p_error_simulation(p_error, bit_width, sample_size, tolerance, helpers):
|
||||
"""
|
||||
Test p_error simulation.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration().fork(global_p_error=None)
|
||||
|
||||
table = LookupTable([0] + [x - 1 for x in range(1, 2**bit_width)])
|
||||
|
||||
@compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return table[x + 1]
|
||||
|
||||
inputset = [np.random.randint(0, (2**bit_width) - 1, size=(sample_size,)) for _ in range(100)]
|
||||
circuit = function.compile(inputset, configuration=configuration, p_error=p_error)
|
||||
|
||||
assert circuit.p_error < p_error
|
||||
|
||||
sample = np.random.randint(0, (2**bit_width) - 1, size=(sample_size,))
|
||||
output = circuit.simulate(sample)
|
||||
|
||||
errors = np.sum(output != sample)
|
||||
|
||||
expected_number_of_errors_on_average = sample_size * circuit.p_error
|
||||
tolerated_difference = expected_number_of_errors_on_average * tolerance
|
||||
|
||||
acceptable_number_of_errors = [
|
||||
round(expected_number_of_errors_on_average - tolerated_difference),
|
||||
round(expected_number_of_errors_on_average + tolerated_difference),
|
||||
]
|
||||
assert acceptable_number_of_errors[0] <= errors <= acceptable_number_of_errors[1]
|
||||
|
||||
|
||||
def test_circuit_run_with_unused_arg(helpers):
|
||||
"""
|
||||
Test `encrypt_run_decrypt` method of `Circuit` class with unused arguments.
|
||||
@@ -444,3 +407,125 @@ def test_dataflow_circuit(helpers):
|
||||
circuit = f.compile(inputset, configuration)
|
||||
|
||||
assert circuit.encrypt_run_decrypt(5, 6) == 28
|
||||
|
||||
|
||||
def test_circuit_sim_disabled(helpers):
|
||||
"""
|
||||
Test attempt to simulate without enabling fhe simulation.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@compiler({"x": "encrypted", "y": "encrypted"})
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(2)]
|
||||
circuit = f.compile(inputset, configuration)
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
circuit.simulate(*inputset[0])
|
||||
assert (
|
||||
str(excinfo.value)
|
||||
== "Simulation isn't enabled. You can call enable_fhe_simulation() to enable it"
|
||||
)
|
||||
|
||||
|
||||
def test_circuit_fhe_exec_disabled(helpers):
|
||||
"""
|
||||
Test attempt to run fhe execution without it being enabled.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@compiler({"x": "encrypted", "y": "encrypted"})
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(2)]
|
||||
circuit = f.compile(inputset, configuration.fork(fhe_execution=False))
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
# as we can't encrypt, we just pass plain inputs, and it should lead to the expected error
|
||||
circuit.run(*inputset[0], None)
|
||||
assert (
|
||||
str(excinfo.value)
|
||||
== "FHE execution isn't enabled. You can call enable_fhe_execution() to enable it"
|
||||
)
|
||||
|
||||
|
||||
def test_circuit_fhe_exec_no_eval_keys(helpers):
|
||||
"""
|
||||
Test attempt to run fhe execution without eval keys.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@compiler({"x": "encrypted", "y": "encrypted"})
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(2)]
|
||||
circuit = f.compile(inputset, configuration)
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
# as we can't encrypt, we just pass plain inputs, and it should lead to the expected error
|
||||
encrypted_args = inputset[0]
|
||||
circuit.server.run(*encrypted_args)
|
||||
assert (
|
||||
str(excinfo.value) == "Expected evaluation keys to be provided when not in simulation mode"
|
||||
)
|
||||
|
||||
|
||||
def test_circuit_eval_graph_scalar(helpers):
|
||||
"""
|
||||
Test evaluation of the graph.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@compiler({"x": "encrypted", "y": "encrypted"})
|
||||
def f(x, y):
|
||||
lut = LookupTable(list(range(128)))
|
||||
return lut[x + y]
|
||||
|
||||
inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(2)]
|
||||
circuit = f.compile(inputset, configuration.fork(fhe_simulation=False, fhe_execution=False))
|
||||
assert f(*inputset[0]) == circuit.graph(*inputset[0], p_error=0.01)
|
||||
|
||||
|
||||
def test_circuit_eval_graph_tensor(helpers):
|
||||
"""
|
||||
Test evaluation of the graph.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@compiler({"x": "encrypted", "y": "encrypted"})
|
||||
def f(x, y):
|
||||
lut = LookupTable(list(range(128)))
|
||||
return lut[x + y]
|
||||
|
||||
inputset = [
|
||||
(
|
||||
np.random.randint(0, 2**4, size=[2, 2]),
|
||||
np.random.randint(0, 2**5, size=[2, 2]),
|
||||
)
|
||||
for _ in range(2)
|
||||
]
|
||||
circuit = f.compile(inputset, configuration.fork(fhe_simulation=False, fhe_execution=False))
|
||||
assert np.all(f(*inputset[0]) == circuit.graph(*inputset[0], p_error=0.01))
|
||||
|
||||
|
||||
def test_circuit_compile_sim_only(helpers):
|
||||
"""
|
||||
Test compiling with simulation only.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@compiler({"x": "encrypted", "y": "encrypted"})
|
||||
def f(x, y):
|
||||
lut = LookupTable(list(range(128)))
|
||||
return lut[x + y]
|
||||
|
||||
inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(2)]
|
||||
circuit = f.compile(inputset, configuration.fork(fhe_simulation=True, fhe_execution=False))
|
||||
assert f(*inputset[0]) == circuit.simulate(*inputset[0])
|
||||
|
||||
@@ -254,7 +254,6 @@ class Helpers:
|
||||
function: Callable,
|
||||
sample: Union[Any, List[Any]],
|
||||
retries: int = 1,
|
||||
simulate: bool = False,
|
||||
):
|
||||
"""
|
||||
Assert that `circuit` is behaves the same as `function` on `sample`.
|
||||
@@ -296,9 +295,7 @@ class Helpers:
|
||||
|
||||
for i in range(retries):
|
||||
expected = sanitize(function(*sample))
|
||||
actual = sanitize(
|
||||
circuit.simulate(*sample) if simulate else circuit.encrypt_run_decrypt(*sample)
|
||||
)
|
||||
actual = sanitize(circuit.encrypt_run_decrypt(*sample))
|
||||
|
||||
if all(np.array_equal(e, a) for e, a in zip(expected, actual)):
|
||||
break
|
||||
@@ -317,6 +314,32 @@ Actual Output
|
||||
"""
|
||||
raise AssertionError(message)
|
||||
|
||||
try:
|
||||
circuit.enable_fhe_simulation()
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
print(f"Catched exception while enabling simulation: {e}")
|
||||
return
|
||||
for i in range(retries):
|
||||
expected = sanitize(function(*sample))
|
||||
actual = sanitize(circuit.simulate(*sample))
|
||||
|
||||
if all(np.array_equal(e, a) for e, a in zip(expected, actual)):
|
||||
break
|
||||
|
||||
if i == retries - 1:
|
||||
message = f"""
|
||||
|
||||
Expected Output During Simulation
|
||||
=================================
|
||||
{expected}
|
||||
|
||||
Actual Output During Simulation
|
||||
===============================
|
||||
{actual}
|
||||
|
||||
"""
|
||||
raise AssertionError(message)
|
||||
|
||||
@staticmethod
|
||||
def check_str(expected: str, actual: str):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user