feat(frontend): simulate execution using the compiler

This commit is contained in:
youben11
2023-06-22 14:34:12 +01:00
committed by Ayoub Benaissa
parent 648e868ffe
commit 9f54184375
8 changed files with 360 additions and 140 deletions

View File

@@ -7,6 +7,7 @@ Declaration of `Circuit` class.
from typing import Any, List, Optional, Tuple, Union
import numpy as np
from concrete.compiler import SimulatedValueDecrypter, SimulatedValueExporter
from ..internal.utils import assert_that
from ..representation import Graph
@@ -14,6 +15,7 @@ from .client import Client
from .configuration import Configuration
from .keys import Keys
from .server import Server
from .utils import validate_input_args
from .value import Value
# pylint: enable=import-error,no-member,no-name-in-module
@@ -31,6 +33,7 @@ class Circuit:
client: Client
server: Server
simulator: Server
def __init__(self, graph: Graph, mlir: str, configuration: Optional[Configuration] = None):
self.configuration = configuration if configuration is not None else Configuration()
@@ -38,22 +41,36 @@ class Circuit:
self.graph = graph
self.mlir = mlir
self._initialize_client_and_server()
self._initialize_circuit()
def _initialize_client_and_server(self):
self.server = Server.create(self.mlir, self.configuration)
def _initialize_circuit(self):
if self.configuration.fhe_execution:
self.enable_fhe_execution()
keyset_cache_directory = None
if self.configuration.use_insecure_key_cache:
assert_that(self.configuration.enable_unsafe_features)
assert_that(self.configuration.insecure_key_cache_location is not None)
keyset_cache_directory = self.configuration.insecure_key_cache_location
self.client = Client(self.server.client_specs, keyset_cache_directory)
if self.configuration.fhe_simulation:
self.enable_fhe_simulation()
def __str__(self):
return self.graph.format()
def enable_fhe_simulation(self):
"""Enable fhe simulation mode."""
if not hasattr(self, "simulator"):
self.simulator = Server.create(self.mlir, self.configuration, is_simulated=True)
def enable_fhe_execution(self):
"""Enable fhe execution mode."""
if not hasattr(self, "server"):
self.server = Server.create(self.mlir, self.configuration)
keyset_cache_directory = None
if self.configuration.use_insecure_key_cache:
assert_that(self.configuration.enable_unsafe_features)
assert_that(self.configuration.insecure_key_cache_location is not None)
keyset_cache_directory = self.configuration.insecure_key_cache_location
self.client = Client(self.server.client_specs, keyset_cache_directory)
def simulate(self, *args: Any) -> Any:
"""
Simulate execution of the circuit.
@@ -66,8 +83,31 @@ class Circuit:
Any:
result of the simulation
"""
if not hasattr(self, "simulator"):
message = "Simulation isn't enabled. You can call enable_fhe_simulation() to enable it"
raise RuntimeError(message)
return self.graph(*args, p_error=self.p_error)
ordered_validated_args = validate_input_args(self.simulator.client_specs, *args)
exporter = SimulatedValueExporter.new(self.simulator.client_specs.client_parameters)
exported = [
None
if arg is None
else Value(
exporter.export_tensor(position, arg.flatten().tolist(), list(arg.shape))
if isinstance(arg, np.ndarray) and arg.shape != ()
else exporter.export_scalar(position, int(arg))
)
for position, arg in enumerate(ordered_validated_args)
]
results = self.simulator.run(*exported)
if not isinstance(results, tuple):
results = (results,)
decrypter = SimulatedValueDecrypter.new(self.simulator.client_specs.client_parameters)
decrypted = tuple(
decrypter.decrypt(position, result.inner) for position, result in enumerate(results)
)
return decrypted if len(decrypted) != 1 else decrypted[0]
@property
def keys(self) -> Keys:
@@ -130,6 +170,11 @@ class Circuit:
Union[Value, Tuple[Value, ...]]:
result(s) of evaluation
"""
if not hasattr(self, "server"):
message = (
"FHE execution isn't enabled. You can call enable_fhe_execution() to enable it"
)
raise RuntimeError(message)
self.keygen(force=False)
return self.server.run(*args, evaluation_keys=self.client.evaluation_keys)

View File

@@ -4,20 +4,17 @@ Declaration of `Client` class.
# pylint: disable=import-error,no-member,no-name-in-module
import json
import shutil
import tempfile
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import numpy as np
from concrete.compiler import EvaluationKeys, ValueDecrypter, ValueExporter
from ..dtypes.integer import SignedInteger, UnsignedInteger
from ..internal.utils import assert_that
from ..values import ValueDescription
from .keys import Keys
from .specs import ClientSpecs
from .utils import validate_input_args
from .value import Value
# pylint: enable=import-error,no-member,no-name-in-module
@@ -133,64 +130,7 @@ class Client:
encrypted argument(s) for evaluation
"""
client_parameters_json = json.loads(self.specs.client_parameters.serialize())
assert_that("inputs" in client_parameters_json)
input_specs = client_parameters_json["inputs"]
if len(args) != len(input_specs):
message = f"Expected {len(input_specs)} inputs but got {len(args)}"
raise ValueError(message)
sanitized_args: Dict[int, Optional[Union[int, np.ndarray]]] = {}
for index, (arg, spec) in enumerate(zip(args, input_specs)):
if arg is None:
sanitized_args[index] = None
continue
if isinstance(arg, list):
arg = np.array(arg)
is_valid = isinstance(arg, (int, np.integer)) or (
isinstance(arg, np.ndarray) and np.issubdtype(arg.dtype, np.integer)
)
width = spec["shape"]["width"]
is_signed = spec["shape"]["sign"]
shape = tuple(spec["shape"]["dimensions"])
is_encrypted = spec["encryption"] is not None
expected_dtype = SignedInteger(width) if is_signed else UnsignedInteger(width)
expected_value = ValueDescription(expected_dtype, shape, is_encrypted)
if is_valid:
expected_min = expected_dtype.min()
expected_max = expected_dtype.max()
if not is_encrypted:
# clear integers are signless
# (e.g., 8-bit clear integer can be in range -128, 255)
expected_min = -(expected_max // 2) - 1
actual_min = arg if isinstance(arg, int) else arg.min()
actual_max = arg if isinstance(arg, int) else arg.max()
actual_shape = () if isinstance(arg, int) else arg.shape
is_valid = (
actual_min >= expected_min
and actual_max <= expected_max
and actual_shape == expected_value.shape
)
if is_valid:
sanitized_args[index] = arg
if not is_valid:
actual_value = ValueDescription.of(arg, is_encrypted=is_encrypted)
message = (
f"Expected argument {index} to be {expected_value} but it's {actual_value}"
)
raise ValueError(message)
ordered_sanitized_args = [sanitized_args[i] for i in range(len(sanitized_args))]
ordered_sanitized_args = validate_input_args(self.specs, *args)
self.keygen(force=False)
keyset = self.keys._keyset # pylint: disable=protected-access

View File

@@ -434,6 +434,14 @@ class Compiler:
self._evaluate("Compiling", inputset)
assert self.graph is not None
if len(self.graph.output_nodes) > 1:
fmtd_graph = self.graph.format(
highlighted_result=["multiple outputs are not supported"],
show_bounds=False,
)
message = "Function you are trying to compile cannot be compiled\n\n" + fmtd_graph
raise RuntimeError(message)
mlir = GraphConverter().convert(self.graph, self.configuration)
if self.artifacts is not None:
self.artifacts.add_mlir_to_compile(mlir)
@@ -507,9 +515,10 @@ class Compiler:
circuit = Circuit(self.graph, mlir, self.configuration)
client_parameters = circuit.client.specs.client_parameters
if self.artifacts is not None:
self.artifacts.add_client_parameters(client_parameters.serialize())
if hasattr(circuit, "client"):
client_parameters = circuit.client.specs.client_parameters
if self.artifacts is not None:
self.artifacts.add_client_parameters(client_parameters.serialize())
if show_optimizer:
print("-" * columns)
@@ -536,14 +545,6 @@ class Compiler:
self.configuration = old_configuration
self.artifacts = old_artifacts
if self.graph and len(self.graph.output_nodes) > 1:
graph = self.graph.format(
highlighted_result=["multiple outputs are not supported"],
show_bounds=False,
)
message = "Function you are trying to compile cannot be compiled\n\n" + graph
raise RuntimeError(message)
return circuit
# pylint: enable=too-many-branches,too-many-statements

View File

@@ -66,6 +66,8 @@ class Configuration:
show_progress: bool
progress_title: str
progress_tag: Union[bool, int]
fhe_simulation: bool
fhe_execution: bool
def __init__(
self,
@@ -91,6 +93,8 @@ class Configuration:
show_progress: bool = False,
progress_title: str = "",
progress_tag: Union[bool, int] = False,
fhe_simulation: bool = False,
fhe_execution: bool = True,
): # pylint: disable=too-many-arguments
self.verbose = verbose
self.show_graph = show_graph
@@ -114,6 +118,8 @@ class Configuration:
self.show_progress = show_progress
self.progress_title = progress_title
self.progress_tag = progress_tag
self.fhe_simulation = fhe_simulation
self.fhe_execution = fhe_execution
self._validate()
@@ -146,6 +152,8 @@ class Configuration:
show_progress: Union[Keep, bool] = KEEP,
progress_title: Union[Keep, str] = KEEP,
progress_tag: Union[Keep, Union[bool, int]] = KEEP,
fhe_simulation: Union[Keep, bool] = KEEP,
fhe_execution: Union[Keep, bool] = KEEP,
) -> "Configuration":
"""
Get a new configuration from another one specified changes.

View File

@@ -45,6 +45,7 @@ class Server:
"""
client_specs: ClientSpecs
is_simulated: bool
_output_dir: Optional[tempfile.TemporaryDirectory]
_support: Union[JITSupport, LibrarySupport]
@@ -62,8 +63,10 @@ class Server:
support: Union[JITSupport, LibrarySupport],
compilation_result: Union[JITCompilationResult, LibraryCompilationResult],
server_lambda: Union[JITLambda, LibraryLambda],
is_simulated: bool,
):
self.client_specs = client_specs
self.is_simulated = is_simulated
self._output_dir = output_dir
self._support = support
@@ -78,7 +81,7 @@ class Server:
)
@staticmethod
def create(mlir: str, configuration: Configuration) -> "Server":
def create(mlir: str, configuration: Configuration, is_simulated: bool = False) -> "Server":
"""
Create a server using MLIR and output sign information.
@@ -86,11 +89,15 @@ class Server:
mlir (str):
mlir to compile
configuration (Optional[Configuration], default = None):
configuration (Configuration):
configuration to use
is_simulated (bool, default = False):
whether to compile in simulation mode or not
"""
options = CompilationOptions.new("main")
options.simulation(is_simulated)
options.set_loop_parallelize(configuration.loop_parallelize)
options.set_dataflow_parallelize(configuration.dataflow_parallelize)
@@ -142,7 +149,8 @@ class Server:
elif parameter_selection_strategy == ParameterSelectionStrategy.MULTI: # pragma: no cover
options.set_optimizer_strategy(OptimizerStrategy.DAG_MULTI)
if configuration.jit:
if configuration.jit: # pragma: no cover
# JIT to be dropped soon
output_dir = None
support = JITSupport.new()
@@ -164,7 +172,14 @@ class Server:
client_parameters = support.load_client_parameters(compilation_result)
client_specs = ClientSpecs(client_parameters)
result = Server(client_specs, output_dir, support, compilation_result, server_lambda)
result = Server(
client_specs,
output_dir,
support,
compilation_result,
server_lambda,
is_simulated,
)
# pylint: disable=protected-access
result._mlir = mlir
@@ -199,6 +214,9 @@ class Server:
with open(Path(tmp) / "circuit.mlir", "w", encoding="utf-8") as f:
f.write(self._mlir)
with open(Path(tmp) / "is_simulated", "w", encoding="utf-8") as f:
f.write("1" if self.is_simulated else "0")
with open(Path(tmp) / "configuration.json", "w", encoding="utf-8") as f:
f.write(json.dumps(self._configuration.__dict__))
@@ -206,13 +224,17 @@ class Server:
return
if self._output_dir is None:
if self._output_dir is None: # pragma: no cover
# JIT to be dropped soon
message = "Just-in-Time compilation cannot be saved"
raise RuntimeError(message)
with open(Path(self._output_dir.name) / "client.specs.json", "wb") as f:
f.write(self.client_specs.serialize())
with open(Path(self._output_dir.name) / "is_simulated", "w", encoding="utf-8") as f:
f.write("1" if self.is_simulated else "0")
shutil.make_archive(path, "zip", self._output_dir.name)
@staticmethod
@@ -236,6 +258,9 @@ class Server:
shutil.unpack_archive(path, str(output_dir_path), "zip")
with open(output_dir_path / "is_simulated", "r", encoding="utf-8") as f:
is_simulated = f.read() == "1"
if (output_dir_path / "circuit.mlir").exists():
with open(output_dir_path / "circuit.mlir", "r", encoding="utf-8") as f:
mlir = f.read()
@@ -243,7 +268,7 @@ class Server:
with open(output_dir_path / "configuration.json", "r", encoding="utf-8") as f:
configuration = Configuration().fork(**json.load(f))
return Server.create(mlir, configuration)
return Server.create(mlir, configuration, is_simulated)
with open(output_dir_path / "client.specs.json", "rb") as f:
client_specs = ClientSpecs.deserialize(f.read())
@@ -256,12 +281,14 @@ class Server:
compilation_result = support.reload("main")
server_lambda = support.load_server_lambda(compilation_result)
return Server(client_specs, output_dir, support, compilation_result, server_lambda)
return Server(
client_specs, output_dir, support, compilation_result, server_lambda, is_simulated
)
def run(
self,
*args: Optional[Union[Value, Tuple[Optional[Value], ...]]],
evaluation_keys: EvaluationKeys,
evaluation_keys: Optional[EvaluationKeys] = None,
) -> Union[Value, Tuple[Value, ...]]:
"""
Evaluate.
@@ -270,14 +297,18 @@ class Server:
*args (Optional[Union[Value, Tuple[Optional[Value], ...]]]):
argument(s) for evaluation
evaluation_keys (EvaluationKeys):
evaluation keys
evaluation_keys (Optional[EvaluationKeys], default = None):
evaluation keys required for fhe execution
Returns:
Union[Value, Tuple[Value, ...]]:
result(s) of evaluation
"""
if evaluation_keys is None and not self.is_simulated:
message = "Expected evaluation keys to be provided when not in simulation mode"
raise RuntimeError(message)
flattened_args: List[Optional[Value]] = []
for arg in args:
if isinstance(arg, tuple):
@@ -298,7 +329,17 @@ class Server:
buffers.append(arg.inner)
public_args = PublicArguments.new(self.client_specs.client_parameters, buffers)
public_result = self._support.server_call(self._server_lambda, public_args, evaluation_keys)
if self.is_simulated:
if isinstance(self._support, JITSupport): # pragma: no cover
# JIT to be dropped soon
message = "Can't run simulation while using JIT"
raise RuntimeError(message)
public_result = self._support.simulate(self._server_lambda, public_args)
else:
public_result = self._support.server_call(
self._server_lambda, public_args, evaluation_keys
)
result = tuple(Value(public_result.get_value(i)) for i in range(public_result.n_values()))
return result if len(result) > 1 else result[0]

