""" Configuration of `pytest`. """ import json import random from pathlib import Path from typing import Any, Callable, Dict, List, Tuple, Union import numpy as np import pytest import concrete.numpy as cnp from concrete.numpy.compilation import configuration as configuration_ def pytest_addoption(parser): """ Add CLI options. """ parser.addoption( "--global-coverage", type=str, default=None, action="store", help="JSON file to dump pytest-cov terminal report.", ) parser.addoption( "--key-cache", type=str, default=None, action="store", help="Specify the location of the key cache", ) def pytest_sessionstart(session): """ Initialize insecure key cache. """ key_cache_location = session.config.getoption("--key-cache", default=None) if key_cache_location is not None: if key_cache_location.lower() == "disable": key_cache_location = None else: key_cache_location = Path(key_cache_location).expanduser().resolve() else: key_cache_location = Path.home().resolve() / ".cache" / "concrete-numpy" / "pytest" if key_cache_location: key_cache_location.mkdir(parents=True, exist_ok=True) print(f"INSECURE_KEY_CACHE_LOCATION={str(key_cache_location)}") # pylint: disable=protected-access configuration_._INSECURE_KEY_CACHE_LOCATION = str(key_cache_location) # pylint: enable=protected-access def pytest_sessionfinish(session, exitstatus): # pylint: disable=unused-argument """ Save global coverage info after testing is finished. """ # Hacked together from the source code, they don't have an option to export to file, # and it's too much work to get a PR in for such a little thing. # https://github.com/pytest-dev/pytest-cov/blob/ec344d8adf2d78238d8f07cb20ed2463d7536970/src/pytest_cov/plugin.py#L329 if session.config.pluginmanager.hasplugin("_cov"): global_coverage_option = session.config.getoption("--global-coverage", default=None) if global_coverage_option is not None: coverage_plugin = session.config.pluginmanager.getplugin("_cov") coverage_txt = coverage_plugin.cov_report.getvalue() coverage_status = 0 if ( coverage_plugin.options.cov_fail_under is not None and coverage_plugin.options.cov_fail_under > 0 and coverage_plugin.cov_total < coverage_plugin.options.cov_fail_under ): coverage_status = 1 global_coverage_file_path = Path(global_coverage_option).resolve() with open(global_coverage_file_path, "w", encoding="utf-8") as f: json.dump({"exit_code": coverage_status, "content": coverage_txt}, f) class Helpers: """ Helpers class, which provides various helpers to tests. """ @staticmethod def configuration() -> cnp.Configuration: """ Get the test configuration to use during testing. Returns: cnp.Configuration: test configuration """ return cnp.Configuration( dump_artifacts_on_unexpected_failures=False, enable_unsafe_features=True, use_insecure_key_cache=True, loop_parallelize=True, dataflow_parallelize=False, auto_parallelize=False, ) @staticmethod def generate_encryption_statuses(parameters: Dict[str, Dict[str, Any]]) -> Dict[str, str]: """ Generate parameter encryption statuses accoring to a parameter specification. Args: parameters (Dict[str, Dict[str, Any]]): parameter specification to use e.g., { "param1": {"range": [0, 10], "status": "clear"}, "param2": {"range": [3, 30], "status": "encrypted", "shape": (3,)}, } Returns: Dict[str, str]: parameter encryption statuses generated according to the given parameter specification """ return { parameter: details["status"] if "status" in details else "encrypted" for parameter, details in parameters.items() } @staticmethod def generate_inputset( parameters: Dict[str, Dict[str, Any]], size: int = 128, ) -> List[Union[Tuple[Union[int, np.ndarray], ...], Union[int, np.ndarray]]]: """ Generate a random inputset of desired size accoring to a parameter specification. Args: parameters (Dict[str, Dict[str, Any]]): parameter specification to use e.g., { "param1": {"range": [0, 10], "status": "clear"}, "param2": {"range": [3, 30], "status": "encrypted", "shape": (3,)}, } size (int): size of the resulting inputset Returns: List[Union[Tuple[Union[int, np.ndarray], ...], Union[int, np.ndarray]]]: random inputset of desired size generated according to the given parameter specification """ inputset = [] for _ in range(size): sample = Helpers.generate_sample(parameters) inputset.append(tuple(sample) if len(sample) > 1 else sample[0]) return inputset @staticmethod def generate_sample(parameters: Dict[str, Dict[str, Any]]) -> List[Union[int, np.ndarray]]: """ Generate a random sample accoring to a parameter specification. Args: parameters (Dict[str, Dict[str, Any]]): parameter specification to use e.g., { "param1": {"range": [0, 10], "status": "clear"}, "param2": {"range": [3, 30], "status": "encrypted", "shape": (3,)}, } Returns: List[Union[int, np.ndarray]]: random sample generated according to the given parameter specification """ sample = [] for description in parameters.values(): minimum, maximum = description.get("range", [0, 127]) assert minimum >= 0 assert maximum <= 127 if "shape" in description: shape = description["shape"] sample.append(np.random.randint(minimum, maximum + 1, size=shape, dtype=np.int64)) else: sample.append(np.int64(random.randint(minimum, maximum))) return sample @staticmethod def check_execution( circuit: cnp.Circuit, function: Callable, sample: Union[Any, List[Any]], retries: int = 1, ): """ Assert that `circuit` is behaves the same as `function` on `sample`. Args: circuit (cnp.Circuit): compiled circuit function (Callable): original function sample (List[Any]): inputs retries (int): number of times to retry (for probabilistic execution) """ if not isinstance(sample, list): sample = [sample] for i in range(retries): expected = function(*sample) actual = circuit.encrypt_run_decrypt(*sample) if not isinstance(expected, tuple): expected = (expected,) if not isinstance(actual, tuple): actual = (actual,) if all(np.array_equal(e, a) for e, a in zip(expected, actual)): break if i == retries - 1: raise AssertionError( f""" Expected Output =============== {expected} Actual Output ============= {actual} """ ) @staticmethod def check_str(expected: str, actual: str): """ Assert that `circuit` is behaves the same as `function` on `sample`. Args: expected (str): expected str actual (str): actual str """ assert ( actual.strip() == expected.strip() ), f""" Expected Output =============== {expected} Actual Output ============= {actual} """ @pytest.fixture def helpers(): """ Fixture that provides `Helpers` class to tests. """ return Helpers