Files
concrete/tests/conftest.py

302 lines
8.4 KiB
Python

"""
Configuration of `pytest`.
"""
import json
import random
from pathlib import Path
from typing import Any, Callable, Dict, List, Tuple, Union
import numpy as np
import pytest
import concrete.numpy as cnp
from concrete.numpy.compilation import configuration as configuration_
def pytest_addoption(parser):
"""
Add CLI options.
"""
parser.addoption(
"--global-coverage",
type=str,
default=None,
action="store",
help="JSON file to dump pytest-cov terminal report.",
)
parser.addoption(
"--key-cache",
type=str,
default=None,
action="store",
help="Specify the location of the key cache",
)
def pytest_sessionstart(session):
"""
Initialize insecure key cache.
"""
key_cache_location = session.config.getoption("--key-cache", default=None)
if key_cache_location is not None:
if key_cache_location.lower() == "disable":
key_cache_location = None
else:
key_cache_location = Path(key_cache_location).expanduser().resolve()
else:
key_cache_location = Path.home().resolve() / ".cache" / "concrete-numpy" / "pytest"
if key_cache_location:
key_cache_location.mkdir(parents=True, exist_ok=True)
print(f"INSECURE_KEY_CACHE_LOCATION={str(key_cache_location)}")
# pylint: disable=protected-access
configuration_._INSECURE_KEY_CACHE_LOCATION = str(key_cache_location)
# pylint: enable=protected-access
def pytest_sessionfinish(session, exitstatus): # pylint: disable=unused-argument
"""
Save global coverage info after testing is finished.
"""
# Hacked together from the source code, they don't have an option to export to file,
# and it's too much work to get a PR in for such a little thing.
# https://github.com/pytest-dev/pytest-cov/blob/ec344d8adf2d78238d8f07cb20ed2463d7536970/src/pytest_cov/plugin.py#L329
if session.config.pluginmanager.hasplugin("_cov"):
global_coverage_option = session.config.getoption("--global-coverage", default=None)
if global_coverage_option is not None:
coverage_plugin = session.config.pluginmanager.getplugin("_cov")
coverage_txt = coverage_plugin.cov_report.getvalue()
coverage_status = 0
if (
coverage_plugin.options.cov_fail_under is not None
and coverage_plugin.options.cov_fail_under > 0
and coverage_plugin.cov_total < coverage_plugin.options.cov_fail_under
):
coverage_status = 1
global_coverage_file_path = Path(global_coverage_option).resolve()
with open(global_coverage_file_path, "w", encoding="utf-8") as f:
json.dump({"exit_code": coverage_status, "content": coverage_txt}, f)
class Helpers:
"""
Helpers class, which provides various helpers to tests.
"""
@staticmethod
def configuration() -> cnp.Configuration:
"""
Get the test configuration to use during testing.
Returns:
cnp.Configuration:
test configuration
"""
return cnp.Configuration(
dump_artifacts_on_unexpected_failures=False,
enable_unsafe_features=True,
use_insecure_key_cache=True,
loop_parallelize=True,
dataflow_parallelize=False,
auto_parallelize=False,
)
@staticmethod
def generate_encryption_statuses(parameters: Dict[str, Dict[str, Any]]) -> Dict[str, str]:
"""
Generate parameter encryption statuses accoring to a parameter specification.
Args:
parameters (Dict[str, Dict[str, Any]]):
parameter specification to use
e.g.,
{
"param1": {"range": [0, 10], "status": "clear"},
"param2": {"range": [3, 30], "status": "encrypted", "shape": (3,)},
}
Returns:
Dict[str, str]:
parameter encryption statuses
generated according to the given parameter specification
"""
return {
parameter: details["status"] if "status" in details else "encrypted"
for parameter, details in parameters.items()
}
@staticmethod
def generate_inputset(
parameters: Dict[str, Dict[str, Any]],
size: int = 128,
) -> List[Union[Tuple[Union[int, np.ndarray], ...], Union[int, np.ndarray]]]:
"""
Generate a random inputset of desired size accoring to a parameter specification.
Args:
parameters (Dict[str, Dict[str, Any]]):
parameter specification to use
e.g.,
{
"param1": {"range": [0, 10], "status": "clear"},
"param2": {"range": [3, 30], "status": "encrypted", "shape": (3,)},
}
size (int):
size of the resulting inputset
Returns:
List[Union[Tuple[Union[int, np.ndarray], ...], Union[int, np.ndarray]]]:
random inputset of desired size
generated according to the given parameter specification
"""
inputset = []
for _ in range(size):
sample = Helpers.generate_sample(parameters)
inputset.append(tuple(sample) if len(sample) > 1 else sample[0])
return inputset
@staticmethod
def generate_sample(parameters: Dict[str, Dict[str, Any]]) -> List[Union[int, np.ndarray]]:
"""
Generate a random sample accoring to a parameter specification.
Args:
parameters (Dict[str, Dict[str, Any]]):
parameter specification to use
e.g.,
{
"param1": {"range": [0, 10], "status": "clear"},
"param2": {"range": [3, 30], "status": "encrypted", "shape": (3,)},
}
Returns:
List[Union[int, np.ndarray]]:
random sample
generated according to the given parameter specification
"""
sample = []
for description in parameters.values():
minimum, maximum = description.get("range", [0, 127])
assert minimum >= 0
assert maximum <= 127
if "shape" in description:
shape = description["shape"]
sample.append(np.random.randint(minimum, maximum + 1, size=shape, dtype=np.int64))
else:
sample.append(np.int64(random.randint(minimum, maximum)))
return sample
@staticmethod
def check_execution(
circuit: cnp.Circuit,
function: Callable,
sample: Union[Any, List[Any]],
retries: int = 1,
):
"""
Assert that `circuit` is behaves the same as `function` on `sample`.
Args:
circuit (cnp.Circuit):
compiled circuit
function (Callable):
original function
sample (List[Any]):
inputs
retries (int):
number of times to retry (for probabilistic execution)
"""
if not isinstance(sample, list):
sample = [sample]
for i in range(retries):
expected = function(*sample)
actual = circuit.encrypt_run_decrypt(*sample)
if not isinstance(expected, tuple):
expected = (expected,)
if not isinstance(actual, tuple):
actual = (actual,)
if all(np.array_equal(e, a) for e, a in zip(expected, actual)):
break
if i == retries - 1:
raise AssertionError(
f"""
Expected Output
===============
{expected}
Actual Output
=============
{actual}
"""
)
@staticmethod
def check_str(expected: str, actual: str):
"""
Assert that `circuit` is behaves the same as `function` on `sample`.
Args:
expected (str):
expected str
actual (str):
actual str
"""
assert (
actual.strip() == expected.strip()
), f"""
Expected Output
===============
{expected}
Actual Output
=============
{actual}
"""
@pytest.fixture
def helpers():
"""
Fixture that provides `Helpers` class to tests.
"""
return Helpers