diff --git a/concrete/__init__.py b/concrete/__init__.py deleted file mode 100644 index 18d9e7759..000000000 --- a/concrete/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Top level import.""" -# Do not modify, this is to have a compatible namespace package -# https://packaging.python.org/en/latest/guides/packaging-namespace-packages/ -# #pkg-resources-style-namespace-packages -__import__("pkg_resources").declare_namespace(__name__) # pragma: no cover diff --git a/concrete/common/__init__.py b/concrete/common/__init__.py deleted file mode 100644 index ce8e30ea3..000000000 --- a/concrete/common/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Module for shared data structures and code.""" -from . import compilation, data_types, debugging, representation, values -from .common_helpers import check_op_graph_is_integer_program, is_a_power_of_2 diff --git a/concrete/common/bounds_measurement/__init__.py b/concrete/common/bounds_measurement/__init__.py deleted file mode 100644 index a1ea8260d..000000000 --- a/concrete/common/bounds_measurement/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Bounds measurement module.""" -from . import inputset_eval diff --git a/concrete/common/bounds_measurement/inputset_eval.py b/concrete/common/bounds_measurement/inputset_eval.py deleted file mode 100644 index a50b98ca5..000000000 --- a/concrete/common/bounds_measurement/inputset_eval.py +++ /dev/null @@ -1,260 +0,0 @@ -"""Code to evaluate the IR graph on inputsets.""" - -import sys -from functools import partial -from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union - -from ..compilation import CompilationConfiguration -from ..data_types.dtypes_helpers import ( - get_base_value_for_python_constant_data, - is_data_type_compatible_with, -) -from ..debugging import assert_true -from ..operator_graph import OPGraph -from ..representation.intermediate import IntermediateNode - - -def _check_input_coherency( - input_to_check: Dict[str, Any], - parameters: Dict[str, Any], - get_base_value_for_constant_data_func: Callable[[Any], Any], -): - """Check whether `input_to_check` is coherent with `parameters`. - - This function works by iterating over each constant of the input, - determining base value of the constant using `get_base_value_for_constant_data_func` and - checking if the base value of the contant is compatible with the base value of the parameter. - - Args: - input_to_check (Dict[str, Any]): input to check coherency of - parameters (Dict[str, Any]): parameters and their expected base values - get_base_value_for_constant_data_func (Callable[[Any], Any]): - function to get the base value of python objects. - - Returns: - List[str]: List of warnings about the coherency - """ - - warnings = [] - for parameter_name, value in input_to_check.items(): - parameter_base_value = parameters[parameter_name] - - base_value_class = get_base_value_for_constant_data_func(value) - base_value = base_value_class(is_encrypted=parameter_base_value.is_encrypted) - - if base_value.shape != parameter_base_value.shape or not is_data_type_compatible_with( - base_value.dtype, parameter_base_value.dtype - ): - warnings.append( - f"expected {str(parameter_base_value)} " - f"for parameter `{parameter_name}` " - f"but got {str(base_value)} " - f"which is not compatible" - ) - return warnings - - -def _print_input_coherency_warnings( - current_input_index: int, - current_input_data: Dict[int, Any], - parameters: Dict[str, Any], - parameter_index_to_parameter_name: Dict[int, str], - get_base_value_for_constant_data_func: Callable[[Any], Any], - treat_warnings_as_errors: bool, -): - """Print coherency warning for `input_to_check` against `parameters`. - - Args: - current_input_index (int): index of the current input on the inputset - current_input_data (Dict[int, Any]): input to print coherency warnings of - parameters (Dict[str, Any]): parameters and their expected base values - parameter_index_to_parameter_name (Dict[int, str]): - dict to get parameter names from parameter indices - get_base_value_for_constant_data_func (Callable[[Any], Any]): - function to get the base value of python objects. - - Returns: - None - """ - - current_input_named_data = { - parameter_index_to_parameter_name[index]: data for index, data in current_input_data.items() - } - - problems = _check_input_coherency( - current_input_named_data, - parameters, - get_base_value_for_constant_data_func, - ) - messages = [ - ( - f"Input #{current_input_index} (0-indexed) " - f"is not coherent with the hinted parameters ({problem})\n" - ) - for problem in problems - ] - - if len(messages) > 0: - if treat_warnings_as_errors: - raise ValueError(", ".join(messages)) - - for message in messages: - sys.stderr.write(f"Warning: {message}") - - -def eval_op_graph_bounds_on_inputset( - op_graph: OPGraph, - inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]], - compilation_configuration: CompilationConfiguration, - min_func: Callable[[Any, Any], Any] = min, - max_func: Callable[[Any, Any], Any] = max, - get_base_value_for_constant_data_func: Callable[ - [Any], Any - ] = get_base_value_for_python_constant_data, - prev_node_bounds_and_samples: Optional[Dict[IntermediateNode, Dict[str, Any]]] = None, -) -> Tuple[int, Dict[IntermediateNode, Dict[str, Any]]]: - """Evaluate the bounds with a inputset. - - Evaluate the bounds for all output values of the operators in the graph op_graph over data - coming from the inputset - - Args: - op_graph (OPGraph): The graph for which we want to determine the bounds - inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]]]): The inputset over which op_graph - is evaluated. It needs to be an iterable on tuples (can be single values in case the - function has only one argument) which are of the same length than the number of - parameters in the function, and in the same order than these same parameters - compilation_configuration (CompilationConfiguration): Configuration object to use - during determining input checking strategy - min_func (Callable[[Any, Any], Any], optional): custom function to compute a scalar minimum - between two values that can be encountered during evaluation (for e.g. numpy or torch - tensors). Defaults to min. - max_func (Callable[[Any, Any], Any], optional): custom function to compute a scalar maximum - between two values that can be encountered during evaluation (for e.g. numpy or torch - tensors). Defaults to max. - get_base_value_for_constant_data_func (Callable[[Any], Any], optional): custom function - to compute the base value of a python object. - prev_node_bounds_and_samples (Optional[Dict[IntermediateNode, Dict[str, Any]]], optional): - Bounds and samples from a previous run. Defaults to None. - - Returns: - Tuple[int, Dict[IntermediateNode, Dict[str, Any]]]: number of inputs in the inputset and - a dict containing the bounds for each node from op_graph, stored with the node - as key and a dict with keys "min", "max" and "sample" as value. - """ - - num_input_nodes = len(op_graph.input_nodes) - - def check_inputset_input_len_is_valid(data_to_check): - # Only check if there are more than one input node, otherwise accept the value as the sole - # argument passed to the OPGraph for evaluation - if num_input_nodes > 1: - assert_true( - len(data_to_check) == num_input_nodes, - ( - f"Got input data from inputset of len: {len(data_to_check)}, " - f"function being evaluated has {num_input_nodes} inputs, please make " - f"sure your data generator returns valid tuples of input values" - ), - ) - - def generate_input_values_dict(input_data) -> Dict[int, Any]: - if num_input_nodes > 1: - return dict(enumerate(input_data)) - # TODO: https://github.com/zama-ai/concrete-numpy-internal/issues/772 - # update this to support tuple in case of 1-input functions accepting tuples - assert_true( - not isinstance(input_data, tuple), - "Tuples are unsupported for single input inputset evaluation", - ) - return {0: input_data} - - # TODO: do we want to check coherence between the input data type and the corresponding Input ir - # node expected data type ? Not considering bit_width as they may not make sense at this stage - - parameter_index_to_parameter_name = { - index: input_node.input_name for index, input_node in op_graph.input_nodes.items() - } - parameters = { - input_node.input_name: input_node.inputs[0] for input_node in op_graph.input_nodes.values() - } - - inputset_iterator = iter(inputset) - inputset_size = 0 - - current_input_data = generate_input_values_dict(next(inputset_iterator)) - inputset_size += 1 - - check_inputset_input_len_is_valid(current_input_data.values()) - _print_input_coherency_warnings( - inputset_size - 1, - current_input_data, - parameters, - parameter_index_to_parameter_name, - get_base_value_for_constant_data_func, - compilation_configuration.treat_warnings_as_errors, - ) - - first_output = op_graph.evaluate(current_input_data) - - prev_node_bounds_and_samples = ( - {} if prev_node_bounds_and_samples is None else prev_node_bounds_and_samples - ) - - def get_previous_value_for_key_or_default_for_dict( - dict_: Dict[IntermediateNode, Dict[str, Any]], - node: IntermediateNode, - key: str, - default: Any, - ) -> Any: - return_value = default - - previous_value_dict = dict_.get(node, None) - - if previous_value_dict is not None: - return_value = previous_value_dict.get(key, default) - - return return_value - - get_previous_value_for_key_or_default = partial( - get_previous_value_for_key_or_default_for_dict, prev_node_bounds_and_samples - ) - - # We evaluate the min and max func to be able to resolve the tensors min and max rather than - # having the tensor itself as the stored min and max values. - # As we don't know the integrity of prev_node_bounds_and_samples we make sure we can - # populate the new node_bounds_and_samples - node_bounds_and_samples = { - node: { - "min": min_func(value, get_previous_value_for_key_or_default(node, "min", value)), - "max": max_func(value, get_previous_value_for_key_or_default(node, "max", value)), - "sample": get_previous_value_for_key_or_default(node, "sample", value), - } - for node, value in first_output.items() - } - - for input_data in inputset_iterator: - inputset_size += 1 - current_input_data = generate_input_values_dict(input_data) - - check_inputset_input_len_is_valid(current_input_data.values()) - if compilation_configuration.check_every_input_in_inputset: - _print_input_coherency_warnings( - inputset_size - 1, - current_input_data, - parameters, - parameter_index_to_parameter_name, - get_base_value_for_constant_data_func, - compilation_configuration.treat_warnings_as_errors, - ) - - current_output = op_graph.evaluate(current_input_data) - for node, value in current_output.items(): - node_bounds_and_samples[node]["min"] = min_func( - node_bounds_and_samples[node]["min"], value - ) - node_bounds_and_samples[node]["max"] = max_func( - node_bounds_and_samples[node]["max"], value - ) - - return inputset_size, node_bounds_and_samples diff --git a/concrete/common/common_helpers.py b/concrete/common/common_helpers.py deleted file mode 100644 index 9ad5138d3..000000000 --- a/concrete/common/common_helpers.py +++ /dev/null @@ -1,67 +0,0 @@ -"""File to hold some helper code.""" - -from typing import List, Optional - -from .data_types.integers import Integer -from .debugging import assert_true -from .operator_graph import OPGraph -from .representation.intermediate import IntermediateNode - - -def is_a_power_of_2(x: int) -> bool: - """Check if an integer is a power of two. - - Args: - x (int): Number to check - - Returns: - bool: True if the number is a power of two - """ - # https://stackoverflow.com/questions/57025836/how-to-check-if-a-given-number-is-a-power-of-two - - return x > 0 and (x & (x - 1)) == 0 - - -def ir_nodes_has_integer_input_and_output(node: IntermediateNode) -> bool: - """Check if an ir node has Integer inputs and outputs. - - Args: - node (IntermediateNode): Node to check - - Returns: - bool: True if all input and output values hold Integers - """ - return all(isinstance(x.dtype, Integer) for x in node.inputs) and all( - isinstance(x.dtype, Integer) for x in node.outputs - ) - - -# This check makes sense as long as the compiler backend only manages integers, to be removed in the -# long run probably -def check_op_graph_is_integer_program( - op_graph: OPGraph, - offending_nodes_out: Optional[List[IntermediateNode]] = None, -) -> bool: - """Check if an op_graph inputs, outputs and intermediate values are Integers. - - Args: - op_graph (OPGraph): The OPGraph to check - offending_nodes_out (Optional[List[IntermediateNode]]): Optionally pass a list that will - be populated with offending nodes, the list will be cleared before being filled - - Returns: - bool: True if inputs, outputs and intermediate values are Integers, False otherwise - """ - offending_nodes = [] if offending_nodes_out is None else offending_nodes_out - - assert_true( - isinstance(offending_nodes, list), - f"offending_nodes_out must be a list, got {type(offending_nodes_out)}", - ) - - offending_nodes.clear() - offending_nodes.extend( - node for node in op_graph.graph.nodes() if not ir_nodes_has_integer_input_and_output(node) - ) - - return len(offending_nodes) == 0 diff --git a/concrete/common/compilation/__init__.py b/concrete/common/compilation/__init__.py deleted file mode 100644 index 59e402417..000000000 --- a/concrete/common/compilation/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -"""Module for compilation related types.""" - -from .artifacts import CompilationArtifacts -from .configuration import CompilationConfiguration diff --git a/concrete/common/compilation/artifacts.py b/concrete/common/compilation/artifacts.py deleted file mode 100644 index 28526e908..000000000 --- a/concrete/common/compilation/artifacts.py +++ /dev/null @@ -1,222 +0,0 @@ -"""Module for compilation artifacts.""" - -import inspect -import platform -import shutil -import subprocess -from pathlib import Path -from typing import Any, Callable, Dict, Optional, Union - -import networkx as nx -from loguru import logger - -from ..debugging import assert_true, draw_graph, format_operation_graph -from ..operator_graph import OPGraph -from ..representation.intermediate import IntermediateNode -from ..values import BaseValue - -DEFAULT_OUTPUT_DIRECTORY: Path = Path(".artifacts") - - -class CompilationArtifacts: - """Class that conveys information about compilation process.""" - - output_directory: Path - - source_code_of_the_function_to_compile: Optional[str] - parameters_of_the_function_to_compile: Dict[str, str] - - drawings_of_operation_graphs: Dict[str, str] - textual_representations_of_operation_graphs: Dict[str, str] - - final_operation_graph: Optional[OPGraph] - bounds_of_the_final_operation_graph: Optional[Dict[IntermediateNode, Dict[str, Any]]] - mlir_of_the_final_operation_graph: Optional[str] - - def __init__(self, output_directory: Union[Path, str] = DEFAULT_OUTPUT_DIRECTORY): - self.output_directory = Path(output_directory) - - self.source_code_of_the_function_to_compile = None - self.parameters_of_the_function_to_compile = {} - - self.drawings_of_operation_graphs = {} - self.textual_representations_of_operation_graphs = {} - - self.final_operation_graph = None - self.bounds_of_the_final_operation_graph = None - self.mlir_of_the_final_operation_graph = None - - def add_function_to_compile(self, function: Union[Callable, str]): - """Add the function to compile to artifacts. - - Args: - function (Union[Callable, str]): the function to compile or source code of it - - Returns: - None - """ - - try: - self.source_code_of_the_function_to_compile = ( - function if isinstance(function, str) else inspect.getsource(function) - ) - # When using the python console we cannot use getsource, so catch that and emit an error - except OSError: # pragma: no cover - function_str = function if isinstance(function, str) else function.__name__ - logger.error(f"Could not get source for function: {function_str}") - self.source_code_of_the_function_to_compile = "unavailable" - - def add_parameter_of_function_to_compile(self, name: str, value: Union[BaseValue, str]): - """Add a parameter of the function to compile to the artifacts. - - Args: - name (str): name of the parameter - value (Union[BaseValue, str]): value of the parameter or textual representation of it - - Returns: - None - """ - - self.parameters_of_the_function_to_compile[name] = str(value) - - def add_operation_graph(self, name: str, operation_graph: OPGraph): - """Add an operation graph to the artifacts. - - Args: - name (str): name of the graph - operation_graph (OPGraph): the operation graph itself - - Returns: - None - """ - - try: - drawing = draw_graph(operation_graph) - self.drawings_of_operation_graphs[name] = drawing - # Do not crash on imports ourselves for drawings if the package is not installed - except ImportError as e: # pragma: no cover - if "pygraphviz" in str(e): - pass - else: - raise e - textual_representation = format_operation_graph(operation_graph) - - self.textual_representations_of_operation_graphs[name] = textual_representation - - self.final_operation_graph = operation_graph - - def add_final_operation_graph_bounds(self, bounds: Dict[IntermediateNode, Dict[str, Any]]): - """Add the bounds of the final operation graph to the artifacts. - - Args: - bounds (Dict[IntermediateNode, Dict[str, Any]]): the bound dictionary - - Returns: - None - """ - - assert_true(self.final_operation_graph is not None) - self.bounds_of_the_final_operation_graph = bounds - - def add_final_operation_graph_mlir(self, mlir: str): - """Add the mlir of the final operation graph to the artifacts. - - Args: - mlir (str): the mlir code of the final operation graph - - Returns: - None - """ - - assert_true(self.final_operation_graph is not None) - self.mlir_of_the_final_operation_graph = mlir - - def export(self): - """Export the artifacts to a the output directory. - - Returns: - None - """ - - output_directory = self.output_directory - if output_directory.exists(): - shutil.rmtree(output_directory) - output_directory.mkdir(parents=True) - - with open(output_directory.joinpath("environment.txt"), "w", encoding="utf-8") as f: - f.write(f"{platform.platform()} {platform.version()}\n") - f.write(f"Python {platform.python_version()}\n") - - with open(output_directory.joinpath("requirements.txt"), "w", encoding="utf-8") as f: - # example `pip list` output - - # Package Version - # ----------------------------- --------- - # alabaster 0.7.12 - # appdirs 1.4.4 - # ... ... - # ... ... - # wrapt 1.12.1 - # zipp 3.5.0 - - pip_process = subprocess.run( - ["pip", "--disable-pip-version-check", "list"], stdout=subprocess.PIPE, check=True - ) - dependencies = iter(pip_process.stdout.decode("utf-8").split("\n")) - - # skip 'Package ... Version' line - next(dependencies) - - # skip '------- ... -------' line - next(dependencies) - - for dependency in dependencies: - tokens = [token for token in dependency.split(" ") if token != ""] - if len(tokens) == 0: - continue - - name = tokens[0] - version = tokens[1] - - f.write(f"{name}=={version}\n") - - if self.source_code_of_the_function_to_compile is not None: - with open(output_directory.joinpath("function.txt"), "w", encoding="utf-8") as f: - f.write(self.source_code_of_the_function_to_compile) - - if len(self.parameters_of_the_function_to_compile) > 0: - with open(output_directory.joinpath("parameters.txt"), "w", encoding="utf-8") as f: - for name, parameter in self.parameters_of_the_function_to_compile.items(): - f.write(f"{name} :: {parameter}\n") - - drawings = self.drawings_of_operation_graphs.items() - for index, (name, drawing_filename) in enumerate(drawings): - identifier = CompilationArtifacts._identifier(index, name) - shutil.copy(drawing_filename, output_directory.joinpath(f"{identifier}.png")) - - textual_representations = self.textual_representations_of_operation_graphs.items() - for index, (name, representation) in enumerate(textual_representations): - identifier = CompilationArtifacts._identifier(index, name) - with open(output_directory.joinpath(f"{identifier}.txt"), "w", encoding="utf-8") as f: - f.write(f"{representation}") - - if self.bounds_of_the_final_operation_graph is not None: - assert_true(self.final_operation_graph is not None) - with open(output_directory.joinpath("bounds.txt"), "w", encoding="utf-8") as f: - # TODO: - # if nx.topological_sort is not deterministic between calls, - # the lines below will not work properly - # thus, we may want to change this in the future - for index, node in enumerate(nx.topological_sort(self.final_operation_graph.graph)): - bounds = self.bounds_of_the_final_operation_graph.get(node) - assert_true(bounds is not None) - f.write(f"%{index} :: [{bounds.get('min')}, {bounds.get('max')}]\n") - - if self.mlir_of_the_final_operation_graph is not None: - assert_true(self.final_operation_graph is not None) - with open(output_directory.joinpath("mlir.txt"), "w", encoding="utf-8") as f: - f.write(self.mlir_of_the_final_operation_graph) - - @staticmethod - def _identifier(index, name): - return f"{index + 1}.{name}.graph" diff --git a/concrete/common/compilation/configuration.py b/concrete/common/compilation/configuration.py deleted file mode 100644 index 79d64a665..000000000 --- a/concrete/common/compilation/configuration.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Module for compilation configuration.""" - - -class CompilationConfiguration: - """Class that allows the compilation process to be customized.""" - - dump_artifacts_on_unexpected_failures: bool - enable_topological_optimizations: bool - check_every_input_in_inputset: bool - treat_warnings_as_errors: bool - enable_unsafe_features: bool - random_inputset_samples: int - use_insecure_key_cache: bool - auto_parallelize: bool - loop_parallelize: bool - dataflow_parallelize: bool - - # pylint: disable=too-many-arguments - def __init__( - self, - dump_artifacts_on_unexpected_failures: bool = True, - enable_topological_optimizations: bool = True, - check_every_input_in_inputset: bool = False, - treat_warnings_as_errors: bool = False, - enable_unsafe_features: bool = False, - random_inputset_samples: int = 30, - use_insecure_key_cache: bool = False, - auto_parallelize: bool = False, - loop_parallelize: bool = True, - dataflow_parallelize: bool = False, - ): - self.dump_artifacts_on_unexpected_failures = dump_artifacts_on_unexpected_failures - self.enable_topological_optimizations = enable_topological_optimizations - self.check_every_input_in_inputset = check_every_input_in_inputset - self.treat_warnings_as_errors = treat_warnings_as_errors - self.enable_unsafe_features = enable_unsafe_features - self.random_inputset_samples = random_inputset_samples - self.use_insecure_key_cache = use_insecure_key_cache - self.auto_parallelize = auto_parallelize - self.loop_parallelize = loop_parallelize - self.dataflow_parallelize = dataflow_parallelize - - # pylint: enable=too-many-arguments - - def __eq__(self, other) -> bool: - return isinstance(other, CompilationConfiguration) and self.__dict__ == other.__dict__ diff --git a/concrete/common/data_types/__init__.py b/concrete/common/data_types/__init__.py deleted file mode 100644 index 8ee1eced0..000000000 --- a/concrete/common/data_types/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -"""Module for data types code and data structures.""" -from . import dtypes_helpers, floats, integers -from .floats import Float, Float16, Float32, Float64 -from .integers import Integer, SignedInteger, UnsignedInteger diff --git a/concrete/common/data_types/base.py b/concrete/common/data_types/base.py deleted file mode 100644 index 834e75dc9..000000000 --- a/concrete/common/data_types/base.py +++ /dev/null @@ -1,11 +0,0 @@ -"""File holding code to represent data types in a program.""" - -from abc import ABC, abstractmethod - - -class BaseDataType(ABC): - """Base class to represent a data type.""" - - @abstractmethod - def __eq__(self, o: object) -> bool: - """No default implementation.""" diff --git a/concrete/common/data_types/dtypes_helpers.py b/concrete/common/data_types/dtypes_helpers.py deleted file mode 100644 index 77430f9b7..000000000 --- a/concrete/common/data_types/dtypes_helpers.py +++ /dev/null @@ -1,393 +0,0 @@ -"""File to hold helper functions for data types related stuff.""" - -from copy import deepcopy -from functools import partial -from typing import Callable, Optional, Tuple, Union, cast - -from ..debugging.custom_assert import assert_true -from ..values import BaseValue, ClearTensor, EncryptedTensor, TensorValue -from .base import BaseDataType -from .floats import Float -from .integers import Integer, get_bits_to_represent_value_as_integer - -INTEGER_TYPES = (Integer,) -FLOAT_TYPES = (Float,) -BASE_DATA_TYPES = INTEGER_TYPES + FLOAT_TYPES - - -def value_is_encrypted_scalar_integer(value_to_check: BaseValue) -> bool: - """Check that a value is an encrypted scalar of type Integer. - - Args: - value_to_check (BaseValue): The value to check - - Returns: - bool: True if the passed value_to_check is an encrypted scalar of type Integer - """ - return value_is_scalar_integer(value_to_check) and value_to_check.is_encrypted - - -def value_is_encrypted_scalar_unsigned_integer(value_to_check: BaseValue) -> bool: - """Check that a value is an encrypted scalar of type unsigned Integer. - - Args: - value_to_check (BaseValue): The value to check - - Returns: - bool: True if the passed value_to_check is an encrypted scalar of type Integer and - unsigned - """ - return ( - value_is_encrypted_scalar_integer(value_to_check) - and not cast(Integer, value_to_check.dtype).is_signed - ) - - -def value_is_clear_scalar_integer(value_to_check: BaseValue) -> bool: - """Check that a value is a clear scalar of type Integer. - - Args: - value_to_check (BaseValue): The value to check - - Returns: - bool: True if the passed value_to_check is a clear scalar of type Integer - """ - return value_is_scalar_integer(value_to_check) and value_to_check.is_clear - - -def value_is_scalar_integer(value_to_check: BaseValue) -> bool: - """Check that a value is a scalar of type Integer. - - Args: - value_to_check (BaseValue): The value to check - - Returns: - bool: True if the passed value_to_check is a scalar of type Integer - """ - return ( - isinstance(value_to_check, TensorValue) - and value_to_check.is_scalar - and isinstance(value_to_check.dtype, INTEGER_TYPES) - ) - - -def value_is_integer(value_to_check: BaseValue) -> bool: - """Check that a value is of type Integer. - - Args: - value_to_check (BaseValue): The value to check - - Returns: - bool: True if the passed value_to_check is of type Integer - """ - - return isinstance(value_to_check.dtype, INTEGER_TYPES) - - -def value_is_unsigned_integer(value_to_check: BaseValue) -> bool: - """Check that a value is of type Integer and is unsigned. - - Args: - value_to_check (BaseValue): The value to check - - Returns: - bool: True if the passed value_to_check is of type Integer and is unsigned - """ - - return ( - isinstance(value_to_check.dtype, INTEGER_TYPES) - and not cast(Integer, value_to_check.dtype).is_signed - ) - - -def value_is_encrypted_tensor_integer(value_to_check: BaseValue) -> bool: - """Check that a value is an encrypted TensorValue of type Integer. - - Args: - value_to_check (BaseValue): The value to check - - Returns: - bool: True if the passed value_to_check is an encrypted TensorValue of type Integer - """ - return value_is_tensor_integer(value_to_check) and value_to_check.is_encrypted - - -def value_is_clear_tensor_integer(value_to_check: BaseValue) -> bool: - """Check that a value is a clear TensorValue of type Integer. - - Args: - value_to_check (BaseValue): The value to check - - Returns: - bool: True if the passed value_to_check is a clear TensorValue of type Integer - """ - return value_is_tensor_integer(value_to_check) and value_to_check.is_clear - - -def value_is_tensor_integer(value_to_check: BaseValue) -> bool: - """Check that a value is a TensorValue of type Integer. - - Args: - value_to_check (BaseValue): The value to check - - Returns: - bool: True if the passed value_to_check is a TensorValue of type Integer - """ - return ( - isinstance(value_to_check, TensorValue) - and not value_to_check.is_scalar - and isinstance(value_to_check.dtype, INTEGER_TYPES) - ) - - -def find_type_to_hold_both_lossy( - dtype1: BaseDataType, - dtype2: BaseDataType, -) -> BaseDataType: - """Determine the type that can represent both dtype1 and dtype2 separately. - - This is lossy with floating point types. - - Args: - dtype1 (BaseDataType): first dtype to hold - dtype2 (BaseDataType): second dtype to hold - - Raises: - NotImplementedError: Raised if one of the two input dtypes is not an Integer as they are the - only type supported for now - - Returns: - BaseDataType: The dtype able to hold (potentially lossy) dtype1 and dtype2 - """ - assert_true(isinstance(dtype1, BASE_DATA_TYPES), f"Unsupported dtype1: {type(dtype1)}") - assert_true(isinstance(dtype2, BASE_DATA_TYPES), f"Unsupported dtype2: {type(dtype2)}") - - type_to_return: BaseDataType - - if isinstance(dtype1, Integer) and isinstance(dtype2, Integer): - d1_signed = dtype1.is_signed - d2_signed = dtype2.is_signed - max_bits = max(dtype1.bit_width, dtype2.bit_width) - - if d1_signed and d2_signed: - type_to_return = Integer(max_bits, is_signed=True) - elif not d1_signed and not d2_signed: - type_to_return = Integer(max_bits, is_signed=False) - elif d1_signed and not d2_signed: - # 2 is unsigned, if it has the bigger bit_width, we need a signed integer that can hold - # it, so add 1 bit of sign to its bit_width - if dtype2.bit_width >= dtype1.bit_width: - new_bit_width = dtype2.bit_width + 1 - type_to_return = Integer(new_bit_width, is_signed=True) - else: - type_to_return = Integer(dtype1.bit_width, is_signed=True) - elif not d1_signed and d2_signed: - # Same as above, with 1 and 2 switched around - if dtype1.bit_width >= dtype2.bit_width: - new_bit_width = dtype1.bit_width + 1 - type_to_return = Integer(new_bit_width, is_signed=True) - else: - type_to_return = Integer(dtype2.bit_width, is_signed=True) - elif isinstance(dtype1, Float) and isinstance(dtype2, Float): - max_bits = max(dtype1.bit_width, dtype2.bit_width) - type_to_return = Float(max_bits) - elif isinstance(dtype1, Float): - type_to_return = deepcopy(dtype1) - elif isinstance(dtype2, Float): - type_to_return = deepcopy(dtype2) - - return type_to_return - - -def mix_tensor_values_determine_holding_dtype( - value1: TensorValue, - value2: TensorValue, -) -> TensorValue: - """Return mixed TensorValue with data type able to hold both value1 and value2 dtypes. - - Returns a TensorValue that would result from computation on both value1 and value2 while - determining the data type able to hold both value1 and value2 data type (this can be lossy - with floats). - - Args: - value1 (TensorValue): first TensorValue to mix. - value2 (TensorValue): second TensorValue to mix. - - Returns: - TensorValue: The resulting mixed TensorValue with data type able to hold both value1 and - value2 dtypes. - """ - - assert_true( - isinstance(value1, TensorValue), f"Unsupported value1: {value1}, expected TensorValue" - ) - assert_true( - isinstance(value2, TensorValue), f"Unsupported value2: {value2}, expected TensorValue" - ) - - resulting_shape = broadcast_shapes(value1.shape, value2.shape) - assert_true( - resulting_shape is not None, - ( - f"Tensors have incompatible shapes which is not supported.\n" - f"value1: {value1.shape}, value2: {value2.shape}" - ), - ) - assert resulting_shape is not None # this is to make mypy happy - - holding_type = find_type_to_hold_both_lossy(value1.dtype, value2.dtype) - if value1.is_encrypted or value2.is_encrypted: - mixed_value = EncryptedTensor(dtype=holding_type, shape=resulting_shape) - else: - mixed_value = ClearTensor(dtype=holding_type, shape=resulting_shape) - - return mixed_value - - -def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> BaseValue: - """Return mixed BaseValue with data type able to hold both value1 and value2 dtypes. - - Returns a BaseValue that would result from computation on both value1 and value2 while - determining the data type able to hold both value1 and value2 data type (this can be lossy - with floats). Supports only mixing instances from the same class. - - Args: - value1 (BaseValue): first BaseValue to mix. - value2 (BaseValue): second BaseValue to mix. - - Raises: - ValueError: raised if the BaseValue is not one of (TensorValue) - - Returns: - BaseValue: The resulting mixed BaseValue with data type able to hold both value1 and value2 - dtypes. - """ - - assert_true( - (value1.__class__ == value2.__class__), - f"Cannot mix values of different types: value 1:{type(value1)}, value2: {type(value2)}", - ) - - if isinstance(value1, TensorValue) and isinstance(value2, TensorValue): - return mix_tensor_values_determine_holding_dtype(value1, value2) - - raise ValueError( - f"{mix_values_determine_holding_dtype.__name__} does not support value {type(value1)}" - ) - - -def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]) -> BaseDataType: - """Determine the BaseDataType to hold the input constant data. - - Args: - constant_data (Union[int, float]): The constant data for which to determine the - corresponding BaseDataType. - - Returns: - BaseDataType: The corresponding BaseDataType - """ - constant_data_type: BaseDataType - assert_true( - isinstance(constant_data, (int, float)), - f"Unsupported constant data of type {type(constant_data)}", - ) - - if isinstance(constant_data, int): - is_signed = constant_data < 0 - constant_data_type = Integer( - get_bits_to_represent_value_as_integer(constant_data, is_signed), is_signed - ) - elif isinstance(constant_data, float): - constant_data_type = Float(64) - - return constant_data_type - - -def get_base_value_for_python_constant_data( - constant_data: Union[int, float] -) -> Callable[..., BaseValue]: - """Wrap the BaseDataType to hold the input constant data in BaseValue partial. - - The returned object can then be instantiated as an Encrypted or Clear version - by calling it with the proper arguments forwarded to the BaseValue `__init__` function - - Args: - constant_data (Union[int, float]): The constant data for which to determine the - corresponding Value. - - Returns: - Callable[..., BaseValue]: A partial object that will return the proper BaseValue when - called with `is_encrypted` as keyword argument (forwarded to the BaseValue `__init__` - method). - """ - - constant_data_type = get_base_data_type_for_python_constant_data(constant_data) - return partial(TensorValue, dtype=constant_data_type, shape=()) - - -def get_constructor_for_python_constant_data(constant_data: Union[int, float]): - """Get the constructor for the passed python constant data. - - Args: - constant_data (Any): The data for which we want to determine the type constructor. - """ - return type(constant_data) - - -def is_data_type_compatible_with( - dtype: BaseDataType, - other: BaseDataType, -) -> bool: - """Determine whether dtype is compatible with other. - - `dtype` being compatible with `other` means `other` can hold every value of `dtype` - (e.g., uint2 is compatible with uint4 and int4) - (e.g., int2 is compatible with int4 but not with uint4) - - Note that this function is not symetric. - (e.g., uint2 is compatible with uint4, but uint4 is not compatible with uint2) - - Args: - dtype (BaseDataType): dtype to check compatiblity - other (BaseDataType): dtype to check compatiblity against - - Returns: - bool: Whether the dtype is compatible with other or not - """ - - combination = find_type_to_hold_both_lossy(dtype, other) - return other == combination - - -def broadcast_shapes(shape1: Tuple[int, ...], shape2: Tuple[int, ...]) -> Optional[Tuple[int, ...]]: - """Broadcast two shapes into a single shape. - - We are mimicing the exact semantics of broadcasting in numpy. - You can learn more about it here: https://numpy.org/doc/stable/user/theory.broadcasting.html - - Args: - shape1 (Tuple[int, ...]): first shape to broadcast - shape2 (Tuple[int, ...]): second shape to broadcast - - Returns: - Optional[Tuple[int, ...]]: None if the shapes are not broadcastable else broadcasted shape - """ - - result = [] - for size1, size2 in zip(shape1[::-1], shape2[::-1]): - if size1 != size2 and size1 != 1 and size2 != 1 and size1 != 0 and size2 != 0: - return None - - if size1 == 0 or size2 == 0: - result.append(0) - else: - result.append(max(size1, size2)) - - if len(result) < len(shape1): - for i in reversed(range(len(shape1) - len(result))): - result.append(shape1[i]) - elif len(result) < len(shape2): - for i in reversed(range(len(shape2) - len(result))): - result.append(shape2[i]) - - return tuple(reversed(result)) diff --git a/concrete/common/data_types/floats.py b/concrete/common/data_types/floats.py deleted file mode 100644 index 57db1956b..000000000 --- a/concrete/common/data_types/floats.py +++ /dev/null @@ -1,33 +0,0 @@ -"""This file holds the definitions for floating point types.""" - -from functools import partial - -from ..debugging.custom_assert import assert_true -from . import base - - -class Float(base.BaseDataType): - """Class representing a float.""" - - # bit_width is the total number of bits used to represent a floating point number, including - # sign bit, exponent and mantissa - bit_width: int - - def __init__(self, bit_width: int) -> None: - super().__init__() - assert_true(bit_width in (16, 32, 64), "Only 16, 32 and 64 bits floats are supported") - self.bit_width = bit_width - - def __repr__(self) -> str: - return f"{self.__class__.__name__}<{self.bit_width} bits>" - - def __str__(self) -> str: - return f"float{self.bit_width}" - - def __eq__(self, other: object) -> bool: - return isinstance(other, self.__class__) and self.bit_width == other.bit_width - - -Float16 = partial(Float, 16) -Float32 = partial(Float, 32) -Float64 = partial(Float, 64) diff --git a/concrete/common/data_types/integers.py b/concrete/common/data_types/integers.py deleted file mode 100644 index 29e33bc70..000000000 --- a/concrete/common/data_types/integers.py +++ /dev/null @@ -1,144 +0,0 @@ -"""This file holds the definitions for integer types.""" - -import math -from typing import Any, Iterable - -from ..debugging.custom_assert import assert_true -from . import base - - -class Integer(base.BaseDataType): - """Class representing an integer.""" - - bit_width: int - is_signed: bool - - def __init__(self, bit_width: int, is_signed: bool) -> None: - super().__init__() - assert_true(bit_width > 0, "bit_width must be > 0") - self.bit_width = bit_width - self.is_signed = is_signed - - def __repr__(self) -> str: - signed_str = "signed" if self.is_signed else "unsigned" - return f"{self.__class__.__name__}<{signed_str}, {self.bit_width} bits>" - - def __str__(self) -> str: - return f"{('int' if self.is_signed else 'uint')}{self.bit_width}" - - def __eq__(self, other: object) -> bool: - return ( - isinstance(other, self.__class__) - and self.bit_width == other.bit_width - and self.is_signed == other.is_signed - ) - - def min_value(self) -> int: - """Minimum value representable by the Integer.""" - if self.is_signed: - return -(2 ** (self.bit_width - 1)) - - return 0 - - def max_value(self) -> int: - """Maximum value representable by the Integer.""" - if self.is_signed: - return 2 ** (self.bit_width - 1) - 1 - - return 2 ** self.bit_width - 1 - - def can_represent_value(self, value_to_represent: int) -> bool: - """Check if a value is representable by the Integer. - - Args: - value_to_represent (int): Value to check - - Returns: - bool: True if the value can be represented by this integer - """ - return self.min_value() <= value_to_represent <= self.max_value() - - -def create_signed_integer(bit_width: int) -> Integer: - """Create a signed integer. - - Args: - bit_width (int): width of the integer - - Returns: - Integer: A signed integer with the requested bit_width - """ - return Integer(bit_width, is_signed=True) - - -SignedInteger = create_signed_integer - - -def create_unsigned_integer(bit_width: int) -> Integer: - """Create an unsigned integer. - - Args: - bit_width (int): width of the integer - - Returns: - Integer: An unsigned integer with the requested bit_width - """ - return Integer(bit_width, is_signed=False) - - -UnsignedInteger = create_unsigned_integer - - -def make_integer_to_hold(values: Iterable[Any], force_signed: bool) -> Integer: - """Return an Integer able to hold all values, it is possible to force the Integer to be signed. - - Args: - values (Iterable[Any]): The values to hold - force_signed (bool): Set to True to force the result to be a signed Integer - - Returns: - Integer: The Integer able to hold values - """ - min_value = min(values) - max_value = max(values) - - make_signed_integer = force_signed or min_value < 0 - - num_bits = max( - get_bits_to_represent_value_as_integer(min_value, make_signed_integer), - get_bits_to_represent_value_as_integer(max_value, make_signed_integer), - ) - - return Integer(num_bits, is_signed=make_signed_integer) - - -def get_bits_to_represent_value_as_integer(value: Any, force_signed: bool) -> int: - """Return how many bits are required to represent a numerical Value. - - Args: - value (Any): The value for which we want to know how many bits are required. - force_signed (bool): Set to True to force the result to be a signed integer. - - Returns: - int: required amount of bits - """ - # Writing this in a very dumb way - num_bits: int - if value < 0: - abs_value = abs(value) - if abs_value > 1: - num_bits = math.ceil(math.log2(abs_value)) + 1 - else: - # -1 case - num_bits = 2 - else: - if value > 1: - num_bits = math.ceil(math.log2(value + 1)) - else: - # 0 and 1 case - num_bits = 1 - - if force_signed: - num_bits += 1 - - return num_bits diff --git a/concrete/common/debugging/__init__.py b/concrete/common/debugging/__init__.py deleted file mode 100644 index 2c3d2fcac..000000000 --- a/concrete/common/debugging/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -"""Module for debugging.""" -from .custom_assert import assert_true -from .drawing import draw_graph -from .formatting import format_operation_graph diff --git a/concrete/common/debugging/custom_assert.py b/concrete/common/debugging/custom_assert.py deleted file mode 100644 index 1a639776c..000000000 --- a/concrete/common/debugging/custom_assert.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Provide some variants of assert.""" - - -def _custom_assert(condition: bool, on_error_msg: str = "") -> None: - """Provide a custom assert which is kept even if the optimized python mode is used. - - See https://docs.python.org/3/reference/simple_stmts.html#assert for the documentation - on the classical assert function - - Args: - condition(bool): the condition. If False, raise AssertionError - on_error_msg(str): optional message for precising the error, in case of error - - """ - - if not condition: - raise AssertionError(on_error_msg) - - -def assert_true(condition: bool, on_error_msg: str = ""): - """Provide a custom assert to check that the condition is True. - - Args: - condition(bool): the condition. If False, raise AssertionError - on_error_msg(str): optional message for precising the error, in case of error - - """ - return _custom_assert(condition, on_error_msg) - - -def assert_false(condition: bool, on_error_msg: str = ""): - """Provide a custom assert to check that the condition is False. - - Args: - condition(bool): the condition. If True, raise AssertionError - on_error_msg(str): optional message for precising the error, in case of error - - """ - return _custom_assert(not condition, on_error_msg) - - -def assert_not_reached(on_error_msg: str): - """Provide a custom assert to check that a piece of code is never reached. - - Args: - on_error_msg(str): message for precising the error - - """ - return _custom_assert(False, on_error_msg) diff --git a/concrete/common/debugging/drawing.py b/concrete/common/debugging/drawing.py deleted file mode 100644 index 0bb5fc6ca..000000000 --- a/concrete/common/debugging/drawing.py +++ /dev/null @@ -1,156 +0,0 @@ -"""functions to draw the different graphs we can generate in the package, eg to debug.""" - -import os -import tempfile -from pathlib import Path -from typing import Optional - -import matplotlib.pyplot as plt -import networkx as nx -from PIL import Image - -from ..debugging.custom_assert import assert_true -from ..operator_graph import OPGraph -from ..representation.intermediate import ( - ALL_IR_NODES, - Add, - Constant, - Conv2D, - Dot, - GenericFunction, - IndexConstant, - Input, - MatMul, - Mul, - Sub, -) - -IR_NODE_COLOR_MAPPING = { - Input: "blue", - Constant: "cyan", - Conv2D: "brown", - Add: "red", - Sub: "yellow", - Mul: "green", - GenericFunction: "orange", - IndexConstant: "black", - Dot: "purple", - MatMul: "brown", - "GenericFunction": "orange", - "TLU": "grey", - "output": "magenta", -} - -_missing_nodes_in_mapping = ALL_IR_NODES - IR_NODE_COLOR_MAPPING.keys() -assert_true( - len(_missing_nodes_in_mapping) == 0, - ( - f"Missing IR node in IR_NODE_COLOR_MAPPING : " - f"{', '.join(sorted(str(node_type) for node_type in _missing_nodes_in_mapping))}" - ), -) - -del _missing_nodes_in_mapping - - -def draw_graph( - op_graph: OPGraph, - show: bool = False, - vertical: bool = True, - save_to: Optional[Path] = None, -) -> str: - """Draws operation graphs and optionally saves/shows the drawing. - - Note that this function requires the python `pygraphviz` package which itself requires the - installation of `graphviz` packages, see - https://pygraphviz.github.io/documentation/stable/install.html - - Args: - op_graph (OPGraph): the operation graph to be drawn and optionally saved/shown - show (bool): if set to True, the drawing will be shown using matplotlib - vertical (bool): if set to True, the orientation will be vertical - save_to (Optional[Path]): if specified, the drawn graph will be saved to this path; else - it is saved in a temporary file - - Returns: - The path of the file where the drawn graph is saved - - """ - - def get_color(node, output_nodes): - value_to_return = IR_NODE_COLOR_MAPPING[type(node)] - if node in output_nodes: - value_to_return = IR_NODE_COLOR_MAPPING["output"] - elif isinstance(node, GenericFunction): - value_to_return = IR_NODE_COLOR_MAPPING.get(node.op_name, value_to_return) - return value_to_return - - graph = op_graph.graph - output_nodes = set(op_graph.output_nodes.values()) - - attributes = { - node: { - "label": node.text_for_drawing(), - "color": get_color(node, output_nodes), - "penwidth": 2, # double thickness for circles - "peripheries": 2 if node in output_nodes else 1, # double circle for output nodes - } - for node in graph.nodes - } - nx.set_node_attributes(graph, attributes) - - # TODO: #639 adapt drawing routine to manage output_idx - for edge in graph.edges(keys=True): - idx = graph.edges[edge]["input_idx"] - graph.edges[edge]["label"] = f" {idx} " # spaces are there intentionally for a better look - - try: - agraph = nx.nx_agraph.to_agraph(graph) - except ImportError as e: # pragma: no cover - if "pygraphviz" in str(e): - err_msg = ( - f"{draw_graph.__name__} requires pygraphviz, install your OS graphviz distribution " - "https://pygraphviz.github.io/documentation/stable/install.html " - "and reinstall with extras: `pip install --force-reinstall " - "concrete-numpy[full]`" - ) - raise ImportError(err_msg) from e - agraph.graph_attr["rankdir"] = "TB" if vertical else "LR" - agraph.layout("dot") - - if save_to is None: - with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: - # we need to change the permissions of the temporary file - # so that it can be read by all users - - # (https://stackoverflow.com/a/44130605) - - # get the old umask and replace it with 0o666 - old_umask = os.umask(0o666) - - # restore the old umask back - os.umask(old_umask) - - # combine the old umask with the wanted permissions - permissions = 0o666 & ~old_umask - - # set new permissions - os.chmod(tmp.name, permissions) - - save_to_str = str(tmp.name) - else: - save_to_str = str(save_to) - - agraph.draw(save_to_str) - - if show: # pragma: no cover - # We can't have coverage in this branch as `plt.show()` blocks and waits for user action. - plt.close("all") - plt.figure() - img = Image.open(save_to_str) - plt.imshow(img) - img.close() - plt.axis("off") - plt.show() - - return save_to_str diff --git a/concrete/common/debugging/formatting.py b/concrete/common/debugging/formatting.py deleted file mode 100644 index 3b6588625..000000000 --- a/concrete/common/debugging/formatting.py +++ /dev/null @@ -1,151 +0,0 @@ -"""Functions to format operation graphs for debugging purposes.""" - -from typing import Dict, List, Optional, Tuple - -import networkx as nx - -from ..debugging.custom_assert import assert_true -from ..operator_graph import OPGraph -from ..representation.intermediate import GenericFunction, IntermediateNode - - -def format_operation_graph( - op_graph: OPGraph, - maximum_constant_length: int = 25, - highlighted_nodes: Optional[Dict[IntermediateNode, List[str]]] = None, -) -> str: - """Format an operation graph. - - Args: - op_graph (OPGraph): - the operation graph to format - - maximum_constant_length (int): - maximum length of the constant throughout the formatting - - highlighted_nodes (Optional[Dict[IntermediateNode, List[str]]] = None): - the dict of nodes and their corresponding messages which will be highlighted - - Returns: - str: formatted operation graph - """ - - # This function is well documented and split into very readable sections - # Thus, splitting it to multiple functions doesn't increase readability - - # pylint: disable=too-many-locals,too-many-branches - - assert_true(isinstance(op_graph, OPGraph)) - - # (node, output_index) -> identifier - # e.g., id_map[(node1, 0)] = 2 and id_map[(node1, 1)] = 3 - # means line for node1 is in this form (%2, %3) = node1.format(...) - id_map: Dict[Tuple[IntermediateNode, int], int] = {} - - # lines that will be merged at the end - lines: List[str] = [] - - # type information to add to each line (for alingment, this is done after lines are determined) - type_informations: List[str] = [] - - # default highlighted nodes is empty - highlighted_nodes = highlighted_nodes if highlighted_nodes is not None else {} - - # highlight information for lines, this is required because highlights are added to lines - # after their type information is added and we only have line numbers, not nodes - highlighted_lines: Dict[int, List[str]] = {} - - # subgraphs to format after the main graph is formatted - subgraphs: Dict[str, OPGraph] = {} - - # format nodes - for node in nx.topological_sort(op_graph.graph): - # assign a unique id to outputs of node - assert_true(len(node.outputs) > 0) - for i in range(len(node.outputs)): - id_map[(node, i)] = len(id_map) - - # remember highlights of the node - if node in highlighted_nodes: - highlighted_lines[len(lines)] = highlighted_nodes[node] - - # extract predecessors and their ids - predecessors = [] - for predecessor, output_idx in op_graph.get_ordered_preds_and_inputs_of(node): - predecessors.append(f"%{id_map[(predecessor, output_idx)]}") - - # start the build the line for the node - line = "" - - # add output information to the line - outputs = ", ".join(f"%{id_map[(node, i)]}" for i in range(len(node.outputs))) - line += outputs if len(node.outputs) == 1 else f"({outputs})" - - # add node information to the line - line += " = " - line += node.text_for_formatting(predecessors, maximum_constant_length) - - # append line to list of lines - lines.append(line) - - # if exists, save the subgraph - if isinstance(node, GenericFunction) and "float_op_subgraph" in node.op_kwargs: - subgraphs[line] = node.op_kwargs["float_op_subgraph"] - - # remember type information of the node - types = ", ".join(str(output) for output in node.outputs) - type_informations.append(types if len(node.outputs) == 1 else f"({types})") - - # align = signs - # - # e.g., - # - # %1 = ... - # %2 = ... - # ... - # %8 = ... - # %9 = ... - # %10 = ... - # %11 = ... - # ... - longest_length_before_equals_sign = max(len(line.split("=")[0]) for line in lines) - for i, line in enumerate(lines): - length_before_equals_sign = len(line.split("=")[0]) - lines[i] = (" " * (longest_length_before_equals_sign - length_before_equals_sign)) + line - - # add type informations - longest_line_length = max(len(line) for line in lines) - for i, line in enumerate(lines): - lines[i] += " " * (longest_line_length - len(line)) - lines[i] += f" # {type_informations[i]}" - - # add highlights (this is done in reverse to keep indices consistent) - for i in reversed(range(len(lines))): - if i in highlighted_lines: - for j, message in enumerate(highlighted_lines[i]): - highlight = "^" if j == 0 else " " - lines.insert(i + 1 + j, f"{highlight * len(lines[i])} {message}") - - # add return information - # (if there is a single return, it's in the form `return %id` - # (otherwise, it's in the form `return (%id1, %id2, ..., %idN)` - returns: List[str] = [] - for node in op_graph.output_nodes.values(): - outputs = ", ".join(f"%{id_map[(node, i)]}" for i in range(len(node.outputs))) - returns.append(outputs if len(node.outputs) == 1 else f"({outputs})") - lines.append("return " + returns[0] if len(returns) == 1 else f"({', '.join(returns)})") - - # format subgraphs after the actual graph - result = "\n".join(lines) - if len(subgraphs) > 0: - result += "\n\n" - result += "Subgraphs:" - for line, subgraph in subgraphs.items(): - subgraph_lines = format_operation_graph(subgraph, maximum_constant_length).split("\n") - result += "\n\n" - result += f" {line}:\n\n" - result += "\n".join(f" {line}" for line in subgraph_lines) - - # pylint: enable=too-many-locals,too-many-branches - - return result diff --git a/concrete/common/extensions/__init__.py b/concrete/common/extensions/__init__.py deleted file mode 100644 index 95ccf9a83..000000000 --- a/concrete/common/extensions/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Extensions module to provide additional functionality to our users.""" -from . import convolution, multi_table, table diff --git a/concrete/common/extensions/convolution.py b/concrete/common/extensions/convolution.py deleted file mode 100644 index 79b9f1f0d..000000000 --- a/concrete/common/extensions/convolution.py +++ /dev/null @@ -1,161 +0,0 @@ -"""This file contains tracers for convolution operations.""" - -from typing import List, Optional, Tuple, Union, cast - -import numpy as np - -from ...numpy.tracing import NPConstant, NPTracer -from ..representation.intermediate import Conv2D -from ..tracing.base_tracer import BaseTracer - -SUPPORTED_AUTO_PAD = [ - "NOTSET", -] - - -def conv2d( - x: Union[np.ndarray, BaseTracer], - weight: Union[np.ndarray, BaseTracer], - bias: Optional[Union[np.ndarray, BaseTracer]] = None, - pads: Union[Tuple[int, int, int, int], List[int]] = (0, 0, 0, 0), - strides: Union[Tuple[int, int], List[int]] = (1, 1), - dilations: Union[Tuple[int, int], List[int]] = (1, 1), - auto_pad: str = "NOTSET", -) -> Union[np.ndarray, NPTracer]: - """Trace or evaluate 2D convolution. - - Args: - x (Union[np.ndarray, BaseTracer]): Input of shape (NxCxHxW) - weight (Union[np.ndarray, BaseTracer]): Weight (kernel) of shape (FxCxHxW) - bias (Optional[Union[np.ndarray, BaseTracer]], optional): Bias vector of size (F). - Defaults to None. - pads (Union[Tuple[int, int, int, int], List[int]], optional): Padding over each axis - (H_beg, W_beg, H_end, W_end). Defaults to (0, 0, 0, 0). - strides (Union[Tuple[int, int], List[int]], optional): Stride over each axis - (height and width). Defaults to (1, 1). - dilations (Union[Tuple[int, int], List[int]], optional): Dilation over each axis - (height and width). Defaults to (1, 1). - auto_pad (str, optional): Padding strategy. Defaults to "NOTSET". - - Raises: - ValueError: If one argument isn't in the range of expected values. - TypeError: If one argument isn't of the appropriate type. - - Returns: - Union[np.ndarray, BaseTracer]: Evaluation result, or traced computation - """ - if auto_pad not in SUPPORTED_AUTO_PAD: - raise ValueError("invalid auto_pad is specified") - - if not isinstance(x, (np.ndarray, BaseTracer)): - raise TypeError(f"input x must be an ndarray, or a BaseTracer, not a {type(x)}") - if not isinstance(weight, (np.ndarray, BaseTracer)): - raise TypeError(f"weight must be an ndarray, or a BaseTracer, not a {type(weight)}") - if not isinstance(bias, (np.ndarray, BaseTracer, type(None))): - raise TypeError(f"bias must be an ndarray, a BaseTracer, or None, not a {type(bias)}") - if not isinstance(pads, (tuple, list)): - raise TypeError(f"padding must be a tuple, or list, not a {type(pads)}") - if not isinstance(strides, (tuple, list)): - raise TypeError(f"strides must be a tuple, or list, not a {type(strides)}") - if not isinstance(dilations, (tuple, list)): - raise TypeError(f"dilations must be a tuple, or list, not a {type(dilations)}") - - if len(pads) != 4: - raise ValueError( - f"padding should be of the form (pad_height_begin, pad_width_begin, pad_height_end, " - f" pad_width_end), but got {type(pads)} of length {len(pads)}" - ) - if len(strides) != 2: - raise ValueError( - f"strides should be of the form (stride_height, stride_width), but got {type(strides)}" - f" of length {len(strides)}" - ) - if len(dilations) != 2: - raise ValueError( - f"dilations should be of the form (dilation_height, dilation_width), but got" - f" {type(dilations)} of length {len(dilations)}" - ) - - assert len(x.shape) == 4, f"input x should have size (N x C x H x W), not {x.shape}" - assert len(weight.shape) == 4, f"weight should have size (F x C x H x W), not {weight.shape}" - if bias is not None: - assert len(bias.shape) == 1, f"bias should have size (F), not {bias.shape}" - - if isinstance(x, BaseTracer): - return _trace_conv2d(x, weight, bias, pads, strides, dilations) - # X is an ndarray - bias = np.zeros(weight.shape[0]) if bias is None else bias - # For mypy - weight = cast(np.ndarray, weight) - bias = cast(np.ndarray, bias) - return _evaluate_conv2d(x, weight, bias, pads, strides, dilations) - - -def _trace_conv2d( - x: BaseTracer, - weight: Union[np.ndarray, BaseTracer], - bias: Optional[Union[np.ndarray, BaseTracer]], - pads: Union[Tuple[int, int, int, int], List[int]], - strides: Union[Tuple[int, int], List[int]], - dilations: Union[Tuple[int, int], List[int]], -) -> NPTracer: - """Trace 2D convolution. - - Args: - x (BaseTracer): Input of shape (NxCxHxW) - weight (Union[np.ndarray, BaseTracer]): Weight (kernel) of shape (FxCxHxW) - bias (Optional[Union[np.ndarray, BaseTracer]]): Bias vector of size (F) - pads (Union[Tuple[int, int, int, int], List[int]]): Padding over each - axis (H_beg, W_beg, H_end, W_end) - strides (Union[Tuple[int, int], List[int]]): Stride over each - axis (height and width) - dilations (Union[Tuple[int, int], List[int]]): Dilation over each - axis (height and width) - - Returns: - BaseTracer: Traced computation - """ - weight_tracer = ( - weight if isinstance(weight, BaseTracer) else NPTracer([], NPConstant(weight), 0) - ) - inputs = [x.output, weight_tracer.output] - output_tracer_inputs = [x, weight_tracer] - if bias is not None: - bias_tracer = bias if isinstance(bias, BaseTracer) else NPTracer([], NPConstant(bias), 0) - inputs.append(bias_tracer.output) - # For mypy - bias = cast(BaseTracer, bias_tracer) - output_tracer_inputs.append(bias) - - traced_computation = Conv2D(inputs, x.output.dtype, pads, strides, dilations) - output_tracer = x.__class__( - output_tracer_inputs, traced_computation=traced_computation, output_idx=0 - ) - # For mypy - assert isinstance(output_tracer, NPTracer) - return output_tracer - - -def _evaluate_conv2d( - x: np.ndarray, - weight: np.ndarray, - bias: np.ndarray, - pads: Union[Tuple[int, int, int, int], List[int]], - strides: Union[Tuple[int, int], List[int]], - dilations: Union[Tuple[int, int], List[int]], -) -> np.ndarray: - """Evaluate 2D convolution. - - Args: - x (np.ndarray): Input of shape (NxCxHxW) - weight (np.ndarray): Weight (kernel) of shape (FxCxHxW) - bias (np.ndarray): Bias vector of size (F) - pads (Union[Tuple[int, int, int, int], List[int]]): Padding over each - axis (H_beg, W_beg, H_end, W_end) - strides (Union[Tuple[int, int], List[int]]): Stride over each axis (height and width) - dilations (Union[Tuple[int, int], List[int]]): Dilation over each axis (height and width) - - Returns: - np.ndarray: Result of the convolution of shape (NxCxHxW) - """ - return Conv2D.evaluate_conv2d(x, weight, bias, pads, strides, dilations) diff --git a/concrete/common/extensions/multi_table.py b/concrete/common/extensions/multi_table.py deleted file mode 100644 index 73ac38191..000000000 --- a/concrete/common/extensions/multi_table.py +++ /dev/null @@ -1,233 +0,0 @@ -"""This file contains a wrapper class for direct multi table lookups.""" - -import itertools -from copy import deepcopy -from typing import List, Tuple, Union - -from ..data_types.base import BaseDataType -from ..data_types.dtypes_helpers import find_type_to_hold_both_lossy -from ..representation.intermediate import GenericFunction -from ..tracing.base_tracer import BaseTracer -from ..values import TensorValue -from .table import LookupTable - - -class MultiLookupTable: - """Class representing a multi lookup table.""" - - # Multi table lookup is needed when you want to perform a lookup on a tensor, - # but you want each element to be used with a different lookup table. - # - # Here is an example: - # - # You have x which is of shape (2, 3), - # you want the first row to be indexed with `table1 = LookupTable([2, 3, 1, 0])` - # and the second row to be indexed with `table1 = LookupTable([0, 1, 3, 2])` - # - # You can create such a multi lookup table - # multitable = MultiLookupTable( - # [ - # [table1, table1, table1], - # [table2, table2, table2], - # ], - # ) - # (notice the shape of multitable matches with the shape of x) - # - # and use multitable[x] toget the following result - # assert multitable[x] == [ - # [table1[x[0, 0]], table1[x[0, 1]], table1[x[0, 2]]], - # [table2[x[1, 0]], table2[x[1, 1]], table2[x[1, 2]]], - # ] - - # underlying lookup tables - tables: List - - # shape of the input of the lookup - input_shape: Tuple[int, ...] - - # type of the result of the lookup - output_dtype: BaseDataType - - def __init__(self, tables: List): - input_shape_list: List[int] = [] - MultiLookupTable._extract_shape_using_first_elements_only(tables, input_shape_list) - input_shape: Tuple[int, ...] = tuple(input_shape_list) - - table_sizes: List[int] = [] - table_output_dtypes: List[BaseDataType] = [] - MultiLookupTable._check_shape_and_record_luts( - tables, - 0, - input_shape, - table_sizes, - table_output_dtypes, - ) - - for i in range(1, len(table_sizes)): - if table_sizes[i - 1] != table_sizes[i]: - # this branch is for such a case: - # - # table1 = hnp.LookupTable([1, 3]) - # table2 = hnp.LookupTable([0, 2, 3, 1]) - # - # multitable = hnp.MultiLookupTable( - # [ - # [table1, table2, table1], - # [table2, table1, table2], - # ], - # ) - raise ValueError( - f"LookupTables within a MultiLookupTable " - f"should have the same size but they do not " - f"(there was a table with the size of {table_sizes[i - 1]} " - f"and another with the size of {table_sizes[i]})" - ) - - output_dtype = table_output_dtypes[0] - for table_output_dtype in table_output_dtypes: - output_dtype = find_type_to_hold_both_lossy(output_dtype, table_output_dtype) - - self.tables = tables - self.input_shape = input_shape - self.output_dtype = output_dtype - - def __getitem__(self, key: Union[int, BaseTracer]): - # this branch is used during tracing and the regular flow is used during evaluation - if isinstance(key, BaseTracer): - out_dtype = deepcopy(key.output.dtype) - out_shape = deepcopy(self.input_shape) - - generic_function_output_value = TensorValue( - out_dtype, - key.output.is_encrypted, - out_shape, - ) - - traced_computation = GenericFunction( - inputs=[key.output], - arbitrary_func=MultiLookupTable._checked_indexing, - output_value=generic_function_output_value, - op_kind="TLU", - op_kwargs={ - "input_shape": deepcopy(self.input_shape), - "tables": deepcopy(self.tables), - }, - op_name="MultiTLU", - ) - return key.__class__( - inputs=[key], - traced_computation=traced_computation, - output_idx=0, - ) - - # if not, it means table is indexed with a constant - # thus, the result of the lookup is a constant - # so, we can propagate it directly - return MultiLookupTable._checked_indexing(key, self.input_shape, self.tables) - - @staticmethod - def _extract_shape_using_first_elements_only(array, shape): - if not isinstance(array, list): - # base case for recursion - # the shape is already accumulated up to this point - # so we just return - return - - if len(array) == 0: - # this branch is for such a case: - # - # table1 = hnp.LookupTable([1, 3, 2, 0]) - # table2 = hnp.LookupTable([0, 2, 3, 1]) - # - # multitable = hnp.MultiLookupTable( - # [ - # [], - # [table1, table2, table1], - # [table2, table1, table2], - # ], - # ) - - raise ValueError("MultiLookupTable cannot have an empty array within it") - - shape.append(len(array)) - MultiLookupTable._extract_shape_using_first_elements_only(array[0], shape) - - @staticmethod - def _check_shape_and_record_luts(array, dimension, shape, table_sizes, table_output_dtypes): - if dimension == len(shape): - if not isinstance(array, LookupTable): - # this branch is for such a case: - # - # table1 = hnp.LookupTable([1, 3, 2, 0]) - # table2 = hnp.LookupTable([0, 2, 3, 1]) - # - # multitable = hnp.MultiLookupTable( - # [ - # [table1, table2, 4], - # [table2, table1, table2], - # ], - # ) - raise ValueError( - f"MultiLookupTable should have been made out of LookupTables " - f"but it had an object of type {array.__class__.__name__} within it" - ) - - table_sizes.append(len(array.table)) - table_output_dtypes.append(array.output_dtype) - return - - if not isinstance(array, list) or len(array) != shape[dimension]: - # this branch is for such a case: - # - # table1 = hnp.LookupTable([1, 3, 2, 0]) - # table2 = hnp.LookupTable([0, 2, 3, 1]) - # - # multitable = hnp.MultiLookupTable( - # [ - # [table1, table2], - # [table2, table1, table2], - # ], - # ) - raise ValueError( - f"MultiLookupTable should have the shape {shape} but it does not " - f"(an array on dimension {dimension} has the size {len(array)} " - f"but its size should have been {shape[dimension]} " - f"as the expected shape is {shape})" - ) - - for item in array: - MultiLookupTable._check_shape_and_record_luts( - item, - dimension + 1, - shape, - table_sizes, - table_output_dtypes, - ) - - @staticmethod - def _checked_indexing(x, input_shape, tables): - try: - result = [] - for indices in itertools.product(*[range(dimension) for dimension in input_shape]): - which_table_to_use = tables - what_value_to_use = x - where_to_append = result - - for index in indices[:-1]: - which_table_to_use = tables[index] - what_value_to_use = x[index] - - if len(where_to_append) == index: - where_to_append.append([]) - where_to_append = result[index] - - which_table_to_use = which_table_to_use[indices[-1]] - what_value_to_use = what_value_to_use[indices[-1]] - where_to_append.append(which_table_to_use[what_value_to_use]) - except Exception as error: - raise ValueError( - f"Multiple Lookup Table of shape {input_shape} cannot be looked up with {x} " - f"(you should check your inputset)", - ) from error - - return result diff --git a/concrete/common/extensions/table.py b/concrete/common/extensions/table.py deleted file mode 100644 index 39651a8b4..000000000 --- a/concrete/common/extensions/table.py +++ /dev/null @@ -1,118 +0,0 @@ -"""This file contains a wrapper class for direct table lookups.""" - -from copy import deepcopy -from typing import Any, Iterable, List, Tuple, Union - -from ..common_helpers import is_a_power_of_2 -from ..data_types.base import BaseDataType -from ..data_types.integers import make_integer_to_hold -from ..representation.intermediate import GenericFunction -from ..tracing.base_tracer import BaseTracer - - -class LookupTable: - """Class representing a lookup table.""" - - # lookup table itself, has 2^N entries - table: Tuple[int, ...] - - # type of the result of the lookup - output_dtype: BaseDataType - - def __init__(self, table: Iterable[int]): - table = tuple(table) - - if not is_a_power_of_2(len(table)): - raise ValueError( - f"Desired lookup table has inappropriate number of entries ({len(table)})" - ) - - self.table = table - self.output_dtype = make_integer_to_hold(table, force_signed=False) - - def __repr__(self): - return str(list(self.table)) - - def __getitem__(self, key: Union[int, Iterable, BaseTracer]): - # if a tracer is used for indexing, - # we need to create an `GenericFunction` node - # because the result will be determined during the runtime - if isinstance(key, BaseTracer): - generic_function_output_value = deepcopy(key.output) - generic_function_output_value.dtype = self.output_dtype - - traced_computation = GenericFunction( - inputs=[key.output], - arbitrary_func=LookupTable._checked_indexing, - output_value=generic_function_output_value, - op_kind="TLU", - op_kwargs={"table": deepcopy(self.table)}, - op_name="TLU", - ) - return key.__class__( - inputs=[key], - traced_computation=traced_computation, - output_idx=0, - ) - - # if not, it means table is indexed with a constant - # thus, the result of the lookup is a constant - # so, we can propagate it directly - return LookupTable._checked_indexing(key, self.table) - - @staticmethod - def _check_index_out_of_range(x, table): - if not -len(table) <= x < len(table): - raise ValueError( - f"Lookup table with {len(table)} entries cannot be indexed with {x} " - f"(you should check your inputset)", - ) - - @staticmethod - def _checked_indexing(x, table): - """Index `table` using `x`. - - There is a single table and the indexing works with the following semantics: - - when x == c - - table[x] == table[c] - - when x == [c1, c2] - - table[x] == [table[c1], table[c2]] - - when x == [[c1, c2], [c3, c4], [c5, c6]] - - table[x] == [[table[c1], table[c2]], [table[c3], table[c4]], [table[c5], table[c6]]] - - Args: - x (Union[int, Iterable]): index to use - table (Tuple[int, ...]): table to index - - Returns: - Union[int, List[int]]: result of indexing - """ - - if not isinstance(x, Iterable): - LookupTable._check_index_out_of_range(x, table) - return table[x] - - def fill_result(partial_result: List[Any], partial_x: Iterable[Any]): - """Fill partial result with partial x. - - This function implements the recursive indexing of nested iterables. - - Args: - partial_result (List[Any]): currently accumulated result - partial_x (Iterable[Any]): current index to use - - Returns: - None - """ - - for item in partial_x: - if isinstance(item, Iterable): - partial_result.append([]) - fill_result(partial_result[-1], item) - else: - LookupTable._check_index_out_of_range(item, table) - partial_result.append(table[item]) - - result = [] - fill_result(result, x) - return result diff --git a/concrete/common/fhe_circuit.py b/concrete/common/fhe_circuit.py deleted file mode 100644 index ae9c73964..000000000 --- a/concrete/common/fhe_circuit.py +++ /dev/null @@ -1,144 +0,0 @@ -"""Module to hold the result of compilation.""" - -from pathlib import Path -from typing import Optional, Union - -import numpy -from concrete.compiler import ( - ClientParameters, - ClientSupport, - CompilationOptions, - JITCompilationResult, - JITLambda, - JITSupport, - KeySet, - KeySetCache, - PublicArguments, - PublicResult, -) - -from .debugging import draw_graph, format_operation_graph -from .operator_graph import OPGraph - - -class FHECircuit: - """Class which is the result of compilation.""" - - op_graph: OPGraph - _jit_support: JITSupport - _compilation_result: JITCompilationResult - _client_parameters: ClientParameters - _server_lambda: JITLambda - _keyset_cache: KeySetCache - _keyset: KeySet - - def __init__( - self, - op_graph: OPGraph, - mlir_str: str, - unsecure_key_set_cache_path: Optional[str] = None, - auto_parallelize: bool = False, - loop_parallelize: bool = False, - dataflow_parallelize: bool = False, - ): - self.op_graph = op_graph - self._jit_support = JITSupport.new() - # Set compilation options - options = CompilationOptions.new("main") - options.set_auto_parallelize(auto_parallelize) - options.set_loop_parallelize(loop_parallelize) - options.set_dataflow_parallelize(dataflow_parallelize) - # Compile - self._compilation_result = self._jit_support.compile(mlir_str, options) - self._client_parameters = self._jit_support.load_client_parameters(self._compilation_result) - self._server_lambda = self._jit_support.load_server_lambda(self._compilation_result) - # Setup keyset cache - self._keyset_cache = None - if unsecure_key_set_cache_path: - self._keyset_cache = KeySetCache.new(unsecure_key_set_cache_path) - self._keyset = None - - def __str__(self): - return format_operation_graph(self.op_graph) - - def draw( - self, - show: bool = False, - vertical: bool = True, - save_to: Optional[Path] = None, - ) -> str: - """Draw operation graph of the circuit and optionally save/show the drawing. - - Args: - show (bool): if set to True, the drawing will be shown using matplotlib - vertical (bool): if set to True, the orientation will be vertical - save_to (Optional[Path]): if specified, the drawn graph will be saved to this path; - otherwise it will be saved to a temporary file - - Returns: - str: path of the file where the drawn graph is saved - - """ - - return draw_graph(self.op_graph, show, vertical, save_to) - - def keygen(self, force: bool = False): - """Generate the keys required for the encrypted circuit. - - Args: - force (bool, optional): generate even if keyset already exists. Defaults to False. - """ - if self._keyset is None or force: - self._keyset = ClientSupport.key_set(self._client_parameters, self._keyset_cache) - - def encrypt(self, *args: Union[int, numpy.ndarray]) -> PublicArguments: - """Encrypt the inputs of the circuit. - - Args: - *args (Union[int, numpy.ndarray]): plain input of the circuit - - Returns: - PublicArguments: encrypted and plain arguments as well as public keys - """ - # Make sure keys are available: shouldn't regenerate if they already exist - self.keygen(force=False) - return ClientSupport.encrypt_arguments(self._client_parameters, self._keyset, args) - - def run(self, args: PublicArguments) -> PublicResult: - """Evaluate the the encrypted circuit (no encryption or decryption involved). - - Args: - args (PublicArguments): encrypted inputs to the circuit - - Returns: - PublicResult: encrypted result - """ - return self._jit_support.server_call(self._server_lambda, args) - - def decrypt(self, result: PublicResult) -> Union[int, numpy.ndarray]: - """Decrypt the result of the circuit. - - Args: - result (PublicResult): encrypted result of the circuit - - Returns: - Union[int, numpy.ndarray]: plain result of the circuit - """ - return ClientSupport.decrypt_result(self._keyset, result) - - def encrypt_run_decrypt(self, *args: Union[int, numpy.ndarray]) -> Union[int, numpy.ndarray]: - """Encrypt, evaluate, and decrypt the inputs on the circuit. - - Generate keyset automatically if not yet done. - - Args: - *args (Union[int, numpy.ndarray]): plain inputs of the circuit - - Returns: - Union[int, numpy.ndarray]: plain result of the circuit - """ - self.keygen(force=False) - public_args = self.encrypt(*args) - encrypted_result = self.run(public_args) - decrypted_result = self.decrypt(encrypted_result) - return decrypted_result diff --git a/concrete/common/helpers/__init__.py b/concrete/common/helpers/__init__.py deleted file mode 100644 index 8680796ce..000000000 --- a/concrete/common/helpers/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Helpers for all kinds of tasks.""" - -from . import indexing_helpers, python_helpers diff --git a/concrete/common/helpers/formatting_helpers.py b/concrete/common/helpers/formatting_helpers.py deleted file mode 100644 index 6121e444b..000000000 --- a/concrete/common/helpers/formatting_helpers.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Helpers for formatting functionality.""" - -from typing import Any, Dict, Hashable - -import numpy - -from ..debugging.custom_assert import assert_true - -SPECIAL_OBJECT_MAPPING: Dict[Any, str] = { - numpy.float32: "float32", - numpy.float64: "float64", - numpy.int8: "int8", - numpy.int16: "int16", - numpy.int32: "int32", - numpy.int64: "int64", - numpy.uint8: "uint8", - numpy.uint16: "uint16", - numpy.uint32: "uint32", - numpy.uint64: "uint64", -} - - -def format_constant(constant: Any, maximum_length: int = 45) -> str: - """Format a constant. - - Args: - constant (Any): the constant to format - maximum_length (int): maximum length of the resulting string - - Returns: - str: the formatted constant - """ - - if isinstance(constant, Hashable) and constant in SPECIAL_OBJECT_MAPPING: - return SPECIAL_OBJECT_MAPPING[constant] - - # maximum_length should not be smaller than 7 characters because - # the constant will be formatted to `x ... y` - # where x and y are part of the constant and they are at least 1 character - assert_true(maximum_length >= 7) - - content = str(constant).replace("\n", "") - if len(content) > maximum_length: - from_start = (maximum_length - 5) // 2 - from_end = (maximum_length - 5) - from_start - content = f"{content[:from_start]} ... {content[-from_end:]}" - return content diff --git a/concrete/common/helpers/indexing_helpers.py b/concrete/common/helpers/indexing_helpers.py deleted file mode 100644 index 80b77fede..000000000 --- a/concrete/common/helpers/indexing_helpers.py +++ /dev/null @@ -1,277 +0,0 @@ -"""Helpers for indexing functionality.""" - -from typing import Tuple, Union - - -def format_indexing_element(indexing_element: Union[int, slice]) -> str: - """Format an indexing element. - - This is required mainly for slices. The reason is that string representation of slices - are very long and verbose. To give an example, `x[:, 2:]` will have the following index - `[slice(None, None, None), slice(2, None, None)]` if printed naively. With this helper, - it will be formatted as `[:, 2:]`. - - Args: - indexing_element (Union[int, slice]): indexing element to be formatted - - Returns: - str: formatted element - """ - - result = "" - if isinstance(indexing_element, slice): - if indexing_element.start is not None: - result += str(indexing_element.start) - result += ":" - if indexing_element.stop is not None: - result += str(indexing_element.stop) - if indexing_element.step is not None: - result += ":" - result += str(indexing_element.step) - else: - result += str(indexing_element) - return result.replace("\n", " ") - - -def validate_index( - index: Union[int, slice, Tuple[Union[int, slice], ...]], -) -> Tuple[Union[int, slice], ...]: - """Make sure index is valid and convert it to the tuple form. - - For example in `x[2]`, `index` is passed as `2`. - To make it easier to work with, this function converts index to `(2,)`. - - Args: - index (Union[int, slice, Tuple[Union[int, slice], ...]]): index to validate, improve - and return - - Returns: - Tuple[Union[int, slice], ...]: validated and improved index - """ - - if not isinstance(index, tuple): - index = (index,) - - for indexing_element in index: - valid = isinstance(indexing_element, (int, slice)) - - if isinstance(indexing_element, slice): - if ( - not (indexing_element.start is None or isinstance(indexing_element.start, int)) - or not (indexing_element.stop is None or isinstance(indexing_element.stop, int)) - or not (indexing_element.step is None or isinstance(indexing_element.step, int)) - ): - valid = False - - if not valid: - raise TypeError( - f"Only integers and integer slices can be used for indexing " - f"but you tried to use {format_indexing_element(indexing_element)} for indexing" - ) - - return index - - -def determine_output_shape( - input_shape: Tuple[int, ...], - index: Tuple[Union[int, slice], ...], -) -> Tuple[int, ...]: - """Determine the output shape from the input shape and the index. - - e.g., for `input_shape=(3, 2)` and `index=(:, 0)`, returns `(3,)` - for `input_shape=(4, 3, 2)` and `index=(2:,)`, returns `(2, 3, 2)` - - Args: - input_shape (Tuple[int, ...]): shape of the input tensor that is indexed - index (Tuple[Union[int, slice], ...]): desired and validated index - - Returns: - Tuple[int, ...]: shape of the result of indexing - """ - - indexing_elements = [format_indexing_element(indexing_element) for indexing_element in index] - index_str = f"[{', '.join(indexing_elements)}]" - - if len(index) > len(input_shape): - raise ValueError( - f"Tensor of shape {input_shape} cannot be indexed with {index_str} " - f"as the index has more elements than the number of dimensions of the tensor" - ) - - # indexing (3, 4, 5) with [1] is the same as indexing it with [1, :, :] - # indexing (3, 4, 5) with [1, 2] is the same as indexing it with [1, 2, :] - - # so let's replicate that behavior to make the rest of the code generic - index += (slice(None, None, None),) * (len(input_shape) - len(index)) - - output_shape = [] - for dimension, (indexing_element, dimension_size) in enumerate(zip(index, input_shape)): - if isinstance(indexing_element, int): # indexing removes the dimension - indexing_element = ( - indexing_element if indexing_element >= 0 else indexing_element + dimension_size - ) - if not 0 <= indexing_element < dimension_size: - raise ValueError( - f"Tensor of shape {input_shape} cannot be indexed with {index_str} " - f"because index is out of range for dimension {dimension}" - ) - elif isinstance(indexing_element, slice): # indexing possibly shrinks the dimension - output_shape.append( - determine_new_dimension_size( - indexing_element, - dimension_size, - dimension, - input_shape, - index_str, - ) - ) - - return tuple(output_shape) - - -def sanitize_start_index( - start: int, - dimension_size: int, - # the rest is used for detailed exception message - dimension: int, - input_shape: Tuple[int, ...], - index_str: str, -) -> int: - """Sanitize and check start index of a slice. - - Args: - start (int): start index being sanitized - dimension_size (int): size of the dimension the slice is applied to - dimension (int): index of the dimension being sliced (for better messages) - input_shape (Tuple[int, ...]): shape of the whole input (for better messages) - index_str (str): string representation of the whole index (for better messages) - - Returns: - int: sanitized start index - """ - - start = start if start >= 0 else start + dimension_size - if not 0 <= start < dimension_size: - raise ValueError( - f"Tensor of shape {input_shape} cannot be indexed with {index_str} " - f"because start index is out of range for dimension {dimension}" - ) - return start - - -def sanitize_stop_index( - stop: int, - dimension_size: int, - # the rest is used for detailed exception message - dimension: int, - input_shape: Tuple[int, ...], - index_str: str, -) -> int: - """Sanitize and check stop index of a slice. - - Args: - stop (int): stop index being sanitized - dimension_size (int): size of the dimension the slice is applied to - dimension (int): index of the dimension being sliced (for better messages) - input_shape (Tuple[int, ...]): shape of the whole input (for better messages) - index_str (str): string representation of the whole index (for better messages) - - Returns: - int: sanitized stop index - """ - - stop = stop if stop >= 0 else stop + dimension_size - if not 0 <= stop <= dimension_size: - raise ValueError( - f"Tensor of shape {input_shape} cannot be indexed with {index_str} " - f"because stop index is out of range for dimension {dimension}" - ) - return stop - - -def determine_new_dimension_size( - slice_: slice, - dimension_size: int, - # the rest is used for detailed exception message - dimension: int, - input_shape: Tuple[int, ...], - index_str: str, -) -> int: - """Determine the new size of a dimension from the old size and the slice applied to it. - - e.g., for `slice_=1:4` and `dimension_size=5`, returns `3` - for `slice_=::-1` and `dimension_size=5`, returns `5` - - You may want to check this page to learn more about how this function works - https://numpy.org/doc/stable/reference/arrays.indexing.html#basic-slicing-and-indexing - - Args: - slice_ (slice): slice being applied to the dimension - dimension_size (int): size of the dimension the slice is applied to - dimension (int): index of the dimension being sliced (for better messages) - input_shape (Tuple[int, ...]): shape of the whole input (for better messages) - index_str (str): string representation of the whole index (for better messages) - - Returns: - int: new size of the dimension - """ - - step = slice_.step if slice_.step is not None else 1 - - if step > 0: - start = slice_.start if slice_.start is not None else 0 - stop = slice_.stop if slice_.stop is not None else dimension_size - - start = sanitize_start_index(start, dimension_size, dimension, input_shape, index_str) - stop = sanitize_stop_index(stop, dimension_size, dimension, input_shape, index_str) - - if start >= stop: - raise ValueError( - f"Tensor of shape {input_shape} cannot be indexed with {index_str} " - f"because start index is not less than stop index for dimension {dimension}" - ) - - size_before_stepping = stop - start - elif step < 0: - start = slice_.start if slice_.start is not None else dimension_size - 1 - stop = slice_.stop - - start = sanitize_start_index(start, dimension_size, dimension, input_shape, index_str) - - if stop is None: - # this is a weird case but it works as expected - # the issue is that it's impossible to slice whole vector reversed - # with a stop value different than none - - # if `x.shape == (6,)` the only one that works is `x[::-1].shape == (6,)` - # here is what doesn't work (and this is expected it's just weird) - # - # ... - # `x[:-2:-1].shape == (1,)` - # `x[:-1:-1].shape == (0,)` (note that this is a hard error for us) - # `x[:0:-1].shape == (5,)` - # `x[:1:-1].shape == (4,)` - # ... - - size_before_stepping = start + 1 - else: - stop = sanitize_stop_index(stop, dimension_size, dimension, input_shape, index_str) - - if stop >= start: - raise ValueError( - f"Tensor of shape {input_shape} cannot be indexed with {index_str} " - f"because step is negative and " - f"stop index is not less than start index for dimension {dimension}" - ) - - size_before_stepping = start - stop - else: - raise ValueError( - f"Tensor of shape {input_shape} cannot be indexed with {index_str} " - f"because step is zero for dimension {dimension}" - ) - - quotient = size_before_stepping // abs(step) - remainder = size_before_stepping % abs(step) - - return quotient + (remainder != 0) diff --git a/concrete/common/helpers/python_helpers.py b/concrete/common/helpers/python_helpers.py deleted file mode 100644 index e7c7bbed2..000000000 --- a/concrete/common/helpers/python_helpers.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Common python helpers.""" - -from typing import Any, Callable, Iterable, Mapping, Tuple, Union - - -def update_and_return_dict( - dict_to_update: dict, update_values: Union[Mapping, Iterable[Tuple[Any, Any]]] -) -> dict: - """Update a dictionary and return the ref to the dictionary that was updated. - - Args: - dict_to_update (dict): the dict to update - update_values (Union[Mapping, Iterable[Tuple[Any, Any]]]): the values to update the dict - with - - Returns: - dict: the dict that was just updated. - """ - dict_to_update.update(update_values) - return dict_to_update - - -def catch(func: Callable, *args, **kwargs) -> Union[Any, None]: - """Execute func by passing args and kwargs. Catch exceptions and return None in case of failure. - - Args: - func (Callable): function to execute and catch exceptions from - - Returns: - Union[Any, None]: the function result if there was no exception, None otherwise. - """ - try: - return func(*args, **kwargs) - except Exception: # pylint: disable=broad-except - return None diff --git a/concrete/common/mlir/__init__.py b/concrete/common/mlir/__init__.py deleted file mode 100644 index b06eb9322..000000000 --- a/concrete/common/mlir/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""MLIR conversion module.""" - -from .graph_converter import OPGraphConverter diff --git a/concrete/common/mlir/conversion_helpers.py b/concrete/common/mlir/conversion_helpers.py deleted file mode 100644 index 87af5e3bc..000000000 --- a/concrete/common/mlir/conversion_helpers.py +++ /dev/null @@ -1,61 +0,0 @@ -"""Helpers for MLIR conversion functionality.""" - -# pylint cannot extract symbol information of 'mlir' module so we need to disable some lints - -# pylint: disable=no-name-in-module - -from typing import Optional - -from concrete.lang.dialects.fhe import EncryptedIntegerType -from mlir.ir import Context, IntegerType, RankedTensorType, Type - -from ..data_types import Integer -from ..values import BaseValue, TensorValue - -# pylint: enable=no-name-in-module - - -def integer_to_mlir_type(ctx: Context, integer: Integer, is_encrypted: bool) -> Optional[Type]: - """Convert an integer to its corresponding MLIR type. - - Args: - ctx (Context): the MLIR context to perform the conversion - integer (Integer): the integer to convert - is_encrypted (bool): whether the integer is encrypted or not - - Returns: - Type: - the MLIR type corresponding to given integer and encryption status - if it's supported otherwise None - """ - - bit_width = integer.bit_width - - if is_encrypted: - result = EncryptedIntegerType.get(ctx, bit_width) - else: - result = IntegerType.get_signless(bit_width) - - return result - - -def value_to_mlir_type(ctx: Context, value: BaseValue) -> Type: - """Convert a value to its corresponding MLIR type. - - Args: - ctx (Context): the MLIR context to perform the conversion - value (BaseValue): the value to convert - - Returns: - Type: the MLIR type corresponding to given value - """ - - dtype = value.dtype - if isinstance(dtype, Integer): - mlir_type = integer_to_mlir_type(ctx, dtype, value.is_encrypted) - if isinstance(value, TensorValue): - if not value.is_scalar: - mlir_type = RankedTensorType.get(value.shape, mlir_type) - return mlir_type - - raise TypeError(f"{value} is not supported for MLIR conversion") diff --git a/concrete/common/mlir/graph_converter.py b/concrete/common/mlir/graph_converter.py deleted file mode 100644 index cfff6ac31..000000000 --- a/concrete/common/mlir/graph_converter.py +++ /dev/null @@ -1,138 +0,0 @@ -"""Module that provides OPGraph conversion functionality.""" - -# pylint cannot extract symbol information of 'mlir' module so we need to disable some lints - -# pylint: disable=no-name-in-module - -from abc import ABC, abstractmethod -from typing import Any, Dict, List - -import concrete.lang as concretelang -import networkx as nx -from mlir.dialects import builtin -from mlir.ir import Context, InsertionPoint, Location, Module - -from ..operator_graph import OPGraph -from ..representation.intermediate import Input, IntermediateNode -from .conversion_helpers import value_to_mlir_type -from .node_converter import IntermediateNodeConverter - -# pylint: enable=no-name-in-module - - -class OPGraphConverter(ABC): - """Converter of OPGraph to MLIR.""" - - def convert(self, op_graph: OPGraph) -> str: - """Convert an operation graph to its corresponding MLIR representation. - - Args: - op_graph (OPGraph): the operation graph to be converted - - Returns: - str: textual MLIR representation corresponding to given operation graph - """ - - additional_conversion_info = self._generate_additional_info_dict(op_graph) - - # There are no tensor +*- scalar operations in the compiler - # But such operations are used commonly so we need to support them - # So, we implemented some workarounds (pull request #970) - # Once we have native support, this workaround shall be removed (issue #837) - # (most changes in #970 shall be reverted) - - # { node1: "%arg0", node2: "%0", node3: "%1" } - nodes_to_mlir_names: Dict[IntermediateNode, str] = {} - - # { "%arg0": "i5", "%0": "tensor<2x3x!FHE.eint<4>>" } - mlir_names_to_mlir_types: Dict[str, str] = {} - - # { "%0": ["%c1_i5"] } == for %0 we need to convert %c1_i5 to 1d tensor - scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]] = {} - - with Context() as ctx, Location.unknown(): - concretelang.register_dialects(ctx) - - module = Module.create() - with InsertionPoint(module.body): - parameters = [ - value_to_mlir_type(ctx, input_node.outputs[0]) - for input_node in op_graph.get_ordered_inputs() - ] - - @builtin.FuncOp.from_py_func(*parameters) - def main(*arg): - ir_to_mlir = {} - for arg_num, node in op_graph.input_nodes.items(): - ir_to_mlir[node] = arg[arg_num] - - mlir_name = f"%arg{arg_num}" - nodes_to_mlir_names[node] = mlir_name - mlir_names_to_mlir_types[mlir_name] = str(parameters[arg_num]) - - for node in nx.topological_sort(op_graph.graph): - if isinstance(node, Input): - continue - - preds = [ir_to_mlir[pred] for pred in op_graph.get_ordered_preds(node)] - node_converter = IntermediateNodeConverter( - ctx, - op_graph, - node, - preds, - nodes_to_mlir_names, - mlir_names_to_mlir_types, - scalar_to_1d_tensor_conversion_hacks, - ) - ir_to_mlir[node] = node_converter.convert(additional_conversion_info) - - results = ( - ir_to_mlir[output_node] for output_node in op_graph.get_ordered_outputs() - ) - return results - - module_lines_after_hacks_are_applied = [] - for line in str(module).split("\n"): - mlir_name = line.split("=")[0].strip() - if mlir_name not in scalar_to_1d_tensor_conversion_hacks: - module_lines_after_hacks_are_applied.append(line) - continue - - to_be_replaced = scalar_to_1d_tensor_conversion_hacks[mlir_name] - for arg_name in to_be_replaced: - new_name = f"%hack_{mlir_name.replace('%', '')}_{arg_name.replace('%', '')}" - mlir_type = mlir_names_to_mlir_types[arg_name] - - hack_line = ( - f" {new_name} = tensor.from_elements {arg_name} : tensor<1x{mlir_type}>" - ) - module_lines_after_hacks_are_applied.append(hack_line) - - line = line.replace(arg_name, new_name) - - new_arg_types = [] - - arg_types = line.split(":")[1].split("->")[0].strip()[1:-1] - for arg in arg_types.split(", "): - if arg.startswith("tensor"): - new_arg_types.append(arg) - else: - new_arg_types.append(f"tensor<1x{arg}>") - - line = line.replace(arg_types, ", ".join(new_arg_types)) - - module_lines_after_hacks_are_applied.append(line) - - return "\n".join(module_lines_after_hacks_are_applied) - - @staticmethod - @abstractmethod - def _generate_additional_info_dict(op_graph: OPGraph) -> Dict[str, Any]: - """Generate additional conversion info dict for the MLIR converter. - - Args: - op_graph (OPGraph): the operation graph from which the additional info will be generated - - Returns: - Dict[str, Any]: dict of additional conversion info - """ diff --git a/concrete/common/mlir/node_converter.py b/concrete/common/mlir/node_converter.py deleted file mode 100644 index 52e280254..000000000 --- a/concrete/common/mlir/node_converter.py +++ /dev/null @@ -1,873 +0,0 @@ -"""Module that provides IntermediateNode conversion functionality.""" - -# pylint cannot extract symbol information of 'mlir' module so we need to disable some lints - -# pylint: disable=no-name-in-module - -from typing import Any, Dict, List, Tuple, cast - -import numpy -from concrete.lang.dialects import fhe, fhelinalg -from mlir.dialects import arith, linalg, tensor -from mlir.ir import ( - ArrayAttr, - Attribute, - BoolAttr, - Context, - DenseElementsAttr, - IndexType, - IntegerAttr, - IntegerType, - OpResult, - RankedTensorType, -) - -from ..data_types import Integer -from ..debugging import assert_true -from ..helpers.indexing_helpers import determine_new_dimension_size -from ..operator_graph import OPGraph -from ..representation.intermediate import ( - Add, - Constant, - Conv2D, - Dot, - GenericFunction, - IndexConstant, - IntermediateNode, - MatMul, - Mul, - Sub, -) -from ..values import TensorValue -from .conversion_helpers import integer_to_mlir_type, value_to_mlir_type - -# pylint: enable=no-name-in-module - - -class IntermediateNodeConverter: - """Converter of IntermediateNode to MLIR.""" - - ctx: Context - op_graph: OPGraph - node: IntermediateNode - preds: List[OpResult] - - all_of_the_inputs_are_encrypted: bool - all_of_the_inputs_are_tensors: bool - one_of_the_inputs_is_a_tensor: bool - - nodes_to_mlir_names: Dict[IntermediateNode, str] - mlir_names_to_mlir_types: Dict[str, str] - scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]] - - def __init__( - self, - ctx: Context, - op_graph: OPGraph, - node: IntermediateNode, - preds: List[OpResult], - nodes_to_mlir_names: Dict[OpResult, str], - mlir_names_to_mlir_types: Dict[str, str], - scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]], - ): - self.ctx = ctx - self.op_graph = op_graph - self.node = node - self.preds = preds - - self.all_of_the_inputs_are_encrypted = True - self.all_of_the_inputs_are_tensors = True - self.one_of_the_inputs_is_a_tensor = False - - for inp in node.inputs: - if inp.is_clear: - self.all_of_the_inputs_are_encrypted = False - - if isinstance(inp, TensorValue): - if inp.is_scalar: - self.all_of_the_inputs_are_tensors = False - else: - self.one_of_the_inputs_is_a_tensor = True - else: # pragma: no cover - # this branch is not covered as there are only TensorValues for now - self.all_of_the_inputs_are_tensors = False - - self.nodes_to_mlir_names = nodes_to_mlir_names - self.mlir_names_to_mlir_types = mlir_names_to_mlir_types - self.scalar_to_1d_tensor_conversion_hacks = scalar_to_1d_tensor_conversion_hacks - - def convert(self, additional_conversion_info: Dict[str, Any]) -> OpResult: - """Convert an intermediate node to its corresponding MLIR representation. - - Args: - additional_conversion_info (Dict[str, Any]): - external info that the converted node might need - - Returns: - str: textual MLIR representation corresponding to self.node - """ - - # pylint: disable=too-many-branches - - if isinstance(self.node, Add): - result = self.convert_add() - - elif isinstance(self.node, Constant): - result = self.convert_constant() - - elif isinstance(self.node, Dot): - result = self.convert_dot() - - elif isinstance(self.node, GenericFunction): - if self.node.op_name in ["flatten", "reshape"]: - # notice flatten() == reshape(-1) and convert_reshape can handle that - result = self.convert_reshape() - elif self.node.op_name == "sum": - result = self.convert_sum() - elif self.node.op_name == "concat": - result = self.convert_concat() - elif self.node.op_name == "transpose": - result = self.convert_transpose() - else: - result = self.convert_generic_function(additional_conversion_info) - - elif isinstance(self.node, IndexConstant): - result = self.convert_index_constant() - - elif isinstance(self.node, MatMul): - result = self.convert_matmul() - - elif isinstance(self.node, Mul): - result = self.convert_mul() - - elif isinstance(self.node, Sub): - result = self.convert_sub() - - elif isinstance(self.node, Conv2D): - result = self.convert_conv2d() - - else: # pragma: no cover - # this branch is not covered as unsupported opeations fail on check mlir compatibility - raise NotImplementedError(f"{type(self.node)} nodes cannot be converted to MLIR yet") - - # pylint: enable=too-many-branches - - mlir_name = str(result).replace("Value(", "").split("=", maxsplit=1)[0].strip() - - self.nodes_to_mlir_names[self.node] = mlir_name - self.mlir_names_to_mlir_types[mlir_name] = str(result.type) - - if isinstance(self.node, (Add, Mul, Sub, Dot)): - if self.one_of_the_inputs_is_a_tensor and not self.all_of_the_inputs_are_tensors: - to_be_converted = [] - for (pred, output) in self.op_graph.get_ordered_preds_and_inputs_of(self.node): - inp = pred.outputs[output] - if isinstance(inp, TensorValue) and inp.is_scalar: - to_be_converted.append(self.nodes_to_mlir_names[pred]) - self.scalar_to_1d_tensor_conversion_hacks[mlir_name] = to_be_converted - - return result - - def convert_add(self) -> OpResult: - """Convert an Add node to its corresponding MLIR representation. - - Returns: - str: textual MLIR representation corresponding to self.node - """ - - assert_true(len(self.node.inputs) == 2) - assert_true(len(self.node.outputs) == 1) - - resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0]) - preds = self.preds - - if self.all_of_the_inputs_are_encrypted: - if self.one_of_the_inputs_is_a_tensor: - result = fhelinalg.AddEintOp(resulting_type, *preds).result - else: - result = fhe.AddEintOp(resulting_type, *preds).result - else: - if self.node.inputs[0].is_clear: # pragma: no cover - # this branch is not covered as it's impossible to get into due to how tracing works - # however, it doesn't hurt to keep it as an extra measure - preds = preds[::-1] - - if self.one_of_the_inputs_is_a_tensor: - result = fhelinalg.AddEintIntOp(resulting_type, *preds).result - else: - result = fhe.AddEintIntOp(resulting_type, *preds).result - - return result - - def convert_concat(self) -> OpResult: - """Convert a "concat" node to its corresponding MLIR representation. - - Returns: - str: textual MLIR representation corresponding to self.node - """ - - assert_true(len(self.node.inputs) >= 2) - assert_true(len(self.node.outputs) == 1) - - node = cast(GenericFunction, self.node) - resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0]) - - axis = node.op_kwargs.get("axis", 0) - if axis is not None: - if axis < 0: - axis += len(cast(TensorValue, self.node.inputs[0]).shape) - return fhelinalg.ConcatOp( - resulting_type, - self.preds, - IntegerAttr.get(IntegerType.get_signless(64), axis), - ).result - - flattened_preds = [] - for pred, input_value in zip(self.preds, self.node.inputs): - input_shape = cast(TensorValue, input_value).shape - input_size = numpy.prod(input_shape) - input_dtype = cast(Integer, input_value.dtype) - - flattened_pred_type = RankedTensorType.get( - [input_size], - integer_to_mlir_type(self.ctx, input_dtype, input_value.is_encrypted), - ) - flattened_pred = linalg.TensorCollapseShapeOp( - flattened_pred_type, - pred, - ArrayAttr.get( - [ - ArrayAttr.get( - [ - IntegerAttr.get(IndexType.parse("index"), i) - for i in range(len(input_shape)) - ] - ) - ] - ), - ).result - flattened_preds.append(flattened_pred) - - return fhelinalg.ConcatOp( - resulting_type, - flattened_preds, - IntegerAttr.get(IntegerType.get_signless(64), 0), - ).result - - def convert_constant(self) -> OpResult: - """Convert a Constant node to its corresponding MLIR representation. - - Returns: - str: textual MLIR representation corresponding to self.node - """ - - assert_true(len(self.node.inputs) == 0) - assert_true(len(self.node.outputs) == 1) - - value = self.node.outputs[0] - if not isinstance(value, TensorValue): # pragma: no cover - # this branch is not covered as there are only TensorValues for now - raise NotImplementedError(f"{value} constants cannot be converted to MLIR yet") - - resulting_type = value_to_mlir_type(self.ctx, value) - data = cast(Constant, self.node).constant_data - - if value.is_scalar: - attr = IntegerAttr.get(resulting_type, data) - else: - # usage of `Attribute.parse` is the result of some limitations in the MLIR module - # provided by LLVM - - # what should have been used is `DenseElementsAttr` but it's impossible to assign - # custom bit-widths using it (e.g., uint5) - - # since we coudn't create a `DenseElementsAttr` with a custom bit width using python api - # we use `Attribute.parse` to let the underlying library do it by itself - - attr = Attribute.parse(f"dense<{str(data.tolist())}> : {resulting_type}") - - return arith.ConstantOp(resulting_type, attr).result - - def convert_conv2d(self) -> OpResult: - """Convert a Conv2D node to its corresponding MLIR representation. - - Returns: - str: textual MLIR representation corresponding to self.node - """ - - assert_true(len(self.node.inputs) == 2 or len(self.node.inputs) == 3) - assert_true(len(self.node.outputs) == 1) - has_bias = len(self.node.inputs) == 3 - - x = self.node.inputs[0] - weight = self.node.inputs[1] - if not (x.is_encrypted and weight.is_clear): # pragma: no cover - raise NotImplementedError( - f"Conv2D with input {x} and weight {weight} cannot be converted to MLIR yet", - ) - - resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0]) - preds = self.preds - - node = cast(Conv2D, self.node) - integer_type = IntegerType.get_signless(64, context=self.ctx) - strides = DenseElementsAttr.get( - numpy.array(list(node.strides), dtype=numpy.uint64), - context=self.ctx, - type=integer_type, - ) - dilations = DenseElementsAttr.get( - numpy.array(list(node.dilations), dtype=numpy.uint64), - context=self.ctx, - type=integer_type, - ) - pads = DenseElementsAttr.get( - numpy.array(list(node.pads), dtype=numpy.uint64), context=self.ctx, type=integer_type - ) - if has_bias: - result = fhelinalg.Conv2dOp(resulting_type, *preds, pads, strides, dilations).result - else: - result = fhelinalg.Conv2dOp( - resulting_type, *preds, None, pads, strides, dilations - ).result - - return result - - def convert_dot(self) -> OpResult: - """Convert a Dot node to its corresponding MLIR representation. - - Returns: - str: textual MLIR representation corresponding to self.node - """ - - assert_true(len(self.node.inputs) == 2) - assert_true(len(self.node.outputs) == 1) - - if self.all_of_the_inputs_are_encrypted: - lhs = self.node.inputs[0] - rhs = self.node.inputs[1] - raise NotImplementedError( - f"Dot product between {lhs} and {rhs} cannot be converted to MLIR yet", - ) - - resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0]) - preds = self.preds - - if self.node.inputs[0].is_clear: - preds = preds[::-1] - - if self.all_of_the_inputs_are_tensors: - # numpy.dot(x, y) where x and y are both vectors = regular dot product - result = fhelinalg.Dot(resulting_type, *preds).result - - elif not self.one_of_the_inputs_is_a_tensor: - # numpy.dot(x, y) where x and y are both scalars = x * y - result = fhe.MulEintIntOp(resulting_type, *preds).result - - else: - # numpy.dot(x, y) where one of x or y is a scalar and the other one is a vector = x * y - result = fhelinalg.MulEintIntOp(resulting_type, *preds).result - - return result - - def convert_generic_function(self, additional_conversion_info: Dict[str, Any]) -> OpResult: - """Convert a GenericFunction node to its corresponding MLIR representation. - - Returns: - str: textual MLIR representation corresponding to self.node - """ - - variable_input_indices = [ - idx - for idx, inp in enumerate(self.op_graph.get_ordered_preds(self.node)) - if not isinstance(inp, Constant) - ] - if len(variable_input_indices) != 1: # pragma: no cover - # this branch is not covered as it's impossible to get into due to how tracing works - # however, it doesn't hurt to keep it as an extra measure - raise NotImplementedError( - "Table lookups with more than one variable input cannot be converted to MLIR yet" - ) - variable_input_index = variable_input_indices[0] - - assert_true(len(self.node.outputs) == 1) - output = self.node.outputs[0] - - value = self.node.inputs[variable_input_index] - assert_true(value.is_encrypted) - - if not isinstance(value.dtype, Integer): # pragma: no cover - # this branch is not covered as it's impossible to get into due to how compilation works - # however, it doesn't hurt to keep it as an extra measure - raise NotImplementedError(f"Table lookup on {value} cannot be converted to MLIR yet") - - tables = additional_conversion_info["tables"][self.node] - assert_true(len(tables) > 0) - - lut_shape: Tuple[int, ...] = () - map_shape: Tuple[int, ...] = () - - if len(tables) == 1: - table = tables[0][0] - - # The reduction on 63b is to avoid problems like doing a TLU of - # the form T[j] = 2< OpResult: - """Convert a IndexConstant node to its corresponding MLIR representation. - - Returns: - str: textual MLIR representation corresponding to self.node - """ - - # pylint: disable=too-many-locals - - assert_true(len(self.node.inputs) == 1) - assert_true(len(self.node.outputs) == 1) - - tensor_type = value_to_mlir_type(self.ctx, self.node.outputs[0]) - pred = self.preds[0] - - input_value = cast(TensorValue, self.node.inputs[0]) - input_shape = input_value.shape - - index = cast(IndexConstant, self.node).index - index_str = self.node.text_for_formatting([""], 0) - - index_type = IndexType.parse("index") - - if len(index) == len(input_shape) and all(isinstance(i, int) for i in index): - indices = [] - for value, dimension_size in zip(index, input_shape): - assert isinstance(value, int) # mypy - attr = IntegerAttr.get(index_type, value if value >= 0 else value + dimension_size) - indices.append(arith.ConstantOp(index_type, attr).result) - return tensor.ExtractOp(tensor_type, pred, indices).result - - offsets = [] - sizes = [] - strides = [] - - destroyed_dimensions = [] - for dimension, (indexing_element, dimension_size) in enumerate(zip(index, input_shape)): - - if isinstance(indexing_element, int): - destroyed_dimensions.append(dimension) - size = 1 - stride = 1 - offset = ( - indexing_element if indexing_element >= 0 else indexing_element + dimension_size - ) - - elif isinstance(indexing_element, slice): - size = determine_new_dimension_size( - indexing_element, - dimension_size, - dimension, - input_shape, - index_str, - ) - stride = indexing_element.step if isinstance(indexing_element.step, int) else 1 - offset = ( - ( - indexing_element.start - if indexing_element.start >= 0 - else indexing_element.start + dimension_size - ) - if isinstance(indexing_element.start, int) - else (0 if stride > 0 else dimension_size - 1) - ) - - else: # pragma: no cover - # this branch is impossible to reach with all the previous checks - # but let's keep it as an extra measure - raise NotImplementedError( - f"Indexing of {input_value} with {index_str} cannot be converted to MLIR", - ) - - offsets.append(offset) - sizes.append(size) - strides.append(stride) - - if len(destroyed_dimensions) == 0: - return tensor.ExtractSliceOp( - tensor_type, - pred, - [], - [], - [], - ArrayAttr.get([IntegerAttr.get(index_type, value) for value in offsets]), - ArrayAttr.get([IntegerAttr.get(index_type, value) for value in sizes]), - ArrayAttr.get([IntegerAttr.get(index_type, value) for value in strides]), - ).result - - output_value = cast(TensorValue, self.node.outputs[0]) - - intermediate_shape = list(output_value.shape) - for dimension in destroyed_dimensions: - intermediate_shape.insert(dimension, 1) - - intermediate_type = RankedTensorType.get( - intermediate_shape, - integer_to_mlir_type( - self.ctx, - cast(Integer, output_value.dtype), - output_value.is_encrypted, - ), - ) - - intermediate = tensor.ExtractSliceOp( - intermediate_type, - pred, - [], - [], - [], - ArrayAttr.get([IntegerAttr.get(index_type, value) for value in offsets]), - ArrayAttr.get([IntegerAttr.get(index_type, value) for value in sizes]), - ArrayAttr.get([IntegerAttr.get(index_type, value) for value in strides]), - ).result - - reassociaton = [] - - current_intermediate_dimension = 0 - for _ in range(len(output_value.shape)): - indices = [current_intermediate_dimension] - while current_intermediate_dimension in destroyed_dimensions: - current_intermediate_dimension += 1 - indices.append(current_intermediate_dimension) - - reassociaton.append(indices) - current_intermediate_dimension += 1 - while current_intermediate_dimension < len(intermediate_shape): - reassociaton[-1].append(current_intermediate_dimension) - current_intermediate_dimension += 1 - - return linalg.TensorCollapseShapeOp( - tensor_type, - intermediate, - ArrayAttr.get( - [ - ArrayAttr.get( - [IntegerAttr.get(index_type, index) for index in indices], - ) - for indices in reassociaton - ], - ), - ).result - - # pylint: enable=too-many-locals - - def convert_matmul(self) -> OpResult: - """Convert a MatMul node to its corresponding MLIR representation. - - Returns: - str: textual MLIR representation corresponding to self.node - """ - - assert_true(len(self.node.inputs) == 2) - assert_true(len(self.node.outputs) == 1) - - if self.all_of_the_inputs_are_encrypted: - lhs = self.node.inputs[0] - rhs = self.node.inputs[1] - raise NotImplementedError( - f"Matrix multiplication between {lhs} and {rhs} cannot be converted to MLIR yet", - ) - - resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0]) - preds = self.preds - - assert isinstance(self.node.outputs[0], TensorValue) - if self.node.outputs[0].shape == (): - if self.node.inputs[0].is_clear: - preds = preds[::-1] - result = fhelinalg.Dot(resulting_type, *preds).result - - elif self.node.inputs[0].is_clear: - result = fhelinalg.MatMulIntEintOp(resulting_type, *preds).result - else: - result = fhelinalg.MatMulEintIntOp(resulting_type, *preds).result - - return result - - def convert_mul(self) -> OpResult: - """Convert a Mul node to its corresponding MLIR representation. - - Returns: - str: textual MLIR representation corresponding to self.node - """ - - assert_true(len(self.node.inputs) == 2) - assert_true(len(self.node.outputs) == 1) - - if self.all_of_the_inputs_are_encrypted: - lhs = self.node.inputs[0] - rhs = self.node.inputs[1] - raise NotImplementedError( - f"Multiplication between {lhs} and {rhs} cannot be converted to MLIR yet", - ) - - resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0]) - preds = self.preds - - if self.node.inputs[0].is_clear: # pragma: no cover - # this branch is not covered as it's impossible to get into due to how tracing works - # however, it doesn't hurt to keep it as an extra measure - preds = preds[::-1] - - if self.one_of_the_inputs_is_a_tensor: - result = fhelinalg.MulEintIntOp(resulting_type, *preds).result - else: - result = fhe.MulEintIntOp(resulting_type, *preds).result - - return result - - def convert_reshape(self) -> OpResult: - """Convert a "reshape" node to its corresponding MLIR representation. - - Returns: - str: textual MLIR representation corresponding to self.node - """ - - assert_true(len(self.node.inputs) == 1) - assert_true(len(self.node.outputs) == 1) - - assert_true(isinstance(self.node.inputs[0], TensorValue)) - input_shape = cast(TensorValue, self.node.inputs[0]).shape - - assert_true(isinstance(self.node.outputs[0], TensorValue)) - output_shape = cast(TensorValue, self.node.outputs[0]).shape - - pred = self.preds[0] - if input_shape == output_shape: - return pred - - # we can either collapse or expand, which changes the number of dimensions - # this is a limitation of the current compiler and it will be improved in the future (#1060) - can_be_converted_directly = len(input_shape) != len(output_shape) - - reassociation: List[List[int]] = [] - if can_be_converted_directly: - if len(output_shape) == 1: - # output is 1 dimensional so collapse every dimension into the same dimension - reassociation.append(list(range(len(input_shape)))) - else: - # input is m dimensional - # output is n dimensional - # and m is different than n - - # we don't want to duplicate code so we forget about input and output - # and we focus on smaller shape and bigger shape - - smaller_shape, bigger_shape = ( - (output_shape, input_shape) - if len(output_shape) < len(input_shape) - else (input_shape, output_shape) - ) - s_index, b_index = 0, 0 - - # now we will figure out how to group the bigger shape to get the smaller shape - # think of the algorithm below as - # keep merging the dimensions of the bigger shape - # until we have a match on the smaller shape - # then try to match the next dimension of the smaller shape - # if all dimensions of the smaller shape is matched - # we can convert it - - group = [] - size = 1 - while s_index < len(smaller_shape) and b_index < len(bigger_shape): - # dimension `b_index` of `bigger_shape` belongs to current group - group.append(b_index) - - # and current group has `size * bigger_shape[b_index]` elements now - size *= bigger_shape[b_index] - - # if current group size matches the dimension `s_index` of `smaller_shape` - if size == smaller_shape[s_index]: - # we finalize this group and reset everything - size = 1 - reassociation.append(group) - group = [] - - # now try to match the next dimension of `smaller_shape` - s_index += 1 - - # now process the next dimension of `bigger_shape` - b_index += 1 - - # handle the case where bigger shape has proceeding 1s - # e.g., (5,) -> (5, 1) - while b_index < len(bigger_shape) and bigger_shape[b_index] == 1: - reassociation[-1].append(b_index) - b_index += 1 - - # if not all dimensions of both shapes are processed exactly - if s_index != len(smaller_shape) or b_index != len(bigger_shape): - # we cannot convert - can_be_converted_directly = False - - index_type = IndexType.parse("index") - resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0]) - - if can_be_converted_directly: - reassociation_attr = ArrayAttr.get( - [ - ArrayAttr.get([IntegerAttr.get(index_type, dimension) for dimension in group]) - for group in reassociation - ] - ) - if len(output_shape) < len(input_shape): - return linalg.TensorCollapseShapeOp(resulting_type, pred, reassociation_attr).result - return linalg.TensorExpandShapeOp(resulting_type, pred, reassociation_attr).result - - flattened_type = value_to_mlir_type( - self.ctx, - TensorValue( - self.node.inputs[0].dtype, - self.node.inputs[0].is_encrypted, - (numpy.prod(input_shape),), - ), - ) - flattened_result = linalg.TensorCollapseShapeOp( - flattened_type, - pred, - ArrayAttr.get( - [ArrayAttr.get([IntegerAttr.get(index_type, i) for i in range(len(input_shape))])] - ), - ).result - - return linalg.TensorExpandShapeOp( - resulting_type, - flattened_result, - ArrayAttr.get( - [ArrayAttr.get([IntegerAttr.get(index_type, i) for i in range(len(output_shape))])] - ), - ).result - - def convert_sub(self) -> OpResult: - """Convert a Sub node to its corresponding MLIR representation. - - Returns: - str: textual MLIR representation corresponding to self.node - """ - - assert_true(len(self.node.inputs) == 2) - assert_true(len(self.node.outputs) == 1) - - lhs = self.node.inputs[0] - rhs = self.node.inputs[1] - if not (lhs.is_clear and rhs.is_encrypted): - raise NotImplementedError( - f"Subtraction of {rhs} from {lhs} cannot be converted to MLIR yet", - ) - - resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0]) - preds = self.preds - - if self.one_of_the_inputs_is_a_tensor: - result = fhelinalg.SubIntEintOp(resulting_type, *preds).result - else: - result = fhe.SubIntEintOp(resulting_type, *preds).result - - return result - - def convert_sum(self) -> OpResult: - """Convert a "sum" node to its corresponding MLIR representation. - - Returns: - str: textual MLIR representation corresponding to self.node - """ - - assert_true(len(self.node.inputs) == 1) - assert_true(len(self.node.outputs) == 1) - - node = cast(GenericFunction, self.node) - resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0]) - - axes = node.op_kwargs.get("axis", []) - keep_dims = node.op_kwargs.get("keepdims", False) - - if isinstance(axes, int): - axes = [axes] - elif isinstance(axes, tuple): - axes = list(axes) - - input_dimensions = len(cast(TensorValue, self.node.inputs[0]).shape) - for i, axis in enumerate(axes): - if axis < 0: - axes[i] += input_dimensions - - return fhelinalg.SumOp( - resulting_type, - self.preds[0], - ArrayAttr.get([IntegerAttr.get(IntegerType.get_signless(64), axis) for axis in axes]), - BoolAttr.get(keep_dims), - ).result - - def convert_transpose(self) -> OpResult: - """Convert a Transpose node to its corresponding MLIR representation. - - Returns: - str: textual MLIR representation corresponding to self.node - """ - - assert_true(len(self.node.inputs) == 1) - assert_true(len(self.node.outputs) == 1) - - resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0]) - preds = self.preds - return fhelinalg.TransposeOp(resulting_type, *preds).result diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py deleted file mode 100644 index 56e7694fe..000000000 --- a/concrete/common/mlir/utils.py +++ /dev/null @@ -1,235 +0,0 @@ -"""Utilities for MLIR conversion.""" -from typing import Dict, List, Optional, Tuple - -import networkx as nx - -from ..data_types.dtypes_helpers import ( - value_is_clear_scalar_integer, - value_is_clear_tensor_integer, - value_is_encrypted_scalar_integer, - value_is_encrypted_tensor_integer, - value_is_integer, - value_is_unsigned_integer, -) -from ..debugging import format_operation_graph -from ..debugging.custom_assert import assert_not_reached, assert_true -from ..operator_graph import OPGraph -from ..representation import intermediate -from ..representation.intermediate import Conv2D, IntermediateNode - -# TODO: should be removed as the supported bit-width is now dynamic -ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB = 8 - - -def check_node_compatibility_with_mlir( - node: IntermediateNode, - nx_graph: nx.MultiDiGraph, - is_output: bool, -) -> Optional[str]: - """Check if node is compatible with MLIR. - - Args: - node (IntermediateNode): node to check - nx_graph (nx.MultiDiGraph): the networkx graph to which node belongs - is_output (bool): whether the node is an output node or not - - Returns: - Optional[str]: None if the node is compatible else reason for incompatibility - """ - - # pylint: disable=too-many-branches,too-many-return-statements - - inputs = node.inputs - outputs = node.outputs - - if isinstance(node, intermediate.Add): # constraints for addition - for inp in inputs: - if not value_is_integer(inp): - return "only integer addition is supported" - - elif isinstance(node, intermediate.Sub): # constraints for subtraction - for inp in inputs: - if not value_is_integer(inp): - return "only integer subtraction is supported" - - elif isinstance(node, intermediate.Mul): # constraints for multiplication - for inp in inputs: - if not value_is_integer(inp): - return "only integer multiplication is supported" - - elif isinstance(node, intermediate.Input): # constraints for inputs - assert_true(len(outputs) == 1) - if not value_is_unsigned_integer(outputs[0]): - return "only unsigned integer inputs are supported" - - elif isinstance(node, intermediate.Constant): # constraints for constants - assert_true(len(outputs) == 1) - # We currently can't fail on the following assert, but let it for possible changes in the - # future - if not value_is_integer(outputs[0]): - return "only integer constants are supported" # pragma: no cover - - elif isinstance(node, intermediate.GenericFunction): # constraints for univariate functions - for inp in inputs: - if not value_is_integer(inp): - return ( - f"{node.op_name} with floating-point inputs " - f"is required to be fused to be supported" - ) - - if node.op_kind == "TLU": - assert_true( - len( - [ - pred_node - for pred_node in nx_graph.pred[node] - if not isinstance(pred_node, intermediate.Constant) - ] - ) - == 1 - ) - else: - if node.op_name not in ["flatten", "reshape", "sum", "concat", "transpose"]: - return f"{node.op_name} is not supported for the time being" - - elif isinstance(node, intermediate.Dot): # constraints for dot product - assert_true(len(inputs) == 2) - if not value_is_integer(inputs[0]) or not value_is_integer(inputs[1]): - return "only integer dot product is supported" - - elif isinstance(node, intermediate.IndexConstant): # constraints for constant indexing - assert_true(len(outputs) == 1) - - elif isinstance(node, intermediate.MatMul): # constraints for matrix multiplication - assert_true(len(inputs) == 2) - - elif isinstance(node, Conv2D): - assert_true(len(inputs) in [2, 3]) - - else: # pragma: no cover - assert_not_reached("Non IntermediateNode object in the OPGraph") - - if is_output: - for out in outputs: - # For signed values and waiting for a real fix (#845): what is returned by the compiler - # is not the (possibly negative) result r, but the always-positive (r mod 2**t), where t - # is the bitwidth of r - - # We currently can't fail on the following assert, but let it for possible changes in - # the future - if not value_is_integer(out): - return "only integer outputs are supported" # pragma: no cover - else: - for out in outputs: - # We currently can't fail on the following assert, but let it for possible changes in - # the future - if not value_is_integer(out): - return "only integer intermediates are supported" # pragma: no cover - - # pylint: enable=too-many-branches,too-many-return-statements - - return None - - -def check_graph_values_compatibility_with_mlir( - op_graph: OPGraph, -) -> Optional[Dict[IntermediateNode, List[str]]]: - """Make sure the graph outputs are unsigned integers, which is what the compiler supports. - - Args: - op_graph: computation graph to check - - Returns: - Dict[IntermediateNode, str]: None if the graph is compatible - information about offending nodes otherwise - """ - - offending_nodes = {} - - for node in op_graph.graph.nodes: - is_output = node in op_graph.output_nodes.values() - if ( - reason := check_node_compatibility_with_mlir(node, op_graph.graph, is_output) - ) is not None: - offending_nodes[node] = [reason] - - return None if len(offending_nodes) == 0 else offending_nodes - - -def _set_all_bit_width(op_graph: OPGraph, p: int): - """Set all bit_width in the graph to `p` and `p+1` for clear and encrypted values respectively. - - Args: - op_graph: graph to set bit_width for - p: bit_width to set everywhere - """ - for node in op_graph.graph.nodes: - for value in node.outputs + node.inputs: - if value_is_clear_scalar_integer(value) or value_is_clear_tensor_integer(value): - value.dtype.bit_width = p + 1 - elif value_is_encrypted_scalar_integer(value) or value_is_encrypted_tensor_integer( - value - ): - value.dtype.bit_width = p - - -def get_op_graph_max_bit_width_and_nodes_over_bit_width_limit( - op_graph: OPGraph, -) -> Tuple[int, Dict[IntermediateNode, List[str]]]: - """Get the maximum bit width of integer nodes in the given OPGraph. - - Also returns a dictionary with nodes having an unsupported bit width. - - Args: - op_graph: graph to update bit_width for - - Returns: - Tuple[int, Dict[IntermediateNode, List[str]]]: a tuple containing the maximum bit width of - integer values in the OPGraph as well as a dictionary with nodes and the list of issues - that the nodes have, in this case having an unsupported bit width. - """ - max_bit_width = 0 - offending_nodes: Dict[IntermediateNode, List[str]] = {} - for node in op_graph.graph.nodes: - for value_out in node.outputs: - if value_is_clear_scalar_integer(value_out) or value_is_clear_tensor_integer(value_out): - current_node_out_bit_width = value_out.dtype.bit_width - 1 - else: - - assert_true( - value_is_encrypted_scalar_integer(value_out) - or value_is_encrypted_tensor_integer(value_out) - ) - - current_node_out_bit_width = value_out.dtype.bit_width - - max_bit_width = max(max_bit_width, current_node_out_bit_width) - - # Check that current_node_out_bit_width is supported by the compiler - if current_node_out_bit_width > ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB: - offending_nodes[node] = [ - f"{current_node_out_bit_width} bits is not supported for the time being" - ] - - return max_bit_width, offending_nodes - - -def update_bit_width_for_mlir(op_graph: OPGraph): - """Prepare bit_width of all nodes to be the same, set to the maximum value in the graph. - - Args: - op_graph: graph to update bit_width for - """ - max_bit_width, offending_nodes = get_op_graph_max_bit_width_and_nodes_over_bit_width_limit( - op_graph - ) - - if len(offending_nodes) != 0: - raise RuntimeError( - f"max_bit_width of some nodes is too high for the current version of " - f"the compiler (maximum must be {ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB}) " - f"which is not compatible with:\n\n" - + format_operation_graph(op_graph, highlighted_nodes=offending_nodes) - ) - - _set_all_bit_width(op_graph, max_bit_width) diff --git a/concrete/common/operator_graph.py b/concrete/common/operator_graph.py deleted file mode 100644 index 80fcae0f8..000000000 --- a/concrete/common/operator_graph.py +++ /dev/null @@ -1,319 +0,0 @@ -"""Code to wrap and make manipulating networkx graphs easier.""" - -from copy import deepcopy -from typing import Any, Callable, Dict, Iterable, List, Tuple, Union - -import networkx as nx - -from .data_types.base import BaseDataType -from .data_types.dtypes_helpers import ( - get_base_data_type_for_python_constant_data, - get_constructor_for_python_constant_data, -) -from .data_types.floats import Float -from .data_types.integers import Integer, make_integer_to_hold -from .debugging.custom_assert import assert_true -from .representation.intermediate import Input, IntermediateNode -from .tracing import BaseTracer -from .tracing.tracing_helpers import create_graph_from_output_tracers - - -class OPGraph: - """Class to make work with nx graphs easier.""" - - graph: nx.MultiDiGraph - input_nodes: Dict[int, Input] - output_nodes: Dict[int, IntermediateNode] - - def __init__( - self, - graph: nx.MultiDiGraph, - input_nodes: Dict[int, Input], - output_nodes: Dict[int, IntermediateNode], - ) -> None: - assert_true( - all(isinstance(node, Input) for node in input_nodes.values()), - "Got input nodes that were not Input, which is not supported", - ) - assert_true( - all(isinstance(node, IntermediateNode) for node in output_nodes.values()), - "Got output nodes which were not IntermediateNode, which is not supported", - ) - - self.graph = graph - self.input_nodes = input_nodes - self.output_nodes = output_nodes - self.prune_nodes() - - def __call__(self, *args) -> Union[Any, Tuple[Any, ...]]: - assert_true(len(self.input_nodes) > 0, "Cannot evaluate a graph with no input nodes") - inputs = dict(enumerate(args)) - - assert_true( - len(inputs) == len(self.input_nodes), - f"Expected {len(self.input_nodes)} arguments, got {len(inputs)} : {args}", - ) - - results = self.evaluate(inputs) - tuple_result = tuple(results[output_node] for output_node in self.get_ordered_outputs()) - return tuple_result if len(tuple_result) > 1 else tuple_result[0] - - @staticmethod - def from_output_tracers(output_tracers: Iterable[BaseTracer]) -> "OPGraph": - """Construct OPGraph from output tracers. - - Args: - output_tracers (Iterable[BaseTracer]): The tracers output by the function that was - traced. - - Returns: - OPGraph: The resulting OPGraph. - """ - graph = create_graph_from_output_tracers(output_tracers) - input_nodes = { - node.program_input_idx: node - for node in graph.nodes() - if len(graph.pred[node]) == 0 and isinstance(node, Input) - } - output_nodes = { - output_idx: tracer.traced_computation - for output_idx, tracer in enumerate(output_tracers) - } - return OPGraph(graph, input_nodes, output_nodes) - - @staticmethod - def from_graph( - graph: nx.MultiDiGraph, - input_nodes: Iterable[Input], - output_nodes: Iterable[IntermediateNode], - ) -> "OPGraph": - """Construct OPGraph from an existing networkx MultiDiGraph. - - Args: - graph (nx.MultiDiGraph): The networkx MultiDiGraph to use. - input_nodes (Iterable[Input]): The input nodes of the MultiDiGraph. - output_nodes (Iterable[IntermediateNode]): The output nodes of the MultiDiGraph. - - Returns: - OPGraph: The resulting OPGraph. - """ - return OPGraph(graph, dict(enumerate(input_nodes)), dict(enumerate(output_nodes))) - - def get_ordered_inputs(self) -> List[Input]: - """Get the input nodes of the graph, ordered by their index. - - Returns: - List[Input]: ordered input nodes - """ - return [self.input_nodes[idx] for idx in range(len(self.input_nodes))] - - def get_ordered_outputs(self) -> List[IntermediateNode]: - """Get the output nodes of the graph, ordered by their index. - - Returns: - List[IntermediateNode]: ordered input nodes - """ - return [self.output_nodes[idx] for idx in range(len(self.output_nodes))] - - def get_ordered_preds(self, node: IntermediateNode) -> List[IntermediateNode]: - """Get node predecessors ordered by their indices. - - Args: - node (IntermediateNode): The node for which we want the ordered predecessors. - - Returns: - List[IntermediateNode]: The list of predecessors ordered by input index. - """ - # Replication of pred is managed e.g. x + x will yield the proper pred x twice - idx_to_pred: Dict[int, IntermediateNode] = {} - for pred in self.graph.predecessors(node): - edge_data = self.graph.get_edge_data(pred, node) - idx_to_pred.update((data["input_idx"], pred) for data in edge_data.values()) - return [idx_to_pred[i] for i in range(len(idx_to_pred))] - - def get_ordered_preds_and_inputs_of( - self, node: IntermediateNode - ) -> List[Tuple[IntermediateNode, int]]: - """Get node preds and inputs ordered by their indices. - - Args: - node (IntermediateNode): the node for which we want the ordered inputs - - Returns: - List[Tuple[IntermediateNode, int]]: the ordered list of preds and inputs - """ - - idx_to_inp: Dict[int, Tuple[IntermediateNode, int]] = {} - for pred in self.graph.predecessors(node): - edge_data = self.graph.get_edge_data(pred, node) - idx_to_inp.update( - (data["input_idx"], (pred, data["output_idx"])) for data in edge_data.values() - ) - return [idx_to_inp[i] for i in range(len(idx_to_inp))] - - def evaluate(self, inputs: Dict[int, Any]) -> Dict[IntermediateNode, Any]: - """Evaluate a graph and get intermediate values for all nodes. - - Args: - inputs (Dict[int, Any]): The inputs to the program - - Returns: - Dict[IntermediateNode, Any]: Dictionary with node as keys and resulting values - """ - node_results: Dict[IntermediateNode, Any] = {} - - def get_result_of_node_at_index(node: IntermediateNode, output_idx: int) -> Any: - """Get the output result at index output_idx for a node. - - Args: - node (IntermediateNode): the node from which we want the output. - output_idx (int): which output we want. - - Returns: - Any: the output value of the evaluation of node. - """ - result = node_results[node] - # TODO: #81 remove no cover once we have nodes with multiple outputs - if isinstance(result, tuple): # pragma: no cover - # If the node has multiple outputs (i.e. the result is a tuple), return the - # requested output - return result[output_idx] - # If the result is not a tuple, then the result is the node's only output. Check that - # the requested index is 0 (as it's the only valid value) and return the result itself. - assert_true( - output_idx == 0, - f"Unable to get output at index {output_idx} for node {node}.\n" - f"Node result: {result}", - ) - return result - - for node in nx.topological_sort(self.graph): - if not isinstance(node, Input): - curr_inputs = {} - for pred_node in self.graph.predecessors(node): - edges = self.graph.get_edge_data(pred_node, node) - curr_inputs.update( - { - edge["input_idx"]: get_result_of_node_at_index( - pred_node, - output_idx=edge["output_idx"], - ) - for edge in edges.values() - } - ) - node_results[node] = node.evaluate(curr_inputs) - else: - node_results[node] = node.evaluate({0: inputs[node.program_input_idx]}) - - return node_results - - def update_values_with_bounds_and_samples( - self, - node_bounds_and_samples: dict, - get_base_data_type_for_constant_data: Callable[ - [Any], BaseDataType - ] = get_base_data_type_for_python_constant_data, - get_constructor_for_constant_data: Callable[ - ..., Callable - ] = get_constructor_for_python_constant_data, - ): - """Update values with bounds. - - Update nodes inputs and outputs values with data types able to hold data ranges measured - and passed in nodes_bounds - - Args: - node_bounds_and_samples (dict): Dictionary with nodes as keys, holding dicts with a - 'min', 'max' and 'sample' keys. Those bounds will be taken as the data range to be - represented, per node. The sample allows to determine the data constructors to - prepare the GenericFunction nodes for table generation. - get_base_data_type_for_constant_data (Callable[ [Any], BaseDataType ], optional): This - is a callback function to convert data encountered during value updates to - BaseDataType. This allows to manage data coming from foreign frameworks without - specialising OPGraph. Defaults to get_base_data_type_for_python_constant_data. - get_constructor_for_constant_data (Callable[ ..., Callable ], optional): This is a - callback function to determine the type constructor of the data encountered while - updating the graph bounds. Defaults to get_constructor_for_python_constant_data. - """ - node: IntermediateNode - - for node in self.graph.nodes(): - current_node_bounds_and_samples = node_bounds_and_samples[node] - min_bound, max_bound, sample = ( - current_node_bounds_and_samples["min"], - current_node_bounds_and_samples["max"], - current_node_bounds_and_samples["sample"], - ) - - min_data_type = get_base_data_type_for_constant_data(min_bound) - max_data_type = get_base_data_type_for_constant_data(max_bound) - - # This is a sanity check - min_value_constructor = get_constructor_for_constant_data(min_bound) - max_value_constructor = get_constructor_for_constant_data(max_bound) - - assert_true( - max_value_constructor == min_value_constructor, - ( - f"Got two different type constructors for min and max bound: " - f"{min_value_constructor}, {max_value_constructor}" - ), - ) - - value_constructor = get_constructor_for_constant_data(sample) - - if not isinstance(node, Input): - for output_value in node.outputs: - if isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer): - output_value.dtype = make_integer_to_hold( - (min_bound, max_bound), force_signed=False - ) - else: - assert_true( - isinstance(min_data_type, Float) and isinstance(max_data_type, Float), - ( - "min_bound and max_bound have different common types, " - "this should never happen.\n" - f"min_bound: {min_data_type}, max_bound: {max_data_type}" - ), - ) - output_value.dtype = Float(64) - output_value.underlying_constructor = value_constructor - else: - # Currently variable inputs are only allowed to be integers - assert_true( - isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer), - ( - f"Inputs to a graph should be integers, got bounds that were float, \n" - f"min: {min_bound} ({type(min_bound)}), " - f"max: {max_bound} ({type(max_bound)})" - ), - ) - node.inputs[0].dtype = make_integer_to_hold( - (min_bound, max_bound), force_signed=False - ) - node.inputs[0].underlying_constructor = value_constructor - - node.outputs[0] = deepcopy(node.inputs[0]) - - successors = self.graph.successors(node) - for succ in successors: - edge_data = self.graph.get_edge_data(node, succ) - for edge in edge_data.values(): - input_idx, output_idx = edge["input_idx"], edge["output_idx"] - succ.inputs[input_idx] = deepcopy(node.outputs[output_idx]) - - def prune_nodes(self): - """Remove unreachable nodes from outputs.""" - - current_nodes = {node: None for node in self.get_ordered_outputs()} - useful_nodes: Dict[IntermediateNode, None] = {} - while current_nodes: - next_nodes: Dict[IntermediateNode, None] = {} - useful_nodes.update(current_nodes) - for node in current_nodes: - next_nodes.update({node: None for node in self.graph.predecessors(node)}) - current_nodes = next_nodes - - useless_nodes = [node for node in self.graph.nodes() if node not in useful_nodes] - self.graph.remove_nodes_from(useless_nodes) diff --git a/concrete/common/optimization/__init__.py b/concrete/common/optimization/__init__.py deleted file mode 100644 index f4180d6bb..000000000 --- a/concrete/common/optimization/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Module holding various optimization/simplification code.""" diff --git a/concrete/common/optimization/topological.py b/concrete/common/optimization/topological.py deleted file mode 100644 index 9675e5623..000000000 --- a/concrete/common/optimization/topological.py +++ /dev/null @@ -1,594 +0,0 @@ -"""File holding topological optimization/simplification code.""" -from collections import defaultdict -from copy import deepcopy -from typing import DefaultDict, Dict, Iterable, List, Optional, Set, Tuple, cast - -import networkx as nx -from loguru import logger - -from ..compilation.artifacts import CompilationArtifacts -from ..data_types.floats import Float -from ..data_types.integers import Integer -from ..debugging import format_operation_graph -from ..debugging.custom_assert import assert_true -from ..operator_graph import OPGraph -from ..representation.intermediate import Constant, GenericFunction, Input, IntermediateNode -from ..values import TensorValue - - -def fuse_float_operations( - op_graph: OPGraph, - compilation_artifacts: Optional[CompilationArtifacts] = None, -): - """Find and fuse float domains into single Integer to Integer GenericFunction. - - Args: - op_graph (OPGraph): The OPGraph to simplify - compilation_artifacts (Optional[CompilationArtifacts]): The CompilationArtifacts of the - current compilation, this argument is optional as it's not required to execute float - fusing. - """ - - nx_graph = op_graph.graph - processed_terminal_nodes: Set[IntermediateNode] = set() - number_of_fuse = 0 - while True: - float_subgraph_search_result = find_float_subgraph_with_unique_terminal_node( - nx_graph, processed_terminal_nodes - ) - if float_subgraph_search_result is None: - break - - float_subgraph_start_nodes, terminal_node, subgraph_all_nodes = float_subgraph_search_result - processed_terminal_nodes.add(terminal_node) - - subgraph_conversion_result = convert_float_subgraph_to_fused_node( - op_graph, - float_subgraph_start_nodes, - terminal_node, - subgraph_all_nodes, - ) - - # Not a subgraph we can handle, continue - if subgraph_conversion_result is None: - continue - - fused_node, node_before_subgraph = subgraph_conversion_result - - nx_graph.add_node(fused_node) - - if terminal_node in op_graph.output_nodes.values(): - # Output value replace it - # As the graph changes recreate the output_node_to_idx dict - output_node_to_idx: Dict[IntermediateNode, List[int]] = { - out_node: [] for out_node in op_graph.output_nodes.values() - } - for output_idx, output_node in op_graph.output_nodes.items(): - output_node_to_idx[output_node].append(output_idx) - - for output_idx in output_node_to_idx.get(terminal_node, []): - op_graph.output_nodes[output_idx] = fused_node - - # Disconnect after terminal node and connect fused node instead - terminal_node_succ = list(nx_graph.successors(terminal_node)) - for succ in terminal_node_succ: - succ_edge_data = deepcopy(nx_graph.get_edge_data(terminal_node, succ)) - for edge_key, edge_data in succ_edge_data.items(): - nx_graph.remove_edge(terminal_node, succ, key=edge_key) - # fused_node is always a GenericFunction so output_idx == 0 always - new_edge_data = deepcopy(edge_data) - new_edge_data["output_idx"] = 0 - nx_graph.add_edge(fused_node, succ, key=edge_key, **new_edge_data) - - # Connect the node feeding the subgraph contained in fused_node - # node_before_subgraph has a single integer output currently so output_idx == 0 - nx_graph.add_edge(node_before_subgraph, fused_node, input_idx=0, output_idx=0) - - op_graph.prune_nodes() - if compilation_artifacts is not None: - compilation_artifacts.add_operation_graph( - f"after-float-fuse-{number_of_fuse}", op_graph - ) - - number_of_fuse += 1 - - -def convert_float_subgraph_to_fused_node( - op_graph: OPGraph, - float_subgraph_start_nodes: Dict[IntermediateNode, None], - terminal_node: IntermediateNode, - subgraph_all_nodes: Dict[IntermediateNode, None], -) -> Optional[Tuple[GenericFunction, IntermediateNode]]: - """Convert a float subgraph to an equivalent fused GenericFunction node. - - Args: - op_graph (OPGraph): The OPGraph the float subgraph is part of. - float_subgraph_start_nodes (Dict[IntermediateNode, None]): The nodes starting the float - subgraph in `op_graph`. - terminal_node (IntermediateNode): The node ending the float subgraph. - subgraph_all_nodes (Dict[IntermediateNode, None]): All the nodes in the float subgraph. - - Returns: - Optional[Tuple[GenericFunction, IntermediateNode]]: None if the float subgraph - cannot be fused, otherwise returns a tuple containing the fused node and the node whose - output must be plugged as the input to the subgraph. - """ - - node_with_issues_for_fusing: DefaultDict[IntermediateNode, List[str]] = defaultdict(list) - - subgraph_can_be_fused = subgraph_has_unique_variable_input( - float_subgraph_start_nodes, terminal_node, node_with_issues_for_fusing - ) - - if subgraph_can_be_fused: - # subgraph_values_allow_fusing can be called iff the subgraph has a unique variable input - subgraph_can_be_fused = subgraph_nodes_and_values_allow_fusing( - float_subgraph_start_nodes, subgraph_all_nodes, node_with_issues_for_fusing - ) - - # This test is separate from the previous one to only handle printing issues once - if not subgraph_can_be_fused: - float_subgraph = nx.MultiDiGraph(op_graph.graph.subgraph(subgraph_all_nodes)) - float_subgraph_as_op_graph = OPGraph.from_graph(float_subgraph, [], [terminal_node]) - - printable_graph = format_operation_graph( - float_subgraph_as_op_graph, - highlighted_nodes=node_with_issues_for_fusing, - ) - message = f"The following subgraph is not fusable:\n\n{printable_graph}" - logger.warning(message) - return None - - # Only one variable input node, find which node feeds its input - variable_input_nodes = [ - node for node in float_subgraph_start_nodes if not isinstance(node, Constant) - ] - assert_true(len(variable_input_nodes) == 1) - - current_subgraph_variable_input = variable_input_nodes[0] - assert_true(len(current_subgraph_variable_input.outputs) == 1) - new_input_value = deepcopy(current_subgraph_variable_input.outputs[0]) - - nx_graph = op_graph.graph - - nodes_after_input_set = { - node: None - for node in subgraph_all_nodes - if node in nx_graph.succ[current_subgraph_variable_input] - } - - # # Previous non-deterministic implementation : - # # For some reason creating a graph from a subgraph this way is not deterministic - # float_subgraph = nx.MultiDiGraph(nx_graph.subgraph(subgraph_all_nodes)) - - # Create a copy of the graph, remove nodes that are not in all the subgraph nodes in order to - # get a subgraph deterministically - float_subgraph = nx.MultiDiGraph(nx_graph) - nodes_to_remove = [node for node in float_subgraph.nodes() if node not in subgraph_all_nodes] - float_subgraph.remove_nodes_from(nodes_to_remove) - - new_subgraph_variable_input = Input(new_input_value, "float_subgraph_input", 0) - float_subgraph.add_node(new_subgraph_variable_input) - - for node_after_input in nodes_after_input_set: - # Connect the new input to our subgraph - edge_data_input_to_subgraph = deepcopy( - float_subgraph.get_edge_data( - current_subgraph_variable_input, - node_after_input, - ) - ) - for edge_key, edge_data in edge_data_input_to_subgraph.items(): - float_subgraph.remove_edge( - current_subgraph_variable_input, node_after_input, key=edge_key - ) - # new_subgraph_variable_input is always an Input so output_idx == 0 always - new_edge_data = deepcopy(edge_data) - new_edge_data["output_idx"] = 0 - float_subgraph.add_edge( - new_subgraph_variable_input, - node_after_input, - key=edge_key, - **new_edge_data, - ) - - float_op_subgraph = OPGraph.from_graph( - float_subgraph, - [new_subgraph_variable_input], - [terminal_node], - ) - - assert_true(len(terminal_node.outputs) == 1) - - # Create fused_node - fused_node = GenericFunction( - inputs=[new_subgraph_variable_input.inputs[0]], - arbitrary_func=lambda x, float_op_subgraph, terminal_node: float_op_subgraph.evaluate( - {0: x} - )[terminal_node], - output_value=terminal_node.outputs[0], - op_kind="TLU", - op_kwargs={ - "float_op_subgraph": float_op_subgraph, - "terminal_node": terminal_node, - }, - op_name="subgraph", - ) - - return ( - fused_node, - current_subgraph_variable_input, - ) - - -def is_single_int_output_node(node: IntermediateNode) -> bool: - """Check if a node has a single output and that output is an integer. - - Args: - node (IntermediateNode): the node to check. - - Returns: - bool: returns True if the node has a single integer output, False otherwise. - """ - return len(node.outputs) == 1 and isinstance(node.outputs[0].dtype, Integer) - - -def find_closest_single_int_output_nodes( - nx_graph: nx.MultiDiGraph, - start_nodes: List[IntermediateNode], - subgraph_all_nodes: Dict[IntermediateNode, None], -) -> Tuple[Dict[IntermediateNode, None], Dict[IntermediateNode, None]]: - """Find in nx_graph the closest upstream single integer output nodes to some start nodes. - - Args: - nx_graph (nx.MultiDiGraph): the networkx graph to search in. - start_nodes (List[IntermediateNode]): the nodes from which to start the search. - subgraph_all_nodes (Dict[IntermediateNode, None]): a set that will be updated with all the - nodes visited during the search. - - Returns: - Tuple[Dict[IntermediateNode, None], Dict[IntermediateNode, None]]: returns the dict used as - an ordered set containing the found single output nodes and the updated set of the - visited nodes during the search. - """ - - # Use dict as ordered set - current_nodes = {start_node: None for start_node in start_nodes} - closest_single_int_output_nodes: Dict[IntermediateNode, None] = {} - visited_nodes: Set[IntermediateNode] = set() - while current_nodes: - next_nodes: Dict[IntermediateNode, None] = {} - for node in current_nodes: - if node in visited_nodes: - continue - visited_nodes.add(node) - subgraph_all_nodes.update({node: None}) - predecessors = nx_graph.predecessors(node) - for pred in predecessors: - if is_single_int_output_node(pred): - # Limit of subgraph, record that and record the node as we won't visit it - closest_single_int_output_nodes.update({pred: None}) - subgraph_all_nodes.update({pred: None}) - else: - next_nodes.update({pred: None}) - current_nodes = next_nodes - - return closest_single_int_output_nodes, subgraph_all_nodes - - -def add_nodes_from_to( - nx_graph: nx.MultiDiGraph, - from_nodes: Iterable[IntermediateNode], - to_nodes: Dict[IntermediateNode, None], - subgraph_all_nodes: Dict[IntermediateNode, None], -) -> Dict[IntermediateNode, None]: - """Add nodes from from_nodes to to_nodes to the subgraph_all_nodes set. - - Args: - nx_graph (nx.MultiDiGraph): the graph to traverse. - from_nodes (Iterable[IntermediateNode]): the nodes from which we will add nodes to - subgraph_all_nodes. - to_nodes (Dict[IntermediateNode, None]): the nodes we should stop at. - subgraph_all_nodes (Dict[IntermediateNode, None]): All the nodes in the float subgraph, will - be updated and returned. - - Returns: - Dict[IntermediateNode, None]: returns the updated subgraph_all_nodes. - """ - - # Add the end nodes we won't visit - subgraph_all_nodes.update(to_nodes) - - current_nodes = {from_node: None for from_node in from_nodes} - visited_nodes: Set[IntermediateNode] = set() - while current_nodes: - next_nodes: Dict[IntermediateNode, None] = {} - for node in current_nodes: - if node in visited_nodes: - continue - visited_nodes.add(node) - subgraph_all_nodes.update({node: None}) - predecessors = nx_graph.predecessors(node) - # Add nodes to explore next if they are not indicated as end nodes - next_nodes.update({pred: None for pred in predecessors if pred not in to_nodes}) - current_nodes = next_nodes - - return subgraph_all_nodes - - -def find_float_subgraph_with_unique_terminal_node( - nx_graph: nx.MultiDiGraph, - processed_terminal_nodes: Set[IntermediateNode], -) -> Optional[Tuple[Dict[IntermediateNode, None], IntermediateNode, Dict[IntermediateNode, None]]]: - """Find a subgraph of the graph with float computations. - - Args: - nx_graph (nx.MultiDiGraph): The networkx graph to search in. - processed_terminal_nodes (Dict[IntermediateNode, None]): The set of terminal nodes for which - subgraphs have already been searched, those will be skipped. - - Returns: - Optional[ - Tuple[Dict[IntermediateNode, None], IntermediateNode, Dict[IntermediateNode, None]]]: - None if there are no float subgraphs to process in `nx_graph`. Otherwise returns a - tuple containing the set of nodes beginning a float subgraph, the terminal node of - the subgraph and the set of all the nodes in the subgraph. - """ - - def is_float_to_single_int_node(node: IntermediateNode) -> bool: - return ( - any(isinstance(input_.dtype, Float) for input_ in node.inputs) - and len(node.outputs) == 1 - and isinstance(node.outputs[0].dtype, Integer) - ) - - float_subgraphs_terminal_nodes = ( - node - for node in nx_graph.nodes() - if is_float_to_single_int_node(node) and node not in processed_terminal_nodes - ) - - terminal_node: IntermediateNode - - try: - terminal_node = next(float_subgraphs_terminal_nodes) - except StopIteration: - return None - - # networkx does not implement lowest common ancestor search for multidigraph, but we only care - # about parent relationship here and not the meaning of edges, so we can convert our - # multidigraph to a digraph and use the lca search algorithm (if needed), we create the - # equivalent digraph here as it will avoid recreating it in a loop. Constant nodes could cause - # issues in our search so we remove them. - equivalent_digraph_without_constants = nx.DiGraph(nx_graph) - constant_graph_nodes = [ - constant_node - for constant_node in equivalent_digraph_without_constants.nodes() - if isinstance(constant_node, Constant) - ] - equivalent_digraph_without_constants.remove_nodes_from(constant_graph_nodes) - - # Use dict as ordered set - subgraph_all_nodes: Dict[IntermediateNode, None] = {} - - start_single_int_output_nodes_search_from = terminal_node - - while True: - float_subgraph_start_nodes, subgraph_all_nodes = find_closest_single_int_output_nodes( - nx_graph, - [start_single_int_output_nodes_search_from], - subgraph_all_nodes, - ) - - variable_start_nodes = [ - start_node - for start_node in float_subgraph_start_nodes - if not isinstance(start_node, Constant) - ] - - # We found a single input variable node - if len(variable_start_nodes) == 1: - break - - # Otherwise find a common ancestor as we need a single variable input node - # lca == lowest common ancestor - # lca search only works for node pairs in networkx, so we progressively find the ancestors - # setting the lca by default to one of the nodes we are searching the lca for - lca = variable_start_nodes.pop() - - while len(variable_start_nodes) > 0 and lca is not None: - node_to_find_lca = variable_start_nodes.pop() - lca = nx.algorithms.lowest_common_ancestors.lowest_common_ancestor( - equivalent_digraph_without_constants, lca, node_to_find_lca, default=None - ) - - # The subgraph cannot be fused as there is no way to find a common ancestor - if lca is None: - break - - # if lca is not None, add the nodes from the current start nodes to the lca to - # subgraph_all_nodes - subgraph_all_nodes = add_nodes_from_to( - nx_graph, float_subgraph_start_nodes, {lca: None}, subgraph_all_nodes - ) - - # if the lca is a valid starting node for fusing break - if is_single_int_output_node(lca): - # the lca is our new start node - float_subgraph_start_nodes = {lca: None} - break - - # otherwise push a little bit further the search (if there is a node just before that has an - # integer output e.g.) - start_single_int_output_nodes_search_from = lca - - return float_subgraph_start_nodes, terminal_node, subgraph_all_nodes - - -def subgraph_nodes_and_values_allow_fusing( - float_subgraph_start_nodes: Dict[IntermediateNode, None], - subgraph_all_nodes: Dict[IntermediateNode, None], - node_with_issues_for_fusing: DefaultDict[IntermediateNode, List[str]], -) -> bool: - """Check if a subgraph's values are compatible with fusing. - - A fused subgraph for example only works on an input tensor if the resulting GenericFunction - can be applied per cell, hence shuffling or tensor shape changes make fusing impossible. - - Args: - float_subgraph_start_nodes (Dict[IntermediateNode, None]): The nodes starting the float - subgraph. - subgraph_all_nodes (Dict[IntermediateNode, None]): All the nodes in the float subgraph. - node_with_issues_for_fusing (DefaultDict[IntermediateNode, List[str]]): Dictionary to fill - with potential nodes issues preventing fusing. - - Returns: - bool: True if all inputs and outputs of the nodes in the subgraph are compatible with fusing - i.e. outputs have the same shapes equal to the variable input. - """ - - node: IntermediateNode - - variable_input_nodes = [ - node for node in float_subgraph_start_nodes if not isinstance(node, Constant) - ] - - assert_true( - (num_variable_input_nodes := len(variable_input_nodes)) == 1, - f"{subgraph_nodes_and_values_allow_fusing.__name__} " - f"only works for subgraphs with 1 variable input node, got {num_variable_input_nodes}", - ) - - explicitely_non_fusable = [ - node - for node in subgraph_all_nodes - if isinstance(node, GenericFunction) and not node.op_attributes["fusable"] - ] - for node in explicitely_non_fusable: - node_with_issues_for_fusing[node].append( - "this node is explicitely marked by the package as non-fusable" - ) - if len(explicitely_non_fusable) > 0: - return False - - all_values_are_tensors = all( - all(isinstance(input_, TensorValue) for input_ in node.inputs) - and all(isinstance(output, TensorValue) for output in node.outputs) - for node in subgraph_all_nodes - ) - - if not all_values_are_tensors: - # This cannot be reached today as scalars are Tensors with shape == () (numpy convention) - return False # pragma: no cover - - variable_input_node = variable_input_nodes[0] - - # A cheap check is that the variable input node must have the biggest size, i.e. have the most - # elements, meaning all constants will broadcast to its shape. This is because the - # GenericFunction input and output must have the same shape so that it can be applied to each - # of the input tensor cells. - # There *may* be a way to manage the other case by simulating the broadcast of the smaller input - # array and then concatenating/stacking the results. This is not currently doable as we don't - # have a concatenate operator on the compiler side. - # TODO: #587 https://github.com/zama-ai/concrete-numpy-internal/issues/587 - - variable_input_node_output = cast(TensorValue, variable_input_node.outputs[0]) - variable_input_node_output_size, variable_input_node_output_shape = ( - variable_input_node_output.size, - variable_input_node_output.shape, - ) - - constant_nodes_with_bigger_size_than_variable_input = [ - constant_input_node - for constant_input_node in subgraph_all_nodes - if isinstance(constant_input_node, Constant) - and cast(TensorValue, constant_input_node.outputs[0]).size > variable_input_node_output_size - ] - - for bigger_constant_node in constant_nodes_with_bigger_size_than_variable_input: - bigger_constant_node_shape = cast(TensorValue, bigger_constant_node.outputs[0]).shape - node_with_issues_for_fusing[bigger_constant_node].append( - f"this constant node has a bigger shape {bigger_constant_node_shape} " - f"than the subgraph's input: {variable_input_node_output_shape}" - ) - - if len(constant_nodes_with_bigger_size_than_variable_input) > 0: - node_with_issues_for_fusing[variable_input_node].append( - f"input node with shape {variable_input_node_output_shape}" - ) - return False - - # Now that we know the variable input node has the biggest size we can check shapes are - # consistent throughout the subgraph: outputs of ir nodes that are not constant must be equal. - - non_constant_nodes = (node for node in subgraph_all_nodes if not isinstance(node, Constant)) - - nodes_with_different_output_shapes = { - node: [ - (output_idx, output.shape) - for output_idx, output in enumerate(node.outputs) - if isinstance(output, TensorValue) and output.shape != variable_input_node - ] - for node in non_constant_nodes - if any( - isinstance(output, TensorValue) and output.shape != variable_input_node_output_shape - for output in node.outputs - ) - } - - for node, node_shape_infos in nodes_with_different_output_shapes.items(): - shape_issue_details = "; ".join( - f"#{output_idx}, {output_shape}" for output_idx, output_shape in node_shape_infos - ) - node_with_issues_for_fusing[node].append( - f"output shapes: {shape_issue_details} are not the same as the subgraph's input: " - f"{variable_input_node_output_shape}" - ) - - all_nodes_have_same_shape_as_input = len(nodes_with_different_output_shapes) == 0 - - if not all_nodes_have_same_shape_as_input: - node_with_issues_for_fusing[variable_input_node].append( - f"input node with shape {variable_input_node_output_shape}" - ) - - # All non constant node outputs currently need to have the same shape - return all_nodes_have_same_shape_as_input - - -def subgraph_has_unique_variable_input( - float_subgraph_start_nodes: Dict[IntermediateNode, None], - terminal_node: IntermediateNode, - node_with_issues_for_fusing: DefaultDict[IntermediateNode, List[str]], -) -> bool: - """Check that only one of the nodes starting the subgraph is variable. - - Args: - float_subgraph_start_nodes (Dict[IntermediateNode, None]): The nodes starting the subgraph. - terminal_node (IntermediateNode): The node ending the float subgraph. - node_with_issues_for_fusing (DefaultDict[IntermediateNode, List[str]]): Dictionary to fill - with potential nodes issues preventing fusing. - - Returns: - bool: True if only one of the nodes is not an Constant - """ - - variable_inputs_list = [ - node for node in float_subgraph_start_nodes if not isinstance(node, Constant) - ] - variable_inputs_num = len(variable_inputs_list) - - # Only one input to the subgraph where computations are done in floats can be variable, this - # is the only case we can manage with GenericFunction fusing - has_unique_variable_input = variable_inputs_num == 1 - - if not has_unique_variable_input: - for node in variable_inputs_list: - node_with_issues_for_fusing[node].append( - f"one of {variable_inputs_num} variable inputs (can only have 1 for fusing)" - ) - node_with_issues_for_fusing[terminal_node].append( - f"cannot fuse here as the subgraph has {variable_inputs_num} variable inputs" - ) - - return has_unique_variable_input diff --git a/concrete/common/representation/__init__.py b/concrete/common/representation/__init__.py deleted file mode 100644 index 5a86259b7..000000000 --- a/concrete/common/representation/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Representation module to represent source programs.""" -from . import intermediate diff --git a/concrete/common/representation/intermediate.py b/concrete/common/representation/intermediate.py deleted file mode 100644 index 53048dff2..000000000 --- a/concrete/common/representation/intermediate.py +++ /dev/null @@ -1,650 +0,0 @@ -"""File containing code to represent source programs operations.""" - -from abc import ABC, abstractmethod -from collections import deque -from copy import deepcopy -from enum import Enum, unique -from math import floor -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast - -import numpy as np -import torch -from loguru import logger - -from ..data_types.base import BaseDataType -from ..data_types.dtypes_helpers import ( - get_base_value_for_python_constant_data, - mix_values_determine_holding_dtype, -) -from ..data_types.integers import Integer -from ..debugging.custom_assert import assert_true -from ..helpers import indexing_helpers -from ..helpers.formatting_helpers import format_constant -from ..helpers.python_helpers import catch, update_and_return_dict -from ..values import BaseValue, ClearTensor, EncryptedTensor, TensorValue - -IR_MIX_VALUES_FUNC_ARG_NAME = "mix_values_func" - -ALL_IR_NODES: Set[Type] = set() - - -class IntermediateNode(ABC): - """Abstract Base Class to derive from to represent source program operations.""" - - inputs: List[BaseValue] - outputs: List[BaseValue] - _n_in: int # _n_in indicates how many inputs are required to evaluate the IntermediateNode - - def __init__( - self, - inputs: Iterable[BaseValue], - **_kwargs, # This is to be able to feed arbitrary arguments to IntermediateNodes - ) -> None: - self.inputs = list(inputs) - assert_true(all(isinstance(x, BaseValue) for x in self.inputs)) - - # Register all IR nodes - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - ALL_IR_NODES.add(cls) - - def _init_binary( - self, - inputs: Iterable[BaseValue], - mix_values_func: Callable[..., BaseValue] = mix_values_determine_holding_dtype, - **_kwargs, # Required to conform to __init__ typing - ) -> None: - """__init__ for a binary operation, ie two inputs.""" - IntermediateNode.__init__(self, inputs) - - assert_true(len(self.inputs) == 2) - - self.outputs = [mix_values_func(self.inputs[0], self.inputs[1])] - - def text_for_formatting(self, predecessors: List[str], _maximum_constant_length: int) -> str: - """Get the formatted node (used in formatting operation graphs). - - Args: - predecessors (List[str]): predecessor names to this node - _maximum_constant_length (int): desired maximum constant length - - Returns: - str: the formatted node - """ - - return f"{self.__class__.__name__.lower()}({', '.join(predecessors)})" - - @abstractmethod - def text_for_drawing(self) -> str: - """Get the label of the node (used in drawing operation graphs). - - Returns: - str: the label of the node - """ - - @abstractmethod - def evaluate(self, inputs: Dict[int, Any]) -> Any: - """Simulate what the represented computation would output for the given inputs. - - Args: - inputs (Dict[int, Any]): Dict containing the inputs for the evaluation - - Returns: - Any: the result of the computation - """ - - @classmethod - def n_in(cls) -> int: - """Return how many inputs the node has. - - Returns: - int: The number of inputs of the node. - """ - return cls._n_in - - @classmethod - def requires_mix_values_func(cls) -> bool: - """Determine whether the Class requires a mix_values_func to be built. - - Returns: - bool: True if __init__ expects a mix_values_func argument. - """ - return cls.n_in() > 1 - - -class Add(IntermediateNode): - """Addition between two values.""" - - _n_in: int = 2 - - __init__ = IntermediateNode._init_binary - - def text_for_drawing(self) -> str: - return "+" - - def evaluate(self, inputs: Dict[int, Any]) -> Any: - return inputs[0] + inputs[1] - - -class Sub(IntermediateNode): - """Subtraction between two values.""" - - _n_in: int = 2 - - __init__ = IntermediateNode._init_binary - - def text_for_drawing(self) -> str: - return "-" - - def evaluate(self, inputs: Dict[int, Any]) -> Any: - return inputs[0] - inputs[1] - - -class Mul(IntermediateNode): - """Multiplication between two values.""" - - _n_in: int = 2 - - __init__ = IntermediateNode._init_binary - - def text_for_drawing(self) -> str: - return "*" - - def evaluate(self, inputs: Dict[int, Any]) -> Any: - return inputs[0] * inputs[1] - - -class Input(IntermediateNode): - """Node representing an input of the program.""" - - input_name: str - program_input_idx: int - _n_in: int = 1 - - def __init__( - self, - input_value: BaseValue, - input_name: str, - program_input_idx: int, - ) -> None: - super().__init__((input_value,)) - assert_true(len(self.inputs) == 1) - self.input_name = input_name - self.program_input_idx = program_input_idx - self.outputs = [deepcopy(self.inputs[0])] - - def text_for_formatting(self, predecessors: List[str], _maximum_constant_length: int) -> str: - assert_true(len(predecessors) == 0) - return self.input_name - - def text_for_drawing(self) -> str: - return self.input_name - - def evaluate(self, inputs: Dict[int, Any]) -> Any: - return inputs[0] - - -class Constant(IntermediateNode): - """Node representing a constant of the program.""" - - _constant_data: Any - _n_in: int = 0 - - def __init__( - self, - constant_data: Any, - get_base_value_for_data_func: Callable[ - [Any], Callable[..., BaseValue] - ] = get_base_value_for_python_constant_data, - ) -> None: - super().__init__([]) - - base_value_class = get_base_value_for_data_func(constant_data) - - self._constant_data = constant_data - self.outputs = [base_value_class(is_encrypted=False)] - - def text_for_formatting(self, predecessors: List[str], maximum_constant_length: int) -> str: - assert_true(len(predecessors) == 0) - return format_constant(self.constant_data, maximum_constant_length) - - def text_for_drawing(self) -> str: - return format_constant(self.constant_data) - - def evaluate(self, inputs: Dict[int, Any]) -> Any: - return self.constant_data - - @property - def constant_data(self) -> Any: - """Return the constant_data stored in the Constant node. - - Returns: - Any: The constant data that was stored. - """ - return self._constant_data - - -class Conv2D(IntermediateNode): - """Return the node representing a 2d-convolution.""" - - def __init__( - self, - inputs: Iterable[BaseValue], - output_dtype: BaseDataType, - pads: Union[List[int], Tuple[int, int, int, int]], - strides: Union[List[int], Tuple[int, int]], - dilations: Union[List[int], Tuple[int, int]], - ) -> None: - - # TODO: remove this when padding is supported (#427) - assert all(pad == 0 for pad in pads), "conv2d doesn't support padding yet" - - super().__init__(inputs) - self.pads = pads - self.strides = strides - self.dilations = dilations - - self._n_in = len(self.inputs) - assert_true(len(self.inputs) == 2 or len(self.inputs) == 3) - - assert_true( - all( - isinstance(input_value, TensorValue) and input_value.ndim == 4 - for input_value in self.inputs[:2] - ), - f"Conv2D only supports input and weight tensors of 4 dimensions" - f"({TensorValue.__name__} with ndim == 4)", - ) - bias = cast(TensorValue, self.inputs[2]) if len(self.inputs) == 3 else None - if bias is not None: - assert_true( - isinstance(bias, TensorValue) and bias.ndim == 1, - f"Conv2D only supports bias 1 dimension ({TensorValue.__name__} with ndim == 1)", - ) - - x = cast(TensorValue, self.inputs[0]) - weight = cast(TensorValue, self.inputs[1]) - - # Compute output shape - input_n, _, input_h, input_w = x.shape - weight_f, _, weight_h, weight_w = weight.shape - pads_h = pads[0] + pads[2] - pads_w = pads[1] + pads[3] - output_h = floor((input_h + pads_h - dilations[0] * (weight_h - 1) - 1) / strides[0]) + 1 - output_w = floor((input_w + pads_w - dilations[1] * (weight_w - 1) - 1) / strides[1]) + 1 - output_shape = (input_n, weight_f, output_h, output_w) - - output_value = EncryptedTensor(dtype=output_dtype, shape=output_shape) - self.outputs = [output_value] - - def text_for_drawing(self) -> str: - return "conv2d" - - def evaluate(self, inputs: Dict[int, Any]) -> Any: - - assert_true( - len(inputs) == self._n_in, f"expected {self.n_in} inputs, but got {len(inputs)}" - ) - x, weight = inputs[0], inputs[1] - bias = inputs[2] if len(inputs) == 3 else np.zeros(weight.shape[0]) - - return self.evaluate_conv2d(x, weight, bias, self.pads, self.strides, self.dilations) - - @staticmethod - def evaluate_conv2d( - x: np.ndarray, - weight: np.ndarray, - bias: np.ndarray, - # TODO: use padding when supported (#427) - _: Union[Tuple[int, int, int, int], List[int]], - strides: Union[Tuple[int, int], List[int]], - dilations: Union[Tuple[int, int], List[int]], - ): - """Evaluate 2D convolution. - - Args: - x (np.ndarray): Input of shape (NxCxHxW) - weight (np.ndarray): Weight (kernel) of shape (FxCxHxW) - bias (np.ndarray): Bias vector of size (F) - pads (Union[Tuple[int, int, int, int], List[int]]): Padding over each - axis (H_beg, W_beg, H_end, W_end) - strides (Union[Tuple[int, int], List[int]]): Stride over each - axis (height and width) - dilations (Union[Tuple[int, int], List[int]]): Dilation over each - axis (height and width) - - Returns: - np.ndarray: Result of the convolution of shape (NxCxHxW) - """ - # pylint: disable=no-member - return torch.conv2d( - torch.tensor(x, dtype=torch.long), - torch.tensor(weight, dtype=torch.long), - torch.tensor(bias, dtype=torch.long), - stride=strides, - dilation=dilations, - ).numpy() - # pylint: enable=no-member - - -class IndexConstant(IntermediateNode): - """Node representing a constant indexing in the program. - - What we mean by constant indexing is that the index part of the operation is a constant. - Here are some examples: `x[2]`, `x[0, 1]`, `y[:, 0]`, `y[3:, :5]` - - The opposite is to have dynamic indexing, which this node does not support. - Some examples of dynamic indexing are: `x[y]`, `x[y, z]`, `x[:, y]` - """ - - _n_in: int = 1 - - index: Tuple[Union[int, slice], ...] - - def __init__( - self, - input_: BaseValue, - index: Union[int, slice, Tuple[Union[int, slice], ...]], - ) -> None: - super().__init__((input_,)) - - if not isinstance(self.inputs[0], TensorValue) or self.inputs[0].is_scalar: - raise TypeError(f"Only tensors can be indexed but you tried to index {self.inputs[0]}") - - self.index = indexing_helpers.validate_index(index) - - output_dtype = self.inputs[0].dtype - output_shape = indexing_helpers.determine_output_shape(self.inputs[0].shape, self.index) - - self.outputs = [ - EncryptedTensor(output_dtype, output_shape) - if self.inputs[0].is_encrypted - else ClearTensor(output_dtype, output_shape) - ] - - def text_for_formatting(self, predecessors: List[str], _maximum_constant_length: int) -> str: - assert_true(len(predecessors) == 1) - elements = [indexing_helpers.format_indexing_element(element) for element in self.index] - index = ", ".join(elements) - return f"{predecessors[0]}[{index}]" - - def text_for_drawing(self) -> str: - return self.text_for_formatting(["value"], 0) # 0 is unused - - def evaluate(self, inputs: Dict[int, Any]) -> Any: - return inputs[0][self.index] - - -def flood_replace_none_values(table: list): - """Use a flooding algorithm to replace None values. - - Args: - table (list): the list in which there are None values that need to be replaced by copies of - the closest non None data from the list. - """ - assert_true(any(value is not None for value in table)) - - not_none_values_idx = deque(idx for idx, value in enumerate(table) if value is not None) - while not_none_values_idx: - current_idx = not_none_values_idx.popleft() - current_value = table[current_idx] - previous_idx = current_idx - 1 - next_idx = current_idx + 1 - if previous_idx >= 0 and table[previous_idx] is None: - table[previous_idx] = deepcopy(current_value) - not_none_values_idx.append(previous_idx) - if next_idx < len(table) and table[next_idx] is None: - table[next_idx] = deepcopy(current_value) - not_none_values_idx.append(next_idx) - - assert_true(all(value is not None for value in table)) - - -@unique -class GenericFunctionKind(str, Enum): - """Enum to validate GenericFunction op_kind.""" - - TLU = "TLU" - MEMORY = "Memory" - - -class GenericFunction(IntermediateNode): - """Node representing an arbitrary function with a single output, e.g. sin(x).""" - - # The arbitrary_func is not optional but mypy has a long standing bug and is not able to - # understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623 - # arbitrary_func can take more than one argument but during evaluation the input variable will - # be the first argument passed to it. You can add other constant arguments needed for the proper - # execution of the function through op_args and op_kwargs. - arbitrary_func: Optional[Callable] - op_kind: GenericFunctionKind - op_name: str - op_args: Tuple[Any, ...] - op_kwargs: Dict[str, Any] - op_attributes: Dict[str, Any] - _n_in: int - - # TODO: https://github.com/zama-ai/concrete-numpy-internal/issues/798 have a proper - # attribute system - DEFAULT_OP_ATTRIBUTES: Dict[str, Any] = {"fusable": True} - - KWARGS_IGNORED_IN_FORMATTING: Set[str] = { - "float_op_subgraph", - "terminal_node", - } - - def __init__( - self, - inputs: Iterable[BaseValue], - arbitrary_func: Callable, - output_value: BaseValue, - op_kind: Union[str, GenericFunctionKind], - op_name: Optional[str] = None, - op_args: Optional[Tuple[Any, ...]] = None, - op_kwargs: Optional[Dict[str, Any]] = None, - op_attributes: Optional[Dict[str, Any]] = None, - ) -> None: - super().__init__([deepcopy(i) for i in inputs]) - self._n_in = len(self.inputs) - self.arbitrary_func = arbitrary_func - self.op_kind = GenericFunctionKind(op_kind) - self.op_args = op_args if op_args is not None else () - self.op_kwargs = op_kwargs if op_kwargs is not None else {} - self.op_attributes = deepcopy(self.DEFAULT_OP_ATTRIBUTES) - if op_attributes is not None: - self.op_attributes.update(op_attributes) - - self.outputs = [output_value] - - self.op_name = op_name if op_name is not None else self.__class__.__name__ - - def text_for_formatting(self, predecessors: List[str], maximum_constant_length: int) -> str: - if self.op_name == "concat": - all_args = ["(" + ", ".join(predecessors) + ")"] - else: - all_args = deepcopy(predecessors) - - all_args.extend(format_constant(value, maximum_constant_length) for value in self.op_args) - all_args.extend( - f"{name}={format_constant(value, maximum_constant_length)}" - for name, value in self.op_kwargs.items() - if name not in GenericFunction.KWARGS_IGNORED_IN_FORMATTING - ) - - return f"{self.op_name}({', '.join(all_args)})" - - def text_for_drawing(self) -> str: - return self.op_name - - def evaluate(self, inputs: Dict[int, Any]) -> Any: - # This is the continuation of the mypy bug workaround - assert self.arbitrary_func is not None - ordered_inputs = [inputs[idx] for idx in range(len(inputs))] - if self.op_name == "concat": - return self.arbitrary_func(tuple(ordered_inputs), *self.op_args, **self.op_kwargs) - return self.arbitrary_func(*ordered_inputs, *self.op_args, **self.op_kwargs) - - def get_table(self, ordered_preds: List[IntermediateNode]) -> List[Any]: - """Get the table for the current input value of this GenericFunction. - - This function only works if the GenericFunction variable input value is an Integer. - This function only works if there is a single variable input node among ordered_preds. - - Args: - ordered_preds (List[IntermediateNode]): List of predecessors of the node. This list must - contain a single non constant node and any number of Constant nodes. - - Returns: - List[Any]: The table. - """ - - variable_input_indices = [ - idx for idx, pred in enumerate(ordered_preds) if not isinstance(pred, Constant) - ] - - assert_true( - (non_constant_pred_count := len(variable_input_indices)) == 1, - f"Can only have 1 non constant predecessor in {self.get_table.__name__}, " - f"got {non_constant_pred_count}", - ) - - variable_input_idx = variable_input_indices[0] - variable_input_dtype = self.inputs[variable_input_idx].dtype - # Check the input is an integer to be able to build a table - assert_true( - isinstance(variable_input_dtype, Integer), - f"{self.get_table.__name__} only works for an unsigned Integer input", - ) - variable_input_dtype = cast(Integer, variable_input_dtype) - - input_value_constructor = self.inputs[variable_input_idx].underlying_constructor - if input_value_constructor is None: - logger.info( - f"{self.__class__.__name__} input data type constructor was None, defaulting to int" - ) - input_value_constructor = int - - min_input_range = variable_input_dtype.min_value() - max_input_range = variable_input_dtype.max_value() + 1 - - template_input_dict = { - idx: node.evaluate({}) if isinstance(node, Constant) else None - for idx, node in enumerate(ordered_preds) - } - - table = [ - catch( - self.evaluate, - update_and_return_dict( - template_input_dict, {variable_input_idx: input_value_constructor(input_value)} - ), - ) - for input_value in range(min_input_range, max_input_range) - ] - - flood_replace_none_values(table) - - return table - - -def default_dot_evaluation_function(lhs: Any, rhs: Any) -> Any: - """Return the default python dot implementation for 1D iterable arrays. - - Args: - lhs (Any): lhs vector of the dot. - rhs (Any): rhs vector of the dot. - - Returns: - Any: the result of the dot operation. - """ - return sum(lhs * rhs for lhs, rhs in zip(lhs, rhs)) - - -class Dot(IntermediateNode): - """Return the node representing a dot product.""" - - _n_in: int = 2 - # Optional, same issue as in GenericFunction for mypy - evaluation_function: Optional[Callable[[Any, Any], Any]] - # Allows to use specialized implementations from e.g. numpy - - def __init__( - self, - inputs: Iterable[BaseValue], - output_dtype: BaseDataType, - delegate_evaluation_function: Optional[ - Callable[[Any, Any], Any] - ] = default_dot_evaluation_function, - ) -> None: - super().__init__(inputs) - assert_true(len(self.inputs) == 2) - - assert_true( - all( - isinstance(input_value, TensorValue) and input_value.ndim <= 1 - for input_value in self.inputs - ), - f"Dot only supports two scalars or vectors ({TensorValue.__name__} with ndim up to 1)", - ) - - lhs = cast(TensorValue, self.inputs[0]) - rhs = cast(TensorValue, self.inputs[1]) - - if lhs.ndim == 1 and rhs.ndim == 1: - assert_true( - lhs.shape[0] == rhs.shape[0], - f"Dot between vectors of shapes {lhs.shape} and {rhs.shape} is not supported", - ) - - output_shape: Tuple[int, ...] - if (lhs.ndim == 1 and rhs.ndim == 1) or (lhs.ndim == 0 and rhs.ndim == 0): - # numpy.dot(x, y) where x and y are both vectors or both scalars - output_shape = () - elif lhs.ndim == 1: - # numpy.dot(x, y) where x is a vector and y is a scalar - output_shape = lhs.shape - else: - # numpy.dot(x, y) where x is a scalar and y is a vector - output_shape = rhs.shape - - output_value = EncryptedTensor if (lhs.is_encrypted or rhs.is_encrypted) else ClearTensor - - self.outputs = [output_value(output_dtype, output_shape)] - self.evaluation_function = delegate_evaluation_function - - def text_for_drawing(self) -> str: - return "dot" - - def evaluate(self, inputs: Dict[int, Any]) -> Any: - # This is the continuation of the mypy bug workaround - assert self.evaluation_function is not None - return self.evaluation_function(inputs[0], inputs[1]) - - -class MatMul(IntermediateNode): - """Return the node representing a matrix multiplication.""" - - _n_in: int = 2 - - def __init__( - self, - inputs: Iterable[BaseValue], - output_dtype: BaseDataType, - output_shape: Tuple[int, ...], - ) -> None: - super().__init__(inputs) - assert_true(len(self.inputs) == 2) - - output_value = ( - EncryptedTensor(dtype=output_dtype, shape=output_shape) - if (self.inputs[0].is_encrypted or self.inputs[1].is_encrypted) - else ClearTensor(dtype=output_dtype, shape=output_shape) - ) - - self.outputs = [output_value] - - def text_for_drawing(self) -> str: - return "matmul" - - def evaluate(self, inputs: Dict[int, Any]) -> Any: - return inputs[0] @ inputs[1] diff --git a/concrete/common/tracing/__init__.py b/concrete/common/tracing/__init__.py deleted file mode 100644 index e40ece012..000000000 --- a/concrete/common/tracing/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Module for basic tracing facilities.""" -from .base_tracer import BaseTracer -from .tracing_helpers import ( - create_graph_from_output_tracers, - make_input_tracer, - make_input_tracers, - prepare_function_parameters, -) diff --git a/concrete/common/tracing/base_tracer.py b/concrete/common/tracing/base_tracer.py deleted file mode 100644 index 798f9150c..000000000 --- a/concrete/common/tracing/base_tracer.py +++ /dev/null @@ -1,462 +0,0 @@ -"""This file holds the code that can be shared between tracers.""" - -from abc import ABC, abstractmethod -from copy import deepcopy -from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union, cast - -from ..data_types import Float -from ..data_types.base import BaseDataType -from ..debugging.custom_assert import assert_true -from ..representation.intermediate import ( - IR_MIX_VALUES_FUNC_ARG_NAME, - Add, - Constant, - GenericFunction, - IndexConstant, - IntermediateNode, - Mul, - Sub, -) -from ..values import BaseValue, TensorValue - - -class BaseTracer(ABC): - """Base class for implementing tracers.""" - - # this variable changes the behavior of __eq__ so that it can be traced but still allows to hash - # BaseTracers when not tracing. - _is_tracing: bool = False - - inputs: List["BaseTracer"] - traced_computation: IntermediateNode - output_idx: int - output: BaseValue - _mix_values_func: Callable[..., BaseValue] - - def __init__( - self, - inputs: Iterable["BaseTracer"], - traced_computation: IntermediateNode, - output_idx: int, - ) -> None: - self.inputs = list(inputs) - self.traced_computation = traced_computation - self.output_idx = output_idx - self.output = traced_computation.outputs[output_idx] - - @property - def shape(self) -> Tuple[int, ...]: - """Get the shape of the output of the tracer. - - Returns: - Tuple[int, ...]: the shape of the output - """ - - if isinstance(self.output, TensorValue): - return self.output.shape - - raise AttributeError( - f"'{self.__class__.__name__}' object " - f"with '{self.output}' output " - f"has no attribute 'shape'" - ) # pragma: no cover - - # this error cannot be covered because we only have TensorValue for now - - @abstractmethod - def _supports_other_operand(self, other: Any) -> bool: - """Check if the current class supports tracing with the other operand. - - Args: - other (Any): the operand to check compatibility with. - - Returns: - bool: True if the tracer can manage operations with the other operand. - """ - return isinstance(other, self.__class__) - - @abstractmethod - def _make_const_input_tracer(self, constant_data: Any) -> "BaseTracer": - """Create a tracer for a constant input. - - Args: - constant_data (Any): The constant to store. - - Returns: - BaseTracer: The BaseTracer for that constant. - """ - - @classmethod - def set_is_tracing(cls, is_tracing: bool) -> None: - """Set whether we are in a tracing context to change __eq__ behavior. - - Args: - is_tracing (bool): boolean to use to set whether we are tracing - """ - cls._is_tracing = is_tracing - - @classmethod - def _get_mix_values_func(cls): - return cls._mix_values_func - - def _sanitize(self, inp) -> "BaseTracer": - if not isinstance(inp, BaseTracer) and not ( - isinstance(inp, Tuple) # type: ignore - and all(isinstance(item, BaseTracer) for item in inp) # type: ignore - ): - return self._make_const_input_tracer(inp) - return inp - - def instantiate_output_tracers( - self, - inputs: Iterable[Union["BaseTracer", Any]], - computation_to_trace: Type[IntermediateNode], - ) -> Tuple["BaseTracer", ...]: - """Instantiate all output BaseTracer for a given computation. - - Args: - inputs (Iterable[Union[BaseTracer, Any]]): Previous BaseTracer or data used as inputs - for a new node. - computation_to_trace (Type[IntermediateNode]): The IntermediateNode class - to instantiate for the computation being traced - - Returns: - Tuple[BaseTracer, ...]: A tuple containing an BaseTracer per output function - """ - - # For inputs which are actually constant, first convert into a tracer - sanitized_inputs = [self._sanitize(inp) for inp in inputs] - - additional_parameters = ( - {IR_MIX_VALUES_FUNC_ARG_NAME: self._get_mix_values_func()} - if computation_to_trace.requires_mix_values_func() - else {} - ) - - traced_computation = computation_to_trace( - (x.output for x in sanitized_inputs), - **additional_parameters, - ) - - output_tracers = tuple( - self.__class__(sanitized_inputs, traced_computation, output_idx) - for output_idx in range(len(traced_computation.outputs)) - ) - - return output_tracers - - def _helper_for_unary_functions(self, op_lambda: Callable, op_name: str) -> "BaseTracer": - """Trace a unary operator which maintains the shape, which will thus be replaced by a TLU. - - Returns: - BaseTracer: The output NPTracer containing the traced function - """ - first_arg_output = self.output - assert_true(isinstance(first_arg_output, TensorValue)) - first_arg_output = cast(TensorValue, first_arg_output) - - out_dtype = first_arg_output.dtype - out_shape = first_arg_output.shape - - generic_function_output_value = TensorValue( - out_dtype, - first_arg_output.is_encrypted, - out_shape, - ) - - traced_computation = GenericFunction( - inputs=[first_arg_output], - arbitrary_func=op_lambda, - output_value=generic_function_output_value, - op_kind="TLU", - op_name=f"{op_name}", - ) - output_tracer = self.__class__( - [self], - traced_computation=traced_computation, - output_idx=0, - ) - return output_tracer - - def _helper_for_binary_functions_with_one_cst_input( - self, - lhs: Union["BaseTracer", Any], - rhs: Union["BaseTracer", Any], - op_lambda: Callable, - op_name: str, - output_dtype: Optional[BaseDataType] = None, - ) -> "BaseTracer": - """Trace a binary operator which maintains the shape, when one input is a constant. - - This function is helpful to convert an operation with two inputs, one of which being a - constant, into a TLU, while maintaining the constant somewhere in the graph, eg to simplify - debugging. - - Returns: - BaseTracer: The output NPTracer containing the traced function - """ - if isinstance(lhs, BaseTracer): - if not self._supports_other_operand(rhs): - return NotImplemented - elif isinstance(rhs, BaseTracer): - if not self._supports_other_operand(lhs): - return NotImplemented - - sanitized_inputs = [self._sanitize(inp) for inp in [lhs, rhs]] - - # One of the inputs has to be constant - if not ( - isinstance(sanitized_inputs[0].traced_computation, Constant) - or isinstance(sanitized_inputs[1].traced_computation, Constant) - ): - raise NotImplementedError(f"Can't manage binary operator {op_name}") - - sanitized_input_values = [san_input.output for san_input in sanitized_inputs] - output_value = self._get_mix_values_func()(*sanitized_input_values) - if output_dtype is not None: - output_value.dtype = deepcopy(output_dtype) - - traced_computation = GenericFunction( - inputs=sanitized_input_values, - arbitrary_func=op_lambda, - output_value=output_value, - op_kind="TLU", - op_name=op_name, - ) - - result_tracer = self.__class__(sanitized_inputs, traced_computation, 0) - - return result_tracer - - def __hash__(self) -> int: - return id(self) - - def __add__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": - if not self._supports_other_operand(other): - return NotImplemented - - result_tracer = self.instantiate_output_tracers( - [self, other], - Add, - ) - - assert_true(len(result_tracer) == 1) - return result_tracer[0] - - # With that is that x + 1 and 1 + x have the same graph. If we want to keep - # the order, we need to do as in __rsub__, ie mostly a copy of __sub__ + - # some changes - __radd__ = __add__ - - def __neg__(self) -> "BaseTracer": - return 0 - self - - def __pos__(self) -> "BaseTracer": - # Remark that we don't want to return 'self' since we want the result to be a copy, ie not - # a reference to the same object - return 0 + self - - def _lshift(self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]) -> "BaseTracer": - return self._helper_for_binary_functions_with_one_cst_input( - lhs, rhs, lambda x, y: x << y, "lshift" - ) - - def __lshift__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": - # x << shift - return self._lshift(self, other) - - def __rlshift__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": - # cst << x - return self._lshift(other, self) - - def _rshift(self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]) -> "BaseTracer": - return self._helper_for_binary_functions_with_one_cst_input( - lhs, rhs, lambda x, y: x >> y, "rshift" - ) - - def __rshift__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": - # x >> shift - return self._rshift(self, other) - - def __rrshift__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": - # cst >> x - return self._rshift(other, self) - - def __gt__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": - # x > cst - return self._helper_for_binary_functions_with_one_cst_input( - self, other, lambda x, y: x > y, "gt" - ) - - def __ge__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": - # x >= cst - return self._helper_for_binary_functions_with_one_cst_input( - self, other, lambda x, y: x >= y, "ge" - ) - - def __lt__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": - # x < cst - return self._helper_for_binary_functions_with_one_cst_input( - self, other, lambda x, y: x < y, "lt" - ) - - def __le__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": - # x <= cst - return self._helper_for_binary_functions_with_one_cst_input( - self, other, lambda x, y: x <= y, "le" - ) - - def __eq__(self, other: Union["BaseTracer", Any]): - # x == cst - # Return the tracer if we are tracing, else return the result of the default __eq__ function - # allows to have hash capabilities outside of tracing - return ( - self._helper_for_binary_functions_with_one_cst_input( - self, other, lambda x, y: x == y, "eq" - ) - if self._is_tracing - else self is other - ) - - def __ne__(self, other: Union["BaseTracer", Any]): - # x != cst - return self._helper_for_binary_functions_with_one_cst_input( - self, other, lambda x, y: x != y, "ne" - ) - - def __pow__(self, other: Union["BaseTracer", Any]): - # x ** cst - return self._helper_for_binary_functions_with_one_cst_input( - self, other, lambda x, y: x ** y, "pow" - ) - - def __rpow__(self, other: Union["BaseTracer", Any]): - # cst ** x - return self._helper_for_binary_functions_with_one_cst_input( - other, self, lambda x, y: x ** y, "pow" - ) - - def __mod__(self, other: Union["BaseTracer", Any]): - # x % cst - return self._helper_for_binary_functions_with_one_cst_input( - self, other, lambda x, y: x % y, "mod" - ) - - def __rmod__(self, other: Union["BaseTracer", Any]): - # cst % x - return self._helper_for_binary_functions_with_one_cst_input( - other, self, lambda x, y: x % y, "mod" - ) - - def __and__(self, other: Union["BaseTracer", Any]): - # x & cst - return self._helper_for_binary_functions_with_one_cst_input( - self, other, lambda x, y: x & y, "and" - ) - - def __rand__(self, other: Union["BaseTracer", Any]): - # cst & x - return self._helper_for_binary_functions_with_one_cst_input( - other, self, lambda x, y: x & y, "and" - ) - - def __or__(self, other: Union["BaseTracer", Any]): - # x | cst - return self._helper_for_binary_functions_with_one_cst_input( - self, other, lambda x, y: x | y, "or" - ) - - def __ror__(self, other: Union["BaseTracer", Any]): - # cst | x - return self._helper_for_binary_functions_with_one_cst_input( - other, self, lambda x, y: x | y, "or" - ) - - def __xor__(self, other: Union["BaseTracer", Any]): - # x ^ cst - return self._helper_for_binary_functions_with_one_cst_input( - self, other, lambda x, y: x ^ y, "xor" - ) - - def __rxor__(self, other: Union["BaseTracer", Any]): - # cst ^ x - return self._helper_for_binary_functions_with_one_cst_input( - other, self, lambda x, y: x ^ y, "xor" - ) - - def __sub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": - if not self._supports_other_operand(other): - return NotImplemented - - result_tracer = self.instantiate_output_tracers( - [self, other], - Sub, - ) - - assert_true(len(result_tracer) == 1) - return result_tracer[0] - - def __rsub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": - if not self._supports_other_operand(other): - return NotImplemented - - result_tracer = self.instantiate_output_tracers( - [other, self], - Sub, - ) - - assert_true(len(result_tracer) == 1) - return result_tracer[0] - - def __mul__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": - if not self._supports_other_operand(other): - return NotImplemented - - result_tracer = self.instantiate_output_tracers( - [self, other], - Mul, - ) - - assert_true(len(result_tracer) == 1) - return result_tracer[0] - - # With that is that x * 3 and 3 * x have the same graph. If we want to keep - # the order, we need to do as in __rmul__, ie mostly a copy of __mul__ + - # some changes - __rmul__ = __mul__ - - def __abs__(self): - return self._helper_for_unary_functions(lambda x: x.__abs__(), "__abs__") - - def __invert__(self): - return self._helper_for_unary_functions(lambda x: x.__invert__(), "__invert__") - - def __getitem__(self, item): - traced_computation = IndexConstant(self.output, item) - return self.__class__([self], traced_computation, 0) - - def _truediv( - self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any] - ) -> "BaseTracer": - return self._helper_for_binary_functions_with_one_cst_input( - lhs, rhs, lambda x, y: x / y, "truediv", Float(64) - ) - - def __truediv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": - return self._truediv(self, other) - - def __rtruediv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": - return self._truediv(other, self) - - def _floordiv( - self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any] - ) -> "BaseTracer": - return self._helper_for_binary_functions_with_one_cst_input( - lhs, rhs, lambda x, y: x // y, "floordiv" - ) - - def __floordiv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": - return self._floordiv(self, other) - - def __rfloordiv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": - return self._floordiv(other, self) diff --git a/concrete/common/tracing/tracing_helpers.py b/concrete/common/tracing/tracing_helpers.py deleted file mode 100644 index 85b5492a6..000000000 --- a/concrete/common/tracing/tracing_helpers.py +++ /dev/null @@ -1,162 +0,0 @@ -"""Helper functions for tracing.""" -import collections -from contextlib import contextmanager -from inspect import signature -from typing import Callable, Dict, Iterable, List, OrderedDict, Set, Type - -import networkx as nx -from networkx.algorithms.dag import is_directed_acyclic_graph - -from ..debugging.custom_assert import assert_true -from ..representation.intermediate import Input -from ..values import BaseValue -from .base_tracer import BaseTracer - - -def make_input_tracers( - tracer_class: Type[BaseTracer], - function_parameters: OrderedDict[str, BaseValue], -) -> OrderedDict[str, BaseTracer]: - """Create tracers for a function's parameters. - - Args: - tracer_class (Type[BaseTracer]): the class of tracer to create an Input for - function_parameters (OrderedDict[str, BaseValue]): the dictionary with the parameters names - and corresponding Values - - Returns: - OrderedDict[str, BaseTracer]: the dictionary containing the Input Tracers for each parameter - """ - return collections.OrderedDict( - (param_name, make_input_tracer(tracer_class, param_name, input_idx, param)) - for input_idx, (param_name, param) in enumerate(function_parameters.items()) - ) - - -def make_input_tracer( - tracer_class: Type[BaseTracer], - input_name: str, - input_idx: int, - input_value: BaseValue, -) -> BaseTracer: - """Create a tracer for an input value. - - Args: - tracer_class (Type[BaseTracer]): the class of tracer to create an Input for - input_name (str): the name of the input in the traced function - input_idx (int): the input index in the function parameters - input_value (BaseValue): the Value that is an input and needs to be wrapped in an - BaseTracer - - Returns: - BaseTracer: The BaseTracer for that input value - """ - return tracer_class([], Input(input_value, input_name, input_idx), 0) - - -def prepare_function_parameters( - function_to_trace: Callable, function_parameters: Dict[str, BaseValue] -) -> OrderedDict[str, BaseValue]: - """Filter the passed function_parameters to trace function_to_trace. - - Args: - function_to_trace (Callable): function that will be traced for which parameters are checked - function_parameters (Dict[str, BaseValue]): parameters given to trace the function - - Raises: - ValueError: Raised when some parameters are missing to trace function_to_trace - - Returns: - OrderedDict[str, BaseValue]: filtered function_parameters dictionary - """ - function_signature = signature(function_to_trace) - - missing_args = function_signature.parameters.keys() - function_parameters.keys() - - if len(missing_args) > 0: - raise ValueError( - f"The function '{function_to_trace.__name__}' requires the following parameters" - f"that were not provided: {', '.join(sorted(missing_args))}" - ) - - # This convoluted way of creating the dict is to ensure key order is maintained - return collections.OrderedDict( - (param_name, function_parameters[param_name]) - for param_name in function_signature.parameters.keys() - ) - - -def create_graph_from_output_tracers( - output_tracers: Iterable[BaseTracer], -) -> nx.MultiDiGraph: - """Generate a networkx Directed Graph that represents the computation from a traced function. - - Args: - output_tracers (Iterable[BaseTracer]): the output tracers resulting from running the - function over the proper input tracers - - Returns: - nx.MultiDiGraph: Directed Graph that is guaranteed to be a DAG containing the ir nodes - representing the traced program/function - """ - graph = nx.MultiDiGraph() - - visited_tracers: Set[BaseTracer] = set() - # use dict as ordered set - current_tracers = {tracer: None for tracer in output_tracers} - - while current_tracers: - # use dict as ordered set - next_tracers: Dict[BaseTracer, None] = {} - for tracer in current_tracers: - if tracer in visited_tracers: - continue - current_ir_node = tracer.traced_computation - graph.add_node(current_ir_node) - - for input_idx, input_tracer in enumerate(tracer.inputs): - input_ir_node = input_tracer.traced_computation - output_idx = input_tracer.output_idx - graph.add_node(input_ir_node) - graph.add_edge( - input_ir_node, - current_ir_node, - input_idx=input_idx, - output_idx=output_idx, - ) - if input_tracer not in visited_tracers: - next_tracers.update({input_tracer: None}) - - visited_tracers.add(tracer) - - current_tracers = next_tracers - - assert_true(is_directed_acyclic_graph(graph)) - - # Check each edge is unique - unique_edges = set( - (pred, succ, tuple((k, v) for k, v in edge_data.items())) - for pred, succ, edge_data in graph.edges(data=True) - ) - number_of_edges = len(graph.edges) - assert_true(len(unique_edges) == number_of_edges) - - return graph - - -@contextmanager -def tracing_context(tracer_classes: List[Type[BaseTracer]]): - """Set tracer classes in tracing mode. - - Args: - tracer_classes (List[Type[BaseTracer]]): The list of tracers for which we should enable - tracing. - """ - - try: - for tracer_class in tracer_classes: - tracer_class.set_is_tracing(True) - yield - finally: - for tracer_class in tracer_classes: - tracer_class.set_is_tracing(False) diff --git a/concrete/common/values/__init__.py b/concrete/common/values/__init__.py deleted file mode 100644 index 4a1e3290c..000000000 --- a/concrete/common/values/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Module for value structures.""" - -from . import tensors -from .base import BaseValue -from .tensors import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue diff --git a/concrete/common/values/base.py b/concrete/common/values/base.py deleted file mode 100644 index b33f026e4..000000000 --- a/concrete/common/values/base.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Module that defines the values in a program.""" - -from abc import ABC, abstractmethod -from copy import deepcopy -from typing import Callable, Optional - -from ..data_types.base import BaseDataType - - -class BaseValue(ABC): - """Abstract base class to represent any kind of value in a program.""" - - dtype: BaseDataType - _is_encrypted: bool - underlying_constructor: Optional[Callable] - - def __init__(self, dtype: BaseDataType, is_encrypted: bool) -> None: - self.dtype = deepcopy(dtype) - self._is_encrypted = is_encrypted - self.underlying_constructor = None - - def __repr__(self) -> str: # pragma: no cover - return str(self) - - @abstractmethod - def __eq__(self, other: object) -> bool: - return isinstance(other, self.__class__) and self.dtype == other.dtype - - @property - def is_encrypted(self) -> bool: - """Whether Value is encrypted or not. - - Returns: - bool: True if encrypted False otherwise - """ - return self._is_encrypted - - @property - def is_clear(self) -> bool: - """Whether Value is clear or not. - - Returns: - bool: True if clear False otherwise - """ - return not self._is_encrypted diff --git a/concrete/common/values/tensors.py b/concrete/common/values/tensors.py deleted file mode 100644 index 5a6eb8185..000000000 --- a/concrete/common/values/tensors.py +++ /dev/null @@ -1,142 +0,0 @@ -"""Module that defines the tensor values in a program.""" - -from math import prod -from typing import Tuple - -from ..data_types.base import BaseDataType -from .base import BaseValue - - -class TensorValue(BaseValue): - """Class representing a tensor value.""" - - _shape: Tuple[int, ...] - _ndim: int - _size: int - - def __init__( - self, - dtype: BaseDataType, - is_encrypted: bool, - shape: Tuple[int, ...], - ): - super().__init__(dtype, is_encrypted) - # Managing tensors as in numpy, shape of () means the value is scalar - self._shape = shape - self._ndim = len(self._shape) - self._size = prod(self._shape) if self._shape != () else 1 - - def __eq__(self, other: object) -> bool: - return ( - isinstance(other, self.__class__) - and self.shape == other.shape - and self.ndim == other.ndim - and self.size == other.size - and super().__eq__(other) - ) - - def __str__(self) -> str: - encrypted_str = "Encrypted" if self._is_encrypted else "Clear" - tensor_or_scalar_str = "Scalar" if self.is_scalar else "Tensor" - shape_str = f", shape={self.shape}" if self.shape != () else "" - return f"{encrypted_str}{tensor_or_scalar_str}<{str(self.dtype)}{shape_str}>" - - @property - def shape(self) -> Tuple[int, ...]: - """Return the TensorValue shape property. - - Returns: - Tuple[int, ...]: The TensorValue shape. - """ - return self._shape - - @property - def ndim(self) -> int: - """Return the TensorValue ndim property. - - Returns: - int: The TensorValue ndim. - """ - return self._ndim - - @property - def size(self) -> int: - """Return the TensorValue size property. - - Returns: - int: The TensorValue size. - """ - return self._size - - @property - def is_scalar(self) -> bool: - """Whether Value is scalar or not. - - Returns: - bool: True if scalar False otherwise - """ - return self.shape == () - - -def make_clear_tensor( - dtype: BaseDataType, - shape: Tuple[int, ...], -) -> TensorValue: - """Create a clear TensorValue. - - Args: - dtype (BaseDataType): The data type for the tensor. - shape (Optional[Tuple[int, ...]], optional): The tensor shape. Defaults to None. - - Returns: - TensorValue: The corresponding TensorValue. - """ - return TensorValue(dtype=dtype, is_encrypted=False, shape=shape) - - -def make_encrypted_tensor( - dtype: BaseDataType, - shape: Tuple[int, ...], -) -> TensorValue: - """Create an encrypted TensorValue. - - Args: - dtype (BaseDataType): The data type for the tensor. - shape (Optional[Tuple[int, ...]], optional): The tensor shape. Defaults to None. - - Returns: - TensorValue: The corresponding TensorValue. - """ - return TensorValue(dtype=dtype, is_encrypted=True, shape=shape) - - -ClearTensor = make_clear_tensor -EncryptedTensor = make_encrypted_tensor - - -def make_clear_scalar(dtype: BaseDataType) -> TensorValue: - """Create a clear scalar value. - - Args: - dtype (BaseDataType): The data type for the value. - - Returns: - TensorValue: The corresponding TensorValue. - """ - return TensorValue(dtype=dtype, is_encrypted=False, shape=()) - - -def make_encrypted_scalar(dtype: BaseDataType) -> TensorValue: - """Create an encrypted scalar value. - - Args: - dtype (BaseDataType): The data type for the value. - - Returns: - TensorValue: The corresponding TensorValue. - """ - return TensorValue(dtype=dtype, is_encrypted=True, shape=()) - - -ClearScalar = make_clear_scalar -EncryptedScalar = make_encrypted_scalar diff --git a/concrete/numpy/__init__.py b/concrete/numpy/__init__.py deleted file mode 100644 index a268d866e..000000000 --- a/concrete/numpy/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Module for compiling numpy functions to homomorphic equivalents.""" - -# Import differently to put at the top, and avoid circular import issues -from concrete.numpy.compile import ( - compile_numpy_function, - compile_numpy_function_into_op_graph_and_measure_bounds, -) -from concrete.numpy.np_fhe_compiler import NPFHECompiler -from concrete.numpy.tracing import trace_numpy_function - -from ..common.compilation import CompilationArtifacts, CompilationConfiguration -from ..common.data_types import ( - Float, - Float16, - Float32, - Float64, - Integer, - SignedInteger, - UnsignedInteger, -) -from ..common.debugging import draw_graph, format_operation_graph -from ..common.extensions.convolution import conv2d -from ..common.extensions.multi_table import MultiLookupTable -from ..common.extensions.table import LookupTable -from ..common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue diff --git a/concrete/numpy/compile.py b/concrete/numpy/compile.py deleted file mode 100644 index 444738420..000000000 --- a/concrete/numpy/compile.py +++ /dev/null @@ -1,805 +0,0 @@ -"""numpy compilation function.""" - -import sys -import traceback -from copy import deepcopy -from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union, cast - -import numpy - -from ..common.bounds_measurement.inputset_eval import eval_op_graph_bounds_on_inputset -from ..common.common_helpers import check_op_graph_is_integer_program -from ..common.compilation import CompilationArtifacts, CompilationConfiguration -from ..common.data_types import Integer -from ..common.debugging import format_operation_graph -from ..common.debugging.custom_assert import assert_true -from ..common.fhe_circuit import FHECircuit -from ..common.mlir.utils import ( - check_graph_values_compatibility_with_mlir, - update_bit_width_for_mlir, -) -from ..common.operator_graph import OPGraph -from ..common.optimization.topological import fuse_float_operations -from ..common.representation.intermediate import Add, Constant, GenericFunction, IntermediateNode -from ..common.values import BaseValue, ClearScalar -from ..numpy.tracing import trace_numpy_function -from .np_dtypes_helpers import ( - get_base_data_type_for_numpy_or_python_constant_data, - get_base_value_for_numpy_or_python_constant_data, - get_constructor_for_numpy_or_python_constant_data, -) -from .np_inputset_helpers import _check_special_inputset_availability, _generate_random_inputset -from .np_mlir_converter import NPMLIRConverter - -_COMPILE_FHE_INSECURE_KEY_CACHE_DIR: Optional[str] = None - - -def numpy_max_func(lhs: Any, rhs: Any) -> Any: - """Compute the maximum value between two values which can be numpy classes (e.g. ndarray). - - Args: - lhs (Any): lhs value to compute max from. - rhs (Any): rhs value to compute max from. - - Returns: - Any: maximum scalar value between lhs and rhs. - """ - return numpy.maximum(lhs, rhs).max() - - -def numpy_min_func(lhs: Any, rhs: Any) -> Any: - """Compute the minimum value between two values which can be numpy classes (e.g. ndarray). - - Args: - lhs (Any): lhs value to compute min from. - rhs (Any): rhs value to compute min from. - - Returns: - Any: minimum scalar value between lhs and rhs. - """ - return numpy.minimum(lhs, rhs).min() - - -def sanitize_compilation_configuration_and_artifacts( - compilation_configuration: Optional[CompilationConfiguration] = None, - compilation_artifacts: Optional[CompilationArtifacts] = None, -) -> Tuple[CompilationConfiguration, CompilationArtifacts]: - """Return the proper compilation configuration and artifacts. - - Default values are returned if None is passed for each argument. - - Args: - compilation_configuration (Optional[CompilationConfiguration], optional): the compilation - configuration to sanitize. Defaults to None. - compilation_artifacts (Optional[CompilationArtifacts], optional): the compilation artifacts - to sanitize. Defaults to None. - - Returns: - Tuple[CompilationConfiguration, CompilationArtifacts]: the tuple of sanitized configuration - and artifacts. - """ - # Create default configuration if custom configuration is not specified - compilation_configuration = ( - CompilationConfiguration() - if compilation_configuration is None - else compilation_configuration - ) - - # Create temporary artifacts if custom artifacts is not specified (in case of exceptions) - if compilation_artifacts is None: - compilation_artifacts = CompilationArtifacts() - - return compilation_configuration, compilation_artifacts - - -def get_inputset_to_use( - function_parameters: Dict[str, BaseValue], - inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]], str], - compilation_configuration: CompilationConfiguration, -) -> Union[Iterable[Any], Iterable[Tuple[Any, ...]]]: - """Get the proper inputset to use for compilation. - - Args: - function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the - function is e.g. an EncryptedScalar holding a 7bits unsigned Integer - inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]], str]): The inputset over which - op_graph is evaluated. It needs to be an iterable on tuples which are of the same length - than the number of parameters in the function, and in the same order than these same - parameters - compilation_configuration (CompilationConfiguration): Configuration object to use during - compilation - - Returns: - Union[Iterable[Any], Iterable[Tuple[Any, ...]]]: the inputset to use. - """ - # Generate random inputset if it is requested and available - if isinstance(inputset, str): - _check_special_inputset_availability(inputset, compilation_configuration) - return _generate_random_inputset(function_parameters, compilation_configuration) - return inputset - - -def run_compilation_function_with_error_management( - compilation_function: Callable, - compilation_configuration: CompilationConfiguration, - compilation_artifacts: CompilationArtifacts, -) -> Any: - """Call compilation_function() and manage exceptions that may occur. - - Args: - compilation_function (Callable): the compilation function to call. - compilation_configuration (CompilationConfiguration): the current compilation configuration. - compilation_artifacts (CompilationArtifacts): the current compilation artifacts. - - Returns: - Any: returns the result of the call to compilation_function - """ - - # Try to compile the function and save partial artifacts on failure - try: - # Use context manager to restore numpy error handling - with numpy.errstate(**numpy.geterr()): - return compilation_function() - except Exception: # pragma: no cover - # This branch is reserved for unexpected issues and hence it shouldn't be tested. - # If it could be tested, we would have fixed the underlying issue. - - # We need to export all the information we have about the compilation - # If the user wants them to be exported - - if compilation_configuration.dump_artifacts_on_unexpected_failures: - compilation_artifacts.export() - - traceback_path = compilation_artifacts.output_directory.joinpath("traceback.txt") - with open(traceback_path, "w", encoding="utf-8") as f: - f.write(traceback.format_exc()) - - raise - - -def _compile_numpy_function_into_op_graph_internal( - function_to_compile: Callable, - function_parameters: Dict[str, BaseValue], - compilation_configuration: CompilationConfiguration, - compilation_artifacts: CompilationArtifacts, -) -> OPGraph: - """Compile a function into an OPGraph without evaluating the intermediate nodes bounds. - - Args: - function_to_compile (Callable): The function to compile - function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the - function is e.g. an EncryptedScalar holding a 7bits unsigned Integer - compilation_configuration (CompilationConfiguration): Configuration object to use - during compilation - compilation_artifacts (CompilationArtifacts): Artifacts object to fill - during compilation - - Returns: - OPGraph: compiled function into a graph, node values are not representative of the values - that can be observed during execution. - Use _compile_numpy_function_into_op_graph_and_measure_bounds_internal if you need bounds - estimation. - """ - # Check function parameters - wrong_inputs = { - inp: function_parameters[inp] - for inp in function_parameters.keys() - if not isinstance(function_parameters[inp], BaseValue) - } - list_of_possible_basevalue = [ - "ClearTensor", - "EncryptedTensor", - "ClearScalar", - "EncryptedScalar", - ] - assert_true( - len(wrong_inputs.keys()) == 0, - f"wrong type for inputs {wrong_inputs}, needs to be one of {list_of_possible_basevalue}", - ) - - # Add the function to compile as an artifact - compilation_artifacts.add_function_to_compile(function_to_compile) - - # Add the parameters of function to compile as artifacts - for name, value in function_parameters.items(): - compilation_artifacts.add_parameter_of_function_to_compile(name, str(value)) - - # Trace the function - op_graph = trace_numpy_function(function_to_compile, function_parameters) - - # Add the initial graph as an artifact - compilation_artifacts.add_operation_graph("initial", op_graph) - - # Apply topological optimizations if they are enabled - if compilation_configuration.enable_topological_optimizations: - # Fuse float operations to have int to int GenericFunction - if not check_op_graph_is_integer_program(op_graph): - fuse_float_operations(op_graph, compilation_artifacts) - - return op_graph - - -def compile_numpy_function_into_op_graph( - function_to_compile: Callable, - function_parameters: Dict[str, BaseValue], - compilation_configuration: Optional[CompilationConfiguration] = None, - compilation_artifacts: Optional[CompilationArtifacts] = None, -) -> OPGraph: - """Compile a function into an OPGraph. - - Args: - function_to_compile (Callable): The function to compile - function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the - function is e.g. an EncryptedScalar holding a 7bits unsigned Integer - compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use - during compilation - compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill - during compilation - - Returns: - OPGraph: compiled function into a graph - """ - - ( - compilation_configuration, - compilation_artifacts, - ) = sanitize_compilation_configuration_and_artifacts( - compilation_configuration, compilation_artifacts - ) - - def compilation_function(): - return _compile_numpy_function_into_op_graph_internal( - function_to_compile, - function_parameters, - compilation_configuration, - compilation_artifacts, - ) - - result = run_compilation_function_with_error_management( - compilation_function, compilation_configuration, compilation_artifacts - ) - - # for mypy - assert isinstance(result, OPGraph) - return result - - -def _measure_op_graph_bounds_and_update_internal( - op_graph: OPGraph, - function_parameters: Dict[str, BaseValue], - inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]], - compilation_configuration: CompilationConfiguration, - compilation_artifacts: CompilationArtifacts, - prev_node_bounds_and_samples: Optional[Dict[IntermediateNode, Dict[str, Any]]] = None, - warn_on_inputset_length: bool = True, -) -> Dict[IntermediateNode, Dict[str, Any]]: - """Measure the intermediate values and update the OPGraph accordingly for the given inputset. - - Args: - op_graph (OPGraph): the OPGraph for which to measure bounds and update node values. - function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the - function is e.g. an EncryptedScalar holding a 7bits unsigned Integer - inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]]]): The inputset over which op_graph - is evaluated. It needs to be an iterable on tuples which are of the same length than the - number of parameters in the function, and in the same order than these same parameters - compilation_configuration (CompilationConfiguration): Configuration object to use - during compilation - compilation_artifacts (CompilationArtifacts): Artifacts object to fill - during compilation - prev_node_bounds_and_samples (Optional[Dict[IntermediateNode, Dict[str, Any]]], optional): - Bounds and samples from a previous run. Defaults to None. - warn_on_inputset_length (bool, optional): Set to True to get a warning if inputset is not - long enough. Defaults to True. - - Raises: - ValueError: Raises an error if the inputset is too small and the compilation configuration - treats warnings as error. - - Returns: - Dict[IntermediateNode, Dict[str, Any]]: a dict containing the bounds for each node from - op_graph, stored with the node as key and a dict with keys "min", "max" and "sample" as - value. - """ - - # Find bounds with the inputset - inputset_size, node_bounds_and_samples = eval_op_graph_bounds_on_inputset( - op_graph, - inputset, - compilation_configuration=compilation_configuration, - min_func=numpy_min_func, - max_func=numpy_max_func, - get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data, - prev_node_bounds_and_samples=prev_node_bounds_and_samples, - ) - - if warn_on_inputset_length: - # Check inputset size - inputset_size_upper_limit = 1 - - # this loop will determine the number of possible inputs of the function - # if a function have a single 3-bit input, for example, inputset_size_upper_limit will be 8 - for parameter_value in function_parameters.values(): - if isinstance(parameter_value.dtype, Integer): - # multiple parameter bit-widths are multiplied as they can be combined into an input - inputset_size_upper_limit *= 2 ** parameter_value.dtype.bit_width - - # if the upper limit of the inputset size goes above 10, - # break the loop as we will require at least 10 inputs in this case - if inputset_size_upper_limit > 10: - break - - minimum_required_inputset_size = min(inputset_size_upper_limit, 10) - if inputset_size < minimum_required_inputset_size: - message = ( - f"Provided inputset contains too few inputs " - f"(it should have had at least {minimum_required_inputset_size} " - f"but it only had {inputset_size})\n" - ) - - if compilation_configuration.treat_warnings_as_errors: - raise ValueError(message) - - sys.stderr.write(f"Warning: {message}") - - # Add the bounds as an artifact - compilation_artifacts.add_final_operation_graph_bounds(node_bounds_and_samples) - - # Update the graph accordingly: after that, we have the compilable graph - op_graph.update_values_with_bounds_and_samples( - node_bounds_and_samples, - get_base_data_type_for_numpy_or_python_constant_data, - get_constructor_for_numpy_or_python_constant_data, - ) - - return node_bounds_and_samples - - -def measure_op_graph_bounds_and_update( - op_graph: OPGraph, - function_parameters: Dict[str, BaseValue], - inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]], str], - compilation_configuration: Optional[CompilationConfiguration] = None, - compilation_artifacts: Optional[CompilationArtifacts] = None, - prev_node_bounds_and_samples: Optional[Dict[IntermediateNode, Dict[str, Any]]] = None, - warn_on_inputset_length: bool = True, -) -> Dict[IntermediateNode, Dict[str, Any]]: - """Measure the intermediate values and update the OPGraph accordingly for the given inputset. - - Args: - op_graph (OPGraph): the OPGraph for which to measure bounds and update node values. - function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the - function is e.g. an EncryptedScalar holding a 7bits unsigned Integer - inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]], str]): The inputset over which - op_graph is evaluated. It needs to be an iterable on tuples which are of the same length - than the number of parameters in the function, and in the same order than these same - parameters - compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use - during compilation - compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill - during compilation - prev_node_bounds_and_samples (Optional[Dict[IntermediateNode, Dict[str, Any]]], optional): - Bounds and samples from a previous run. Defaults to None. - warn_on_inputset_length (bool, optional): Set to True to get a warning if inputset is not - long enough. Defaults to True. - - Raises: - ValueError: Raises an error if the inputset is too small and the compilation configuration - treats warnings as error. - - Returns: - Dict[IntermediateNode, Dict[str, Any]]: a dict containing the bounds for each node from - op_graph, stored with the node as key and a dict with keys "min", "max" and "sample" as - value. - """ - - ( - compilation_configuration, - compilation_artifacts, - ) = sanitize_compilation_configuration_and_artifacts( - compilation_configuration, compilation_artifacts - ) - - inputset = get_inputset_to_use(function_parameters, inputset, compilation_configuration) - - def compilation_function(): - return _measure_op_graph_bounds_and_update_internal( - op_graph, - function_parameters, - inputset, - compilation_configuration, - compilation_artifacts, - prev_node_bounds_and_samples, - warn_on_inputset_length, - ) - - result = run_compilation_function_with_error_management( - compilation_function, compilation_configuration, compilation_artifacts - ) - - # for mypy - assert isinstance(result, dict) - return result - - -def _compile_numpy_function_into_op_graph_and_measure_bounds_internal( - function_to_compile: Callable, - function_parameters: Dict[str, BaseValue], - inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]], - compilation_configuration: CompilationConfiguration, - compilation_artifacts: CompilationArtifacts, -) -> OPGraph: - """Compile a function into an OPGraph and evaluate the intermediate nodes bounds. - - Args: - function_to_compile (Callable): The function to compile - function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the - function is e.g. an EncryptedScalar holding a 7bits unsigned Integer - inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]]]): The inputset over which op_graph - is evaluated. It needs to be an iterable on tuples which are of the same length than the - number of parameters in the function, and in the same order than these same parameters - compilation_configuration (CompilationConfiguration): Configuration object to use - during compilation - compilation_artifacts (CompilationArtifacts): Artifacts object to fill - during compilation - - Returns: - OPGraph: compiled function into a graph with estimated bounds in node values. - """ - - op_graph = _compile_numpy_function_into_op_graph_internal( - function_to_compile, - function_parameters, - compilation_configuration, - compilation_artifacts, - ) - - _measure_op_graph_bounds_and_update_internal( - op_graph, - function_parameters, - inputset, - compilation_configuration, - compilation_artifacts, - ) - - # Add the final graph as an artifact - compilation_artifacts.add_operation_graph("final", op_graph) - - return op_graph - - -def compile_numpy_function_into_op_graph_and_measure_bounds( - function_to_compile: Callable, - function_parameters: Dict[str, BaseValue], - inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]], str], - compilation_configuration: Optional[CompilationConfiguration] = None, - compilation_artifacts: Optional[CompilationArtifacts] = None, -) -> OPGraph: - """Compile a function into an OPGraph. - - Args: - function_to_compile (Callable): The function to compile - function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the - function is e.g. an EncryptedScalar holding a 7bits unsigned Integer - inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]], str]): The inputset over which - op_graph is evaluated. It needs to be an iterable on tuples which are of the same length - than the number of parameters in the function, and in the same order than these same - parameters. Alternatively, it can be "random" but that's an unstable feature and should - not be used in production. - compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use - during compilation - compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill - during compilation - - Returns: - OPGraph: compiled function into a graph - """ - - ( - compilation_configuration, - compilation_artifacts, - ) = sanitize_compilation_configuration_and_artifacts( - compilation_configuration, compilation_artifacts - ) - - inputset = get_inputset_to_use(function_parameters, inputset, compilation_configuration) - - def compilation_function(): - return _compile_numpy_function_into_op_graph_and_measure_bounds_internal( - function_to_compile, - function_parameters, - inputset, - compilation_configuration, - compilation_artifacts, - ) - - result = run_compilation_function_with_error_management( - compilation_function, compilation_configuration, compilation_artifacts - ) - - # for mypy - assert isinstance(result, OPGraph) - return result - - -# HACK -# TODO: remove this ugly hack when -# https://github.com/zama-ai/concrete-numpy-internal/issues/1001 is done -# TODO: https://github.com/zama-ai/concrete-numpy-internal/issues/1015 -def hack_offset_negative_inputs_to_lookup_tables(op_graph: OPGraph) -> None: - """Hack the op_graph to add offsets to signed inputs to TLUs. - - Args: - op_graph (OPGraph): the OPGraph to hack. - """ - # Ugly hack to add an offset before entering a TLU if its variable input node has a signed - # output. - # This is ugly as this makes hardcoded assumptions about the way bit widths are handled in MLIR. - # This does not update the TLU input values to allow for proper table generation. - # Thankfully we are not supposed to touch the op_graph beyond that point - for node in list((nx_graph := op_graph.graph).nodes): - if isinstance(node, GenericFunction) and node.op_kind == "TLU": - ordered_preds_and_inputs = op_graph.get_ordered_preds_and_inputs_of(node) - variable_input_indices = [ - idx - for idx, (pred, _) in enumerate(ordered_preds_and_inputs) - if not isinstance(pred, Constant) - ] - assert_true(len(variable_input_indices) == 1) - variable_input_idx = variable_input_indices[0] - variable_input_node = ordered_preds_and_inputs[variable_input_idx][0] - variable_input_value = variable_input_node.outputs[0] - variable_input_dtype = variable_input_value.dtype - assert_true(isinstance(variable_input_dtype, Integer)) - variable_input_dtype = cast(Integer, variable_input_dtype) - if not variable_input_dtype.is_signed: - continue - - # input_bit_width + 1 to be MLIR compliant - input_bit_width = variable_input_dtype.bit_width - mlir_compliant_int_type = Integer(input_bit_width + 1, True) - - # Manually fix the output values to be MLIR compliant - # offset_constant is set to abs(min_value) for the variable input so that the values - # [- 2 ** (n - 1); 2 ** (n - 1) - 1] is mapped to [0; 2 ** n - 1], changing the signed - # TLU to an actual unsigned TLU. The get_table function creates the table from the min - # value to the max value. As we keep the input value as a signed value, it will be from - # - 2 ** (n - 1) to 2 ** (n - 1) - 1. Then, the get_table function stores corresponding - # values in increasing indexes from 0 to 2 ** n - 1. As our signed values have been - # shifted by 2 ** (n - 1), the table will be usable as-is, without needing any change in - # the lambda function of the GenericFunction. - offset_constant = Constant(abs(variable_input_dtype.min_value())) - offset_constant.outputs[0].dtype = deepcopy(mlir_compliant_int_type) - add_offset = Add( - [deepcopy(variable_input_value), ClearScalar(deepcopy(mlir_compliant_int_type))] - ) - add_offset.outputs[0] = deepcopy(variable_input_value) - - nx_graph.remove_edge(variable_input_node, node) - nx_graph.add_edge(variable_input_node, add_offset, input_idx=0, output_idx=0) - nx_graph.add_edge(offset_constant, add_offset, input_idx=1, output_idx=0) - nx_graph.add_edge(add_offset, node, input_idx=variable_input_idx, output_idx=0) - - -def prepare_op_graph_for_mlir(op_graph: OPGraph): - """Prepare OPGraph for MLIR lowering. - - This includes checking compatibility, changing bit-widths, and modifying lookup tables. - - Args: - op_graph (OPGraph): The operation graph to prepare - - Returns: - None - """ - - # Make sure the graph can be lowered to MLIR - offending_nodes = check_graph_values_compatibility_with_mlir(op_graph) - if offending_nodes is not None: - raise RuntimeError( - "function you are trying to compile isn't supported for MLIR lowering\n\n" - + format_operation_graph(op_graph, highlighted_nodes=offending_nodes) - ) - - # Update bit_width for MLIR - update_bit_width_for_mlir(op_graph) - - # HACK - # TODO: remove this ugly hack when - # https://github.com/zama-ai/concrete-numpy-internal/issues/1001 is done - # TODO: https://github.com/zama-ai/concrete-numpy-internal/issues/1015 - hack_offset_negative_inputs_to_lookup_tables(op_graph) - - -def _compile_op_graph_to_fhe_circuit_internal( - op_graph: OPGraph, - show_mlir: bool, - compilation_configuration: CompilationConfiguration, - compilation_artifacts: CompilationArtifacts, -) -> FHECircuit: - """Compile the OPGraph to an FHECircuit. - - Args: - op_graph (OPGraph): the OPGraph to compile. - show_mlir (bool): determine whether we print the mlir string. - compilation_configuration (CompilationConfiguration): Configuration object to use - during compilation - compilation_artifacts (CompilationArtifacts): Artifacts object to fill - during compilation - - Returns: - FHECircuit: the compiled FHECircuit - """ - prepare_op_graph_for_mlir(op_graph) - - # Convert graph to an MLIR representation - converter = NPMLIRConverter() - mlir_result = converter.convert(op_graph) - - # Show MLIR representation if requested - if show_mlir: - print(f"MLIR which is going to be compiled: \n{mlir_result}") - - # Add MLIR representation as an artifact - compilation_artifacts.add_final_operation_graph_mlir(mlir_result) - - if _COMPILE_FHE_INSECURE_KEY_CACHE_DIR is not None and not ( - compilation_configuration.use_insecure_key_cache - and compilation_configuration.enable_unsafe_features - ): - raise RuntimeError( - f"Unable to use insecure key cache {_COMPILE_FHE_INSECURE_KEY_CACHE_DIR} " - "as use_insecure_key_cache or enable_unsafe_features are not set to True in" - "compilation_configuration" - ) - - return FHECircuit( - op_graph, - mlir_result, - unsecure_key_set_cache_path=_COMPILE_FHE_INSECURE_KEY_CACHE_DIR, - auto_parallelize=compilation_configuration.auto_parallelize, - loop_parallelize=compilation_configuration.loop_parallelize, - dataflow_parallelize=compilation_configuration.dataflow_parallelize, - ) - - -def compile_op_graph_to_fhe_circuit( - op_graph: OPGraph, - show_mlir: bool, - compilation_configuration: Optional[CompilationConfiguration] = None, - compilation_artifacts: Optional[CompilationArtifacts] = None, -) -> FHECircuit: - """Compile the OPGraph to an FHECircuit. - - Args: - op_graph (OPGraph): the OPGraph to compile. - show_mlir (bool): determine whether we print the mlir string. - compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use - during compilation - compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill - during compilation - - Returns: - FHECircuit: the compiled circuit and the compiled FHECircuit - """ - - ( - compilation_configuration, - compilation_artifacts, - ) = sanitize_compilation_configuration_and_artifacts( - compilation_configuration, compilation_artifacts - ) - - def compilation_function(): - return _compile_op_graph_to_fhe_circuit_internal( - op_graph, show_mlir, compilation_configuration, compilation_artifacts - ) - - result = run_compilation_function_with_error_management( - compilation_function, compilation_configuration, compilation_artifacts - ) - - # for mypy - assert isinstance(result, FHECircuit) - return result - - -def _compile_numpy_function_internal( - function_to_compile: Callable, - function_parameters: Dict[str, BaseValue], - inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]], - compilation_configuration: CompilationConfiguration, - compilation_artifacts: CompilationArtifacts, - show_mlir: bool, -) -> FHECircuit: - """Compile an homomorphic program (internal part of the API). - - Args: - function_to_compile (Callable): The function you want to compile - function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the - function is e.g. an EncryptedScalar holding a 7bits unsigned Integer - inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]]]): The inputset over which op_graph - is evaluated. It needs to be an iterable on tuples which are of the same length than the - number of parameters in the function, and in the same order than these same parameters - compilation_configuration (CompilationConfiguration): Configuration object to use - during compilation - compilation_artifacts (CompilationArtifacts): Artifacts object to fill - during compilation - show_mlir (bool): if set, the MLIR produced by the converter and which is going - to be sent to the compiler backend is shown on the screen, e.g., for debugging or demo - - Returns: - CompilerEngine: engine to run and debug the compiled graph - """ - - # Compile into an OPGraph - op_graph = _compile_numpy_function_into_op_graph_and_measure_bounds_internal( - function_to_compile, - function_parameters, - inputset, - compilation_configuration, - compilation_artifacts, - ) - - fhe_circuit = _compile_op_graph_to_fhe_circuit_internal( - op_graph, show_mlir, compilation_configuration, compilation_artifacts - ) - - return fhe_circuit - - -def compile_numpy_function( - function_to_compile: Callable, - function_parameters: Dict[str, BaseValue], - inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]], str], - compilation_configuration: Optional[CompilationConfiguration] = None, - compilation_artifacts: Optional[CompilationArtifacts] = None, - show_mlir: bool = False, -) -> FHECircuit: - """Compile an homomorphic program (main API). - - Args: - function_to_compile (Callable): The function to compile - function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the - function is e.g. an EncryptedScalar holding a 7bits unsigned Integer - inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]], str]): The inputset over which - op_graph is evaluated. It needs to be an iterable on tuples which are of the same length - than the number of parameters in the function, and in the same order than these same - parameters. Alternatively, it can be "random" but that's an unstable feature and should - not be used in production. - compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use - during compilation - compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill - during compilation - show_mlir (bool): if set, the MLIR produced by the converter and which is going - to be sent to the compiler backend is shown on the screen, e.g., for debugging or demo - - Returns: - CompilerEngine: engine to run and debug the compiled graph - """ - - ( - compilation_configuration, - compilation_artifacts, - ) = sanitize_compilation_configuration_and_artifacts( - compilation_configuration, compilation_artifacts - ) - - inputset = get_inputset_to_use(function_parameters, inputset, compilation_configuration) - - def compilation_function(): - return _compile_numpy_function_internal( - function_to_compile, - function_parameters, - inputset, - compilation_configuration, - compilation_artifacts, - show_mlir, - ) - - result = run_compilation_function_with_error_management( - compilation_function, compilation_configuration, compilation_artifacts - ) - - # for mypy - assert isinstance(result, FHECircuit) - return result diff --git a/concrete/numpy/np_dtypes_helpers.py b/concrete/numpy/np_dtypes_helpers.py deleted file mode 100644 index 1e3307644..000000000 --- a/concrete/numpy/np_dtypes_helpers.py +++ /dev/null @@ -1,308 +0,0 @@ -"""File to hold code to manage package and numpy dtypes.""" - -from copy import deepcopy -from functools import partial -from typing import Any, Callable, Dict, List, Tuple, Union - -import numpy -from numpy.typing import DTypeLike - -from ..common.data_types.base import BaseDataType -from ..common.data_types.dtypes_helpers import ( - BASE_DATA_TYPES, - find_type_to_hold_both_lossy, - get_base_data_type_for_python_constant_data, - get_base_value_for_python_constant_data, - get_constructor_for_python_constant_data, -) -from ..common.data_types.floats import Float -from ..common.data_types.integers import Integer -from ..common.debugging.custom_assert import assert_true -from ..common.tracing import BaseTracer -from ..common.values import BaseValue, TensorValue - -NUMPY_TO_COMMON_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = { - numpy.dtype(numpy.byte): Integer(numpy.byte(0).nbytes * 8, is_signed=True), - numpy.dtype(numpy.short): Integer(numpy.short(0).nbytes * 8, is_signed=True), - numpy.dtype(numpy.intc): Integer(numpy.intc(0).nbytes * 8, is_signed=True), - numpy.dtype(numpy.int_): Integer(numpy.int_(0).nbytes * 8, is_signed=True), - numpy.dtype(numpy.longlong): Integer(numpy.longlong(0).nbytes * 8, is_signed=True), - numpy.dtype(numpy.int8): Integer(numpy.int8(0).nbytes * 8, is_signed=True), - numpy.dtype(numpy.int16): Integer(numpy.int16(0).nbytes * 8, is_signed=True), - numpy.dtype(numpy.int32): Integer(numpy.int32(0).nbytes * 8, is_signed=True), - numpy.dtype(numpy.int64): Integer(numpy.int64(0).nbytes * 8, is_signed=True), - numpy.dtype(numpy.ubyte): Integer(numpy.ubyte(0).nbytes * 8, is_signed=False), - numpy.dtype(numpy.ushort): Integer(numpy.ushort(0).nbytes * 8, is_signed=False), - numpy.dtype(numpy.uintc): Integer(numpy.uintc(0).nbytes * 8, is_signed=False), - numpy.dtype(numpy.uint): Integer(numpy.uint(0).nbytes * 8, is_signed=False), - numpy.dtype(numpy.ulonglong): Integer(numpy.ulonglong(0).nbytes * 8, is_signed=False), - numpy.dtype(numpy.uint8): Integer(numpy.uint8(0).nbytes * 8, is_signed=False), - numpy.dtype(numpy.uint16): Integer(numpy.uint16(0).nbytes * 8, is_signed=False), - numpy.dtype(numpy.uint32): Integer(numpy.uint32(0).nbytes * 8, is_signed=False), - numpy.dtype(numpy.uint64): Integer(numpy.uint64(0).nbytes * 8, is_signed=False), - numpy.dtype(numpy.float16): Float(16), - numpy.dtype(numpy.float32): Float(32), - numpy.dtype(numpy.float64): Float(64), - numpy.dtype(bool): Integer(8, is_signed=False), -} - -SUPPORTED_NUMPY_DTYPES = tuple(NUMPY_TO_COMMON_DTYPE_MAPPING) -SUPPORTED_NUMPY_DTYPES_CLASS_TYPES = tuple(dtype.type for dtype in NUMPY_TO_COMMON_DTYPE_MAPPING) - -SUPPORTED_DTYPE_MSG_STRING = ", ".join(sorted(str(dtype) for dtype in SUPPORTED_NUMPY_DTYPES)) - - -def convert_numpy_dtype_to_base_data_type(numpy_dtype: DTypeLike) -> BaseDataType: - """Get the corresponding BaseDataType from a numpy dtype. - - Args: - numpy_dtype (DTypeLike): Any python object that can be translated to a numpy.dtype - - Raises: - ValueError: If the numpy_dtype is not supported - - Returns: - BaseDataType: The corresponding data type corresponding to the input numpy_dtype - """ - # Normalize numpy_dtype - normalized_numpy_dtype = numpy.dtype(numpy_dtype) - corresponding_common_dtype = NUMPY_TO_COMMON_DTYPE_MAPPING.get(normalized_numpy_dtype, None) - - if corresponding_common_dtype is None: - raise ValueError( - f"Unsupported numpy type: {numpy_dtype} ({normalized_numpy_dtype}), " - f"supported numpy types: " - f"{SUPPORTED_DTYPE_MSG_STRING}" - ) - - # deepcopy to avoid having the value from the dict modified - return deepcopy(corresponding_common_dtype) - - -def convert_base_data_type_to_numpy_dtype(common_dtype: BaseDataType) -> numpy.dtype: - """Convert a BaseDataType to corresponding numpy.dtype. - - Args: - common_dtype (BaseDataType): dtype to convert to numpy.dtype - - Returns: - numpy.dtype: The resulting numpy.dtype - """ - assert_true( - isinstance(common_dtype, BASE_DATA_TYPES), f"Unsupported common_dtype: {type(common_dtype)}" - ) - type_to_return: numpy.dtype - - if isinstance(common_dtype, Float): - assert_true( - (bit_width := common_dtype.bit_width) - in ( - 16, - 32, - 64, - ), - "Only converting Float(16), Float(32) or Float(64) is supported", - ) - if bit_width == 64: - type_to_return = numpy.dtype(numpy.float64) - elif bit_width == 32: - type_to_return = numpy.dtype(numpy.float32) - else: - type_to_return = numpy.dtype(numpy.float16) - elif isinstance(common_dtype, Integer): - signed = common_dtype.is_signed - if common_dtype.bit_width <= 32: - type_to_return = numpy.dtype(numpy.int32) if signed else numpy.dtype(numpy.uint32) - elif common_dtype.bit_width <= 64: - type_to_return = numpy.dtype(numpy.int64) if signed else numpy.dtype(numpy.uint64) - else: - raise NotImplementedError( - f"Conversion to numpy dtype only supports Integers with bit_width <= 64, " - f"got {common_dtype!r}" - ) - - return type_to_return - - -def get_base_data_type_for_numpy_or_python_constant_data(constant_data: Any) -> BaseDataType: - """Determine the BaseDataType to hold the input constant data. - - Args: - constant_data (Any): The constant data for which to determine the - corresponding BaseDataType. - - Returns: - BaseDataType: The corresponding BaseDataType - """ - base_dtype: BaseDataType - assert_true( - isinstance( - constant_data, (int, float, list, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES) - ), - f"Unsupported constant data of type {type(constant_data)}", - ) - if isinstance(constant_data, (numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)): - native_type = float if (constant_data.dtype in (numpy.float32, numpy.float64)) else int - - min_value = native_type(constant_data.min()) - max_value = native_type(constant_data.max()) - - min_value_dtype = get_base_data_type_for_python_constant_data(min_value) - max_value_dtype = get_base_data_type_for_python_constant_data(max_value) - - # numpy - base_dtype = find_type_to_hold_both_lossy(min_value_dtype, max_value_dtype) - else: - # python - base_dtype = get_base_data_type_for_python_constant_data(constant_data) - return base_dtype - - -def get_base_value_for_numpy_or_python_constant_data( - constant_data: Any, -) -> Callable[..., BaseValue]: - """Determine the BaseValue and BaseDataType to hold the input constant data. - - This function is able to handle numpy types - - Args: - constant_data (Any): The constant data for which to determine the - corresponding BaseValue and BaseDataType. - - Raises: - AssertionError: If `constant_data` is of an unsupported type. - - Returns: - Callable[..., BaseValue]: A partial object that will return the proper BaseValue when called - with `encrypted` as keyword argument (forwarded to the BaseValue `__init__` method). - """ - constant_data_value: Callable[..., BaseValue] - assert_true( - not isinstance(constant_data, list), - "Unsupported constant data of type list " - "(if you meant to use a list as an array, please use numpy.array instead)", - ) - assert_true( - isinstance( - constant_data, - (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES), - ), - f"Unsupported constant data of type {type(constant_data)}", - ) - - base_dtype = get_base_data_type_for_numpy_or_python_constant_data(constant_data) - if isinstance(constant_data, numpy.ndarray): - constant_data_value = partial(TensorValue, dtype=base_dtype, shape=constant_data.shape) - elif isinstance(constant_data, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES): - constant_data_value = partial(TensorValue, dtype=base_dtype, shape=()) - else: - constant_data_value = get_base_value_for_python_constant_data(constant_data) - return constant_data_value - - -def get_numpy_function_output_dtype_and_shape_from_input_dtypes( - function: Union[numpy.ufunc, Callable], - input_dtypes: List[BaseDataType], - input_shapes: List[Tuple[int, ...]], -) -> List[Tuple[numpy.dtype, Tuple[int, ...]]]: - """Record the output dtype of a numpy function given some input types. - - Args: - function (Union[numpy.ufunc, Callable]): The numpy function whose output types need to - be recorded - input_dtypes (List[BaseDataType]): BaseDataTypes in the same order as they will be used with - the function inputs - input_shapes (List[Tuple[int, ...]]): Shapes in the same order as they will be used with - the function inputs - - Returns: - List[Tuple[numpy.dtype, Tuple[int, ...]]]: appropriate (numpy.dtype, shape) tuple for each - output of the function - """ - if isinstance(function, numpy.ufunc): - assert_true( - (len(input_dtypes) == function.nin), - f"Expected {function.nin} types, got {len(input_dtypes)}: {input_dtypes}", - ) - - input_numpy_dtypes = [convert_base_data_type_to_numpy_dtype(dtype) for dtype in input_dtypes] - - dummy_inputs = tuple( - ( - dtype.type(10.0 * numpy.random.random_sample()) - if shape == () - else numpy.abs(numpy.random.randn(*shape) * 10.0).astype(dtype) - ) - for dtype, shape in zip(input_numpy_dtypes, input_shapes) - ) - - # We ignore errors as we may call functions with invalid inputs just to get the proper output - # dtypes - with numpy.errstate(all="ignore"): - outputs = function(*dummy_inputs) - - if not isinstance(outputs, tuple): - outputs = (outputs,) - - return [(output.dtype, output.shape) for output in outputs] - - -def get_numpy_function_output_dtype_and_shape_from_input_tracers( - func: Union[numpy.ufunc, Callable], - *input_tracers: BaseTracer, -) -> List[Tuple[BaseDataType, Tuple[int, ...]]]: - """Determine output dtypes and shapes for a numpy function. - - This function is responsible for determining the output dtype - of a numpy function after inputs with specific dtypes are passed to it. - - Args: - func (Union[numpy.ufunc, Callable]): function that is being managed - *input_tracers (BaseTracer): inputs to the function - - Returns: - List[Tuple[BaseDataType, Tuple[int, ...]]]: appropriate (BaseDataType, shape) tuple for each - output of the function - """ - - input_shapes = [ - input_tracer.output.shape if isinstance(input_tracer.output, TensorValue) else () - for input_tracer in input_tracers - ] - output_dtypes_and_shapes = get_numpy_function_output_dtype_and_shape_from_input_dtypes( - func, - [input_tracer.output.dtype for input_tracer in input_tracers], - input_shapes, - ) - common_output_dtypes = [ - (convert_numpy_dtype_to_base_data_type(dtype), shape) - for dtype, shape in output_dtypes_and_shapes - ] - return common_output_dtypes - - -def get_constructor_for_numpy_or_python_constant_data(constant_data: Any): - """Get the constructor for the numpy constant data or python dtype. - - Args: - constant_data (Any): The data for which we want to determine the type constructor. - """ - - assert_true( - isinstance( - constant_data, (int, float, list, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES) - ), - f"Unsupported constant data of type {type(constant_data)}", - ) - - if isinstance(constant_data, list): - # this is required because some operations return python lists from their evaluate function - # an example of such operation is evaluation of multi tlu during bound measurements - constant_data = numpy.array(constant_data) - - if isinstance(constant_data, (numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)): - if isinstance(constant_data, numpy.ndarray): - return lambda x: numpy.full(constant_data.shape, x, dtype=constant_data.dtype) - return constant_data.dtype.type - - return get_constructor_for_python_constant_data(constant_data) diff --git a/concrete/numpy/np_fhe_compiler.py b/concrete/numpy/np_fhe_compiler.py deleted file mode 100644 index 9595352af..000000000 --- a/concrete/numpy/np_fhe_compiler.py +++ /dev/null @@ -1,309 +0,0 @@ -"""Module to hold a user friendly class to compile programs.""" - -import itertools -from copy import deepcopy -from enum import Enum, unique -from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union - -from loguru import logger - -from ..common.compilation import CompilationArtifacts, CompilationConfiguration -from ..common.data_types import Integer -from ..common.debugging import draw_graph, format_operation_graph -from ..common.fhe_circuit import FHECircuit -from ..common.operator_graph import OPGraph -from ..common.representation.intermediate import IntermediateNode -from ..common.values import BaseValue -from .compile import ( - compile_numpy_function_into_op_graph, - compile_op_graph_to_fhe_circuit, - measure_op_graph_bounds_and_update, -) -from .np_dtypes_helpers import get_base_value_for_numpy_or_python_constant_data - - -@unique -class EncryptedStatus(str, Enum): - """Enum to validate GenericFunction op_kind.""" - - CLEAR = "clear" - ENCRYPTED = "encrypted" - - -class NPFHECompiler: - """Class to ease the conversion of a numpy program to its FHE equivalent.""" - - INPUTSET_SIZE_BEFORE_AUTO_BOUND_UPDATE: int = 128 - - # _function_to_compile is not optional but mypy has a long standing bug and is not able to - # understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623 - _function_to_compile: Optional[Callable] - _function_parameters_encrypted_status: Dict[str, bool] - _current_inputset: List[Union[Any, Tuple]] - _op_graph: Optional[OPGraph] - _nodes_and_bounds: Dict[IntermediateNode, Dict[str, Any]] - - _compilation_configuration: CompilationConfiguration - - compilation_artifacts: CompilationArtifacts - - def __init__( - self, - function_to_compile: Callable, - function_parameters_encrypted_status: Dict[str, Union[str, EncryptedStatus]], - compilation_configuration: Optional[CompilationConfiguration] = None, - compilation_artifacts: Optional[CompilationArtifacts] = None, - ) -> None: - self._function_to_compile = function_to_compile - self._function_parameters_encrypted_status = { - param_name: EncryptedStatus(status.lower()) == EncryptedStatus.ENCRYPTED - for param_name, status in function_parameters_encrypted_status.items() - } - - self._current_inputset = [] - self._op_graph = None - self._nodes_and_bounds = {} - - self._compilation_configuration = ( - deepcopy(compilation_configuration) - if compilation_configuration is not None - else CompilationConfiguration() - ) - self.compilation_artifacts = ( - compilation_artifacts if compilation_artifacts is not None else CompilationArtifacts() - ) - - @property - def function_to_compile(self) -> Callable: - """Get the function to compile. - - Returns: - Callable: the function to compile. - """ - # Continuation of mypy bug - assert self._function_to_compile is not None - return self._function_to_compile - - @property - def op_graph(self) -> Optional[OPGraph]: - """Return a copy of the OPGraph. - - Returns: - Optional[OPGraph]: the held OPGraph or None - """ - # To keep consistency with what the user expects, we make sure to evaluate on the remaining - # inputset values if any before giving a copy of the OPGraph we trace - self._eval_on_current_inputset() - return deepcopy(self._op_graph) - - @property - def compilation_configuration(self) -> Optional[CompilationConfiguration]: - """Get a copy of the compilation configuration. - - Returns: - Optional[CompilationConfiguration]: copy of the current compilation configuration. - """ - return deepcopy(self._compilation_configuration) - - def __call__(self, *args: Any) -> Any: - """Evaluate the OPGraph corresponding to the function being compiled and return result. - - Returns: - Any: the result of the OPGraph evaluation. - """ - self._current_inputset.append(deepcopy(args) if len(args) > 1 else deepcopy(args[0])) - - inferred_args = { - param_name: get_base_value_for_numpy_or_python_constant_data(val)( - is_encrypted=is_encrypted - ) - for (param_name, is_encrypted), val in zip( - self._function_parameters_encrypted_status.items(), args - ) - } - - if len(self._current_inputset) >= self.INPUTSET_SIZE_BEFORE_AUTO_BOUND_UPDATE: - self._eval_on_current_inputset() - - self._trace_op_graph_if_needed(inferred_args) - - # For mypy - assert self._op_graph is not None - return self._op_graph(*args) - - def __str__(self) -> str: - self._eval_on_current_inputset() - if self._op_graph is None: - warning_msg = ( - f"__str__ failed: OPGraph is None, {self.__class__.__name__} " - "needs evaluation on an inputset" - ) - logger.warning(warning_msg) - return warning_msg - return format_operation_graph(self._op_graph) - - def draw_graph( - self, - show: bool = False, - vertical: bool = True, - save_to: Optional[Path] = None, - ) -> Optional[str]: - """Draws operation graphs and optionally saves/shows the drawing. - - Args: - op_graph (OPGraph): the operation graph to be drawn and optionally saved/shown - show (bool): if set to True, the drawing will be shown using matplotlib - vertical (bool): if set to True, the orientation will be vertical - save_to (Optional[Path]): if specified, the drawn graph will be saved to this path; else - it is saved in a temporary file - - Returns: - Optional[str]: if OPGraph was not None returns the path as a string of the file where - the drawn graph is saved - """ - self._eval_on_current_inputset() - if self._op_graph is None: - logger.warning( - f"{self.draw_graph.__name__} failed: OPGraph is None, {self.__class__.__name__} " - "needs evaluation on an inputset" - ) - return None - return draw_graph(self._op_graph, show, vertical, save_to) - - def eval_on_inputset( - self, - inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]], - warn_on_inputset_length: bool = False, - ) -> None: - """Evaluate the underlying function on an inputset in one go, populates OPGraph and bounds. - - Args: - inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]]]): The inputset on which the - function should be evaluated. - warn_on_inputset_length (bool, optional): Set to True to get a warning - if inputset is not long enough. Defaults to False. - """ - - inputset_iter = iter(inputset) - try: - first_sample = next(inputset_iter) - except StopIteration: - return - - inferred_args = { - param_name: get_base_value_for_numpy_or_python_constant_data(val)( - is_encrypted=is_encrypted - ) - for (param_name, is_encrypted), val in zip( - self._function_parameters_encrypted_status.items(), - first_sample - if len(self._function_parameters_encrypted_status) > 1 - else (first_sample,), - ) - } - - self._trace_op_graph_if_needed(inferred_args) - - # For mypy - assert self._op_graph is not None - - self._patch_op_graph_input_to_accept_any_integer_input() - - self._nodes_and_bounds = measure_op_graph_bounds_and_update( - self._op_graph, - inferred_args, - itertools.chain((first_sample,), inputset_iter), - self._compilation_configuration, - self.compilation_artifacts, - self._nodes_and_bounds, - warn_on_inputset_length, - ) - - def _eval_on_current_inputset(self) -> None: - """Evaluate OPGraph on _current_inputset.""" - self.eval_on_inputset(self._current_inputset) - self._current_inputset.clear() - - def _needs_tracing(self) -> bool: - """Return whether we need to trace the function and populate the OPGraph.""" - return self._op_graph is None - - def _trace_op_graph_if_needed(self, inferred_args: Dict[str, BaseValue]) -> None: - """Populate _op_graph with the OPGraph for _function_to_compile.""" - if not self._needs_tracing(): - return - - self._op_graph = compile_numpy_function_into_op_graph( - self.function_to_compile, - inferred_args, - self._compilation_configuration, - self.compilation_artifacts, - ) - - def _patch_op_graph_input_to_accept_any_integer_input(self) -> None: - """Patch inputs as we don't know what data we expect.""" - - # Can only do that if the OPGraph was created hence the test. - if self._needs_tracing(): - return - - # For mypy - assert self._op_graph is not None - - # Cheat on Input nodes to avoid issues during inputset eval as we do not know in advance - # what the final bit width for the inputs should be - for node in self._op_graph.input_nodes.values(): - for input_ in node.inputs: - if isinstance(dtype := (input_.dtype), Integer): - dtype.bit_width = 128 - dtype.is_signed = True - - def compile_on_inputset( - self, - inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]], - show_mlir: bool = False, - ) -> FHECircuit: - """Compile the function on an inputset and get resulting FHECircuit. - - Args: - inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]]]): - The inputset on which the function is evaluated. - show_mlir (bool, optional, defaults to False): - The flag to enable printing the MLIR that is being compiled for debugging purposes. - - Returns: - FHECircuit: the compiled FHECircuit - """ - - self.eval_on_inputset(inputset) - return self.get_compiled_fhe_circuit(show_mlir) - - def get_compiled_fhe_circuit(self, show_mlir: bool = False) -> FHECircuit: - """Return a compiled FHECircuit if the instance was evaluated on an inputset. - - Args: - show_mlir (bool, optional): if set, the MLIR produced by the converter and which is - going to be sent to the compiler backend is shown on the screen, e.g., for debugging - or demo. Defaults to False. - - Raises: - RuntimeError: raised if no inputset was passed to the instance. - - Returns: - FHECircuit: the compiled FHECircuit - """ - self._eval_on_current_inputset() - - if self._op_graph is None: - raise RuntimeError( - "Requested FHECircuit but no OPGraph was compiled. " - f"Did you forget to evaluate {self.__class__.__name__} over an inputset?" - ) - - return compile_op_graph_to_fhe_circuit( - self._op_graph, - show_mlir, - self.compilation_configuration, - self.compilation_artifacts, - ) diff --git a/concrete/numpy/np_indexing_helpers.py b/concrete/numpy/np_indexing_helpers.py deleted file mode 100644 index 945f66412..000000000 --- a/concrete/numpy/np_indexing_helpers.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Helpers for indexing with numpy values functionality.""" - -from typing import Any - -import numpy - - -def should_sanitize(indexing_element: Any) -> bool: - """Decide whether to sanitize an indexing element or not. - - Sanitizing in this context means converting supported numpy values into python values. - - Args: - indexing_element (Any): the indexing element to decide sanitization. - - Returns: - bool: True if indexing element should be sanitized otherwise False. - """ - - return isinstance(indexing_element, numpy.integer) or ( - isinstance(indexing_element, numpy.ndarray) - and issubclass(indexing_element.dtype.type, numpy.integer) - and indexing_element.shape == () - ) - - -def process_indexing_element(indexing_element: Any) -> Any: - """Process an indexing element. - - Processing in this context means converting supported numpy values into python values. - (if they are decided to be sanitized) - - Args: - indexing_element (Any): the indexing element to sanitize. - - Returns: - Any: the sanitized indexing element. - """ - - if isinstance(indexing_element, slice): - - start = indexing_element.start - if should_sanitize(start): - start = int(start) - - stop = indexing_element.stop - if should_sanitize(stop): - stop = int(stop) - - step = indexing_element.step - if should_sanitize(step): - step = int(step) - - indexing_element = slice(start, stop, step) - - elif should_sanitize(indexing_element): - indexing_element = int(indexing_element) - - return indexing_element diff --git a/concrete/numpy/np_inputset_helpers.py b/concrete/numpy/np_inputset_helpers.py deleted file mode 100644 index f5c6b88d8..000000000 --- a/concrete/numpy/np_inputset_helpers.py +++ /dev/null @@ -1,157 +0,0 @@ -"""Helpers for numpy inputset related functionality.""" - -import random -from typing import Any, Dict, Iterable, Tuple, Union - -import numpy - -from ..common.compilation import CompilationConfiguration -from ..common.data_types import Float, Integer -from ..common.values import BaseValue, TensorValue - - -def _generate_random_integer_scalar(dtype: Integer) -> int: - """Generate a random integer scalar. - - Args: - dtype (Integer): the data type to extract bounds - - Returns: - int: a random value within the range [dtype.min_value(), dtype.max_value()] - """ - - return random.randint(dtype.min_value(), dtype.max_value()) - - -def _generate_random_integer_tensor(dtype: Integer, shape: Tuple[int, ...]) -> numpy.ndarray: - """Generate a random integer tensor. - - Args: - dtype (Integer): the data type to extract bounds - shape (Tuple[int, ...]): the shape of the generated tensor - - Returns: - numpy.ndarray: a random array of the specified shape where each value of it - is within the range [dtype.min_value(), dtype.max_value()] - """ - - return numpy.random.randint( - dtype.min_value(), - dtype.max_value() + 1, - size=shape, - dtype=numpy.int64 if dtype.is_signed else numpy.uint64, # type: ignore - ) - - -def _generate_random_float_scalar() -> float: - """Generate a random float scalar. - - Returns: - float: a random value within the range [0, 1) - """ - - return random.random() - - -def _generate_random_float_tensor(dtype: Float, shape: Tuple[int, ...]) -> numpy.ndarray: - """Generate a random float tensor. - - Args: - dtype (Integer): the data type to extract resulting numpy data type - shape (Tuple[int, ...]): the shape of the generated tensor - - Returns: - numpy.ndarray: a random array of the specified shape where each value of it - is within the range [0, 1) - """ - - result = numpy.random.rand(*shape) - return result.astype(numpy.float32 if dtype.bit_width == 32 else numpy.float64) - - -def _generate_random_inputset( - function_parameters: Dict[str, BaseValue], - compilation_configuration: CompilationConfiguration, -) -> Union[Iterable[Any], Iterable[Tuple[Any, ...]]]: - """Generate a random inputset from function parameters. - - Using this function is not a good practice since the randomly generated inputset - might not reflect real world data. We have it to speed up our development workflow - and we also don't use it in any of our tests, benchmarks, or examples. - - Args: - function_parameters (Dict[str, BaseValue]): the function parameters - to extract data types and shapes - compilation_configuration (CompilationConfiguration): the compilation configuration - to extract the sample size of the resulting inputset - - Raises: - ValueError: if the provided function arguments cannot be used for random inputset generation - - Returns: - Union[Iterable[Any], Iterable[Tuple[Any, ...]]]: the inputset - """ - - inputset = [] - for _ in range(compilation_configuration.random_inputset_samples): - sample = [] - for parameter in function_parameters.values(): - if not isinstance(parameter, TensorValue): - raise ValueError(f"Random inputset cannot be generated for {parameter} parameters") - - if isinstance(parameter.dtype, Integer): - sample.append( - _generate_random_integer_scalar(parameter.dtype) - if parameter.is_scalar - else _generate_random_integer_tensor(parameter.dtype, parameter.shape) - ) - elif isinstance(parameter.dtype, Float): - sample.append( - _generate_random_float_scalar() - if parameter.is_scalar - else _generate_random_float_tensor(parameter.dtype, parameter.shape) - ) - else: - raise ValueError( - f"Random inputset cannot be generated " - f"for parameters of type {parameter.dtype}" - ) - inputset.append(tuple(sample) if len(sample) > 1 else sample[0]) - return inputset - - -def _check_special_inputset_availability( - inputset: str, - compilation_configuration: CompilationConfiguration, -): - """Check special inputset is valid and is available. - - This function makes sure the provided special inputset is valid and can be used with the - provided compilation configuration. - - Currently, the only special inputset is "random" but this can be extended in the future. - - Args: - inputset (str): the special inputset to check - compilation_configuration (CompilationConfiguration): the compilation configuration - to check the availability of the provided special inputset - - Raises: - ValueError: if the provided special inputset is not valid - RuntimeError: if the provided special inputset is not available - - Returns: - None - """ - - if inputset != "random": - raise ValueError( - f"inputset can only be an iterable of tuples or the string 'random' " - f"but you specified '{inputset}' for it" - ) - - if not compilation_configuration.enable_unsafe_features: - raise RuntimeError( - "Random inputset generation is an unsafe feature and should not be used " - "if you don't know what you are doing" - ) diff --git a/concrete/numpy/np_mlir_converter.py b/concrete/numpy/np_mlir_converter.py deleted file mode 100644 index 0ffe00327..000000000 --- a/concrete/numpy/np_mlir_converter.py +++ /dev/null @@ -1,85 +0,0 @@ -"""Numpy-specific MLIR converter.""" - -import math -from collections import defaultdict -from itertools import product -from typing import Any, DefaultDict, Dict, List, Tuple - -import numpy - -from ..common.debugging import assert_true -from ..common.mlir.graph_converter import OPGraphConverter -from ..common.operator_graph import OPGraph -from ..common.representation.intermediate import GenericFunction, IntermediateNode - - -class HashableNPArray: - """Class to easily manipulate numpy arrays for hashing. - - Note that the hash behavior won't work if the array is modified after being hashed, as it will - have been hashed to a certain value and the new array content will be hashed to a different one. - """ - - array: numpy.ndarray - - def __init__(self, array: numpy.ndarray) -> None: - self.array = array - - def __hash__(self) -> int: - return hash(self.array.tobytes()) - - def __eq__(self, other: object) -> bool: - return isinstance(other, HashableNPArray) and numpy.array_equal(self.array, other.array) - - -def generate_deduplicated_tables( - node: GenericFunction, ordered_preds: List[IntermediateNode] -) -> Tuple[Tuple[numpy.ndarray, List[Tuple[int, ...]]], ...]: - """Deduplicate the tables for the different cells of a tensor if needed. - - Args: - node (GenericFunction): the node for which to deduplicate the table. - ordered_preds (List[IntermediateNode]): ordered list of predecessors of the node. - - Returns: - Tuple[Tuple[numpy.ndarray, List[Tuple[int, ...]]], ...]: A tuple containing tuples whose - first element is a table and the second element is a list of tuples indicating which - cells in the tensor will use that table. - """ - # This is the tensor containing the tables for each cell of the tensor for node - node_complete_table = numpy.concatenate( - tuple(numpy.expand_dims(array, -1) for array in node.get_table(ordered_preds)), axis=-1 - ) - - all_cells_idx = product(*tuple(range(max_val) for max_val in node_complete_table.shape[:-1])) - tables_to_cell_idx: DefaultDict[HashableNPArray, List[Tuple[int, ...]]] = defaultdict(list) - idx: Tuple[int, ...] - all_idx_set = set() - for idx in all_cells_idx: - hashable_array = HashableNPArray(node_complete_table[idx]) - tables_to_cell_idx[hashable_array].append(idx) - all_idx_set.add(idx) - - assert_true(len(all_idx_set) == math.prod(node_complete_table.shape[:-1])) - - return tuple( - (hashable_array.array, indices) for hashable_array, indices in tables_to_cell_idx.items() - ) - - -class NPMLIRConverter(OPGraphConverter): - """Numpy-specific MLIR converter.""" - - @staticmethod - def _generate_additional_info_dict(op_graph: OPGraph) -> Dict[str, Any]: - additional_conversion_info = {} - - # Disable numpy warnings during conversion to avoid issues during TLU generation - with numpy.errstate(all="ignore"): - additional_conversion_info["tables"] = { - node: generate_deduplicated_tables(node, op_graph.get_ordered_preds(node)) - for node in op_graph.graph.nodes() - if isinstance(node, GenericFunction) and node.op_kind == "TLU" - } - - return additional_conversion_info diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py deleted file mode 100644 index 17356537f..000000000 --- a/concrete/numpy/tracing.py +++ /dev/null @@ -1,818 +0,0 @@ -"""numpy tracing utilities.""" -from copy import deepcopy -from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast - -import numpy -from numpy.typing import DTypeLike - -from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype -from ..common.debugging.custom_assert import assert_true -from ..common.operator_graph import OPGraph -from ..common.representation.intermediate import Constant, Dot, GenericFunction, MatMul -from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters -from ..common.tracing.tracing_helpers import tracing_context -from ..common.values import BaseValue, TensorValue -from .np_dtypes_helpers import ( - SUPPORTED_NUMPY_DTYPES_CLASS_TYPES, - convert_numpy_dtype_to_base_data_type, - get_base_value_for_numpy_or_python_constant_data, - get_numpy_function_output_dtype_and_shape_from_input_tracers, -) -from .np_indexing_helpers import process_indexing_element - -SUPPORTED_TYPES_FOR_TRACING = (int, float, numpy.ndarray) + tuple( - SUPPORTED_NUMPY_DTYPES_CLASS_TYPES -) - -NPConstant = partial( - Constant, - get_base_value_for_data_func=get_base_value_for_numpy_or_python_constant_data, -) - - -class NPTracer(BaseTracer): - """Tracer class for numpy operations.""" - - _mix_values_func: Callable[..., BaseValue] = mix_values_determine_holding_dtype - - def __array_ufunc__(self, ufunc: numpy.ufunc, method, *args, **kwargs): - """Catch calls to numpy ufunc and routes them to tracing functions if supported. - - Read more: https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch - """ - if method == "__call__": - tracing_func = self.get_tracing_func_for_np_function(ufunc) - assert_true( - (len(kwargs) == 0), - f"**kwargs are currently not supported for numpy ufuncs, ufunc: {ufunc.__name__}", - ) - - # Create constant tracers for args, numpy only passes ufunc.nin args so we can - # sanitize all of them without issues - sanitized_args = [self._sanitize(arg) for arg in args] - return tracing_func(*sanitized_args, **kwargs) - raise NotImplementedError("Only __call__ method is supported currently") - - def __array_function__(self, func, _types, args, kwargs): - """Catch calls to numpy function in routes them to tracing functions if supported. - - Read more: https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch - """ - tracing_func = self.get_tracing_func_for_np_function(func) - assert_true( - (tracing_func in [NPTracer.numpy_sum, NPTracer.numpy_concatenate]) or len(kwargs) == 0, - f"**kwargs are currently not supported for numpy functions, func: {func}", - ) - - # Fixme: Special case to be removed once #772 is done - if func is not numpy.reshape: - sanitized_args = [self._sanitize(arg) for arg in args] - else: - # In numpy.reshape, the second argument is the new shape - sanitized_args = [self._sanitize(args[0]), args[1]] - return tracing_func(self, sanitized_args[0], sanitized_args[1], **kwargs) - - return tracing_func(self, *sanitized_args, **kwargs) - - def astype(self, numpy_dtype: DTypeLike, *args, **kwargs) -> "NPTracer": - r"""Support numpy astype feature. - - For now it only accepts a dtype and no additional parameters, \*args and - \*\*kwargs are accepted for interface compatibility only - - Args: - numpy_dtype (DTypeLike): The object describing a numpy type - - Returns: - NPTracer: The NPTracer representing the casting operation - """ - assert_true( - len(args) == 0, f"astype currently only supports tracing without *args, got {args}" - ) - assert_true( - (len(kwargs) == 0), - f"astype currently only supports tracing without **kwargs, got {kwargs}", - ) - - normalized_numpy_dtype = numpy.dtype(numpy_dtype) - output_dtype = convert_numpy_dtype_to_base_data_type(numpy_dtype) - generic_function_output_value = deepcopy(self.output) - generic_function_output_value.dtype = output_dtype - traced_computation = GenericFunction( - inputs=[self.output], - arbitrary_func=lambda x, dtype: x.astype(dtype), - output_value=generic_function_output_value, - op_kind="TLU", - op_kwargs={"dtype": normalized_numpy_dtype.type}, - op_name="astype", - ) - output_tracer = self.__class__([self], traced_computation=traced_computation, output_idx=0) - return output_tracer - - @staticmethod - def get_tracing_func_for_np_function(func: Union[numpy.ufunc, Callable]) -> Callable: - """Get the tracing function for a numpy function. - - Args: - func (Union[numpy.ufunc, Callable]): The numpy function that will be traced - - Raises: - NotImplementedError: Raised if the passed function is not supported by NPTracer - - Returns: - Callable: the tracing function that needs to be called to trace func - """ - tracing_func: Optional[Callable] - - # numpy.invert is not great in term of types it supports, so we've decided not to support it - # and to propose to the user to use numpy.bitwise_not - if func == numpy.invert: - raise RuntimeError( - f"NPTracer does not manage the following func: {func.__name__}. Please replace by " - f"calls to bitwise_xor with appropriate mask" - ) - - if isinstance(func, numpy.ufunc): - tracing_func = NPTracer.UFUNC_ROUTING.get(func, None) - else: - tracing_func = NPTracer.FUNC_ROUTING.get(func, None) - - if tracing_func is None: - raise NotImplementedError( - f"NPTracer does not yet manage the following func: {func.__name__}" - ) - return tracing_func - - def _supports_other_operand(self, other: Any) -> bool: - return super()._supports_other_operand(other) or isinstance( - other, SUPPORTED_TYPES_FOR_TRACING - ) - - def _make_const_input_tracer(self, constant_data: Any) -> "NPTracer": - return self.__class__([], NPConstant(constant_data), 0) - - @classmethod - def _np_operator( - cls, - numpy_operator, - numpy_operator_string, - numpy_operator_nin, - *input_tracers: "NPTracer", - **kwargs, - ) -> "NPTracer": - """Trace a numpy operator. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - assert_true(len(input_tracers) == numpy_operator_nin) - - common_output_dtypes_and_shapes = ( - get_numpy_function_output_dtype_and_shape_from_input_tracers( - numpy_operator, - *input_tracers, - ) - ) - assert_true(len(common_output_dtypes_and_shapes) == 1) - - variable_input_indices = [ - idx - for idx, pred in enumerate(input_tracers) - if not isinstance(pred.traced_computation, Constant) - ] - assert_true( - (non_constant_pred_count := len(variable_input_indices)) == 1, - f"Can only have 1 non constant predecessor in {cls._np_operator.__name__}, " - f"got {non_constant_pred_count} for operator {numpy_operator}", - ) - - variable_input_idx = variable_input_indices[0] - output_dtype, output_shape = common_output_dtypes_and_shapes[0] - - generic_function_output_value = TensorValue( - output_dtype, - input_tracers[variable_input_idx].output.is_encrypted, - output_shape, - ) - - op_kwargs = deepcopy(kwargs) - - traced_computation = GenericFunction( - inputs=[input_tracer.output for input_tracer in input_tracers], - arbitrary_func=numpy_operator, - output_value=generic_function_output_value, - op_kind="TLU", - op_kwargs=op_kwargs, - op_name=numpy_operator_string, - ) - output_tracer = cls( - input_tracers, - traced_computation=traced_computation, - output_idx=0, - ) - return output_tracer - - def numpy_dot(self, *args: "NPTracer", **_kwargs) -> "NPTracer": - """Trace numpy.dot. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - assert_true((num_args := len(args)) == 2, f"dot expects 2 inputs got {num_args}") - - common_output_dtypes_and_shapes = ( - get_numpy_function_output_dtype_and_shape_from_input_tracers(numpy.dot, *args) - ) - assert_true(len(common_output_dtypes_and_shapes) == 1) - - traced_computation = Dot( - [input_tracer.output for input_tracer in args], - common_output_dtypes_and_shapes[0][0], - delegate_evaluation_function=numpy.dot, - ) - - output_tracer = self.__class__( - args, - traced_computation=traced_computation, - output_idx=0, - ) - return output_tracer - - def clip(self, *args: Union["NPTracer", Any], **kwargs) -> "NPTracer": - """Trace x.clip. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - sanitized_args = [cast(NPTracer, self._sanitize(arg)) for arg in args] - return self.numpy_clip(self, *sanitized_args, **kwargs) - - def numpy_clip(self, *args: "NPTracer", **kwargs) -> "NPTracer": - """Trace numpy.clip. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - return self._np_operator(numpy.clip, "clip", 3, *args, **kwargs) - - def dot(self, *args: "NPTracer", **kwargs) -> "NPTracer": - """Trace x.dot. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - assert len(args) == 1 - arg0 = self._sanitize(args[0]) - assert_true(isinstance(arg0, NPTracer)) - arg0 = cast(NPTracer, arg0) - return self.numpy_dot(self, arg0, **kwargs) - - def transpose(self, *args: "NPTracer", **kwargs) -> "NPTracer": - """Trace x.transpose. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - return self.numpy_transpose(self, *args, **kwargs) - - def numpy_transpose(self, *args: "NPTracer", **kwargs) -> "NPTracer": - """Trace numpy.transpose. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - assert_true((num_args := len(args)) == 1, f"transpose expect 1 input got {num_args}") - - first_arg_output = args[0].output - assert_true(isinstance(first_arg_output, TensorValue)) - first_arg_output = cast(TensorValue, first_arg_output) - - transpose_is_fusable = first_arg_output.is_scalar or first_arg_output.ndim == 1 - - out_dtype = first_arg_output.dtype - out_shape = first_arg_output.shape[::-1] - - generic_function_output_value = TensorValue( - out_dtype, - first_arg_output.is_encrypted, - out_shape, - ) - - traced_computation = GenericFunction( - inputs=[first_arg_output], - arbitrary_func=numpy.transpose, - output_value=generic_function_output_value, - op_kind="Memory", - op_kwargs=deepcopy(kwargs), - op_name="transpose", - op_attributes={"fusable": transpose_is_fusable}, - ) - output_tracer = self.__class__( - args, - traced_computation=traced_computation, - output_idx=0, - ) - return output_tracer - - def ravel(self, *args: "NPTracer", **kwargs) -> "NPTracer": - """Trace x.ravel. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - return self.numpy_ravel(self, *args, **kwargs) - - def numpy_ravel(self, *args: "NPTracer", **kwargs) -> "NPTracer": - """Trace numpy.ravel. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - assert_true((num_args := len(args)) == 1, f"ravel expect 1 input got {num_args}") - - first_arg_output = args[0].output - assert_true(isinstance(first_arg_output, TensorValue)) - first_arg_output = cast(TensorValue, first_arg_output) - - ravel_is_fusable = first_arg_output.ndim == 1 - - out_dtype = first_arg_output.dtype - out_shape = (1,) if first_arg_output.is_scalar else (numpy.product(first_arg_output.shape),) - - generic_function_output_value = TensorValue( - out_dtype, - first_arg_output.is_encrypted, - out_shape, - ) - - traced_computation = GenericFunction( - inputs=[first_arg_output], - arbitrary_func=numpy.ravel, - output_value=generic_function_output_value, - op_kind="Memory", - op_kwargs=deepcopy(kwargs), - op_name="ravel", - op_attributes={"fusable": ravel_is_fusable}, - ) - output_tracer = self.__class__( - args, - traced_computation=traced_computation, - output_idx=0, - ) - return output_tracer - - def reshape(self, newshape: Tuple[Any, ...], **kwargs) -> "NPTracer": - """Trace x.reshape. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - return self.numpy_reshape(self, newshape, **kwargs) - - def numpy_reshape(self, arg0: "NPTracer", arg1: Tuple[Any, ...], **kwargs) -> "NPTracer": - """Trace numpy.reshape. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - - # FIXME: #772, restore reshape(self, *args, **kwargs) signature when possible, with mypy - # types - - # FIXME: #772, restore - # assert_true((num_args := len(args)) == 2, f"reshape expect 2 input got {num_args}") - # when possible - - assert_true((num_kwargs := len(kwargs)) == 0, f"reshape expect 0 kwargs got {num_kwargs}") - - first_arg_output = arg0.output - assert_true(isinstance(first_arg_output, TensorValue)) - first_arg_output = cast(TensorValue, first_arg_output) - - try: - # calculate a newshape using numpy to handle edge cases such as `-1`s within new shape - newshape = numpy.zeros(first_arg_output.shape).reshape(arg1).shape - except Exception as error: - raise ValueError( - f"shapes are not compatible (old shape {first_arg_output.shape}, new shape {arg1})" - ) from error - - reshape_is_fusable = newshape == first_arg_output.shape - - out_dtype = first_arg_output.dtype - out_shape = newshape - - generic_function_output_value = TensorValue( - out_dtype, - first_arg_output.is_encrypted, - out_shape, - ) - - traced_computation = GenericFunction( - inputs=[first_arg_output], - arbitrary_func=numpy.reshape, - output_value=generic_function_output_value, - op_kind="Memory", - op_kwargs={"newshape": newshape}, - op_name="reshape", - op_attributes={"fusable": reshape_is_fusable}, - ) - output_tracer = self.__class__( - [arg0], - traced_computation=traced_computation, - output_idx=0, - ) - return output_tracer - - def flatten(self, *args: "NPTracer", **kwargs) -> "NPTracer": - """Trace x.flatten. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - assert_true((num_args := len(args)) == 0, f"flatten expect 0 input got {num_args}") - - first_arg_output = self.output - assert_true(isinstance(first_arg_output, TensorValue)) - first_arg_output = cast(TensorValue, first_arg_output) - - flatten_is_fusable = first_arg_output.ndim == 1 - - out_dtype = first_arg_output.dtype - out_shape = (1,) if first_arg_output.is_scalar else (numpy.product(first_arg_output.shape),) - - generic_function_output_value = TensorValue( - out_dtype, - first_arg_output.is_encrypted, - out_shape, - ) - - traced_computation = GenericFunction( - inputs=[first_arg_output], - arbitrary_func=lambda x: x.flatten(), - output_value=generic_function_output_value, - op_kind="Memory", - op_kwargs=deepcopy(kwargs), - op_name="flatten", - op_attributes={"fusable": flatten_is_fusable}, - ) - output_tracer = self.__class__( - [self], - traced_computation=traced_computation, - output_idx=0, - ) - return output_tracer - - def numpy_sum(self, inp: "NPTracer", **kwargs) -> "NPTracer": - """Trace numpy.sum. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - - input_value = inp.output - - def supported(value): - if not value.is_encrypted or not isinstance(input_value, TensorValue): - return False - - value = cast(TensorValue, value) - if value.shape == (): - return False - - return True - - if not supported(input_value): - raise ValueError( - f"only encrypted tensor sum is supported but you tried to sum {input_value}" - ) - - try: - # calculate a newshape using numpy to handle all cases - newshape = numpy.sum(numpy.zeros(input_value.shape), **kwargs).shape # type: ignore - except Exception as error: - raise ValueError( - f"invalid sum on {input_value} with " - f"{', '.join('='.join([key, str(value)]) for key, value in kwargs.items())}" - ) from error - - output_value = TensorValue( - input_value.dtype, - input_value.is_encrypted, - newshape, - ) - traced_computation = GenericFunction( - inputs=[input_value], - arbitrary_func=numpy.sum, - output_value=output_value, - op_kind="Memory", - op_kwargs=kwargs, - op_name="sum", - op_attributes={"fusable": False}, - ) - output_tracer = self.__class__( - [inp], - traced_computation=traced_computation, - output_idx=0, - ) - return output_tracer - - def numpy_concatenate(self, inputs: Tuple["NPTracer", ...], **kwargs) -> "NPTracer": - """Trace numpy.concatenate. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - - input_values = [tracer.output for tracer in inputs] - - def supported(values): - if any( - not value.is_encrypted or not isinstance(value, TensorValue) for value in values - ): - return False - - values = [cast(TensorValue, value) for value in values] - if any(value.shape == () for value in values): - return False - - return True - - if not supported(input_values): - raise ValueError( - f"only encrypted tensor concatenation is supported " - f"but you tried to concatenate " - f"{', '.join(str(input_value) for input_value in input_values)}" - ) - - input_tensor_values = [cast(TensorValue, value) for value in input_values] - - try: - # calculate a newshape using numpy to handle all cases - sample = tuple(numpy.zeros(input_value.shape) for input_value in input_tensor_values) - newshape = numpy.concatenate(sample, **kwargs).shape - except Exception as error: - kwarg_info = "" - if len(kwargs) != 0: - kwarg_info += " with " - kwarg_info += ", ".join( - "=".join([key, str(value)]) for key, value in kwargs.items() - ) - - raise ValueError( - f"invalid concatenation of " - f"{', '.join(str(input_value) for input_value in input_values)}{kwarg_info}" - ) from error - - output_value = TensorValue( - input_tensor_values[0].dtype, - input_tensor_values[0].is_encrypted, - newshape, - ) - traced_computation = GenericFunction( - inputs=input_values, - arbitrary_func=numpy.concatenate, - output_value=output_value, - op_kind="Memory", - op_kwargs=kwargs, - op_name="concat", - op_attributes={"fusable": False}, - ) - output_tracer = self.__class__( - list(inputs), - traced_computation=traced_computation, - output_idx=0, - ) - return output_tracer - - def __getitem__(self, item): - if isinstance(item, tuple): - item = tuple(process_indexing_element(indexing_element) for indexing_element in item) - else: - item = process_indexing_element(item) - - return BaseTracer.__getitem__(self, item) - - def __matmul__(self, other): - """Trace numpy.matmul.""" - return self.__array_ufunc__(numpy.matmul, "__call__", self, other) - - # Supported functions are either univariate or bivariate for which one of the two - # sources is a constant - # - # numpy.add, numpy.multiply and numpy.subtract are not there since already managed - # by leveled operations - # - # numpy.conjugate is not there since working on complex numbers - # - # numpy.isnat is not there since it is about timings - # - # numpy.divmod, numpy.modf and numpy.frexp are not there since output two values - # - # numpy.invert (as known as numpy.bitwise_not) is not here, because it has strange input type. - # We ask the user to replace bitwise_xor instead - LIST_OF_SUPPORTED_UFUNC: List[numpy.ufunc] = [ - numpy.absolute, - numpy.arccos, - numpy.arccosh, - numpy.arcsin, - numpy.arcsinh, - numpy.arctan, - numpy.arctan2, - numpy.arctanh, - numpy.bitwise_and, - numpy.bitwise_or, - numpy.bitwise_xor, - numpy.cbrt, - numpy.ceil, - numpy.copysign, - numpy.cos, - numpy.cosh, - numpy.deg2rad, - numpy.degrees, - numpy.equal, - numpy.exp, - numpy.exp2, - numpy.expm1, - numpy.fabs, - numpy.float_power, - numpy.floor, - numpy.floor_divide, - numpy.fmax, - numpy.fmin, - numpy.fmod, - numpy.gcd, - numpy.greater, - numpy.greater_equal, - numpy.heaviside, - numpy.hypot, - numpy.isfinite, - numpy.isinf, - numpy.isnan, - numpy.lcm, - numpy.ldexp, - numpy.left_shift, - numpy.less, - numpy.less_equal, - numpy.log, - numpy.log10, - numpy.log1p, - numpy.log2, - numpy.logaddexp, - numpy.logaddexp2, - numpy.logical_and, - numpy.logical_not, - numpy.logical_or, - numpy.logical_xor, - numpy.maximum, - numpy.minimum, - numpy.negative, - numpy.nextafter, - numpy.not_equal, - numpy.positive, - numpy.power, - numpy.rad2deg, - numpy.radians, - numpy.reciprocal, - numpy.remainder, - numpy.right_shift, - numpy.rint, - numpy.sign, - numpy.signbit, - numpy.sin, - numpy.sinh, - numpy.spacing, - numpy.sqrt, - numpy.square, - numpy.tan, - numpy.tanh, - numpy.true_divide, - numpy.trunc, - ] - - # We build UFUNC_ROUTING dynamically after the creation of the class, - # because of some limits of python or our unability to do it properly - # in the class with techniques which are compatible with the different - # coding checks we use - UFUNC_ROUTING: Dict[numpy.ufunc, Callable] = {} - - FUNC_ROUTING: Dict[Callable, Callable] = { - numpy.dot: numpy_dot, - numpy.transpose: numpy_transpose, - numpy.reshape: numpy_reshape, - numpy.ravel: numpy_ravel, - numpy.clip: numpy_clip, - numpy.sum: numpy_sum, - numpy.concatenate: numpy_concatenate, - } - - -def _get_unary_fun(function: numpy.ufunc): - """Wrap _unary_operator in a lambda to populate NPTRACER.UFUNC_ROUTING.""" - - # We have to access this method to be able to build NPTracer.UFUNC_ROUTING - # dynamically - # pylint: disable=protected-access - return lambda *input_tracers, **kwargs: NPTracer._np_operator( - function, f"{function.__name__}", 1, *input_tracers, **kwargs - ) - # pylint: enable=protected-access - - -def _get_binary_fun(function: numpy.ufunc): - """Wrap _binary_operator in a lambda to populate NPTRACER.UFUNC_ROUTING.""" - - # We have to access this method to be able to build NPTracer.UFUNC_ROUTING - # dynamically - # pylint: disable=protected-access - return lambda *input_tracers, **kwargs: NPTracer._np_operator( - function, f"{function.__name__}", 2, *input_tracers, **kwargs - ) - # pylint: enable=protected-access - - -# We are populating NPTracer.UFUNC_ROUTING dynamically -NPTracer.UFUNC_ROUTING = { - fun: _get_unary_fun(fun) for fun in NPTracer.LIST_OF_SUPPORTED_UFUNC if fun.nin == 1 -} - -NPTracer.UFUNC_ROUTING.update( - {fun: _get_binary_fun(fun) for fun in NPTracer.LIST_OF_SUPPORTED_UFUNC if fun.nin == 2} -) - -list_of_not_supported = [ - (ufunc.__name__, ufunc.nin) - for ufunc in NPTracer.LIST_OF_SUPPORTED_UFUNC - if ufunc.nin not in [1, 2] -] - -assert_true(len(list_of_not_supported) == 0, f"Not supported nin's, {list_of_not_supported}") -del list_of_not_supported - -# We are adding initial support for `np.array(...)` +,-,* `BaseTracer` -# (note that this is not the proper complete handling of these functions) - - -def _on_numpy_add(lhs, rhs): - return lhs.__add__(rhs) - - -def _on_numpy_subtract(lhs, rhs): - return lhs.__sub__(rhs) - - -def _on_numpy_multiply(lhs, rhs): - return lhs.__mul__(rhs) - - -def _on_numpy_matmul(lhs: NPTracer, rhs: NPTracer): - common_output_dtypes_and_shapes = get_numpy_function_output_dtype_and_shape_from_input_tracers( - numpy.matmul, lhs, rhs - ) - assert_true(len(common_output_dtypes_and_shapes) == 1) - - output_shape = common_output_dtypes_and_shapes[0][1] - traced_computation = MatMul( - [lhs.output, rhs.output], - common_output_dtypes_and_shapes[0][0], - output_shape, - ) - return NPTracer([lhs, rhs], traced_computation, output_idx=0) - - -NPTracer.UFUNC_ROUTING[numpy.add] = _on_numpy_add -NPTracer.UFUNC_ROUTING[numpy.subtract] = _on_numpy_subtract -NPTracer.UFUNC_ROUTING[numpy.multiply] = _on_numpy_multiply -NPTracer.UFUNC_ROUTING[numpy.matmul] = _on_numpy_matmul - - -def trace_numpy_function( - function_to_trace: Callable, function_parameters: Dict[str, BaseValue] -) -> OPGraph: - """Trace a numpy function. - - Args: - function_to_trace (Callable): The function you want to trace - function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the - function is e.g. an EncryptedScalar holding a 7bits unsigned Integer - - Returns: - OPGraph: The graph containing the ir nodes representing the computation done in the input - function - """ - function_parameters = prepare_function_parameters(function_to_trace, function_parameters) - - input_tracers = make_input_tracers(NPTracer, function_parameters) - - # We could easily create a graph of NPTracer, but we may end up with dead nodes starting from - # the inputs that's why we create the graph starting from the outputs - with tracing_context([NPTracer]): - output_tracers = function_to_trace(**input_tracers) - - if isinstance(output_tracers, NPTracer): - output_tracers = (output_tracers,) - - op_graph = OPGraph.from_output_tracers(output_tracers) - - return op_graph diff --git a/tests/common/bounds_measurement/test_inputset_eval.py b/tests/common/bounds_measurement/test_inputset_eval.py deleted file mode 100644 index c917a3434..000000000 --- a/tests/common/bounds_measurement/test_inputset_eval.py +++ /dev/null @@ -1,533 +0,0 @@ -"""Test file for bounds evaluation with a inputset""" - -from typing import Tuple - -import numpy as np -import pytest - -from concrete.common.bounds_measurement.inputset_eval import eval_op_graph_bounds_on_inputset -from concrete.common.compilation import CompilationConfiguration -from concrete.common.data_types.floats import Float -from concrete.common.data_types.integers import Integer, UnsignedInteger -from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor -from concrete.numpy.compile import numpy_max_func, numpy_min_func -from concrete.numpy.np_dtypes_helpers import get_base_value_for_numpy_or_python_constant_data -from concrete.numpy.tracing import trace_numpy_function - - -@pytest.mark.parametrize( - "function,input_ranges,expected_output_bounds,expected_output_data_type", - [ - pytest.param( - lambda x, y: x + y, - ((-10, 10), (-10, 10)), - (-20, 20), - Integer(6, is_signed=True), - id="x + y, (-10, 10), (-10, 10), (-20, 20)", - ), - pytest.param( - lambda x, y: x + y, - ((-10, 2), (-4, 5)), - (-14, 7), - Integer(5, is_signed=True), - id="x + y, (-10, 2), (-4, 5), (-14, 7)", - ), - pytest.param( - lambda x, y: x + y + 1.7, - ((-10, 2), (-4, 5)), - (-12.3, 8.7), - Float(64), - id="x + y + 1.7, (-10, 2), (-4, 5), (-12.3, 8.7)", - ), - pytest.param( - lambda x, y: x + y + 1, - ((-10, 2), (-4, 5)), - (-13, 8), - Integer(5, is_signed=True), - id="x + y + 1, (-10, 2), (-4, 5), (-13, 8)", - ), - pytest.param( - lambda x, y: x + y + (-3), - ((-10, 2), (-4, 5)), - (-17, 4), - Integer(6, is_signed=True), - id="x + y + 1, (-10, 2), (-4, 5), (-17, 4)", - ), - pytest.param( - lambda x, y: (1 + x) + y, - ((-10, 2), (-4, 5)), - (-13, 8), - Integer(5, is_signed=True), - id="(1 + x) + y, (-10, 2), (-4, 5), (-13, 8)", - ), - pytest.param( - lambda x, y: x - y, - ((-10, 10), (-10, 10)), - (-20, 20), - Integer(6, is_signed=True), - id="x - y, (-10, 10), (-10, 10), (-20, 20)", - ), - pytest.param( - lambda x, y: x - y, - ((-10, 2), (-4, 5)), - (-15, 6), - Integer(5, is_signed=True), - id="x - y, (-10, 2), (-4, 5), (-15, 6)", - ), - pytest.param( - lambda x, y: x - y - 42, - ((-10, 2), (-4, 5)), - (-57, -36), - Integer(7, is_signed=True), - id="x - y - 42, (-10, 2), (-4, 5), (-57, -36)", - ), - pytest.param( - lambda x, y: x - y - 41.5, - ((-10, 2), (-4, 5)), - (-56.5, -35.5), - Float(64), - id="x - y - 41.5, (-10, 2), (-4, 5), (-56.5, -35.5)", - ), - pytest.param( - lambda x, y: 3 - x + y, - ((-10, 2), (-4, 5)), - (-3, 18), - Integer(6, is_signed=True), - id="3 - x + y, (-10, 2), (-4, 5), (-3, 18)", - ), - pytest.param( - lambda x, y: 2.8 - x + y, - ((-10, 2), (-4, 5)), - (-3.2, 17.8), - Float(64), - id="2.8 - x + y, (-10, 2), (-4, 5), (-3.2, 17.8)", - ), - pytest.param( - lambda x, y: (-13) - x + y, - ((-10, 2), (-4, 5)), - (-19, 2), - Integer(6, is_signed=True), - id="(-13) - x + y, (-10, 2), (-4, 5), (-19, 2)", - ), - pytest.param( - lambda x, y: (-13.5) - x + y, - ((-10, 2), (-4, 5)), - (-19.5, 1.5), - Float(64), - id="(-13.5) - x + y, (-10, 2), (-4, 5), (-19.5, 1.5)", - ), - pytest.param( - lambda x, y: x * y, - ((-10, 10), (-10, 10)), - (-100, 100), - Integer(8, is_signed=True), - id="x * y, (-10, 10), (-10, 10), (-100, 100)", - ), - pytest.param( - lambda x, y: x * y, - ((-10, 2), (-4, 5)), - (-50, 40), - Integer(7, is_signed=True), - id="x * y, (-10, 2), (-4, 5), (-50, 40)", - ), - pytest.param( - lambda x, y: (3 * x) * y, - ((-10, 2), (-4, 5)), - (-150, 120), - Integer(9, is_signed=True), - id="(3 * x) * y, (-10, 2), (-4, 5), (-150, 120)", - ), - pytest.param( - lambda x, y: (3.0 * x) * y, - ((-10, 2), (-4, 5)), - (-150.0, 120.0), - Float(64), - id="(3.0 * x) * y, (-10, 2), (-4, 5), (-150.0, 120.0)", - ), - pytest.param( - lambda x, y: (x * 11) * y, - ((-10, 2), (-4, 5)), - (-550, 440), - Integer(11, is_signed=True), - id="x * y, (-10, 2), (-4, 5), (-550, 440)", - ), - pytest.param( - lambda x, y: (x * (-11)) * y, - ((-10, 2), (-4, 5)), - (-440, 550), - Integer(11, is_signed=True), - id="(x * (-11)) * y, (-10, 2), (-4, 5), (-440, 550)", - ), - pytest.param( - lambda x, y: (x * (-11.0)) * y, - ((-10, 2), (-4, 5)), - (-440.0, 550.0), - Float(64), - id="(x * (-11.0)) * y, (-10, 2), (-4, 5), (-440.0, 550.0)", - ), - pytest.param( - lambda x, y: x + x + y, - ((-10, 10), (-10, 10)), - (-30, 30), - Integer(6, is_signed=True), - id="x + x + y, (-10, 10), (-10, 10), (-30, 30)", - ), - pytest.param( - lambda x, y: x - x + y, - ((-10, 10), (-10, 10)), - (-10, 10), - Integer(5, is_signed=True), - id="x - x + y, (-10, 10), (-10, 10), (-10, 10)", - ), - pytest.param( - lambda x, y: x - x + y, - ((-10, 2), (-4, 5)), - (-4, 5), - Integer(4, is_signed=True), - id="x - x + y, (-10, 2), (-4, 5), (-4, 5)", - ), - pytest.param( - lambda x, y: x * y - x, - ((-10, 10), (-10, 10)), - (-110, 110), - Integer(8, is_signed=True), - id="x * y - x, (-10, 10), (-10, 10), (-110, 110)", - ), - pytest.param( - lambda x, y: x * y - x, - ((-10, 2), (-4, 5)), - (-40, 50), - Integer(7, is_signed=True), - id="x * y - x, (-10, 2), (-4, 5), (-40, 50),", - ), - pytest.param( - lambda x, y: (x * 3) * y - (x + 3) + (y - 13) + x * (11 + y) * (12 + y) + (15 - x), - ((-10, 2), (-4, 5)), - (-2846, 574), - Integer(13, is_signed=True), - id="x * y - x, (-10, 2), (-4, 5), (-2846, 574),", - ), - ], -) -def test_eval_op_graph_bounds_on_inputset( - function, - input_ranges, - expected_output_bounds, - expected_output_data_type: Integer, -): - """Test function for eval_op_graph_bounds_on_inputset""" - - test_eval_op_graph_bounds_on_inputset_multiple_output( - function, - input_ranges, - (expected_output_bounds,), - (expected_output_data_type,), - ) - - -@pytest.mark.parametrize( - "function,input_ranges,expected_output_bounds,expected_output_data_type", - [ - pytest.param( - lambda x, y: (x + 1, y + 10), - ((-1, 1), (3, 4)), - ((0, 2), (13, 14)), - (Integer(2, is_signed=False), Integer(4, is_signed=False)), - ), - pytest.param( - lambda x, y: (x + 1.5, y + 9.6), - ((-1, 1), (3, 4)), - ((0.5, 2.5), (12.6, 13.6)), - (Float(64), Float(64)), - ), - pytest.param( - lambda x, y: (x + y + 1, x * y + 42), - ((-1, 1), (3, 4)), - ((3, 6), (38, 46)), - (Integer(3, is_signed=False), Integer(6, is_signed=False)), - ), - pytest.param( - lambda x, y: (x + y + 0.4, x * y + 41.7), - ((-1, 1), (3, 4)), - ((2.4, 5.4), (37.7, 45.7)), - (Float(64), Float(64)), - ), - pytest.param( - lambda x, y: (x + y + 1, x * y + 41.7), - ((-1, 1), (3, 4)), - ((3, 6), (37.7, 45.7)), - (Integer(3, is_signed=False), Float(64)), - ), - pytest.param( - lambda x, y: (x + y + 0.4, x * y + 42), - ((-1, 1), (3, 4)), - ((2.4, 5.4), (38, 46)), - (Float(64), Integer(6, is_signed=False)), - ), - ], -) -def test_eval_op_graph_bounds_on_inputset_multiple_output( - function, - input_ranges, - expected_output_bounds, - expected_output_data_type: Tuple[Integer], -): - """Test function for eval_op_graph_bounds_on_inputset""" - - op_graph = trace_numpy_function( - function, {"x": EncryptedScalar(Integer(64, True)), "y": EncryptedScalar(Integer(64, True))} - ) - - def data_gen(range_x, range_y): - for x_gen in range_x: - for y_gen in range_y: - yield (x_gen, y_gen) - - _, node_bounds_and_samples = eval_op_graph_bounds_on_inputset( - op_graph, - data_gen(*tuple(range(x[0], x[1] + 1) for x in input_ranges)), - CompilationConfiguration(), - ) - - for i, output_node in op_graph.output_nodes.items(): - output_node_bounds = node_bounds_and_samples[output_node] - assert (output_node_bounds["min"], output_node_bounds["max"]) == expected_output_bounds[i] - - op_graph.update_values_with_bounds_and_samples(node_bounds_and_samples) - - for i, output_node in op_graph.output_nodes.items(): - assert expected_output_data_type[i] == output_node.outputs[0].dtype - - -def test_eval_op_graph_bounds_on_non_conformant_inputset_default(capsys): - """Test function for eval_op_graph_bounds_on_inputset with non conformant inputset""" - - def f(x, y): - return np.dot(x, y) - - x = EncryptedTensor(UnsignedInteger(2), (3,)) - y = ClearTensor(UnsignedInteger(2), (3,)) - - inputset = [ - (np.array([2, 1, 3, 1]), np.array([1, 2, 1, 1])), - (np.array([3, 3, 3]), np.array([3, 3, 5])), - ] - - op_graph = trace_numpy_function(f, {"x": x, "y": y}) - - configuration = CompilationConfiguration() - eval_op_graph_bounds_on_inputset( - op_graph, - inputset, - compilation_configuration=configuration, - min_func=numpy_min_func, - max_func=numpy_max_func, - get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data, - ) - - captured = capsys.readouterr() - assert ( - captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters " - "(expected EncryptedTensor for parameter `x` " - "but got EncryptedTensor which is not compatible)\n" - "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters " - "(expected ClearTensor for parameter `y` " - "but got ClearTensor which is not compatible)\n" - ) - - -def test_eval_op_graph_bounds_on_non_conformant_inputset_check_all(capsys): - """Test function for eval_op_graph_bounds_on_inputset with non conformant inputset, check all""" - - def f(x, y): - return np.dot(x, y) - - x = EncryptedTensor(UnsignedInteger(2), (3,)) - y = ClearTensor(UnsignedInteger(2), (3,)) - - inputset = [ - (np.array([2, 1, 3, 1]), np.array([1, 2, 1, 1])), - (np.array([3, 3, 3]), np.array([3, 3, 5])), - ] - - op_graph = trace_numpy_function(f, {"x": x, "y": y}) - - configuration = CompilationConfiguration(check_every_input_in_inputset=True) - eval_op_graph_bounds_on_inputset( - op_graph, - inputset, - compilation_configuration=configuration, - min_func=numpy_min_func, - max_func=numpy_max_func, - get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data, - ) - - captured = capsys.readouterr() - assert ( - captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters " - "(expected EncryptedTensor for parameter `x` " - "but got EncryptedTensor which is not compatible)\n" - "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters " - "(expected ClearTensor for parameter `y` " - "but got ClearTensor which is not compatible)\n" - "Warning: Input #1 (0-indexed) is not coherent with the hinted parameters " - "(expected ClearTensor for parameter `y` " - "but got ClearTensor which is not compatible)\n" - ) - - -def test_eval_op_graph_bounds_on_conformant_numpy_inputset_check_all(capsys): - """Test function for eval_op_graph_bounds_on_inputset - with conformant inputset of numpy arrays, check all""" - - def f(x, y): - return np.dot(x, y) - - x = EncryptedTensor(UnsignedInteger(2), (3,)) - y = ClearTensor(UnsignedInteger(2), (3,)) - - inputset = [ - (np.array([2, 1, 3]), np.array([1, 2, 1])), - (np.array([3, 3, 3]), np.array([3, 3, 1])), - ] - - op_graph = trace_numpy_function(f, {"x": x, "y": y}) - - configuration = CompilationConfiguration(check_every_input_in_inputset=True) - eval_op_graph_bounds_on_inputset( - op_graph, - inputset, - compilation_configuration=configuration, - min_func=numpy_min_func, - max_func=numpy_max_func, - get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data, - ) - - captured = capsys.readouterr() - assert captured.err == "" - - -def test_eval_op_graph_bounds_on_non_conformant_numpy_inputset_check_all(capsys): - """Test function for eval_op_graph_bounds_on_inputset with non conformant inputset, check all""" - - def f(x, y): - return np.dot(x, y) - - x = EncryptedTensor(UnsignedInteger(2), (3,)) - y = ClearTensor(UnsignedInteger(2), (3,)) - - inputset = [ - (np.array([2, 1, 3, 1]), np.array([1, 2, 1, 1])), - (np.array([3, 3, 3]), np.array([3, 3, 5])), - ] - - op_graph = trace_numpy_function(f, {"x": x, "y": y}) - - configuration = CompilationConfiguration(check_every_input_in_inputset=True) - eval_op_graph_bounds_on_inputset( - op_graph, - inputset, - compilation_configuration=configuration, - min_func=numpy_min_func, - max_func=numpy_max_func, - get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data, - ) - - captured = capsys.readouterr() - assert ( - captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters " - "(expected EncryptedTensor for parameter `x` " - "but got EncryptedTensor which is not compatible)\n" - "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters " - "(expected ClearTensor for parameter `y` " - "but got ClearTensor which is not compatible)\n" - "Warning: Input #1 (0-indexed) is not coherent with the hinted parameters " - "(expected ClearTensor for parameter `y` " - "but got ClearTensor which is not compatible)\n" - ) - - -def test_eval_op_graph_bounds_on_non_conformant_inputset_treating_warnings_as_errors(): - """Test function for eval_op_graph_bounds_on_inputset with non conformant inputset and errors""" - - def f(x, y): - return np.dot(x, y) - - x = EncryptedTensor(UnsignedInteger(2), (3,)) - y = ClearTensor(UnsignedInteger(2), (3,)) - - inputset = [ - (np.array([2, 1, 3, 1]), np.array([1, 2, 1, 1])), - (np.array([3, 3, 3]), np.array([3, 3, 5])), - ] - - op_graph = trace_numpy_function(f, {"x": x, "y": y}) - - with pytest.raises(ValueError, match=".* is not coherent with the hinted parameters .*"): - configuration = CompilationConfiguration(treat_warnings_as_errors=True) - eval_op_graph_bounds_on_inputset( - op_graph, - inputset, - compilation_configuration=configuration, - min_func=numpy_min_func, - max_func=numpy_max_func, - get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data, - ) - - -def test_inpuset_eval_1_input(default_compilation_configuration): - """Test case for a function with a single parameter and passing the inputset without tuples.""" - - def f(x): - return x + 42 - - x = EncryptedScalar(UnsignedInteger(4)) - - inputset = range(10) - - op_graph = trace_numpy_function(f, {"x": x}) - - eval_op_graph_bounds_on_inputset( - op_graph, - inputset, - compilation_configuration=default_compilation_configuration, - min_func=numpy_min_func, - max_func=numpy_max_func, - get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data, - ) - - input_node = op_graph.input_nodes[0] - - assert input_node.inputs[0] == input_node.outputs[0] - assert input_node.inputs[0] == EncryptedScalar(UnsignedInteger(4)) - - output_node = op_graph.output_nodes[0] - - assert output_node.outputs[0] == EncryptedScalar(UnsignedInteger(6)) - - -# TODO: https://github.com/zama-ai/concrete-numpy-internal/issues/772 -# Remove once this issue is done -def test_inpuset_eval_1_input_refuse_tuple(default_compilation_configuration): - """Test case for a function with a single parameter and passing the inputset with tuples.""" - - def f(x): - return x + 42 - - x = EncryptedScalar(UnsignedInteger(4)) - - inputset = [(i,) for i in range(10)] - - op_graph = trace_numpy_function(f, {"x": x}) - - with pytest.raises(AssertionError) as excinfo: - eval_op_graph_bounds_on_inputset( - op_graph, - inputset, - compilation_configuration=default_compilation_configuration, - min_func=numpy_min_func, - max_func=numpy_max_func, - get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data, - ) - - assert str(excinfo.value) == "Tuples are unsupported for single input inputset evaluation" diff --git a/tests/common/compilation/test_artifacts.py b/tests/common/compilation/test_artifacts.py deleted file mode 100644 index 1b2e515d9..000000000 --- a/tests/common/compilation/test_artifacts.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Test file for compilation artifacts""" - -import tempfile -from pathlib import Path - -from concrete.common.compilation import CompilationArtifacts -from concrete.common.data_types.integers import UnsignedInteger -from concrete.common.values import EncryptedScalar -from concrete.numpy.compile import compile_numpy_function - - -def test_artifacts_export(default_compilation_configuration): - """Test function to check exporting compilation artifacts""" - - def function(x): - return x + 42 - - with tempfile.TemporaryDirectory() as tmp: - output_directory = Path(tmp) - artifacts = CompilationArtifacts(output_directory) - - compile_numpy_function( - function, - {"x": EncryptedScalar(UnsignedInteger(7))}, - range(10), - default_compilation_configuration, - compilation_artifacts=artifacts, - ) - - artifacts.export() - - assert output_directory.joinpath("environment.txt").exists() - assert output_directory.joinpath("requirements.txt").exists() - - assert output_directory.joinpath("function.txt").exists() - assert output_directory.joinpath("parameters.txt").exists() - - assert output_directory.joinpath("1.initial.graph.txt").exists() - assert output_directory.joinpath("1.initial.graph.png").exists() - - assert output_directory.joinpath("2.final.graph.txt").exists() - assert output_directory.joinpath("2.final.graph.png").exists() - - assert output_directory.joinpath("bounds.txt").exists() - assert output_directory.joinpath("mlir.txt").exists() - - # format of those files might change in the future - # so it is sufficient to test their existance diff --git a/tests/common/compilation/test_configuration.py b/tests/common/compilation/test_configuration.py deleted file mode 100644 index 0011f370c..000000000 --- a/tests/common/compilation/test_configuration.py +++ /dev/null @@ -1,77 +0,0 @@ -"""Test file for compilation configuration""" - -from inspect import signature - -import numpy -import pytest - -from concrete.common.compilation import CompilationConfiguration -from concrete.common.data_types.integers import Integer -from concrete.common.values import EncryptedScalar -from concrete.numpy.compile import compile_numpy_function_into_op_graph_and_measure_bounds - - -def no_fuse(x): - """No fuse""" - return x + 2 - - -def simple_fuse_not_output(x): - """Simple fuse not output""" - intermediate = x.astype(numpy.float64) - intermediate = intermediate.astype(numpy.uint32) - return intermediate + 2 - - -@pytest.mark.parametrize( - "function_to_trace,fused", - [ - pytest.param( - no_fuse, - False, - id="no_fuse", - ), - pytest.param( - simple_fuse_not_output, - True, - id="simple_fuse_not_output", - ), - ], -) -def test_enable_topological_optimizations( - test_helpers, function_to_trace, fused, default_compilation_configuration -): - """Test function for enable_topological_optimizations flag of compilation configuration""" - - op_graph = compile_numpy_function_into_op_graph_and_measure_bounds( - function_to_trace, - { - param: EncryptedScalar(Integer(32, is_signed=False)) - for param in signature(function_to_trace).parameters.keys() - }, - [numpy.array(i) for i in range(10)], - default_compilation_configuration, - ) - op_graph_not_optimized = compile_numpy_function_into_op_graph_and_measure_bounds( - function_to_trace, - { - param: EncryptedScalar(Integer(32, is_signed=False)) - for param in signature(function_to_trace).parameters.keys() - }, - [numpy.array(i) for i in range(10)], - CompilationConfiguration( - dump_artifacts_on_unexpected_failures=False, - enable_topological_optimizations=False, - treat_warnings_as_errors=True, - ), - ) - - graph = op_graph.graph - not_optimized_graph = op_graph_not_optimized.graph - - if fused: - assert not test_helpers.digraphs_are_equivalent(graph, not_optimized_graph) - assert len(graph) < len(not_optimized_graph) - else: - assert test_helpers.digraphs_are_equivalent(graph, not_optimized_graph) - assert len(graph) == len(not_optimized_graph) diff --git a/tests/common/data_types/test_dtypes_helpers.py b/tests/common/data_types/test_dtypes_helpers.py deleted file mode 100644 index 2ba821c8c..000000000 --- a/tests/common/data_types/test_dtypes_helpers.py +++ /dev/null @@ -1,299 +0,0 @@ -"""Test file for data types helpers""" -import pytest - -from concrete.common.data_types.base import BaseDataType -from concrete.common.data_types.dtypes_helpers import ( - broadcast_shapes, - find_type_to_hold_both_lossy, - mix_values_determine_holding_dtype, - value_is_encrypted_scalar_integer, - value_is_encrypted_scalar_unsigned_integer, -) -from concrete.common.data_types.floats import Float -from concrete.common.data_types.integers import Integer -from concrete.common.values import ( - BaseValue, - ClearScalar, - ClearTensor, - EncryptedScalar, - EncryptedTensor, -) - - -@pytest.mark.parametrize( - "value,expected_result", - [ - pytest.param( - ClearScalar(Integer(8, is_signed=False)), - False, - id="ClearScalar 8 bits unsigned Integer", - ), - pytest.param( - EncryptedScalar(Integer(8, is_signed=True)), - True, - id="EncryptedScalar 8 bits signed Integer", - ), - ], -) -def test_value_is_encrypted_integer(value: BaseValue, expected_result: bool): - """Test value_is_encrypted_integer helper""" - assert value_is_encrypted_scalar_integer(value) == expected_result - - -@pytest.mark.parametrize( - "value,expected_result", - [ - pytest.param( - ClearScalar(Integer(8, is_signed=False)), - False, - id="ClearScalar 8 bits unsigned Integer", - ), - pytest.param( - EncryptedScalar(Integer(8, is_signed=True)), - False, - id="EncryptedScalar 8 bits signed Integer", - ), - pytest.param( - EncryptedScalar(Integer(8, is_signed=False)), - True, - id="EncryptedScalar 8 bits unsigned Integer", - ), - ], -) -def test_value_is_encrypted_unsigned_integer(value: BaseValue, expected_result: bool): - """Test value_is_encrypted_unsigned_integer helper""" - assert value_is_encrypted_scalar_unsigned_integer(value) == expected_result - - -class UnsupportedDataType(BaseDataType): - """Test helper class to represent an UnsupportedDataType""" - - def __eq__(self, o: object) -> bool: - return isinstance(o, self.__class__) - - -@pytest.mark.parametrize( - "dtype1,dtype2,expected_mixed_dtype", - [ - pytest.param(Integer(6, True), Integer(6, True), Integer(6, True), id="int6, int6, int6"), - pytest.param( - Integer(6, False), Integer(6, False), Integer(6, False), id="uint6, uint6, uint6" - ), - pytest.param(Integer(6, True), Integer(6, False), Integer(7, True), id="int6, uint6, int7"), - pytest.param(Integer(6, False), Integer(6, True), Integer(7, True), id="uint6, int6, int7"), - pytest.param(Integer(6, True), Integer(5, False), Integer(6, True), id="int6, uint5, int6"), - pytest.param(Integer(5, False), Integer(6, True), Integer(6, True), id="uint5, int6, int6"), - pytest.param(Integer(32, True), Float(32), Float(32), id="int32, float32, float32"), - pytest.param(Integer(64, True), Float(32), Float(32), id="int64, float32, float32"), - pytest.param(Integer(64, True), Float(64), Float(64), id="int64, float64, float64"), - pytest.param(Integer(32, True), Float(64), Float(64), id="int32, float64, float64"), - pytest.param(Float(64), Integer(32, True), Float(64), id="float64, int32, float64"), - pytest.param(Float(64), Integer(7, False), Float(64), id="float64, uint7, float64"), - pytest.param(Float(32), Float(32), Float(32), id="float32, float32, float32"), - pytest.param(Float(32), Float(64), Float(64), id="float32, float64, float64"), - pytest.param(Float(64), Float(32), Float(64), id="float64, float32, float64"), - pytest.param(Float(64), Float(64), Float(64), id="float64, float64, float64"), - pytest.param( - UnsupportedDataType(), - UnsupportedDataType(), - None, - id="unsupported, unsupported, xfail", - marks=pytest.mark.xfail(strict=True), - ), - pytest.param( - Integer(6, True), - UnsupportedDataType(), - None, - id="int6, unsupported, xfail", - marks=pytest.mark.xfail(strict=True), - ), - pytest.param( - UnsupportedDataType(), - Integer(6, True), - None, - id="unsupported, int6, xfail", - marks=pytest.mark.xfail(strict=True), - ), - pytest.param( - UnsupportedDataType(), - Float(32), - None, - id="unsupported, float32, xfail", - marks=pytest.mark.xfail(strict=True), - ), - ], -) -def test_mix_data_types( - dtype1: BaseDataType, - dtype2: BaseDataType, - expected_mixed_dtype: BaseDataType, -): - """Test find_type_to_hold_both_lossy helper""" - assert expected_mixed_dtype == find_type_to_hold_both_lossy(dtype1, dtype2) - - -@pytest.mark.parametrize( - "value1,value2,expected_mixed_value", - [ - pytest.param( - EncryptedScalar(Integer(7, False)), - EncryptedScalar(Integer(7, False)), - EncryptedScalar(Integer(7, False)), - id="euint7, euint7, euint7", - ), - pytest.param( - EncryptedScalar(Integer(7, False)), - ClearScalar(Integer(7, False)), - EncryptedScalar(Integer(7, False)), - id="euint7, cuint7, euint7", - ), - pytest.param( - ClearScalar(Integer(7, False)), - EncryptedScalar(Integer(7, False)), - EncryptedScalar(Integer(7, False)), - id="cuint7, euint7, euint7", - ), - pytest.param( - ClearScalar(Integer(7, False)), - ClearScalar(Integer(7, False)), - ClearScalar(Integer(7, False)), - id="cuint7, cuint7, cuint7", - ), - pytest.param( - ClearScalar(Float(32)), - ClearScalar(Float(32)), - ClearScalar(Float(32)), - id="cfloat32, cfloat32, cfloat32", - ), - pytest.param( - EncryptedScalar(Float(32)), - ClearScalar(Float(32)), - EncryptedScalar(Float(32)), - id="efloat32, cfloat32, efloat32", - ), - ], -) -def test_mix_scalar_values(value1, value2, expected_mixed_value): - """Test mix_values_determine_holding_dtype helper with scalars""" - - assert expected_mixed_value == mix_values_determine_holding_dtype(value1, value2) - - -@pytest.mark.parametrize( - "value1,value2,expected_mixed_value", - [ - pytest.param( - EncryptedTensor(Integer(7, False), (1, 2, 3)), - EncryptedTensor(Integer(7, False), (1, 2, 3)), - EncryptedTensor(Integer(7, False), (1, 2, 3)), - ), - pytest.param( - ClearTensor(Integer(7, False), (1, 2, 3)), - EncryptedTensor(Integer(7, False), (1, 2, 3)), - EncryptedTensor(Integer(7, False), (1, 2, 3)), - ), - pytest.param( - ClearTensor(Integer(7, False), (1, 2, 3)), - ClearTensor(Integer(7, False), (1, 2, 3)), - ClearTensor(Integer(7, False), (1, 2, 3)), - ), - pytest.param( - ClearTensor(Integer(7, False), (1, 2, 3)), - ClearTensor(Integer(7, False), (1, 2, 3)), - ClearTensor(Integer(7, False), (1, 2, 3)), - ), - pytest.param( - ClearTensor(Integer(7, False), (1, 2, 3)), - EncryptedScalar(Integer(7, False)), - None, - marks=pytest.mark.xfail(strict=True, raises=AssertionError), - ), - pytest.param( - ClearTensor(Integer(7, False), (1, 2, 3)), - ClearTensor(Integer(7, False), (3, 2, 1)), - None, - marks=pytest.mark.xfail(strict=True, raises=AssertionError), - ), - ], -) -def test_mix_tensor_values(value1, value2, expected_mixed_value): - """Test mix_values_determine_holding_dtype helper with tensors""" - - assert expected_mixed_value == mix_values_determine_holding_dtype(value1, value2) - - -class DummyValue(BaseValue): - """DummyValue""" - - def __eq__(self, other: object) -> bool: - return BaseValue.__eq__(self, other) - - -def test_fail_mix_values_determine_holding_dtype(): - """Test function for failure case of mix_values_determine_holding_dtype""" - - with pytest.raises(ValueError, match=r".* does not support value .*"): - mix_values_determine_holding_dtype( - DummyValue(Integer(32, True), True), - DummyValue(Integer(32, True), True), - ) - - -@pytest.mark.parametrize( - "shape1,shape2,expected_shape", - [ - pytest.param((), (), ()), - pytest.param((3,), (), (3,)), - pytest.param((3,), (1,), (3,)), - pytest.param((3,), (2,), None), - pytest.param((3,), (3,), (3,)), - pytest.param((2, 3), (), (2, 3)), - pytest.param((2, 3), (1,), (2, 3)), - pytest.param((2, 3), (2,), None), - pytest.param((2, 3), (3,), (2, 3)), - pytest.param((2, 3), (1, 1), (2, 3)), - pytest.param((2, 3), (2, 1), (2, 3)), - pytest.param((2, 3), (3, 1), None), - pytest.param((2, 3), (1, 2), None), - pytest.param((2, 3), (2, 2), None), - pytest.param((2, 3), (3, 2), None), - pytest.param((2, 3), (1, 3), (2, 3)), - pytest.param((2, 3), (2, 3), (2, 3)), - pytest.param((2, 3), (3, 3), None), - pytest.param((2, 1, 3), (1, 1, 1), (2, 1, 3)), - pytest.param((2, 1, 3), (1, 4, 1), (2, 4, 3)), - pytest.param((2, 1, 3), (2, 4, 3), (2, 4, 3)), - # Tests cases taken from `numpy` - # https://github.com/numpy/numpy/blob/623bc1fae1d47df24e7f1e29321d0c0ba2771ce0/numpy/lib/tests/test_stride_tricks.py#L296-L351 - pytest.param((1, 2), (2,), (1, 2)), - pytest.param((1, 1), (3, 4), (3, 4)), - pytest.param((1, 3), (3, 1), (3, 3)), - pytest.param((1, 0), (0, 0), (0, 0)), - pytest.param((0, 1), (0, 0), (0, 0)), - pytest.param((1, 0), (0, 1), (0, 0)), - pytest.param((1, 1), (0, 0), (0, 0)), - pytest.param((1, 1), (1, 0), (1, 0)), - pytest.param((1, 1), (0, 1), (0, 1)), - pytest.param((), (0,), (0,)), - pytest.param((0,), (0, 0), (0, 0)), - pytest.param((0,), (0, 1), (0, 0)), - pytest.param((1,), (0, 0), (0, 0)), - pytest.param((2,), (0, 0), (0, 0)), - pytest.param((), (0, 0), (0, 0)), - pytest.param((1, 1), (0,), (1, 0)), - pytest.param((1,), (0, 1), (0, 1)), - pytest.param((1,), (1, 0), (1, 0)), - pytest.param((), (1, 0), (1, 0)), - pytest.param((), (0, 1), (0, 1)), - pytest.param((1,), (3,), (3,)), - pytest.param((2,), (3, 2), (3, 2)), - pytest.param((3,), (4,), None), - pytest.param((2, 3), (2,), None), - pytest.param((1, 3, 4), (2, 3, 3), None), - pytest.param((2,), (2, 3), None), - ], -) -def test_broadcast_shapes(shape1, shape2, expected_shape): - """Test function for `broadcast_shapes` helper""" - assert broadcast_shapes(shape1=shape1, shape2=shape2) == expected_shape - assert broadcast_shapes(shape1=shape2, shape2=shape1) == expected_shape diff --git a/tests/common/data_types/test_floats.py b/tests/common/data_types/test_floats.py deleted file mode 100644 index fc33acce5..000000000 --- a/tests/common/data_types/test_floats.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Test file for float data types""" - - -import pytest - -from concrete.common.data_types.floats import Float, Float32, Float64 - - -@pytest.mark.parametrize( - "float_,expected_repr_str", - [ - pytest.param( - Float32(), - "Float<32 bits>", - id="Float32", - ), - pytest.param( - Float(32), - "Float<32 bits>", - id="32 bits Float", - ), - pytest.param( - Float64(), - "Float<64 bits>", - id="Float64", - ), - pytest.param( - Float(64), - "Float<64 bits>", - id="64 bits Float", - ), - ], -) -def test_floats_repr(float_: Float, expected_repr_str: str): - """Test float repr""" - assert float_.__repr__() == expected_repr_str - - -@pytest.mark.parametrize( - "float_1,float_2,expected_equal", - [ - pytest.param(Float32(), Float(32), True), - pytest.param(Float(64), Float32(), False), - pytest.param(Float64(), Float(64), True), - ], -) -def test_floats_eq(float_1: Float, float_2: Float, expected_equal: bool): - """Test float eq""" - assert expected_equal == (float_1 == float_2) - assert expected_equal == (float_2 == float_1) diff --git a/tests/common/data_types/test_integers.py b/tests/common/data_types/test_integers.py deleted file mode 100644 index 59c186b29..000000000 --- a/tests/common/data_types/test_integers.py +++ /dev/null @@ -1,112 +0,0 @@ -"""Test file for integers data types""" - -import random - -import pytest - -from concrete.common.data_types.integers import ( - Integer, - SignedInteger, - UnsignedInteger, - make_integer_to_hold, -) - - -@pytest.mark.parametrize( - "integer,expected_min,expected_max", - [ - pytest.param(Integer(8, is_signed=False), 0, 255, id="8 bits unsigned Integer"), - pytest.param(UnsignedInteger(8), 0, 255, id="8 bits UnsignedInteger"), - pytest.param(Integer(8, is_signed=True), -128, 127, id="8 bits signed Integer"), - pytest.param(SignedInteger(8), -128, 127, id="8 bits SignedInteger"), - pytest.param(Integer(32, is_signed=False), 0, 4_294_967_295, id="32 bits unsigned Integer"), - pytest.param(UnsignedInteger(32), 0, 4_294_967_295, id="32 bits UnsignedInteger"), - pytest.param( - Integer(32, is_signed=True), - -2_147_483_648, - 2_147_483_647, - id="32 bits signed Integer", - ), - pytest.param( - SignedInteger(32), - -2_147_483_648, - 2_147_483_647, - id="32 bits SignedInteger", - ), - ], -) -def test_basic_integers(integer: Integer, expected_min: int, expected_max: int): - """Test integer class basic functions""" - assert integer.min_value() == expected_min - assert integer.max_value() == expected_max - - assert integer.can_represent_value(random.randint(expected_min, expected_max)) - assert not integer.can_represent_value(expected_min - 1) - assert not integer.can_represent_value(expected_max + 1) - - -@pytest.mark.parametrize( - "integer,expected_repr_str", - [ - pytest.param( - Integer(8, is_signed=False), - "Integer", - id="8 bits unsigned Integer", - ), - pytest.param( - Integer(8, is_signed=True), - "Integer", - id="8 bits signed Integer", - ), - pytest.param( - Integer(32, is_signed=False), - "Integer", - id="32 bits unsigned Integer", - ), - pytest.param( - Integer(32, is_signed=True), - "Integer", - id="32 bits signed Integer", - ), - ], -) -def test_integers_repr(integer: Integer, expected_repr_str: str): - """Test integer repr""" - assert integer.__repr__() == expected_repr_str - - -@pytest.mark.parametrize( - "values,force_signed,expected_result", - [ - ([0], False, Integer(1, is_signed=False)), - ([0], True, Integer(2, is_signed=True)), - ([1], False, Integer(1, is_signed=False)), - ([1], True, Integer(2, is_signed=True)), - ([-1], False, Integer(2, is_signed=True)), - ([-2], False, Integer(2, is_signed=True)), - ([0, 1], False, Integer(1, is_signed=False)), - ([0, 1], True, Integer(2, is_signed=True)), - ([7], False, Integer(3, is_signed=False)), - ([7], True, Integer(4, is_signed=True)), - ([8], False, Integer(4, is_signed=False)), - ([8], True, Integer(5, is_signed=True)), - ([-7], False, Integer(4, is_signed=True)), - ([-8], False, Integer(4, is_signed=True)), - ([-7, -8], False, Integer(4, is_signed=True)), - ([-9], False, Integer(5, is_signed=True)), - ([-9], True, Integer(5, is_signed=True)), - ([0, 127], False, Integer(7, is_signed=False)), - ([0, 127], True, Integer(8, is_signed=True)), - ([0, 128], False, Integer(8, is_signed=False)), - ([0, 128], True, Integer(9, is_signed=True)), - ([-1, 127], False, Integer(8, is_signed=True)), - ([-256, 127], False, Integer(9, is_signed=True)), - ([-128, 127], False, Integer(8, is_signed=True)), - ([-128, 128], False, Integer(9, is_signed=True)), - ([-13, 4], False, Integer(5, is_signed=True)), - ([42, 1019], False, Integer(10, is_signed=False)), - ], -) -def test_make_integer_to_hold(values, force_signed, expected_result): - """Test make_integer_to_hold""" - assert expected_result == make_integer_to_hold(values, force_signed) diff --git a/tests/common/data_types/test_values.py b/tests/common/data_types/test_values.py deleted file mode 100644 index 25c24f488..000000000 --- a/tests/common/data_types/test_values.py +++ /dev/null @@ -1,86 +0,0 @@ -"""Test file for values related code.""" - -from copy import deepcopy -from functools import partial -from typing import Callable, Optional, Tuple, Union - -import pytest - -from concrete.common.data_types.base import BaseDataType -from concrete.common.data_types.floats import Float -from concrete.common.data_types.integers import Integer -from concrete.common.values import ClearTensor, EncryptedTensor, TensorValue - - -class DummyDtype(BaseDataType): - """Dummy Helper Dtype""" - - def __eq__(self, o: object) -> bool: - return isinstance(o, self.__class__) - - -@pytest.mark.parametrize( - "tensor_constructor,expected_is_encrypted", - [ - (ClearTensor, False), - (partial(TensorValue, is_encrypted=False), False), - (EncryptedTensor, True), - (partial(TensorValue, is_encrypted=True), True), - ], -) -@pytest.mark.parametrize( - "shape,expected_shape,expected_ndim,expected_size", - [ - ((), (), 0, 1), - ((3, 256, 256), (3, 256, 256), 3, 196_608), - ((1920, 1080, 3), (1920, 1080, 3), 3, 6_220_800), - ], -) -@pytest.mark.parametrize( - "data_type", - [ - Integer(7, False), - Integer(32, True), - Integer(32, False), - Integer(64, True), - Integer(64, False), - Float(32), - Float(64), - ], -) -def test_tensor_value( - tensor_constructor: Callable[..., TensorValue], - expected_is_encrypted: bool, - shape: Optional[Tuple[int, ...]], - expected_shape: Tuple[int, ...], - expected_ndim: int, - expected_size: int, - data_type: Union[Integer, Float], -): - """Test function for TensorValue""" - - tensor_value = tensor_constructor(dtype=data_type, shape=shape) - - assert expected_is_encrypted == tensor_value.is_encrypted - assert expected_shape == tensor_value.shape - assert expected_ndim == tensor_value.ndim - assert expected_size == tensor_value.size - - assert data_type == tensor_value.dtype - - other_tensor = deepcopy(tensor_value) - - assert other_tensor == tensor_value - - other_tensor_value = deepcopy(other_tensor) - other_tensor_value.dtype = DummyDtype() - assert other_tensor_value != tensor_value - - other_shape = tuple(val + 1 for val in shape) if shape is not None else () - other_shape += (2,) - other_tensor_value = tensor_constructor(dtype=data_type, shape=other_shape) - - assert other_tensor_value.shape != tensor_value.shape - assert other_tensor_value.ndim != tensor_value.ndim - assert other_tensor_value.size != tensor_value.size - assert other_tensor_value != tensor_value diff --git a/tests/common/debugging/test_custom_assert.py b/tests/common/debugging/test_custom_assert.py deleted file mode 100644 index b779f4bce..000000000 --- a/tests/common/debugging/test_custom_assert.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Test custom assert functions.""" -import pytest - -from concrete.common.debugging.custom_assert import assert_false, assert_not_reached, assert_true - - -def test_assert_not_functions(): - """Test custom assert functions""" - assert_true(True, "one check") - assert_false(False, "another check") - - with pytest.raises(AssertionError) as excinfo: - assert_not_reached("yet another one") - - assert "yet another one" in str(excinfo.value) - - with pytest.raises(AssertionError) as excinfo: - assert_true(False, "one failing check") - - assert "one failing check" in str(excinfo.value) - - with pytest.raises(AssertionError) as excinfo: - assert_false(True, "another failing check") - - assert "another failing check" in str(excinfo.value) diff --git a/tests/common/debugging/test_drawing.py b/tests/common/debugging/test_drawing.py deleted file mode 100644 index 34d3e6137..000000000 --- a/tests/common/debugging/test_drawing.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Test file for drawing""" - -import filecmp -import tempfile -from pathlib import Path - -from concrete.common.data_types.integers import Integer -from concrete.common.debugging import draw_graph -from concrete.common.values import EncryptedScalar -from concrete.numpy import NPFHECompiler -from concrete.numpy.compile import compile_numpy_function_into_op_graph_and_measure_bounds - - -def test_draw_graph_with_saving(default_compilation_configuration): - """Tests drawing and saving a graph""" - - def function(x): - return x + 42 - - op_graph = compile_numpy_function_into_op_graph_and_measure_bounds( - function, - {"x": EncryptedScalar(Integer(7, True))}, - range(-5, 5), - default_compilation_configuration, - ) - - compiler = NPFHECompiler(function, {"x": "encrypted"}, default_compilation_configuration) - - assert (got := compiler.draw_graph()) is None, got - - compiler.eval_on_inputset(range(-5, 5)) - - with tempfile.TemporaryDirectory() as tmp: - output_directory = Path(tmp) - output_file = output_directory.joinpath("test.png") - draw_graph(op_graph, save_to=output_file) - assert output_file.exists() - - output_file_compiler = output_directory.joinpath("test_compiler.png") - compiler_output_file = compiler.draw_graph(save_to=output_file_compiler) - assert compiler_output_file is not None - compiler_output_file = Path(compiler_output_file) - assert compiler_output_file == output_file_compiler - assert compiler_output_file.exists() - - assert filecmp.cmp(output_file, compiler_output_file) diff --git a/tests/common/debugging/test_formatting.py b/tests/common/debugging/test_formatting.py deleted file mode 100644 index 455802c78..000000000 --- a/tests/common/debugging/test_formatting.py +++ /dev/null @@ -1,160 +0,0 @@ -"""Test file for formatting""" - -import numpy - -from concrete.common.data_types.integers import Integer, UnsignedInteger -from concrete.common.debugging import format_operation_graph -from concrete.common.values import EncryptedScalar -from concrete.numpy import NPFHECompiler -from concrete.numpy.compile import ( - compile_numpy_function, - compile_numpy_function_into_op_graph_and_measure_bounds, -) - - -def test_format_operation_graph_with_multiple_edges(default_compilation_configuration): - """Test format_operation_graph with multiple edges""" - - def function(x): - return x + x - - op_graph = compile_numpy_function_into_op_graph_and_measure_bounds( - function, - {"x": EncryptedScalar(Integer(4, True))}, - range(0, 10), - default_compilation_configuration, - ) - - formatted_graph = format_operation_graph(op_graph) - assert ( - formatted_graph - == """ - -%0 = x # EncryptedScalar -%1 = add(%0, %0) # EncryptedScalar -return %1 - -""".strip() - ) - - -def test_format_operation_graph_with_offending_nodes(default_compilation_configuration): - """Test format_operation_graph with offending nodes""" - - def function(x): - return x + 42 - - op_graph = compile_numpy_function_into_op_graph_and_measure_bounds( - function, - {"x": EncryptedScalar(Integer(7, True))}, - range(-5, 5), - default_compilation_configuration, - ) - - highlighted_nodes = {op_graph.input_nodes[0]: ["foo"]} - formatted_graph = format_operation_graph(op_graph, highlighted_nodes=highlighted_nodes).strip() - assert ( - formatted_graph - == """ - -%0 = x # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo -%1 = 42 # ClearScalar -%2 = add(%0, %1) # EncryptedScalar -return %2 - -""".strip() - ) - - highlighted_nodes = {op_graph.input_nodes[0]: ["foo"], op_graph.output_nodes[0]: ["bar", "baz"]} - formatted_graph = format_operation_graph(op_graph, highlighted_nodes=highlighted_nodes).strip() - assert ( - formatted_graph - == """ - -%0 = x # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo -%1 = 42 # ClearScalar -%2 = add(%0, %1) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ bar - baz -return %2 - -""".strip() - ) - - -def test_format_operation_graph_with_fusing(default_compilation_configuration): - """Test format_operation_graph with fusing""" - - def function(x): - return (10 * (numpy.cos(x + 1) + 1)).astype(numpy.uint32) - - circuit = compile_numpy_function( - function, - { - "x": EncryptedScalar(UnsignedInteger(3)), - }, - range(2 ** 3), - default_compilation_configuration, - ) - - assert (got := str(circuit)) == ( - """ - -%0 = x # EncryptedScalar -%1 = 1 # ClearScalar -%2 = add(%0, %1) # EncryptedScalar -%3 = subgraph(%2) # EncryptedScalar -return %3 - -Subgraphs: - - %3 = subgraph(%2): - - %0 = 10 # ClearScalar - %1 = 1 # ClearScalar - %2 = float_subgraph_input # EncryptedScalar - %3 = cos(%2) # EncryptedScalar - %4 = add(%3, %1) # EncryptedScalar - %5 = mul(%4, %0) # EncryptedScalar - %6 = astype(%5, dtype=uint32) # EncryptedScalar - return %6 - -""".strip() - ), got - - compiler = NPFHECompiler(function, {"x": "encrypted"}, default_compilation_configuration) - - assert ( - got := str(compiler) - ) == "__str__ failed: OPGraph is None, NPFHECompiler needs evaluation on an inputset", got - - compiler.eval_on_inputset(range(2 ** 3)) - - # String is different here as the type that is first propagated to trace the opgraph is not the - # same - - assert (got := str(compiler)) == ( - """ - -%0 = x # EncryptedScalar -%1 = 1 # ClearScalar -%2 = add(%0, %1) # EncryptedScalar -%3 = subgraph(%2) # EncryptedScalar -return %3 - -Subgraphs: - - %3 = subgraph(%2): - - %0 = 10 # ClearScalar - %1 = 1 # ClearScalar - %2 = float_subgraph_input # EncryptedScalar - %3 = cos(%2) # EncryptedScalar - %4 = add(%3, %1) # EncryptedScalar - %5 = mul(%4, %0) # EncryptedScalar - %6 = astype(%5, dtype=uint32) # EncryptedScalar - return %6 -""".strip() - ), got diff --git a/tests/common/extensions/test_convolution.py b/tests/common/extensions/test_convolution.py deleted file mode 100644 index 8c2214ba1..000000000 --- a/tests/common/extensions/test_convolution.py +++ /dev/null @@ -1,193 +0,0 @@ -"""Test file for convolution""" - -import numpy as np -import pytest -import torch - -from concrete.common.extensions import convolution -from concrete.common.representation.intermediate import Conv2D -from concrete.common.tracing.base_tracer import BaseTracer -from concrete.common.values.tensors import TensorValue -from concrete.numpy.tracing import NPConstant, NPTracer - - -@pytest.mark.parametrize( - "kwargs, error_msg", - [ - pytest.param( - {"x": None, "weight": np.zeros(1)}, - "input x must be an ndarray, or a BaseTracer, not a", - ), - pytest.param( - {"x": np.zeros(1), "weight": None}, - "weight must be an ndarray, or a BaseTracer, not a", - ), - pytest.param( - {"x": np.zeros(1), "weight": np.zeros(1), "bias": 0}, - "bias must be an ndarray, a BaseTracer, or None, not a", - ), - pytest.param( - {"x": np.zeros(1), "weight": np.zeros(1), "strides": None}, - "strides must be a tuple, or list, not a", - ), - pytest.param( - {"x": np.zeros(1), "weight": np.zeros(1), "dilations": None}, - "dilations must be a tuple, or list, not a", - ), - pytest.param( - {"x": np.zeros(1), "weight": np.zeros(1), "pads": None}, - "padding must be a tuple, or list, not a", - ), - ], -) -def test_invalid_arg_types(kwargs, error_msg): - """Test function to make sure convolution doesn't accept invalid types""" - - with pytest.raises(TypeError) as err: - convolution.conv2d(**kwargs) - - assert error_msg in str(err) - - -@pytest.mark.parametrize( - "kwargs, error_msg", - [ - pytest.param( - {"x": np.zeros(1), "weight": np.zeros(1)}, - "input x should have size (N x C x H x W), not", - ), - pytest.param( - {"x": np.zeros((1, 2, 3, 4)), "weight": np.zeros(1)}, - "weight should have size (F x C x H x W), not", - ), - pytest.param( - { - "x": np.zeros((1, 2, 3, 4)), - "weight": np.zeros((1, 2, 3, 4)), - "bias": np.zeros((1, 2)), - }, - "bias should have size (F), not", - ), - pytest.param( - {"x": np.zeros(1), "weight": np.zeros(1), "strides": (1,)}, - "strides should be of the form", - ), - pytest.param( - {"x": np.zeros(1), "weight": np.zeros(1), "dilations": (1,)}, - "dilations should be of the form", - ), - pytest.param( - {"x": np.zeros(1), "weight": np.zeros(1), "pads": (1,)}, - "padding should be of the form", - ), - pytest.param( - {"x": np.zeros(1), "weight": np.zeros(1), "auto_pad": None}, - "invalid auto_pad is specified", - ), - ], -) -def test_invalid_input_shape(kwargs, error_msg): - """Test function to make sure convolution doesn't accept invalid shapes""" - - with pytest.raises((ValueError, AssertionError)) as err: - convolution.conv2d(**kwargs) - - assert error_msg in str(err) - - -@pytest.mark.parametrize( - "input_shape, weight_shape", - [ - pytest.param((1, 1, 4, 4), (1, 1, 2, 2)), - pytest.param((3, 1, 4, 4), (1, 1, 2, 2)), - pytest.param((1, 1, 4, 4), (3, 1, 2, 2)), - pytest.param((1, 3, 4, 4), (1, 3, 2, 2)), - pytest.param((4, 3, 4, 4), (3, 3, 2, 2)), - pytest.param((4, 3, 16, 16), (3, 3, 2, 2)), - pytest.param((4, 3, 16, 16), (3, 3, 3, 3)), - ], -) -@pytest.mark.parametrize("strides", [(1, 1), (1, 2), (2, 1), (2, 2)]) -@pytest.mark.parametrize("dilations", [(1, 1), (1, 2), (2, 1), (2, 2)]) -@pytest.mark.parametrize("has_bias", [True, False]) -@pytest.mark.parametrize("use_ndarray", [True, False]) -def test_tracing(input_shape, weight_shape, strides, dilations, has_bias, use_ndarray): - """Test function to make sure tracong of conv2d works properly""" - if has_bias: - bias = np.random.randint(0, 4, size=(weight_shape[0],)) - if not use_ndarray: - bias = NPTracer([], NPConstant(bias), 0) - else: - bias = None - - x = NPTracer([], NPConstant(np.random.randint(0, 4, size=input_shape)), 0) - weight = np.random.randint(0, 4, size=weight_shape) - if not use_ndarray: - weight = NPTracer([], NPConstant(weight), 0) - - output_tracer = convolution.conv2d(x, weight, bias, strides=strides, dilations=dilations) - traced_computation = output_tracer.traced_computation - assert isinstance(traced_computation, Conv2D) - - if has_bias: - assert len(output_tracer.inputs) == 3 - else: - assert len(output_tracer.inputs) == 2 - - assert all( - isinstance(input_, BaseTracer) for input_ in output_tracer.inputs - ), f"{output_tracer.inputs}" - - assert len(traced_computation.outputs) == 1 - output_value = traced_computation.outputs[0] - assert isinstance(output_value, TensorValue) and output_value.is_encrypted - # pylint: disable=no-member - expected_shape = torch.conv2d( - torch.randn(input_shape), - torch.randn(weight_shape), - torch.randn((weight_shape[0])), - stride=strides, - dilation=dilations, - ).shape - # pylint: enable=no-member - - assert output_value.shape == expected_shape - - -@pytest.mark.parametrize( - "input_shape, weight_shape", - [ - pytest.param((1, 1, 4, 4), (1, 1, 2, 2)), - pytest.param((3, 1, 4, 4), (1, 1, 2, 2)), - pytest.param((1, 1, 4, 4), (3, 1, 2, 2)), - pytest.param((1, 3, 4, 4), (1, 3, 2, 2)), - pytest.param((4, 3, 4, 4), (3, 3, 2, 2)), - pytest.param((4, 3, 16, 16), (3, 3, 2, 2)), - pytest.param((4, 3, 16, 16), (3, 3, 3, 3)), - ], -) -@pytest.mark.parametrize("strides", [(1, 1), (1, 2), (2, 1), (2, 2)]) -@pytest.mark.parametrize("dilations", [(1, 1), (1, 2), (2, 1), (2, 2)]) -@pytest.mark.parametrize("has_bias", [True, False]) -def test_evaluation(input_shape, weight_shape, strides, dilations, has_bias): - """Test function to make sure evaluation of conv2d on plain data works properly""" - if has_bias: - bias = np.random.randint(0, 4, size=(weight_shape[0],)) - else: - bias = np.zeros((weight_shape[0],)) - x = np.random.randint(0, 4, size=input_shape) - weight = np.random.randint(0, 4, size=weight_shape) - # pylint: disable=no-member - expected = torch.conv2d( - torch.tensor(x, dtype=torch.long), - torch.tensor(weight, dtype=torch.long), - torch.tensor(bias, dtype=torch.long), - stride=strides, - dilation=dilations, - ).numpy() - # pylint: enable=no-member - # conv2d should handle None biases - if not has_bias: - bias = None - result = convolution.conv2d(x, weight, bias, strides=strides, dilations=dilations) - assert (result == expected).all() diff --git a/tests/common/extensions/test_multi_table.py b/tests/common/extensions/test_multi_table.py deleted file mode 100644 index 6590e9699..000000000 --- a/tests/common/extensions/test_multi_table.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Test file for direct multi table lookups""" - -import random - -import numpy -import pytest - -from concrete.common.data_types.integers import Integer -from concrete.common.extensions.multi_table import MultiLookupTable -from concrete.common.extensions.table import LookupTable - -table_2b_to_2b = LookupTable([1, 2, 0, 3]) -table_2b_to_1b = LookupTable([1, 0, 0, 1]) -table_2b_to_3b = LookupTable([5, 2, 7, 0]) - -table_3b_to_2b = LookupTable([1, 2, 0, 3, 0, 3, 1, 2]) -table_3b_to_1b = LookupTable([1, 0, 0, 1, 1, 1, 1, 0]) -table_3b_to_3b = LookupTable([5, 2, 7, 0, 4, 1, 6, 2]) - -tables_2b = [table_2b_to_1b, table_2b_to_2b, table_2b_to_3b] -tables_3b = [table_3b_to_1b, table_3b_to_2b, table_3b_to_3b] - - -def test_multi_lookup_table_creation_and_indexing(): - """Test function for creating and indexing multi lookup tables""" - tables = [ - [tables_2b[random.randint(0, 2)], tables_2b[random.randint(0, 2)]], - [tables_2b[random.randint(0, 2)], tables_2b[random.randint(0, 2)]], - [tables_2b[random.randint(0, 2)], tables_2b[random.randint(0, 2)]], - ] - multitable = MultiLookupTable(tables) - - assert multitable.input_shape == (3, 2) - - assert isinstance(multitable.output_dtype, Integer) - assert multitable.output_dtype.bit_width <= 3 - - index = numpy.random.randint(0, 2 ** 2, size=multitable.input_shape).tolist() - result = multitable[index] - - for i in range(3): - for j in range(2): - assert result[i][j] == multitable.tables[i][j][index[i][j]], f"i={i}, j={j}" - - -@pytest.mark.parametrize( - "tables,match", - [ - pytest.param( - [ - [], - [table_2b_to_2b, table_2b_to_3b], - ], - "MultiLookupTable cannot have an empty array within it", - ), - pytest.param( - [ - [table_2b_to_1b, 42.0], - [table_2b_to_2b, table_2b_to_3b], - ], - "MultiLookupTable should have been made out of LookupTables " - "but it had an object of type float within it", - ), - pytest.param( - [ - [table_2b_to_2b], - [table_2b_to_2b, table_2b_to_3b], - [table_2b_to_2b, table_2b_to_1b], - ], - "MultiLookupTable should have the shape (3, 1) but it does not " - "(an array on dimension 1 has the size 2 but its size should have been 1 " - "as the expected shape is (3, 1))", - ), - pytest.param( - [ - [table_2b_to_2b, table_3b_to_3b], - [table_2b_to_2b, table_3b_to_1b], - ], - "LookupTables within a MultiLookupTable should have the same size but they do not " - "(there was a table with the size of 4 and another with the size of 8)", - ), - ], -) -def test_multi_lookup_table_creation_failure(tables, match): - """Test function for failing to create multi lookup tables""" - - with pytest.raises(ValueError) as excinfo: - MultiLookupTable(tables) - - assert str(excinfo.value) == match - - -@pytest.mark.parametrize( - "tables,index,match", - [ - pytest.param( - [ - [table_2b_to_2b, table_2b_to_1b, table_2b_to_3b], - [table_2b_to_1b, table_2b_to_2b, table_2b_to_3b], - ], - [ - [1, 2], - [3, 0], - ], - "Multiple Lookup Table of shape (2, 3) cannot be looked up with [[1, 2], [3, 0]] " - "(you should check your inputset)", - ), - ], -) -def test_multi_lookup_table_indexing_failure(tables, index, match): - """Test function for failing to index multi lookup tables""" - - table = MultiLookupTable(tables) - - with pytest.raises(ValueError) as excinfo: - table.__getitem__(index) - - assert str(excinfo.value) == match diff --git a/tests/common/extensions/test_table.py b/tests/common/extensions/test_table.py deleted file mode 100644 index 82ca10374..000000000 --- a/tests/common/extensions/test_table.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Test file for direct table lookups""" - -from copy import deepcopy - -import networkx as nx -import pytest - -from concrete.common import is_a_power_of_2 -from concrete.common.data_types.integers import Integer -from concrete.common.extensions.table import LookupTable -from concrete.common.representation import intermediate as ir -from concrete.common.values import EncryptedScalar -from concrete.numpy import tracing - - -def test_lookup_table_size_constraints(): - """Test function to make sure lookup tables have correct size""" - - table = [] - - # creating empty lookup table is not acceptable - with pytest.raises(ValueError): - LookupTable(table) - - for _ in range(512): - table.append(0) - - if is_a_power_of_2(len(table)): - # creating lookup table with 2^N entries are acceptable - LookupTable(table) - else: - # creating lookup table with anything other than 2^N entries are not acceptable - with pytest.raises(ValueError): - LookupTable(table) - - -def test_lookup_table_encrypted_lookup(test_helpers): - """Test function for tracing with explicit table lookups using encrypted inputs""" - - table = LookupTable([3, 6, 0, 2]) - - def f(x): - return table[x] - - x = EncryptedScalar(Integer(2, is_signed=False)) - op_graph = tracing.trace_numpy_function(f, {"x": x}) - - table_node = op_graph.output_nodes[0] - - assert table_node.get_table(op_graph.get_ordered_preds(table_node)) == [3, 6, 0, 2] - - ref_graph = nx.MultiDiGraph() - # Here is the ASCII drawing of the expected graph: - # (x) - (TLU) - - input_x = ir.Input(input_value=x, input_name="x", program_input_idx=0) - ref_graph.add_node(input_x) - - generic_function_output_value = deepcopy(x) - generic_function_output_value.dtype = table.output_dtype - - # pylint: disable=protected-access - # Need access to _checked_indexing to have is_equivalent_to work for ir.GenericFunction - output_arbitrary_function = ir.GenericFunction( - inputs=[x], - arbitrary_func=LookupTable._checked_indexing, - output_value=generic_function_output_value, - op_kind="TLU", - op_kwargs={"table": deepcopy(table.table)}, - op_name="TLU", - ) - # pylint: enable=protected-access - ref_graph.add_node(output_arbitrary_function) - - ref_graph.add_edge(input_x, output_arbitrary_function, input_idx=0, output_idx=0) - - # TODO: discuss if this check is enough as == is not overloaded properly for GenericFunction - assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph) - - -def test_lookup_table_encrypted_and_plain_lookup(test_helpers): - """Test function for tracing with explicit table lookups using encrypted and plain inputs""" - - table = LookupTable([3, 6, 0, 2, 1, 4, 5, 7]) - - def f(x): - return table[x] + table[0] - - x = EncryptedScalar(Integer(3, is_signed=False)) - op_graph = tracing.trace_numpy_function(f, {"x": x}) - - ref_graph = nx.MultiDiGraph() - # Here is the ASCII drawing of the expected graph: - # (x) - (TLU) - # \ - # (+) - # / - # (3) - - input_x = ir.Input(input_value=x, input_name="x", program_input_idx=0) - ref_graph.add_node(input_x) - - generic_function_output_value = deepcopy(x) - generic_function_output_value.dtype = table.output_dtype - - # pylint: disable=protected-access - # Need access to _checked_indexing to have is_equivalent_to work for ir.GenericFunction - intermediate_arbitrary_function = ir.GenericFunction( - inputs=[x], - arbitrary_func=LookupTable._checked_indexing, - output_value=generic_function_output_value, - op_kind="TLU", - op_kwargs={"table": deepcopy(table.table)}, - op_name="TLU", - ) - # pylint: enable=protected-access - ref_graph.add_node(intermediate_arbitrary_function) - - constant_3 = ir.Constant(3) - ref_graph.add_node(constant_3) - - output_add = ir.Add((intermediate_arbitrary_function.outputs[0], constant_3.outputs[0])) - ref_graph.add_node(output_add) - - ref_graph.add_edge(input_x, intermediate_arbitrary_function, input_idx=0, output_idx=0) - - ref_graph.add_edge(intermediate_arbitrary_function, output_add, input_idx=0, output_idx=0) - ref_graph.add_edge(constant_3, output_add, input_idx=1, output_idx=0) - - # TODO: discuss if this check is enough as == is not overloaded properly for GenericFunction - assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph) diff --git a/tests/common/helpers/test_python_helpers.py b/tests/common/helpers/test_python_helpers.py deleted file mode 100644 index 2b219b9b6..000000000 --- a/tests/common/helpers/test_python_helpers.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Test file for common python helpers""" - -from concrete.common.helpers.python_helpers import catch - - -def test_catch_failure(): - """Test case for when the function called with catch raises an exception.""" - - def f_fail(): - return 1 / 0 - - assert catch(f_fail) is None - - -def test_catch(): - """Test case for catch""" - - def f(*args, **kwargs): - return *args, dict(**kwargs) - - assert catch(f, (1, 2, 3,), **{"one": 1, "two": 2, "three": 3}) == ( - (1, 2, 3), - {"one": 1, "two": 2, "three": 3}, - ) diff --git a/tests/common/mlir/test_conversion_helpers.py b/tests/common/mlir/test_conversion_helpers.py deleted file mode 100644 index 2f672206b..000000000 --- a/tests/common/mlir/test_conversion_helpers.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Test file for MLIR conversion helpers.""" - -# pylint cannot extract symbol information of 'mlir' module so we need to disable some lints - -# pylint: disable=no-name-in-module - -import concrete.lang as concretelang -import pytest -from mlir.ir import Context, Location - -from concrete.common.data_types import Float, SignedInteger, UnsignedInteger -from concrete.common.mlir.conversion_helpers import integer_to_mlir_type, value_to_mlir_type -from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor - -# pylint: enable=no-name-in-module - - -@pytest.mark.parametrize( - "integer,is_encrypted,expected_mlir_type_str", - [ - pytest.param(SignedInteger(5), False, "i5"), - pytest.param(UnsignedInteger(5), False, "i5"), - pytest.param(SignedInteger(32), False, "i32"), - pytest.param(UnsignedInteger(32), False, "i32"), - pytest.param(SignedInteger(5), True, "!FHE.eint<5>"), - pytest.param(UnsignedInteger(5), True, "!FHE.eint<5>"), - ], -) -def test_integer_to_mlir_type(integer, is_encrypted, expected_mlir_type_str): - """Test function for integer to MLIR type conversion.""" - - with Context() as ctx, Location.unknown(): - concretelang.register_dialects(ctx) - assert str(integer_to_mlir_type(ctx, integer, is_encrypted)) == expected_mlir_type_str - - -@pytest.mark.parametrize( - "value,expected_mlir_type_str", - [ - pytest.param(ClearScalar(SignedInteger(5)), "i5"), - pytest.param(ClearTensor(SignedInteger(5), shape=(2, 3)), "tensor<2x3xi5>"), - pytest.param(EncryptedScalar(SignedInteger(5)), "!FHE.eint<5>"), - pytest.param(EncryptedTensor(SignedInteger(5), shape=(2, 3)), "tensor<2x3x!FHE.eint<5>>"), - pytest.param(ClearScalar(UnsignedInteger(5)), "i5"), - pytest.param(ClearTensor(UnsignedInteger(5), shape=(2, 3)), "tensor<2x3xi5>"), - pytest.param(EncryptedScalar(UnsignedInteger(5)), "!FHE.eint<5>"), - pytest.param(EncryptedTensor(UnsignedInteger(5), shape=(2, 3)), "tensor<2x3x!FHE.eint<5>>"), - ], -) -def test_value_to_mlir_type(value, expected_mlir_type_str): - """Test function for value to MLIR type conversion.""" - - with Context() as ctx, Location.unknown(): - concretelang.register_dialects(ctx) - assert str(value_to_mlir_type(ctx, value)) == expected_mlir_type_str - - -@pytest.mark.parametrize( - "value,expected_error_message", - [ - pytest.param( - ClearScalar(Float(32)), - "ClearScalar is not supported for MLIR conversion", - ), - pytest.param( - ClearTensor(Float(32), shape=(2, 3)), - "ClearTensor is not supported for MLIR conversion", - ), - pytest.param( - EncryptedScalar(Float(32)), - "EncryptedScalar is not supported for MLIR conversion", - ), - pytest.param( - EncryptedTensor(Float(32), shape=(2, 3)), - "EncryptedTensor is not supported for MLIR conversion", - ), - ], -) -def test_fail_value_to_mlir_type(value, expected_error_message): - """Test function for failed value to MLIR type conversion.""" - - with pytest.raises(TypeError) as excinfo: - with Context() as ctx, Location.unknown(): - concretelang.register_dialects(ctx) - value_to_mlir_type(ctx, value) - - assert str(excinfo.value) == expected_error_message diff --git a/tests/common/mlir/test_node_converter.py b/tests/common/mlir/test_node_converter.py deleted file mode 100644 index 521be627e..000000000 --- a/tests/common/mlir/test_node_converter.py +++ /dev/null @@ -1,107 +0,0 @@ -"""Test file for intermediate node to MLIR converter.""" - -import random - -import numpy -import pytest - -from concrete.common.data_types import UnsignedInteger -from concrete.common.values import EncryptedScalar, EncryptedTensor -from concrete.numpy import compile_numpy_function - - -@pytest.mark.parametrize( - "function_to_compile,parameters,inputset,expected_error_type,expected_error_message", - [ - pytest.param( - lambda x, y: x * y, - { - "x": EncryptedScalar(UnsignedInteger(3)), - "y": EncryptedScalar(UnsignedInteger(3)), - }, - [(random.randint(0, 7), random.randint(0, 7)) for _ in range(10)] + [(7, 7)], - NotImplementedError, - "Multiplication " - "between " - "EncryptedScalar " - "and " - "EncryptedScalar " - "cannot be converted to MLIR yet", - ), - pytest.param( - lambda x, y: x - y, - { - "x": EncryptedScalar(UnsignedInteger(3)), - "y": EncryptedScalar(UnsignedInteger(3)), - }, - [(random.randint(5, 7), random.randint(0, 5)) for _ in range(10)], - NotImplementedError, - "Subtraction " - "of " - "EncryptedScalar " - "from " - "EncryptedScalar " - "cannot be converted to MLIR yet", - ), - pytest.param( - lambda x, y: numpy.dot(x, y), - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(2,)), - "y": EncryptedTensor(UnsignedInteger(3), shape=(2,)), - }, - [ - ( - numpy.random.randint(0, 2 ** 3, size=(2,)), - numpy.random.randint(0, 2 ** 3, size=(2,)), - ) - for _ in range(10) - ] - + [(numpy.array([7, 7]), numpy.array([7, 7]))], - NotImplementedError, - "Dot product " - "between " - "EncryptedTensor " - "and " - "EncryptedTensor " - "cannot be converted to MLIR yet", - ), - pytest.param( - lambda x, y: x @ y, - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)), - "y": EncryptedTensor(UnsignedInteger(3), shape=(2, 1)), - }, - [ - ( - numpy.random.randint(0, 2 ** 3, size=(3, 2)), - numpy.random.randint(0, 2 ** 3, size=(2, 1)), - ) - for i in range(10) - ] - + [(numpy.array([[7, 7], [7, 7], [7, 7]]), numpy.array([[7], [7]]))], - NotImplementedError, - "Matrix multiplication " - "between " - "EncryptedTensor " - "and " - "EncryptedTensor " - "cannot be converted to MLIR yet", - ), - ], -) -def test_fail_node_conversion( - function_to_compile, - parameters, - inputset, - expected_error_type, - expected_error_message, - default_compilation_configuration, -): - """Test function for failed intermediate node conversion.""" - - with pytest.raises(expected_error_type) as excinfo: - compile_numpy_function( - function_to_compile, parameters, inputset, default_compilation_configuration - ) - - assert str(excinfo.value) == expected_error_message diff --git a/tests/common/optimization/test_float_fusing.py b/tests/common/optimization/test_float_fusing.py deleted file mode 100644 index 2f3902ae8..000000000 --- a/tests/common/optimization/test_float_fusing.py +++ /dev/null @@ -1,735 +0,0 @@ -"""Test file for float subgraph fusing""" - -import random -from copy import deepcopy -from inspect import signature - -import numpy -import pytest - -from concrete.common.data_types.integers import Integer -from concrete.common.debugging import format_operation_graph -from concrete.common.debugging.custom_assert import assert_not_reached -from concrete.common.optimization.topological import fuse_float_operations -from concrete.common.values import EncryptedScalar, EncryptedTensor -from concrete.numpy import tracing -from concrete.numpy.tracing import trace_numpy_function - - -def no_fuse(x): - """No fuse""" - return x + 2 - - -def no_fuse_unhandled(x, y): - """No fuse unhandled""" - x_1 = x + 0.7 - y_1 = y + 1.3 - intermediate = x_1 + y_1 - return intermediate.astype(numpy.int32) - - -def fusable_with_bigger_search(x, y): - """fusable with bigger search""" - x = x + 1 - x_1 = x.astype(numpy.int32) - x_1 = x_1 + 1.5 - x_2 = x.astype(numpy.int32) - x_2 = x_2 + 3.4 - add = x_1 + x_2 - add_int = add.astype(numpy.int32) - return add_int + y - - -def fusable_with_bigger_search_needs_second_iteration(x, y): - """fusable with bigger search and triggers a second iteration in the fusing""" - x = x + 1 - x = x + 0.5 - x = numpy.cos(x) - x_1 = x.astype(numpy.int32) - x_1 = x_1 + 1.5 - x_p = x + 1 - x_p2 = x_p + 1 - x_2 = (x_p + x_p2).astype(numpy.int32) - x_2 = x_2 + 3.4 - add = x_1 + x_2 - add_int = add.astype(numpy.int32) - return add_int + y - - -def no_fuse_big_constant_3_10_10(x): - """Pass an array x with size < 100 to trigger a no fuse condition.""" - x = x.astype(numpy.float64) - return (x + numpy.ones((3, 10, 10))).astype(numpy.int32) - - -def no_fuse_dot(x): - """No fuse dot""" - return numpy.dot(x, numpy.full((10,), 1.33, dtype=numpy.float64)).astype(numpy.int32) - - -def simple_create_fuse_opportunity(f, x): - """No fuse because the function is explicitely marked as unfusable in our code.""" - return f(x.astype(numpy.float64)).astype(numpy.int32) - - -def ravel_cases(x): - """Simple ravel cases""" - return simple_create_fuse_opportunity(numpy.ravel, x) - - -def transpose_cases(x): - """Simple transpose cases""" - return simple_create_fuse_opportunity(numpy.transpose, x) - - -def reshape_cases(x, newshape): - """Simple reshape cases""" - return simple_create_fuse_opportunity(lambda x: numpy.reshape(x, newshape), x) - - -def simple_fuse_not_output(x): - """Simple fuse not output""" - intermediate = x.astype(numpy.float64) - intermediate = intermediate.astype(numpy.int32) - return intermediate + 2 - - -def simple_fuse_output(x): - """Simple fuse output""" - return x.astype(numpy.float64).astype(numpy.int32) - - -def mix_x_and_y_intricately_and_call_f(function, x, y): - """Mix x and y in an intricated way, that can't be simplified by - an optimizer eg, and then call function - """ - intermediate = x + y - intermediate = intermediate + 2 - intermediate = intermediate.astype(numpy.float32) - intermediate = intermediate.astype(numpy.int32) - x_p_1 = intermediate + 1.5 - x_p_2 = intermediate + 2.7 - x_p_3 = function(x_p_1 + x_p_2) - return ( - x_p_3.astype(numpy.int32), - x_p_2.astype(numpy.int32), - (x_p_2 + 3).astype(numpy.int32), - x_p_3.astype(numpy.int32) + 67, - y, - (y + 4.7).astype(numpy.int32) + 3, - ) - - -def mix_x_and_y_and_call_f(function, x, y): - """Mix x and y and then call function""" - x_p_1 = x + 0.1 - x_p_2 = x + 0.2 - x_p_3 = function(x_p_1 + x_p_2) - return ( - x_p_3.astype(numpy.int32), - x_p_2.astype(numpy.int32), - (x_p_2 + 3).astype(numpy.int32), - x_p_3.astype(numpy.int32) + 67, - y, - (y + 4.7).astype(numpy.int32) + 3, - ) - - -def mix_x_and_y_into_range_0_to_1_and_call_f(function, x, y): - """Mix x and y and then call function, in such a way that the input to function is between - 0 and 1""" - x_p_1 = x + 0.1 - x_p_2 = x + 0.2 - x_p_4 = 1 - numpy.abs(numpy.sin(x_p_1 + x_p_2 + 0.3)) - x_p_3 = function(x_p_4) - return ( - x_p_3.astype(numpy.int32), - x_p_2.astype(numpy.int32), - (x_p_2 + 3).astype(numpy.int32), - x_p_3.astype(numpy.int32) + 67, - y, - (y + 4.7).astype(numpy.int32) + 3, - ) - - -def mix_x_and_y_into_integer_and_call_f(function, x, y): - """Mix x and y but keep the entry to function as an integer""" - x_p_1 = x + 1 - x_p_2 = x + 2 - x_p_3 = function(x_p_1 + x_p_2) - return ( - x_p_3.astype(numpy.int32), - x_p_2.astype(numpy.int32), - (x_p_2 + 3).astype(numpy.int32), - x_p_3.astype(numpy.int32) + 67, - y, - (y + 4.7).astype(numpy.int32) + 3, - ) - - -def get_func_params_int32(func, scalar=True): - """Returns a dict with parameters as scalar int32""" - - return { - param_name: EncryptedScalar(Integer(32, True)) - if scalar - else EncryptedTensor(Integer(32, True), (1,)) - for param_name in signature(func).parameters.keys() - } - - -@pytest.mark.parametrize( - "function_to_trace,fused,params,warning_message", - [ - pytest.param(no_fuse, False, get_func_params_int32(no_fuse), "", id="no_fuse"), - pytest.param( - no_fuse_unhandled, - False, - get_func_params_int32(no_fuse_unhandled), - """ - -The following subgraph is not fusable: - -%0 = x # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of 2 variable inputs (can only have 1 for fusing) -%1 = 0.7 # ClearScalar -%2 = y # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of 2 variable inputs (can only have 1 for fusing) -%3 = 1.3 # ClearScalar -%4 = add(%0, %1) # EncryptedScalar -%5 = add(%2, %3) # EncryptedScalar -%6 = add(%4, %5) # EncryptedScalar -%7 = astype(%6, dtype=int32) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ cannot fuse here as the subgraph has 2 variable inputs -return %7 - - """.strip(), # noqa: E501 # pylint: disable=line-too-long - id="no_fuse_unhandled", - ), - pytest.param( - fusable_with_bigger_search, - True, - get_func_params_int32(fusable_with_bigger_search), - None, - id="fusable_with_bigger_search", - ), - pytest.param( - fusable_with_bigger_search_needs_second_iteration, - True, - get_func_params_int32(fusable_with_bigger_search_needs_second_iteration), - None, - id="fusable_with_bigger_search", - ), - pytest.param( - no_fuse_dot, - False, - {"x": EncryptedTensor(Integer(32, True), (10,))}, - """ - -The following subgraph is not fusable: - -%0 = x # EncryptedTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10,) -%1 = [1.33 1.33 ... 1.33 1.33] # ClearTensor -%2 = dot(%0, %1) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ output shapes: #0, () are not the same as the subgraph's input: (10,) -%3 = astype(%2, dtype=int32) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ output shapes: #0, () are not the same as the subgraph's input: (10,) -return %3 - - """.strip(), # noqa: E501 # pylint: disable=line-too-long - id="no_fuse_dot", - ), - pytest.param( - ravel_cases, - False, - {"x": EncryptedTensor(Integer(32, True), (10, 20))}, - """ - -The following subgraph is not fusable: - -%0 = x # EncryptedTensor -%1 = astype(%0, dtype=float64) # EncryptedTensor -%2 = ravel(%1) # EncryptedTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable -%3 = astype(%2, dtype=int32) # EncryptedTensor -return %3 - - """.strip(), # noqa: E501 # pylint: disable=line-too-long - id="no_fuse_explicitely_ravel", - ), - pytest.param( - transpose_cases, - False, - {"x": EncryptedTensor(Integer(32, True), (10, 20))}, - """ - -The following subgraph is not fusable: - -%0 = x # EncryptedTensor -%1 = astype(%0, dtype=float64) # EncryptedTensor -%2 = transpose(%1) # EncryptedTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable -%3 = astype(%2, dtype=int32) # EncryptedTensor -return %3 - - """.strip(), # noqa: E501 # pylint: disable=line-too-long - id="no_fuse_explicitely_transpose", - ), - pytest.param( - lambda x: reshape_cases(x, (20, 10)), - False, - {"x": EncryptedTensor(Integer(32, True), (10, 20))}, - """ - -The following subgraph is not fusable: - -%0 = x # EncryptedTensor -%1 = astype(%0, dtype=float64) # EncryptedTensor -%2 = reshape(%1, newshape=(20, 10)) # EncryptedTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable -%3 = astype(%2, dtype=int32) # EncryptedTensor -return %3 - - """.strip(), # noqa: E501 # pylint: disable=line-too-long - id="no_fuse_explicitely_reshape", - ), - pytest.param( - no_fuse_big_constant_3_10_10, - False, - {"x": EncryptedTensor(Integer(32, True), (10, 10))}, - """ - -The following subgraph is not fusable: - -%0 = [[[1. 1. 1 ... . 1. 1.]]] # ClearTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this constant node has a bigger shape (3, 10, 10) than the subgraph's input: (10, 10) -%1 = x # EncryptedTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10, 10) -%2 = astype(%1, dtype=float64) # EncryptedTensor -%3 = add(%2, %0) # EncryptedTensor -%4 = astype(%3, dtype=int32) # EncryptedTensor -return %4 - - """.strip(), # noqa: E501 # pylint: disable=line-too-long - id="no_fuse_big_constant_3_10_10", - ), - pytest.param( - simple_fuse_not_output, - True, - get_func_params_int32(simple_fuse_not_output), - None, - id="simple_fuse_not_output", - ), - pytest.param( - simple_fuse_output, - True, - get_func_params_int32(simple_fuse_output), - None, - id="simple_fuse_output", - ), - pytest.param( - lambda x, y: mix_x_and_y_intricately_and_call_f(numpy.rint, x, y), - True, - get_func_params_int32(lambda x, y: None), - None, - id="mix_x_and_y_intricately_and_call_f_with_rint", - ), - pytest.param( - lambda x, y: mix_x_and_y_and_call_f(numpy.rint, x, y), - True, - get_func_params_int32(lambda x, y: None), - None, - id="mix_x_and_y_and_call_f_with_rint", - ), - pytest.param( - transpose_cases, - True, - get_func_params_int32(transpose_cases), - None, - id="transpose_cases scalar", - ), - pytest.param( - transpose_cases, - True, - {"x": EncryptedTensor(Integer(32, True), (10,))}, - None, - id="transpose_cases ndim == 1", - ), - pytest.param( - ravel_cases, - True, - {"x": EncryptedTensor(Integer(32, True), (10,))}, - None, - id="ravel_cases ndim == 1", - ), - pytest.param( - lambda x: reshape_cases(x, (10, 20)), - True, - {"x": EncryptedTensor(Integer(32, True), (10, 20))}, - None, - id="reshape_cases same shape", - ), - ], -) -def test_fuse_float_operations( - function_to_trace, - fused, - params, - warning_message, - capfd, - remove_color_codes, - check_array_equality, -): - """Test function for fuse_float_operations""" - - op_graph = trace_numpy_function( - function_to_trace, - params, - ) - copied_graph = deepcopy(op_graph) - orig_num_nodes = len(op_graph.graph) - fuse_float_operations(op_graph) - fused_num_nodes = len(op_graph.graph) - fuse_float_operations(copied_graph) - - # Check determinism - assert format_operation_graph(copied_graph) == format_operation_graph(op_graph) - - if fused: - assert fused_num_nodes < orig_num_nodes - else: - assert fused_num_nodes == orig_num_nodes - captured = capfd.readouterr() - assert warning_message in (output := remove_color_codes(captured.err)), output - - for input_ in [0, 2, 42, 44]: - inputs = () - for param_input_value in params.values(): - if param_input_value.is_scalar: - input_ = numpy.int32(input_) - else: - input_ = numpy.full(param_input_value.shape, input_, dtype=numpy.int32) - inputs += (input_,) - - check_array_equality(function_to_trace(*inputs), op_graph(*inputs)) - - -def subtest_tensor_no_fuse(fun, tensor_shape): - """Test case to verify float fusing is only applied on functions on scalars.""" - - if tensor_shape == (): - # We want tensors - return - - if fun in LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES: - # We need at least one input of the bivariate function to be float - return - - # Float fusing currently cannot work if the constant in a bivariate operator is bigger than the - # variable input. - # Make a broadcastable shape but with the constant being bigger - variable_tensor_shape = (1,) + tensor_shape - constant_bigger_shape = (random.randint(2, 10),) + tensor_shape - - def tensor_no_fuse(x): - intermediate = x.astype(numpy.float64) - intermediate = fun(intermediate, numpy.ones(constant_bigger_shape)) - return intermediate.astype(numpy.int32) - - function_to_trace = tensor_no_fuse - params_names = signature(function_to_trace).parameters.keys() - - op_graph = trace_numpy_function( - function_to_trace, - { - param_name: EncryptedTensor(Integer(32, True), shape=variable_tensor_shape) - for param_name in params_names - }, - ) - orig_num_nodes = len(op_graph.graph) - fuse_float_operations(op_graph) - fused_num_nodes = len(op_graph.graph) - - assert orig_num_nodes == fused_num_nodes - - -def check_results_are_equal(function_result, op_graph_result): - """Check the output of function execution and OPGraph evaluation are equal.""" - - if isinstance(function_result, tuple) and isinstance(op_graph_result, tuple): - assert len(function_result) == len(op_graph_result) - are_equal = ( - function_output == op_graph_output - for function_output, op_graph_output in zip(function_result, op_graph_result) - ) - elif not isinstance(function_result, tuple) and not isinstance(op_graph_result, tuple): - are_equal = (function_result == op_graph_result,) - else: - assert_not_reached(f"Incompatible outputs: {function_result}, {op_graph_result}") - - return all(value.all() if isinstance(value, numpy.ndarray) else value for value in are_equal) - - -def subtest_fuse_float_unary_operations_correctness(fun, tensor_shape): - """Test a unary function with fuse_float_operations.""" - - # Some manipulation to avoid issues with domain of definitions of functions - if fun == numpy.arccosh: - # 0 is not in the domain of definition - input_list = [1, 2, 42, 44] - super_fun_list = [mix_x_and_y_and_call_f] - elif fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan]: - # Needs values between 0 and 1 in the call function - input_list = [0, 2, 42, 44] - super_fun_list = [mix_x_and_y_into_range_0_to_1_and_call_f] - elif fun in [numpy.cosh, numpy.sinh, numpy.exp, numpy.exp2, numpy.expm1]: - # Not too large values to avoid overflows - input_list = [1, 2, 5, 11] - super_fun_list = [mix_x_and_y_and_call_f, mix_x_and_y_intricately_and_call_f] - else: - # Regular case - input_list = [0, 2, 42, 44] - super_fun_list = [mix_x_and_y_and_call_f, mix_x_and_y_intricately_and_call_f] - - for super_fun in super_fun_list: - - for input_ in input_list: - - def get_function_to_trace(): - return lambda x, y: super_fun(fun, x, y) - - function_to_trace = get_function_to_trace() - - params_names = signature(function_to_trace).parameters.keys() - - op_graph = trace_numpy_function( - function_to_trace, - { - param_name: EncryptedTensor(Integer(32, True), tensor_shape) - for param_name in params_names - }, - ) - copied_graph = deepcopy(op_graph) - orig_num_nodes = len(op_graph.graph) - fuse_float_operations(op_graph) - fused_num_nodes = len(op_graph.graph) - fuse_float_operations(copied_graph) - - # Check determinism - assert format_operation_graph(copied_graph) == format_operation_graph(op_graph) - - assert fused_num_nodes < orig_num_nodes - - # Check that the call to the function or to the op_graph evaluation give the same - # result - tensor_diversifier = ( - # The following +1 in the range is to avoid to have 0's which is not in the - # domain definition of some of our functions - numpy.arange(1, numpy.product(tensor_shape) + 1, dtype=numpy.int32).reshape( - tensor_shape - ) - if tensor_shape != () - else 1 - ) - - if fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan]: - # Domain of definition for these functions - tensor_diversifier = ( - numpy.ones(tensor_shape, dtype=numpy.int32) if tensor_shape != () else 1 - ) - - input_ = numpy.int32(input_ * tensor_diversifier) - - num_params = len(params_names) - assert num_params == 2 - - # Create inputs which are either of the form [x, x] or [x, y] - for j in range(4): - - if fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan] and j > 0: - # Domain of definition for these functions - break - - input_a = input_ - input_b = input_ + j - - if tensor_shape != (): - numpy.random.shuffle(input_a) - numpy.random.shuffle(input_b) - - inputs = (input_a, input_b) if random.randint(0, 1) == 0 else (input_b, input_a) - - function_result = function_to_trace(*inputs) - op_graph_result = op_graph(*inputs) - - assert check_results_are_equal(function_result, op_graph_result) - - -LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES = { - numpy.bitwise_and, - numpy.bitwise_or, - numpy.bitwise_xor, - numpy.gcd, - numpy.lcm, - numpy.ldexp, - numpy.left_shift, - numpy.logical_and, - numpy.logical_not, - numpy.logical_or, - numpy.logical_xor, - numpy.remainder, - numpy.right_shift, -} - - -def subtest_fuse_float_binary_operations_correctness(fun, tensor_shape): - """Test a binary functions with fuse_float_operations, with a constant as a source.""" - - for i in range(4): - - # Know if the function is defined for integer inputs - if fun in LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES: - if i not in [0, 2]: - continue - - # The .astype(numpy.float64) that we have in cases 0 and 2 is here to force - # a float output even for functions which return an integer (eg, XOR), such - # that our frontend always try to fuse them - - # The .astype(numpy.float64) that we have in cases 1 and 3 is here to force - # a float output even for functions which return a bool (eg, EQUAL), such - # that our frontend always try to fuse them - - # For bivariate functions: fix one of the inputs - if i == 0: - # With an integer in first position - ones_0 = numpy.ones(tensor_shape, dtype=numpy.int32) if tensor_shape != () else 1 - - def get_function_to_trace(): - return lambda x, y: fun(3 * ones_0, x + y).astype(numpy.float64).astype(numpy.int32) - - elif i == 1: - # With a float in first position - ones_1 = numpy.ones(tensor_shape, dtype=numpy.float64) if tensor_shape != () else 1 - - def get_function_to_trace(): - return ( - lambda x, y: fun(2.3 * ones_1, x + y).astype(numpy.float64).astype(numpy.int32) - ) - - elif i == 2: - # With an integer in second position - ones_2 = numpy.ones(tensor_shape, dtype=numpy.int32) if tensor_shape != () else 1 - - def get_function_to_trace(): - return lambda x, y: fun(x + y, 4 * ones_2).astype(numpy.float64).astype(numpy.int32) - - else: - # With a float in second position - ones_else = numpy.ones(tensor_shape, dtype=numpy.float64) if tensor_shape != () else 1 - - def get_function_to_trace(): - return ( - lambda x, y: fun(x + y, 5.7 * ones_else) - .astype(numpy.float64) - .astype(numpy.int32) - ) - - input_list = [0, 2, 42, 44] - - # Domain of definition - if fun in [numpy.true_divide, numpy.remainder, numpy.floor_divide, numpy.fmod]: - input_list = [2, 42, 44] - - for input_ in input_list: - function_to_trace = get_function_to_trace() - - params_names = signature(function_to_trace).parameters.keys() - - op_graph = trace_numpy_function( - function_to_trace, - { - param_name: EncryptedTensor(Integer(32, True), tensor_shape) - for param_name in params_names - }, - ) - copied_graph = deepcopy(op_graph) - orig_num_nodes = len(op_graph.graph) - fuse_float_operations(op_graph) - fused_num_nodes = len(op_graph.graph) - fuse_float_operations(copied_graph) - - # Check determinism - assert format_operation_graph(copied_graph) == format_operation_graph(op_graph) - - assert fused_num_nodes < orig_num_nodes - - # Check that the call to the function or to the op_graph evaluation give the same - # result - tensor_diversifier = ( - # The following +1 in the range is to avoid to have 0's which is not in the - # domain definition of some of our functions - numpy.arange(1, numpy.product(tensor_shape) + 1, dtype=numpy.int32).reshape( - tensor_shape - ) - if tensor_shape != () - else numpy.int64(1) - ) - # Make sure the tensor diversifier is a numpy variable, otherwise some cases may fail - # as python int and float don't have the astype method - input_ = input_ * tensor_diversifier - - num_params = len(params_names) - assert num_params == 2 - - # Create inputs which are either of the form [x, x] or [x, y] - for j in range(4): - inputs = (input_, input_ + j) - - function_result = function_to_trace(*inputs) - op_graph_result = op_graph(*inputs) - - assert check_results_are_equal(function_result, op_graph_result) - - -def subtest_fuse_float_binary_operations_dont_support_two_variables(fun, tensor_shape): - """Test a binary function with fuse_float_operations, with no constant as - a source.""" - - def get_function_to_trace(): - return lambda x, y: fun(x, y).astype(numpy.int32) - - function_to_trace = get_function_to_trace() - - params_names = signature(function_to_trace).parameters.keys() - - with pytest.raises( - AssertionError, - match=r"Can only have 1 non constant predecessor in _np_operator, got 2 for operator", - ): - trace_numpy_function( - function_to_trace, - { - param_name: EncryptedTensor(Integer(32, True), tensor_shape) - for param_name in params_names - }, - ) - - -@pytest.mark.parametrize("fun", tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC) -@pytest.mark.parametrize( - "tensor_shape", [pytest.param((), id="scalar"), pytest.param((3, 1, 2), id="tensor")] -) -def test_ufunc_operations(fun, tensor_shape): - """Test functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC.""" - - if fun.nin == 1: - subtest_fuse_float_unary_operations_correctness(fun, tensor_shape) - elif fun.nin == 2: - subtest_fuse_float_binary_operations_correctness(fun, tensor_shape) - subtest_fuse_float_binary_operations_dont_support_two_variables(fun, tensor_shape) - subtest_tensor_no_fuse(fun, tensor_shape) - else: - raise NotImplementedError("Only unary and binary functions are tested for now") diff --git a/tests/common/representation/test_intermediate.py b/tests/common/representation/test_intermediate.py deleted file mode 100644 index ca5c073a1..000000000 --- a/tests/common/representation/test_intermediate.py +++ /dev/null @@ -1,433 +0,0 @@ -"""Test file for intermediate representation""" -from copy import deepcopy - -import numpy -import pytest - -from concrete.common.data_types.floats import Float -from concrete.common.data_types.integers import Integer -from concrete.common.representation import intermediate as ir -from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor - - -@pytest.mark.parametrize( - "node,input_data,expected_result", - [ - pytest.param( - ir.Add([EncryptedScalar(Integer(64, False)), EncryptedScalar(Integer(64, False))]), - [10, 4589], - 4599, - id="Add", - ), - pytest.param( - ir.Sub([EncryptedScalar(Integer(64, False)), EncryptedScalar(Integer(64, False))]), - [10, 4589], - -4579, - id="Sub", - ), - pytest.param( - ir.Mul([EncryptedScalar(Integer(64, False)), EncryptedScalar(Integer(64, False))]), - [10, 4589], - 45890, - id="Mul", - ), - pytest.param(ir.Input(ClearScalar(Integer(32, True)), "in", 0), [42], 42, id="Input"), - pytest.param(ir.Constant(42), None, 42, id="Constant"), - pytest.param(ir.Constant(-42), None, -42, id="Constant"), - pytest.param( - ir.GenericFunction( - [EncryptedScalar(Integer(7, False))], - lambda x: x + 3, - EncryptedScalar(Integer(7, False)), - op_kind="TLU", - ), - [10], - 13, - id="GenericFunction, x + 3", - ), - pytest.param( - ir.GenericFunction( - [EncryptedScalar(Integer(7, False))], - lambda x, y: x + y, - EncryptedScalar(Integer(7, False)), - op_kind="TLU", - op_kwargs={"y": 3}, - ), - [10], - 13, - id="GenericFunction, (x, y) -> x + y, where y is constant == 3", - ), - pytest.param( - ir.GenericFunction( - [EncryptedScalar(Integer(7, False))], - lambda x, y: y[x], - EncryptedScalar(Integer(7, False)), - op_kind="TLU", - op_kwargs={"y": (1, 2, 3, 4)}, - ), - [2], - 3, - id="GenericFunction, (x, y) -> y[x], where y is constant == (1, 2, 3, 4)", - ), - pytest.param( - ir.GenericFunction( - [EncryptedScalar(Integer(7, False))], - lambda x, y: y[3], - EncryptedScalar(Integer(7, False)), - op_kind="TLU", - op_kwargs={"y": (1, 2, 3, 4)}, - ), - [2], - 4, - id="GenericFunction, x, y -> y[3], where y is constant == (1, 2, 3, 4)", - ), - pytest.param( - ir.Dot( - [ - EncryptedTensor(Integer(32, True), shape=(4,)), - ClearTensor(Integer(32, True), shape=(4,)), - ], - Integer(32, True), - ), - [[1, 2, 3, 4], [4, 3, 2, 1]], - 20, - id="Dot, [1, 2, 3, 4], [4, 3, 2, 1]", - ), - pytest.param( - ir.Dot( - [ - EncryptedTensor(Float(32), shape=(4,)), - ClearTensor(Float(32), shape=(4,)), - ], - Float(32), - ), - [[1.0, 2.0, 3.0, 4.0], [4.0, 3.0, 2.0, 1.0]], - 20, - id="Dot, [1.0, 2.0, 3.0, 4.0], [4.0, 3.0, 2.0, 1.0]", - ), - pytest.param( - ir.Dot( - [ - EncryptedTensor(Integer(32, True), shape=(4,)), - ClearTensor(Integer(32, True), shape=(4,)), - ], - Integer(32, True), - delegate_evaluation_function=numpy.dot, - ), - [ - numpy.array([1, 2, 3, 4], dtype=numpy.int32), - numpy.array([4, 3, 2, 1], dtype=numpy.int32), - ], - 20, - id="Dot, np.array([1, 2, 3, 4]), np.array([4, 3, 2, 1])", - ), - pytest.param( - ir.IndexConstant(EncryptedTensor(Integer(4, True), shape=(4,)), (0,)), - [ - numpy.array([1, 2, 3, 4], dtype=numpy.int32), - ], - 1, - id="IndexConstant, np.array([1, 2, 3, 4])[0]", - ), - pytest.param( - ir.IndexConstant(EncryptedTensor(Integer(4, True), shape=(4,)), (slice(1, 3, None),)), - [ - numpy.array([1, 2, 3, 4], dtype=numpy.int32), - ], - numpy.array([2, 3]), - id="IndexConstant, np.array([1, 2, 3, 4])[1:3]", - ), - pytest.param( - ir.IndexConstant(EncryptedTensor(Integer(4, True), shape=(4,)), (slice(3, 1, -1),)), - [ - numpy.array([1, 2, 3, 4], dtype=numpy.int32), - ], - numpy.array([4, 3], dtype=numpy.int32), - id="IndexConstant, np.array([1, 2, 3, 4])[3:1:-1]", - ), - pytest.param( - ir.IndexConstant( - EncryptedTensor(Integer(5, True), shape=(4, 4)), (slice(1, 3, 1), slice(2, 0, -1)) - ), - [ - numpy.array( - [ - [1, 2, 3, 4], - [5, 6, 7, 8], - [9, 10, 11, 12], - [13, 14, 15, 16], - ], - dtype=numpy.int32, - ), - ], - numpy.array( - [ - [7, 6], - [11, 10], - ], - dtype=numpy.int32, - ), - id="IndexConstant, np.array([[1, 2, 3, 4]...[13, 14, 15, 16]])[1:3, 2:0:-1]", - ), - pytest.param( - ir.MatMul( - [ - EncryptedTensor(Integer(32, True), shape=(3, 2)), - ClearTensor(Integer(32, True), shape=(2, 3)), - ], - Integer(32, True), - (3, 3), - ), - [numpy.arange(1, 7).reshape(3, 2), numpy.arange(1, 7).reshape(2, 3)], - numpy.array([[9, 12, 15], [19, 26, 33], [29, 40, 51]]), - id="MatMul, numpy.arange(1, 7).reshape(3, 2), numpy.arange(1, 7).reshape(2, 3)", - ), - pytest.param( - ir.GenericFunction( - [EncryptedTensor(Integer(32, False), shape=(3, 5))], - lambda x: numpy.transpose(x), - EncryptedTensor(Integer(32, False), shape=(5, 3)), - op_kind="Memory", - ), - [numpy.arange(15).reshape(3, 5)], - numpy.array([[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8, 13], [4, 9, 14]]), - id="GenericFunction, x transpose", - ), - pytest.param( - ir.GenericFunction( - [EncryptedTensor(Integer(32, False), shape=(3, 5))], - lambda x: numpy.ravel(x), - EncryptedTensor(Integer(32, False), shape=(5, 3)), - op_kind="Memory", - ), - [numpy.arange(15).reshape(3, 5)], - numpy.arange(15), - id="GenericFunction, x ravel", - ), - pytest.param( - ir.GenericFunction( - [EncryptedTensor(Integer(32, False), shape=(3, 5))], - lambda x: numpy.reshape(x, (5, 3)), - output_value=EncryptedTensor(Integer(32, False), shape=(5, 3)), - op_kind="Memory", - ), - [numpy.arange(15).reshape(3, 5)], - numpy.arange(15).reshape(5, 3), - id="GenericFunction, x reshape", - ), - ], -) -def test_evaluate( - node: ir.IntermediateNode, - input_data, - expected_result: int, - check_array_equality, -): - """Test evaluate methods on IntermediateNodes""" - if isinstance(expected_result, numpy.ndarray): - check_array_equality(node.evaluate(input_data), expected_result) - else: - assert node.evaluate(input_data) == expected_result - - -@pytest.mark.parametrize( - "node1,node2,expected_result", - [ - ( - ir.Add([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]), - ir.Add([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]), - True, - ), - ( - ir.Add([EncryptedScalar(Integer(16, False)), EncryptedScalar(Integer(32, False))]), - ir.Add([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(16, False))]), - True, - ), - ( - ir.Add([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]), - ir.Sub([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]), - False, - ), - ( - ir.Sub([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]), - ir.Sub([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]), - True, - ), - ( - ir.Sub([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(16, False))]), - ir.Sub([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(16, False))]), - True, - ), - ( - ir.Sub([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(16, False))]), - ir.Sub([EncryptedScalar(Integer(16, False)), EncryptedScalar(Integer(32, False))]), - False, - ), - ( - ir.Mul([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]), - ir.Mul([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]), - True, - ), - ( - ir.Mul([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]), - ir.Sub([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]), - False, - ), - ( - ir.Input(EncryptedScalar(Integer(32, False)), "x", 0), - ir.Sub([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]), - False, - ), - ( - ir.Input(EncryptedScalar(Integer(32, False)), "x", 0), - ir.Input(EncryptedScalar(Integer(32, False)), "x", 0), - True, - ), - ( - ir.Input(EncryptedScalar(Integer(32, False)), "x", 0), - ir.Input(EncryptedScalar(Integer(32, False)), "y", 0), - False, - ), - ( - ir.Input(EncryptedScalar(Integer(32, False)), "x", 0), - ir.Input(EncryptedScalar(Integer(32, False)), "x", 1), - False, - ), - ( - ir.Input(EncryptedScalar(Integer(32, False)), "x", 0), - ir.Input(EncryptedScalar(Integer(8, False)), "x", 0), - False, - ), - ( - ir.Constant(10), - ir.Constant(10), - True, - ), - ( - ir.Constant(10), - ir.Input(EncryptedScalar(Integer(8, False)), "x", 0), - False, - ), - ( - ir.Constant(10), - ir.Constant(10.0), - False, - ), - ( - ir.GenericFunction( - [EncryptedScalar(Integer(8, False))], - lambda x: x, - EncryptedScalar(Integer(8, False)), - op_kind="TLU", - ), - ir.GenericFunction( - [EncryptedScalar(Integer(8, False))], - lambda x: x, - EncryptedScalar(Integer(8, False)), - op_kind="TLU", - ), - True, - ), - ( - ir.GenericFunction( - [EncryptedScalar(Integer(8, False))], - lambda x: x, - EncryptedScalar(Integer(8, False)), - op_kind="TLU", - op_args=(1, 2, 3), - ), - ir.GenericFunction( - [EncryptedScalar(Integer(8, False))], - lambda x: x, - EncryptedScalar(Integer(8, False)), - op_kind="TLU", - ), - False, - ), - ( - ir.GenericFunction( - [EncryptedScalar(Integer(8, False))], - lambda x: x, - EncryptedScalar(Integer(8, False)), - op_kind="TLU", - op_kwargs={"tuple": (1, 2, 3)}, - ), - ir.GenericFunction( - [EncryptedScalar(Integer(8, False))], - lambda x: x, - EncryptedScalar(Integer(8, False)), - op_kind="TLU", - ), - False, - ), - ( - ir.Dot( - [ - EncryptedTensor(Integer(32, True), shape=(4,)), - ClearTensor(Integer(32, True), shape=(4,)), - ], - Integer(32, True), - delegate_evaluation_function=numpy.dot, - ), - ir.Dot( - [ - EncryptedTensor(Integer(32, True), shape=(4,)), - ClearTensor(Integer(32, True), shape=(4,)), - ], - Integer(32, True), - delegate_evaluation_function=numpy.dot, - ), - True, - ), - ( - ir.Dot( - [ - EncryptedTensor(Integer(32, True), shape=(4,)), - ClearTensor(Integer(32, True), shape=(4,)), - ], - Integer(32, True), - delegate_evaluation_function=numpy.dot, - ), - ir.Dot( - [ - EncryptedTensor(Integer(32, True), shape=(4,)), - ClearTensor(Integer(32, True), shape=(4,)), - ], - Integer(32, True), - ), - False, - ), - ], -) -def test_is_equivalent_to( - node1: ir.IntermediateNode, - node2: ir.IntermediateNode, - expected_result: bool, - test_helpers, -): - """Test is_equivalent_to methods on IntermediateNodes""" - assert ( - test_helpers.nodes_are_equivalent(node1, node2) - == test_helpers.nodes_are_equivalent(node2, node1) - == expected_result - ) - - -@pytest.mark.parametrize( - "list_to_fill,expected_list", - [ - pytest.param([None, 1, 2, 3, None, None], [1, 1, 2, 3, 3, 3]), - pytest.param([None], None, marks=pytest.mark.xfail(strict=True)), - pytest.param([None, None, None, None, 7, None, None, None], [7, 7, 7, 7, 7, 7, 7, 7]), - pytest.param([None, None, 3, None, None, None, 2, None], [3, 3, 3, 3, 3, 2, 2, 2]), - ], -) -def test_flood_replace_none_values(list_to_fill: list, expected_list: list): - """Unit test for flood_replace_none_values""" - - # avoid modifying the test input - list_to_fill_copy = deepcopy(list_to_fill) - ir.flood_replace_none_values(list_to_fill_copy) - - assert all(value is not None for value in list_to_fill_copy) - assert list_to_fill_copy == expected_list diff --git a/tests/common/test_common_helpers.py b/tests/common/test_common_helpers.py deleted file mode 100644 index 575e118cf..000000000 --- a/tests/common/test_common_helpers.py +++ /dev/null @@ -1,74 +0,0 @@ -"""Test file for common helpers""" - -from copy import deepcopy - -import pytest - -from concrete.common import check_op_graph_is_integer_program, is_a_power_of_2 -from concrete.common.data_types.floats import Float64 -from concrete.common.data_types.integers import Integer -from concrete.common.values import EncryptedScalar -from concrete.numpy.tracing import trace_numpy_function - - -@pytest.mark.parametrize( - "x,result", - [ - (0, False), - (1, True), - (2, True), - (3, False), - (4, True), - (10, False), - (16, True), - ], -) -def test_is_a_power_of_2(x, result): - """Test function for test_is_a_power_of_2""" - - assert is_a_power_of_2(x) == result - - -def test_check_op_graph_is_integer_program(): - """Test function for check_op_graph_is_integer_program""" - - def function(x, y): - return x + y - y * y + x * y - - op_graph = trace_numpy_function( - function, {"x": EncryptedScalar(Integer(64, True)), "y": EncryptedScalar(Integer(64, True))} - ) - - # Test without and with output list - offending_nodes = [] - assert check_op_graph_is_integer_program(op_graph) - assert check_op_graph_is_integer_program(op_graph, offending_nodes) - assert len(offending_nodes) == 0 - - op_graph_copy = deepcopy(op_graph) - op_graph_copy.output_nodes[0].outputs[0].dtype = Float64 - - offending_nodes = [] - assert not check_op_graph_is_integer_program(op_graph_copy) - assert not check_op_graph_is_integer_program(op_graph_copy, offending_nodes) - assert len(offending_nodes) == 1 - assert offending_nodes == [op_graph_copy.output_nodes[0]] - - op_graph_copy = deepcopy(op_graph) - op_graph_copy.input_nodes[0].inputs[0].dtype = Float64 - - offending_nodes = [] - assert not check_op_graph_is_integer_program(op_graph_copy) - assert not check_op_graph_is_integer_program(op_graph_copy, offending_nodes) - assert len(offending_nodes) == 1 - assert offending_nodes == [op_graph_copy.input_nodes[0]] - - op_graph_copy = deepcopy(op_graph) - op_graph_copy.input_nodes[0].inputs[0].dtype = Float64 - op_graph_copy.input_nodes[1].inputs[0].dtype = Float64 - - offending_nodes = [] - assert not check_op_graph_is_integer_program(op_graph_copy) - assert not check_op_graph_is_integer_program(op_graph_copy, offending_nodes) - assert len(offending_nodes) == 2 - assert set(offending_nodes) == set([op_graph_copy.input_nodes[0], op_graph_copy.input_nodes[1]]) diff --git a/tests/common/test_fhe_circuit.py b/tests/common/test_fhe_circuit.py deleted file mode 100644 index 6a2c5d1a6..000000000 --- a/tests/common/test_fhe_circuit.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Test module for Circuit class""" - -import filecmp - -import concrete.numpy as hnp -from concrete.common.debugging import draw_graph, format_operation_graph - - -def test_circuit_str(default_compilation_configuration): - """Test function for `__str__` method of `Circuit`""" - - def f(x): - return x + 42 - - x = hnp.EncryptedScalar(hnp.UnsignedInteger(3)) - - inputset = range(2 ** 3) - circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration) - - assert str(circuit) == format_operation_graph(circuit.op_graph) - - -def test_circuit_draw(default_compilation_configuration): - """Test function for `draw` method of `Circuit`""" - - def f(x): - return x + 42 - - x = hnp.EncryptedScalar(hnp.UnsignedInteger(3)) - - inputset = range(2 ** 3) - circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration) - - assert filecmp.cmp(circuit.draw(), draw_graph(circuit.op_graph)) - assert filecmp.cmp(circuit.draw(vertical=False), draw_graph(circuit.op_graph, vertical=False)) - - -def test_circuit_run(default_compilation_configuration): - """Test equivalence of encrypt/run/decrypt and encrypt_run_decrypt""" - - def f(x): - return x + 42 - - x = hnp.EncryptedScalar(hnp.UnsignedInteger(3)) - - inputset = range(2 ** 3) - circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration) - - circuit.keygen() - for x in inputset: - enc_x = circuit.encrypt(x) - enc_res = circuit.run(enc_x) - res = circuit.decrypt(enc_res) - assert circuit.encrypt_run_decrypt(x) == res diff --git a/tests/common/tracing/test_tracing_helpers.py b/tests/common/tracing/test_tracing_helpers.py deleted file mode 100644 index 3cabf6ce4..000000000 --- a/tests/common/tracing/test_tracing_helpers.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Test file for common tracing helpers""" - -from typing import Any, Dict, List - -import pytest - -from concrete.common.tracing.tracing_helpers import prepare_function_parameters - - -@pytest.mark.parametrize( - "function,function_parameters,ref_dict", - [ - pytest.param(lambda x: None, {}, {}, id="Missing x", marks=pytest.mark.xfail(strict=True)), - pytest.param(lambda x: None, {"x": None}, {"x": None}, id="Only x"), - pytest.param( - lambda x: None, {"x": None, "y": None}, {"x": None}, id="Additional y filtered" - ), - ], -) -def test_prepare_function_parameters( - function, function_parameters: Dict[str, Any], ref_dict: Dict[str, Any] -): - """Test prepare_function_parameters""" - prepared_dict = prepare_function_parameters(function, function_parameters) - - assert prepared_dict == ref_dict - - -@pytest.mark.parametrize( - "function,function_parameters,expected_ordered_keys", - [ - (lambda x: None, {"x": None}, ["x"]), - (lambda x, y: None, {"x": None, "y": None}, ["x", "y"]), - (lambda x, y: None, {"y": None, "x": None}, ["x", "y"]), - (lambda z, x, y: None, {"y": None, "z": None, "x": None}, ["z", "x", "y"]), - ], -) -def test_prepare_function_parameters_order( - function, function_parameters: Dict[str, Any], expected_ordered_keys: List[str] -): - """Test prepare_function_parameters output order""" - prepared_dict = prepare_function_parameters(function, function_parameters) - - assert list(prepared_dict.keys()) == expected_ordered_keys diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index 3597b5c8f..000000000 --- a/tests/conftest.py +++ /dev/null @@ -1,473 +0,0 @@ -"""PyTest configuration file""" -import json -import operator -import re -import shutil -from pathlib import Path -from typing import Any, Callable, Dict, Iterable, Optional, Type - -import networkx as nx -import networkx.algorithms.isomorphism as iso -import numpy -import pytest - -from concrete.common.compilation import CompilationConfiguration -from concrete.common.fhe_circuit import FHECircuit -from concrete.common.mlir.utils import ( - ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB, - get_op_graph_max_bit_width_and_nodes_over_bit_width_limit, -) -from concrete.common.representation.intermediate import ( - ALL_IR_NODES, - Add, - Constant, - Conv2D, - Dot, - GenericFunction, - IndexConstant, - Input, - IntermediateNode, - MatMul, - Mul, - Sub, -) -from concrete.numpy import compile as compile_ - - -def pytest_addoption(parser): - """Options for pytest""" - - parser.addoption( - "--global-coverage-infos-json", - action="store", - default=None, - type=str, - help="To dump pytest-cov term report to a text file.", - ) - - parser.addoption( - "--keyring-dir", - action="store", - default=None, - type=str, - help="Specify the dir to use to store key cache", - ) - - -DEFAULT_KEYRING_PATH = Path.home().resolve() / ".cache/concrete-numpy_pytest" - - -def get_keyring_dir_from_session_or_default( - session: Optional[pytest.Session] = None, -) -> Optional[Path]: - """Get keyring dir from test session.""" - if session is None: - return DEFAULT_KEYRING_PATH - - keyring_dir = session.config.getoption("--keyring-dir", default=None) - if keyring_dir is not None: - if keyring_dir.lower() == "disable": - return None - keyring_dir = Path(keyring_dir).expanduser().resolve() - else: - keyring_dir = DEFAULT_KEYRING_PATH - return keyring_dir - - -@pytest.fixture -def default_keyring_path(): - """Fixture to get test keyring dir.""" - return DEFAULT_KEYRING_PATH - - -# This is only for doctests where we currently cannot make use of fixtures -original_compilation_config_init = CompilationConfiguration.__init__ - - -def monkeypatched_compilation_configuration_init_for_codeblocks( - self: CompilationConfiguration, *args, **kwargs -): - """Monkeypatched compilation configuration init for codeblocks tests.""" - original_compilation_config_init(self, *args, **kwargs) - self.dump_artifacts_on_unexpected_failures = False - self.enable_unsafe_features = True # This is for our tests only, never use that in prod - self.treat_warnings_as_errors = True - self.use_insecure_key_cache = True # This is for our tests only, never use that in prod - - -def pytest_sessionstart(session: pytest.Session): - """Handle keyring for session and codeblocks CompilationConfiguration if needed.""" - if session.config.getoption("--codeblocks", default=False): - # setattr to avoid mypy complaining - # Disable the flake8 bug bear warning for the mypy fix - setattr( # noqa: B010 - CompilationConfiguration, - "__init__", - monkeypatched_compilation_configuration_init_for_codeblocks, - ) - - keyring_dir = get_keyring_dir_from_session_or_default(session) - if keyring_dir is None: - return - keyring_dir.mkdir(parents=True, exist_ok=True) - keyring_dir_as_str = str(keyring_dir) - print(f"Using {keyring_dir_as_str} as key cache dir") - compile_._COMPILE_FHE_INSECURE_KEY_CACHE_DIR = ( # pylint: disable=protected-access - keyring_dir_as_str - ) - - -def pytest_sessionfinish(session: pytest.Session, exitstatus): # pylint: disable=unused-argument - """Pytest callback when testing ends.""" - # Hacked together from the source code, they don't have an option to export to file and it's too - # much work to get a PR in for such a little thing - # https://github.com/pytest-dev/pytest-cov/blob/ - # ec344d8adf2d78238d8f07cb20ed2463d7536970/src/pytest_cov/plugin.py#L329 - if session.config.pluginmanager.hasplugin("_cov"): - global_coverage_file = session.config.getoption( - "--global-coverage-infos-json", default=None - ) - if global_coverage_file is not None: - cov_plugin = session.config.pluginmanager.getplugin("_cov") - coverage_txt = cov_plugin.cov_report.getvalue() - coverage_status = 0 - if ( - cov_plugin.options.cov_fail_under is not None - and cov_plugin.options.cov_fail_under > 0 - ): - failed = cov_plugin.cov_total < cov_plugin.options.cov_fail_under - # If failed is False coverage_status is 0, if True it's 1 - coverage_status = int(failed) - global_coverage_file_path = Path(global_coverage_file).resolve() - with open(global_coverage_file_path, "w", encoding="utf-8") as f: - json.dump({"exit_code": coverage_status, "content": coverage_txt}, f) - - keyring_dir = get_keyring_dir_from_session_or_default(session) - if keyring_dir is not None: - # Remove incomplete keys - for incomplete_keys in keyring_dir.glob("**/*incomplete*"): - shutil.rmtree(incomplete_keys, ignore_errors=True) - - -def _is_equivalent_to_binary_commutative(lhs: IntermediateNode, rhs: object) -> bool: - """is_equivalent_to for a binary and commutative operation.""" - return ( - isinstance(rhs, lhs.__class__) - and (lhs.inputs in (rhs.inputs, rhs.inputs[::-1])) - and lhs.outputs == rhs.outputs - ) - - -def _is_equivalent_to_binary_non_commutative(lhs: IntermediateNode, rhs: object) -> bool: - """is_equivalent_to for a binary and non-commutative operation.""" - return ( - isinstance(rhs, lhs.__class__) and lhs.inputs == rhs.inputs and lhs.outputs == rhs.outputs - ) - - -def is_equivalent_add(lhs: Add, rhs: object) -> bool: - """Helper function to check if an Add node is equivalent to an other object.""" - return _is_equivalent_to_binary_commutative(lhs, rhs) - - -# From https://stackoverflow.com/a/28635464 -_code_and_constants_attr_getter = operator.attrgetter("co_code", "co_consts") - - -def _code_and_constants(object_): - """Helper function to get python code and constants""" - return _code_and_constants_attr_getter(object_.__code__) - - -def python_functions_are_equal_or_equivalent(lhs: object, rhs: object) -> bool: - """Helper function to check if two functions are equal or their code are equivalent. - - This is not perfect, but will be good enough for tests. - """ - - if lhs == rhs: - return True - - try: - lhs_code_and_constants = _code_and_constants(lhs) - rhs_code_and_constants = _code_and_constants(rhs) - return lhs_code_and_constants == rhs_code_and_constants - except AttributeError: - return False - - -def is_equivalent_arbitrary_function(lhs: GenericFunction, rhs: object) -> bool: - """Helper function to check if an GenericFunction node is equivalent to an other object.""" - return ( - isinstance(rhs, GenericFunction) - and python_functions_are_equal_or_equivalent(lhs.arbitrary_func, rhs.arbitrary_func) - and lhs.op_kind == rhs.op_kind - and lhs.op_args == rhs.op_args - and lhs.op_kwargs == rhs.op_kwargs - and lhs.op_attributes == rhs.op_attributes - and lhs.op_name == rhs.op_name - and is_equivalent_intermediate_node(lhs, rhs) - ) - - -def is_equivalent_constant(lhs: Constant, rhs: object) -> bool: - """Helper function to check if a Constant node is equivalent to an other object.""" - return ( - isinstance(rhs, Constant) - and lhs.constant_data == rhs.constant_data - and is_equivalent_intermediate_node(lhs, rhs) - ) - - -def is_equivalent_dot(lhs: Dot, rhs: object) -> bool: - """Helper function to check if a Dot node is equivalent to an other object.""" - return ( - isinstance(rhs, Dot) - and lhs.evaluation_function == rhs.evaluation_function - and is_equivalent_intermediate_node(lhs, rhs) - ) - - -def is_equivalent_input(lhs: Input, rhs: object) -> bool: - """Helper function to check if an Input node is equivalent to an other object.""" - return ( - isinstance(rhs, Input) - and lhs.input_name == rhs.input_name - and lhs.program_input_idx == rhs.program_input_idx - and is_equivalent_intermediate_node(lhs, rhs) - ) - - -def is_equivalent_index_constant(lhs: IndexConstant, rhs: object) -> bool: - """Helper function to check if an IndexConstant node is equivalent to an other object.""" - return ( - isinstance(rhs, IndexConstant) - and lhs.index == rhs.index - and is_equivalent_intermediate_node(lhs, rhs) - ) - - -def is_equivalent_mul(lhs: Mul, rhs: object) -> bool: - """Helper function to check if a Mul node is equivalent to an other object.""" - return _is_equivalent_to_binary_commutative(lhs, rhs) - - -def is_equivalent_sub(lhs: Sub, rhs: object) -> bool: - """Helper function to check if a Sub node is equivalent to an other object.""" - return _is_equivalent_to_binary_non_commutative(lhs, rhs) - - -def is_equivalent_matmul(lhs: MatMul, rhs: object) -> bool: - """Helper function to check if a MatMul node is equivalent to an other object.""" - return isinstance(rhs, MatMul) and is_equivalent_intermediate_node(lhs, rhs) - - -def is_equivalent_conv2d(lhs: Conv2D, rhs: object) -> bool: - """Helper function to check if a Conv2D node is equivalent to an other object.""" - return isinstance(rhs, Conv2D) and is_equivalent_intermediate_node(lhs, rhs) - - -def is_equivalent_intermediate_node(lhs: IntermediateNode, rhs: object) -> bool: - """Helper function to check if an IntermediateNode node is equivalent to an other object.""" - return ( - isinstance(rhs, IntermediateNode) - and lhs.inputs == rhs.inputs - and lhs.outputs == rhs.outputs - ) - - -EQUIVALENT_TEST_FUNC: Dict[Type, Callable[..., bool]] = { - Add: is_equivalent_add, - GenericFunction: is_equivalent_arbitrary_function, - Constant: is_equivalent_constant, - Conv2D: is_equivalent_conv2d, - Dot: is_equivalent_dot, - IndexConstant: is_equivalent_index_constant, - Input: is_equivalent_input, - Mul: is_equivalent_mul, - Sub: is_equivalent_sub, - MatMul: is_equivalent_matmul, -} - -_missing_nodes_in_mapping = ALL_IR_NODES - EQUIVALENT_TEST_FUNC.keys() -assert len(_missing_nodes_in_mapping) == 0, ( - f"Missing IR node in EQUIVALENT_TEST_FUNC : " - f"{', '.join(sorted(str(node_type) for node_type in _missing_nodes_in_mapping))}" -) - -del _missing_nodes_in_mapping - - -class TestHelpers: - """Class allowing to pass helper functions to tests""" - - @staticmethod - def nodes_are_equivalent(lhs, rhs) -> bool: - """Helper function for tests to check if two nodes are equivalent.""" - equivalent_func = EQUIVALENT_TEST_FUNC.get(type(lhs), None) - if equivalent_func is not None: - return equivalent_func(lhs, rhs) - - # This is a default for the test_conftest.py that should remain separate from the package - # nodes is_equivalent_* functions - return lhs.is_equivalent_to(rhs) - - @staticmethod - def digraphs_are_equivalent(reference: nx.MultiDiGraph, to_compare: nx.MultiDiGraph): - """Check that two digraphs are equivalent without modifications""" - # edge_match is a copy of node_match - edge_matcher = iso.categorical_multiedge_match(["input_idx", "output_idx"], [None, None]) - node_matcher = iso.generic_node_match( - "_test_content", None, TestHelpers.nodes_are_equivalent - ) - - # Set the _test_content for each node in the graphs - for node in reference.nodes(): - reference.add_node(node, _test_content=node) - - for node in to_compare.nodes(): - to_compare.add_node(node, _test_content=node) - - graphs_are_isomorphic = nx.is_isomorphic( - reference, - to_compare, - node_match=node_matcher, - edge_match=edge_matcher, - ) - - return graphs_are_isomorphic - - @staticmethod - def python_functions_are_equal_or_equivalent(lhs, rhs): - """Helper function to check if two functions are equal or their code are equivalent. - - This is not perfect, but will be good enough for tests. - """ - return python_functions_are_equal_or_equivalent(lhs, rhs) - - -@pytest.fixture -def test_helpers(): - """Fixture to return the static helper class""" - return TestHelpers - - -@pytest.fixture -def default_compilation_configuration(): - """Return the default test compilation configuration""" - return CompilationConfiguration( - dump_artifacts_on_unexpected_failures=False, - enable_unsafe_features=True, # This is for our tests only, never use that in prod - treat_warnings_as_errors=True, - use_insecure_key_cache=True, # This is for our tests only, never use that in prod - ) - - -REMOVE_COLOR_CODES_RE = re.compile(r"\x1b[^m]*m") - - -@pytest.fixture -def remove_color_codes(): - """Return the re object to remove color codes""" - return lambda x: REMOVE_COLOR_CODES_RE.sub("", x) - - -def check_is_good_execution_impl( - fhe_circuit: FHECircuit, - function: Callable, - args: Iterable[Any], - preprocess_input_func: Callable[[Any], Any] = lambda x: x, - postprocess_output_func: Callable[[Any], Any] = lambda x: x, - check_function: Callable[[Any, Any], bool] = numpy.equal, - verbose: bool = True, -): - """Run several times the check compiler_engine.run(*args) == function(*args). If always wrong, - return an error. One can set the expected probability of success of one execution and the - number of tests, to finetune the probability of bad luck, ie that we run several times the - check and always have a wrong result.""" - max_bit_width, _ = get_op_graph_max_bit_width_and_nodes_over_bit_width_limit( - fhe_circuit.op_graph - ) - - # Allow tests to pass if cells of the output result are good at least once over the nb_tries - # Enabled only when we have a circuit that's using the maximum possible bit width - # >= if there are 8 bits signed integers - allow_relaxed_tests_passing = max_bit_width >= ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB - - # FIXME: https://github.com/zama-ai/concrete-numpy-internal/issues/1255 - # Increased with compiler accuracy which dropped, make sure to remove once accuracy improves - nb_tries = 10 - - # Prepare the bool array to record if cells were properly computed - preprocessed_args = tuple(preprocess_input_func(val) for val in args) - cells_were_properly_computed = numpy.zeros_like(function(*preprocessed_args), dtype=bool) - - for i in range(1, nb_tries + 1): - preprocessed_args = tuple(preprocess_input_func(val) for val in args) - last_engine_result = postprocess_output_func( - fhe_circuit.encrypt_run_decrypt(*preprocessed_args) - ) - last_function_result = postprocess_output_func(function(*preprocessed_args)) - - ok_execution = check_function(last_engine_result, last_function_result) - if isinstance(ok_execution, numpy.ndarray): - # Record the cells that were well computed - cells_were_properly_computed = numpy.logical_or( - cells_were_properly_computed, ok_execution - ) - - # Get a boolean for the execution - ok_execution = ok_execution.all() - - if ok_execution: - # Good computation after i tries - if verbose: - print(f"Good computation after {i} tries") - return - # FIXME: https://github.com/zama-ai/concrete-numpy-internal/issues/1264 - # Remove the relaxed tests once accuracy is good again for 7 bits - if allow_relaxed_tests_passing and cells_were_properly_computed.all(): - print( - "Computation was never good for all output cells at the same time, " - f"however each was evaluated properly at least once, stopped after {i} tries" - ) - return - - raise AssertionError( - f"bad computation after {nb_tries} tries.\nLast engine result:\n{last_engine_result}\n" - f"Last function result:\n{last_function_result}" - ) - - -@pytest.fixture -def check_is_good_execution(): - """Fixture to seed torch""" - - return check_is_good_execution_impl - - -def check_array_equality_impl(actual: Any, expected: Any, verbose: bool = True): - """Assert that `actual` is equal to `expected`.""" - - assert numpy.array_equal(actual, expected), ( - "" - if not verbose - else f""" - -Expected Output -=============== -{expected} - -Actual Output -============= -{actual} - - """ - ) - - -@pytest.fixture -def check_array_equality(): - """Fixture to check array equality""" - - return check_array_equality_impl diff --git a/tests/helpers/test_conftest.py b/tests/helpers/test_conftest.py deleted file mode 100644 index 65a5d3a09..000000000 --- a/tests/helpers/test_conftest.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Test file for conftest helper functions""" -import networkx as nx - - -def test_digraphs_are_equivalent(test_helpers): - """Function to test digraphs_are_equivalent helper function""" - - class TestNode: - """Dummy test node""" - - computation: str - - def __init__(self, computation: str) -> None: - self.computation = computation - - def __hash__(self) -> int: - return self.computation.__hash__() - - def __eq__(self, other: object) -> bool: - return isinstance(other, self.__class__) and self.computation == other.computation - - is_equivalent_to = __eq__ - - g_1 = nx.MultiDiGraph() - g_2 = nx.MultiDiGraph() - - t_0 = TestNode("Add") - t_1 = TestNode("Mul") - t_2 = TestNode("TLU") - - g_1.add_edge(t_0, t_2, input_idx=0, output_idx=0) - g_1.add_edge(t_1, t_2, input_idx=1, output_idx=0) - - t0p = TestNode("Add") - t1p = TestNode("Mul") - t2p = TestNode("TLU") - - g_2.add_edge(t1p, t2p, input_idx=1, output_idx=0) - g_2.add_edge(t0p, t2p, input_idx=0, output_idx=0) - - bad_g2 = nx.MultiDiGraph() - - bad_t0 = TestNode("Not Add") - - bad_g2.add_edge(bad_t0, t_2, input_idx=0, output_idx=0) - bad_g2.add_edge(t_1, t_2, input_idx=1, output_idx=0) - - bad_g3 = nx.MultiDiGraph() - - bad_g3.add_edge(t_0, t_2, input_idx=1, output_idx=0) - bad_g3.add_edge(t_1, t_2, input_idx=0, output_idx=0) - - assert test_helpers.digraphs_are_equivalent(g_1, g_2), "Graphs should be equivalent" - assert not test_helpers.digraphs_are_equivalent(g_1, bad_g2), "Graphs should not be equivalent" - assert not test_helpers.digraphs_are_equivalent(g_2, bad_g2), "Graphs should not be equivalent" - assert not test_helpers.digraphs_are_equivalent(g_1, bad_g3), "Graphs should not be equivalent" - assert not test_helpers.digraphs_are_equivalent(g_2, bad_g3), "Graphs should not be equivalent" diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py deleted file mode 100644 index 7349ffc19..000000000 --- a/tests/numpy/test_compile.py +++ /dev/null @@ -1,2734 +0,0 @@ -"""Test file for numpy compilation functions""" -import itertools -import random -from copy import deepcopy - -import numpy -import pytest - -from concrete.common.compilation import CompilationConfiguration -from concrete.common.data_types.integers import Integer, SignedInteger, UnsignedInteger -from concrete.common.debugging import draw_graph, format_operation_graph -from concrete.common.extensions.multi_table import MultiLookupTable -from concrete.common.extensions.table import LookupTable -from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor -from concrete.numpy import compile as compile_ -from concrete.numpy import tracing -from concrete.numpy.compile import ( - FHECircuit, - compile_numpy_function, - compile_numpy_function_into_op_graph_and_measure_bounds, -) - -# pylint: disable=too-many-lines - - -def data_gen(args): - """Helper to create an inputset""" - for prod in itertools.product(*args): - yield prod if len(prod) > 1 else prod[0] - - -def numpy_array_data_gen(args, tensor_shapes): - """Helper to create an inputset containing numpy arrays filled with the same value and of a - particular shape""" - for prod in itertools.product(*args): - yield tuple( - numpy.full(tensor_shape, val, numpy.int64) - for val, tensor_shape in zip(prod, tensor_shapes) - ) - - -def no_fuse_unhandled(x, y): - """No fuse unhandled""" - x_intermediate = x + 2.8 - y_intermediate = y + 9.3 - intermediate = x_intermediate - y_intermediate - return (intermediate * 1.5).astype(numpy.int32) - - -def identity_lut_generator(n): - """Test lookup table""" - return lambda x: LookupTable(list(range(2 ** n)))[x] - - -def negative_identity_smaller_lut_generator(n): - """Test negative lookup table""" - - table = LookupTable(range(2 ** (n - 1))) - offset = 2 ** (n - 1) - - return (lambda x: table[x + (-offset)]), table - - -def negative_identity_lut_generator(n): - """Test negative lookup table (bigger than bit-width)""" - - table = LookupTable(range(2 ** n)) - offset = 2 ** (n - 1) - - return (lambda x: table[x + (-offset)]), table - - -def negative_identity_bigger_lut_generator(n): - """Test negative lookup table (bigger than bit-width)""" - - table = LookupTable(range(2 ** (n + 1))) - offset = 2 ** (n - 1) - - return (lambda x: table[x + (-offset)]), table - - -def weird_lut(n): - """A weird lookup table to test an edge case related to negative indexing""" - - table = LookupTable([0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 4, 5, 6, 7]) - offset = 2 ** (n - 1) - - return (lambda x: table[x + (-offset)]), table - - -def random_lut_1b(x): - """1-bit random table lookup""" - - # fmt: off - table = LookupTable([10, 12]) - # fmt: on - - return table[x] - - -def random_lut_2b(x): - """2-bit random table lookup""" - - # fmt: off - table = LookupTable([3, 8, 22, 127]) - # fmt: on - - return table[x] - - -def random_lut_3b(x): - """3-bit random table lookup""" - - # fmt: off - table = LookupTable([30, 52, 125, 23, 17, 12, 90, 4]) - # fmt: on - - return table[x] - - -def random_lut_4b(x): - """4-bit random table lookup""" - - # fmt: off - table = LookupTable([30, 52, 125, 23, 17, 12, 90, 4, 21, 51, 22, 15, 53, 100, 75, 90]) - # fmt: on - - return table[x] - - -def random_lut_5b(x): - """5-bit random table lookup""" - - # fmt: off - table = LookupTable( - [ - 1, 5, 2, 3, 10, 2, 4, 8, 1, 12, 15, 12, 10, 1, 0, 2, - 4, 3, 8, 7, 10, 11, 6, 13, 9, 0, 2, 1, 15, 11, 12, 5 - ] - ) - # fmt: on - - return table[x] - - -def random_lut_6b(x): - """6-bit random table lookup""" - - # fmt: off - table = LookupTable( - [ - 95, 74, 11, 83, 24, 116, 28, 75, 26, 85, 114, 121, 91, 123, 78, 69, - 72, 115, 67, 5, 39, 11, 120, 88, 56, 43, 74, 16, 72, 85, 103, 92, - 44, 115, 50, 56, 107, 77, 25, 71, 52, 45, 80, 35, 69, 8, 40, 87, - 26, 85, 84, 53, 73, 95, 86, 22, 16, 45, 59, 112, 53, 113, 98, 116 - ] - ) - # fmt: on - - return table[x] - - -def random_lut_7b(x): - """7-bit random table lookup""" - - # fmt: off - table = LookupTable( - [ - 13, 58, 38, 58, 15, 15, 77, 86, 80, 94, 108, 27, 126, 60, 65, 95, - 50, 79, 22, 97, 38, 60, 25, 48, 73, 112, 27, 45, 88, 20, 67, 17, - 16, 6, 71, 60, 77, 43, 93, 40, 41, 31, 99, 122, 120, 40, 94, 13, - 111, 44, 96, 62, 108, 91, 34, 90, 103, 58, 3, 103, 19, 69, 55, 108, - 0, 111, 113, 0, 0, 73, 22, 52, 81, 2, 88, 76, 36, 121, 97, 121, - 123, 79, 82, 120, 12, 65, 54, 101, 90, 52, 84, 106, 23, 15, 110, 79, - 85, 101, 30, 61, 104, 35, 81, 30, 98, 44, 111, 32, 68, 18, 45, 123, - 84, 80, 68, 27, 31, 38, 126, 61, 51, 7, 49, 37, 63, 114, 22, 18, - ] - ) - # fmt: on - - return table[x] - - -def multi_lut(x): - """2-bit multi table lookup""" - - table = MultiLookupTable( - [ - [LookupTable([1, 2, 1, 0]), LookupTable([2, 2, 1, 3])], - [LookupTable([1, 0, 1, 0]), LookupTable([0, 2, 3, 3])], - [LookupTable([0, 2, 3, 0]), LookupTable([2, 1, 2, 0])], - ] - ) - return table[x] - - -def small_fused_table(x): - """Test with a small fused table""" - return (10 * (numpy.cos(x + 1) + 1)).astype(numpy.uint32) - - -def complicated_topology(x): - """Mix x in an intricated way.""" - intermediate = x - x_p_1 = intermediate + 1 - x_p_2 = intermediate + 2 - x_p_3 = x_p_1 + x_p_2 - return ( - x_p_3.astype(numpy.int32), - x_p_2.astype(numpy.int32), - (x_p_2 + 3).astype(numpy.int32), - x_p_3.astype(numpy.int32) + 67, - ) - - -def mix_x_and_y_and_call_f(func, x, y): - """Create an upper function to test `func`""" - z = numpy.abs(10 * func(x)) - z = z / 2 - z = z.astype(numpy.int32) + y - return z - - -def mix_x_and_y_and_call_f_with_float_inputs(func, x, y): - """Create an upper function to test `func`, with inputs which are forced to be floats""" - z = numpy.abs(10 * func(x + 0.1)) - z = z.astype(numpy.int32) + y - return z - - -def mix_x_and_y_and_call_f_with_integer_inputs(func, x, y): - """Create an upper function to test `func`, with inputs which are forced to be integers but - in a way which is fusable into a TLU""" - x = x // 2 - a = x + 0.1 - a = numpy.rint(a).astype(numpy.int32) - z = numpy.abs(10 * func(a)) - z = z.astype(numpy.int32) + y - return z - - -def mix_x_and_y_and_call_f_which_expects_small_inputs(func, x, y): - """Create an upper function to test `func`, which expects small values to not use too much - precision""" - # TODO: https://github.com/zama-ai/concrete-numpy-internal/issues/993 - # Understand why it's failing with 0.77 for numpy.arctanh - a = numpy.abs(0.5 * numpy.sin(x)) - z = numpy.abs(3 * func(a)) - z = z.astype(numpy.int32) + y - return z - - -def mix_x_and_y_and_call_f_which_has_large_outputs(func, x, y): - """Create an upper function to test `func`, which outputs large values""" - a = numpy.abs(2 * numpy.sin(x)) - z = numpy.abs(func(a) * 0.131) - z = z.astype(numpy.int32) + y - return z - - -def mix_x_and_y_and_call_f_avoid_0_input(func, x, y): - """Create an upper function to test `func`, which makes that inputs are not 0""" - a = numpy.abs(7 * numpy.sin(x)) + 1 - c = 100 // a - b = 100 / a - a = a + b + c - z = numpy.abs(5 * func(a)) - z = z.astype(numpy.int32) + y - return z - - -def mix_x_and_y_and_call_binary_f_one(func, c, x, y): - """Create an upper function to test `func`""" - z = numpy.abs(func(x, c) + 1) - z = z.astype(numpy.uint32) + y - return z - - -def mix_x_and_y_and_call_binary_f_two(func, c, x, y): - """Create an upper function to test `func`""" - z = numpy.abs(func(c, x) + 1) - z = z.astype(numpy.uint32) + y - return z - - -def negative_binary_f_one(func, c, x, y): - """Test negative values as input to func as first argument.""" - x = x + (-4) - z = func(x, c) - z = numpy.clip(z, 0, 63).astype(numpy.int32) + y - return z - - -def negative_binary_f_two(func, c, x, y): - """Test negative values as input to func as second argument.""" - x = x + (-4) - z = func(c, x) - z = numpy.clip(z, 0, 63).astype(numpy.int32) + y - return z - - -def negative_unary_f(func, x, y): - """Test negative values as input to func.""" - x = x + (-4) - z = func(x) - z = numpy.clip(z, 0, 63).astype(numpy.int32) + y - return z - - -def subtest_compile_and_run_unary_ufunc_correctness( - ufunc, - upper_function, - input_ranges, - tensor_shape, - default_compilation_configuration, - check_is_good_execution, -): - """Test correctness of results when running a compiled function""" - - def get_function(ufunc, upper_function): - return lambda x, y: upper_function(ufunc, x, y) - - function = get_function(ufunc, upper_function) - - function_parameters = { - arg_name: EncryptedTensor(Integer(64, True), shape=tensor_shape) for arg_name in ["x", "y"] - } - - compiler_engine = compile_numpy_function( - function, - function_parameters, - numpy_array_data_gen( - tuple(range(x[0], x[1] + 1) for x in input_ranges), - [tensor_shape] * len(function_parameters), - ), - default_compilation_configuration, - ) - - # TODO: https://github.com/zama-ai/concrete-numpy-internal/issues/910 - args = [ - numpy.random.randint(low, high, size=tensor_shape, dtype=numpy.uint8) - if tensor_shape != () - else random.randint(low, high) - for (low, high) in input_ranges - ] - - check_is_good_execution(compiler_engine, function, args) - - -def subtest_compile_and_run_binary_ufunc_correctness( - ufunc, - upper_function, - c, - input_ranges, - tensor_shape, - default_compilation_configuration, - check_is_good_execution, -): - """Test correctness of results when running a compiled function""" - - def get_function(ufunc, upper_function): - return lambda x, y: upper_function(ufunc, c, x, y) - - function = get_function(ufunc, upper_function) - - function_parameters = { - arg_name: EncryptedTensor(Integer(64, True), shape=tensor_shape) for arg_name in ["x", "y"] - } - - compiler_engine = compile_numpy_function( - function, - function_parameters, - numpy_array_data_gen( - tuple(range(x[0], x[1] + 1) for x in input_ranges), - [tensor_shape] * len(function_parameters), - ), - default_compilation_configuration, - ) - - # TODO: https://github.com/zama-ai/concrete-numpy-internal/issues/910 - args = [ - numpy.random.randint(low, high, size=tensor_shape, dtype=numpy.uint8) - if tensor_shape != () - else random.randint(low, high) - for (low, high) in input_ranges - ] - - check_is_good_execution(compiler_engine, function, args) - - -@pytest.mark.parametrize( - "ufunc", - [f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 2], -) -@pytest.mark.parametrize( - "tensor_shape", [pytest.param((), id="scalar"), pytest.param((3, 1, 2), id="tensor")] -) -def test_binary_ufunc_operations( - ufunc, - tensor_shape, - default_compilation_configuration, - check_is_good_execution, -): - """Test biary functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC.""" - - run_multi_tlu_test = False - if tensor_shape != (): - run_multi_tlu_test = True - tensor_for_multi_tlu = numpy.arange(numpy.prod(tensor_shape)).reshape(tensor_shape) - tensor_for_multi_tlu_small_values = tensor_for_multi_tlu // 2 - - if ufunc in [numpy.power, numpy.float_power]: - # Need small constants to keep results really small - subtest_compile_and_run_binary_ufunc_correctness( - ufunc, - mix_x_and_y_and_call_binary_f_one, - 3, - ((0, 4), (0, 5)), - tensor_shape, - default_compilation_configuration, - check_is_good_execution, - ) - subtest_compile_and_run_binary_ufunc_correctness( - ufunc, - mix_x_and_y_and_call_binary_f_two, - 2, - ((0, 4), (0, 5)), - tensor_shape, - default_compilation_configuration, - check_is_good_execution, - ) - if run_multi_tlu_test: - subtest_compile_and_run_binary_ufunc_correctness( - ufunc, - mix_x_and_y_and_call_binary_f_one, - tensor_for_multi_tlu_small_values, - ((0, 4), (0, 5)), - tensor_shape, - default_compilation_configuration, - check_is_good_execution, - ) - subtest_compile_and_run_binary_ufunc_correctness( - ufunc, - mix_x_and_y_and_call_binary_f_two, - tensor_for_multi_tlu_small_values, - ((0, 4), (0, 5)), - tensor_shape, - default_compilation_configuration, - check_is_good_execution, - ) - elif ufunc in [numpy.floor_divide, numpy.fmod, numpy.remainder, numpy.true_divide]: - subtest_compile_and_run_binary_ufunc_correctness( - ufunc, - mix_x_and_y_and_call_binary_f_two, - 31, - ((1, 5), (1, 5)), - tensor_shape, - default_compilation_configuration, - check_is_good_execution, - ) - if run_multi_tlu_test: - subtest_compile_and_run_binary_ufunc_correctness( - ufunc, - mix_x_and_y_and_call_binary_f_two, - tensor_for_multi_tlu, - ((1, 5), (1, 5)), - tensor_shape, - default_compilation_configuration, - check_is_good_execution, - ) - elif ufunc in [numpy.lcm, numpy.left_shift]: - # Need small constants to keep results sufficiently small - subtest_compile_and_run_binary_ufunc_correctness( - ufunc, - mix_x_and_y_and_call_binary_f_one, - 3, - ((0, 5), (0, 5)), - tensor_shape, - default_compilation_configuration, - check_is_good_execution, - ) - subtest_compile_and_run_binary_ufunc_correctness( - ufunc, - mix_x_and_y_and_call_binary_f_two, - 2, - ((0, 5), (0, 5)), - tensor_shape, - default_compilation_configuration, - check_is_good_execution, - ) - if run_multi_tlu_test: - subtest_compile_and_run_binary_ufunc_correctness( - ufunc, - mix_x_and_y_and_call_binary_f_one, - tensor_for_multi_tlu - if ufunc != numpy.left_shift - else tensor_for_multi_tlu_small_values, - ((0, 5), (0, 5)), - tensor_shape, - default_compilation_configuration, - check_is_good_execution, - ) - subtest_compile_and_run_binary_ufunc_correctness( - ufunc, - mix_x_and_y_and_call_binary_f_two, - tensor_for_multi_tlu - if ufunc != numpy.left_shift - else tensor_for_multi_tlu_small_values, - ((0, 5), (0, 5)), - tensor_shape, - default_compilation_configuration, - check_is_good_execution, - ) - elif ufunc in [numpy.ldexp]: - # Need small constants to keep results sufficiently small - subtest_compile_and_run_binary_ufunc_correctness( - ufunc, - mix_x_and_y_and_call_binary_f_two, - 2, - ((0, 5), (0, 5)), - tensor_shape, - default_compilation_configuration, - check_is_good_execution, - ) - if run_multi_tlu_test: - subtest_compile_and_run_binary_ufunc_correctness( - ufunc, - mix_x_and_y_and_call_binary_f_two, - tensor_for_multi_tlu // 2, - ((0, 5), (0, 5)), - tensor_shape, - default_compilation_configuration, - check_is_good_execution, - ) - else: - # General case - subtest_compile_and_run_binary_ufunc_correctness( - ufunc, - mix_x_and_y_and_call_binary_f_one, - 41, - ((0, 5), (0, 5)), - tensor_shape, - default_compilation_configuration, - check_is_good_execution, - ) - subtest_compile_and_run_binary_ufunc_correctness( - ufunc, - mix_x_and_y_and_call_binary_f_two, - 42, - ((0, 5), (0, 5)), - tensor_shape, - default_compilation_configuration, - check_is_good_execution, - ) - if run_multi_tlu_test: - subtest_compile_and_run_binary_ufunc_correctness( - ufunc, - mix_x_and_y_and_call_binary_f_one, - tensor_for_multi_tlu, - ((0, 5), (0, 5)), - tensor_shape, - default_compilation_configuration, - check_is_good_execution, - ) - subtest_compile_and_run_binary_ufunc_correctness( - ufunc, - mix_x_and_y_and_call_binary_f_two, - tensor_for_multi_tlu, - ((0, 5), (0, 5)), - tensor_shape, - default_compilation_configuration, - check_is_good_execution, - ) - - # Negative inputs tests on compatible functions - if ufunc not in [ - numpy.floor_divide, - numpy.fmod, - numpy.remainder, - numpy.true_divide, - numpy.power, - numpy.float_power, - ]: - subtest_compile_and_run_binary_ufunc_correctness( - ufunc, - negative_binary_f_one, - 2, - ((0, 7), (0, 3)), - tensor_shape, - default_compilation_configuration, - check_is_good_execution, - ) - subtest_compile_and_run_binary_ufunc_correctness( - ufunc, - negative_binary_f_two, - 2, - ((0, 7), (0, 3)), - tensor_shape, - default_compilation_configuration, - check_is_good_execution, - ) - - -@pytest.mark.parametrize( - "ufunc", [f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 1] -) -@pytest.mark.parametrize( - "tensor_shape", [pytest.param((), id="scalar"), pytest.param((3, 1, 2), id="tensor")] -) -def test_unary_ufunc_operations( - ufunc, tensor_shape, default_compilation_configuration, check_is_good_execution -): - """Test unary functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC.""" - - if ufunc in [ - numpy.degrees, - numpy.rad2deg, - ]: - # Need to reduce the output value, to avoid to need too much precision - subtest_compile_and_run_unary_ufunc_correctness( - ufunc, - mix_x_and_y_and_call_f_which_has_large_outputs, - ((0, 5), (0, 5)), - tensor_shape, - default_compilation_configuration, - check_is_good_execution, - ) - elif ufunc in [ - numpy.negative, - ]: - # Need to turn the input into a float - subtest_compile_and_run_unary_ufunc_correctness( - ufunc, - mix_x_and_y_and_call_f_with_float_inputs, - ((0, 5), (0, 5)), - tensor_shape, - default_compilation_configuration, - check_is_good_execution, - ) - elif ufunc in [ - numpy.arccosh, - numpy.log, - numpy.log2, - numpy.log10, - numpy.reciprocal, - ]: - # No 0 in the domain of definition - subtest_compile_and_run_unary_ufunc_correctness( - ufunc, - mix_x_and_y_and_call_f_avoid_0_input, - ((1, 5), (1, 5)), - tensor_shape, - default_compilation_configuration, - check_is_good_execution, - ) - elif ufunc in [ - numpy.cosh, - numpy.exp, - numpy.exp2, - numpy.expm1, - numpy.square, - numpy.arccos, - numpy.arcsin, - numpy.arctanh, - numpy.sinh, - ]: - # Need a small range of inputs, to avoid to need too much precision - subtest_compile_and_run_unary_ufunc_correctness( - ufunc, - mix_x_and_y_and_call_f_which_expects_small_inputs, - ((0, 5), (0, 5)), - tensor_shape, - default_compilation_configuration, - check_is_good_execution, - ) - else: - # Regular case for univariate functions - subtest_compile_and_run_unary_ufunc_correctness( - ufunc, - mix_x_and_y_and_call_f, - ((0, 5), (0, 5)), - tensor_shape, - default_compilation_configuration, - check_is_good_execution, - ) - - # Negative inputs tests on compatible functions - if ufunc not in [ - numpy.arccosh, - numpy.arccos, - numpy.arcsin, - numpy.arctanh, - numpy.sqrt, - numpy.log, - numpy.log1p, - numpy.log2, - numpy.log10, - numpy.reciprocal, - ]: - subtest_compile_and_run_unary_ufunc_correctness( - ufunc, - negative_unary_f, - ((0, 7), (0, 3)), - tensor_shape, - default_compilation_configuration, - check_is_good_execution, - ) - - -@pytest.mark.parametrize( - "function,input_ranges,list_of_arg_names", - [ - pytest.param(lambda x: x + 42, ((-5, 5),), ["x"]), - pytest.param(lambda x, y: x + y + 8, ((2, 10), (4, 8)), ["x", "y"]), - pytest.param(lambda x, y: (x + 1, y + 10), ((-1, 1), (3, 8)), ["x", "y"]), - pytest.param( - lambda x, y, z: (x + y + 1 - z, x * y + 42, z, z + 99), - ((4, 8), (3, 4), (0, 4)), - ["x", "y", "z"], - ), - pytest.param(complicated_topology, ((0, 10),), ["x"]), - ], -) -def test_compile_function_multiple_outputs( - function, input_ranges, list_of_arg_names, default_compilation_configuration -): - """Test function compile_numpy_function_into_op_graph for a program with multiple outputs""" - - def data_gen_local(args): - for prod in itertools.product(*args): - yield tuple(numpy.array(val) for val in prod) if len(prod) > 1 else numpy.array(prod[0]) - - function_parameters = { - arg_name: EncryptedScalar(Integer(64, True)) for arg_name in list_of_arg_names - } - - op_graph = compile_numpy_function_into_op_graph_and_measure_bounds( - function, - function_parameters, - data_gen_local(tuple(range(x[0], x[1] + 1) for x in input_ranges)), - default_compilation_configuration, - ) - - # TODO: For the moment, we don't have really checks, but some printfs. Later, - # when we have the converter, we can check the MLIR - draw_graph(op_graph, show=False) - - str_of_the_graph = format_operation_graph(op_graph) - print(f"\n{str_of_the_graph}\n") - - -@pytest.mark.parametrize( - "function,input_ranges,list_of_arg_names", - [ - pytest.param(lambda x: (-27) + 4 * (x + 8), ((0, 10),), ["x"]), - pytest.param(lambda x: x + (-33), ((40, 60),), ["x"]), - pytest.param(lambda x: 17 - (0 - x), ((0, 10),), ["x"]), - pytest.param(lambda x: 42 + x * (-3), ((0, 10),), ["x"]), - pytest.param(lambda x: 43 + (-4) * x, ((0, 10),), ["x"]), - pytest.param(lambda x: 3 - (-5) * x, ((0, 10),), ["x"]), - pytest.param(lambda x: (-2) * (-5) * x, ((0, 10),), ["x"]), - pytest.param(lambda x: (-2) * x * (-5), ((0, 10),), ["x"]), - pytest.param(lambda x, y: 40 - (-3 * x) + (-2 * y), ((0, 20), (0, 20)), ["x", "y"]), - pytest.param(lambda x: x + numpy.int32(42), ((0, 10),), ["x"]), - pytest.param(lambda x: x + 64, ((0, 10),), ["x"]), - pytest.param(lambda x: x * 3, ((0, 40),), ["x"]), - pytest.param(lambda x: 120 - x, ((40, 80),), ["x"]), - pytest.param(lambda x, y: x + y + 64, ((0, 20), (0, 20)), ["x", "y"]), - pytest.param(lambda x, y: 100 - y + x, ((0, 20), (0, 20)), ["x", "y"]), - pytest.param(lambda x, y: 50 - y * 2 + x, ((0, 20), (0, 20)), ["x", "y"]), - pytest.param(lambda x: -x + 50, ((0, 20),), ["x"]), - pytest.param(lambda x: numpy.dot(x, 2), ((0, 20),), ["x"]), - pytest.param(lambda x: numpy.dot(2, x), ((0, 20),), ["x"]), - ], -) -def test_compile_and_run_correctness( - function, input_ranges, list_of_arg_names, default_compilation_configuration -): - """Test correctness of results when running a compiled function""" - - function_parameters = { - arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names - } - - compiler_engine = compile_numpy_function( - function, - function_parameters, - data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)), - default_compilation_configuration, - ) - - args = [random.randint(low, high) for (low, high) in input_ranges] - assert compiler_engine.encrypt_run_decrypt(*args) == function(*args) - - -@pytest.mark.parametrize( - "function,input_ranges,list_of_arg_names", - [ - pytest.param(lambda x: x ** 2, ((0, 10),), ["x"]), - pytest.param(lambda x: 2 ** (x % 5), ((0, 20),), ["x"]), - pytest.param(lambda x: x << 1, ((0, 13),), ["x"]), - pytest.param(lambda x: 2 << (x % 6), ((0, 13),), ["x"]), - pytest.param(lambda x: x >> 2, ((30, 100),), ["x"]), - pytest.param(lambda x: 115 >> (x % 3), ((0, 17),), ["x"]), - pytest.param(lambda x: x % 7, ((0, 100),), ["x"]), - pytest.param(lambda x: x > 7, ((0, 20),), ["x"]), - pytest.param(lambda x: x < 11, ((0, 20),), ["x"]), - pytest.param(lambda x: x >= 8, ((0, 20),), ["x"]), - pytest.param(lambda x: x <= 10, ((0, 20),), ["x"]), - pytest.param(lambda x: x == 15, ((0, 20),), ["x"]), - pytest.param(lambda x: x & 14, ((0, 20),), ["x"]), - pytest.param(lambda x: x | 18, ((0, 20),), ["x"]), - pytest.param(lambda x: x ^ 23, ((0, 20),), ["x"]), - pytest.param(lambda x: x % 3, ((0, 20),), ["x"]), - pytest.param(lambda x: 17 & x, ((0, 20),), ["x"]), - pytest.param(lambda x: 19 | x, ((0, 20),), ["x"]), - pytest.param(lambda x: 45 ^ x, ((0, 20),), ["x"]), - pytest.param(lambda x: 19 % (x + 1), ((0, 20),), ["x"]), - ], -) -def test_compile_and_run_correctness__for_prog_with_tlu( - function, - input_ranges, - list_of_arg_names, - default_compilation_configuration, - check_is_good_execution, -): - """Test correctness of results when running a compiled function which uses a TLU""" - - function_parameters = { - arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names - } - - compiler_engine = compile_numpy_function( - function, - function_parameters, - data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)), - default_compilation_configuration, - ) - - for _ in range(16): - args = [random.randint(low, high) for (low, high) in input_ranges] - check_is_good_execution(compiler_engine, function, args, verbose=False) - - -@pytest.mark.parametrize( - "function,parameters,inputset,test_input,use_check_good_exec", - [ - pytest.param( - lambda x: x + 1, - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)), - }, - [numpy.random.randint(0, 2 ** 3, size=(3, 2)) for _ in range(10)], - ( - [ - [0, 7], - [6, 1], - [2, 5], - ], - ), - False, - ), - pytest.param( - lambda x: x + numpy.array([[1, 0], [2, 0], [3, 1]], dtype=numpy.uint32), - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)), - }, - [numpy.random.randint(0, 2 ** 3, size=(3, 2)) for _ in range(10)], - ( - [ - [0, 7], - [6, 1], - [2, 5], - ], - ), - False, - ), - pytest.param( - lambda x, y: x + y, - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)), - "y": EncryptedScalar(UnsignedInteger(3)), - }, - [ - ( - numpy.random.randint(0, 2 ** 3, size=(3, 2)), - random.randint(0, (2 ** 3) - 1), - ) - for _ in range(10) - ], - ( - [ - [0, 7], - [6, 1], - [2, 5], - ], - 2, - ), - False, - ), - pytest.param( - lambda x, y: x + y, - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)), - "y": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)), - }, - [ - ( - numpy.random.randint(0, 2 ** 3, size=(3, 2)), - numpy.random.randint(0, 2 ** 3, size=(3, 2)), - ) - for _ in range(10) - ], - ( - [ - [0, 7], - [6, 1], - [2, 5], - ], - [ - [1, 6], - [2, 5], - [3, 4], - ], - ), - False, - ), - pytest.param( - lambda x: 100 - x, - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)), - }, - [numpy.random.randint(0, 2 ** 3, size=(3, 2)) for _ in range(10)], - ( - [ - [0, 7], - [6, 1], - [2, 5], - ], - ), - False, - ), - pytest.param( - lambda x: numpy.array([[10, 15], [20, 15], [10, 30]], dtype=numpy.uint32) - x, - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)), - }, - [numpy.random.randint(0, 2 ** 3, size=(3, 2)) for _ in range(10)], - ( - [ - [0, 7], - [6, 1], - [2, 5], - ], - ), - False, - ), - pytest.param( - lambda x: x * 2, - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)), - }, - [numpy.random.randint(0, 2 ** 3, size=(3, 2)) for _ in range(10)], - ( - [ - [0, 7], - [6, 1], - [2, 5], - ], - ), - False, - ), - pytest.param( - lambda x: x * numpy.array([[1, 2], [2, 1], [3, 1]], dtype=numpy.uint32), - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)), - }, - [numpy.random.randint(0, 2 ** 3, size=(3, 2)) for _ in range(10)], - ( - [ - [4, 7], - [6, 1], - [2, 5], - ], - ), - False, - ), - pytest.param( - lambda x: LookupTable([2, 1, 3, 0])[x], - { - "x": EncryptedTensor(UnsignedInteger(2), shape=(3, 2)), - }, - [numpy.random.randint(0, 2 ** 2, size=(3, 2)) for _ in range(10)], - ( - [ - [0, 1], - [2, 1], - [3, 0], - ], - ), - True, - ), - pytest.param( - lambda x: numpy.dot(x, 2), - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(3,)), - }, - [numpy.random.randint(0, 2 ** 3, size=(3,)) for _ in range(10)], - ([2, 7, 1],), - False, - ), - pytest.param( - lambda x: numpy.dot(2, x), - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(3,)), - }, - [numpy.random.randint(0, 2 ** 3, size=(3,)) for _ in range(10)], - ([2, 7, 1],), - False, - ), - pytest.param( - lambda x: x + x.shape[0], - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(3,)), - }, - [numpy.random.randint(0, 2 ** 3, size=(3,)) for _ in range(10)], - ([2, 1, 3],), - False, - ), - pytest.param( - lambda x: numpy.clip(x, 1, 5), - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)), - }, - [numpy.random.randint(0, 2 ** 3, size=(3, 2)) for _ in range(10)], - ( - [ - [0, 7], - [6, 1], - [2, 5], - ], - ), - True, - ), - pytest.param( - lambda x: numpy.clip(x + (-4), -3, 5) + 3, - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)), - }, - [numpy.random.randint(0, 2 ** 3, size=(3, 2)) for _ in range(10)], - ( - [ - [0, 7], - [6, 1], - [2, 5], - ], - ), - True, - ), - pytest.param( - lambda x: x.clip(1, 5), - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)), - }, - [numpy.random.randint(0, 2 ** 3, size=(3, 2)) for _ in range(10)], - ( - [ - [0, 7], - [6, 1], - [2, 5], - ], - ), - True, - ), - pytest.param( - lambda x: (x + (-4)).clip(-3, 5) + 3, - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)), - }, - [numpy.random.randint(0, 2 ** 3, size=(3, 2)) for _ in range(10)], - ( - [ - [0, 7], - [6, 1], - [2, 5], - ], - ), - True, - ), - pytest.param( - lambda x: numpy.array([120, 60, 30], dtype=numpy.uint8) // x, - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(2, 3)), - }, - [numpy.random.randint(1, 2 ** 3, size=(2, 3)) for _ in range(10)], - ( - [ - [1, 7, 6], - [3, 2, 5], - ], - ), - True, - ), - pytest.param( - lambda x: numpy.sum(x), - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(2, 3)), - }, - [numpy.random.randint(0, 2 ** 3, size=(2, 3)) for _ in range(10)], - ( - [ - [1, 7, 6], - [3, 2, 5], - ], - ), - False, - ), - pytest.param( - lambda x: numpy.sum(x, axis=0), - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(2, 3)), - }, - [numpy.random.randint(0, 2 ** 3, size=(2, 3)) for _ in range(10)], - ( - [ - [1, 7, 6], - [3, 2, 5], - ], - ), - False, - ), - pytest.param( - lambda x: numpy.sum(x, axis=1), - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(2, 3)), - }, - [numpy.random.randint(0, 2 ** 3, size=(2, 3)) for _ in range(10)], - ( - [ - [1, 7, 6], - [3, 2, 5], - ], - ), - False, - ), - pytest.param( - lambda x: numpy.sum(x, axis=(0, 1)), - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(2, 3)), - }, - [numpy.random.randint(0, 2 ** 3, size=(2, 3)) for _ in range(10)], - ( - [ - [1, 7, 6], - [3, 2, 5], - ], - ), - False, - ), - pytest.param( - lambda x: numpy.sum(x, keepdims=True), - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(2, 3)), - }, - [numpy.random.randint(0, 2 ** 3, size=(2, 3)) for _ in range(10)], - ( - [ - [1, 7, 6], - [3, 2, 5], - ], - ), - False, - ), - pytest.param( - lambda x: numpy.sum(x, axis=0, keepdims=True), - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(2, 3)), - }, - [numpy.random.randint(0, 2 ** 3, size=(2, 3)) for _ in range(10)], - ( - [ - [1, 7, 6], - [3, 2, 5], - ], - ), - False, - ), - pytest.param( - lambda x: numpy.sum(x, axis=1, keepdims=True), - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(2, 3)), - }, - [numpy.random.randint(0, 2 ** 3, size=(2, 3)) for _ in range(10)], - ( - [ - [1, 7, 6], - [3, 2, 5], - ], - ), - False, - ), - pytest.param( - lambda x: numpy.sum(x, axis=(0, 1), keepdims=True), - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(2, 3)), - }, - [numpy.random.randint(0, 2 ** 3, size=(2, 3)) for _ in range(10)], - ( - [ - [1, 7, 6], - [3, 2, 5], - ], - ), - False, - ), - pytest.param( - lambda x: numpy.sum(x, axis=-1, keepdims=True), - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(2, 3)), - }, - [numpy.random.randint(0, 2 ** 3, size=(2, 3)) for _ in range(10)], - ( - [ - [1, 7, 6], - [3, 2, 5], - ], - ), - False, - ), - pytest.param( - lambda x: numpy.sum(x, axis=-2, keepdims=True), - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(2, 3)), - }, - [numpy.random.randint(0, 2 ** 3, size=(2, 3)) for _ in range(10)], - ( - [ - [1, 7, 6], - [3, 2, 5], - ], - ), - False, - ), - pytest.param( - lambda x: numpy.sum(x, axis=(-2, -1), keepdims=True), - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(2, 3)), - }, - [numpy.random.randint(0, 2 ** 3, size=(2, 3)) for _ in range(10)], - ( - [ - [1, 7, 6], - [3, 2, 5], - ], - ), - False, - ), - pytest.param( - lambda x, y: numpy.concatenate((x, y)), - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(4, 2)), - "y": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)), - }, - [ - ( - numpy.random.randint(0, 2 ** 3, size=(4, 2)), - numpy.random.randint(0, 2 ** 3, size=(3, 2)), - ) - for _ in range(10) - ], - ( - [ - [0, 1], - [2, 3], - [4, 5], - [6, 7], - ], - [ - [4, 5], - [2, 3], - [0, 1], - ], - ), - False, - ), - pytest.param( - lambda x, y: numpy.concatenate((x, y), axis=1), - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(2, 4)), - "y": EncryptedTensor(UnsignedInteger(3), shape=(2, 3)), - }, - [ - ( - numpy.random.randint(0, 2 ** 3, size=(2, 4)), - numpy.random.randint(0, 2 ** 3, size=(2, 3)), - ) - for _ in range(10) - ], - ( - [ - [0, 1, 2, 3], - [4, 5, 6, 7], - ], - [ - [5, 4, 3], - [2, 1, 0], - ], - ), - False, - ), - pytest.param( - lambda x, y: numpy.concatenate((x, y), axis=-1), - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(2, 4)), - "y": EncryptedTensor(UnsignedInteger(3), shape=(2, 3)), - }, - [ - ( - numpy.random.randint(0, 2 ** 3, size=(2, 4)), - numpy.random.randint(0, 2 ** 3, size=(2, 3)), - ) - for _ in range(10) - ], - ( - [ - [0, 1, 2, 3], - [4, 5, 6, 7], - ], - [ - [5, 4, 3], - [2, 1, 0], - ], - ), - False, - ), - pytest.param( - lambda x, y: numpy.concatenate((x, y), axis=-2), - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(4, 2)), - "y": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)), - }, - [ - ( - numpy.random.randint(0, 2 ** 3, size=(4, 2)), - numpy.random.randint(0, 2 ** 3, size=(3, 2)), - ) - for _ in range(10) - ], - ( - [ - [0, 1], - [2, 3], - [4, 5], - [6, 7], - ], - [ - [4, 5], - [2, 3], - [0, 1], - ], - ), - False, - ), - pytest.param( - lambda x, y: numpy.concatenate((x, y), axis=None), - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(3, 4)), - "y": EncryptedTensor(UnsignedInteger(3), shape=(2, 3)), - }, - [ - ( - numpy.random.randint(0, 2 ** 3, size=(3, 4)), - numpy.random.randint(0, 2 ** 3, size=(2, 3)), - ) - for _ in range(10) - ], - ( - [ - [0, 1, 2, 3], - [4, 5, 6, 7], - [7, 6, 5, 4], - ], - [ - [5, 4, 3], - [2, 1, 0], - ], - ), - False, - ), - ], -) -def test_compile_and_run_tensor_correctness( - function, - parameters, - inputset, - test_input, - use_check_good_exec, - default_compilation_configuration, - check_is_good_execution, - check_array_equality, -): - """Test correctness of results when running a compiled function with tensor operators""" - circuit = compile_numpy_function( - function, - parameters, - inputset, - default_compilation_configuration, - ) - - numpy_test_input = tuple( - item if isinstance(item, int) else numpy.array(item, dtype=numpy.uint8) - for item in test_input - ) - - if use_check_good_exec: - check_is_good_execution(circuit, function, numpy_test_input) - else: - check_array_equality( - circuit.encrypt_run_decrypt(*numpy_test_input), - numpy.array(function(*numpy_test_input), dtype=numpy.uint8), - ) - - -@pytest.mark.parametrize( - "size, input_range", - [ - pytest.param( - 1, - (0, 8), - ), - pytest.param( - 4, - (0, 5), - ), - pytest.param( - 6, - (0, 4), - ), - pytest.param( - 10, - (0, 3), - ), - ], -) -def test_compile_and_run_dot_correctness(size, input_range, default_compilation_configuration): - """Test correctness of results when running a compiled function""" - - low, high = input_range - shape = (size,) - - inputset = [ - (numpy.zeros(shape, dtype=numpy.uint32), numpy.zeros(shape, dtype=numpy.uint32)), - ( - numpy.ones(shape, dtype=numpy.uint32) * high, - numpy.ones(shape, dtype=numpy.uint32) * high, - ), - ] - for _ in range(8): - inputset.append( - ( - numpy.random.randint(low, high + 1, size=shape), - numpy.random.randint(low, high + 1, size=shape), - ) - ) - - function_parameters = { - "x": EncryptedTensor(Integer(64, False), shape), - "y": ClearTensor(Integer(64, False), shape), - } - - def function(x, y): - return numpy.dot(x, y) - - def function_indirect_args(x, y): - return numpy.dot(x.flatten(), y.flatten()) - - for func_to_compile in [function, function_indirect_args]: - compiler_engine = compile_numpy_function( - func_to_compile, - function_parameters, - inputset, - default_compilation_configuration, - ) - - args = [ - numpy.random.randint(low, high + 1, size=shape, dtype=numpy.uint8) for __ in range(2) - ] - assert compiler_engine.encrypt_run_decrypt(*args) == func_to_compile(*args) - - -@pytest.mark.parametrize( - "size, input_range_x, input_range_y", - [ - pytest.param(6, (0, 3), (-3, 3)), - pytest.param(3, (0, 3), (-7, 7)), - ], -) -def test_compile_and_run_dot_correctness_with_signed_cst( - size, input_range_x, input_range_y, default_compilation_configuration -): - """Test correctness of dot with signed constant tensor.""" - - low_x, high_x = input_range_x - low_y, high_y = input_range_y - shape = (size,) - - # Check that never, the dot goes too high - # For this, we simplify our check knowing that low_x >= 0. Under this condition, the maximal - # value is for the dot is size * max(abs(high_x * low_y), abs(high_x * high_y)). And we want - # is to be less than 64, to have a signed value on strictly less than 8b - assert low_x >= 0 - assert size * max(abs(high_x * low_y), abs(high_x * high_y)) < 64 - - function_parameters = { - "x": EncryptedTensor(Integer(64, False), shape), - } - - constant1 = numpy.random.randint(low_y, high_y + 1, size=(size,)) - constant2 = numpy.random.randint(low_y, high_y + 1, size=(size,)) - - worst_x_1_1 = numpy.where(constant1 < 0, 0, high_x) - worst_x_1_2 = numpy.where(constant1 > 0, 0, high_x) - - worst_x_2_1 = numpy.where(constant2 < 0, 0, high_x) - worst_x_2_2 = numpy.where(constant2 > 0, 0, high_x) - - for i in range(2): - - inputset = [ - numpy.zeros(shape, dtype=numpy.uint32), - numpy.ones(shape, dtype=numpy.uint32) * low_x, - numpy.ones(shape, dtype=numpy.uint32) * high_x, - ] - - for _ in range(128): - inputset.append(numpy.random.randint(low_x, high_x + 1, size=shape)) - - if i == 0: - - def function(x): - return numpy.dot(x, constant1) - - inputset.extend([worst_x_1_1, worst_x_1_2]) - - else: - - def function(x): - return numpy.dot(constant2, x) - - inputset.extend([worst_x_2_1, worst_x_2_2]) - - compiler_engine = compile_numpy_function( - function, function_parameters, inputset, default_compilation_configuration - ) - - # compute modulus used for the output - output_bit_width = compiler_engine.op_graph.output_nodes[0].outputs[0].dtype.bit_width - # bit width + 1 padding bit - modulus = 2 ** (output_bit_width + 1) - - for _ in range(5): - args = [ - numpy.random.randint(low_x, high_x + 1, size=(size,), dtype=numpy.uint8), - ] - assert check_equality_modulo( - compiler_engine.encrypt_run_decrypt(*args), function(*args), modulus - ) - - -@pytest.mark.parametrize( - "size,input_range", - [ - pytest.param( - 1, - (0, 8), - ), - pytest.param( - 4, - (0, 5), - ), - pytest.param( - 6, - (0, 4), - ), - pytest.param( - 10, - (0, 3), - ), - ], -) -def test_compile_and_run_constant_dot_correctness( - size, input_range, default_compilation_configuration -): - """Test correctness of results when running a compiled function""" - - low, high = input_range - shape = (size,) - - inputset = [ - numpy.zeros(shape, dtype=numpy.uint32), - numpy.ones(shape, dtype=numpy.uint32) * high, - ] - for _ in range(8): - inputset.append(numpy.random.randint(low, high + 1)) - - constant = numpy.random.randint(low, high + 1, size=shape) - - def left(x): - return numpy.dot(x, constant) - - def right(x): - return numpy.dot(constant, x) - - left_circuit = compile_numpy_function( - left, - {"x": EncryptedTensor(Integer(64, False), shape)}, - inputset, - default_compilation_configuration, - ) - right_circuit = compile_numpy_function( - right, - {"x": EncryptedTensor(Integer(64, False), shape)}, - inputset, - default_compilation_configuration, - ) - - args = (numpy.random.randint(low, high + 1, size=shape, dtype=numpy.uint8),) - assert left_circuit.encrypt_run_decrypt(*args) == left(*args) - assert right_circuit.encrypt_run_decrypt(*args) == right(*args) - - -@pytest.mark.parametrize( - "lhs_shape,rhs_shape,input_range_inclusive_bound", - [ - pytest.param( - (3, 2), - (2, 3), - (0, 3), - ), - pytest.param( - (1, 2), - (2, 1), - (0, 3), - ), - pytest.param( - (3, 3), - (3, 3), - (0, 3), - ), - pytest.param( - (2, 1), - (1, 2), - (0, 7), - ), - pytest.param( - (2,), - (2,), - (0, 7), - ), - pytest.param( - (5, 5), - (5,), - (0, 3), - ), - pytest.param( - (5,), - (5, 5), - (0, 3), - ), - pytest.param( - (3, 2), - (2, 3), - (-4, 3), - ), - pytest.param( - (5,), - (5, 3), - (0, 3), - ), - pytest.param( - (5, 3), - (3,), - (0, 3), - ), - pytest.param( - (5,), - (4, 5, 3), - (0, 5), - ), - pytest.param( - (4, 5, 3), - (3,), - (0, 5), - ), - pytest.param( - (5,), - (2, 4, 5, 3), - (0, 5), - ), - pytest.param( - (2, 4, 5, 3), - (3,), - (0, 5), - ), - pytest.param( - (5, 4, 3), - (3, 2), - (0, 5), - ), - pytest.param( - (4, 3), - (5, 3, 2), - (0, 5), - ), - pytest.param( - (2, 5, 4, 3), - (3, 2), - (0, 5), - ), - pytest.param( - (5, 4, 3), - (1, 3, 2), - (0, 5), - ), - pytest.param( - (1, 4, 3), - (5, 3, 2), - (0, 5), - ), - pytest.param( - (5, 4, 3), - (2, 1, 3, 2), - (0, 5), - ), - pytest.param( - (2, 1, 4, 3), - (5, 3, 2), - (0, 5), - ), - ], -) -def test_compile_and_run_matmul_correctness( - lhs_shape, - rhs_shape, - input_range_inclusive_bound, - default_compilation_configuration, - check_array_equality, -): - """Test correctness of results when running a compiled function""" - - low, high = input_range_inclusive_bound - - check_mod = low < 0 or high < 0 - - max_abs = max(abs(low), abs(high)) - - # Inputset for x as lhs of matmul - lhs_inputset = [ - numpy.zeros(lhs_shape, dtype=numpy.uint32), - numpy.ones(lhs_shape, dtype=numpy.uint32) * high, - ] - # Inputset for x as rhs of matmul - rhs_inputset = [ - numpy.zeros(rhs_shape, dtype=numpy.uint32), - numpy.ones(rhs_shape, dtype=numpy.uint32) * high, - ] - for _ in range(8): - lhs_inputset.append(numpy.random.randint(low, high + 1, size=lhs_shape)) - rhs_inputset.append(numpy.random.randint(low, high + 1, size=rhs_shape)) - - left_constant = numpy.random.randint(low, high + 1, size=lhs_shape) - right_constant = numpy.random.randint(low, high + 1, size=rhs_shape) - - # Generate worst case inputsets for bit widths, replacing negative values by 0 and putting - # the max value elsewhere, and then doing the same for positive values - rhs_inputset.extend( - [ - numpy.where(right_constant < 0, 0, max_abs), - numpy.where(right_constant > 0, 0, max_abs), - ] - ) - lhs_inputset.extend( - [ - numpy.where(left_constant < 0, 0, max_abs), - numpy.where(left_constant > 0, 0, max_abs), - ] - ) - - # Keep inputset positive - rhs_inputset = [numpy.clip(val, 0, high) for val in rhs_inputset] - lhs_inputset = [numpy.clip(val, 0, high) for val in lhs_inputset] - - def get_output_mod(circuit: FHECircuit): - assert len(circuit.op_graph.output_nodes) == 1 - assert isinstance( - output_dtype := circuit.op_graph.get_ordered_outputs()[0].outputs[0].dtype, Integer - ) - return 2 ** output_dtype.bit_width - - def using_operator_left(x): - return x @ right_constant - - def using_function_left(x): - return numpy.matmul(x, right_constant) - - def using_operator_right(x): - return left_constant @ x - - def using_function_right(x): - return numpy.matmul(left_constant, x) - - operator_left_circuit = compile_numpy_function( - using_operator_left, - {"x": EncryptedTensor(UnsignedInteger(3), lhs_shape)}, - lhs_inputset, - default_compilation_configuration, - ) - function_left_circuit = compile_numpy_function( - using_function_left, - {"x": EncryptedTensor(UnsignedInteger(3), lhs_shape)}, - lhs_inputset, - default_compilation_configuration, - ) - operator_right_circuit = compile_numpy_function( - using_operator_right, - {"x": EncryptedTensor(UnsignedInteger(3), rhs_shape)}, - rhs_inputset, - default_compilation_configuration, - ) - function_right_circuit = compile_numpy_function( - using_function_right, - {"x": EncryptedTensor(UnsignedInteger(3), rhs_shape)}, - rhs_inputset, - default_compilation_configuration, - ) - - def check_result(circuit: FHECircuit, func, arg): - # Stay positive for input to FHE circuit - arg = numpy.clip(arg, 0, high).astype(numpy.uint8) - - circuit_output = circuit.encrypt_run_decrypt(arg) - func_output = func(arg) - - if check_mod: - output_mod = get_output_mod(circuit) - - circuit_output %= output_mod - func_output %= output_mod - - check_array_equality(circuit_output, func_output) - - arg = numpy.random.randint(low, high + 1, size=lhs_shape) - check_result(operator_left_circuit, using_operator_left, arg) - check_result(function_left_circuit, using_function_left, arg) - - arg = numpy.random.randint(low, high + 1, size=rhs_shape) - check_result(operator_right_circuit, using_operator_right, arg) - check_result(function_right_circuit, using_function_right, arg) - - -@pytest.mark.parametrize( - "function,input_bits,list_of_arg_names", - [ - pytest.param(identity_lut_generator(1), (1,), ["x"], id="identity function (1-bit)"), - pytest.param(identity_lut_generator(2), (2,), ["x"], id="identity function (2-bit)"), - pytest.param(identity_lut_generator(3), (3,), ["x"], id="identity function (3-bit)"), - pytest.param(identity_lut_generator(4), (4,), ["x"], id="identity function (4-bit)"), - pytest.param(identity_lut_generator(5), (5,), ["x"], id="identity function (5-bit)"), - pytest.param(identity_lut_generator(6), (6,), ["x"], id="identity function (6-bit)"), - pytest.param(identity_lut_generator(7), (7,), ["x"], id="identity function (7-bit)"), - pytest.param(random_lut_1b, (1,), ["x"], id="random function (1-bit)"), - pytest.param(random_lut_2b, (2,), ["x"], id="random function (2-bit)"), - pytest.param(random_lut_3b, (3,), ["x"], id="random function (3-bit)"), - pytest.param(random_lut_4b, (4,), ["x"], id="random function (4-bit)"), - pytest.param(random_lut_5b, (5,), ["x"], id="random function (5-bit)"), - pytest.param(random_lut_6b, (6,), ["x"], id="random function (6-bit)"), - pytest.param(random_lut_7b, (7,), ["x"], id="random function (7-bit)"), - pytest.param(small_fused_table, (5,), ["x"], id="small fused table (5-bits)"), - ], -) -def test_compile_and_run_lut_correctness( - function, - input_bits, - list_of_arg_names, - default_compilation_configuration, - check_is_good_execution, -): - """Test correctness of results when running a compiled function with LUT""" - - input_ranges = tuple((0, 2 ** input_bit - 1) for input_bit in input_bits) - - function_parameters = { - arg_name: EncryptedScalar(Integer(input_bit, False)) - for input_bit, arg_name in zip(input_bits, list_of_arg_names) - } - - compiler_engine = compile_numpy_function( - function, - function_parameters, - data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)), - default_compilation_configuration, - ) - - # testing random values - for _ in range(10): - args = [random.randint(low, high) for (low, high) in input_ranges] - check_is_good_execution(compiler_engine, function, args) - - # testing low values - args = [low for (low, _) in input_ranges] - check_is_good_execution(compiler_engine, function, args) - - # testing high values - args = [high for (_, high) in input_ranges] - check_is_good_execution(compiler_engine, function, args) - - -@pytest.mark.parametrize( - "function,table,bit_width", - [ - pytest.param(*negative_identity_smaller_lut_generator(n), n, id=f"smaller ({n}-bit)") - for n in range(1, 8) - ] - + [ - pytest.param(*negative_identity_lut_generator(n), n, id=f"normal ({n}-bit)") - for n in range(1, 8) - ] - + [ - pytest.param(*negative_identity_bigger_lut_generator(n), n, id=f"bigger ({n}-bit)") - for n in range(1, 7) - ] - + [ - pytest.param(*weird_lut(3), 3, id="weird"), - ], -) -def test_compile_and_run_negative_lut_correctness( - function, - table, - bit_width, - default_compilation_configuration, - check_is_good_execution, -): - """Test correctness when running a compiled function with LUT using negative values""" - - circuit = compile_numpy_function( - function, - {"x": EncryptedScalar(UnsignedInteger(bit_width))}, - range(2 ** bit_width), - default_compilation_configuration, - ) - - offset = 2 ** (bit_width - 1) - values = [-offset, -offset // 2, 0, offset // 2, offset - 1] - values.extend([random.randint(-offset, offset - 1) for _ in range(5)]) - for value in values: - assert table[value] == function(value + offset) - check_is_good_execution(circuit, function, [value + offset]) - - -def test_compile_and_run_multi_lut_correctness( - default_compilation_configuration, - check_is_good_execution, -): - """Test correctness of results when running a compiled function with Multi LUT""" - - def function_to_compile(x): - table = MultiLookupTable( - [ - [LookupTable([1, 2, 1, 0]), LookupTable([2, 2, 1, 3])], - [LookupTable([1, 0, 1, 0]), LookupTable([0, 2, 3, 3])], - [LookupTable([0, 2, 3, 0]), LookupTable([2, 1, 2, 0])], - ] - ) - return table[x] - - compiler_engine = compile_numpy_function( - function_to_compile, - { - "x": EncryptedTensor(UnsignedInteger(2), shape=(3, 2)), - }, - [numpy.random.randint(0, 2 ** 2, size=(3, 2)) for _ in range(10)], - default_compilation_configuration, - ) - - # testing random values - for _ in range(10): - args = [numpy.random.randint(0, 2 ** 2, size=(3, 2), dtype=numpy.uint8)] - check_is_good_execution(compiler_engine, function_to_compile, args) - - -def test_compile_function_with_direct_tlu(default_compilation_configuration): - """Test compile_numpy_function_into_op_graph for a program with direct table lookup""" - - table = LookupTable([9, 2, 4, 11]) - - def function(x): - return x + table[x] - - op_graph = compile_numpy_function_into_op_graph_and_measure_bounds( - function, - {"x": EncryptedScalar(Integer(2, is_signed=False))}, - range(4), - default_compilation_configuration, - ) - - str_of_the_graph = format_operation_graph(op_graph) - print(f"\n{str_of_the_graph}\n") - - -def test_compile_function_with_direct_tlu_overflow(default_compilation_configuration): - """Test compile_numpy_function_into_op_graph for a program with direct table lookup overflow""" - - table = LookupTable([9, 2, 4, 11]) - - def function(x): - return table[x] - - with pytest.raises(ValueError): - compile_numpy_function_into_op_graph_and_measure_bounds( - function, - {"x": EncryptedScalar(Integer(3, is_signed=False))}, - range(8), - default_compilation_configuration, - ) - - -@pytest.mark.parametrize( - "input_shape", - [ - pytest.param((4,)), - pytest.param((3, 2)), - pytest.param((3, 2, 5)), - pytest.param((3, 2, 5, 3)), - ], -) -def test_compile_and_run_transpose_correctness(input_shape, default_compilation_configuration): - """Test function to make sure compilation and execution of transpose works properly""" - - def transpose(x): - return numpy.transpose(x) - - compiler_engine = compile_numpy_function( - transpose, - {"x": EncryptedTensor(Integer(64, False), input_shape)}, - [numpy.random.randint(0, 120, size=input_shape) for i in range(20)], - default_compilation_configuration, - ) - x = numpy.random.randint(0, 120, size=input_shape, dtype=numpy.uint8) - expected = transpose(x) - result = compiler_engine.encrypt_run_decrypt(x) - assert (expected == result).all() - - -@pytest.mark.parametrize( - "input_shape", - [ - pytest.param((8,)), - ], -) -@pytest.mark.parametrize( - "loop_parallelize", - [ - pytest.param(True), - pytest.param(False), - ], -) -def test_compile_and_run_loop_parallelization( - input_shape, loop_parallelize, default_compilation_configuration -): - """Test function to make sure compilation and execution with and without loop parallelization - works properly""" - - def dot_and_add(x, y, a): - return numpy.dot(x, y) + a - - # Enable/Disable loop parallelization - compilation_configuration = deepcopy(default_compilation_configuration) - compilation_configuration.loop_parallelize = loop_parallelize - - compiler_engine = compile_numpy_function( - dot_and_add, - { - "x": EncryptedTensor(Integer(64, False), input_shape), - "y": ClearTensor(Integer(64, False), input_shape), - "a": ClearScalar(Integer(64, False)), - }, - [ - ( - numpy.random.randint(0, 2, size=input_shape), - numpy.random.randint(0, 2, size=input_shape), - numpy.random.randint(0, 2, size=()), - ) - for i in range(20) - ], - compilation_configuration, - ) - x = numpy.random.randint(0, 2, size=input_shape, dtype=numpy.uint8) - y = numpy.random.randint(0, 2, size=input_shape, dtype=numpy.uint8) - a = numpy.random.randint(0, 2, size=(), dtype=numpy.uint8) - expected = dot_and_add(x, y, a) - result = compiler_engine.encrypt_run_decrypt(x, y, a) - assert (expected == result).all() - - -# pylint: disable=line-too-long -@pytest.mark.parametrize( - "function,parameters,inputset,error,match", - [ - pytest.param( - lambda x: numpy.dot(x, numpy.array([-1.5])), - { - "x": EncryptedTensor(Integer(2, is_signed=False), shape=(1,)), - }, - [numpy.array([i]) for i in [1, 1, 0, 0, 1, 1, 0, 0, 2, 2]], - RuntimeError, - ( - """ - - function you are trying to compile isn't supported for MLIR lowering - -%0 = x # EncryptedTensor -%1 = [-1.5] # ClearTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported -%2 = dot(%0, %1) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer dot product is supported -return %2 - """.strip() # noqa: E501 - ), - ), - pytest.param( - no_fuse_unhandled, - {"x": EncryptedScalar(Integer(2, False)), "y": EncryptedScalar(Integer(2, False))}, - [(numpy.array(i), numpy.array(i)) for i in range(10)], - RuntimeError, - ( - """ - -function you are trying to compile isn't supported for MLIR lowering - -%0 = 1.5 # ClearScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported -%1 = x # EncryptedScalar -%2 = 2.8 # ClearScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported -%3 = y # EncryptedScalar -%4 = 9.3 # ClearScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported -%5 = add(%1, %2) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer addition is supported -%6 = add(%3, %4) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer addition is supported -%7 = sub(%5, %6) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer subtraction is supported -%8 = mul(%7, %0) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer multiplication is supported -%9 = astype(%8, dtype=int32) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ astype with floating-point inputs is required to be fused to be supported -return %9 - - """.strip() # noqa: E501 - ), - ), - pytest.param( - lambda x: numpy.ravel(x), - {"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))}, - [numpy.random.randint(0, 2 ** 3, size=(3, 2)) for i in range(10)], - RuntimeError, - ( - """ - -function you are trying to compile isn't supported for MLIR lowering - -%0 = x # EncryptedTensor -%1 = ravel(%0) # EncryptedTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ravel is not supported for the time being -return %1 - - """.strip() # noqa: E501 - ), - ), - pytest.param( - lambda x: numpy.sum(x), - {"x": EncryptedScalar(UnsignedInteger(3))}, - [numpy.random.randint(0, 2 ** 3) for i in range(10)], - ValueError, - ( - """ - -only encrypted tensor sum is supported but you tried to sum EncryptedScalar - - """.strip() # noqa: E501 - ), - ), - pytest.param( - lambda x: numpy.sum(x, axis="abc"), # type: ignore - {"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2))}, - [numpy.random.randint(0, 2 ** 3, size=(3, 2)) for i in range(10)], - ValueError, - ( - """ - -invalid sum on EncryptedTensor with axis=abc - - """.strip() # noqa: E501 - ), - ), - pytest.param( - lambda x: numpy.sum(x), # type: ignore - {"x": ClearTensor(UnsignedInteger(3), shape=(3, 2))}, - [numpy.random.randint(0, 2 ** 3, size=(3, 2)) for i in range(10)], - ValueError, - ( - """ - -only encrypted tensor sum is supported but you tried to sum ClearTensor - - """.strip() # noqa: E501 - ), - ), - pytest.param( - lambda x: numpy.concatenate((x, x)), - {"x": EncryptedScalar(UnsignedInteger(3))}, - [numpy.random.randint(0, 2 ** 3) for i in range(10)], - ValueError, - ( - """ - -only encrypted tensor concatenation is supported but you tried to concatenate EncryptedScalar, EncryptedScalar - - """.strip() # noqa: E501 - ), - ), - pytest.param( - lambda x: numpy.concatenate((x, x), axis="abc"), # type: ignore - {"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2))}, - [numpy.random.randint(0, 2 ** 3, size=(3, 2)) for i in range(10)], - ValueError, - ( - """ - -invalid concatenation of EncryptedTensor, EncryptedTensor with axis=abc - - """.strip() # noqa: E501 - ), - ), - pytest.param( - lambda x: numpy.concatenate((x, x)), # type: ignore - {"x": ClearTensor(UnsignedInteger(3), shape=(3, 2))}, - [numpy.random.randint(0, 2 ** 3, size=(3, 2)) for i in range(10)], - ValueError, - ( - """ - -only encrypted tensor concatenation is supported but you tried to concatenate ClearTensor, ClearTensor - - """.strip() # noqa: E501 - ), - ), - ], -) -# pylint: enable=line-too-long -def test_fail_compile( - function, - parameters, - inputset, - error, - match, - default_compilation_configuration, -): - """Test function compile_numpy_function_into_op_graph for a program with signed values""" - - with pytest.raises(error) as excinfo: - compile_numpy_function( - function, - parameters, - inputset, - default_compilation_configuration, - ) - - assert str(excinfo.value) == match, str(excinfo.value) - - -@pytest.mark.parametrize( - "function,parameters,inputset,match", - [ - pytest.param( - lambda x: (x * 1.5)[0, 1], - {"x": EncryptedTensor(SignedInteger(3), shape=(2, 2))}, - [numpy.random.randint(-4, 3, size=(2, 2)) for i in range(10)], - ( - """ - -function you are trying to compile isn't supported for MLIR lowering - -%0 = x # EncryptedTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported -%1 = 1.5 # ClearScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported -%2 = mul(%0, %1) # EncryptedTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer multiplication is supported -%3 = %2[0, 1] # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer outputs are supported -return %3 - - """.strip() # noqa: E501 - ), - ), - ], -) -def test_fail_compile_while_fusing_is_disabled( - function, parameters, inputset, match, default_compilation_configuration -): - """Test compile_numpy_function without fusing and with failing inputs""" - - configuration_to_use = deepcopy(default_compilation_configuration) - configuration_to_use.enable_topological_optimizations = False - - with pytest.raises(RuntimeError) as excinfo: - compile_numpy_function( - function, - parameters, - inputset, - configuration_to_use, - ) - - assert str(excinfo.value) == match, str(excinfo.value) - - -def test_small_inputset_no_fail(): - """Test function compile_numpy_function_into_op_graph with an unacceptably small inputset""" - compile_numpy_function_into_op_graph_and_measure_bounds( - lambda x: x + 42, - {"x": EncryptedScalar(Integer(5, is_signed=False))}, - [0, 3], - CompilationConfiguration(dump_artifacts_on_unexpected_failures=False), - ) - - -def test_small_inputset_treat_warnings_as_errors(): - """Test function compile_numpy_function_into_op_graph with an unacceptably small inputset""" - with pytest.raises(ValueError, match=".* inputset contains too few inputs .*"): - compile_numpy_function_into_op_graph_and_measure_bounds( - lambda x: x + 42, - {"x": EncryptedScalar(Integer(5, is_signed=False))}, - [0, 3], - CompilationConfiguration( - dump_artifacts_on_unexpected_failures=False, - treat_warnings_as_errors=True, - ), - ) - - -@pytest.mark.parametrize( - "function,params,shape,ref_graph_str", - [ - ( - lambda x, y: numpy.dot(x, y), - { - "x": EncryptedTensor(Integer(2, is_signed=False), shape=(4,)), - "y": EncryptedTensor(Integer(2, is_signed=False), shape=(4,)), - }, - (4,), - # Remark that, when you do the dot of tensors of 4 values between 0 and 3, - # you can get a maximal value of 4*3*3 = 36, ie something on 6 bits - """ - -%0 = x # EncryptedTensor -%1 = y # EncryptedTensor -%2 = dot(%0, %1) # EncryptedScalar -return %2 - - """.strip(), - ), - ], -) -def test_compile_function_with_dot( - function, params, shape, ref_graph_str, default_compilation_configuration -): - """Test compile_numpy_function_into_op_graph for a program with np.dot""" - - # This is the exhaust, but if ever we have too long inputs (ie, large 'repeat'), - # we'll have to take random values, not all values one by one - def data_gen_local(max_for_ij, repeat): - iter_i = itertools.product(range(0, max_for_ij + 1), repeat=repeat) - iter_j = itertools.product(range(0, max_for_ij + 1), repeat=repeat) - for prod_i, prod_j in itertools.product(iter_i, iter_j): - yield numpy.array(prod_i), numpy.array(prod_j) - - max_for_ij = 3 - assert len(shape) == 1 - repeat = shape[0] - - op_graph = compile_numpy_function_into_op_graph_and_measure_bounds( - function, - params, - data_gen_local(max_for_ij, repeat), - default_compilation_configuration, - ) - str_of_the_graph = format_operation_graph(op_graph) - assert str_of_the_graph == ref_graph_str, ( - f"\n==================\nGot \n{str_of_the_graph}" - f"==================\nExpected \n{ref_graph_str}" - f"==================\n" - ) - - -@pytest.mark.parametrize( - "function,input_ranges,list_of_arg_names", - [ - pytest.param(lambda x: x + 64, ((0, 10),), ["x"]), - pytest.param(lambda x: x * 3, ((0, 40),), ["x"]), - pytest.param(lambda x: 120 - x, ((40, 80),), ["x"]), - pytest.param(lambda x, y: x + y + 64, ((0, 20), (0, 20)), ["x", "y"]), - pytest.param(lambda x, y: 100 - y + x, ((0, 20), (0, 20)), ["x", "y"]), - pytest.param(lambda x, y: 50 - y * 2 + x, ((0, 20), (0, 20)), ["x", "y"]), - ], -) -def test_compile_with_show_mlir( - function, input_ranges, list_of_arg_names, default_compilation_configuration -): - """Test show_mlir option""" - - function_parameters = { - arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names - } - - compile_numpy_function( - function, - function_parameters, - data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)), - default_compilation_configuration, - show_mlir=True, - ) - - -def test_compile_too_high_bitwidth(default_compilation_configuration): - """Check that the check of maximal bitwidth of intermediate data works fine.""" - - def function(x, y): - return x + y - - function_parameters = { - "x": EncryptedScalar(Integer(64, False)), - "y": EncryptedScalar(Integer(64, False)), - } - - # A bit too much - input_ranges = [(100, 200), (100, 200)] - - with pytest.raises(RuntimeError) as excinfo: - compile_numpy_function( - function, - function_parameters, - data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)), - default_compilation_configuration, - ) - - assert ( - str(excinfo.value) - == """ - -max_bit_width of some nodes is too high for the current version of the compiler (maximum must be 8) which is not compatible with: - -%0 = x # EncryptedScalar -%1 = y # EncryptedScalar -%2 = add(%0, %1) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 9 bits is not supported for the time being -return %2 - - """.strip() # noqa: E501 # pylint: disable=line-too-long - ) - - # Just ok - input_ranges = [(0, 99), (0, 28)] - - compile_numpy_function( - function, - function_parameters, - data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)), - default_compilation_configuration, - ) - - -def test_compile_with_random_inputset(default_compilation_configuration): - """Test function for compile with random input set""" - - configuration_to_use = deepcopy(default_compilation_configuration) - configuration_to_use.enable_unsafe_features = True - - compile_numpy_function_into_op_graph_and_measure_bounds( - lambda x: x + 1, - {"x": EncryptedScalar(UnsignedInteger(6))}, - inputset="random", - compilation_configuration=configuration_to_use, - ) - compile_numpy_function( - lambda x: x + 32, - {"x": EncryptedScalar(UnsignedInteger(6))}, - inputset="random", - compilation_configuration=configuration_to_use, - ) - - -def test_fail_compile_with_random_inputset(): - """Test function for failed compile with random input set""" - - compilation_configuration = CompilationConfiguration( - dump_artifacts_on_unexpected_failures=False, - treat_warnings_as_errors=True, - ) - - with pytest.raises(ValueError): - try: - compile_numpy_function_into_op_graph_and_measure_bounds( - lambda x: x + 1, - {"x": EncryptedScalar(UnsignedInteger(3))}, - inputset="unsupported", - compilation_configuration=compilation_configuration, - ) - except Exception as error: - expected = ( - "inputset can only be an iterable of tuples or the string 'random' " - "but you specified 'unsupported' for it" - ) - assert str(error) == expected - raise - - with pytest.raises(RuntimeError): - try: - compile_numpy_function( - lambda x: x + 1, - {"x": EncryptedScalar(UnsignedInteger(3))}, - inputset="random", - compilation_configuration=compilation_configuration, - ) - except Exception as error: - expected = ( - "Random inputset generation is an unsafe feature " - "and should not be used if you don't know what you are doing" - ) - assert str(error) == expected - raise - - -def test_wrong_inputs(default_compilation_configuration): - """Test compilation with faulty inputs""" - - # x should have been something like EncryptedScalar(UnsignedInteger(3)) - x = [1, 2, 3] - input_ranges = ((0, 10),) - inputset = data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)) - dict_for_inputs = {"x": x} - - with pytest.raises(AssertionError) as excinfo: - compile_numpy_function( - lambda x: 2 * x, dict_for_inputs, inputset, default_compilation_configuration - ) - - list_of_possible_basevalue = [ - "ClearTensor", - "EncryptedTensor", - "ClearScalar", - "EncryptedScalar", - ] - assert ( - str(excinfo.value) == f"wrong type for inputs {dict_for_inputs}, " - f"needs to be one of {list_of_possible_basevalue}" - ) - - -@pytest.mark.parametrize( - "function,input_ranges,list_of_arg_names", - [ - pytest.param(lambda x: (x + (-27)) + 32, ((0, 10),), ["x"]), - pytest.param(lambda x: ((-3) * x) + (100 - (x + 1)), ((0, 10),), ["x"]), - pytest.param( - lambda x, y: (-1) * x + (-2) * y + 40, - ( - (0, 10), - (0, 10), - ), - ["x", "y"], - ), - ], -) -def test_compile_and_run_correctness_with_negative_values( - function, input_ranges, list_of_arg_names, default_compilation_configuration -): - """Test correctness of results when running a compiled function, which has some negative - intermediate values.""" - - function_parameters = { - arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names - } - - compiler_engine = compile_numpy_function( - function, - function_parameters, - data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)), - default_compilation_configuration, - ) - - args = [random.randint(low, high) for (low, high) in input_ranges] - assert compiler_engine.encrypt_run_decrypt(*args) == function(*args) - - -@pytest.mark.parametrize( - "function,input_ranges,list_of_arg_names", - [ - pytest.param( - lambda x: (20 + 10 * numpy.tanh(50 * (numpy.cos(x + 33.0)))).astype(numpy.uint32), - ((0, 31),), - ["x"], - ), - pytest.param( - lambda x: (20 * (numpy.cos(x + 33.0)) + 30).astype(numpy.uint32), - ((0, 31),), - ["x"], - ), - ], -) -def test_compile_and_run_correctness_with_negative_values_and_pbs( - function, - input_ranges, - list_of_arg_names, - default_compilation_configuration, - check_is_good_execution, -): - """Test correctness of results when running a compiled function, which has some negative - intermediate values.""" - - function_parameters = { - arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names - } - - compiler_engine = compile_numpy_function( - function, - function_parameters, - data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)), - default_compilation_configuration, - ) - - args = [random.randint(low, high) for (low, high) in input_ranges] - check_is_good_execution(compiler_engine, function, args, verbose=False) - - -def check_equality_modulo(a, b, modulus): - """Check that (a mod modulus) == (b mod modulus)""" - return (a % modulus) == (b % modulus) - - -@pytest.mark.parametrize( - "function,input_ranges,list_of_arg_names,modulus", - [ - pytest.param(lambda x: x + (-20), ((0, 10),), ["x"], 128), - pytest.param(lambda x: 10 + x * (-3), ((0, 20),), ["x"], 128), - ], -) -def test_compile_and_run_correctness_with_negative_results( - function, input_ranges, list_of_arg_names, modulus, default_compilation_configuration -): - """Test correctness of computations when the result is possibly negative: until #845 is fixed, - results are currently only correct modulo a power of 2 (given by `modulus` parameter). Eg, - instead of returning -3, the execution may return -3 mod 128 = 125.""" - - function_parameters = { - arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names - } - - compiler_engine = compile_numpy_function( - function, - function_parameters, - data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)), - default_compilation_configuration, - ) - - args = [random.randint(low, high) for (low, high) in input_ranges] - assert check_equality_modulo( - compiler_engine.encrypt_run_decrypt(*args), function(*args), modulus - ) - - -@pytest.mark.parametrize( - "compilation_configuration", - [ - CompilationConfiguration( - dump_artifacts_on_unexpected_failures=False, - enable_unsafe_features=False, - use_insecure_key_cache=False, - ), - CompilationConfiguration( - dump_artifacts_on_unexpected_failures=False, - enable_unsafe_features=True, - use_insecure_key_cache=False, - ), - CompilationConfiguration( - dump_artifacts_on_unexpected_failures=False, - enable_unsafe_features=False, - use_insecure_key_cache=True, - ), - ], -) -def test_compile_improper_use_of_insecure_key_cache( - default_keyring_path, compilation_configuration -): - """Test the case where the key cache is used with wrong compilation configuration. - - DO NOT USE INSECURE KEY CACHE FOR NORMAL PRODUCTION WORK - - This is a test to check we properly fail for users trying to incorrectly use the insecure key - cache (to reuse keys across compilations). This allows to speed up tests A LOT but should not be - used in normal prod environment /!\\ DANGER /!\\.""" - - def f(x): - return x + 42 - - if compile_._COMPILE_FHE_INSECURE_KEY_CACHE_DIR is None: # pylint: disable=protected-access - compile_._COMPILE_FHE_INSECURE_KEY_CACHE_DIR = str( # pylint: disable=protected-access - default_keyring_path - ) - - with pytest.raises( - RuntimeError, - match="Unable to use insecure key cache .* " - "as use_insecure_key_cache or enable_unsafe_features are not set to True in" - "compilation_configuration", - ): - _ = compile_numpy_function( - f, - {"x": EncryptedScalar(Integer(64, False))}, - range(10), - compilation_configuration, - ) diff --git a/tests/numpy/test_compile_constant_indexing.py b/tests/numpy/test_compile_constant_indexing.py deleted file mode 100644 index fcc186f47..000000000 --- a/tests/numpy/test_compile_constant_indexing.py +++ /dev/null @@ -1,727 +0,0 @@ -"""Test module for constant indexing.""" - -import numpy as np -import pytest - -from concrete.common.data_types import UnsignedInteger -from concrete.common.values import EncryptedScalar, EncryptedTensor -from concrete.numpy import ( - compile_numpy_function, - compile_numpy_function_into_op_graph_and_measure_bounds, -) - - -@pytest.mark.parametrize( - "input_value,function_with_indexing,output_value", - [ - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[-3], - EncryptedScalar(UnsignedInteger(1)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[-2], - EncryptedScalar(UnsignedInteger(1)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[-1], - EncryptedScalar(UnsignedInteger(1)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[0], - EncryptedScalar(UnsignedInteger(1)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[1], - EncryptedScalar(UnsignedInteger(1)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[2], - EncryptedScalar(UnsignedInteger(1)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[:], - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[-3:], - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[-2:], - EncryptedTensor(UnsignedInteger(1), shape=(2,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[-1:], - EncryptedTensor(UnsignedInteger(1), shape=(1,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[0:], - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[1:], - EncryptedTensor(UnsignedInteger(1), shape=(2,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[2:], - EncryptedTensor(UnsignedInteger(1), shape=(1,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[:-1], - EncryptedTensor(UnsignedInteger(1), shape=(2,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[:-2], - EncryptedTensor(UnsignedInteger(1), shape=(1,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[:1], - EncryptedTensor(UnsignedInteger(1), shape=(1,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[:2], - EncryptedTensor(UnsignedInteger(1), shape=(2,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[:3], - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[-3:-2], - EncryptedTensor(UnsignedInteger(1), shape=(1,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[-3:-1], - EncryptedTensor(UnsignedInteger(1), shape=(2,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[-3:1], - EncryptedTensor(UnsignedInteger(1), shape=(1,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[-3:2], - EncryptedTensor(UnsignedInteger(1), shape=(2,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[-3:3], - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[-2:-1], - EncryptedTensor(UnsignedInteger(1), shape=(1,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[-2:2], - EncryptedTensor(UnsignedInteger(1), shape=(1,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[-2:3], - EncryptedTensor(UnsignedInteger(1), shape=(2,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[-1:3], - EncryptedTensor(UnsignedInteger(1), shape=(1,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[0:-2], - EncryptedTensor(UnsignedInteger(1), shape=(1,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[0:-1], - EncryptedTensor(UnsignedInteger(1), shape=(2,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[0:1], - EncryptedTensor(UnsignedInteger(1), shape=(1,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[0:2], - EncryptedTensor(UnsignedInteger(1), shape=(2,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[0:3], - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[1:-1], - EncryptedTensor(UnsignedInteger(1), shape=(1,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[1:2], - EncryptedTensor(UnsignedInteger(1), shape=(1,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[1:3], - EncryptedTensor(UnsignedInteger(1), shape=(2,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[2:3], - EncryptedTensor(UnsignedInteger(1), shape=(1,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[::-1], - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[-3::-1], - EncryptedTensor(UnsignedInteger(1), shape=(1,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[-2::-1], - EncryptedTensor(UnsignedInteger(1), shape=(2,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[-1::-1], - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[0::-1], - EncryptedTensor(UnsignedInteger(1), shape=(1,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[1::-1], - EncryptedTensor(UnsignedInteger(1), shape=(2,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[2::-1], - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[:-3:-1], - EncryptedTensor(UnsignedInteger(1), shape=(2,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[:-2:-1], - EncryptedTensor(UnsignedInteger(1), shape=(1,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[:0:-1], - EncryptedTensor(UnsignedInteger(1), shape=(2,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[:1:-1], - EncryptedTensor(UnsignedInteger(1), shape=(1,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[2:0:-1], - EncryptedTensor(UnsignedInteger(1), shape=(2,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[2:1:-1], - EncryptedTensor(UnsignedInteger(1), shape=(1,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[-1:1:-1], - EncryptedTensor(UnsignedInteger(1), shape=(1,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[-1:0:-1], - EncryptedTensor(UnsignedInteger(1), shape=(2,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), - lambda x: x[:, :, :], - EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), - lambda x: x[0, :, :], - EncryptedTensor(UnsignedInteger(1), shape=(4, 5)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), - lambda x: x[:, 0, :], - EncryptedTensor(UnsignedInteger(1), shape=(3, 5)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), - lambda x: x[:, :, 0], - EncryptedTensor(UnsignedInteger(1), shape=(3, 4)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), - lambda x: x[0, 0, :], - EncryptedTensor(UnsignedInteger(1), shape=(5,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), - lambda x: x[0, :, 0], - EncryptedTensor(UnsignedInteger(1), shape=(4,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), - lambda x: x[:, 0, 0], - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), - lambda x: x[0:, 1:, 2:], - EncryptedTensor(UnsignedInteger(1), shape=(3, 3, 3)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), - lambda x: x[2:, 1:, 0:], - EncryptedTensor(UnsignedInteger(1), shape=(1, 3, 5)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), - lambda x: x[0], - EncryptedTensor(UnsignedInteger(1), shape=(4, 5)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), - lambda x: x[0, 0], - EncryptedTensor(UnsignedInteger(1), shape=(5,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)), - lambda x: x[0, 0, 0], - EncryptedScalar(UnsignedInteger(1)), - ), - ], -) -def test_constant_indexing( - default_compilation_configuration, - input_value, - function_with_indexing, - output_value, -): - """Test compile_numpy_function_into_op_graph with constant indexing""" - - inputset = [ - np.random.randint( - input_value.dtype.min_value(), - input_value.dtype.max_value() + 1, - size=input_value.shape, - ) - for _ in range(10) - ] - - op_graph = compile_numpy_function_into_op_graph_and_measure_bounds( - function_with_indexing, - {"x": input_value}, - inputset, - default_compilation_configuration, - ) - - assert len(op_graph.output_nodes) == 1 - output_node = op_graph.output_nodes[0] - - assert len(output_node.outputs) == 1 - assert output_value == output_node.outputs[0] - - -@pytest.mark.parametrize( - "input_value,function_with_indexing,expected_error_type,expected_error_message", - [ - pytest.param( - EncryptedScalar(UnsignedInteger(1)), - lambda x: x[0], - TypeError, - "Only tensors can be indexed but you tried to index EncryptedScalar", - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[0.5], - TypeError, - "Only integers and integer slices can be used for indexing " - "but you tried to use 0.5 for indexing", - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[1:5:0.5], # type: ignore - TypeError, - "Only integers and integer slices can be used for indexing " - "but you tried to use 1:5:0.5 for indexing", - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[0, 1], - ValueError, - "Tensor of shape (3,) cannot be indexed with [0, 1] " - "as the index has more elements than the number of dimensions of the tensor", - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[5], - ValueError, - "Tensor of shape (3,) cannot be indexed with [5] " - "because index is out of range for dimension 0", - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[5:], - ValueError, - "Tensor of shape (3,) cannot be indexed with [5:] " - "because start index is out of range for dimension 0", - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[:10], - ValueError, - "Tensor of shape (3,) cannot be indexed with [:10] " - "because stop index is out of range for dimension 0", - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[2:0], - ValueError, - "Tensor of shape (3,) cannot be indexed with [2:0] " - "because start index is not less than stop index for dimension 0", - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[5::-1], - ValueError, - "Tensor of shape (3,) cannot be indexed with [5::-1] " - "because start index is out of range for dimension 0", - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[:10:-1], - ValueError, - "Tensor of shape (3,) cannot be indexed with [:10:-1] " - "because stop index is out of range for dimension 0", - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[0:2:-1], - ValueError, - "Tensor of shape (3,) cannot be indexed with [0:2:-1] " - "because step is negative and stop index is not less than start index for dimension 0", - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[::0], - ValueError, - "Tensor of shape (3,) cannot be indexed with [::0] " - "because step is zero for dimension 0", - ), - ], -) -def test_invalid_constant_indexing( - default_compilation_configuration, - input_value, - function_with_indexing, - expected_error_type, - expected_error_message, -): - """Test compile_numpy_function_into_op_graph with invalid constant indexing""" - - with pytest.raises(expected_error_type): - try: - inputset = [ - ( - np.random.randint( - input_value.dtype.min_value(), - input_value.dtype.max_value() + 1, - size=input_value.shape, - ), - ) - for _ in range(10) - ] - compile_numpy_function_into_op_graph_and_measure_bounds( - function_with_indexing, - {"x": input_value}, - inputset, - default_compilation_configuration, - ) - except Exception as error: - assert str(error) == expected_error_message - raise - - -@pytest.mark.parametrize( - "input_value,function_with_indexing,output_value", - [ - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[np.uint32(0)], - EncryptedScalar(UnsignedInteger(1)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[slice(np.uint32(2), np.int32(0), np.int8(-1))], - EncryptedTensor(UnsignedInteger(1), shape=(2,)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[np.array(0)], - EncryptedScalar(UnsignedInteger(1)), - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[slice(np.array(2), np.array(0), np.array(-1))], - EncryptedTensor(UnsignedInteger(1), shape=(2,)), - ), - ], -) -def test_constant_indexing_with_numpy_integers( - default_compilation_configuration, - input_value, - function_with_indexing, - output_value, -): - """Test compile_numpy_function_into_op_graph with constant indexing with numpy integers""" - - inputset = [ - np.random.randint( - input_value.dtype.min_value(), - input_value.dtype.max_value() + 1, - size=input_value.shape, - ) - for _ in range(10) - ] - - op_graph = compile_numpy_function_into_op_graph_and_measure_bounds( - function_with_indexing, - {"x": input_value}, - inputset, - default_compilation_configuration, - ) - - assert len(op_graph.output_nodes) == 1 - output_node = op_graph.output_nodes[0] - - assert len(output_node.outputs) == 1 - assert output_value == output_node.outputs[0] - - -@pytest.mark.parametrize( - "input_value,function_with_indexing,expected_error_type,expected_error_message", - [ - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[np.float32(1.5)], - TypeError, - "Only integers and integer slices can be used for indexing " - "but you tried to use 1.5 for indexing", - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[np.array(1.5)], - TypeError, - "Only integers and integer slices can be used for indexing " - "but you tried to use 1.5 for indexing", - ), - pytest.param( - EncryptedTensor(UnsignedInteger(1), shape=(3,)), - lambda x: x[np.array([1, 2])], - TypeError, - "Only integers and integer slices can be used for indexing " - "but you tried to use [1 2] for indexing", - ), - ], -) -def test_invalid_constant_indexing_with_numpy_values( - default_compilation_configuration, - input_value, - function_with_indexing, - expected_error_type, - expected_error_message, -): - """Test compile_numpy_function_into_op_graph with invalid constant indexing with numpy values""" - - with pytest.raises(expected_error_type): - try: - inputset = [ - ( - np.random.randint( - input_value.dtype.min_value(), - input_value.dtype.max_value() + 1, - size=input_value.shape, - ), - ) - for _ in range(10) - ] - compile_numpy_function_into_op_graph_and_measure_bounds( - function_with_indexing, - {"x": input_value}, - inputset, - default_compilation_configuration, - ) - except Exception as error: - assert str(error) == expected_error_message - raise - - -@pytest.mark.parametrize( - "function,parameters,inputset,test_input,expected_output", - [ - pytest.param( - lambda x: x[0], - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(3,)), - }, - [np.random.randint(0, 2 ** 3, size=(3,)) for _ in range(10)], - ([4, 2, 6],), - 4, - ), - pytest.param( - lambda x: x[-1], - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(3,)), - }, - [np.random.randint(0, 2 ** 3, size=(3,)) for _ in range(10)], - ([4, 2, 6],), - 6, - ), - pytest.param( - lambda x: x[:3], - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(4,)), - }, - [np.random.randint(0, 2 ** 3, size=(4,)) for _ in range(10)], - ([4, 2, 6, 1],), - [4, 2, 6], - ), - pytest.param( - lambda x: x[2:], - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(4,)), - }, - [np.random.randint(0, 2 ** 3, size=(4,)) for _ in range(10)], - ([4, 2, 6, 1],), - [6, 1], - ), - pytest.param( - lambda x: x[1:3], - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(4,)), - }, - [np.random.randint(0, 2 ** 3, size=(4,)) for _ in range(10)], - ([4, 2, 6, 1],), - [2, 6], - ), - pytest.param( - lambda x: x[::2], - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(4,)), - }, - [np.random.randint(0, 2 ** 3, size=(4,)) for _ in range(10)], - ([4, 2, 6, 1],), - [4, 6], - ), - pytest.param( - lambda x: x[::-1], - { - "x": EncryptedTensor(UnsignedInteger(3), shape=(4,)), - }, - [np.random.randint(0, 2 ** 3, size=(4,)) for _ in range(10)], - ([4, 2, 6, 1],), - [1, 6, 2, 4], - ), - pytest.param( - lambda x: x[1, 0], - { - "x": EncryptedTensor(UnsignedInteger(6), shape=(3, 2)), - }, - [np.random.randint(0, 2 ** 6, size=(3, 2)) for _ in range(10)], - ([[11, 12], [21, 22], [31, 32]],), - 21, - ), - pytest.param( - lambda x: x[:, :], - { - "x": EncryptedTensor(UnsignedInteger(6), shape=(3, 2)), - }, - [np.random.randint(0, 2 ** 6, size=(3, 2)) for _ in range(10)], - ([[11, 12], [21, 22], [31, 32]],), - [[11, 12], [21, 22], [31, 32]], - ), - pytest.param( - lambda x: x[0, :], - { - "x": EncryptedTensor(UnsignedInteger(6), shape=(3, 2)), - }, - [np.random.randint(0, 2 ** 6, size=(3, 2)) for _ in range(10)], - ([[11, 12], [21, 22], [31, 32]],), - [11, 12], - ), - pytest.param( - lambda x: x[:, 0], - { - "x": EncryptedTensor(UnsignedInteger(6), shape=(3, 2)), - }, - [np.random.randint(0, 2 ** 6, size=(3, 2)) for _ in range(10)], - ([[11, 12], [21, 22], [31, 32]],), - [11, 21, 31], - ), - ], -) -def test_constant_indexing_run_correctness( - function, - parameters, - inputset, - test_input, - expected_output, - default_compilation_configuration, - check_array_equality, -): - """Test correctness of results when running a compiled function with tensor operators""" - circuit = compile_numpy_function( - function, - parameters, - inputset, - default_compilation_configuration, - ) - - numpy_test_input = tuple( - item if isinstance(item, int) else np.array(item, dtype=np.uint8) for item in test_input - ) - - output = circuit.encrypt_run_decrypt(*numpy_test_input) - expected = np.array(expected_output, dtype=np.uint8) - - check_array_equality(output, expected) diff --git a/tests/numpy/test_compile_conv.py b/tests/numpy/test_compile_conv.py deleted file mode 100644 index 4c4383827..000000000 --- a/tests/numpy/test_compile_conv.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Test module for convolution compilation and execution.""" - -import numpy as np -import pytest - -import concrete.numpy as hnp -from concrete.common.data_types.integers import Integer -from concrete.common.values.tensors import EncryptedTensor -from concrete.numpy.compile import compile_numpy_function - - -@pytest.mark.parametrize( - "input_shape, weight_shape", - [ - pytest.param((1, 1, 4, 4), (1, 1, 2, 2)), - pytest.param((4, 3, 4, 4), (2, 3, 2, 2)), - ], -) -@pytest.mark.parametrize("strides", [(2, 2)]) -@pytest.mark.parametrize("dilations", [(1, 1)]) -@pytest.mark.parametrize("has_bias", [True, False]) -def test_compile_and_run( - input_shape, weight_shape, strides, dilations, has_bias, default_compilation_configuration -): - """Test function to make sure compilation and execution of conv2d works properly""" - if has_bias: - bias = np.random.randint(0, 4, size=(weight_shape[0],)) - else: - bias = None - weight = np.random.randint(0, 4, size=weight_shape) - - def conv(x): - return hnp.conv2d(x, weight, bias, strides=strides, dilations=dilations) - - compiler_engine = compile_numpy_function( - conv, - {"x": EncryptedTensor(Integer(64, False), input_shape)}, - [np.random.randint(0, 4, size=input_shape) for i in range(20)], - default_compilation_configuration, - ) - x = np.random.randint(0, 4, size=input_shape, dtype=np.uint8) - expected = conv(x) - result = compiler_engine.encrypt_run_decrypt(x) - assert (expected == result).all() diff --git a/tests/numpy/test_compile_memory_operations.py b/tests/numpy/test_compile_memory_operations.py deleted file mode 100644 index b743c8115..000000000 --- a/tests/numpy/test_compile_memory_operations.py +++ /dev/null @@ -1,265 +0,0 @@ -"""Test module for memory operations.""" - -import numpy -import pytest - -from concrete.common.data_types import UnsignedInteger -from concrete.common.values import EncryptedTensor -from concrete.numpy import compile_numpy_function - - -@pytest.mark.parametrize( - "function,parameters,inputset,test_input,expected_output", - [ - pytest.param( - lambda x: x.flatten(), - { - "x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2)), - }, - [numpy.random.randint(0, 2 ** 4, size=(3, 2)) for _ in range(10)], - [[0, 1], [1, 2], [2, 3]], - [0, 1, 1, 2, 2, 3], - ), - pytest.param( - lambda x: x.flatten(), - { - "x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4, 5, 6)), - }, - [numpy.random.randint(0, 2 ** 4, size=(2, 3, 4, 5, 6)) for _ in range(10)], - (numpy.arange(720) % 10).reshape((2, 3, 4, 5, 6)), - (numpy.arange(720) % 10), - ), - pytest.param( - lambda x: x.reshape((1, 3)), - { - "x": EncryptedTensor(UnsignedInteger(4), shape=(3,)), - }, - [numpy.random.randint(0, 2 ** 4, size=(3,)) for _ in range(10)], - [5, 9, 1], - [[5, 9, 1]], - ), - pytest.param( - lambda x: x.reshape((3, 1)), - { - "x": EncryptedTensor(UnsignedInteger(4), shape=(3,)), - }, - [numpy.random.randint(0, 2 ** 4, size=(3,)) for _ in range(10)], - [5, 9, 1], - [[5], [9], [1]], - ), - pytest.param( - lambda x: x.reshape((3, 2)), - { - "x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2)), - }, - [numpy.random.randint(0, 2 ** 4, size=(3, 2)) for _ in range(10)], - [[0, 1], [1, 2], [2, 3]], - [[0, 1], [1, 2], [2, 3]], - ), - pytest.param( - lambda x: x.reshape((3, 2)), - { - "x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3)), - }, - [numpy.random.randint(0, 2 ** 4, size=(2, 3)) for _ in range(10)], - [[0, 1, 1], [2, 2, 3]], - [[0, 1], [1, 2], [2, 3]], - ), - pytest.param( - lambda x: x.reshape(-1), - { - "x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2)), - }, - [numpy.random.randint(0, 2 ** 4, size=(3, 2)) for _ in range(10)], - [[0, 1], [1, 2], [2, 3]], - [0, 1, 1, 2, 2, 3], - ), - pytest.param( - lambda x: x.reshape((2, 2, 3)), - { - "x": EncryptedTensor(UnsignedInteger(4), shape=(4, 3)), - }, - [numpy.random.randint(0, 2 ** 4, size=(4, 3)) for _ in range(10)], - (numpy.arange(12) % 10).reshape((4, 3)), - (numpy.arange(12) % 10).reshape((2, 2, 3)), - ), - pytest.param( - lambda x: x.reshape((4, 3)), - { - "x": EncryptedTensor(UnsignedInteger(4), shape=(2, 2, 3)), - }, - [numpy.random.randint(0, 2 ** 4, size=(2, 2, 3)) for _ in range(10)], - (numpy.arange(12) % 10).reshape((2, 2, 3)), - (numpy.arange(12) % 10).reshape((4, 3)), - ), - pytest.param( - lambda x: x.reshape((3, 2, 2)), - { - "x": EncryptedTensor(UnsignedInteger(4), shape=(3, 4)), - }, - [numpy.random.randint(0, 2 ** 4, size=(3, 4)) for _ in range(10)], - (numpy.arange(12) % 10).reshape((3, 4)), - (numpy.arange(12) % 10).reshape((3, 2, 2)), - ), - pytest.param( - lambda x: x.reshape((3, 4)), - { - "x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2, 2)), - }, - [numpy.random.randint(0, 2 ** 4, size=(3, 2, 2)) for _ in range(10)], - (numpy.arange(12) % 10).reshape((3, 2, 2)), - (numpy.arange(12) % 10).reshape((3, 4)), - ), - pytest.param( - lambda x: x.reshape((5, 3, 2)), - { - "x": EncryptedTensor(UnsignedInteger(4), shape=(6, 5)), - }, - [numpy.random.randint(0, 2 ** 4, size=(6, 5)) for _ in range(10)], - (numpy.arange(30) % 10).reshape((6, 5)), - (numpy.arange(30) % 10).reshape((5, 3, 2)), - ), - pytest.param( - lambda x: x.reshape((5, 6)), - { - "x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 5)), - }, - [numpy.random.randint(0, 2 ** 4, size=(2, 3, 5)) for _ in range(10)], - (numpy.arange(30) % 10).reshape((2, 3, 5)), - (numpy.arange(30) % 10).reshape((5, 6)), - ), - pytest.param( - lambda x: x.reshape((6, 4, 30)), - { - "x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4, 5, 6)), - }, - [numpy.random.randint(0, 2 ** 4, size=(2, 3, 4, 5, 6)) for _ in range(10)], - (numpy.arange(720) % 10).reshape((2, 3, 4, 5, 6)), - (numpy.arange(720) % 10).reshape((6, 4, 30)), - ), - pytest.param( - lambda x: x.reshape((2, 60, 6)), - { - "x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4, 5, 6)), - }, - [numpy.random.randint(0, 2 ** 4, size=(2, 3, 4, 5, 6)) for _ in range(10)], - (numpy.arange(720) % 10).reshape((2, 3, 4, 5, 6)), - (numpy.arange(720) % 10).reshape((2, 60, 6)), - ), - pytest.param( - lambda x: x.reshape((6, 6, -1)), - { - "x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 2, 3, 4)), - }, - [numpy.random.randint(0, 2 ** 4, size=(2, 3, 2, 3, 4)) for _ in range(10)], - (numpy.arange(144) % 10).reshape((2, 3, 2, 3, 4)), - (numpy.arange(144) % 10).reshape((6, 6, -1)), - ), - pytest.param( - lambda x: x.reshape((6, -1, 12)), - { - "x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 2, 3, 4)), - }, - [numpy.random.randint(0, 2 ** 4, size=(2, 3, 2, 3, 4)) for _ in range(10)], - (numpy.arange(144) % 10).reshape((2, 3, 2, 3, 4)), - (numpy.arange(144) % 10).reshape((6, -1, 12)), - ), - pytest.param( - lambda x: x.reshape((-1, 18, 4)), - { - "x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 2, 3, 4)), - }, - [numpy.random.randint(0, 2 ** 4, size=(2, 3, 2, 3, 4)) for _ in range(10)], - (numpy.arange(144) % 10).reshape((2, 3, 2, 3, 4)), - (numpy.arange(144) % 10).reshape((-1, 18, 4)), - ), - ], -) -def test_memory_operation_run_correctness( - function, - parameters, - inputset, - test_input, - expected_output, - default_compilation_configuration, - check_array_equality, -): - """ - Test correctness of results when running a compiled function with memory operators. - - e.g., - - flatten - - reshape - """ - circuit = compile_numpy_function( - function, - parameters, - inputset, - default_compilation_configuration, - ) - - actual = circuit.encrypt_run_decrypt(numpy.array(test_input, dtype=numpy.uint8)) - expected = numpy.array(expected_output, dtype=numpy.uint8) - - check_array_equality(actual, expected) - - -@pytest.mark.parametrize( - "function,parameters,inputset,error,match", - [ - pytest.param( - lambda x: x.reshape((-1, -1, 2)), - { - "x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4)), - }, - [numpy.random.randint(0, 2 ** 4, size=(2, 3, 4)) for _ in range(10)], - ValueError, - "shapes are not compatible (old shape (2, 3, 4), new shape (-1, -1, 2))", - ), - pytest.param( - lambda x: x.reshape((3, -1, 3)), - { - "x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4)), - }, - [numpy.random.randint(0, 2 ** 4, size=(2, 3, 4)) for _ in range(10)], - ValueError, - "shapes are not compatible (old shape (2, 3, 4), new shape (3, -1, 3))", - ), - ], -) -def test_memory_operation_failed_compilation( - function, - parameters, - inputset, - error, - match, - default_compilation_configuration, -): - """ - Test compilation failures of compiled function with memory operations. - - e.g., - - reshape - """ - - with pytest.raises(error) as excinfo: - compile_numpy_function( - function, - parameters, - inputset, - default_compilation_configuration, - ) - - assert ( - str(excinfo.value) == match - ), f""" - -Actual Output -============= -{excinfo.value} - -Expected Output -=============== -{match} - - """ diff --git a/tests/numpy/test_compile_user_friendly_api.py b/tests/numpy/test_compile_user_friendly_api.py deleted file mode 100644 index cc48b56f3..000000000 --- a/tests/numpy/test_compile_user_friendly_api.py +++ /dev/null @@ -1,280 +0,0 @@ -"""Test file for user-friendly numpy compilation functions""" - -import numpy -import pytest - -from concrete.common.debugging import format_operation_graph -from concrete.numpy.np_fhe_compiler import NPFHECompiler - - -def complicated_topology(x, y): - """Mix x in an intricated way.""" - intermediate = x + y - x_p_1 = intermediate + 1 - x_p_2 = intermediate + 2 - x_p_3 = x_p_1 + x_p_2 - return ( - x_p_3.astype(numpy.int32), - x_p_2.astype(numpy.int32), - (x_p_2 + 3).astype(numpy.int32), - x_p_3.astype(numpy.int32) + 67, - ) - - -@pytest.mark.parametrize("input_shape", [(), (3, 1, 2)]) -def test_np_fhe_compiler_op_graph( - input_shape, default_compilation_configuration, check_array_equality -): - """Test NPFHECompiler in two subtests.""" - subtest_np_fhe_compiler_1_input_op_graph( - input_shape, default_compilation_configuration, check_array_equality - ) - subtest_np_fhe_compiler_2_inputs_op_graph( - input_shape, default_compilation_configuration, check_array_equality - ) - - -def subtest_np_fhe_compiler_1_input_op_graph( - input_shape, default_compilation_configuration, check_array_equality -): - """test for NPFHECompiler on one input function""" - - def function_to_compile(x): - return complicated_topology(x, 0) - - compiler = NPFHECompiler( - function_to_compile, - {"x": "encrypted"}, - default_compilation_configuration, - ) - - # For coverage when the OPGraph is not yet traced - compiler._patch_op_graph_input_to_accept_any_integer_input() # pylint: disable=protected-access - - assert compiler.compilation_configuration == default_compilation_configuration - assert compiler.compilation_configuration is not default_compilation_configuration - - for i in numpy.arange(5): - i = numpy.ones(input_shape, dtype=numpy.int64) * i - check_array_equality(compiler(i), function_to_compile(i)) - - # For coverage, check that we flush the inputset when we query the OPGraph - current_op_graph = compiler.op_graph - assert current_op_graph is not compiler.op_graph - assert len(compiler._current_inputset) == 0 # pylint: disable=protected-access - # For coverage, cover case where the current inputset is empty - compiler._eval_on_current_inputset() # pylint: disable=protected-access - - # Continue a bit more - for i in numpy.arange(5, 10): - i = numpy.ones(input_shape, dtype=numpy.int64) * i - check_array_equality(compiler(i), function_to_compile(i)) - - if input_shape == (): - assert ( - (got := format_operation_graph(compiler.op_graph)) - == """ %0 = 67 # ClearScalar - %1 = 2 # ClearScalar - %2 = 3 # ClearScalar - %3 = 1 # ClearScalar - %4 = x # EncryptedScalar - %5 = 0 # ClearScalar - %6 = add(%4, %5) # EncryptedScalar - %7 = add(%6, %1) # EncryptedScalar - %8 = add(%6, %3) # EncryptedScalar - %9 = astype(%7, dtype=int32) # EncryptedScalar -%10 = add(%7, %2) # EncryptedScalar -%11 = add(%8, %7) # EncryptedScalar -%12 = astype(%10, dtype=int32) # EncryptedScalar -%13 = astype(%11, dtype=int32) # EncryptedScalar -%14 = astype(%11, dtype=int32) # EncryptedScalar -%15 = add(%14, %0) # EncryptedScalar -(%13, %9, %12, %15)""" - ), got - else: - assert ( - (got := format_operation_graph(compiler.op_graph)) - == """ %0 = 67 # ClearScalar - %1 = 2 # ClearScalar - %2 = 3 # ClearScalar - %3 = 1 # ClearScalar - %4 = x # EncryptedTensor - %5 = 0 # ClearScalar - %6 = add(%4, %5) # EncryptedTensor - %7 = add(%6, %1) # EncryptedTensor - %8 = add(%6, %3) # EncryptedTensor - %9 = astype(%7, dtype=int32) # EncryptedTensor -%10 = add(%7, %2) # EncryptedTensor -%11 = add(%8, %7) # EncryptedTensor -%12 = astype(%10, dtype=int32) # EncryptedTensor -%13 = astype(%11, dtype=int32) # EncryptedTensor -%14 = astype(%11, dtype=int32) # EncryptedTensor -%15 = add(%14, %0) # EncryptedTensor -(%13, %9, %12, %15)""" - ), got - - -def subtest_np_fhe_compiler_2_inputs_op_graph( - input_shape, default_compilation_configuration, check_array_equality -): - """test for NPFHECompiler on two inputs function""" - - compiler = NPFHECompiler( - complicated_topology, - {"x": "encrypted", "y": "clear"}, - default_compilation_configuration, - ) - - # For coverage when the OPGraph is not yet traced - compiler._patch_op_graph_input_to_accept_any_integer_input() # pylint: disable=protected-access - - assert compiler.compilation_configuration == default_compilation_configuration - assert compiler.compilation_configuration is not default_compilation_configuration - - for i, j in zip(numpy.arange(5), numpy.arange(5, 10)): - i = numpy.ones(input_shape, dtype=numpy.int64) * i - j = numpy.ones(input_shape, dtype=numpy.int64) * j - check_array_equality(compiler(i, j), complicated_topology(i, j)) - - # For coverage, check that we flush the inputset when we query the OPGraph - current_op_graph = compiler.op_graph - assert current_op_graph is not compiler.op_graph - assert len(compiler._current_inputset) == 0 # pylint: disable=protected-access - # For coverage, cover case where the current inputset is empty - compiler._eval_on_current_inputset() # pylint: disable=protected-access - - # Continue a bit more - for i, j in zip(numpy.arange(5, 10), numpy.arange(5)): - i = numpy.ones(input_shape, dtype=numpy.int64) * i - j = numpy.ones(input_shape, dtype=numpy.int64) * j - check_array_equality(compiler(i, j), complicated_topology(i, j)) - - if input_shape == (): - assert ( - (got := format_operation_graph(compiler.op_graph)) - == """ %0 = 67 # ClearScalar - %1 = 2 # ClearScalar - %2 = 3 # ClearScalar - %3 = 1 # ClearScalar - %4 = x # EncryptedScalar - %5 = y # ClearScalar - %6 = add(%4, %5) # EncryptedScalar - %7 = add(%6, %1) # EncryptedScalar - %8 = add(%6, %3) # EncryptedScalar - %9 = astype(%7, dtype=int32) # EncryptedScalar -%10 = add(%7, %2) # EncryptedScalar -%11 = add(%8, %7) # EncryptedScalar -%12 = astype(%10, dtype=int32) # EncryptedScalar -%13 = astype(%11, dtype=int32) # EncryptedScalar -%14 = astype(%11, dtype=int32) # EncryptedScalar -%15 = add(%14, %0) # EncryptedScalar -(%13, %9, %12, %15)""" - ), got - else: - assert ( - (got := format_operation_graph(compiler.op_graph)) - == """ %0 = 67 # ClearScalar - %1 = 2 # ClearScalar - %2 = 3 # ClearScalar - %3 = 1 # ClearScalar - %4 = x # EncryptedTensor - %5 = y # ClearTensor - %6 = add(%4, %5) # EncryptedTensor - %7 = add(%6, %1) # EncryptedTensor - %8 = add(%6, %3) # EncryptedTensor - %9 = astype(%7, dtype=int32) # EncryptedTensor -%10 = add(%7, %2) # EncryptedTensor -%11 = add(%8, %7) # EncryptedTensor -%12 = astype(%10, dtype=int32) # EncryptedTensor -%13 = astype(%11, dtype=int32) # EncryptedTensor -%14 = astype(%11, dtype=int32) # EncryptedTensor -%15 = add(%14, %0) # EncryptedTensor -(%13, %9, %12, %15)""" - ), got - - -def remaining_inputset_size(inputset_len): - """Small function to generate test cases below for remaining inputset length.""" - return inputset_len % NPFHECompiler.INPUTSET_SIZE_BEFORE_AUTO_BOUND_UPDATE - - -@pytest.mark.parametrize( - "inputset_len, expected_remaining_inputset_len", - [ - (42, remaining_inputset_size(42)), - (128, remaining_inputset_size(128)), - (234, remaining_inputset_size(234)), - ], -) -def test_np_fhe_compiler_auto_flush( - inputset_len, - expected_remaining_inputset_len, - default_compilation_configuration, - check_array_equality, -): - """Test the auto flush of NPFHECompiler once the inputset is 128 elements.""" - - def function_to_compile(x): - return x // 2 - - compiler = NPFHECompiler( - function_to_compile, - {"x": "encrypted"}, - default_compilation_configuration, - ) - - for i in numpy.arange(inputset_len): - check_array_equality(compiler(i), function_to_compile(i)) - - # Check the inputset was properly flushed - assert ( - len(compiler._current_inputset) # pylint: disable=protected-access - == expected_remaining_inputset_len - ) - - -def test_np_fhe_compiler_full_compilation(default_compilation_configuration, check_array_equality): - """Test the case where we generate an FHE circuit.""" - - def function_to_compile(x): - return x + 42 - - compiler = NPFHECompiler( - function_to_compile, - {"x": "encrypted"}, - default_compilation_configuration, - ) - - # For coverage - with pytest.raises(RuntimeError) as excinfo: - compiler.get_compiled_fhe_circuit() - - assert str(excinfo.value) == ( - "Requested FHECircuit but no OPGraph was compiled. " - "Did you forget to evaluate NPFHECompiler over an inputset?" - ) - - for i in numpy.arange(64): - check_array_equality(compiler(i), function_to_compile(i)) - - fhe_circuit = compiler.get_compiled_fhe_circuit() - - for i in range(64): - assert fhe_circuit.encrypt_run_decrypt(i) == function_to_compile(i) - - -def test_np_fhe_compiler_compile_on_inputset(default_compilation_configuration): - """Test the case where we generate an FHE circuit with a single call.""" - - def function_to_compile(x): - return x + 42 - - compiler = NPFHECompiler( - function_to_compile, - {"x": "encrypted"}, - default_compilation_configuration, - ) - circuit = compiler.compile_on_inputset(numpy.arange(64)) - - for i in range(64): - assert circuit.encrypt_run_decrypt(i) == function_to_compile(i) diff --git a/tests/numpy/test_np_dtypes_helpers.py b/tests/numpy/test_np_dtypes_helpers.py deleted file mode 100644 index 7f2a9cfab..000000000 --- a/tests/numpy/test_np_dtypes_helpers.py +++ /dev/null @@ -1,111 +0,0 @@ -"""Test file for numpy dtype helpers""" - -import numpy -import pytest - -from concrete.common.data_types.floats import Float -from concrete.common.data_types.integers import Integer -from concrete.numpy.np_dtypes_helpers import ( - convert_base_data_type_to_numpy_dtype, - convert_numpy_dtype_to_base_data_type, - get_base_value_for_numpy_or_python_constant_data, - get_constructor_for_numpy_or_python_constant_data, -) - - -@pytest.mark.parametrize( - "numpy_dtype,expected_common_type", - [ - pytest.param(numpy.int8, Integer(8, is_signed=True)), - pytest.param("int8", Integer(8, is_signed=True)), - pytest.param(numpy.int16, Integer(16, is_signed=True)), - pytest.param("int16", Integer(16, is_signed=True)), - pytest.param(numpy.int32, Integer(32, is_signed=True)), - pytest.param("int32", Integer(32, is_signed=True)), - pytest.param(numpy.int64, Integer(64, is_signed=True)), - pytest.param("int64", Integer(64, is_signed=True)), - pytest.param(numpy.uint8, Integer(8, is_signed=False)), - pytest.param("uint8", Integer(8, is_signed=False)), - pytest.param(numpy.uint16, Integer(16, is_signed=False)), - pytest.param("uint16", Integer(16, is_signed=False)), - pytest.param(numpy.uint32, Integer(32, is_signed=False)), - pytest.param("uint32", Integer(32, is_signed=False)), - pytest.param(numpy.uint64, Integer(64, is_signed=False)), - pytest.param("uint64", Integer(64, is_signed=False)), - pytest.param(numpy.float32, Float(32)), - pytest.param("float32", Float(32)), - pytest.param(numpy.float64, Float(64)), - pytest.param("float64", Float(64)), - pytest.param("complex64", None, marks=pytest.mark.xfail(strict=True, raises=ValueError)), - ], -) -def test_convert_numpy_dtype_to_base_data_type(numpy_dtype, expected_common_type): - """Test function for convert_numpy_dtype_to_base_data_type""" - assert convert_numpy_dtype_to_base_data_type(numpy_dtype) == expected_common_type - - -@pytest.mark.parametrize( - "common_dtype,expected_numpy_dtype", - [ - pytest.param(Integer(7, is_signed=False), numpy.uint32), - pytest.param(Integer(7, is_signed=True), numpy.int32), - pytest.param(Integer(32, is_signed=True), numpy.int32), - pytest.param(Integer(64, is_signed=True), numpy.int64), - pytest.param(Integer(32, is_signed=False), numpy.uint32), - pytest.param(Integer(64, is_signed=False), numpy.uint64), - pytest.param(Float(32), numpy.float32), - pytest.param(Float(64), numpy.float64), - pytest.param( - Integer(128, is_signed=True), - None, - marks=pytest.mark.xfail(strict=True, raises=NotImplementedError), - ), - ], -) -def test_convert_common_dtype_to_numpy_dtype(common_dtype, expected_numpy_dtype): - """Test function for convert_common_dtype_to_numpy_dtype""" - assert expected_numpy_dtype == convert_base_data_type_to_numpy_dtype(common_dtype) - - -@pytest.mark.parametrize( - "constant_data,expected_constructor", - [ - (10, int), - (42.0, float), - (numpy.int32(10), numpy.int32), - ], -) -def test_get_constructor_for_numpy_or_python_constant_data(constant_data, expected_constructor): - """Test function for get_constructor_for_numpy_or_python_constant_data""" - - assert expected_constructor == get_constructor_for_numpy_or_python_constant_data(constant_data) - - -def test_get_constructor_for_numpy_arrays(test_helpers): - """Test function for get_constructor_for_numpy_or_python_constant_data for numpy arrays.""" - - arrays = [ - numpy.array([[0, 1], [3, 4]], dtype=numpy.uint64), - numpy.array([[0, 1], [3, 4]], dtype=numpy.float64), - ] - - def get_expected_constructor(array: numpy.ndarray): - return lambda x: numpy.full(array.shape, x, dtype=array.dtype) - - expected_constructors = [get_expected_constructor(array) for array in arrays] - - for array, expected_constructor in zip(arrays, expected_constructors): - assert test_helpers.python_functions_are_equal_or_equivalent( - expected_constructor, get_constructor_for_numpy_or_python_constant_data(array) - ) - - -def test_get_base_value_for_numpy_or_python_constant_data_with_list(): - """Test function for get_base_value_for_numpy_or_python_constant_data called with list""" - - with pytest.raises( - AssertionError, - match="Unsupported constant data of type list " - "\\(if you meant to use a list as an array, please use numpy\\.array instead\\)", - ): - get_base_value_for_numpy_or_python_constant_data([1, 2, 3]) diff --git a/tests/numpy/test_np_inputset_helpers.py b/tests/numpy/test_np_inputset_helpers.py deleted file mode 100644 index 37eb25e2f..000000000 --- a/tests/numpy/test_np_inputset_helpers.py +++ /dev/null @@ -1,96 +0,0 @@ -"""Test file for numpy inputset helpers""" - -import numpy as np -import pytest - -from concrete.common.compilation import CompilationConfiguration -from concrete.common.data_types import Float, UnsignedInteger -from concrete.common.data_types.base import BaseDataType -from concrete.common.values import BaseValue, EncryptedScalar, EncryptedTensor -from concrete.numpy.np_inputset_helpers import _generate_random_inputset - - -def test_generate_random_inputset(): - """Test function for generate_random_inputset""" - - inputset = _generate_random_inputset( - { - "x1": EncryptedScalar(UnsignedInteger(4)), - "x2": EncryptedTensor(UnsignedInteger(4), shape=(2, 3)), - "x3": EncryptedScalar(Float(64)), - "x4": EncryptedTensor(Float(64), shape=(3, 2)), - }, - CompilationConfiguration(random_inputset_samples=15), - ) - - assert isinstance(inputset, list) - assert len(inputset) == 15 - - for sample in inputset: - assert isinstance(sample, tuple) - assert len(sample) == 4 - - assert isinstance(sample[0], int) - assert 0 <= sample[0] < 2 ** 4 - - assert isinstance(sample[1], np.ndarray) - assert sample[1].dtype == np.uint64 - assert sample[1].shape == (2, 3) - assert (sample[1] >= 0).all() - assert (sample[1] < 2 ** 4).all() - - assert isinstance(sample[2], float) - assert 0 <= sample[2] < 1 - - assert isinstance(sample[3], np.ndarray) - assert sample[3].dtype == np.float64 - assert sample[3].shape == (3, 2) - assert (sample[3] >= 0).all() - assert (sample[3] < 1).all() - - -def test_fail_generate_random_inputset(): - """Test function for failed generate_random_inputset""" - - class MockDtype(BaseDataType): - """Unsupported dtype to check error messages""" - - def __eq__(self, o: object) -> bool: - return False - - def __str__(self): - return "MockDtype" - - class MockValue(BaseValue): - """Unsupported value to check error messages""" - - def __init__(self): - super().__init__(MockDtype(), is_encrypted=True) - - def __eq__(self, other: object) -> bool: - return False - - def __str__(self): - return "MockValue" - - with pytest.raises(ValueError): - try: - _generate_random_inputset( - {"x": MockValue()}, - CompilationConfiguration(random_inputset_samples=15), - ) - except Exception as error: - expected = "Random inputset cannot be generated for MockValue parameters" - assert str(error) == expected - raise - - with pytest.raises(ValueError): - try: - _generate_random_inputset( - {"x": EncryptedScalar(MockDtype())}, - CompilationConfiguration(random_inputset_samples=15), - ) - except Exception as error: - expected = "Random inputset cannot be generated for parameters of type MockDtype" - assert str(error) == expected - raise diff --git a/tests/numpy/test_np_mlir_converter.py b/tests/numpy/test_np_mlir_converter.py deleted file mode 100644 index a8c6eb5ff..000000000 --- a/tests/numpy/test_np_mlir_converter.py +++ /dev/null @@ -1,110 +0,0 @@ -"""Test file for numpy mlir converter""" - -import math - -import numpy -import pytest - -import concrete.numpy as hnp -from concrete.common.representation.intermediate import GenericFunction -from concrete.numpy.np_mlir_converter import generate_deduplicated_tables - - -def multi_tlu_func(x, cst): - """Multi TLU function""" - y = x + cst - return y.astype(numpy.int32) - - -RESNET_BIGGEST_SHAPE = (64, 112, 112) -RESNET_BIGGEST_SIZE = math.prod(RESNET_BIGGEST_SHAPE) - - -@pytest.mark.parametrize( - "function,expected_number_of_tables", - [ - ( - lambda x: multi_tlu_func(x, numpy.zeros(RESNET_BIGGEST_SHAPE, dtype=numpy.float64)), - 1, - ), - ( - lambda x: multi_tlu_func( - x, - numpy.arange(RESNET_BIGGEST_SIZE, dtype=numpy.float64).reshape( - RESNET_BIGGEST_SHAPE - ), - ), - RESNET_BIGGEST_SIZE, - ), - ], -) -def test_generate_deduplicated_tables( - function, expected_number_of_tables, default_compilation_configuration -): - """Test function for generate_deduplicated_tables""" - op_graph = hnp.compile_numpy_function_into_op_graph_and_measure_bounds( - function, - {"x": hnp.EncryptedTensor(hnp.Integer(7, False), RESNET_BIGGEST_SHAPE)}, - (i * numpy.ones(RESNET_BIGGEST_SHAPE, dtype=numpy.int32) for i in range(128)), - default_compilation_configuration, - ) - - univariate_function_nodes = [ - node for node in op_graph.graph.nodes() if isinstance(node, GenericFunction) - ] - - assert len(univariate_function_nodes) == 1 - - tlu_node = univariate_function_nodes[0] - - deduplication_result = generate_deduplicated_tables( - tlu_node, op_graph.get_ordered_preds(tlu_node) - ) - - assert len(deduplication_result) == expected_number_of_tables - - -def test_deduplicated_tables_correctness(default_compilation_configuration): - """Check the deduplicated tables are the expected ones""" - - tensor_shape = (2, 2) - - op_graph = hnp.compile_numpy_function_into_op_graph_and_measure_bounds( - lambda x: multi_tlu_func(x, numpy.arange(4, dtype=numpy.float64).reshape(tensor_shape)), - {"x": hnp.EncryptedTensor(hnp.Integer(2, False), tensor_shape)}, - (i * numpy.ones(tensor_shape, dtype=numpy.int32) for i in range(4)), - default_compilation_configuration, - ) - - univariate_function_nodes = [ - node for node in op_graph.graph.nodes() if isinstance(node, GenericFunction) - ] - - assert len(univariate_function_nodes) == 1 - - tlu_node = univariate_function_nodes[0] - - deduplication_result = generate_deduplicated_tables( - tlu_node, op_graph.get_ordered_preds(tlu_node) - ) - - expected_result = tuple( - ( - numpy.arange(i, 4 + i, dtype=numpy.int32), - [ - numpy.unravel_index(i, tensor_shape), - ], - ) - for i in range(4) - ) - - assert len(deduplication_result) == len(expected_result) - for computed_array, computed_idx in deduplication_result: - for expected_array, expected_idx in expected_result: - if numpy.array_equal(computed_array, expected_array) and computed_idx == expected_idx: - break - else: - raise AssertionError( - f"Could not find {(computed_array, computed_idx)} " - f"in expected_result: {expected_result}" - ) diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py deleted file mode 100644 index b4050a3de..000000000 --- a/tests/numpy/test_tracing.py +++ /dev/null @@ -1,991 +0,0 @@ -"""Test file for numpy tracing""" - -import inspect - -import networkx as nx -import numpy -import pytest - -from concrete.common.data_types.dtypes_helpers import broadcast_shapes -from concrete.common.data_types.floats import Float -from concrete.common.data_types.integers import Integer -from concrete.common.debugging import format_operation_graph -from concrete.common.representation import intermediate as ir -from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor -from concrete.numpy import tracing - -OPERATIONS_TO_TEST = [ir.Add, ir.Sub, ir.Mul] - - -@pytest.mark.parametrize( - "operation", - OPERATIONS_TO_TEST, -) -@pytest.mark.parametrize( - "x", - [ - pytest.param(EncryptedScalar(Integer(64, is_signed=False)), id="x: Encrypted uint"), - pytest.param( - EncryptedScalar(Integer(64, is_signed=True)), - id="x: Encrypted int", - ), - pytest.param( - ClearScalar(Integer(64, is_signed=False)), - id="x: Clear uint", - ), - pytest.param( - ClearScalar(Integer(64, is_signed=True)), - id="x: Clear int", - ), - ], -) -@pytest.mark.parametrize( - "y", - [ - pytest.param(EncryptedScalar(Integer(64, is_signed=False)), id="y: Encrypted uint"), - pytest.param( - EncryptedScalar(Integer(64, is_signed=True)), - id="y: Encrypted int", - ), - pytest.param( - ClearScalar(Integer(64, is_signed=False)), - id="y: Clear uint", - ), - pytest.param( - ClearScalar(Integer(64, is_signed=True)), - id="y: Clear int", - ), - ], -) -def test_numpy_tracing_binary_op(operation, x, y, test_helpers): - "Test numpy tracing a binary operation (in the supported ops)" - - # Remark that the functions here have a common structure (which is - # 2x op y), such that creating further the ref_graph is easy, by - # hand - def simple_add_function(x, y): - z = x + x - return z + y - - def simple_sub_function(x, y): - z = x + x - return z - y - - def simple_mul_function(x, y): - z = x + x - return z * y - - assert operation in OPERATIONS_TO_TEST, f"unknown operation {operation}" - if operation == ir.Add: - function_to_compile = simple_add_function - elif operation == ir.Sub: - function_to_compile = simple_sub_function - elif operation == ir.Mul: - function_to_compile = simple_mul_function - - op_graph = tracing.trace_numpy_function(function_to_compile, {"x": x, "y": y}) - - ref_graph = nx.MultiDiGraph() - - input_x = ir.Input(x, input_name="x", program_input_idx=0) - input_y = ir.Input(y, input_name="y", program_input_idx=1) - - add_node_z = ir.Add( - ( - input_x.outputs[0], - input_x.outputs[0], - ) - ) - - returned_final_node = operation( - ( - add_node_z.outputs[0], - input_y.outputs[0], - ) - ) - - ref_graph.add_node(input_x) - ref_graph.add_node(input_y) - ref_graph.add_node(add_node_z) - ref_graph.add_node(returned_final_node) - - ref_graph.add_edge(input_x, add_node_z, input_idx=0, output_idx=0) - ref_graph.add_edge(input_x, add_node_z, input_idx=1, output_idx=0) - - ref_graph.add_edge(add_node_z, returned_final_node, input_idx=0, output_idx=0) - ref_graph.add_edge(input_y, returned_final_node, input_idx=1, output_idx=0) - - assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph) - - -def test_numpy_tracing_tensors(): - "Test numpy tracing tensors" - - def all_operations(x): - intermediate = x + numpy.array([[1, 2], [3, 4]]) - intermediate = numpy.array([[5, 6], [7, 8]]) + intermediate - - intermediate = numpy.array([[100, 200], [300, 400]]) - intermediate - intermediate = intermediate - numpy.array([[10, 20], [30, 40]]) - - intermediate = intermediate * numpy.array([[1, 2], [2, 1]]) - intermediate = numpy.array([[2, 1], [1, 2]]) * intermediate - - return intermediate - - op_graph = tracing.trace_numpy_function( - all_operations, {"x": EncryptedTensor(Integer(32, True), shape=(2, 2))} - ) - - expected = """ %0 = [[2 1] [1 2]] # ClearTensor - %1 = [[1 2] [2 1]] # ClearTensor - %2 = [[10 20] [30 40]] # ClearTensor - %3 = [[100 200] [300 400]] # ClearTensor - %4 = [[5 6] [7 8]] # ClearTensor - %5 = x # EncryptedTensor - %6 = [[1 2] [3 4]] # ClearTensor - %7 = add(%5, %6) # EncryptedTensor - %8 = add(%4, %7) # EncryptedTensor - %9 = sub(%3, %8) # EncryptedTensor -%10 = sub(%9, %2) # EncryptedTensor -%11 = mul(%10, %1) # EncryptedTensor -%12 = mul(%0, %11) # EncryptedTensor -return %12""" # noqa: E501 - - assert format_operation_graph(op_graph) == expected, format_operation_graph(op_graph) - - -def test_numpy_explicit_tracing_tensors(): - "Test numpy tracing tensors using explicit operations" - - def all_explicit_operations(x): - intermediate = numpy.add(x, numpy.array([[1, 2], [3, 4]])) - intermediate = numpy.add(numpy.array([[5, 6], [7, 8]]), intermediate) - - intermediate = numpy.subtract(numpy.array([[100, 200], [300, 400]]), intermediate) - intermediate = numpy.subtract(intermediate, numpy.array([[10, 20], [30, 40]])) - - intermediate = numpy.multiply(intermediate, numpy.array([[1, 2], [2, 1]])) - intermediate = numpy.multiply(numpy.array([[2, 1], [1, 2]]), intermediate) - - return intermediate - - op_graph = tracing.trace_numpy_function( - all_explicit_operations, {"x": EncryptedTensor(Integer(32, True), shape=(2, 2))} - ) - - expected = """ %0 = [[2 1] [1 2]] # ClearTensor - %1 = [[1 2] [2 1]] # ClearTensor - %2 = [[10 20] [30 40]] # ClearTensor - %3 = [[100 200] [300 400]] # ClearTensor - %4 = [[5 6] [7 8]] # ClearTensor - %5 = x # EncryptedTensor - %6 = [[1 2] [3 4]] # ClearTensor - %7 = add(%5, %6) # EncryptedTensor - %8 = add(%4, %7) # EncryptedTensor - %9 = sub(%3, %8) # EncryptedTensor -%10 = sub(%9, %2) # EncryptedTensor -%11 = mul(%10, %1) # EncryptedTensor -%12 = mul(%0, %11) # EncryptedTensor -return %12""" # noqa: E501 - - assert format_operation_graph(op_graph) == expected - - -@pytest.mark.parametrize( - "x_shape,y_shape", - [ - pytest.param((), ()), - pytest.param((3,), ()), - pytest.param((3,), (1,)), - pytest.param((3,), (2,), marks=pytest.mark.xfail(raises=AssertionError, strict=True)), - pytest.param((3,), (3,)), - pytest.param((2, 3), ()), - pytest.param((2, 3), (1,)), - pytest.param((2, 3), (2,), marks=pytest.mark.xfail(raises=AssertionError, strict=True)), - pytest.param((2, 3), (3,)), - pytest.param((2, 3), (1, 1)), - pytest.param((2, 3), (2, 1)), - pytest.param((2, 3), (3, 1), marks=pytest.mark.xfail(raises=AssertionError, strict=True)), - pytest.param((2, 3), (1, 2), marks=pytest.mark.xfail(raises=AssertionError, strict=True)), - pytest.param((2, 3), (2, 2), marks=pytest.mark.xfail(raises=AssertionError, strict=True)), - pytest.param((2, 3), (3, 2), marks=pytest.mark.xfail(raises=AssertionError, strict=True)), - pytest.param((2, 3), (1, 3)), - pytest.param((2, 3), (2, 3)), - pytest.param((2, 3), (3, 3), marks=pytest.mark.xfail(raises=AssertionError, strict=True)), - pytest.param((2, 1, 3), (1, 1, 1)), - pytest.param((2, 1, 3), (1, 4, 1)), - pytest.param((2, 1, 3), (2, 4, 3)), - ], -) -def test_numpy_tracing_broadcasted_tensors(x_shape, y_shape): - """Test numpy tracing broadcasted tensors""" - - def f(x, y): - return x + y - - op_graph = tracing.trace_numpy_function( - f, - { - "x": EncryptedTensor(Integer(3, True), shape=x_shape), - "y": EncryptedTensor(Integer(3, True), shape=y_shape), - }, - ) - - assert op_graph.input_nodes[0].outputs[0].shape == x_shape - assert op_graph.input_nodes[1].outputs[0].shape == y_shape - assert op_graph.output_nodes[0].outputs[0].shape == broadcast_shapes(x_shape, y_shape) - - -@pytest.mark.parametrize( - "function_to_trace,op_graph_expected_output_type,input_and_expected_output_tuples", - [ - ( - lambda x: x.astype(numpy.int32), - Integer(32, is_signed=True), - [ - (14, numpy.int32(14)), - (1.5, numpy.int32(1)), - (2.0, numpy.int32(2)), - (-1.5, numpy.int32(-1)), - (2 ** 31 - 1, numpy.int32(2 ** 31 - 1)), - (-(2 ** 31), numpy.int32(-(2 ** 31))), - ], - ), - ( - lambda x: x.astype(numpy.uint32), - Integer(32, is_signed=False), - [ - (14, numpy.uint32(14)), - (1.5, numpy.uint32(1)), - (2.0, numpy.uint32(2)), - (2 ** 32 - 1, numpy.uint32(2 ** 32 - 1)), - ], - ), - ( - lambda x: x.astype(numpy.int64), - Integer(64, is_signed=True), - [ - (14, numpy.int64(14)), - (1.5, numpy.int64(1)), - (2.0, numpy.int64(2)), - (-1.5, numpy.int64(-1)), - (2 ** 63 - 1, numpy.int64(2 ** 63 - 1)), - (-(2 ** 63), numpy.int64(-(2 ** 63))), - ], - ), - ( - lambda x: x.astype(numpy.uint64), - Integer(64, is_signed=False), - [ - (14, numpy.uint64(14)), - (1.5, numpy.uint64(1)), - (2.0, numpy.uint64(2)), - (2 ** 64 - 1, numpy.uint64(2 ** 64 - 1)), - ], - ), - ( - lambda x: x.astype(numpy.float64), - Float(64), - [ - (14, numpy.float64(14.0)), - (1.5, numpy.float64(1.5)), - (2.0, numpy.float64(2.0)), - (-1.5, numpy.float64(-1.5)), - ], - ), - ( - lambda x: x.astype(numpy.float32), - Float(32), - [ - (14, numpy.float32(14.0)), - (1.5, numpy.float32(1.5)), - (2.0, numpy.float32(2.0)), - (-1.5, numpy.float32(-1.5)), - ], - ), - ], -) -def test_tracing_astype( - function_to_trace, op_graph_expected_output_type, input_and_expected_output_tuples -): - """Test function for NPTracer.astype""" - for input_, expected_output in input_and_expected_output_tuples: - input_value = ( - EncryptedScalar(Integer(64, is_signed=True)) - if isinstance(input_, int) - else EncryptedScalar(Float(64)) - ) - - op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value}) - output_node = op_graph.output_nodes[0] - assert op_graph_expected_output_type == output_node.outputs[0].dtype - - node_results = op_graph.evaluate({0: numpy.array(input_)}) - evaluated_output = node_results[output_node] - assert evaluated_output.dtype == expected_output.dtype - assert expected_output == evaluated_output - - -def test_tracing_astype_single_element_array_corner_case(check_array_equality): - """Test corner case where an array could be transformed to its scalar element""" - a = numpy.array([1], dtype=numpy.float64) - - op_graph = tracing.trace_numpy_function( - lambda x: x.astype(numpy.int32), {"x": EncryptedTensor(Float(64), (1,))} - ) - - eval_result = op_graph(a) - check_array_equality(eval_result, numpy.array([1], dtype=numpy.int32)) - - -@pytest.mark.parametrize( - "function_to_trace,inputs,expected_output_node,expected_output_value", - [ - pytest.param( - lambda x, y: numpy.dot(x, y), - { - "x": EncryptedTensor(Integer(7, is_signed=False), shape=(10,)), - "y": EncryptedTensor(Integer(7, is_signed=False), shape=(10,)), - }, - ir.Dot, - EncryptedScalar(Integer(32, False)), - ), - pytest.param( - lambda x, y: numpy.dot(x, y), - { - "x": EncryptedTensor(Float(64), shape=(10,)), - "y": EncryptedTensor(Float(64), shape=(10,)), - }, - ir.Dot, - EncryptedScalar(Float(64)), - ), - pytest.param( - lambda x, y: numpy.dot(x, y), - { - "x": ClearTensor(Integer(64, is_signed=True), shape=(6,)), - "y": ClearTensor(Integer(64, is_signed=True), shape=(6,)), - }, - ir.Dot, - ClearScalar(Integer(64, is_signed=True)), - ), - pytest.param( - lambda x: numpy.dot(x, numpy.array([1, 2, 3, 4, 5], dtype=numpy.int64)), - { - "x": EncryptedTensor(Integer(64, is_signed=True), shape=(5,)), - }, - ir.Dot, - EncryptedScalar(Integer(64, True)), - ), - pytest.param( - lambda x: x.dot(numpy.array([1, 2, 3, 4, 5], dtype=numpy.int64)), - { - "x": EncryptedTensor(Integer(64, is_signed=True), shape=(5,)), - }, - ir.Dot, - EncryptedScalar(Integer(64, True)), - ), - ], -) -def test_trace_numpy_dot(function_to_trace, inputs, expected_output_node, expected_output_value): - """Function to test dot tracing""" - - op_graph = tracing.trace_numpy_function(function_to_trace, inputs) - - assert len(op_graph.output_nodes) == 1 - assert isinstance(op_graph.output_nodes[0], expected_output_node) - assert len(op_graph.output_nodes[0].outputs) == 1 - assert op_graph.output_nodes[0].outputs[0] == expected_output_value - - -@pytest.mark.parametrize("np_function", tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC) -def test_nptracer_get_tracing_func_for_np_functions(np_function): - """Test NPTracer get_tracing_func_for_np_function""" - - expected_tracing_func = tracing.NPTracer.UFUNC_ROUTING[np_function] - - assert tracing.NPTracer.get_tracing_func_for_np_function(np_function) == expected_tracing_func - - -def test_nptracer_get_tracing_func_for_np_functions_not_implemented(): - """Check NPTracer in case of not-implemented function""" - with pytest.raises(NotImplementedError) as excinfo: - tracing.NPTracer.get_tracing_func_for_np_function(numpy.conjugate) - - assert "NPTracer does not yet manage the following func: conjugate" in str(excinfo.value) - - -@pytest.mark.parametrize( - "operation,exception_type,match", - [ - pytest.param( - lambda x: x + "fail", - TypeError, - "unsupported operand type(s) for +: 'NPTracer' and 'str'", - ), - pytest.param( - lambda x: "fail" + x, - TypeError, - 'can only concatenate str (not "NPTracer") to str', - ), - pytest.param( - lambda x: x - "fail", - TypeError, - "unsupported operand type(s) for -: 'NPTracer' and 'str'", - ), - pytest.param( - lambda x: "fail" - x, - TypeError, - "unsupported operand type(s) for -: 'str' and 'NPTracer'", - ), - pytest.param( - lambda x: x * "fail", - TypeError, - "can't multiply sequence by non-int of type 'NPTracer'", - ), - pytest.param( - lambda x: "fail" * x, - TypeError, - "can't multiply sequence by non-int of type 'NPTracer'", - ), - pytest.param( - lambda x: x / "fail", - TypeError, - "unsupported operand type(s) for /: 'NPTracer' and 'str'", - ), - pytest.param( - lambda x: "fail" / x, - TypeError, - "unsupported operand type(s) for /: 'str' and 'NPTracer'", - ), - pytest.param( - lambda x: x // "fail", - TypeError, - "unsupported operand type(s) for //: 'NPTracer' and 'str'", - ), - pytest.param( - lambda x: "fail" // x, - TypeError, - "unsupported operand type(s) for //: 'str' and 'NPTracer'", - ), - pytest.param( - lambda x, y: x / y, NotImplementedError, "Can't manage binary operator truediv" - ), - pytest.param( - lambda x, y: x // y, NotImplementedError, "Can't manage binary operator floordiv" - ), - ], -) -def test_nptracer_unsupported_operands(operation, exception_type, match): - """Test cases where NPTracer cannot be used with other operands.""" - tracers = [ - tracing.NPTracer([], ir.Input(ClearScalar(Integer(32, True)), param_name, idx), 0) - for idx, param_name in enumerate(inspect.signature(operation).parameters.keys()) - ] - - with pytest.raises(exception_type) as exc_info: - _ = operation(*tracers) - - assert match in str(exc_info) - - -def subtest_tracing_calls( - function_to_trace, - input_value_input_and_expected_output_tuples, - check_array_equality, -): - """Test memory function managed by GenericFunction node of the form numpy.something""" - for input_value, input_, expected_output in input_value_input_and_expected_output_tuples: - - op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value}) - output_node = op_graph.output_nodes[0] - - node_results = op_graph.evaluate({0: input_}) - evaluated_output = node_results[output_node] - assert isinstance(evaluated_output, type(expected_output)), type(evaluated_output) - check_array_equality(evaluated_output, expected_output) - - -@pytest.mark.parametrize( - "function_to_trace,input_value_input_and_expected_output_tuples", - [ - ( - lambda x: numpy.transpose(x), - [ - ( - EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)), - numpy.arange(4).reshape(2, 2), - numpy.array([[0, 2], [1, 3]]), - ), - ( - EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)), - numpy.arange(4, 8).reshape(2, 2), - numpy.array([[4, 6], [5, 7]]), - ), - ( - EncryptedTensor(Integer(6, is_signed=False), shape=()), - numpy.int64(42), - numpy.int64(42), - ), - ], - ), - ( - lambda x: numpy.transpose(x) + 42, - [ - ( - EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.arange(42, 57).reshape(3, 5).transpose(), - ), - ( - EncryptedTensor(Integer(6, is_signed=False), shape=()), - numpy.int64(42), - numpy.int64(84), - ), - ], - ), - ( - lambda x: numpy.ravel(x), - [ - ( - EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)), - numpy.arange(4), - numpy.array([0, 1, 2, 3]), - ), - ( - EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)), - numpy.arange(4).reshape(2, 2), - numpy.array([0, 1, 2, 3]), - ), - ( - EncryptedTensor(Integer(6, is_signed=False), shape=()), - numpy.int64(42), - numpy.array([42], dtype=numpy.int64), - ), - ], - ), - ( - lambda x: numpy.reshape(x, (5, 3)) + 42, - [ - ( - EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.arange(42, 57).reshape(5, 3), - ), - ], - ), - ], -) -def test_tracing_numpy_calls( - function_to_trace, - input_value_input_and_expected_output_tuples, - check_array_equality, -): - """Test memory function managed by GenericFunction node of the form numpy.something""" - subtest_tracing_calls( - function_to_trace, input_value_input_and_expected_output_tuples, check_array_equality - ) - - -@pytest.mark.parametrize( - "function_to_trace,input_value_input_and_expected_output_tuples", - [ - ( - lambda x: x.transpose() + 42, - [ - ( - EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.arange(42, 57).reshape(3, 5).transpose(), - ), - ( - EncryptedTensor(Integer(6, is_signed=False), shape=()), - numpy.int64(42), - numpy.int64(84), - ), - ], - ), - ( - lambda x: x.ravel(), - [ - ( - EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)), - numpy.arange(4), - numpy.array([0, 1, 2, 3]), - ), - ( - EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)), - numpy.arange(4).reshape(2, 2), - numpy.array([0, 1, 2, 3]), - ), - ( - EncryptedTensor(Integer(6, is_signed=False), shape=()), - numpy.int64(42), - numpy.array([42], dtype=numpy.int64), - ), - ], - ), - ( - lambda x: x.reshape((5, 3)) + 42, - [ - ( - EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.arange(42, 57).reshape(5, 3), - ), - ], - ), - pytest.param( - lambda x: x.reshape((5, 3)), - [ - ( - EncryptedTensor(Integer(6, is_signed=False), shape=()), - numpy.int64(42), - None, - ) - ], - marks=pytest.mark.xfail(strict=True, raises=ValueError), - ), - pytest.param( - lambda x: x.flatten(), - [ - ( - EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.arange(15), - ) - ], - ), - pytest.param( - lambda x: abs(x), - [ - ( - EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.arange(15).reshape(3, 5), - ) - ], - ), - pytest.param( - lambda x: +x, - [ - ( - EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.arange(15).reshape(3, 5), - ) - ], - ), - pytest.param( - lambda x: -x, - [ - ( - EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - (numpy.arange(15).reshape(3, 5)) * (-1), - ) - ], - ), - pytest.param( - lambda x: ~x, - [ - ( - EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.arange(15).reshape(3, 5).__invert__(), - ) - ], - ), - pytest.param( - lambda x: x << 3, - [ - ( - EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.arange(15).reshape(3, 5) * 8, - ) - ], - ), - pytest.param( - lambda x: x >> 1, - [ - ( - EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.arange(15).reshape(3, 5) // 2, - ) - ], - ), - pytest.param( - lambda x: 2 << x, - [ - ( - EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), - numpy.arange(15).reshape(3, 5) % 8, - 2 << (numpy.arange(15).reshape(3, 5) % 8), - ) - ], - ), - pytest.param( - lambda x: 256 >> x, - [ - ( - EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), - numpy.arange(15).reshape(3, 5) % 8, - 256 >> (numpy.arange(15).reshape(3, 5) % 8), - ) - ], - ), - pytest.param( - lambda x: x > 4, - [ - ( - EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.arange(15).reshape(3, 5) > 4, - ) - ], - ), - pytest.param( - lambda x: x < 5, - [ - ( - EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.arange(15).reshape(3, 5) < 5, - ) - ], - ), - pytest.param( - lambda x: x <= 7, - [ - ( - EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.arange(15).reshape(3, 5) <= 7, - ) - ], - ), - pytest.param( - lambda x: x >= 9, - [ - ( - EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.arange(15).reshape(3, 5) >= 9, - ) - ], - ), - pytest.param( - lambda x: x == 11, - [ - ( - EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.arange(15).reshape(3, 5) == 11, - ) - ], - ), - pytest.param( - lambda x: x != 12, - [ - ( - EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.arange(15).reshape(3, 5) != 12, - ) - ], - ), - # Remove misplaced-comparison-constant because precisely, we want to be sure it works fine - # pylint: disable=misplaced-comparison-constant - pytest.param( - lambda x: 4 > x, - [ - ( - EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - 4 > numpy.arange(15).reshape(3, 5), - ) - ], - ), - pytest.param( - lambda x: 5 < x, - [ - ( - EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - 5 < numpy.arange(15).reshape(3, 5), - ) - ], - ), - pytest.param( - lambda x: 7 <= x, - [ - ( - EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - 7 <= numpy.arange(15).reshape(3, 5), - ) - ], - ), - pytest.param( - lambda x: 9 >= x, - [ - ( - EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - 9 >= numpy.arange(15).reshape(3, 5), - ) - ], - ), - pytest.param( - lambda x: 11 == x, - [ - ( - EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - 11 == numpy.arange(15).reshape(3, 5), - ) - ], - ), - pytest.param( - lambda x: 12 != x, - [ - ( - EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - 12 != numpy.arange(15).reshape(3, 5), - ) - ], - ), - # pylint: enable=misplaced-comparison-constant - ( - lambda x: x & 11, - [ - ( - EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.array([i & 11 for i in range(15)]).reshape(3, 5), - ), - ], - ), - ( - lambda x: 13 & x, - [ - ( - EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.array([i & 13 for i in range(15)]).reshape(3, 5), - ), - ], - ), - ( - lambda x: x | 6, - [ - ( - EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.array([i | 6 for i in range(15)]).reshape(3, 5), - ), - ], - ), - ( - lambda x: 30 | x, - [ - ( - EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.array([i | 30 for i in range(15)]).reshape(3, 5), - ), - ], - ), - ( - lambda x: x ^ 91, - [ - ( - EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.array([i ^ 91 for i in range(15)]).reshape(3, 5), - ), - ], - ), - ( - lambda x: 115 ^ x, - [ - ( - EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.array([i ^ 115 for i in range(15)]).reshape(3, 5), - ), - ], - ), - ( - lambda x: x % 11, - [ - ( - EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.array([i % 11 for i in range(15)]).reshape(3, 5), - ), - ], - ), - ( - lambda x: 150 % (x + 1), - [ - ( - EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.array([150 % (i + 1) for i in range(15)]).reshape(3, 5), - ), - ], - ), - ( - lambda x: x ** 2, - [ - ( - EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.array([i ** 2 for i in range(15)]).reshape(3, 5), - ), - ], - ), - ( - lambda x: 2 ** x, - [ - ( - EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), - numpy.arange(15).reshape(3, 5) % 7, - numpy.array([2 ** (i % 7) for i in range(15)]).reshape(3, 5), - ), - ], - ), - ], -) -def test_tracing_ndarray_calls( - function_to_trace, - input_value_input_and_expected_output_tuples, - check_array_equality, -): - """Test memory function managed by GenericFunction node of the form ndarray.something""" - subtest_tracing_calls( - function_to_trace, input_value_input_and_expected_output_tuples, check_array_equality - ) - - -@pytest.mark.parametrize( - "lambda_f,params", - [ - ( - lambda x: numpy.reshape(x, (5, 3)), - { - "x": EncryptedTensor(Integer(2, is_signed=False), shape=(7, 5)), - }, - ), - ], -) -def test_errors_with_generic_function(lambda_f, params): - "Test some errors with generic function" - with pytest.raises(ValueError) as excinfo: - tracing.trace_numpy_function(lambda_f, params) - - assert "shapes are not compatible (old shape (7, 5), new shape (5, 3))" in str(excinfo.value) diff --git a/tests/numpy/test_tracing_calls.py b/tests/numpy/test_tracing_calls.py deleted file mode 100644 index 473a6bc85..000000000 --- a/tests/numpy/test_tracing_calls.py +++ /dev/null @@ -1,309 +0,0 @@ -"""Test file for numpy tracing""" - -from copy import deepcopy - -import numpy -import pytest - -from concrete.common.data_types.floats import Float -from concrete.common.data_types.integers import Integer -from concrete.common.representation import intermediate as ir -from concrete.common.values import EncryptedScalar, EncryptedTensor -from concrete.numpy import tracing - -OPERATIONS_TO_TEST = [ir.Add, ir.Sub, ir.Mul] - -# Functions from tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC, whose output -# is a float64, whatever the input type -LIST_OF_UFUNC_WHOSE_OUTPUT_IS_FLOAT64 = set( - [ - numpy.arccos, - numpy.arccosh, - numpy.arcsin, - numpy.arcsinh, - numpy.arctan, - numpy.arctanh, - numpy.cbrt, - numpy.ceil, - numpy.cos, - numpy.cosh, - numpy.deg2rad, - numpy.degrees, - numpy.exp, - numpy.exp2, - numpy.expm1, - numpy.fabs, - numpy.floor, - numpy.log, - numpy.log10, - numpy.log1p, - numpy.log2, - numpy.rad2deg, - numpy.radians, - numpy.rint, - numpy.sin, - numpy.sinh, - numpy.spacing, - numpy.sqrt, - numpy.tan, - numpy.tanh, - numpy.trunc, - ] -) - -# Functions from tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC, whose output -# is a boolean, whatever the input type -LIST_OF_UFUNC_WHOSE_OUTPUT_IS_BOOL = set( - [ - numpy.isfinite, - numpy.isinf, - numpy.isnan, - numpy.signbit, - numpy.logical_not, - ] -) - - -@pytest.mark.parametrize( - "inputs,expected_output_node", - [ - pytest.param( - {"x": EncryptedScalar(Integer(7, is_signed=False))}, - ir.GenericFunction, - ), - pytest.param( - {"x": EncryptedScalar(Integer(32, is_signed=True))}, - ir.GenericFunction, - ), - pytest.param( - {"x": EncryptedScalar(Integer(64, is_signed=True))}, - ir.GenericFunction, - ), - pytest.param( - {"x": EncryptedScalar(Integer(128, is_signed=True))}, - ir.GenericFunction, - marks=pytest.mark.xfail(strict=True, raises=NotImplementedError), - ), - pytest.param( - {"x": EncryptedScalar(Float(64))}, - ir.GenericFunction, - ), - ], -) -@pytest.mark.parametrize( - "function_to_trace_def", - [f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 1], -) -def test_trace_numpy_supported_unary_ufuncs(inputs, expected_output_node, function_to_trace_def): - """Function to trace supported numpy ufuncs""" - - # We really need a lambda (because numpy functions are not playing - # nice with inspect.signature), but pylint and flake8 are not happy - # with it - # pylint: disable=cell-var-from-loop - function_to_trace = lambda x: function_to_trace_def(x) # noqa: E731 - # pylint: enable=cell-var-from-loop - - op_graph = tracing.trace_numpy_function(function_to_trace, inputs) - - assert len(op_graph.output_nodes) == 1 - assert isinstance(op_graph.output_nodes[0], expected_output_node) - assert len(op_graph.output_nodes[0].outputs) == 1 - - if function_to_trace_def in LIST_OF_UFUNC_WHOSE_OUTPUT_IS_FLOAT64: - assert op_graph.output_nodes[0].outputs[0] == EncryptedScalar(Float(64)) - elif function_to_trace_def in LIST_OF_UFUNC_WHOSE_OUTPUT_IS_BOOL: - - # Boolean function - assert op_graph.output_nodes[0].outputs[0] == EncryptedScalar(Integer(8, is_signed=False)) - else: - - # Function keeping more or less input type - input_node_type = inputs["x"] - - expected_output_node_type = deepcopy(input_node_type) - - expected_output_node_type.dtype.bit_width = max( - expected_output_node_type.dtype.bit_width, 32 - ) - - assert op_graph.output_nodes[0].outputs[0] == expected_output_node_type - - -@pytest.mark.parametrize("np_function", tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC) -def test_nptracer_get_tracing_func_for_np_functions(np_function): - """Test NPTracer get_tracing_func_for_np_function""" - - expected_tracing_func = tracing.NPTracer.UFUNC_ROUTING[np_function] - - assert tracing.NPTracer.get_tracing_func_for_np_function(np_function) == expected_tracing_func - - -def subtest_tracing_calls( - function_to_trace, - input_value_input_and_expected_output_tuples, - check_array_equality, -): - """Test memory function managed by GenericFunction node of the form numpy.something""" - for input_value, input_, expected_output in input_value_input_and_expected_output_tuples: - - op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value}) - output_node = op_graph.output_nodes[0] - - node_results = op_graph.evaluate({0: input_}) - evaluated_output = node_results[output_node] - assert isinstance(evaluated_output, type(expected_output)), type(evaluated_output) - check_array_equality(evaluated_output, expected_output) - - -@pytest.mark.parametrize( - "function_to_trace,input_value_input_and_expected_output_tuples", - [ - ( - lambda x: numpy.transpose(x), - [ - ( - EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)), - numpy.arange(4).reshape(2, 2), - numpy.array([[0, 2], [1, 3]]), - ), - ( - EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)), - numpy.arange(4, 8).reshape(2, 2), - numpy.array([[4, 6], [5, 7]]), - ), - ( - EncryptedTensor(Integer(6, is_signed=False), shape=()), - numpy.int64(42), - numpy.int64(42), - ), - ], - ), - ( - lambda x: numpy.transpose(x) + 42, - [ - ( - EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.arange(42, 57).reshape(3, 5).transpose(), - ), - ( - EncryptedTensor(Integer(6, is_signed=False), shape=()), - numpy.int64(42), - numpy.int64(84), - ), - ], - ), - ( - lambda x: numpy.ravel(x), - [ - ( - EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)), - numpy.arange(4), - numpy.array([0, 1, 2, 3]), - ), - ( - EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)), - numpy.arange(4).reshape(2, 2), - numpy.array([0, 1, 2, 3]), - ), - ( - EncryptedTensor(Integer(6, is_signed=False), shape=()), - numpy.int64(42), - numpy.array([42], dtype=numpy.int64), - ), - ], - ), - ( - lambda x: numpy.reshape(x, (5, 3)) + 42, - [ - ( - EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.arange(42, 57).reshape(5, 3), - ), - ], - ), - ], -) -def test_tracing_numpy_calls( - function_to_trace, - input_value_input_and_expected_output_tuples, - check_array_equality, -): - """Test memory function managed by GenericFunction node of the form numpy.something""" - subtest_tracing_calls( - function_to_trace, input_value_input_and_expected_output_tuples, check_array_equality - ) - - -@pytest.mark.parametrize( - "function_to_trace,input_value_input_and_expected_output_tuples", - [ - ( - lambda x: x.transpose() + 42, - [ - ( - EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.arange(42, 57).reshape(3, 5).transpose(), - ), - ( - EncryptedTensor(Integer(6, is_signed=False), shape=()), - numpy.int64(42), - numpy.int64(84), - ), - ], - ), - ( - lambda x: x.ravel(), - [ - ( - EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)), - numpy.arange(4), - numpy.array([0, 1, 2, 3]), - ), - ( - EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)), - numpy.arange(4).reshape(2, 2), - numpy.array([0, 1, 2, 3]), - ), - ( - EncryptedTensor(Integer(6, is_signed=False), shape=()), - numpy.int64(42), - numpy.array([42], dtype=numpy.int64), - ), - ], - ), - ( - lambda x: x.reshape((5, 3)) + 42, - [ - ( - EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), - numpy.arange(15).reshape(3, 5), - numpy.arange(42, 57).reshape(5, 3), - ), - ], - ), - pytest.param( - lambda x: x.reshape((5, 3)), - [ - ( - EncryptedTensor(Integer(6, is_signed=False), shape=()), - numpy.int64(42), - None, - ) - ], - marks=pytest.mark.xfail(strict=True, raises=ValueError), - ), - ], -) -def test_tracing_ndarray_calls( - function_to_trace, - input_value_input_and_expected_output_tuples, - check_array_equality, -): - """Test memory function managed by GenericFunction node of the form ndarray.something""" - subtest_tracing_calls( - function_to_trace, input_value_input_and_expected_output_tuples, check_array_equality - ) diff --git a/tests/numpy/test_tracing_failures.py b/tests/numpy/test_tracing_failures.py deleted file mode 100644 index 7a01f8418..000000000 --- a/tests/numpy/test_tracing_failures.py +++ /dev/null @@ -1,184 +0,0 @@ -"""Test file for numpy tracing""" - -import inspect - -import numpy -import pytest - -from concrete.common.data_types.integers import Integer -from concrete.common.representation import intermediate as ir -from concrete.common.values import ClearScalar, EncryptedScalar, EncryptedTensor -from concrete.numpy import tracing - - -@pytest.mark.parametrize( - "inputs", - [ - pytest.param( - {"x": EncryptedScalar(Integer(32, is_signed=True))}, - ), - ], -) -@pytest.mark.parametrize( - "function_to_trace", - # We really need a lambda (because numpy functions are not playing - # nice with inspect.signature), but pylint is not happy - # with it - [lambda x: numpy.invert(x), lambda x: numpy.bitwise_not(x)], -) -def test_trace_numpy_fails_for_invert(inputs, function_to_trace): - """Check we catch calls to numpy.invert and tell user to change their code""" - - with pytest.raises(RuntimeError) as excinfo: - tracing.trace_numpy_function(function_to_trace, inputs) - - assert ( - "NPTracer does not manage the following func: invert. Please replace by calls to " - "bitwise_xor with appropriate mask" in str(excinfo.value) - ) - - -def test_trace_numpy_ufuncs_not_supported(): - """Testing a failure case of trace_numpy_function""" - inputs = {"x": EncryptedScalar(Integer(128, is_signed=True))} - - # We really need a lambda (because numpy functions are not playing - # nice with inspect.signature), but pylint and flake8 are not happy - # with it - function_to_trace = lambda x: numpy.add.reduce(x) # noqa: E731 - - with pytest.raises(NotImplementedError) as excinfo: - tracing.trace_numpy_function(function_to_trace, inputs) - - assert "Only __call__ method is supported currently" in str(excinfo.value) - - -def test_trace_numpy_ufuncs_no_kwargs_no_extra_args(): - """Test a case where kwargs are not allowed and too many inputs are passed""" - inputs = { - "x": EncryptedScalar(Integer(32, is_signed=True)), - "y": EncryptedScalar(Integer(32, is_signed=True)), - "z": EncryptedScalar(Integer(32, is_signed=True)), - } - - # We really need a lambda (because numpy functions are not playing - # nice with inspect.signature), but pylint and flake8 are not happy - # with it - function_to_trace = lambda x, y, z: numpy.add(x, y, z) # noqa: E731 - - with pytest.raises(AssertionError) as excinfo: - tracing.trace_numpy_function(function_to_trace, inputs) - - # numpy only passes ufunc.nin tracers so the extra arguments are passed as kwargs - assert "**kwargs are currently not supported for numpy ufuncs, ufunc: add" in str(excinfo.value) - - # We really need a lambda (because numpy functions are not playing - # nice with inspect.signature), but pylint and flake8 are not happy - # with it - function_to_trace = lambda x, y, z: numpy.add(x, y, out=z) # noqa: E731 - - with pytest.raises(AssertionError) as excinfo: - tracing.trace_numpy_function(function_to_trace, inputs) - - assert "**kwargs are currently not supported for numpy ufuncs, ufunc: add" in str(excinfo.value) - - -def test_nptracer_get_tracing_func_for_np_functions_not_implemented(): - """Check NPTracer in case of not-implemented function""" - with pytest.raises(NotImplementedError) as excinfo: - tracing.NPTracer.get_tracing_func_for_np_function(numpy.conjugate) - - assert "NPTracer does not yet manage the following func: conjugate" in str(excinfo.value) - - -@pytest.mark.parametrize( - "operation,exception_type,match", - [ - pytest.param( - lambda x: x + "fail", - TypeError, - "unsupported operand type(s) for +: 'NPTracer' and 'str'", - ), - pytest.param( - lambda x: "fail" + x, - TypeError, - 'can only concatenate str (not "NPTracer") to str', - ), - pytest.param( - lambda x: x - "fail", - TypeError, - "unsupported operand type(s) for -: 'NPTracer' and 'str'", - ), - pytest.param( - lambda x: "fail" - x, - TypeError, - "unsupported operand type(s) for -: 'str' and 'NPTracer'", - ), - pytest.param( - lambda x: x * "fail", - TypeError, - "can't multiply sequence by non-int of type 'NPTracer'", - ), - pytest.param( - lambda x: "fail" * x, - TypeError, - "can't multiply sequence by non-int of type 'NPTracer'", - ), - pytest.param( - lambda x: x / "fail", - TypeError, - "unsupported operand type(s) for /: 'NPTracer' and 'str'", - ), - pytest.param( - lambda x: "fail" / x, - TypeError, - "unsupported operand type(s) for /: 'str' and 'NPTracer'", - ), - pytest.param( - lambda x: x // "fail", - TypeError, - "unsupported operand type(s) for //: 'NPTracer' and 'str'", - ), - pytest.param( - lambda x: "fail" // x, - TypeError, - "unsupported operand type(s) for //: 'str' and 'NPTracer'", - ), - pytest.param( - lambda x, y: x / y, NotImplementedError, "Can't manage binary operator truediv" - ), - pytest.param( - lambda x, y: x // y, NotImplementedError, "Can't manage binary operator floordiv" - ), - ], -) -def test_nptracer_unsupported_operands(operation, exception_type, match): - """Test cases where NPTracer cannot be used with other operands.""" - tracers = [ - tracing.NPTracer([], ir.Input(ClearScalar(Integer(32, True)), param_name, idx), 0) - for idx, param_name in enumerate(inspect.signature(operation).parameters.keys()) - ] - - with pytest.raises(exception_type) as exc_info: - _ = operation(*tracers) - - assert match in str(exc_info) - - -@pytest.mark.parametrize( - "lambda_f,params", - [ - ( - lambda x: numpy.reshape(x, (5, 3)), - { - "x": EncryptedTensor(Integer(2, is_signed=False), shape=(7, 5)), - }, - ), - ], -) -def test_errors_with_generic_function(lambda_f, params): - "Test some errors with generic function" - with pytest.raises(ValueError) as excinfo: - tracing.trace_numpy_function(lambda_f, params) - - assert "shapes are not compatible (old shape (7, 5), new shape (5, 3))" in str(excinfo.value)