View File

@@ -2,19 +2,96 @@
Declaration of various functions and constants related to compilation.
"""
import json
import re
from copy import deepcopy
from typing import Dict, Iterable, List, Optional, Set, Tuple
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
import networkx as nx
import numpy as np
from ..dtypes import Float, Integer
from ..dtypes import Float, Integer, SignedInteger, UnsignedInteger
from ..representation import Graph, Node, Operation
from ..values import ValueDescription
from .artifacts import DebugArtifacts
from .specs import ClientSpecs
# ruff: noqa: ERA001
def validate_input_args(
client_specs: ClientSpecs,
*args: Optional[Union[int, np.ndarray, List]],
) -> List[Optional[Union[int, np.ndarray]]]:
"""Validate input arguments.
Args:
client_specs (ClientSpecs):
client specification
*args (Optional[Union[int, np.ndarray, List]]):
argument(s) for evaluation
Returns:
List[Optional[Union[int, np.ndarray]]]: ordered validated args
"""
client_parameters_json = json.loads(client_specs.client_parameters.serialize())
assert "inputs" in client_parameters_json
input_specs = client_parameters_json["inputs"]
if len(args) != len(input_specs):
message = f"Expected {len(input_specs)} inputs but got {len(args)}"
raise ValueError(message)
sanitized_args: Dict[int, Optional[Union[int, np.ndarray]]] = {}
for index, (arg, spec) in enumerate(zip(args, input_specs)):
if arg is None:
sanitized_args[index] = None
continue
if isinstance(arg, list):
arg = np.array(arg)
is_valid = isinstance(arg, (int, np.integer)) or (
isinstance(arg, np.ndarray) and np.issubdtype(arg.dtype, np.integer)
)
width = spec["shape"]["width"]
is_signed = spec["shape"]["sign"]
shape = tuple(spec["shape"]["dimensions"])
is_encrypted = spec["encryption"] is not None
expected_dtype = SignedInteger(width) if is_signed else UnsignedInteger(width)
expected_value = ValueDescription(expected_dtype, shape, is_encrypted)
if is_valid:
expected_min = expected_dtype.min()
expected_max = expected_dtype.max()
if not is_encrypted:
# clear integers are signless
# (e.g., 8-bit clear integer can be in range -128, 255)
expected_min = -(expected_max // 2) - 1
actual_min = arg if isinstance(arg, int) else arg.min()
actual_max = arg if isinstance(arg, int) else arg.max()
actual_shape = () if isinstance(arg, int) else arg.shape
is_valid = (
actual_min >= expected_min
and actual_max <= expected_max
and actual_shape == expected_value.shape
)
if is_valid:
sanitized_args[index] = arg
if not is_valid:
actual_value = ValueDescription.of(arg, is_encrypted=is_encrypted)
message = f"Expected argument {index} to be {expected_value} but it's {actual_value}"
raise ValueError(message)
ordered_sanitized_args = [sanitized_args[i] for i in range(len(sanitized_args))]
return ordered_sanitized_args
def fuse(graph: Graph, artifacts: Optional[DebugArtifacts] = None):
"""
Fuse appropriate subgraphs in a graph to a single Operation.Generic node.

