""" Configuration of `pytest`. """ import json import random from pathlib import Path from typing import Any, Callable, Dict, List, Tuple, Union import numpy as np import pytest import concrete.numpy as cnp INSECURE_KEY_CACHE_LOCATION = None def pytest_addoption(parser): """ Add CLI options. """ parser.addoption( "--global-coverage", type=str, default=None, action="store", help="JSON file to dump pytest-cov terminal report.", ) parser.addoption( "--key-cache", type=str, default=None, action="store", help="Specify the location of the key cache", ) def pytest_sessionstart(session): """ Initialize insecure key cache. """ # pylint: disable=global-statement global INSECURE_KEY_CACHE_LOCATION # pylint: enable=global-statement key_cache_location = session.config.getoption("--key-cache", default=None) if key_cache_location is not None: if key_cache_location.lower() == "disable": key_cache_location = None else: key_cache_location = Path(key_cache_location).expanduser().resolve() else: key_cache_location = Path.home().resolve() / ".cache" / "concrete-numpy" / "pytest" if key_cache_location: key_cache_location.mkdir(parents=True, exist_ok=True) print(f"INSECURE_KEY_CACHE_LOCATION={str(key_cache_location)}") INSECURE_KEY_CACHE_LOCATION = str(key_cache_location) def pytest_sessionfinish(session, exitstatus): # pylint: disable=unused-argument """ Save global coverage info after testing is finished. """ # Hacked together from the source code, they don't have an option to export to file, # and it's too much work to get a PR in for such a little thing. # https://github.com/pytest-dev/pytest-cov/blob/ec344d8adf2d78238d8f07cb20ed2463d7536970/src/pytest_cov/plugin.py#L329 if session.config.pluginmanager.hasplugin("_cov"): global_coverage_option = session.config.getoption("--global-coverage", default=None) if global_coverage_option is not None: coverage_plugin = session.config.pluginmanager.getplugin("_cov") coverage_txt = coverage_plugin.cov_report.getvalue() coverage_status = 0 if ( coverage_plugin.options.cov_fail_under is not None and coverage_plugin.options.cov_fail_under > 0 and coverage_plugin.cov_total < coverage_plugin.options.cov_fail_under ): coverage_status = 1 global_coverage_file_path = Path(global_coverage_option).resolve() with open(global_coverage_file_path, "w", encoding="utf-8") as f: json.dump({"exit_code": coverage_status, "content": coverage_txt}, f) class Helpers: """ Helpers class, which provides various helpers to tests. """ @staticmethod def configuration() -> cnp.Configuration: """ Get the test configuration to use during testing. Returns: cnp.Configuration: test configuration """ return cnp.Configuration( dump_artifacts_on_unexpected_failures=False, enable_unsafe_features=True, use_insecure_key_cache=True, loop_parallelize=True, dataflow_parallelize=False, auto_parallelize=False, jit=True, insecure_key_cache_location=INSECURE_KEY_CACHE_LOCATION, ) @staticmethod def generate_encryption_statuses(parameters: Dict[str, Dict[str, Any]]) -> Dict[str, str]: """ Generate parameter encryption statuses accoring to a parameter specification. Args: parameters (Dict[str, Dict[str, Any]]): parameter specification to use e.g., { "param1": {"range": [0, 10], "status": "clear"}, "param2": {"range": [3, 30], "status": "encrypted", "shape": (3,)}, } Returns: Dict[str, str]: parameter encryption statuses generated according to the given parameter specification """ return { parameter: details["status"] if "status" in details else "encrypted" for parameter, details in parameters.items() } @staticmethod def generate_inputset( parameters: Dict[str, Dict[str, Any]], size: int = 128, ) -> List[Union[Tuple[Union[int, np.ndarray], ...], Union[int, np.ndarray]]]: """ Generate a random inputset of desired size accoring to a parameter specification. Args: parameters (Dict[str, Dict[str, Any]]): parameter specification to use e.g., { "param1": {"range": [0, 10], "status": "clear"}, "param2": {"range": [3, 30], "status": "encrypted", "shape": (3,)}, } size (int): size of the resulting inputset Returns: List[Union[Tuple[Union[int, np.ndarray], ...], Union[int, np.ndarray]]]: random inputset of desired size generated according to the given parameter specification """ inputset = [] for _ in range(size): sample = Helpers.generate_sample(parameters) inputset.append(tuple(sample) if len(sample) > 1 else sample[0]) return inputset @staticmethod def generate_sample(parameters: Dict[str, Dict[str, Any]]) -> List[Union[int, np.ndarray]]: """ Generate a random sample accoring to a parameter specification. Args: parameters (Dict[str, Dict[str, Any]]): parameter specification to use e.g., { "param1": {"range": [0, 10], "status": "clear"}, "param2": {"range": [3, 30], "status": "encrypted", "shape": (3,)}, } Returns: List[Union[int, np.ndarray]]: random sample generated according to the given parameter specification """ sample = [] for description in parameters.values(): minimum, maximum = description.get("range", [0, 127]) if "shape" in description: shape = description["shape"] sample.append(np.random.randint(minimum, maximum + 1, size=shape, dtype=np.int64)) else: sample.append(np.int64(random.randint(minimum, maximum))) return sample @staticmethod def check_execution( circuit: cnp.Circuit, function: Callable, sample: Union[Any, List[Any]], retries: int = 1, ): """ Assert that `circuit` is behaves the same as `function` on `sample`. Args: circuit (cnp.Circuit): compiled circuit function (Callable): original function sample (List[Any]): inputs retries (int): number of times to retry (for probabilistic execution) """ if not isinstance(sample, list): sample = [sample] for i in range(retries): expected = function(*sample) actual = circuit.encrypt_run_decrypt(*sample) if not isinstance(expected, tuple): expected = (expected,) if not isinstance(actual, tuple): actual = (actual,) if all(np.array_equal(e, a) for e, a in zip(expected, actual)): break if i == retries - 1: raise AssertionError( f""" Expected Output =============== {expected} Actual Output ============= {actual} """ ) @staticmethod def check_str(expected: str, actual: str): """ Assert that `circuit` is behaves the same as `function` on `sample`. Args: expected (str): expected str actual (str): actual str """ assert ( actual.strip() == expected.strip() ), f""" Expected Output =============== {expected} Actual Output ============= {actual} """ @pytest.fixture def helpers(): """ Fixture that provides `Helpers` class to tests. """ return Helpers