Files
concrete/frontends/concrete-python/tests/conftest.py
Alexandre Péré e8ef48ffd8 feat(compiler): introduce concrete-protocol
This commit:
 + Adds support for a protocol which enables inter-op between concrete,
   tfhe-rs and potentially other contributors to the fhe ecosystem.
 + Gets rid of hand-made serialization in the compiler, and
   client/server libs.
 + Refactors client/server libs to allow more pre/post processing of
   circuit inputs/outputs.

The protocol is supported by a definition in the shape of a capnp file,
which defines different types of objects among which:
 + ProgramInfo object, which is a precise description of a set of fhe
   circuit coming from the same compilation (understand function type
   information), and the associated key set.
 + *Key objects, which represent secret/public keys used to
   encrypt/execute fhe circuits.
 + Value object, which represent values that can be transferred between
   client and server to support calls to fhe circuits.

The hand-rolled serialization that was previously used is completely
dropped in favor of capnp in the whole codebase.

The client/server libs, are refactored to introduce a modular design for
pre-post processing. Reading the ProgramInfo file associated with a
compilation, the client and server libs assemble a pipeline of
transformers (functions) for pre and post processing of values coming in
and out of a circuit. This design properly decouples various aspects of
the processing, and allows these capabilities to be safely extended.

In practice this commit includes the following:
 + Defines the specification in a concreteprotocol package
 + Integrate the compilation of this package as a compiler dependency
   via cmake
 + Modify the compiler to use the Encodings objects defined in the
   protocol
 + Modify the compiler to emit ProgramInfo files as compilation
   artifact, and gets rid of the bloated ClientParameters.
 + Introduces a new Common library containing the functionalities shared
   between the compiler and the client/server libs.
 + Introduces a functional pre-post processing pipeline to this common
   library
 + Modify the client/server libs to support loading ProgramInfo objects,
   and calling circuits using Value messages.
 + Drops support of JIT.
 + Drops support of C-api.
 + Drops support of Rust bindings.

Co-authored-by: Nikita Frolov <nf@mkmks.org>
2023-11-09 17:09:04 +01:00

383 lines
11 KiB
Python