View File

@@ -352,7 +352,7 @@ def test_bad_server_save(helpers):
Test `save` method of `Server` class with bad parameters.
"""
configuration = helpers.configuration()
configuration = helpers.configuration().fork(jit=True)
@compiler({"x": "encrypted"})
def function(x):
@@ -367,43 +367,6 @@ def test_bad_server_save(helpers):
assert str(excinfo.value) == "Just-in-Time compilation cannot be saved"
@pytest.mark.parametrize("p_error", [0.75, 0.5, 0.4, 0.25, 0.2, 0.1, 0.01, 0.001])
@pytest.mark.parametrize("bit_width", [5])
@pytest.mark.parametrize("sample_size", [1_000_000])
@pytest.mark.parametrize("tolerance", [0.1])
def test_p_error_simulation(p_error, bit_width, sample_size, tolerance, helpers):
"""
Test p_error simulation.
"""
configuration = helpers.configuration().fork(global_p_error=None)
table = LookupTable([0] + [x - 1 for x in range(1, 2**bit_width)])
@compiler({"x": "encrypted"})
def function(x):
return table[x + 1]
inputset = [np.random.randint(0, (2**bit_width) - 1, size=(sample_size,)) for _ in range(100)]
circuit = function.compile(inputset, configuration=configuration, p_error=p_error)
assert circuit.p_error < p_error
sample = np.random.randint(0, (2**bit_width) - 1, size=(sample_size,))
output = circuit.simulate(sample)
errors = np.sum(output != sample)
expected_number_of_errors_on_average = sample_size * circuit.p_error
tolerated_difference = expected_number_of_errors_on_average * tolerance
acceptable_number_of_errors = [
round(expected_number_of_errors_on_average - tolerated_difference),
round(expected_number_of_errors_on_average + tolerated_difference),
]
assert acceptable_number_of_errors[0] <= errors <= acceptable_number_of_errors[1]
def test_circuit_run_with_unused_arg(helpers):
"""
Test `encrypt_run_decrypt` method of `Circuit` class with unused arguments.
@@ -444,3 +407,125 @@ def test_dataflow_circuit(helpers):
circuit = f.compile(inputset, configuration)
assert circuit.encrypt_run_decrypt(5, 6) == 28
def test_circuit_sim_disabled(helpers):
"""
Test attempt to simulate without enabling fhe simulation.
"""
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return x + y
inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(2)]
circuit = f.compile(inputset, configuration)
with pytest.raises(RuntimeError) as excinfo:
circuit.simulate(*inputset[0])
assert (
str(excinfo.value)
== "Simulation isn't enabled. You can call enable_fhe_simulation() to enable it"
)
def test_circuit_fhe_exec_disabled(helpers):
"""
Test attempt to run fhe execution without it being enabled.
"""
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return x + y
inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(2)]
circuit = f.compile(inputset, configuration.fork(fhe_execution=False))
with pytest.raises(RuntimeError) as excinfo:
# as we can't encrypt, we just pass plain inputs, and it should lead to the expected error
circuit.run(*inputset[0], None)
assert (
str(excinfo.value)
== "FHE execution isn't enabled. You can call enable_fhe_execution() to enable it"
)
def test_circuit_fhe_exec_no_eval_keys(helpers):
"""
Test attempt to run fhe execution without eval keys.
"""
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return x + y
inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(2)]
circuit = f.compile(inputset, configuration)
with pytest.raises(RuntimeError) as excinfo:
# as we can't encrypt, we just pass plain inputs, and it should lead to the expected error
encrypted_args = inputset[0]
circuit.server.run(*encrypted_args)
assert (
str(excinfo.value) == "Expected evaluation keys to be provided when not in simulation mode"
)
def test_circuit_eval_graph_scalar(helpers):
"""
Test evaluation of the graph.
"""
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
lut = LookupTable(list(range(128)))
return lut[x + y]
inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(2)]
circuit = f.compile(inputset, configuration.fork(fhe_simulation=False, fhe_execution=False))
assert f(*inputset[0]) == circuit.graph(*inputset[0], p_error=0.01)
def test_circuit_eval_graph_tensor(helpers):
"""
Test evaluation of the graph.
"""
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
lut = LookupTable(list(range(128)))
return lut[x + y]
inputset = [
(
np.random.randint(0, 2**4, size=[2, 2]),
np.random.randint(0, 2**5, size=[2, 2]),
)
for _ in range(2)
]
circuit = f.compile(inputset, configuration.fork(fhe_simulation=False, fhe_execution=False))
assert np.all(f(*inputset[0]) == circuit.graph(*inputset[0], p_error=0.01))
def test_circuit_compile_sim_only(helpers):
"""
Test compiling with simulation only.
"""
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
lut = LookupTable(list(range(128)))
return lut[x + y]
inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(2)]
circuit = f.compile(inputset, configuration.fork(fhe_simulation=True, fhe_execution=False))
assert f(*inputset[0]) == circuit.simulate(*inputset[0])

