From 13b9ff96f00a8ef095015703ae85c1dbb4151cd9 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Thu, 18 Nov 2021 13:35:45 +0100 Subject: [PATCH] feat: create torch-like APIs part 1 - work on generating OPGraph with a torch-like API refs #233 --- .../bounds_measurement/inputset_eval.py | 37 +- concrete/common/compilation/configuration.py | 3 + concrete/numpy/__init__.py | 1 + concrete/numpy/compile.py | 387 +++++++++++++----- concrete/numpy/np_fhe_compiler.py | 199 +++++++++ tests/numpy/test_compile_user_friendly_api.py | 215 ++++++++++ 6 files changed, 739 insertions(+), 103 deletions(-) create mode 100644 concrete/numpy/np_fhe_compiler.py create mode 100644 tests/numpy/test_compile_user_friendly_api.py diff --git a/concrete/common/bounds_measurement/inputset_eval.py b/concrete/common/bounds_measurement/inputset_eval.py index b2af728e4..7f88bd35d 100644 --- a/concrete/common/bounds_measurement/inputset_eval.py +++ b/concrete/common/bounds_measurement/inputset_eval.py @@ -1,7 +1,8 @@ """Code to evaluate the IR graph on inputsets.""" import sys -from typing import Any, Callable, Dict, Iterable, Tuple, Union +from functools import partial +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union from ..compilation import CompilationConfiguration from ..data_types.dtypes_helpers import ( @@ -110,6 +111,7 @@ def eval_op_graph_bounds_on_inputset( get_base_value_for_constant_data_func: Callable[ [Any], Any ] = get_base_value_for_python_constant_data, + prev_node_bounds_and_samples: Optional[Dict[IntermediateNode, Dict[str, Any]]] = None, ) -> Tuple[int, Dict[IntermediateNode, Dict[str, Any]]]: """Evaluate the bounds with a inputset. @@ -132,6 +134,8 @@ def eval_op_graph_bounds_on_inputset( tensors). Defaults to max. get_base_value_for_constant_data_func (Callable[[Any], Any], optional): custom function to compute the base value of a python object. + prev_node_bounds_and_samples (Optional[Dict[IntermediateNode, Dict[str, Any]]], optional): + Bounds and samples from a previous run. Defaults to None. Returns: Tuple[int, Dict[IntermediateNode, Dict[str, Any]]]: number of inputs in the inputset and @@ -187,13 +191,38 @@ def eval_op_graph_bounds_on_inputset( first_output = op_graph.evaluate(current_input_data) + prev_node_bounds_and_samples = ( + {} if prev_node_bounds_and_samples is None else prev_node_bounds_and_samples + ) + + def get_previous_value_for_key_or_default_for_dict( + dict_: Dict[IntermediateNode, Dict[str, Any]], + node: IntermediateNode, + key: str, + default: Any, + ) -> Any: + return_value = default + + previous_value_dict = dict_.get(node, None) + + if previous_value_dict is not None: + return_value = previous_value_dict.get(key, default) + + return return_value + + get_previous_value_for_key_or_default = partial( + get_previous_value_for_key_or_default_for_dict, prev_node_bounds_and_samples + ) + # We evaluate the min and max func to be able to resolve the tensors min and max rather than # having the tensor itself as the stored min and max values. + # As we don't know the integrity of prev_node_bounds_and_samples we make sure we can + # populate the new node_bounds_and_samples node_bounds_and_samples = { node: { - "min": min_func(value, value), - "max": max_func(value, value), - "sample": value, + "min": min_func(value, get_previous_value_for_key_or_default(node, "min", value)), + "max": max_func(value, get_previous_value_for_key_or_default(node, "max", value)), + "sample": get_previous_value_for_key_or_default(node, "sample", value), } for node, value in first_output.items() } diff --git a/concrete/common/compilation/configuration.py b/concrete/common/compilation/configuration.py index 0fd1233cf..81220afdf 100644 --- a/concrete/common/compilation/configuration.py +++ b/concrete/common/compilation/configuration.py @@ -26,3 +26,6 @@ class CompilationConfiguration: self.treat_warnings_as_errors = treat_warnings_as_errors self.enable_unsafe_features = enable_unsafe_features self.random_inputset_samples = random_inputset_samples + + def __eq__(self, other) -> bool: + return isinstance(other, CompilationConfiguration) and self.__dict__ == other.__dict__ diff --git a/concrete/numpy/__init__.py b/concrete/numpy/__init__.py index 39d26dc1d..0627a474d 100644 --- a/concrete/numpy/__init__.py +++ b/concrete/numpy/__init__.py @@ -15,4 +15,5 @@ from ..common.extensions.multi_table import MultiLookupTable from ..common.extensions.table import LookupTable from ..common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue from .compile import compile_numpy_function, compile_numpy_function_into_op_graph_and_measure_bounds +from .np_fhe_compiler import NPFHECompiler from .tracing import trace_numpy_function diff --git a/concrete/numpy/compile.py b/concrete/numpy/compile.py index 872871482..3e204dd49 100644 --- a/concrete/numpy/compile.py +++ b/concrete/numpy/compile.py @@ -22,7 +22,7 @@ from ..common.mlir.utils import ( ) from ..common.operator_graph import OPGraph from ..common.optimization.topological import fuse_float_operations -from ..common.representation.intermediate import Add, Constant, GenericFunction +from ..common.representation.intermediate import Add, Constant, GenericFunction, IntermediateNode from ..common.values import BaseValue, ClearScalar from ..numpy.tracing import trace_numpy_function from .np_dtypes_helpers import ( @@ -60,6 +60,102 @@ def numpy_min_func(lhs: Any, rhs: Any) -> Any: return numpy.minimum(lhs, rhs).min() +def sanitize_compilation_configuration_and_artifacts( + compilation_configuration: Optional[CompilationConfiguration] = None, + compilation_artifacts: Optional[CompilationArtifacts] = None, +) -> Tuple[CompilationConfiguration, CompilationArtifacts]: + """Return the proper compilation configuration and artifacts. + + Default values are returned if None is passed for each argument. + + Args: + compilation_configuration (Optional[CompilationConfiguration], optional): the compilation + configuration to sanitize. Defaults to None. + compilation_artifacts (Optional[CompilationArtifacts], optional): the compilation artifacts + to sanitize. Defaults to None. + + Returns: + Tuple[CompilationConfiguration, CompilationArtifacts]: the tuple of sanitized configuration + and artifacts. + """ + # Create default configuration if custom configuration is not specified + compilation_configuration = ( + CompilationConfiguration() + if compilation_configuration is None + else compilation_configuration + ) + + # Create temporary artifacts if custom artifacts is not specified (in case of exceptions) + if compilation_artifacts is None: + compilation_artifacts = CompilationArtifacts() + + return compilation_configuration, compilation_artifacts + + +def get_inputset_to_use( + function_parameters: Dict[str, BaseValue], + inputset: Union[Iterable[Tuple[Any, ...]], str], + compilation_configuration: CompilationConfiguration, +) -> Iterable[Tuple[Any, ...]]: + """Get the proper inputset to use for compilation. + + Args: + function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the + function is e.g. an EncryptedScalar holding a 7bits unsigned Integer + inputset (Union[Iterable[Tuple[Any, ...]], str]): The inputset over which op_graph is + evaluated. It needs to be an iterable on tuples which are of the same length than the + number of parameters in the function, and in the same order than these same parameters + compilation_configuration (CompilationConfiguration): Configuration object to use during + compilation + + Returns: + Iterable[Tuple[Any, ...]]: the inputset to use. + """ + # Generate random inputset if it is requested and available + if isinstance(inputset, str): + _check_special_inputset_availability(inputset, compilation_configuration) + return _generate_random_inputset(function_parameters, compilation_configuration) + return inputset + + +def run_compilation_function_with_error_management( + compilation_function: Callable, + compilation_configuration: CompilationConfiguration, + compilation_artifacts: CompilationArtifacts, +) -> Any: + """Call compilation_function() and manage exceptions that may occur. + + Args: + compilation_function (Callable): the compilation function to call. + compilation_configuration (CompilationConfiguration): the current compilation configuration. + compilation_artifacts (CompilationArtifacts): the current compilation artifacts. + + Returns: + Any: returns the result of the call to compilation_function + """ + + # Try to compile the function and save partial artifacts on failure + try: + # Use context manager to restore numpy error handling + with numpy.errstate(**numpy.geterr()): + return compilation_function() + except Exception: # pragma: no cover + # This branch is reserved for unexpected issues and hence it shouldn't be tested. + # If it could be tested, we would have fixed the underlying issue. + + # We need to export all the information we have about the compilation + # If the user wants them to be exported + + if compilation_configuration.dump_artifacts_on_unexpected_failures: + compilation_artifacts.export() + + traceback_path = compilation_artifacts.output_directory.joinpath("traceback.txt") + with open(traceback_path, "w", encoding="utf-8") as f: + f.write(traceback.format_exc()) + + raise + + def _compile_numpy_function_into_op_graph_internal( function_to_compile: Callable, function_parameters: Dict[str, BaseValue], @@ -122,13 +218,60 @@ def _compile_numpy_function_into_op_graph_internal( return op_graph +def compile_numpy_function_into_op_graph( + function_to_compile: Callable, + function_parameters: Dict[str, BaseValue], + compilation_configuration: Optional[CompilationConfiguration] = None, + compilation_artifacts: Optional[CompilationArtifacts] = None, +) -> OPGraph: + """Compile a function into an OPGraph. + + Args: + function_to_compile (Callable): The function to compile + function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the + function is e.g. an EncryptedScalar holding a 7bits unsigned Integer + compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use + during compilation + compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill + during compilation + + Returns: + OPGraph: compiled function into a graph + """ + + ( + compilation_configuration, + compilation_artifacts, + ) = sanitize_compilation_configuration_and_artifacts( + compilation_configuration, compilation_artifacts + ) + + def compilation_function(): + return _compile_numpy_function_into_op_graph_internal( + function_to_compile, + function_parameters, + compilation_configuration, + compilation_artifacts, + ) + + result = run_compilation_function_with_error_management( + compilation_function, compilation_configuration, compilation_artifacts + ) + + # for mypy + assert isinstance(result, OPGraph) + return result + + def _measure_op_graph_bounds_and_update_internal( op_graph: OPGraph, function_parameters: Dict[str, BaseValue], inputset: Iterable[Tuple[Any, ...]], compilation_configuration: CompilationConfiguration, compilation_artifacts: CompilationArtifacts, -) -> None: + prev_node_bounds_and_samples: Optional[Dict[IntermediateNode, Dict[str, Any]]] = None, + warn_on_inputset_length: bool = True, +) -> Dict[IntermediateNode, Dict[str, Any]]: """Measure the intermediate values and update the OPGraph accordingly for the given inputset. Args: @@ -142,11 +285,21 @@ def _measure_op_graph_bounds_and_update_internal( during compilation compilation_artifacts (CompilationArtifacts): Artifacts object to fill during compilation + prev_node_bounds_and_samples (Optional[Dict[IntermediateNode, Dict[str, Any]]], optional): + Bounds and samples from a previous run. Defaults to None. + warn_on_inputset_length (bool, optional): Set to True to get a warning if inputset is not + long enough. Defaults to True. Raises: ValueError: Raises an error if the inputset is too small and the compilation configuration treats warnings as error. + + Returns: + Dict[IntermediateNode, Dict[str, Any]]: a dict containing the bounds for each node from + op_graph, stored with the node as key and a dict with keys "min", "max" and "sample" as + value. """ + # Find bounds with the inputset inputset_size, node_bounds_and_samples = eval_op_graph_bounds_on_inputset( op_graph, @@ -155,35 +308,37 @@ def _measure_op_graph_bounds_and_update_internal( min_func=numpy_min_func, max_func=numpy_max_func, get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data, + prev_node_bounds_and_samples=prev_node_bounds_and_samples, ) - # Check inputset size - inputset_size_upper_limit = 1 + if warn_on_inputset_length: + # Check inputset size + inputset_size_upper_limit = 1 - # this loop will determine the number of possible inputs of the function - # if a function have a single 3-bit input, for example, `inputset_size_upper_limit` will be 8 - for parameter_value in function_parameters.values(): - if isinstance(parameter_value.dtype, Integer): - # multiple parameter bit-widths are multiplied as they can be combined into an input - inputset_size_upper_limit *= 2 ** parameter_value.dtype.bit_width + # this loop will determine the number of possible inputs of the function + # if a function have a single 3-bit input, for example, inputset_size_upper_limit will be 8 + for parameter_value in function_parameters.values(): + if isinstance(parameter_value.dtype, Integer): + # multiple parameter bit-widths are multiplied as they can be combined into an input + inputset_size_upper_limit *= 2 ** parameter_value.dtype.bit_width - # if the upper limit of the inputset size goes above 10, - # break the loop as we will require at least 10 inputs in this case - if inputset_size_upper_limit > 10: - break + # if the upper limit of the inputset size goes above 10, + # break the loop as we will require at least 10 inputs in this case + if inputset_size_upper_limit > 10: + break - minimum_required_inputset_size = min(inputset_size_upper_limit, 10) - if inputset_size < minimum_required_inputset_size: - message = ( - f"Provided inputset contains too few inputs " - f"(it should have had at least {minimum_required_inputset_size} " - f"but it only had {inputset_size})\n" - ) + minimum_required_inputset_size = min(inputset_size_upper_limit, 10) + if inputset_size < minimum_required_inputset_size: + message = ( + f"Provided inputset contains too few inputs " + f"(it should have had at least {minimum_required_inputset_size} " + f"but it only had {inputset_size})\n" + ) - if compilation_configuration.treat_warnings_as_errors: - raise ValueError(message) + if compilation_configuration.treat_warnings_as_errors: + raise ValueError(message) - sys.stderr.write(f"Warning: {message}") + sys.stderr.write(f"Warning: {message}") # Add the bounds as an artifact compilation_artifacts.add_final_operation_graph_bounds(node_bounds_and_samples) @@ -195,6 +350,74 @@ def _measure_op_graph_bounds_and_update_internal( get_constructor_for_numpy_or_python_constant_data, ) + return node_bounds_and_samples + + +def measure_op_graph_bounds_and_update( + op_graph: OPGraph, + function_parameters: Dict[str, BaseValue], + inputset: Union[Iterable[Tuple[Any, ...]], str], + compilation_configuration: Optional[CompilationConfiguration] = None, + compilation_artifacts: Optional[CompilationArtifacts] = None, + prev_node_bounds_and_samples: Optional[Dict[IntermediateNode, Dict[str, Any]]] = None, + warn_on_inputset_length: bool = True, +) -> Dict[IntermediateNode, Dict[str, Any]]: + """Measure the intermediate values and update the OPGraph accordingly for the given inputset. + + Args: + op_graph (OPGraph): the OPGraph for which to measure bounds and update node values. + function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the + function is e.g. an EncryptedScalar holding a 7bits unsigned Integer + inputset (Union[Iterable[Tuple[Any, ...]], str]): The inputset over which op_graph is + evaluated. It needs to be an iterable on tuples which are of the same length than the + number of parameters in the function, and in the same order than these same parameters + compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use + during compilation + compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill + during compilation + prev_node_bounds_and_samples (Optional[Dict[IntermediateNode, Dict[str, Any]]], optional): + Bounds and samples from a previous run. Defaults to None. + warn_on_inputset_length (bool, optional): Set to True to get a warning if inputset is not + long enough. Defaults to True. + + Raises: + ValueError: Raises an error if the inputset is too small and the compilation configuration + treats warnings as error. + + Returns: + Dict[IntermediateNode, Dict[str, Any]]: a dict containing the bounds for each node from + op_graph, stored with the node as key and a dict with keys "min", "max" and "sample" as + value. + """ + + ( + compilation_configuration, + compilation_artifacts, + ) = sanitize_compilation_configuration_and_artifacts( + compilation_configuration, compilation_artifacts + ) + + inputset = get_inputset_to_use(function_parameters, inputset, compilation_configuration) + + def compilation_function(): + return _measure_op_graph_bounds_and_update_internal( + op_graph, + function_parameters, + inputset, + compilation_configuration, + compilation_artifacts, + prev_node_bounds_and_samples, + warn_on_inputset_length, + ) + + result = run_compilation_function_with_error_management( + compilation_function, compilation_configuration, compilation_artifacts + ) + + # for mypy + assert isinstance(result, dict) + return result + def _compile_numpy_function_into_op_graph_and_measure_bounds_internal( function_to_compile: Callable, @@ -269,48 +492,31 @@ def compile_numpy_function_into_op_graph_and_measure_bounds( OPGraph: compiled function into a graph """ - # Create default configuration if custom configuration is not specified - compilation_configuration = ( - CompilationConfiguration() - if compilation_configuration is None - else compilation_configuration + ( + compilation_configuration, + compilation_artifacts, + ) = sanitize_compilation_configuration_and_artifacts( + compilation_configuration, compilation_artifacts ) - # Create temporary artifacts if custom artifacts is not specified (in case of exceptions) - if compilation_artifacts is None: - compilation_artifacts = CompilationArtifacts() + inputset = get_inputset_to_use(function_parameters, inputset, compilation_configuration) - # Generate random inputset if it is requested and available - if isinstance(inputset, str): - _check_special_inputset_availability(inputset, compilation_configuration) - inputset = _generate_random_inputset(function_parameters, compilation_configuration) + def compilation_function(): + return _compile_numpy_function_into_op_graph_and_measure_bounds_internal( + function_to_compile, + function_parameters, + inputset, + compilation_configuration, + compilation_artifacts, + ) - # Try to compile the function and save partial artifacts on failure - try: - # Use context manager to restore numpy error handling - with numpy.errstate(**numpy.geterr()): - return _compile_numpy_function_into_op_graph_and_measure_bounds_internal( - function_to_compile, - function_parameters, - inputset, - compilation_configuration, - compilation_artifacts, - ) - except Exception: # pragma: no cover - # This branch is reserved for unexpected issues and hence it shouldn't be tested. - # If it could be tested, we would have fixed the underlying issue. + result = run_compilation_function_with_error_management( + compilation_function, compilation_configuration, compilation_artifacts + ) - # We need to export all the information we have about the compilation - # If the user wants them to be exported - - if compilation_configuration.dump_artifacts_on_unexpected_failures: - compilation_artifacts.export() - - traceback_path = compilation_artifacts.output_directory.joinpath("traceback.txt") - with open(traceback_path, "w", encoding="utf-8") as f: - f.write(traceback.format_exc()) - - raise + # for mypy + assert isinstance(result, OPGraph) + return result # HACK @@ -492,46 +698,29 @@ def compile_numpy_function( CompilerEngine: engine to run and debug the compiled graph """ - # Create default configuration if custom configuration is not specified - compilation_configuration = ( - CompilationConfiguration() - if compilation_configuration is None - else compilation_configuration + ( + compilation_configuration, + compilation_artifacts, + ) = sanitize_compilation_configuration_and_artifacts( + compilation_configuration, compilation_artifacts ) - # Create temporary artifacts if custom artifacts is not specified (in case of exceptions) - if compilation_artifacts is None: - compilation_artifacts = CompilationArtifacts() + inputset = get_inputset_to_use(function_parameters, inputset, compilation_configuration) - # Generate random inputset if it is requested and available - if isinstance(inputset, str): - _check_special_inputset_availability(inputset, compilation_configuration) - inputset = _generate_random_inputset(function_parameters, compilation_configuration) + def compilation_function(): + return _compile_numpy_function_internal( + function_to_compile, + function_parameters, + inputset, + compilation_configuration, + compilation_artifacts, + show_mlir, + ) - # Try to compile the function and save partial artifacts on failure - try: - # Use context manager to restore numpy error handling - with numpy.errstate(**numpy.geterr()): - return _compile_numpy_function_internal( - function_to_compile, - function_parameters, - inputset, - compilation_configuration, - compilation_artifacts, - show_mlir, - ) - except Exception: # pragma: no cover - # This branch is reserved for unexpected issues and hence it shouldn't be tested. - # If it could be tested, we would have fixed the underlying issue. + result = run_compilation_function_with_error_management( + compilation_function, compilation_configuration, compilation_artifacts + ) - # We need to export all the information we have about the compilation - # If the user wants them to be exported - - if compilation_configuration.dump_artifacts_on_unexpected_failures: - compilation_artifacts.export() - - traceback_path = compilation_artifacts.output_directory.joinpath("traceback.txt") - with open(traceback_path, "w", encoding="utf-8") as f: - f.write(traceback.format_exc()) - - raise + # for mypy + assert isinstance(result, FHECircuit) + return result diff --git a/concrete/numpy/np_fhe_compiler.py b/concrete/numpy/np_fhe_compiler.py new file mode 100644 index 000000000..fe90b219c --- /dev/null +++ b/concrete/numpy/np_fhe_compiler.py @@ -0,0 +1,199 @@ +"""Module to hold a user friendly class to compile programs.""" + +from copy import deepcopy +from enum import Enum, unique +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +from ..common.compilation import CompilationArtifacts, CompilationConfiguration +from ..common.data_types import Integer +from ..common.operator_graph import OPGraph +from ..common.representation.intermediate import IntermediateNode +from ..common.values import BaseValue +from .compile import compile_numpy_function_into_op_graph, measure_op_graph_bounds_and_update +from .np_dtypes_helpers import get_base_value_for_numpy_or_python_constant_data + + +@unique +class EncryptedStatus(str, Enum): + """Enum to validate GenericFunction op_kind.""" + + CLEAR = "clear" + ENCRYPTED = "encrypted" + + +class NPFHECompiler: + """Class to ease the conversion of a numpy program to its FHE equivalent.""" + + INPUTSET_SIZE_BEFORE_AUTO_BOUND_UPDATE: int = 128 + + # _function_to_compile is not optional but mypy has a long standing bug and is not able to + # understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623 + _function_to_compile: Optional[Callable] + _function_parameters_encrypted_status: Dict[str, bool] + _current_inputset: List[Union[Any, Tuple]] + _op_graph: Optional[OPGraph] + _nodes_and_bounds: Dict[IntermediateNode, Dict[str, Any]] + + _compilation_configuration: CompilationConfiguration + + compilation_artifacts: CompilationArtifacts + + def __init__( + self, + function_to_compile: Callable, + function_parameters_encrypted_status: Dict[str, Union[str, EncryptedStatus]], + compilation_configuration: Optional[CompilationConfiguration] = None, + compilation_artifacts: Optional[CompilationArtifacts] = None, + ) -> None: + self._function_to_compile = function_to_compile + self._function_parameters_encrypted_status = { + param_name: EncryptedStatus(status.lower()) == EncryptedStatus.ENCRYPTED + for param_name, status in function_parameters_encrypted_status.items() + } + + self._current_inputset = [] + self._op_graph = None + self._nodes_and_bounds = {} + + self._compilation_configuration = ( + deepcopy(compilation_configuration) + if compilation_configuration is not None + else CompilationConfiguration() + ) + self.compilation_artifacts = ( + compilation_artifacts if compilation_artifacts is not None else CompilationArtifacts() + ) + + @property + def function_to_compile(self) -> Callable: + """Get the function to compile. + + Returns: + Callable: the function to compile. + """ + # Continuation of mypy bug + assert self._function_to_compile is not None + return self._function_to_compile + + @property + def op_graph(self) -> Optional[OPGraph]: + """Return a copy of the OPGraph. + + Returns: + Optional[OPGraph]: the held OPGraph or None + """ + # To keep consistency with what the user expects, we make sure to evaluate on the remaining + # inputset values if any before giving a copy of the OPGraph we trace + self._eval_on_current_inputset() + return deepcopy(self._op_graph) + + @property + def compilation_configuration(self) -> Optional[CompilationConfiguration]: + """Get a copy of the compilation configuration. + + Returns: + Optional[CompilationConfiguration]: copy of the current compilation configuration. + """ + return deepcopy(self._compilation_configuration) + + def __call__(self, *args: Any) -> Any: + """Evaluate the OPGraph corresponding to the function being compiled and return result. + + Returns: + Any: the result of the OPGraph evaluation. + """ + self._current_inputset.append(deepcopy(args)) + + inferred_args = { + param_name: get_base_value_for_numpy_or_python_constant_data(val)( + is_encrypted=is_encrypted + ) + for (param_name, is_encrypted), val in zip( + self._function_parameters_encrypted_status.items(), args + ) + } + + if len(self._current_inputset) >= self.INPUTSET_SIZE_BEFORE_AUTO_BOUND_UPDATE: + self._eval_on_current_inputset() + + self._trace_op_graph_if_needed(inferred_args) + + # For mypy + assert self._op_graph is not None + return self._op_graph(*args) + + def eval_on_inputset(self, inputset: Iterable[Union[Any, Tuple]]) -> None: + """Evaluate the underlying function on an inputset in one go, populates OPGraph and bounds. + + Args: + inputset (Iterable[Union[Any, Tuple]]): The inputset on which the function should be + evaluated. + """ + inputset_as_list = list(inputset) + if len(inputset_as_list) == 0: + return + + inferred_args = { + param_name: get_base_value_for_numpy_or_python_constant_data(val)( + is_encrypted=is_encrypted + ) + for (param_name, is_encrypted), val in zip( + self._function_parameters_encrypted_status.items(), self._current_inputset[0] + ) + } + + self._trace_op_graph_if_needed(inferred_args) + + # For mypy + assert self._op_graph is not None + + self._patch_op_graph_input_to_accept_any_integer_input() + + self._nodes_and_bounds = measure_op_graph_bounds_and_update( + self._op_graph, + inferred_args, + inputset_as_list, + self._compilation_configuration, + self.compilation_artifacts, + self._nodes_and_bounds, + False, + ) + + def _eval_on_current_inputset(self) -> None: + """Evaluate OPGraph on _current_inputset.""" + self.eval_on_inputset(self._current_inputset) + self._current_inputset.clear() + + def _needs_tracing(self) -> bool: + """Return whether we need to trace the function and populate the OPGraph.""" + return self._op_graph is None + + def _trace_op_graph_if_needed(self, inferred_args: Dict[str, BaseValue]) -> None: + """Populate _op_graph with the OPGraph for _function_to_compile.""" + if not self._needs_tracing(): + return + + self._op_graph = compile_numpy_function_into_op_graph( + self.function_to_compile, + inferred_args, + self._compilation_configuration, + self.compilation_artifacts, + ) + + def _patch_op_graph_input_to_accept_any_integer_input(self) -> None: + """Patch inputs as we don't know what data we expect.""" + + # Can only do that if the OPGraph was created hence the test. + if self._needs_tracing(): + return + + # For mypy + assert self._op_graph is not None + + # Cheat on Input nodes to avoid issues during inputset eval as we do not know in advance + # what the final bit width for the inputs should be + for node in self._op_graph.input_nodes.values(): + for input_ in node.inputs: + if isinstance(dtype := (input_.dtype), Integer): + dtype.bit_width = 128 + dtype.is_signed = True diff --git a/tests/numpy/test_compile_user_friendly_api.py b/tests/numpy/test_compile_user_friendly_api.py new file mode 100644 index 000000000..02fbf4c6d --- /dev/null +++ b/tests/numpy/test_compile_user_friendly_api.py @@ -0,0 +1,215 @@ +"""Test file for user-friendly numpy compilation functions""" + +import numpy +import pytest + +from concrete.common.debugging import format_operation_graph +from concrete.numpy.np_fhe_compiler import NPFHECompiler + + +def complicated_topology(x, y): + """Mix x in an intricated way.""" + intermediate = x + y + x_p_1 = intermediate + 1 + x_p_2 = intermediate + 2 + x_p_3 = x_p_1 + x_p_2 + return ( + x_p_3.astype(numpy.int32), + x_p_2.astype(numpy.int32), + (x_p_2 + 3).astype(numpy.int32), + x_p_3.astype(numpy.int32) + 67, + ) + + +@pytest.mark.parametrize("input_shape", [(), (3, 1, 2)]) +def test_np_fhe_compiler(input_shape, default_compilation_configuration): + """Test NPFHECompiler in two subtests.""" + subtest_np_fhe_compiler_1_input(input_shape, default_compilation_configuration) + subtest_np_fhe_compiler_2_inputs(input_shape, default_compilation_configuration) + + +def subtest_np_fhe_compiler_1_input(input_shape, default_compilation_configuration): + """test for NPFHECompiler on one input function""" + + compiler = NPFHECompiler( + lambda x: complicated_topology(x, 0), + {"x": "encrypted"}, + default_compilation_configuration, + ) + + # For coverage when the OPGraph is not yet traced + compiler._patch_op_graph_input_to_accept_any_integer_input() # pylint: disable=protected-access + + assert compiler.compilation_configuration == default_compilation_configuration + assert compiler.compilation_configuration is not default_compilation_configuration + + for i in numpy.arange(5): + i = numpy.ones(input_shape, dtype=numpy.int64) * i + assert numpy.array_equal(compiler(i), complicated_topology(i, 0)) + + # For coverage, check that we flush the inputset when we query the OPGraph + current_op_graph = compiler.op_graph + assert current_op_graph is not compiler.op_graph + assert len(compiler._current_inputset) == 0 # pylint: disable=protected-access + # For coverage, cover case where the current inputset is empty + compiler._eval_on_current_inputset() # pylint: disable=protected-access + + # Continue a bit more + for i in numpy.arange(5, 10): + i = numpy.ones(input_shape, dtype=numpy.int64) * i + assert numpy.array_equal(compiler(i), complicated_topology(i, 0)) + + if input_shape == (): + assert ( + (got := format_operation_graph(compiler.op_graph)) + == """ %0 = 67 # ClearScalar + %1 = 2 # ClearScalar + %2 = 3 # ClearScalar + %3 = 1 # ClearScalar + %4 = x # EncryptedScalar + %5 = 0 # ClearScalar + %6 = add(%4, %5) # EncryptedScalar + %7 = add(%6, %1) # EncryptedScalar + %8 = add(%6, %3) # EncryptedScalar + %9 = astype(%7, dtype=int32) # EncryptedScalar +%10 = add(%7, %2) # EncryptedScalar +%11 = add(%8, %7) # EncryptedScalar +%12 = astype(%10, dtype=int32) # EncryptedScalar +%13 = astype(%11, dtype=int32) # EncryptedScalar +%14 = astype(%11, dtype=int32) # EncryptedScalar +%15 = add(%14, %0) # EncryptedScalar +(%13, %9, %12, %15)""" + ), got + else: + assert ( + (got := format_operation_graph(compiler.op_graph)) + == """ %0 = 67 # ClearScalar + %1 = 2 # ClearScalar + %2 = 3 # ClearScalar + %3 = 1 # ClearScalar + %4 = x # EncryptedTensor + %5 = 0 # ClearScalar + %6 = add(%4, %5) # EncryptedTensor + %7 = add(%6, %1) # EncryptedTensor + %8 = add(%6, %3) # EncryptedTensor + %9 = astype(%7, dtype=int32) # EncryptedTensor +%10 = add(%7, %2) # EncryptedTensor +%11 = add(%8, %7) # EncryptedTensor +%12 = astype(%10, dtype=int32) # EncryptedTensor +%13 = astype(%11, dtype=int32) # EncryptedTensor +%14 = astype(%11, dtype=int32) # EncryptedTensor +%15 = add(%14, %0) # EncryptedTensor +(%13, %9, %12, %15)""" + ), got + + +def subtest_np_fhe_compiler_2_inputs(input_shape, default_compilation_configuration): + """test for NPFHECompiler on two inputs function""" + + compiler = NPFHECompiler( + complicated_topology, + {"x": "encrypted", "y": "clear"}, + default_compilation_configuration, + ) + + # For coverage when the OPGraph is not yet traced + compiler._patch_op_graph_input_to_accept_any_integer_input() # pylint: disable=protected-access + + assert compiler.compilation_configuration == default_compilation_configuration + assert compiler.compilation_configuration is not default_compilation_configuration + + for i, j in zip(numpy.arange(5), numpy.arange(5, 10)): + i = numpy.ones(input_shape, dtype=numpy.int64) * i + j = numpy.ones(input_shape, dtype=numpy.int64) * j + assert numpy.array_equal(compiler(i, j), complicated_topology(i, j)) + + # For coverage, check that we flush the inputset when we query the OPGraph + current_op_graph = compiler.op_graph + assert current_op_graph is not compiler.op_graph + assert len(compiler._current_inputset) == 0 # pylint: disable=protected-access + # For coverage, cover case where the current inputset is empty + compiler._eval_on_current_inputset() # pylint: disable=protected-access + + # Continue a bit more + for i, j in zip(numpy.arange(5, 10), numpy.arange(5)): + i = numpy.ones(input_shape, dtype=numpy.int64) * i + j = numpy.ones(input_shape, dtype=numpy.int64) * j + assert numpy.array_equal(compiler(i, j), complicated_topology(i, j)) + + if input_shape == (): + assert ( + (got := format_operation_graph(compiler.op_graph)) + == """ %0 = 67 # ClearScalar + %1 = 2 # ClearScalar + %2 = 3 # ClearScalar + %3 = 1 # ClearScalar + %4 = x # EncryptedScalar + %5 = y # ClearScalar + %6 = add(%4, %5) # EncryptedScalar + %7 = add(%6, %1) # EncryptedScalar + %8 = add(%6, %3) # EncryptedScalar + %9 = astype(%7, dtype=int32) # EncryptedScalar +%10 = add(%7, %2) # EncryptedScalar +%11 = add(%8, %7) # EncryptedScalar +%12 = astype(%10, dtype=int32) # EncryptedScalar +%13 = astype(%11, dtype=int32) # EncryptedScalar +%14 = astype(%11, dtype=int32) # EncryptedScalar +%15 = add(%14, %0) # EncryptedScalar +(%13, %9, %12, %15)""" + ), got + else: + assert ( + (got := format_operation_graph(compiler.op_graph)) + == """ %0 = 67 # ClearScalar + %1 = 2 # ClearScalar + %2 = 3 # ClearScalar + %3 = 1 # ClearScalar + %4 = x # EncryptedTensor + %5 = y # ClearTensor + %6 = add(%4, %5) # EncryptedTensor + %7 = add(%6, %1) # EncryptedTensor + %8 = add(%6, %3) # EncryptedTensor + %9 = astype(%7, dtype=int32) # EncryptedTensor +%10 = add(%7, %2) # EncryptedTensor +%11 = add(%8, %7) # EncryptedTensor +%12 = astype(%10, dtype=int32) # EncryptedTensor +%13 = astype(%11, dtype=int32) # EncryptedTensor +%14 = astype(%11, dtype=int32) # EncryptedTensor +%15 = add(%14, %0) # EncryptedTensor +(%13, %9, %12, %15)""" + ), got + + +def remaining_inputset_size(inputset_len): + """Small function to generate test cases below for remaining inputset length.""" + return inputset_len % NPFHECompiler.INPUTSET_SIZE_BEFORE_AUTO_BOUND_UPDATE + + +@pytest.mark.parametrize( + "inputset_len, expected_remaining_inputset_len", + [ + (42, remaining_inputset_size(42)), + (128, remaining_inputset_size(128)), + (234, remaining_inputset_size(234)), + ], +) +def test_np_fhe_compiler_auto_flush( + inputset_len, + expected_remaining_inputset_len, + default_compilation_configuration, +): + """Test the auto flush of NPFHECompiler once the inputset is 128 elements.""" + compiler = NPFHECompiler( + lambda x: x // 2, + {"x": "encrypted"}, + default_compilation_configuration, + ) + + for i in numpy.arange(inputset_len): + assert numpy.array_equal(compiler(i), i // 2) + + # Check the inputset was properly flushed + assert ( + len(compiler._current_inputset) # pylint: disable=protected-access + == expected_remaining_inputset_len + )