"""
Configuration of `pytest`.
"""
import json
import os
import random
from pathlib import Path
from typing import Any, Callable, Dict, List, Tuple, Union
import numpy as np
import pytest
import tests
from concrete import fhe
tests_directory = os.path.dirname(tests.__file__)
INSECURE_KEY_CACHE_LOCATION = None
USE_MULTI_PRECISION = False
OPTIMIZATION_STRATEGY = fhe.ParameterSelectionStrategy.MONO
def pytest_addoption(parser):
"""
Add CLI options.
"""
parser.addoption(
"--global-coverage",
type=str,
default=None,
action="store",
help="JSON file to dump pytest-cov terminal report.",
)
parser.addoption(
"--key-cache",
type=str,
default=None,
action="store",
help="Specify the location of the key cache",
)
parser.addoption(
"--precision",
type=str,
default=None,
action="store",
help="Which precision strategy to use in execution tests (single or multi)",
)
parser.addoption(
"--strategy",
type=str,
default=None,
action="store",
help="Which optimization strategy to use in execution tests (v0, mono or multi)",
)
def pytest_sessionstart(session):
"""
Initialize insecure key cache.
"""
# pylint: disable=global-statement
global INSECURE_KEY_CACHE_LOCATION
global USE_MULTI_PRECISION
global OPTIMIZATION_STRATEGY
# pylint: enable=global-statement
key_cache_location = session.config.getoption("--key-cache", default=None)
if key_cache_location is not None and key_cache_location != "":
if key_cache_location.lower() == "disable":
key_cache_location = None
else:
key_cache_location = Path(key_cache_location).expanduser().resolve()
else:
key_cache_location = Path.home().resolve() / ".cache" / "concrete-python" / "pytest"
if key_cache_location:
key_cache_location.mkdir(parents=True, exist_ok=True)
print(f"INSECURE_KEY_CACHE_LOCATION={str(key_cache_location)}")
INSECURE_KEY_CACHE_LOCATION = str(key_cache_location)
precision = session.config.getoption("--precision", default="single")
USE_MULTI_PRECISION = precision == "multi"
strategy = session.config.getoption("--strategy", default="mono")
if strategy == "v0":
OPTIMIZATION_STRATEGY = fhe.ParameterSelectionStrategy.V0
elif strategy == "multi":
OPTIMIZATION_STRATEGY = fhe.ParameterSelectionStrategy.MULTI
else:
OPTIMIZATION_STRATEGY = fhe.ParameterSelectionStrategy.MONO
def pytest_sessionfinish(session, exitstatus): # pylint: disable=unused-argument
"""
Save global coverage info after testing is finished.
"""
# Hacked together from the source code, they don't have an option to export to file,
# and it's too much work to get a PR in for such a little thing.
# https://github.com/pytest-dev/pytest-cov/blob/ec344d8adf2d78238d8f07cb20ed2463d7536970/src/pytest_cov/plugin.py#L329
if session.config.pluginmanager.hasplugin("_cov"):
global_coverage_option = session.config.getoption("--global-coverage", default=None)
if global_coverage_option is not None:
coverage_plugin = session.config.pluginmanager.getplugin("_cov")
coverage_txt = coverage_plugin.cov_report.getvalue()
coverage_status = 0
if (
coverage_plugin.options.cov_fail_under is not None
and coverage_plugin.options.cov_fail_under > 0
and coverage_plugin.cov_total < coverage_plugin.options.cov_fail_under
):
coverage_status = 1
global_coverage_file_path = Path(global_coverage_option).resolve()
with open(global_coverage_file_path, "w", encoding="utf-8") as f:
json.dump({"exit_code": coverage_status, "content": coverage_txt}, f)
class Helpers:
"""
Helpers class, which provides various helpers to tests.
"""
@staticmethod
def configuration() -> fhe.Configuration:
"""
Get the test configuration to use during testing.
Returns:
fhe.Configuration:
test configuration
"""
return fhe.Configuration(
dump_artifacts_on_unexpected_failures=False,
enable_unsafe_features=True,
use_insecure_key_cache=True,
loop_parallelize=True,
dataflow_parallelize=False,
auto_parallelize=False,
insecure_key_cache_location=INSECURE_KEY_CACHE_LOCATION,
global_p_error=(1 / 10_000),
single_precision=(not USE_MULTI_PRECISION),
parameter_selection_strategy=OPTIMIZATION_STRATEGY,
)
@staticmethod
def generate_encryption_statuses(parameters: Dict[str, Dict[str, Any]]) -> Dict[str, str]:
"""
Generate parameter encryption statuses accoring to a parameter specification.
Args:
parameters (Dict[str, Dict[str, Any]]):
parameter specification to use
e.g.,
{
"param1": {"range": [0, 10], "status": "clear"},
"param2": {"range": [3, 30], "status": "encrypted", "shape": (3,)},
}
Returns:
Dict[str, str]:
parameter encryption statuses
generated according to the given parameter specification
"""
return {
parameter: details["status"] if "status" in details else "encrypted"
for parameter, details in parameters.items()
}
@staticmethod
def generate_inputset(
parameters: Dict[str, Dict[str, Any]],
size: int = 128,
) -> List[Union[Tuple[Union[int, np.ndarray], ...], Union[int, np.ndarray]]]:
"""
Generate a random inputset of desired size accoring to a parameter specification.
Args:
parameters (Dict[str, Dict[str, Any]]):
parameter specification to use
e.g.,
{
"param1": {"range": [0, 10], "status": "clear"},
"param2": {"range": [3, 30], "status": "encrypted", "shape": (3,)},
}
size (int):
size of the resulting inputset
Returns:
List[Union[Tuple[Union[int, np.ndarray], ...], Union[int, np.ndarray]]]:
random inputset of desired size
generated according to the given parameter specification
"""
inputset = []
for _ in range(size):
sample = Helpers.generate_sample(parameters)
inputset.append(tuple(sample) if len(sample) > 1 else sample[0])
return inputset
@staticmethod
def generate_sample(parameters: Dict[str, Dict[str, Any]]) -> List[Union[int, np.ndarray]]:
"""
Generate a random sample accoring to a parameter specification.
Args:
parameters (Dict[str, Dict[str, Any]]):
parameter specification to use
e.g.,
{
"param1": {"range": [0, 10], "status": "clear"},
"param2": {"range": [3, 30], "status": "encrypted", "shape": (3,)},
}
Returns:
List[Union[int, np.ndarray]]:
random sample
generated according to the given parameter specification
"""
sample = []
for description in parameters.values():
minimum, maximum = description.get("range", [0, (2**16) - 1])
if "shape" in description:
shape = description["shape"]
sample.append(np.random.randint(minimum, maximum + 1, size=shape, dtype=np.int64))
else:
sample.append(np.int64(random.randint(minimum, maximum)))
return sample
@staticmethod
def check_execution(
circuit: fhe.Circuit,
function: Callable,
sample: Union[Any, List[Any]],
retries: int = 1,
only_simulation: bool = False,
):
"""
Assert that `circuit` behaves the same as `function` on `sample`.
Args:
circuit (fhe.Circuit):
compiled circuit
function (Callable):
original function
sample (List[Any]):
inputs
retries (int, default = 1):
number of times to retry (for probabilistic execution)
only_simulation (bool, default = False):
whether to just check simulation but not execution
"""
if not isinstance(sample, list):
sample = [sample]
def sanitize(values):
if not isinstance(values, tuple):
values = (values,)
result = []
for value in values:
if isinstance(value, (bool, np.bool_)):
value = int(value)
elif isinstance(value, np.ndarray) and value.dtype == np.bool_:
value = value.astype(np.int64)
result.append(value)
return tuple(result)
if not only_simulation:
for i in range(retries):
expected = sanitize(function(*sample))
actual = sanitize(circuit.encrypt_run_decrypt(*sample))
if all(np.array_equal(e, a) for e, a in zip(expected, actual)):
break
if i == retries - 1:
message = f"""
Expected Output
===============
{expected}
Actual Output
=============
{actual}
"""
raise AssertionError(message)
circuit.enable_fhe_simulation()
for i in range(retries):
expected = sanitize(function(*sample))
actual = sanitize(circuit.simulate(*sample))
if all(np.array_equal(e, a) for e, a in zip(expected, actual)):
break
if i == retries - 1:
message = f"""
Expected Output During Simulation
=================================
{expected}
Actual Output During Simulation
===============================
{actual}
"""
raise AssertionError(message)
@staticmethod
def check_str(expected: str, actual: str):
"""
Assert that `circuit` is behaves the same as `function` on `sample`.
Args:
expected (str):
expected str
actual (str):
actual str
"""
# remove error line information
# there are explicit tests to make sure the line information is correct
# however, it would have been very hard to keep the other tests up to date
actual = "\n".join(
line for line in actual.splitlines() if not line.strip().startswith(tests_directory)
)
assert (
actual.strip() == expected.strip()
), f"""
Expected Output
===============
{expected}
Actual Output
=============
{actual}
"""
@pytest.fixture
def helpers():
"""
Fixture that provides `Helpers` class to tests.
"""
return Helpers