View File

@@ -254,7 +254,6 @@ class Helpers:
function: Callable,
sample: Union[Any, List[Any]],
retries: int = 1,
simulate: bool = False,
):
"""
Assert that `circuit` is behaves the same as `function` on `sample`.
@@ -296,9 +295,7 @@ class Helpers:
for i in range(retries):
expected = sanitize(function(*sample))
actual = sanitize(
circuit.simulate(*sample) if simulate else circuit.encrypt_run_decrypt(*sample)
)
actual = sanitize(circuit.encrypt_run_decrypt(*sample))
if all(np.array_equal(e, a) for e, a in zip(expected, actual)):
break
@@ -317,6 +314,32 @@ Actual Output
"""
raise AssertionError(message)
try:
circuit.enable_fhe_simulation()
except Exception as e: # pylint: disable=broad-exception-caught
print(f"Catched exception while enabling simulation: {e}")
return
for i in range(retries):
expected = sanitize(function(*sample))
actual = sanitize(circuit.simulate(*sample))
if all(np.array_equal(e, a) for e, a in zip(expected, actual)):
break
if i == retries - 1:
message = f"""
Expected Output During Simulation
=================================
{expected}
Actual Output During Simulation
===============================
{actual}
"""
raise AssertionError(message)
@staticmethod
def check_str(expected: str, actual: str):
"""