chore: remove the old implementation and its tests

This commit is contained in:
Umut
2022-04-04 13:15:37 +02:00
parent c0ab74ec5a
commit 3239a147e6
84 changed files with 0 additions and 18657 deletions

View File

@@ -1,5 +0,0 @@
"""Top level import."""
# Do not modify, this is to have a compatible namespace package
# https://packaging.python.org/en/latest/guides/packaging-namespace-packages/
# #pkg-resources-style-namespace-packages
__import__("pkg_resources").declare_namespace(__name__) # pragma: no cover

View File

@@ -1,3 +0,0 @@
"""Module for shared data structures and code."""
from . import compilation, data_types, debugging, representation, values
from .common_helpers import check_op_graph_is_integer_program, is_a_power_of_2

View File

@@ -1,2 +0,0 @@
"""Bounds measurement module."""
from . import inputset_eval

View File

@@ -1,260 +0,0 @@
"""Code to evaluate the IR graph on inputsets."""
import sys
from functools import partial
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
from ..compilation import CompilationConfiguration
from ..data_types.dtypes_helpers import (
get_base_value_for_python_constant_data,
is_data_type_compatible_with,
)
from ..debugging import assert_true
from ..operator_graph import OPGraph
from ..representation.intermediate import IntermediateNode
def _check_input_coherency(
input_to_check: Dict[str, Any],
parameters: Dict[str, Any],
get_base_value_for_constant_data_func: Callable[[Any], Any],
):
"""Check whether `input_to_check` is coherent with `parameters`.
This function works by iterating over each constant of the input,
determining base value of the constant using `get_base_value_for_constant_data_func` and
checking if the base value of the contant is compatible with the base value of the parameter.
Args:
input_to_check (Dict[str, Any]): input to check coherency of
parameters (Dict[str, Any]): parameters and their expected base values
get_base_value_for_constant_data_func (Callable[[Any], Any]):
function to get the base value of python objects.
Returns:
List[str]: List of warnings about the coherency
"""
warnings = []
for parameter_name, value in input_to_check.items():
parameter_base_value = parameters[parameter_name]
base_value_class = get_base_value_for_constant_data_func(value)
base_value = base_value_class(is_encrypted=parameter_base_value.is_encrypted)
if base_value.shape != parameter_base_value.shape or not is_data_type_compatible_with(
base_value.dtype, parameter_base_value.dtype
):
warnings.append(
f"expected {str(parameter_base_value)} "
f"for parameter `{parameter_name}` "
f"but got {str(base_value)} "
f"which is not compatible"
)
return warnings
def _print_input_coherency_warnings(
current_input_index: int,
current_input_data: Dict[int, Any],
parameters: Dict[str, Any],
parameter_index_to_parameter_name: Dict[int, str],
get_base_value_for_constant_data_func: Callable[[Any], Any],
treat_warnings_as_errors: bool,
):
"""Print coherency warning for `input_to_check` against `parameters`.
Args:
current_input_index (int): index of the current input on the inputset
current_input_data (Dict[int, Any]): input to print coherency warnings of
parameters (Dict[str, Any]): parameters and their expected base values
parameter_index_to_parameter_name (Dict[int, str]):
dict to get parameter names from parameter indices
get_base_value_for_constant_data_func (Callable[[Any], Any]):
function to get the base value of python objects.
Returns:
None
"""
current_input_named_data = {
parameter_index_to_parameter_name[index]: data for index, data in current_input_data.items()
}
problems = _check_input_coherency(
current_input_named_data,
parameters,
get_base_value_for_constant_data_func,
)
messages = [
(
f"Input #{current_input_index} (0-indexed) "
f"is not coherent with the hinted parameters ({problem})\n"
)
for problem in problems
]
if len(messages) > 0:
if treat_warnings_as_errors:
raise ValueError(", ".join(messages))
for message in messages:
sys.stderr.write(f"Warning: {message}")
def eval_op_graph_bounds_on_inputset(
op_graph: OPGraph,
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]],
compilation_configuration: CompilationConfiguration,
min_func: Callable[[Any, Any], Any] = min,
max_func: Callable[[Any, Any], Any] = max,
get_base_value_for_constant_data_func: Callable[
[Any], Any
] = get_base_value_for_python_constant_data,
prev_node_bounds_and_samples: Optional[Dict[IntermediateNode, Dict[str, Any]]] = None,
) -> Tuple[int, Dict[IntermediateNode, Dict[str, Any]]]:
"""Evaluate the bounds with a inputset.
Evaluate the bounds for all output values of the operators in the graph op_graph over data
coming from the inputset
Args:
op_graph (OPGraph): The graph for which we want to determine the bounds
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]]]): The inputset over which op_graph
is evaluated. It needs to be an iterable on tuples (can be single values in case the
function has only one argument) which are of the same length than the number of
parameters in the function, and in the same order than these same parameters
compilation_configuration (CompilationConfiguration): Configuration object to use
during determining input checking strategy
min_func (Callable[[Any, Any], Any], optional): custom function to compute a scalar minimum
between two values that can be encountered during evaluation (for e.g. numpy or torch
tensors). Defaults to min.
max_func (Callable[[Any, Any], Any], optional): custom function to compute a scalar maximum
between two values that can be encountered during evaluation (for e.g. numpy or torch
tensors). Defaults to max.
get_base_value_for_constant_data_func (Callable[[Any], Any], optional): custom function
to compute the base value of a python object.
prev_node_bounds_and_samples (Optional[Dict[IntermediateNode, Dict[str, Any]]], optional):
Bounds and samples from a previous run. Defaults to None.
Returns:
Tuple[int, Dict[IntermediateNode, Dict[str, Any]]]: number of inputs in the inputset and
a dict containing the bounds for each node from op_graph, stored with the node
as key and a dict with keys "min", "max" and "sample" as value.
"""
num_input_nodes = len(op_graph.input_nodes)
def check_inputset_input_len_is_valid(data_to_check):
# Only check if there are more than one input node, otherwise accept the value as the sole
# argument passed to the OPGraph for evaluation
if num_input_nodes > 1:
assert_true(
len(data_to_check) == num_input_nodes,
(
f"Got input data from inputset of len: {len(data_to_check)}, "
f"function being evaluated has {num_input_nodes} inputs, please make "
f"sure your data generator returns valid tuples of input values"
),
)
def generate_input_values_dict(input_data) -> Dict[int, Any]:
if num_input_nodes > 1:
return dict(enumerate(input_data))
# TODO: https://github.com/zama-ai/concrete-numpy-internal/issues/772
# update this to support tuple in case of 1-input functions accepting tuples
assert_true(
not isinstance(input_data, tuple),
"Tuples are unsupported for single input inputset evaluation",
)
return {0: input_data}
# TODO: do we want to check coherence between the input data type and the corresponding Input ir
# node expected data type ? Not considering bit_width as they may not make sense at this stage
parameter_index_to_parameter_name = {
index: input_node.input_name for index, input_node in op_graph.input_nodes.items()
}
parameters = {
input_node.input_name: input_node.inputs[0] for input_node in op_graph.input_nodes.values()
}
inputset_iterator = iter(inputset)
inputset_size = 0
current_input_data = generate_input_values_dict(next(inputset_iterator))
inputset_size += 1
check_inputset_input_len_is_valid(current_input_data.values())
_print_input_coherency_warnings(
inputset_size - 1,
current_input_data,
parameters,
parameter_index_to_parameter_name,
get_base_value_for_constant_data_func,
compilation_configuration.treat_warnings_as_errors,
)
first_output = op_graph.evaluate(current_input_data)
prev_node_bounds_and_samples = (
{} if prev_node_bounds_and_samples is None else prev_node_bounds_and_samples
)
def get_previous_value_for_key_or_default_for_dict(
dict_: Dict[IntermediateNode, Dict[str, Any]],
node: IntermediateNode,
key: str,
default: Any,
) -> Any:
return_value = default
previous_value_dict = dict_.get(node, None)
if previous_value_dict is not None:
return_value = previous_value_dict.get(key, default)
return return_value
get_previous_value_for_key_or_default = partial(
get_previous_value_for_key_or_default_for_dict, prev_node_bounds_and_samples
)
# We evaluate the min and max func to be able to resolve the tensors min and max rather than
# having the tensor itself as the stored min and max values.
# As we don't know the integrity of prev_node_bounds_and_samples we make sure we can
# populate the new node_bounds_and_samples
node_bounds_and_samples = {
node: {
"min": min_func(value, get_previous_value_for_key_or_default(node, "min", value)),
"max": max_func(value, get_previous_value_for_key_or_default(node, "max", value)),
"sample": get_previous_value_for_key_or_default(node, "sample", value),
}
for node, value in first_output.items()
}
for input_data in inputset_iterator:
inputset_size += 1
current_input_data = generate_input_values_dict(input_data)
check_inputset_input_len_is_valid(current_input_data.values())
if compilation_configuration.check_every_input_in_inputset:
_print_input_coherency_warnings(
inputset_size - 1,
current_input_data,
parameters,
parameter_index_to_parameter_name,
get_base_value_for_constant_data_func,
compilation_configuration.treat_warnings_as_errors,
)
current_output = op_graph.evaluate(current_input_data)
for node, value in current_output.items():
node_bounds_and_samples[node]["min"] = min_func(
node_bounds_and_samples[node]["min"], value
)
node_bounds_and_samples[node]["max"] = max_func(
node_bounds_and_samples[node]["max"], value
)
return inputset_size, node_bounds_and_samples

View File

@@ -1,67 +0,0 @@
"""File to hold some helper code."""
from typing import List, Optional
from .data_types.integers import Integer
from .debugging import assert_true
from .operator_graph import OPGraph
from .representation.intermediate import IntermediateNode
def is_a_power_of_2(x: int) -> bool:
"""Check if an integer is a power of two.
Args:
x (int): Number to check
Returns:
bool: True if the number is a power of two
"""
# https://stackoverflow.com/questions/57025836/how-to-check-if-a-given-number-is-a-power-of-two
return x > 0 and (x & (x - 1)) == 0
def ir_nodes_has_integer_input_and_output(node: IntermediateNode) -> bool:
"""Check if an ir node has Integer inputs and outputs.
Args:
node (IntermediateNode): Node to check
Returns:
bool: True if all input and output values hold Integers
"""
return all(isinstance(x.dtype, Integer) for x in node.inputs) and all(
isinstance(x.dtype, Integer) for x in node.outputs
)
# This check makes sense as long as the compiler backend only manages integers, to be removed in the
# long run probably
def check_op_graph_is_integer_program(
op_graph: OPGraph,
offending_nodes_out: Optional[List[IntermediateNode]] = None,
) -> bool:
"""Check if an op_graph inputs, outputs and intermediate values are Integers.
Args:
op_graph (OPGraph): The OPGraph to check
offending_nodes_out (Optional[List[IntermediateNode]]): Optionally pass a list that will
be populated with offending nodes, the list will be cleared before being filled
Returns:
bool: True if inputs, outputs and intermediate values are Integers, False otherwise
"""
offending_nodes = [] if offending_nodes_out is None else offending_nodes_out
assert_true(
isinstance(offending_nodes, list),
f"offending_nodes_out must be a list, got {type(offending_nodes_out)}",
)
offending_nodes.clear()
offending_nodes.extend(
node for node in op_graph.graph.nodes() if not ir_nodes_has_integer_input_and_output(node)
)
return len(offending_nodes) == 0

View File

@@ -1,4 +0,0 @@
"""Module for compilation related types."""
from .artifacts import CompilationArtifacts
from .configuration import CompilationConfiguration

View File

@@ -1,222 +0,0 @@
"""Module for compilation artifacts."""
import inspect
import platform
import shutil
import subprocess
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Union
import networkx as nx
from loguru import logger
from ..debugging import assert_true, draw_graph, format_operation_graph
from ..operator_graph import OPGraph
from ..representation.intermediate import IntermediateNode
from ..values import BaseValue
DEFAULT_OUTPUT_DIRECTORY: Path = Path(".artifacts")
class CompilationArtifacts:
"""Class that conveys information about compilation process."""
output_directory: Path
source_code_of_the_function_to_compile: Optional[str]
parameters_of_the_function_to_compile: Dict[str, str]
drawings_of_operation_graphs: Dict[str, str]
textual_representations_of_operation_graphs: Dict[str, str]
final_operation_graph: Optional[OPGraph]
bounds_of_the_final_operation_graph: Optional[Dict[IntermediateNode, Dict[str, Any]]]
mlir_of_the_final_operation_graph: Optional[str]
def __init__(self, output_directory: Union[Path, str] = DEFAULT_OUTPUT_DIRECTORY):
self.output_directory = Path(output_directory)
self.source_code_of_the_function_to_compile = None
self.parameters_of_the_function_to_compile = {}
self.drawings_of_operation_graphs = {}
self.textual_representations_of_operation_graphs = {}
self.final_operation_graph = None
self.bounds_of_the_final_operation_graph = None
self.mlir_of_the_final_operation_graph = None
def add_function_to_compile(self, function: Union[Callable, str]):
"""Add the function to compile to artifacts.
Args:
function (Union[Callable, str]): the function to compile or source code of it
Returns:
None
"""
try:
self.source_code_of_the_function_to_compile = (
function if isinstance(function, str) else inspect.getsource(function)
)
# When using the python console we cannot use getsource, so catch that and emit an error
except OSError: # pragma: no cover
function_str = function if isinstance(function, str) else function.__name__
logger.error(f"Could not get source for function: {function_str}")
self.source_code_of_the_function_to_compile = "unavailable"
def add_parameter_of_function_to_compile(self, name: str, value: Union[BaseValue, str]):
"""Add a parameter of the function to compile to the artifacts.
Args:
name (str): name of the parameter
value (Union[BaseValue, str]): value of the parameter or textual representation of it
Returns:
None
"""
self.parameters_of_the_function_to_compile[name] = str(value)
def add_operation_graph(self, name: str, operation_graph: OPGraph):
"""Add an operation graph to the artifacts.
Args:
name (str): name of the graph
operation_graph (OPGraph): the operation graph itself
Returns:
None
"""
try:
drawing = draw_graph(operation_graph)
self.drawings_of_operation_graphs[name] = drawing
# Do not crash on imports ourselves for drawings if the package is not installed
except ImportError as e: # pragma: no cover
if "pygraphviz" in str(e):
pass
else:
raise e
textual_representation = format_operation_graph(operation_graph)
self.textual_representations_of_operation_graphs[name] = textual_representation
self.final_operation_graph = operation_graph
def add_final_operation_graph_bounds(self, bounds: Dict[IntermediateNode, Dict[str, Any]]):
"""Add the bounds of the final operation graph to the artifacts.
Args:
bounds (Dict[IntermediateNode, Dict[str, Any]]): the bound dictionary
Returns:
None
"""
assert_true(self.final_operation_graph is not None)
self.bounds_of_the_final_operation_graph = bounds
def add_final_operation_graph_mlir(self, mlir: str):
"""Add the mlir of the final operation graph to the artifacts.
Args:
mlir (str): the mlir code of the final operation graph
Returns:
None
"""
assert_true(self.final_operation_graph is not None)
self.mlir_of_the_final_operation_graph = mlir
def export(self):
"""Export the artifacts to a the output directory.
Returns:
None
"""
output_directory = self.output_directory
if output_directory.exists():
shutil.rmtree(output_directory)
output_directory.mkdir(parents=True)
with open(output_directory.joinpath("environment.txt"), "w", encoding="utf-8") as f:
f.write(f"{platform.platform()} {platform.version()}\n")
f.write(f"Python {platform.python_version()}\n")
with open(output_directory.joinpath("requirements.txt"), "w", encoding="utf-8") as f:
# example `pip list` output
# Package Version
# ----------------------------- ---------
# alabaster 0.7.12
# appdirs 1.4.4
# ... ...
# ... ...
# wrapt 1.12.1
# zipp 3.5.0
pip_process = subprocess.run(
["pip", "--disable-pip-version-check", "list"], stdout=subprocess.PIPE, check=True
)
dependencies = iter(pip_process.stdout.decode("utf-8").split("\n"))
# skip 'Package ... Version' line
next(dependencies)
# skip '------- ... -------' line
next(dependencies)
for dependency in dependencies:
tokens = [token for token in dependency.split(" ") if token != ""]
if len(tokens) == 0:
continue
name = tokens[0]
version = tokens[1]
f.write(f"{name}=={version}\n")
if self.source_code_of_the_function_to_compile is not None:
with open(output_directory.joinpath("function.txt"), "w", encoding="utf-8") as f:
f.write(self.source_code_of_the_function_to_compile)
if len(self.parameters_of_the_function_to_compile) > 0:
with open(output_directory.joinpath("parameters.txt"), "w", encoding="utf-8") as f:
for name, parameter in self.parameters_of_the_function_to_compile.items():
f.write(f"{name} :: {parameter}\n")
drawings = self.drawings_of_operation_graphs.items()
for index, (name, drawing_filename) in enumerate(drawings):
identifier = CompilationArtifacts._identifier(index, name)
shutil.copy(drawing_filename, output_directory.joinpath(f"{identifier}.png"))
textual_representations = self.textual_representations_of_operation_graphs.items()
for index, (name, representation) in enumerate(textual_representations):
identifier = CompilationArtifacts._identifier(index, name)
with open(output_directory.joinpath(f"{identifier}.txt"), "w", encoding="utf-8") as f:
f.write(f"{representation}")
if self.bounds_of_the_final_operation_graph is not None:
assert_true(self.final_operation_graph is not None)
with open(output_directory.joinpath("bounds.txt"), "w", encoding="utf-8") as f:
# TODO:
# if nx.topological_sort is not deterministic between calls,
# the lines below will not work properly
# thus, we may want to change this in the future
for index, node in enumerate(nx.topological_sort(self.final_operation_graph.graph)):
bounds = self.bounds_of_the_final_operation_graph.get(node)
assert_true(bounds is not None)
f.write(f"%{index} :: [{bounds.get('min')}, {bounds.get('max')}]\n")
if self.mlir_of_the_final_operation_graph is not None:
assert_true(self.final_operation_graph is not None)
with open(output_directory.joinpath("mlir.txt"), "w", encoding="utf-8") as f:
f.write(self.mlir_of_the_final_operation_graph)
@staticmethod
def _identifier(index, name):
return f"{index + 1}.{name}.graph"

View File

@@ -1,46 +0,0 @@
"""Module for compilation configuration."""
class CompilationConfiguration:
"""Class that allows the compilation process to be customized."""
dump_artifacts_on_unexpected_failures: bool
enable_topological_optimizations: bool
check_every_input_in_inputset: bool
treat_warnings_as_errors: bool
enable_unsafe_features: bool
random_inputset_samples: int
use_insecure_key_cache: bool
auto_parallelize: bool
loop_parallelize: bool
dataflow_parallelize: bool
# pylint: disable=too-many-arguments
def __init__(
self,
dump_artifacts_on_unexpected_failures: bool = True,
enable_topological_optimizations: bool = True,
check_every_input_in_inputset: bool = False,
treat_warnings_as_errors: bool = False,
enable_unsafe_features: bool = False,
random_inputset_samples: int = 30,
use_insecure_key_cache: bool = False,
auto_parallelize: bool = False,
loop_parallelize: bool = True,
dataflow_parallelize: bool = False,
):
self.dump_artifacts_on_unexpected_failures = dump_artifacts_on_unexpected_failures
self.enable_topological_optimizations = enable_topological_optimizations
self.check_every_input_in_inputset = check_every_input_in_inputset
self.treat_warnings_as_errors = treat_warnings_as_errors
self.enable_unsafe_features = enable_unsafe_features
self.random_inputset_samples = random_inputset_samples
self.use_insecure_key_cache = use_insecure_key_cache
self.auto_parallelize = auto_parallelize
self.loop_parallelize = loop_parallelize
self.dataflow_parallelize = dataflow_parallelize
# pylint: enable=too-many-arguments
def __eq__(self, other) -> bool:
return isinstance(other, CompilationConfiguration) and self.__dict__ == other.__dict__

View File

@@ -1,4 +0,0 @@
"""Module for data types code and data structures."""
from . import dtypes_helpers, floats, integers
from .floats import Float, Float16, Float32, Float64
from .integers import Integer, SignedInteger, UnsignedInteger

View File

@@ -1,11 +0,0 @@
"""File holding code to represent data types in a program."""
from abc import ABC, abstractmethod
class BaseDataType(ABC):
"""Base class to represent a data type."""
@abstractmethod
def __eq__(self, o: object) -> bool:
"""No default implementation."""

View File

@@ -1,393 +0,0 @@
"""File to hold helper functions for data types related stuff."""
from copy import deepcopy
from functools import partial
from typing import Callable, Optional, Tuple, Union, cast
from ..debugging.custom_assert import assert_true
from ..values import BaseValue, ClearTensor, EncryptedTensor, TensorValue
from .base import BaseDataType
from .floats import Float
from .integers import Integer, get_bits_to_represent_value_as_integer
INTEGER_TYPES = (Integer,)
FLOAT_TYPES = (Float,)
BASE_DATA_TYPES = INTEGER_TYPES + FLOAT_TYPES
def value_is_encrypted_scalar_integer(value_to_check: BaseValue) -> bool:
"""Check that a value is an encrypted scalar of type Integer.
Args:
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is an encrypted scalar of type Integer
"""
return value_is_scalar_integer(value_to_check) and value_to_check.is_encrypted
def value_is_encrypted_scalar_unsigned_integer(value_to_check: BaseValue) -> bool:
"""Check that a value is an encrypted scalar of type unsigned Integer.
Args:
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is an encrypted scalar of type Integer and
unsigned
"""
return (
value_is_encrypted_scalar_integer(value_to_check)
and not cast(Integer, value_to_check.dtype).is_signed
)
def value_is_clear_scalar_integer(value_to_check: BaseValue) -> bool:
"""Check that a value is a clear scalar of type Integer.
Args:
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is a clear scalar of type Integer
"""
return value_is_scalar_integer(value_to_check) and value_to_check.is_clear
def value_is_scalar_integer(value_to_check: BaseValue) -> bool:
"""Check that a value is a scalar of type Integer.
Args:
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is a scalar of type Integer
"""
return (
isinstance(value_to_check, TensorValue)
and value_to_check.is_scalar
and isinstance(value_to_check.dtype, INTEGER_TYPES)
)
def value_is_integer(value_to_check: BaseValue) -> bool:
"""Check that a value is of type Integer.
Args:
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is of type Integer
"""
return isinstance(value_to_check.dtype, INTEGER_TYPES)
def value_is_unsigned_integer(value_to_check: BaseValue) -> bool:
"""Check that a value is of type Integer and is unsigned.
Args:
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is of type Integer and is unsigned
"""
return (
isinstance(value_to_check.dtype, INTEGER_TYPES)
and not cast(Integer, value_to_check.dtype).is_signed
)
def value_is_encrypted_tensor_integer(value_to_check: BaseValue) -> bool:
"""Check that a value is an encrypted TensorValue of type Integer.
Args:
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is an encrypted TensorValue of type Integer
"""
return value_is_tensor_integer(value_to_check) and value_to_check.is_encrypted
def value_is_clear_tensor_integer(value_to_check: BaseValue) -> bool:
"""Check that a value is a clear TensorValue of type Integer.
Args:
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is a clear TensorValue of type Integer
"""
return value_is_tensor_integer(value_to_check) and value_to_check.is_clear
def value_is_tensor_integer(value_to_check: BaseValue) -> bool:
"""Check that a value is a TensorValue of type Integer.
Args:
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is a TensorValue of type Integer
"""
return (
isinstance(value_to_check, TensorValue)
and not value_to_check.is_scalar
and isinstance(value_to_check.dtype, INTEGER_TYPES)
)
def find_type_to_hold_both_lossy(
dtype1: BaseDataType,
dtype2: BaseDataType,
) -> BaseDataType:
"""Determine the type that can represent both dtype1 and dtype2 separately.
This is lossy with floating point types.
Args:
dtype1 (BaseDataType): first dtype to hold
dtype2 (BaseDataType): second dtype to hold
Raises:
NotImplementedError: Raised if one of the two input dtypes is not an Integer as they are the
only type supported for now
Returns:
BaseDataType: The dtype able to hold (potentially lossy) dtype1 and dtype2
"""
assert_true(isinstance(dtype1, BASE_DATA_TYPES), f"Unsupported dtype1: {type(dtype1)}")
assert_true(isinstance(dtype2, BASE_DATA_TYPES), f"Unsupported dtype2: {type(dtype2)}")
type_to_return: BaseDataType
if isinstance(dtype1, Integer) and isinstance(dtype2, Integer):
d1_signed = dtype1.is_signed
d2_signed = dtype2.is_signed
max_bits = max(dtype1.bit_width, dtype2.bit_width)
if d1_signed and d2_signed:
type_to_return = Integer(max_bits, is_signed=True)
elif not d1_signed and not d2_signed:
type_to_return = Integer(max_bits, is_signed=False)
elif d1_signed and not d2_signed:
# 2 is unsigned, if it has the bigger bit_width, we need a signed integer that can hold
# it, so add 1 bit of sign to its bit_width
if dtype2.bit_width >= dtype1.bit_width:
new_bit_width = dtype2.bit_width + 1
type_to_return = Integer(new_bit_width, is_signed=True)
else:
type_to_return = Integer(dtype1.bit_width, is_signed=True)
elif not d1_signed and d2_signed:
# Same as above, with 1 and 2 switched around
if dtype1.bit_width >= dtype2.bit_width:
new_bit_width = dtype1.bit_width + 1
type_to_return = Integer(new_bit_width, is_signed=True)
else:
type_to_return = Integer(dtype2.bit_width, is_signed=True)
elif isinstance(dtype1, Float) and isinstance(dtype2, Float):
max_bits = max(dtype1.bit_width, dtype2.bit_width)
type_to_return = Float(max_bits)
elif isinstance(dtype1, Float):
type_to_return = deepcopy(dtype1)
elif isinstance(dtype2, Float):
type_to_return = deepcopy(dtype2)
return type_to_return
def mix_tensor_values_determine_holding_dtype(
value1: TensorValue,
value2: TensorValue,
) -> TensorValue:
"""Return mixed TensorValue with data type able to hold both value1 and value2 dtypes.
Returns a TensorValue that would result from computation on both value1 and value2 while
determining the data type able to hold both value1 and value2 data type (this can be lossy
with floats).
Args:
value1 (TensorValue): first TensorValue to mix.
value2 (TensorValue): second TensorValue to mix.
Returns:
TensorValue: The resulting mixed TensorValue with data type able to hold both value1 and
value2 dtypes.
"""
assert_true(
isinstance(value1, TensorValue), f"Unsupported value1: {value1}, expected TensorValue"
)
assert_true(
isinstance(value2, TensorValue), f"Unsupported value2: {value2}, expected TensorValue"
)
resulting_shape = broadcast_shapes(value1.shape, value2.shape)
assert_true(
resulting_shape is not None,
(
f"Tensors have incompatible shapes which is not supported.\n"
f"value1: {value1.shape}, value2: {value2.shape}"
),
)
assert resulting_shape is not None # this is to make mypy happy
holding_type = find_type_to_hold_both_lossy(value1.dtype, value2.dtype)
if value1.is_encrypted or value2.is_encrypted:
mixed_value = EncryptedTensor(dtype=holding_type, shape=resulting_shape)
else:
mixed_value = ClearTensor(dtype=holding_type, shape=resulting_shape)
return mixed_value
def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> BaseValue:
"""Return mixed BaseValue with data type able to hold both value1 and value2 dtypes.
Returns a BaseValue that would result from computation on both value1 and value2 while
determining the data type able to hold both value1 and value2 data type (this can be lossy
with floats). Supports only mixing instances from the same class.
Args:
value1 (BaseValue): first BaseValue to mix.
value2 (BaseValue): second BaseValue to mix.
Raises:
ValueError: raised if the BaseValue is not one of (TensorValue)
Returns:
BaseValue: The resulting mixed BaseValue with data type able to hold both value1 and value2
dtypes.
"""
assert_true(
(value1.__class__ == value2.__class__),
f"Cannot mix values of different types: value 1:{type(value1)}, value2: {type(value2)}",
)
if isinstance(value1, TensorValue) and isinstance(value2, TensorValue):
return mix_tensor_values_determine_holding_dtype(value1, value2)
raise ValueError(
f"{mix_values_determine_holding_dtype.__name__} does not support value {type(value1)}"
)
def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]) -> BaseDataType:
"""Determine the BaseDataType to hold the input constant data.
Args:
constant_data (Union[int, float]): The constant data for which to determine the
corresponding BaseDataType.
Returns:
BaseDataType: The corresponding BaseDataType
"""
constant_data_type: BaseDataType
assert_true(
isinstance(constant_data, (int, float)),
f"Unsupported constant data of type {type(constant_data)}",
)
if isinstance(constant_data, int):
is_signed = constant_data < 0
constant_data_type = Integer(
get_bits_to_represent_value_as_integer(constant_data, is_signed), is_signed
)
elif isinstance(constant_data, float):
constant_data_type = Float(64)
return constant_data_type
def get_base_value_for_python_constant_data(
constant_data: Union[int, float]
) -> Callable[..., BaseValue]:
"""Wrap the BaseDataType to hold the input constant data in BaseValue partial.
The returned object can then be instantiated as an Encrypted or Clear version
by calling it with the proper arguments forwarded to the BaseValue `__init__` function
Args:
constant_data (Union[int, float]): The constant data for which to determine the
corresponding Value.
Returns:
Callable[..., BaseValue]: A partial object that will return the proper BaseValue when
called with `is_encrypted` as keyword argument (forwarded to the BaseValue `__init__`
method).
"""
constant_data_type = get_base_data_type_for_python_constant_data(constant_data)
return partial(TensorValue, dtype=constant_data_type, shape=())
def get_constructor_for_python_constant_data(constant_data: Union[int, float]):
"""Get the constructor for the passed python constant data.
Args:
constant_data (Any): The data for which we want to determine the type constructor.
"""
return type(constant_data)
def is_data_type_compatible_with(
dtype: BaseDataType,
other: BaseDataType,
) -> bool:
"""Determine whether dtype is compatible with other.
`dtype` being compatible with `other` means `other` can hold every value of `dtype`
(e.g., uint2 is compatible with uint4 and int4)
(e.g., int2 is compatible with int4 but not with uint4)
Note that this function is not symetric.
(e.g., uint2 is compatible with uint4, but uint4 is not compatible with uint2)
Args:
dtype (BaseDataType): dtype to check compatiblity
other (BaseDataType): dtype to check compatiblity against
Returns:
bool: Whether the dtype is compatible with other or not
"""
combination = find_type_to_hold_both_lossy(dtype, other)
return other == combination
def broadcast_shapes(shape1: Tuple[int, ...], shape2: Tuple[int, ...]) -> Optional[Tuple[int, ...]]:
"""Broadcast two shapes into a single shape.
We are mimicing the exact semantics of broadcasting in numpy.
You can learn more about it here: https://numpy.org/doc/stable/user/theory.broadcasting.html
Args:
shape1 (Tuple[int, ...]): first shape to broadcast
shape2 (Tuple[int, ...]): second shape to broadcast
Returns:
Optional[Tuple[int, ...]]: None if the shapes are not broadcastable else broadcasted shape
"""
result = []
for size1, size2 in zip(shape1[::-1], shape2[::-1]):
if size1 != size2 and size1 != 1 and size2 != 1 and size1 != 0 and size2 != 0:
return None
if size1 == 0 or size2 == 0:
result.append(0)
else:
result.append(max(size1, size2))
if len(result) < len(shape1):
for i in reversed(range(len(shape1) - len(result))):
result.append(shape1[i])
elif len(result) < len(shape2):
for i in reversed(range(len(shape2) - len(result))):
result.append(shape2[i])
return tuple(reversed(result))

View File

@@ -1,33 +0,0 @@
"""This file holds the definitions for floating point types."""
from functools import partial
from ..debugging.custom_assert import assert_true
from . import base
class Float(base.BaseDataType):
"""Class representing a float."""
# bit_width is the total number of bits used to represent a floating point number, including
# sign bit, exponent and mantissa
bit_width: int
def __init__(self, bit_width: int) -> None:
super().__init__()
assert_true(bit_width in (16, 32, 64), "Only 16, 32 and 64 bits floats are supported")
self.bit_width = bit_width
def __repr__(self) -> str:
return f"{self.__class__.__name__}<{self.bit_width} bits>"
def __str__(self) -> str:
return f"float{self.bit_width}"
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.bit_width == other.bit_width
Float16 = partial(Float, 16)
Float32 = partial(Float, 32)
Float64 = partial(Float, 64)

View File

@@ -1,144 +0,0 @@
"""This file holds the definitions for integer types."""
import math
from typing import Any, Iterable
from ..debugging.custom_assert import assert_true
from . import base
class Integer(base.BaseDataType):
"""Class representing an integer."""
bit_width: int
is_signed: bool
def __init__(self, bit_width: int, is_signed: bool) -> None:
super().__init__()
assert_true(bit_width > 0, "bit_width must be > 0")
self.bit_width = bit_width
self.is_signed = is_signed
def __repr__(self) -> str:
signed_str = "signed" if self.is_signed else "unsigned"
return f"{self.__class__.__name__}<{signed_str}, {self.bit_width} bits>"
def __str__(self) -> str:
return f"{('int' if self.is_signed else 'uint')}{self.bit_width}"
def __eq__(self, other: object) -> bool:
return (
isinstance(other, self.__class__)
and self.bit_width == other.bit_width
and self.is_signed == other.is_signed
)
def min_value(self) -> int:
"""Minimum value representable by the Integer."""
if self.is_signed:
return -(2 ** (self.bit_width - 1))
return 0
def max_value(self) -> int:
"""Maximum value representable by the Integer."""
if self.is_signed:
return 2 ** (self.bit_width - 1) - 1
return 2 ** self.bit_width - 1
def can_represent_value(self, value_to_represent: int) -> bool:
"""Check if a value is representable by the Integer.
Args:
value_to_represent (int): Value to check
Returns:
bool: True if the value can be represented by this integer
"""
return self.min_value() <= value_to_represent <= self.max_value()
def create_signed_integer(bit_width: int) -> Integer:
"""Create a signed integer.
Args:
bit_width (int): width of the integer
Returns:
Integer: A signed integer with the requested bit_width
"""
return Integer(bit_width, is_signed=True)
SignedInteger = create_signed_integer
def create_unsigned_integer(bit_width: int) -> Integer:
"""Create an unsigned integer.
Args:
bit_width (int): width of the integer
Returns:
Integer: An unsigned integer with the requested bit_width
"""
return Integer(bit_width, is_signed=False)
UnsignedInteger = create_unsigned_integer
def make_integer_to_hold(values: Iterable[Any], force_signed: bool) -> Integer:
"""Return an Integer able to hold all values, it is possible to force the Integer to be signed.
Args:
values (Iterable[Any]): The values to hold
force_signed (bool): Set to True to force the result to be a signed Integer
Returns:
Integer: The Integer able to hold values
"""
min_value = min(values)
max_value = max(values)
make_signed_integer = force_signed or min_value < 0
num_bits = max(
get_bits_to_represent_value_as_integer(min_value, make_signed_integer),
get_bits_to_represent_value_as_integer(max_value, make_signed_integer),
)
return Integer(num_bits, is_signed=make_signed_integer)
def get_bits_to_represent_value_as_integer(value: Any, force_signed: bool) -> int:
"""Return how many bits are required to represent a numerical Value.
Args:
value (Any): The value for which we want to know how many bits are required.
force_signed (bool): Set to True to force the result to be a signed integer.
Returns:
int: required amount of bits
"""
# Writing this in a very dumb way
num_bits: int
if value < 0:
abs_value = abs(value)
if abs_value > 1:
num_bits = math.ceil(math.log2(abs_value)) + 1
else:
# -1 case
num_bits = 2
else:
if value > 1:
num_bits = math.ceil(math.log2(value + 1))
else:
# 0 and 1 case
num_bits = 1
if force_signed:
num_bits += 1
return num_bits

View File

@@ -1,4 +0,0 @@
"""Module for debugging."""
from .custom_assert import assert_true
from .drawing import draw_graph
from .formatting import format_operation_graph

View File

@@ -1,49 +0,0 @@
"""Provide some variants of assert."""
def _custom_assert(condition: bool, on_error_msg: str = "") -> None:
"""Provide a custom assert which is kept even if the optimized python mode is used.
See https://docs.python.org/3/reference/simple_stmts.html#assert for the documentation
on the classical assert function
Args:
condition(bool): the condition. If False, raise AssertionError
on_error_msg(str): optional message for precising the error, in case of error
"""
if not condition:
raise AssertionError(on_error_msg)
def assert_true(condition: bool, on_error_msg: str = ""):
"""Provide a custom assert to check that the condition is True.
Args:
condition(bool): the condition. If False, raise AssertionError
on_error_msg(str): optional message for precising the error, in case of error
"""
return _custom_assert(condition, on_error_msg)
def assert_false(condition: bool, on_error_msg: str = ""):
"""Provide a custom assert to check that the condition is False.
Args:
condition(bool): the condition. If True, raise AssertionError
on_error_msg(str): optional message for precising the error, in case of error
"""
return _custom_assert(not condition, on_error_msg)
def assert_not_reached(on_error_msg: str):
"""Provide a custom assert to check that a piece of code is never reached.
Args:
on_error_msg(str): message for precising the error
"""
return _custom_assert(False, on_error_msg)

View File

@@ -1,156 +0,0 @@
"""functions to draw the different graphs we can generate in the package, eg to debug."""
import os
import tempfile
from pathlib import Path
from typing import Optional
import matplotlib.pyplot as plt
import networkx as nx
from PIL import Image
from ..debugging.custom_assert import assert_true
from ..operator_graph import OPGraph
from ..representation.intermediate import (
ALL_IR_NODES,
Add,
Constant,
Conv2D,
Dot,
GenericFunction,
IndexConstant,
Input,
MatMul,
Mul,
Sub,
)
IR_NODE_COLOR_MAPPING = {
Input: "blue",
Constant: "cyan",
Conv2D: "brown",
Add: "red",
Sub: "yellow",
Mul: "green",
GenericFunction: "orange",
IndexConstant: "black",
Dot: "purple",
MatMul: "brown",
"GenericFunction": "orange",
"TLU": "grey",
"output": "magenta",
}
_missing_nodes_in_mapping = ALL_IR_NODES - IR_NODE_COLOR_MAPPING.keys()
assert_true(
len(_missing_nodes_in_mapping) == 0,
(
f"Missing IR node in IR_NODE_COLOR_MAPPING : "
f"{', '.join(sorted(str(node_type) for node_type in _missing_nodes_in_mapping))}"
),
)
del _missing_nodes_in_mapping
def draw_graph(
op_graph: OPGraph,
show: bool = False,
vertical: bool = True,
save_to: Optional[Path] = None,
) -> str:
"""Draws operation graphs and optionally saves/shows the drawing.
Note that this function requires the python `pygraphviz` package which itself requires the
installation of `graphviz` packages, see
https://pygraphviz.github.io/documentation/stable/install.html
Args:
op_graph (OPGraph): the operation graph to be drawn and optionally saved/shown
show (bool): if set to True, the drawing will be shown using matplotlib
vertical (bool): if set to True, the orientation will be vertical
save_to (Optional[Path]): if specified, the drawn graph will be saved to this path; else
it is saved in a temporary file
Returns:
The path of the file where the drawn graph is saved
"""
def get_color(node, output_nodes):
value_to_return = IR_NODE_COLOR_MAPPING[type(node)]
if node in output_nodes:
value_to_return = IR_NODE_COLOR_MAPPING["output"]
elif isinstance(node, GenericFunction):
value_to_return = IR_NODE_COLOR_MAPPING.get(node.op_name, value_to_return)
return value_to_return
graph = op_graph.graph
output_nodes = set(op_graph.output_nodes.values())
attributes = {
node: {
"label": node.text_for_drawing(),
"color": get_color(node, output_nodes),
"penwidth": 2, # double thickness for circles
"peripheries": 2 if node in output_nodes else 1, # double circle for output nodes
}
for node in graph.nodes
}
nx.set_node_attributes(graph, attributes)
# TODO: #639 adapt drawing routine to manage output_idx
for edge in graph.edges(keys=True):
idx = graph.edges[edge]["input_idx"]
graph.edges[edge]["label"] = f" {idx} " # spaces are there intentionally for a better look
try:
agraph = nx.nx_agraph.to_agraph(graph)
except ImportError as e: # pragma: no cover
if "pygraphviz" in str(e):
err_msg = (
f"{draw_graph.__name__} requires pygraphviz, install your OS graphviz distribution "
"https://pygraphviz.github.io/documentation/stable/install.html "
"and reinstall with extras: `pip install --force-reinstall "
"concrete-numpy[full]`"
)
raise ImportError(err_msg) from e
agraph.graph_attr["rankdir"] = "TB" if vertical else "LR"
agraph.layout("dot")
if save_to is None:
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
# we need to change the permissions of the temporary file
# so that it can be read by all users
# (https://stackoverflow.com/a/44130605)
# get the old umask and replace it with 0o666
old_umask = os.umask(0o666)
# restore the old umask back
os.umask(old_umask)
# combine the old umask with the wanted permissions
permissions = 0o666 & ~old_umask
# set new permissions
os.chmod(tmp.name, permissions)
save_to_str = str(tmp.name)
else:
save_to_str = str(save_to)
agraph.draw(save_to_str)
if show: # pragma: no cover
# We can't have coverage in this branch as `plt.show()` blocks and waits for user action.
plt.close("all")
plt.figure()
img = Image.open(save_to_str)
plt.imshow(img)
img.close()
plt.axis("off")
plt.show()
return save_to_str

View File

@@ -1,151 +0,0 @@
"""Functions to format operation graphs for debugging purposes."""
from typing import Dict, List, Optional, Tuple
import networkx as nx
from ..debugging.custom_assert import assert_true
from ..operator_graph import OPGraph
from ..representation.intermediate import GenericFunction, IntermediateNode
def format_operation_graph(
op_graph: OPGraph,
maximum_constant_length: int = 25,
highlighted_nodes: Optional[Dict[IntermediateNode, List[str]]] = None,
) -> str:
"""Format an operation graph.
Args:
op_graph (OPGraph):
the operation graph to format
maximum_constant_length (int):
maximum length of the constant throughout the formatting
highlighted_nodes (Optional[Dict[IntermediateNode, List[str]]] = None):
the dict of nodes and their corresponding messages which will be highlighted
Returns:
str: formatted operation graph
"""
# This function is well documented and split into very readable sections
# Thus, splitting it to multiple functions doesn't increase readability
# pylint: disable=too-many-locals,too-many-branches
assert_true(isinstance(op_graph, OPGraph))
# (node, output_index) -> identifier
# e.g., id_map[(node1, 0)] = 2 and id_map[(node1, 1)] = 3
# means line for node1 is in this form (%2, %3) = node1.format(...)
id_map: Dict[Tuple[IntermediateNode, int], int] = {}
# lines that will be merged at the end
lines: List[str] = []
# type information to add to each line (for alingment, this is done after lines are determined)
type_informations: List[str] = []
# default highlighted nodes is empty
highlighted_nodes = highlighted_nodes if highlighted_nodes is not None else {}
# highlight information for lines, this is required because highlights are added to lines
# after their type information is added and we only have line numbers, not nodes
highlighted_lines: Dict[int, List[str]] = {}
# subgraphs to format after the main graph is formatted
subgraphs: Dict[str, OPGraph] = {}
# format nodes
for node in nx.topological_sort(op_graph.graph):
# assign a unique id to outputs of node
assert_true(len(node.outputs) > 0)
for i in range(len(node.outputs)):
id_map[(node, i)] = len(id_map)
# remember highlights of the node
if node in highlighted_nodes:
highlighted_lines[len(lines)] = highlighted_nodes[node]
# extract predecessors and their ids
predecessors = []
for predecessor, output_idx in op_graph.get_ordered_preds_and_inputs_of(node):
predecessors.append(f"%{id_map[(predecessor, output_idx)]}")
# start the build the line for the node
line = ""
# add output information to the line
outputs = ", ".join(f"%{id_map[(node, i)]}" for i in range(len(node.outputs)))
line += outputs if len(node.outputs) == 1 else f"({outputs})"
# add node information to the line
line += " = "
line += node.text_for_formatting(predecessors, maximum_constant_length)
# append line to list of lines
lines.append(line)
# if exists, save the subgraph
if isinstance(node, GenericFunction) and "float_op_subgraph" in node.op_kwargs:
subgraphs[line] = node.op_kwargs["float_op_subgraph"]
# remember type information of the node
types = ", ".join(str(output) for output in node.outputs)
type_informations.append(types if len(node.outputs) == 1 else f"({types})")
# align = signs
#
# e.g.,
#
# %1 = ...
# %2 = ...
# ...
# %8 = ...
# %9 = ...
# %10 = ...
# %11 = ...
# ...
longest_length_before_equals_sign = max(len(line.split("=")[0]) for line in lines)
for i, line in enumerate(lines):
length_before_equals_sign = len(line.split("=")[0])
lines[i] = (" " * (longest_length_before_equals_sign - length_before_equals_sign)) + line
# add type informations
longest_line_length = max(len(line) for line in lines)
for i, line in enumerate(lines):
lines[i] += " " * (longest_line_length - len(line))
lines[i] += f" # {type_informations[i]}"
# add highlights (this is done in reverse to keep indices consistent)
for i in reversed(range(len(lines))):
if i in highlighted_lines:
for j, message in enumerate(highlighted_lines[i]):
highlight = "^" if j == 0 else " "
lines.insert(i + 1 + j, f"{highlight * len(lines[i])} {message}")
# add return information
# (if there is a single return, it's in the form `return %id`
# (otherwise, it's in the form `return (%id1, %id2, ..., %idN)`
returns: List[str] = []
for node in op_graph.output_nodes.values():
outputs = ", ".join(f"%{id_map[(node, i)]}" for i in range(len(node.outputs)))
returns.append(outputs if len(node.outputs) == 1 else f"({outputs})")
lines.append("return " + returns[0] if len(returns) == 1 else f"({', '.join(returns)})")
# format subgraphs after the actual graph
result = "\n".join(lines)
if len(subgraphs) > 0:
result += "\n\n"
result += "Subgraphs:"
for line, subgraph in subgraphs.items():
subgraph_lines = format_operation_graph(subgraph, maximum_constant_length).split("\n")
result += "\n\n"
result += f" {line}:\n\n"
result += "\n".join(f" {line}" for line in subgraph_lines)
# pylint: enable=too-many-locals,too-many-branches
return result

View File

@@ -1,2 +0,0 @@
"""Extensions module to provide additional functionality to our users."""
from . import convolution, multi_table, table

View File

@@ -1,161 +0,0 @@
"""This file contains tracers for convolution operations."""
from typing import List, Optional, Tuple, Union, cast
import numpy as np
from ...numpy.tracing import NPConstant, NPTracer
from ..representation.intermediate import Conv2D
from ..tracing.base_tracer import BaseTracer
SUPPORTED_AUTO_PAD = [
"NOTSET",
]
def conv2d(
x: Union[np.ndarray, BaseTracer],
weight: Union[np.ndarray, BaseTracer],
bias: Optional[Union[np.ndarray, BaseTracer]] = None,
pads: Union[Tuple[int, int, int, int], List[int]] = (0, 0, 0, 0),
strides: Union[Tuple[int, int], List[int]] = (1, 1),
dilations: Union[Tuple[int, int], List[int]] = (1, 1),
auto_pad: str = "NOTSET",
) -> Union[np.ndarray, NPTracer]:
"""Trace or evaluate 2D convolution.
Args:
x (Union[np.ndarray, BaseTracer]): Input of shape (NxCxHxW)
weight (Union[np.ndarray, BaseTracer]): Weight (kernel) of shape (FxCxHxW)
bias (Optional[Union[np.ndarray, BaseTracer]], optional): Bias vector of size (F).
Defaults to None.
pads (Union[Tuple[int, int, int, int], List[int]], optional): Padding over each axis
(H_beg, W_beg, H_end, W_end). Defaults to (0, 0, 0, 0).
strides (Union[Tuple[int, int], List[int]], optional): Stride over each axis
(height and width). Defaults to (1, 1).
dilations (Union[Tuple[int, int], List[int]], optional): Dilation over each axis
(height and width). Defaults to (1, 1).
auto_pad (str, optional): Padding strategy. Defaults to "NOTSET".
Raises:
ValueError: If one argument isn't in the range of expected values.
TypeError: If one argument isn't of the appropriate type.
Returns:
Union[np.ndarray, BaseTracer]: Evaluation result, or traced computation
"""
if auto_pad not in SUPPORTED_AUTO_PAD:
raise ValueError("invalid auto_pad is specified")
if not isinstance(x, (np.ndarray, BaseTracer)):
raise TypeError(f"input x must be an ndarray, or a BaseTracer, not a {type(x)}")
if not isinstance(weight, (np.ndarray, BaseTracer)):
raise TypeError(f"weight must be an ndarray, or a BaseTracer, not a {type(weight)}")
if not isinstance(bias, (np.ndarray, BaseTracer, type(None))):
raise TypeError(f"bias must be an ndarray, a BaseTracer, or None, not a {type(bias)}")
if not isinstance(pads, (tuple, list)):
raise TypeError(f"padding must be a tuple, or list, not a {type(pads)}")
if not isinstance(strides, (tuple, list)):
raise TypeError(f"strides must be a tuple, or list, not a {type(strides)}")
if not isinstance(dilations, (tuple, list)):
raise TypeError(f"dilations must be a tuple, or list, not a {type(dilations)}")
if len(pads) != 4:
raise ValueError(
f"padding should be of the form (pad_height_begin, pad_width_begin, pad_height_end, "
f" pad_width_end), but got {type(pads)} of length {len(pads)}"
)
if len(strides) != 2:
raise ValueError(
f"strides should be of the form (stride_height, stride_width), but got {type(strides)}"
f" of length {len(strides)}"
)
if len(dilations) != 2:
raise ValueError(
f"dilations should be of the form (dilation_height, dilation_width), but got"
f" {type(dilations)} of length {len(dilations)}"
)
assert len(x.shape) == 4, f"input x should have size (N x C x H x W), not {x.shape}"
assert len(weight.shape) == 4, f"weight should have size (F x C x H x W), not {weight.shape}"
if bias is not None:
assert len(bias.shape) == 1, f"bias should have size (F), not {bias.shape}"
if isinstance(x, BaseTracer):
return _trace_conv2d(x, weight, bias, pads, strides, dilations)
# X is an ndarray
bias = np.zeros(weight.shape[0]) if bias is None else bias
# For mypy
weight = cast(np.ndarray, weight)
bias = cast(np.ndarray, bias)
return _evaluate_conv2d(x, weight, bias, pads, strides, dilations)
def _trace_conv2d(
x: BaseTracer,
weight: Union[np.ndarray, BaseTracer],
bias: Optional[Union[np.ndarray, BaseTracer]],
pads: Union[Tuple[int, int, int, int], List[int]],
strides: Union[Tuple[int, int], List[int]],
dilations: Union[Tuple[int, int], List[int]],
) -> NPTracer:
"""Trace 2D convolution.
Args:
x (BaseTracer): Input of shape (NxCxHxW)
weight (Union[np.ndarray, BaseTracer]): Weight (kernel) of shape (FxCxHxW)
bias (Optional[Union[np.ndarray, BaseTracer]]): Bias vector of size (F)
pads (Union[Tuple[int, int, int, int], List[int]]): Padding over each
axis (H_beg, W_beg, H_end, W_end)
strides (Union[Tuple[int, int], List[int]]): Stride over each
axis (height and width)
dilations (Union[Tuple[int, int], List[int]]): Dilation over each
axis (height and width)
Returns:
BaseTracer: Traced computation
"""
weight_tracer = (
weight if isinstance(weight, BaseTracer) else NPTracer([], NPConstant(weight), 0)
)
inputs = [x.output, weight_tracer.output]
output_tracer_inputs = [x, weight_tracer]
if bias is not None:
bias_tracer = bias if isinstance(bias, BaseTracer) else NPTracer([], NPConstant(bias), 0)
inputs.append(bias_tracer.output)
# For mypy
bias = cast(BaseTracer, bias_tracer)
output_tracer_inputs.append(bias)
traced_computation = Conv2D(inputs, x.output.dtype, pads, strides, dilations)
output_tracer = x.__class__(
output_tracer_inputs, traced_computation=traced_computation, output_idx=0
)
# For mypy
assert isinstance(output_tracer, NPTracer)
return output_tracer
def _evaluate_conv2d(
x: np.ndarray,
weight: np.ndarray,
bias: np.ndarray,
pads: Union[Tuple[int, int, int, int], List[int]],
strides: Union[Tuple[int, int], List[int]],
dilations: Union[Tuple[int, int], List[int]],
) -> np.ndarray:
"""Evaluate 2D convolution.
Args:
x (np.ndarray): Input of shape (NxCxHxW)
weight (np.ndarray): Weight (kernel) of shape (FxCxHxW)
bias (np.ndarray): Bias vector of size (F)
pads (Union[Tuple[int, int, int, int], List[int]]): Padding over each
axis (H_beg, W_beg, H_end, W_end)
strides (Union[Tuple[int, int], List[int]]): Stride over each axis (height and width)
dilations (Union[Tuple[int, int], List[int]]): Dilation over each axis (height and width)
Returns:
np.ndarray: Result of the convolution of shape (NxCxHxW)
"""
return Conv2D.evaluate_conv2d(x, weight, bias, pads, strides, dilations)

View File

@@ -1,233 +0,0 @@
"""This file contains a wrapper class for direct multi table lookups."""
import itertools
from copy import deepcopy
from typing import List, Tuple, Union
from ..data_types.base import BaseDataType
from ..data_types.dtypes_helpers import find_type_to_hold_both_lossy
from ..representation.intermediate import GenericFunction
from ..tracing.base_tracer import BaseTracer
from ..values import TensorValue
from .table import LookupTable
class MultiLookupTable:
"""Class representing a multi lookup table."""
# Multi table lookup is needed when you want to perform a lookup on a tensor,
# but you want each element to be used with a different lookup table.
#
# Here is an example:
#
# You have x which is of shape (2, 3),
# you want the first row to be indexed with `table1 = LookupTable([2, 3, 1, 0])`
# and the second row to be indexed with `table1 = LookupTable([0, 1, 3, 2])`
#
# You can create such a multi lookup table
# multitable = MultiLookupTable(
# [
# [table1, table1, table1],
# [table2, table2, table2],
# ],
# )
# (notice the shape of multitable matches with the shape of x)
#
# and use multitable[x] toget the following result
# assert multitable[x] == [
# [table1[x[0, 0]], table1[x[0, 1]], table1[x[0, 2]]],
# [table2[x[1, 0]], table2[x[1, 1]], table2[x[1, 2]]],
# ]
# underlying lookup tables
tables: List
# shape of the input of the lookup
input_shape: Tuple[int, ...]
# type of the result of the lookup
output_dtype: BaseDataType
def __init__(self, tables: List):
input_shape_list: List[int] = []
MultiLookupTable._extract_shape_using_first_elements_only(tables, input_shape_list)
input_shape: Tuple[int, ...] = tuple(input_shape_list)
table_sizes: List[int] = []
table_output_dtypes: List[BaseDataType] = []
MultiLookupTable._check_shape_and_record_luts(
tables,
0,
input_shape,
table_sizes,
table_output_dtypes,
)
for i in range(1, len(table_sizes)):
if table_sizes[i - 1] != table_sizes[i]:
# this branch is for such a case:
#
# table1 = hnp.LookupTable([1, 3])
# table2 = hnp.LookupTable([0, 2, 3, 1])
#
# multitable = hnp.MultiLookupTable(
# [
# [table1, table2, table1],
# [table2, table1, table2],
# ],
# )
raise ValueError(
f"LookupTables within a MultiLookupTable "
f"should have the same size but they do not "
f"(there was a table with the size of {table_sizes[i - 1]} "
f"and another with the size of {table_sizes[i]})"
)
output_dtype = table_output_dtypes[0]
for table_output_dtype in table_output_dtypes:
output_dtype = find_type_to_hold_both_lossy(output_dtype, table_output_dtype)
self.tables = tables
self.input_shape = input_shape
self.output_dtype = output_dtype
def __getitem__(self, key: Union[int, BaseTracer]):
# this branch is used during tracing and the regular flow is used during evaluation
if isinstance(key, BaseTracer):
out_dtype = deepcopy(key.output.dtype)
out_shape = deepcopy(self.input_shape)
generic_function_output_value = TensorValue(
out_dtype,
key.output.is_encrypted,
out_shape,
)
traced_computation = GenericFunction(
inputs=[key.output],
arbitrary_func=MultiLookupTable._checked_indexing,
output_value=generic_function_output_value,
op_kind="TLU",
op_kwargs={
"input_shape": deepcopy(self.input_shape),
"tables": deepcopy(self.tables),
},
op_name="MultiTLU",
)
return key.__class__(
inputs=[key],
traced_computation=traced_computation,
output_idx=0,
)
# if not, it means table is indexed with a constant
# thus, the result of the lookup is a constant
# so, we can propagate it directly
return MultiLookupTable._checked_indexing(key, self.input_shape, self.tables)
@staticmethod
def _extract_shape_using_first_elements_only(array, shape):
if not isinstance(array, list):
# base case for recursion
# the shape is already accumulated up to this point
# so we just return
return
if len(array) == 0:
# this branch is for such a case:
#
# table1 = hnp.LookupTable([1, 3, 2, 0])
# table2 = hnp.LookupTable([0, 2, 3, 1])
#
# multitable = hnp.MultiLookupTable(
# [
# [],
# [table1, table2, table1],
# [table2, table1, table2],
# ],
# )
raise ValueError("MultiLookupTable cannot have an empty array within it")
shape.append(len(array))
MultiLookupTable._extract_shape_using_first_elements_only(array[0], shape)
@staticmethod
def _check_shape_and_record_luts(array, dimension, shape, table_sizes, table_output_dtypes):
if dimension == len(shape):
if not isinstance(array, LookupTable):
# this branch is for such a case:
#
# table1 = hnp.LookupTable([1, 3, 2, 0])
# table2 = hnp.LookupTable([0, 2, 3, 1])
#
# multitable = hnp.MultiLookupTable(
# [
# [table1, table2, 4],
# [table2, table1, table2],
# ],
# )
raise ValueError(
f"MultiLookupTable should have been made out of LookupTables "
f"but it had an object of type {array.__class__.__name__} within it"
)
table_sizes.append(len(array.table))
table_output_dtypes.append(array.output_dtype)
return
if not isinstance(array, list) or len(array) != shape[dimension]:
# this branch is for such a case:
#
# table1 = hnp.LookupTable([1, 3, 2, 0])
# table2 = hnp.LookupTable([0, 2, 3, 1])
#
# multitable = hnp.MultiLookupTable(
# [
# [table1, table2],
# [table2, table1, table2],
# ],
# )
raise ValueError(
f"MultiLookupTable should have the shape {shape} but it does not "
f"(an array on dimension {dimension} has the size {len(array)} "
f"but its size should have been {shape[dimension]} "
f"as the expected shape is {shape})"
)
for item in array:
MultiLookupTable._check_shape_and_record_luts(
item,
dimension + 1,
shape,
table_sizes,
table_output_dtypes,
)
@staticmethod
def _checked_indexing(x, input_shape, tables):
try:
result = []
for indices in itertools.product(*[range(dimension) for dimension in input_shape]):
which_table_to_use = tables
what_value_to_use = x
where_to_append = result
for index in indices[:-1]:
which_table_to_use = tables[index]
what_value_to_use = x[index]
if len(where_to_append) == index:
where_to_append.append([])
where_to_append = result[index]
which_table_to_use = which_table_to_use[indices[-1]]
what_value_to_use = what_value_to_use[indices[-1]]
where_to_append.append(which_table_to_use[what_value_to_use])
except Exception as error:
raise ValueError(
f"Multiple Lookup Table of shape {input_shape} cannot be looked up with {x} "
f"(you should check your inputset)",
) from error
return result

View File

@@ -1,118 +0,0 @@
"""This file contains a wrapper class for direct table lookups."""
from copy import deepcopy
from typing import Any, Iterable, List, Tuple, Union
from ..common_helpers import is_a_power_of_2
from ..data_types.base import BaseDataType
from ..data_types.integers import make_integer_to_hold
from ..representation.intermediate import GenericFunction
from ..tracing.base_tracer import BaseTracer
class LookupTable:
"""Class representing a lookup table."""
# lookup table itself, has 2^N entries
table: Tuple[int, ...]
# type of the result of the lookup
output_dtype: BaseDataType
def __init__(self, table: Iterable[int]):
table = tuple(table)
if not is_a_power_of_2(len(table)):
raise ValueError(
f"Desired lookup table has inappropriate number of entries ({len(table)})"
)
self.table = table
self.output_dtype = make_integer_to_hold(table, force_signed=False)
def __repr__(self):
return str(list(self.table))
def __getitem__(self, key: Union[int, Iterable, BaseTracer]):
# if a tracer is used for indexing,
# we need to create an `GenericFunction` node
# because the result will be determined during the runtime
if isinstance(key, BaseTracer):
generic_function_output_value = deepcopy(key.output)
generic_function_output_value.dtype = self.output_dtype
traced_computation = GenericFunction(
inputs=[key.output],
arbitrary_func=LookupTable._checked_indexing,
output_value=generic_function_output_value,
op_kind="TLU",
op_kwargs={"table": deepcopy(self.table)},
op_name="TLU",
)
return key.__class__(
inputs=[key],
traced_computation=traced_computation,
output_idx=0,
)
# if not, it means table is indexed with a constant
# thus, the result of the lookup is a constant
# so, we can propagate it directly
return LookupTable._checked_indexing(key, self.table)
@staticmethod
def _check_index_out_of_range(x, table):
if not -len(table) <= x < len(table):
raise ValueError(
f"Lookup table with {len(table)} entries cannot be indexed with {x} "
f"(you should check your inputset)",
)
@staticmethod
def _checked_indexing(x, table):
"""Index `table` using `x`.
There is a single table and the indexing works with the following semantics:
- when x == c
- table[x] == table[c]
- when x == [c1, c2]
- table[x] == [table[c1], table[c2]]
- when x == [[c1, c2], [c3, c4], [c5, c6]]
- table[x] == [[table[c1], table[c2]], [table[c3], table[c4]], [table[c5], table[c6]]]
Args:
x (Union[int, Iterable]): index to use
table (Tuple[int, ...]): table to index
Returns:
Union[int, List[int]]: result of indexing
"""
if not isinstance(x, Iterable):
LookupTable._check_index_out_of_range(x, table)
return table[x]
def fill_result(partial_result: List[Any], partial_x: Iterable[Any]):
"""Fill partial result with partial x.
This function implements the recursive indexing of nested iterables.
Args:
partial_result (List[Any]): currently accumulated result
partial_x (Iterable[Any]): current index to use
Returns:
None
"""
for item in partial_x:
if isinstance(item, Iterable):
partial_result.append([])
fill_result(partial_result[-1], item)
else:
LookupTable._check_index_out_of_range(item, table)
partial_result.append(table[item])
result = []
fill_result(result, x)
return result

View File

@@ -1,144 +0,0 @@
"""Module to hold the result of compilation."""
from pathlib import Path
from typing import Optional, Union
import numpy
from concrete.compiler import (
ClientParameters,
ClientSupport,
CompilationOptions,
JITCompilationResult,
JITLambda,
JITSupport,
KeySet,
KeySetCache,
PublicArguments,
PublicResult,
)
from .debugging import draw_graph, format_operation_graph
from .operator_graph import OPGraph
class FHECircuit:
"""Class which is the result of compilation."""
op_graph: OPGraph
_jit_support: JITSupport
_compilation_result: JITCompilationResult
_client_parameters: ClientParameters
_server_lambda: JITLambda
_keyset_cache: KeySetCache
_keyset: KeySet
def __init__(
self,
op_graph: OPGraph,
mlir_str: str,
unsecure_key_set_cache_path: Optional[str] = None,
auto_parallelize: bool = False,
loop_parallelize: bool = False,
dataflow_parallelize: bool = False,
):
self.op_graph = op_graph
self._jit_support = JITSupport.new()
# Set compilation options
options = CompilationOptions.new("main")
options.set_auto_parallelize(auto_parallelize)
options.set_loop_parallelize(loop_parallelize)
options.set_dataflow_parallelize(dataflow_parallelize)
# Compile
self._compilation_result = self._jit_support.compile(mlir_str, options)
self._client_parameters = self._jit_support.load_client_parameters(self._compilation_result)
self._server_lambda = self._jit_support.load_server_lambda(self._compilation_result)
# Setup keyset cache
self._keyset_cache = None
if unsecure_key_set_cache_path:
self._keyset_cache = KeySetCache.new(unsecure_key_set_cache_path)
self._keyset = None
def __str__(self):
return format_operation_graph(self.op_graph)
def draw(
self,
show: bool = False,
vertical: bool = True,
save_to: Optional[Path] = None,
) -> str:
"""Draw operation graph of the circuit and optionally save/show the drawing.
Args:
show (bool): if set to True, the drawing will be shown using matplotlib
vertical (bool): if set to True, the orientation will be vertical
save_to (Optional[Path]): if specified, the drawn graph will be saved to this path;
otherwise it will be saved to a temporary file
Returns:
str: path of the file where the drawn graph is saved
"""
return draw_graph(self.op_graph, show, vertical, save_to)
def keygen(self, force: bool = False):
"""Generate the keys required for the encrypted circuit.
Args:
force (bool, optional): generate even if keyset already exists. Defaults to False.
"""
if self._keyset is None or force:
self._keyset = ClientSupport.key_set(self._client_parameters, self._keyset_cache)
def encrypt(self, *args: Union[int, numpy.ndarray]) -> PublicArguments:
"""Encrypt the inputs of the circuit.
Args:
*args (Union[int, numpy.ndarray]): plain input of the circuit
Returns:
PublicArguments: encrypted and plain arguments as well as public keys
"""
# Make sure keys are available: shouldn't regenerate if they already exist
self.keygen(force=False)
return ClientSupport.encrypt_arguments(self._client_parameters, self._keyset, args)
def run(self, args: PublicArguments) -> PublicResult:
"""Evaluate the the encrypted circuit (no encryption or decryption involved).
Args:
args (PublicArguments): encrypted inputs to the circuit
Returns:
PublicResult: encrypted result
"""
return self._jit_support.server_call(self._server_lambda, args)
def decrypt(self, result: PublicResult) -> Union[int, numpy.ndarray]:
"""Decrypt the result of the circuit.
Args:
result (PublicResult): encrypted result of the circuit
Returns:
Union[int, numpy.ndarray]: plain result of the circuit
"""
return ClientSupport.decrypt_result(self._keyset, result)
def encrypt_run_decrypt(self, *args: Union[int, numpy.ndarray]) -> Union[int, numpy.ndarray]:
"""Encrypt, evaluate, and decrypt the inputs on the circuit.
Generate keyset automatically if not yet done.
Args:
*args (Union[int, numpy.ndarray]): plain inputs of the circuit
Returns:
Union[int, numpy.ndarray]: plain result of the circuit
"""
self.keygen(force=False)
public_args = self.encrypt(*args)
encrypted_result = self.run(public_args)
decrypted_result = self.decrypt(encrypted_result)
return decrypted_result

View File

@@ -1,3 +0,0 @@
"""Helpers for all kinds of tasks."""
from . import indexing_helpers, python_helpers

View File

@@ -1,47 +0,0 @@
"""Helpers for formatting functionality."""
from typing import Any, Dict, Hashable
import numpy
from ..debugging.custom_assert import assert_true
SPECIAL_OBJECT_MAPPING: Dict[Any, str] = {
numpy.float32: "float32",
numpy.float64: "float64",
numpy.int8: "int8",
numpy.int16: "int16",
numpy.int32: "int32",
numpy.int64: "int64",
numpy.uint8: "uint8",
numpy.uint16: "uint16",
numpy.uint32: "uint32",
numpy.uint64: "uint64",
}
def format_constant(constant: Any, maximum_length: int = 45) -> str:
"""Format a constant.
Args:
constant (Any): the constant to format
maximum_length (int): maximum length of the resulting string
Returns:
str: the formatted constant
"""
if isinstance(constant, Hashable) and constant in SPECIAL_OBJECT_MAPPING:
return SPECIAL_OBJECT_MAPPING[constant]
# maximum_length should not be smaller than 7 characters because
# the constant will be formatted to `x ... y`
# where x and y are part of the constant and they are at least 1 character
assert_true(maximum_length >= 7)
content = str(constant).replace("\n", "")
if len(content) > maximum_length:
from_start = (maximum_length - 5) // 2
from_end = (maximum_length - 5) - from_start
content = f"{content[:from_start]} ... {content[-from_end:]}"
return content

View File

@@ -1,277 +0,0 @@
"""Helpers for indexing functionality."""
from typing import Tuple, Union
def format_indexing_element(indexing_element: Union[int, slice]) -> str:
"""Format an indexing element.
This is required mainly for slices. The reason is that string representation of slices
are very long and verbose. To give an example, `x[:, 2:]` will have the following index
`[slice(None, None, None), slice(2, None, None)]` if printed naively. With this helper,
it will be formatted as `[:, 2:]`.
Args:
indexing_element (Union[int, slice]): indexing element to be formatted
Returns:
str: formatted element
"""
result = ""
if isinstance(indexing_element, slice):
if indexing_element.start is not None:
result += str(indexing_element.start)
result += ":"
if indexing_element.stop is not None:
result += str(indexing_element.stop)
if indexing_element.step is not None:
result += ":"
result += str(indexing_element.step)
else:
result += str(indexing_element)
return result.replace("\n", " ")
def validate_index(
index: Union[int, slice, Tuple[Union[int, slice], ...]],
) -> Tuple[Union[int, slice], ...]:
"""Make sure index is valid and convert it to the tuple form.
For example in `x[2]`, `index` is passed as `2`.
To make it easier to work with, this function converts index to `(2,)`.
Args:
index (Union[int, slice, Tuple[Union[int, slice], ...]]): index to validate, improve
and return
Returns:
Tuple[Union[int, slice], ...]: validated and improved index
"""
if not isinstance(index, tuple):
index = (index,)
for indexing_element in index:
valid = isinstance(indexing_element, (int, slice))
if isinstance(indexing_element, slice):
if (
not (indexing_element.start is None or isinstance(indexing_element.start, int))
or not (indexing_element.stop is None or isinstance(indexing_element.stop, int))
or not (indexing_element.step is None or isinstance(indexing_element.step, int))
):
valid = False
if not valid:
raise TypeError(
f"Only integers and integer slices can be used for indexing "
f"but you tried to use {format_indexing_element(indexing_element)} for indexing"
)
return index
def determine_output_shape(
input_shape: Tuple[int, ...],
index: Tuple[Union[int, slice], ...],
) -> Tuple[int, ...]:
"""Determine the output shape from the input shape and the index.
e.g., for `input_shape=(3, 2)` and `index=(:, 0)`, returns `(3,)`
for `input_shape=(4, 3, 2)` and `index=(2:,)`, returns `(2, 3, 2)`
Args:
input_shape (Tuple[int, ...]): shape of the input tensor that is indexed
index (Tuple[Union[int, slice], ...]): desired and validated index
Returns:
Tuple[int, ...]: shape of the result of indexing
"""
indexing_elements = [format_indexing_element(indexing_element) for indexing_element in index]
index_str = f"[{', '.join(indexing_elements)}]"
if len(index) > len(input_shape):
raise ValueError(
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
f"as the index has more elements than the number of dimensions of the tensor"
)
# indexing (3, 4, 5) with [1] is the same as indexing it with [1, :, :]
# indexing (3, 4, 5) with [1, 2] is the same as indexing it with [1, 2, :]
# so let's replicate that behavior to make the rest of the code generic
index += (slice(None, None, None),) * (len(input_shape) - len(index))
output_shape = []
for dimension, (indexing_element, dimension_size) in enumerate(zip(index, input_shape)):
if isinstance(indexing_element, int): # indexing removes the dimension
indexing_element = (
indexing_element if indexing_element >= 0 else indexing_element + dimension_size
)
if not 0 <= indexing_element < dimension_size:
raise ValueError(
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
f"because index is out of range for dimension {dimension}"
)
elif isinstance(indexing_element, slice): # indexing possibly shrinks the dimension
output_shape.append(
determine_new_dimension_size(
indexing_element,
dimension_size,
dimension,
input_shape,
index_str,
)
)
return tuple(output_shape)
def sanitize_start_index(
start: int,
dimension_size: int,
# the rest is used for detailed exception message
dimension: int,
input_shape: Tuple[int, ...],
index_str: str,
) -> int:
"""Sanitize and check start index of a slice.
Args:
start (int): start index being sanitized
dimension_size (int): size of the dimension the slice is applied to
dimension (int): index of the dimension being sliced (for better messages)
input_shape (Tuple[int, ...]): shape of the whole input (for better messages)
index_str (str): string representation of the whole index (for better messages)
Returns:
int: sanitized start index
"""
start = start if start >= 0 else start + dimension_size
if not 0 <= start < dimension_size:
raise ValueError(
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
f"because start index is out of range for dimension {dimension}"
)
return start
def sanitize_stop_index(
stop: int,
dimension_size: int,
# the rest is used for detailed exception message
dimension: int,
input_shape: Tuple[int, ...],
index_str: str,
) -> int:
"""Sanitize and check stop index of a slice.
Args:
stop (int): stop index being sanitized
dimension_size (int): size of the dimension the slice is applied to
dimension (int): index of the dimension being sliced (for better messages)
input_shape (Tuple[int, ...]): shape of the whole input (for better messages)
index_str (str): string representation of the whole index (for better messages)
Returns:
int: sanitized stop index
"""
stop = stop if stop >= 0 else stop + dimension_size
if not 0 <= stop <= dimension_size:
raise ValueError(
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
f"because stop index is out of range for dimension {dimension}"
)
return stop
def determine_new_dimension_size(
slice_: slice,
dimension_size: int,
# the rest is used for detailed exception message
dimension: int,
input_shape: Tuple[int, ...],
index_str: str,
) -> int:
"""Determine the new size of a dimension from the old size and the slice applied to it.
e.g., for `slice_=1:4` and `dimension_size=5`, returns `3`
for `slice_=::-1` and `dimension_size=5`, returns `5`
You may want to check this page to learn more about how this function works
https://numpy.org/doc/stable/reference/arrays.indexing.html#basic-slicing-and-indexing
Args:
slice_ (slice): slice being applied to the dimension
dimension_size (int): size of the dimension the slice is applied to
dimension (int): index of the dimension being sliced (for better messages)
input_shape (Tuple[int, ...]): shape of the whole input (for better messages)
index_str (str): string representation of the whole index (for better messages)
Returns:
int: new size of the dimension
"""
step = slice_.step if slice_.step is not None else 1
if step > 0:
start = slice_.start if slice_.start is not None else 0
stop = slice_.stop if slice_.stop is not None else dimension_size
start = sanitize_start_index(start, dimension_size, dimension, input_shape, index_str)
stop = sanitize_stop_index(stop, dimension_size, dimension, input_shape, index_str)
if start >= stop:
raise ValueError(
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
f"because start index is not less than stop index for dimension {dimension}"
)
size_before_stepping = stop - start
elif step < 0:
start = slice_.start if slice_.start is not None else dimension_size - 1
stop = slice_.stop
start = sanitize_start_index(start, dimension_size, dimension, input_shape, index_str)
if stop is None:
# this is a weird case but it works as expected
# the issue is that it's impossible to slice whole vector reversed
# with a stop value different than none
# if `x.shape == (6,)` the only one that works is `x[::-1].shape == (6,)`
# here is what doesn't work (and this is expected it's just weird)
#
# ...
# `x[:-2:-1].shape == (1,)`
# `x[:-1:-1].shape == (0,)` (note that this is a hard error for us)
# `x[:0:-1].shape == (5,)`
# `x[:1:-1].shape == (4,)`
# ...
size_before_stepping = start + 1
else:
stop = sanitize_stop_index(stop, dimension_size, dimension, input_shape, index_str)
if stop >= start:
raise ValueError(
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
f"because step is negative and "
f"stop index is not less than start index for dimension {dimension}"
)
size_before_stepping = start - stop
else:
raise ValueError(
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
f"because step is zero for dimension {dimension}"
)
quotient = size_before_stepping // abs(step)
remainder = size_before_stepping % abs(step)
return quotient + (remainder != 0)

View File

@@ -1,35 +0,0 @@
"""Common python helpers."""
from typing import Any, Callable, Iterable, Mapping, Tuple, Union
def update_and_return_dict(
dict_to_update: dict, update_values: Union[Mapping, Iterable[Tuple[Any, Any]]]
) -> dict:
"""Update a dictionary and return the ref to the dictionary that was updated.
Args:
dict_to_update (dict): the dict to update
update_values (Union[Mapping, Iterable[Tuple[Any, Any]]]): the values to update the dict
with
Returns:
dict: the dict that was just updated.
"""
dict_to_update.update(update_values)
return dict_to_update
def catch(func: Callable, *args, **kwargs) -> Union[Any, None]:
"""Execute func by passing args and kwargs. Catch exceptions and return None in case of failure.
Args:
func (Callable): function to execute and catch exceptions from
Returns:
Union[Any, None]: the function result if there was no exception, None otherwise.
"""
try:
return func(*args, **kwargs)
except Exception: # pylint: disable=broad-except
return None

View File

@@ -1,3 +0,0 @@
"""MLIR conversion module."""
from .graph_converter import OPGraphConverter

View File

@@ -1,61 +0,0 @@
"""Helpers for MLIR conversion functionality."""
# pylint cannot extract symbol information of 'mlir' module so we need to disable some lints
# pylint: disable=no-name-in-module
from typing import Optional
from concrete.lang.dialects.fhe import EncryptedIntegerType
from mlir.ir import Context, IntegerType, RankedTensorType, Type
from ..data_types import Integer
from ..values import BaseValue, TensorValue
# pylint: enable=no-name-in-module
def integer_to_mlir_type(ctx: Context, integer: Integer, is_encrypted: bool) -> Optional[Type]:
"""Convert an integer to its corresponding MLIR type.
Args:
ctx (Context): the MLIR context to perform the conversion
integer (Integer): the integer to convert
is_encrypted (bool): whether the integer is encrypted or not
Returns:
Type:
the MLIR type corresponding to given integer and encryption status
if it's supported otherwise None
"""
bit_width = integer.bit_width
if is_encrypted:
result = EncryptedIntegerType.get(ctx, bit_width)
else:
result = IntegerType.get_signless(bit_width)
return result
def value_to_mlir_type(ctx: Context, value: BaseValue) -> Type:
"""Convert a value to its corresponding MLIR type.
Args:
ctx (Context): the MLIR context to perform the conversion
value (BaseValue): the value to convert
Returns:
Type: the MLIR type corresponding to given value
"""
dtype = value.dtype
if isinstance(dtype, Integer):
mlir_type = integer_to_mlir_type(ctx, dtype, value.is_encrypted)
if isinstance(value, TensorValue):
if not value.is_scalar:
mlir_type = RankedTensorType.get(value.shape, mlir_type)
return mlir_type
raise TypeError(f"{value} is not supported for MLIR conversion")

View File

@@ -1,138 +0,0 @@
"""Module that provides OPGraph conversion functionality."""
# pylint cannot extract symbol information of 'mlir' module so we need to disable some lints
# pylint: disable=no-name-in-module
from abc import ABC, abstractmethod
from typing import Any, Dict, List
import concrete.lang as concretelang
import networkx as nx
from mlir.dialects import builtin
from mlir.ir import Context, InsertionPoint, Location, Module
from ..operator_graph import OPGraph
from ..representation.intermediate import Input, IntermediateNode
from .conversion_helpers import value_to_mlir_type
from .node_converter import IntermediateNodeConverter
# pylint: enable=no-name-in-module
class OPGraphConverter(ABC):
"""Converter of OPGraph to MLIR."""
def convert(self, op_graph: OPGraph) -> str:
"""Convert an operation graph to its corresponding MLIR representation.
Args:
op_graph (OPGraph): the operation graph to be converted
Returns:
str: textual MLIR representation corresponding to given operation graph
"""
additional_conversion_info = self._generate_additional_info_dict(op_graph)
# There are no tensor +*- scalar operations in the compiler
# But such operations are used commonly so we need to support them
# So, we implemented some workarounds (pull request #970)
# Once we have native support, this workaround shall be removed (issue #837)
# (most changes in #970 shall be reverted)
# { node1: "%arg0", node2: "%0", node3: "%1" }
nodes_to_mlir_names: Dict[IntermediateNode, str] = {}
# { "%arg0": "i5", "%0": "tensor<2x3x!FHE.eint<4>>" }
mlir_names_to_mlir_types: Dict[str, str] = {}
# { "%0": ["%c1_i5"] } == for %0 we need to convert %c1_i5 to 1d tensor
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]] = {}
with Context() as ctx, Location.unknown():
concretelang.register_dialects(ctx)
module = Module.create()
with InsertionPoint(module.body):
parameters = [
value_to_mlir_type(ctx, input_node.outputs[0])
for input_node in op_graph.get_ordered_inputs()
]
@builtin.FuncOp.from_py_func(*parameters)
def main(*arg):
ir_to_mlir = {}
for arg_num, node in op_graph.input_nodes.items():
ir_to_mlir[node] = arg[arg_num]
mlir_name = f"%arg{arg_num}"
nodes_to_mlir_names[node] = mlir_name
mlir_names_to_mlir_types[mlir_name] = str(parameters[arg_num])
for node in nx.topological_sort(op_graph.graph):
if isinstance(node, Input):
continue
preds = [ir_to_mlir[pred] for pred in op_graph.get_ordered_preds(node)]
node_converter = IntermediateNodeConverter(
ctx,
op_graph,
node,
preds,
nodes_to_mlir_names,
mlir_names_to_mlir_types,
scalar_to_1d_tensor_conversion_hacks,
)
ir_to_mlir[node] = node_converter.convert(additional_conversion_info)
results = (
ir_to_mlir[output_node] for output_node in op_graph.get_ordered_outputs()
)
return results
module_lines_after_hacks_are_applied = []
for line in str(module).split("\n"):
mlir_name = line.split("=")[0].strip()
if mlir_name not in scalar_to_1d_tensor_conversion_hacks:
module_lines_after_hacks_are_applied.append(line)
continue
to_be_replaced = scalar_to_1d_tensor_conversion_hacks[mlir_name]
for arg_name in to_be_replaced:
new_name = f"%hack_{mlir_name.replace('%', '')}_{arg_name.replace('%', '')}"
mlir_type = mlir_names_to_mlir_types[arg_name]
hack_line = (
f" {new_name} = tensor.from_elements {arg_name} : tensor<1x{mlir_type}>"
)
module_lines_after_hacks_are_applied.append(hack_line)
line = line.replace(arg_name, new_name)
new_arg_types = []
arg_types = line.split(":")[1].split("->")[0].strip()[1:-1]
for arg in arg_types.split(", "):
if arg.startswith("tensor"):
new_arg_types.append(arg)
else:
new_arg_types.append(f"tensor<1x{arg}>")
line = line.replace(arg_types, ", ".join(new_arg_types))
module_lines_after_hacks_are_applied.append(line)
return "\n".join(module_lines_after_hacks_are_applied)
@staticmethod
@abstractmethod
def _generate_additional_info_dict(op_graph: OPGraph) -> Dict[str, Any]:
"""Generate additional conversion info dict for the MLIR converter.
Args:
op_graph (OPGraph): the operation graph from which the additional info will be generated
Returns:
Dict[str, Any]: dict of additional conversion info
"""

View File

@@ -1,873 +0,0 @@
"""Module that provides IntermediateNode conversion functionality."""
# pylint cannot extract symbol information of 'mlir' module so we need to disable some lints
# pylint: disable=no-name-in-module
from typing import Any, Dict, List, Tuple, cast
import numpy
from concrete.lang.dialects import fhe, fhelinalg
from mlir.dialects import arith, linalg, tensor
from mlir.ir import (
ArrayAttr,
Attribute,
BoolAttr,
Context,
DenseElementsAttr,
IndexType,
IntegerAttr,
IntegerType,
OpResult,
RankedTensorType,
)
from ..data_types import Integer
from ..debugging import assert_true
from ..helpers.indexing_helpers import determine_new_dimension_size
from ..operator_graph import OPGraph
from ..representation.intermediate import (
Add,
Constant,
Conv2D,
Dot,
GenericFunction,
IndexConstant,
IntermediateNode,
MatMul,
Mul,
Sub,
)
from ..values import TensorValue
from .conversion_helpers import integer_to_mlir_type, value_to_mlir_type
# pylint: enable=no-name-in-module
class IntermediateNodeConverter:
"""Converter of IntermediateNode to MLIR."""
ctx: Context
op_graph: OPGraph
node: IntermediateNode
preds: List[OpResult]
all_of_the_inputs_are_encrypted: bool
all_of_the_inputs_are_tensors: bool
one_of_the_inputs_is_a_tensor: bool
nodes_to_mlir_names: Dict[IntermediateNode, str]
mlir_names_to_mlir_types: Dict[str, str]
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]]
def __init__(
self,
ctx: Context,
op_graph: OPGraph,
node: IntermediateNode,
preds: List[OpResult],
nodes_to_mlir_names: Dict[OpResult, str],
mlir_names_to_mlir_types: Dict[str, str],
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]],
):
self.ctx = ctx
self.op_graph = op_graph
self.node = node
self.preds = preds
self.all_of_the_inputs_are_encrypted = True
self.all_of_the_inputs_are_tensors = True
self.one_of_the_inputs_is_a_tensor = False
for inp in node.inputs:
if inp.is_clear:
self.all_of_the_inputs_are_encrypted = False
if isinstance(inp, TensorValue):
if inp.is_scalar:
self.all_of_the_inputs_are_tensors = False
else:
self.one_of_the_inputs_is_a_tensor = True
else: # pragma: no cover
# this branch is not covered as there are only TensorValues for now
self.all_of_the_inputs_are_tensors = False
self.nodes_to_mlir_names = nodes_to_mlir_names
self.mlir_names_to_mlir_types = mlir_names_to_mlir_types
self.scalar_to_1d_tensor_conversion_hacks = scalar_to_1d_tensor_conversion_hacks
def convert(self, additional_conversion_info: Dict[str, Any]) -> OpResult:
"""Convert an intermediate node to its corresponding MLIR representation.
Args:
additional_conversion_info (Dict[str, Any]):
external info that the converted node might need
Returns:
str: textual MLIR representation corresponding to self.node
"""
# pylint: disable=too-many-branches
if isinstance(self.node, Add):
result = self.convert_add()
elif isinstance(self.node, Constant):
result = self.convert_constant()
elif isinstance(self.node, Dot):
result = self.convert_dot()
elif isinstance(self.node, GenericFunction):
if self.node.op_name in ["flatten", "reshape"]:
# notice flatten() == reshape(-1) and convert_reshape can handle that
result = self.convert_reshape()
elif self.node.op_name == "sum":
result = self.convert_sum()
elif self.node.op_name == "concat":
result = self.convert_concat()
elif self.node.op_name == "transpose":
result = self.convert_transpose()
else:
result = self.convert_generic_function(additional_conversion_info)
elif isinstance(self.node, IndexConstant):
result = self.convert_index_constant()
elif isinstance(self.node, MatMul):
result = self.convert_matmul()
elif isinstance(self.node, Mul):
result = self.convert_mul()
elif isinstance(self.node, Sub):
result = self.convert_sub()
elif isinstance(self.node, Conv2D):
result = self.convert_conv2d()
else: # pragma: no cover
# this branch is not covered as unsupported opeations fail on check mlir compatibility
raise NotImplementedError(f"{type(self.node)} nodes cannot be converted to MLIR yet")
# pylint: enable=too-many-branches
mlir_name = str(result).replace("Value(", "").split("=", maxsplit=1)[0].strip()
self.nodes_to_mlir_names[self.node] = mlir_name
self.mlir_names_to_mlir_types[mlir_name] = str(result.type)
if isinstance(self.node, (Add, Mul, Sub, Dot)):
if self.one_of_the_inputs_is_a_tensor and not self.all_of_the_inputs_are_tensors:
to_be_converted = []
for (pred, output) in self.op_graph.get_ordered_preds_and_inputs_of(self.node):
inp = pred.outputs[output]
if isinstance(inp, TensorValue) and inp.is_scalar:
to_be_converted.append(self.nodes_to_mlir_names[pred])
self.scalar_to_1d_tensor_conversion_hacks[mlir_name] = to_be_converted
return result
def convert_add(self) -> OpResult:
"""Convert an Add node to its corresponding MLIR representation.
Returns:
str: textual MLIR representation corresponding to self.node
"""
assert_true(len(self.node.inputs) == 2)
assert_true(len(self.node.outputs) == 1)
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
preds = self.preds
if self.all_of_the_inputs_are_encrypted:
if self.one_of_the_inputs_is_a_tensor:
result = fhelinalg.AddEintOp(resulting_type, *preds).result
else:
result = fhe.AddEintOp(resulting_type, *preds).result
else:
if self.node.inputs[0].is_clear: # pragma: no cover
# this branch is not covered as it's impossible to get into due to how tracing works
# however, it doesn't hurt to keep it as an extra measure
preds = preds[::-1]
if self.one_of_the_inputs_is_a_tensor:
result = fhelinalg.AddEintIntOp(resulting_type, *preds).result
else:
result = fhe.AddEintIntOp(resulting_type, *preds).result
return result
def convert_concat(self) -> OpResult:
"""Convert a "concat" node to its corresponding MLIR representation.
Returns:
str: textual MLIR representation corresponding to self.node
"""
assert_true(len(self.node.inputs) >= 2)
assert_true(len(self.node.outputs) == 1)
node = cast(GenericFunction, self.node)
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
axis = node.op_kwargs.get("axis", 0)
if axis is not None:
if axis < 0:
axis += len(cast(TensorValue, self.node.inputs[0]).shape)
return fhelinalg.ConcatOp(
resulting_type,
self.preds,
IntegerAttr.get(IntegerType.get_signless(64), axis),
).result
flattened_preds = []
for pred, input_value in zip(self.preds, self.node.inputs):
input_shape = cast(TensorValue, input_value).shape
input_size = numpy.prod(input_shape)
input_dtype = cast(Integer, input_value.dtype)
flattened_pred_type = RankedTensorType.get(
[input_size],
integer_to_mlir_type(self.ctx, input_dtype, input_value.is_encrypted),
)
flattened_pred = linalg.TensorCollapseShapeOp(
flattened_pred_type,
pred,
ArrayAttr.get(
[
ArrayAttr.get(
[
IntegerAttr.get(IndexType.parse("index"), i)
for i in range(len(input_shape))
]
)
]
),
).result
flattened_preds.append(flattened_pred)
return fhelinalg.ConcatOp(
resulting_type,
flattened_preds,
IntegerAttr.get(IntegerType.get_signless(64), 0),
).result
def convert_constant(self) -> OpResult:
"""Convert a Constant node to its corresponding MLIR representation.
Returns:
str: textual MLIR representation corresponding to self.node
"""
assert_true(len(self.node.inputs) == 0)
assert_true(len(self.node.outputs) == 1)
value = self.node.outputs[0]
if not isinstance(value, TensorValue): # pragma: no cover
# this branch is not covered as there are only TensorValues for now
raise NotImplementedError(f"{value} constants cannot be converted to MLIR yet")
resulting_type = value_to_mlir_type(self.ctx, value)
data = cast(Constant, self.node).constant_data
if value.is_scalar:
attr = IntegerAttr.get(resulting_type, data)
else:
# usage of `Attribute.parse` is the result of some limitations in the MLIR module
# provided by LLVM
# what should have been used is `DenseElementsAttr` but it's impossible to assign
# custom bit-widths using it (e.g., uint5)
# since we coudn't create a `DenseElementsAttr` with a custom bit width using python api
# we use `Attribute.parse` to let the underlying library do it by itself
attr = Attribute.parse(f"dense<{str(data.tolist())}> : {resulting_type}")
return arith.ConstantOp(resulting_type, attr).result
def convert_conv2d(self) -> OpResult:
"""Convert a Conv2D node to its corresponding MLIR representation.
Returns:
str: textual MLIR representation corresponding to self.node
"""
assert_true(len(self.node.inputs) == 2 or len(self.node.inputs) == 3)
assert_true(len(self.node.outputs) == 1)
has_bias = len(self.node.inputs) == 3
x = self.node.inputs[0]
weight = self.node.inputs[1]
if not (x.is_encrypted and weight.is_clear): # pragma: no cover
raise NotImplementedError(
f"Conv2D with input {x} and weight {weight} cannot be converted to MLIR yet",
)
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
preds = self.preds
node = cast(Conv2D, self.node)
integer_type = IntegerType.get_signless(64, context=self.ctx)
strides = DenseElementsAttr.get(
numpy.array(list(node.strides), dtype=numpy.uint64),
context=self.ctx,
type=integer_type,
)
dilations = DenseElementsAttr.get(
numpy.array(list(node.dilations), dtype=numpy.uint64),
context=self.ctx,
type=integer_type,
)
pads = DenseElementsAttr.get(
numpy.array(list(node.pads), dtype=numpy.uint64), context=self.ctx, type=integer_type
)
if has_bias:
result = fhelinalg.Conv2dOp(resulting_type, *preds, pads, strides, dilations).result
else:
result = fhelinalg.Conv2dOp(
resulting_type, *preds, None, pads, strides, dilations
).result
return result
def convert_dot(self) -> OpResult:
"""Convert a Dot node to its corresponding MLIR representation.
Returns:
str: textual MLIR representation corresponding to self.node
"""
assert_true(len(self.node.inputs) == 2)
assert_true(len(self.node.outputs) == 1)
if self.all_of_the_inputs_are_encrypted:
lhs = self.node.inputs[0]
rhs = self.node.inputs[1]
raise NotImplementedError(
f"Dot product between {lhs} and {rhs} cannot be converted to MLIR yet",
)
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
preds = self.preds
if self.node.inputs[0].is_clear:
preds = preds[::-1]
if self.all_of_the_inputs_are_tensors:
# numpy.dot(x, y) where x and y are both vectors = regular dot product
result = fhelinalg.Dot(resulting_type, *preds).result
elif not self.one_of_the_inputs_is_a_tensor:
# numpy.dot(x, y) where x and y are both scalars = x * y
result = fhe.MulEintIntOp(resulting_type, *preds).result
else:
# numpy.dot(x, y) where one of x or y is a scalar and the other one is a vector = x * y
result = fhelinalg.MulEintIntOp(resulting_type, *preds).result
return result
def convert_generic_function(self, additional_conversion_info: Dict[str, Any]) -> OpResult:
"""Convert a GenericFunction node to its corresponding MLIR representation.
Returns:
str: textual MLIR representation corresponding to self.node
"""
variable_input_indices = [
idx
for idx, inp in enumerate(self.op_graph.get_ordered_preds(self.node))
if not isinstance(inp, Constant)
]
if len(variable_input_indices) != 1: # pragma: no cover
# this branch is not covered as it's impossible to get into due to how tracing works
# however, it doesn't hurt to keep it as an extra measure
raise NotImplementedError(
"Table lookups with more than one variable input cannot be converted to MLIR yet"
)
variable_input_index = variable_input_indices[0]
assert_true(len(self.node.outputs) == 1)
output = self.node.outputs[0]
value = self.node.inputs[variable_input_index]
assert_true(value.is_encrypted)
if not isinstance(value.dtype, Integer): # pragma: no cover
# this branch is not covered as it's impossible to get into due to how compilation works
# however, it doesn't hurt to keep it as an extra measure
raise NotImplementedError(f"Table lookup on {value} cannot be converted to MLIR yet")
tables = additional_conversion_info["tables"][self.node]
assert_true(len(tables) > 0)
lut_shape: Tuple[int, ...] = ()
map_shape: Tuple[int, ...] = ()
if len(tables) == 1:
table = tables[0][0]
# The reduction on 63b is to avoid problems like doing a TLU of
# the form T[j] = 2<<j, for j which is supposed to be 7b as per
# constraint of the compiler, while in practice, it is a small
# value. Reducing on 64b was not ok for some reason
lut_shape = (len(table),)
lut_values = numpy.array(table % (2 << 63), dtype=numpy.uint64)
map_shape = ()
map_values = None
else:
assert_true(isinstance(output, TensorValue))
assert isinstance(output, TensorValue)
individual_table_size = len(tables[0][0])
lut_shape = (len(tables), individual_table_size)
map_shape = output.shape
lut_values = numpy.zeros(lut_shape, dtype=numpy.uint64)
map_values = numpy.zeros(map_shape, dtype=numpy.intp)
for i, (table, indices) in enumerate(tables):
assert_true(len(table) == individual_table_size)
lut_values[i, :] = table
for index in indices:
map_values[index] = i
lut_type = RankedTensorType.get(lut_shape, IntegerType.get_signless(64, context=self.ctx))
lut_attr = DenseElementsAttr.get(lut_values, context=self.ctx)
lut = arith.ConstantOp(lut_type, lut_attr).result
resulting_type = value_to_mlir_type(self.ctx, output)
pred = self.preds[variable_input_index]
if self.one_of_the_inputs_is_a_tensor:
if len(tables) == 1:
result = fhelinalg.ApplyLookupTableEintOp(resulting_type, pred, lut).result
else:
assert_true(map_shape != ())
assert_true(map_values is not None)
index_type = IndexType.parse("index")
map_type = RankedTensorType.get(map_shape, index_type)
map_attr = DenseElementsAttr.get(map_values, context=self.ctx, type=index_type)
result = fhelinalg.ApplyMappedLookupTableEintOp(
resulting_type,
pred,
lut,
arith.ConstantOp(map_type, map_attr).result,
).result
else:
result = fhe.ApplyLookupTableEintOp(resulting_type, pred, lut).result
return result
def convert_index_constant(self) -> OpResult:
"""Convert a IndexConstant node to its corresponding MLIR representation.
Returns:
str: textual MLIR representation corresponding to self.node
"""
# pylint: disable=too-many-locals
assert_true(len(self.node.inputs) == 1)
assert_true(len(self.node.outputs) == 1)
tensor_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
pred = self.preds[0]
input_value = cast(TensorValue, self.node.inputs[0])
input_shape = input_value.shape
index = cast(IndexConstant, self.node).index
index_str = self.node.text_for_formatting([""], 0)
index_type = IndexType.parse("index")
if len(index) == len(input_shape) and all(isinstance(i, int) for i in index):
indices = []
for value, dimension_size in zip(index, input_shape):
assert isinstance(value, int) # mypy
attr = IntegerAttr.get(index_type, value if value >= 0 else value + dimension_size)
indices.append(arith.ConstantOp(index_type, attr).result)
return tensor.ExtractOp(tensor_type, pred, indices).result
offsets = []
sizes = []
strides = []
destroyed_dimensions = []
for dimension, (indexing_element, dimension_size) in enumerate(zip(index, input_shape)):
if isinstance(indexing_element, int):
destroyed_dimensions.append(dimension)
size = 1
stride = 1
offset = (
indexing_element if indexing_element >= 0 else indexing_element + dimension_size
)
elif isinstance(indexing_element, slice):
size = determine_new_dimension_size(
indexing_element,
dimension_size,
dimension,
input_shape,
index_str,
)
stride = indexing_element.step if isinstance(indexing_element.step, int) else 1
offset = (
(
indexing_element.start
if indexing_element.start >= 0
else indexing_element.start + dimension_size
)
if isinstance(indexing_element.start, int)
else (0 if stride > 0 else dimension_size - 1)
)
else: # pragma: no cover
# this branch is impossible to reach with all the previous checks
# but let's keep it as an extra measure
raise NotImplementedError(
f"Indexing of {input_value} with {index_str} cannot be converted to MLIR",
)
offsets.append(offset)
sizes.append(size)
strides.append(stride)
if len(destroyed_dimensions) == 0:
return tensor.ExtractSliceOp(
tensor_type,
pred,
[],
[],
[],
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in offsets]),
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in sizes]),
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in strides]),
).result
output_value = cast(TensorValue, self.node.outputs[0])
intermediate_shape = list(output_value.shape)
for dimension in destroyed_dimensions:
intermediate_shape.insert(dimension, 1)
intermediate_type = RankedTensorType.get(
intermediate_shape,
integer_to_mlir_type(
self.ctx,
cast(Integer, output_value.dtype),
output_value.is_encrypted,
),
)
intermediate = tensor.ExtractSliceOp(
intermediate_type,
pred,
[],
[],
[],
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in offsets]),
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in sizes]),
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in strides]),
).result
reassociaton = []
current_intermediate_dimension = 0
for _ in range(len(output_value.shape)):
indices = [current_intermediate_dimension]
while current_intermediate_dimension in destroyed_dimensions:
current_intermediate_dimension += 1
indices.append(current_intermediate_dimension)
reassociaton.append(indices)
current_intermediate_dimension += 1
while current_intermediate_dimension < len(intermediate_shape):
reassociaton[-1].append(current_intermediate_dimension)
current_intermediate_dimension += 1
return linalg.TensorCollapseShapeOp(
tensor_type,
intermediate,
ArrayAttr.get(
[
ArrayAttr.get(
[IntegerAttr.get(index_type, index) for index in indices],
)
for indices in reassociaton
],
),
).result
# pylint: enable=too-many-locals
def convert_matmul(self) -> OpResult:
"""Convert a MatMul node to its corresponding MLIR representation.
Returns:
str: textual MLIR representation corresponding to self.node
"""
assert_true(len(self.node.inputs) == 2)
assert_true(len(self.node.outputs) == 1)
if self.all_of_the_inputs_are_encrypted:
lhs = self.node.inputs[0]
rhs = self.node.inputs[1]
raise NotImplementedError(
f"Matrix multiplication between {lhs} and {rhs} cannot be converted to MLIR yet",
)
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
preds = self.preds
assert isinstance(self.node.outputs[0], TensorValue)
if self.node.outputs[0].shape == ():
if self.node.inputs[0].is_clear:
preds = preds[::-1]
result = fhelinalg.Dot(resulting_type, *preds).result
elif self.node.inputs[0].is_clear:
result = fhelinalg.MatMulIntEintOp(resulting_type, *preds).result
else:
result = fhelinalg.MatMulEintIntOp(resulting_type, *preds).result
return result
def convert_mul(self) -> OpResult:
"""Convert a Mul node to its corresponding MLIR representation.
Returns:
str: textual MLIR representation corresponding to self.node
"""
assert_true(len(self.node.inputs) == 2)
assert_true(len(self.node.outputs) == 1)
if self.all_of_the_inputs_are_encrypted:
lhs = self.node.inputs[0]
rhs = self.node.inputs[1]
raise NotImplementedError(
f"Multiplication between {lhs} and {rhs} cannot be converted to MLIR yet",
)
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
preds = self.preds
if self.node.inputs[0].is_clear: # pragma: no cover
# this branch is not covered as it's impossible to get into due to how tracing works
# however, it doesn't hurt to keep it as an extra measure
preds = preds[::-1]
if self.one_of_the_inputs_is_a_tensor:
result = fhelinalg.MulEintIntOp(resulting_type, *preds).result
else:
result = fhe.MulEintIntOp(resulting_type, *preds).result
return result
def convert_reshape(self) -> OpResult:
"""Convert a "reshape" node to its corresponding MLIR representation.
Returns:
str: textual MLIR representation corresponding to self.node
"""
assert_true(len(self.node.inputs) == 1)
assert_true(len(self.node.outputs) == 1)
assert_true(isinstance(self.node.inputs[0], TensorValue))
input_shape = cast(TensorValue, self.node.inputs[0]).shape
assert_true(isinstance(self.node.outputs[0], TensorValue))
output_shape = cast(TensorValue, self.node.outputs[0]).shape
pred = self.preds[0]
if input_shape == output_shape:
return pred
# we can either collapse or expand, which changes the number of dimensions
# this is a limitation of the current compiler and it will be improved in the future (#1060)
can_be_converted_directly = len(input_shape) != len(output_shape)
reassociation: List[List[int]] = []
if can_be_converted_directly:
if len(output_shape) == 1:
# output is 1 dimensional so collapse every dimension into the same dimension
reassociation.append(list(range(len(input_shape))))
else:
# input is m dimensional
# output is n dimensional
# and m is different than n
# we don't want to duplicate code so we forget about input and output
# and we focus on smaller shape and bigger shape
smaller_shape, bigger_shape = (
(output_shape, input_shape)
if len(output_shape) < len(input_shape)
else (input_shape, output_shape)
)
s_index, b_index = 0, 0
# now we will figure out how to group the bigger shape to get the smaller shape
# think of the algorithm below as
# keep merging the dimensions of the bigger shape
# until we have a match on the smaller shape
# then try to match the next dimension of the smaller shape
# if all dimensions of the smaller shape is matched
# we can convert it
group = []
size = 1
while s_index < len(smaller_shape) and b_index < len(bigger_shape):
# dimension `b_index` of `bigger_shape` belongs to current group
group.append(b_index)
# and current group has `size * bigger_shape[b_index]` elements now
size *= bigger_shape[b_index]
# if current group size matches the dimension `s_index` of `smaller_shape`
if size == smaller_shape[s_index]:
# we finalize this group and reset everything
size = 1
reassociation.append(group)
group = []
# now try to match the next dimension of `smaller_shape`
s_index += 1
# now process the next dimension of `bigger_shape`
b_index += 1
# handle the case where bigger shape has proceeding 1s
# e.g., (5,) -> (5, 1)
while b_index < len(bigger_shape) and bigger_shape[b_index] == 1:
reassociation[-1].append(b_index)
b_index += 1
# if not all dimensions of both shapes are processed exactly
if s_index != len(smaller_shape) or b_index != len(bigger_shape):
# we cannot convert
can_be_converted_directly = False
index_type = IndexType.parse("index")
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
if can_be_converted_directly:
reassociation_attr = ArrayAttr.get(
[
ArrayAttr.get([IntegerAttr.get(index_type, dimension) for dimension in group])
for group in reassociation
]
)
if len(output_shape) < len(input_shape):
return linalg.TensorCollapseShapeOp(resulting_type, pred, reassociation_attr).result
return linalg.TensorExpandShapeOp(resulting_type, pred, reassociation_attr).result
flattened_type = value_to_mlir_type(
self.ctx,
TensorValue(
self.node.inputs[0].dtype,
self.node.inputs[0].is_encrypted,
(numpy.prod(input_shape),),
),
)
flattened_result = linalg.TensorCollapseShapeOp(
flattened_type,
pred,
ArrayAttr.get(
[ArrayAttr.get([IntegerAttr.get(index_type, i) for i in range(len(input_shape))])]
),
).result
return linalg.TensorExpandShapeOp(
resulting_type,
flattened_result,
ArrayAttr.get(
[ArrayAttr.get([IntegerAttr.get(index_type, i) for i in range(len(output_shape))])]
),
).result
def convert_sub(self) -> OpResult:
"""Convert a Sub node to its corresponding MLIR representation.
Returns:
str: textual MLIR representation corresponding to self.node
"""
assert_true(len(self.node.inputs) == 2)
assert_true(len(self.node.outputs) == 1)
lhs = self.node.inputs[0]
rhs = self.node.inputs[1]
if not (lhs.is_clear and rhs.is_encrypted):
raise NotImplementedError(
f"Subtraction of {rhs} from {lhs} cannot be converted to MLIR yet",
)
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
preds = self.preds
if self.one_of_the_inputs_is_a_tensor:
result = fhelinalg.SubIntEintOp(resulting_type, *preds).result
else:
result = fhe.SubIntEintOp(resulting_type, *preds).result
return result
def convert_sum(self) -> OpResult:
"""Convert a "sum" node to its corresponding MLIR representation.
Returns:
str: textual MLIR representation corresponding to self.node
"""
assert_true(len(self.node.inputs) == 1)
assert_true(len(self.node.outputs) == 1)
node = cast(GenericFunction, self.node)
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
axes = node.op_kwargs.get("axis", [])
keep_dims = node.op_kwargs.get("keepdims", False)
if isinstance(axes, int):
axes = [axes]
elif isinstance(axes, tuple):
axes = list(axes)
input_dimensions = len(cast(TensorValue, self.node.inputs[0]).shape)
for i, axis in enumerate(axes):
if axis < 0:
axes[i] += input_dimensions
return fhelinalg.SumOp(
resulting_type,
self.preds[0],
ArrayAttr.get([IntegerAttr.get(IntegerType.get_signless(64), axis) for axis in axes]),
BoolAttr.get(keep_dims),
).result
def convert_transpose(self) -> OpResult:
"""Convert a Transpose node to its corresponding MLIR representation.
Returns:
str: textual MLIR representation corresponding to self.node
"""
assert_true(len(self.node.inputs) == 1)
assert_true(len(self.node.outputs) == 1)
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
preds = self.preds
return fhelinalg.TransposeOp(resulting_type, *preds).result

View File

@@ -1,235 +0,0 @@
"""Utilities for MLIR conversion."""
from typing import Dict, List, Optional, Tuple
import networkx as nx
from ..data_types.dtypes_helpers import (
value_is_clear_scalar_integer,
value_is_clear_tensor_integer,
value_is_encrypted_scalar_integer,
value_is_encrypted_tensor_integer,
value_is_integer,
value_is_unsigned_integer,
)
from ..debugging import format_operation_graph
from ..debugging.custom_assert import assert_not_reached, assert_true
from ..operator_graph import OPGraph
from ..representation import intermediate
from ..representation.intermediate import Conv2D, IntermediateNode
# TODO: should be removed as the supported bit-width is now dynamic
ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB = 8
def check_node_compatibility_with_mlir(
node: IntermediateNode,
nx_graph: nx.MultiDiGraph,
is_output: bool,
) -> Optional[str]:
"""Check if node is compatible with MLIR.
Args:
node (IntermediateNode): node to check
nx_graph (nx.MultiDiGraph): the networkx graph to which node belongs
is_output (bool): whether the node is an output node or not
Returns:
Optional[str]: None if the node is compatible else reason for incompatibility
"""
# pylint: disable=too-many-branches,too-many-return-statements
inputs = node.inputs
outputs = node.outputs
if isinstance(node, intermediate.Add): # constraints for addition
for inp in inputs:
if not value_is_integer(inp):
return "only integer addition is supported"
elif isinstance(node, intermediate.Sub): # constraints for subtraction
for inp in inputs:
if not value_is_integer(inp):
return "only integer subtraction is supported"
elif isinstance(node, intermediate.Mul): # constraints for multiplication
for inp in inputs:
if not value_is_integer(inp):
return "only integer multiplication is supported"
elif isinstance(node, intermediate.Input): # constraints for inputs
assert_true(len(outputs) == 1)
if not value_is_unsigned_integer(outputs[0]):
return "only unsigned integer inputs are supported"
elif isinstance(node, intermediate.Constant): # constraints for constants
assert_true(len(outputs) == 1)
# We currently can't fail on the following assert, but let it for possible changes in the
# future
if not value_is_integer(outputs[0]):
return "only integer constants are supported" # pragma: no cover
elif isinstance(node, intermediate.GenericFunction): # constraints for univariate functions
for inp in inputs:
if not value_is_integer(inp):
return (
f"{node.op_name} with floating-point inputs "
f"is required to be fused to be supported"
)
if node.op_kind == "TLU":
assert_true(
len(
[
pred_node
for pred_node in nx_graph.pred[node]
if not isinstance(pred_node, intermediate.Constant)
]
)
== 1
)
else:
if node.op_name not in ["flatten", "reshape", "sum", "concat", "transpose"]:
return f"{node.op_name} is not supported for the time being"
elif isinstance(node, intermediate.Dot): # constraints for dot product
assert_true(len(inputs) == 2)
if not value_is_integer(inputs[0]) or not value_is_integer(inputs[1]):
return "only integer dot product is supported"
elif isinstance(node, intermediate.IndexConstant): # constraints for constant indexing
assert_true(len(outputs) == 1)
elif isinstance(node, intermediate.MatMul): # constraints for matrix multiplication
assert_true(len(inputs) == 2)
elif isinstance(node, Conv2D):
assert_true(len(inputs) in [2, 3])
else: # pragma: no cover
assert_not_reached("Non IntermediateNode object in the OPGraph")
if is_output:
for out in outputs:
# For signed values and waiting for a real fix (#845): what is returned by the compiler
# is not the (possibly negative) result r, but the always-positive (r mod 2**t), where t
# is the bitwidth of r
# We currently can't fail on the following assert, but let it for possible changes in
# the future
if not value_is_integer(out):
return "only integer outputs are supported" # pragma: no cover
else:
for out in outputs:
# We currently can't fail on the following assert, but let it for possible changes in
# the future
if not value_is_integer(out):
return "only integer intermediates are supported" # pragma: no cover
# pylint: enable=too-many-branches,too-many-return-statements
return None
def check_graph_values_compatibility_with_mlir(
op_graph: OPGraph,
) -> Optional[Dict[IntermediateNode, List[str]]]:
"""Make sure the graph outputs are unsigned integers, which is what the compiler supports.
Args:
op_graph: computation graph to check
Returns:
Dict[IntermediateNode, str]: None if the graph is compatible
information about offending nodes otherwise
"""
offending_nodes = {}
for node in op_graph.graph.nodes:
is_output = node in op_graph.output_nodes.values()
if (
reason := check_node_compatibility_with_mlir(node, op_graph.graph, is_output)
) is not None:
offending_nodes[node] = [reason]
return None if len(offending_nodes) == 0 else offending_nodes
def _set_all_bit_width(op_graph: OPGraph, p: int):
"""Set all bit_width in the graph to `p` and `p+1` for clear and encrypted values respectively.
Args:
op_graph: graph to set bit_width for
p: bit_width to set everywhere
"""
for node in op_graph.graph.nodes:
for value in node.outputs + node.inputs:
if value_is_clear_scalar_integer(value) or value_is_clear_tensor_integer(value):
value.dtype.bit_width = p + 1
elif value_is_encrypted_scalar_integer(value) or value_is_encrypted_tensor_integer(
value
):
value.dtype.bit_width = p
def get_op_graph_max_bit_width_and_nodes_over_bit_width_limit(
op_graph: OPGraph,
) -> Tuple[int, Dict[IntermediateNode, List[str]]]:
"""Get the maximum bit width of integer nodes in the given OPGraph.
Also returns a dictionary with nodes having an unsupported bit width.
Args:
op_graph: graph to update bit_width for
Returns:
Tuple[int, Dict[IntermediateNode, List[str]]]: a tuple containing the maximum bit width of
integer values in the OPGraph as well as a dictionary with nodes and the list of issues
that the nodes have, in this case having an unsupported bit width.
"""
max_bit_width = 0
offending_nodes: Dict[IntermediateNode, List[str]] = {}
for node in op_graph.graph.nodes:
for value_out in node.outputs:
if value_is_clear_scalar_integer(value_out) or value_is_clear_tensor_integer(value_out):
current_node_out_bit_width = value_out.dtype.bit_width - 1
else:
assert_true(
value_is_encrypted_scalar_integer(value_out)
or value_is_encrypted_tensor_integer(value_out)
)
current_node_out_bit_width = value_out.dtype.bit_width
max_bit_width = max(max_bit_width, current_node_out_bit_width)
# Check that current_node_out_bit_width is supported by the compiler
if current_node_out_bit_width > ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB:
offending_nodes[node] = [
f"{current_node_out_bit_width} bits is not supported for the time being"
]
return max_bit_width, offending_nodes
def update_bit_width_for_mlir(op_graph: OPGraph):
"""Prepare bit_width of all nodes to be the same, set to the maximum value in the graph.
Args:
op_graph: graph to update bit_width for
"""
max_bit_width, offending_nodes = get_op_graph_max_bit_width_and_nodes_over_bit_width_limit(
op_graph
)
if len(offending_nodes) != 0:
raise RuntimeError(
f"max_bit_width of some nodes is too high for the current version of "
f"the compiler (maximum must be {ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB}) "
f"which is not compatible with:\n\n"
+ format_operation_graph(op_graph, highlighted_nodes=offending_nodes)
)
_set_all_bit_width(op_graph, max_bit_width)

View File

@@ -1,319 +0,0 @@
"""Code to wrap and make manipulating networkx graphs easier."""
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
import networkx as nx
from .data_types.base import BaseDataType
from .data_types.dtypes_helpers import (
get_base_data_type_for_python_constant_data,
get_constructor_for_python_constant_data,
)
from .data_types.floats import Float
from .data_types.integers import Integer, make_integer_to_hold
from .debugging.custom_assert import assert_true
from .representation.intermediate import Input, IntermediateNode
from .tracing import BaseTracer
from .tracing.tracing_helpers import create_graph_from_output_tracers
class OPGraph:
"""Class to make work with nx graphs easier."""
graph: nx.MultiDiGraph
input_nodes: Dict[int, Input]
output_nodes: Dict[int, IntermediateNode]
def __init__(
self,
graph: nx.MultiDiGraph,
input_nodes: Dict[int, Input],
output_nodes: Dict[int, IntermediateNode],
) -> None:
assert_true(
all(isinstance(node, Input) for node in input_nodes.values()),
"Got input nodes that were not Input, which is not supported",
)
assert_true(
all(isinstance(node, IntermediateNode) for node in output_nodes.values()),
"Got output nodes which were not IntermediateNode, which is not supported",
)
self.graph = graph
self.input_nodes = input_nodes
self.output_nodes = output_nodes
self.prune_nodes()
def __call__(self, *args) -> Union[Any, Tuple[Any, ...]]:
assert_true(len(self.input_nodes) > 0, "Cannot evaluate a graph with no input nodes")
inputs = dict(enumerate(args))
assert_true(
len(inputs) == len(self.input_nodes),
f"Expected {len(self.input_nodes)} arguments, got {len(inputs)} : {args}",
)
results = self.evaluate(inputs)
tuple_result = tuple(results[output_node] for output_node in self.get_ordered_outputs())
return tuple_result if len(tuple_result) > 1 else tuple_result[0]
@staticmethod
def from_output_tracers(output_tracers: Iterable[BaseTracer]) -> "OPGraph":
"""Construct OPGraph from output tracers.
Args:
output_tracers (Iterable[BaseTracer]): The tracers output by the function that was
traced.
Returns:
OPGraph: The resulting OPGraph.
"""
graph = create_graph_from_output_tracers(output_tracers)
input_nodes = {
node.program_input_idx: node
for node in graph.nodes()
if len(graph.pred[node]) == 0 and isinstance(node, Input)
}
output_nodes = {
output_idx: tracer.traced_computation
for output_idx, tracer in enumerate(output_tracers)
}
return OPGraph(graph, input_nodes, output_nodes)
@staticmethod
def from_graph(
graph: nx.MultiDiGraph,
input_nodes: Iterable[Input],
output_nodes: Iterable[IntermediateNode],
) -> "OPGraph":
"""Construct OPGraph from an existing networkx MultiDiGraph.
Args:
graph (nx.MultiDiGraph): The networkx MultiDiGraph to use.
input_nodes (Iterable[Input]): The input nodes of the MultiDiGraph.
output_nodes (Iterable[IntermediateNode]): The output nodes of the MultiDiGraph.
Returns:
OPGraph: The resulting OPGraph.
"""
return OPGraph(graph, dict(enumerate(input_nodes)), dict(enumerate(output_nodes)))
def get_ordered_inputs(self) -> List[Input]:
"""Get the input nodes of the graph, ordered by their index.
Returns:
List[Input]: ordered input nodes
"""
return [self.input_nodes[idx] for idx in range(len(self.input_nodes))]
def get_ordered_outputs(self) -> List[IntermediateNode]:
"""Get the output nodes of the graph, ordered by their index.
Returns:
List[IntermediateNode]: ordered input nodes
"""
return [self.output_nodes[idx] for idx in range(len(self.output_nodes))]
def get_ordered_preds(self, node: IntermediateNode) -> List[IntermediateNode]:
"""Get node predecessors ordered by their indices.
Args:
node (IntermediateNode): The node for which we want the ordered predecessors.
Returns:
List[IntermediateNode]: The list of predecessors ordered by input index.
"""
# Replication of pred is managed e.g. x + x will yield the proper pred x twice
idx_to_pred: Dict[int, IntermediateNode] = {}
for pred in self.graph.predecessors(node):
edge_data = self.graph.get_edge_data(pred, node)
idx_to_pred.update((data["input_idx"], pred) for data in edge_data.values())
return [idx_to_pred[i] for i in range(len(idx_to_pred))]
def get_ordered_preds_and_inputs_of(
self, node: IntermediateNode
) -> List[Tuple[IntermediateNode, int]]:
"""Get node preds and inputs ordered by their indices.
Args:
node (IntermediateNode): the node for which we want the ordered inputs
Returns:
List[Tuple[IntermediateNode, int]]: the ordered list of preds and inputs
"""
idx_to_inp: Dict[int, Tuple[IntermediateNode, int]] = {}
for pred in self.graph.predecessors(node):
edge_data = self.graph.get_edge_data(pred, node)
idx_to_inp.update(
(data["input_idx"], (pred, data["output_idx"])) for data in edge_data.values()
)
return [idx_to_inp[i] for i in range(len(idx_to_inp))]
def evaluate(self, inputs: Dict[int, Any]) -> Dict[IntermediateNode, Any]:
"""Evaluate a graph and get intermediate values for all nodes.
Args:
inputs (Dict[int, Any]): The inputs to the program
Returns:
Dict[IntermediateNode, Any]: Dictionary with node as keys and resulting values
"""
node_results: Dict[IntermediateNode, Any] = {}
def get_result_of_node_at_index(node: IntermediateNode, output_idx: int) -> Any:
"""Get the output result at index output_idx for a node.
Args:
node (IntermediateNode): the node from which we want the output.
output_idx (int): which output we want.
Returns:
Any: the output value of the evaluation of node.
"""
result = node_results[node]
# TODO: #81 remove no cover once we have nodes with multiple outputs
if isinstance(result, tuple): # pragma: no cover
# If the node has multiple outputs (i.e. the result is a tuple), return the
# requested output
return result[output_idx]
# If the result is not a tuple, then the result is the node's only output. Check that
# the requested index is 0 (as it's the only valid value) and return the result itself.
assert_true(
output_idx == 0,
f"Unable to get output at index {output_idx} for node {node}.\n"
f"Node result: {result}",
)
return result
for node in nx.topological_sort(self.graph):
if not isinstance(node, Input):
curr_inputs = {}
for pred_node in self.graph.predecessors(node):
edges = self.graph.get_edge_data(pred_node, node)
curr_inputs.update(
{
edge["input_idx"]: get_result_of_node_at_index(
pred_node,
output_idx=edge["output_idx"],
)
for edge in edges.values()
}
)
node_results[node] = node.evaluate(curr_inputs)
else:
node_results[node] = node.evaluate({0: inputs[node.program_input_idx]})
return node_results
def update_values_with_bounds_and_samples(
self,
node_bounds_and_samples: dict,
get_base_data_type_for_constant_data: Callable[
[Any], BaseDataType
] = get_base_data_type_for_python_constant_data,
get_constructor_for_constant_data: Callable[
..., Callable
] = get_constructor_for_python_constant_data,
):
"""Update values with bounds.
Update nodes inputs and outputs values with data types able to hold data ranges measured
and passed in nodes_bounds
Args:
node_bounds_and_samples (dict): Dictionary with nodes as keys, holding dicts with a
'min', 'max' and 'sample' keys. Those bounds will be taken as the data range to be
represented, per node. The sample allows to determine the data constructors to
prepare the GenericFunction nodes for table generation.
get_base_data_type_for_constant_data (Callable[ [Any], BaseDataType ], optional): This
is a callback function to convert data encountered during value updates to
BaseDataType. This allows to manage data coming from foreign frameworks without
specialising OPGraph. Defaults to get_base_data_type_for_python_constant_data.
get_constructor_for_constant_data (Callable[ ..., Callable ], optional): This is a
callback function to determine the type constructor of the data encountered while
updating the graph bounds. Defaults to get_constructor_for_python_constant_data.
"""
node: IntermediateNode
for node in self.graph.nodes():
current_node_bounds_and_samples = node_bounds_and_samples[node]
min_bound, max_bound, sample = (
current_node_bounds_and_samples["min"],
current_node_bounds_and_samples["max"],
current_node_bounds_and_samples["sample"],
)
min_data_type = get_base_data_type_for_constant_data(min_bound)
max_data_type = get_base_data_type_for_constant_data(max_bound)
# This is a sanity check
min_value_constructor = get_constructor_for_constant_data(min_bound)
max_value_constructor = get_constructor_for_constant_data(max_bound)
assert_true(
max_value_constructor == min_value_constructor,
(
f"Got two different type constructors for min and max bound: "
f"{min_value_constructor}, {max_value_constructor}"
),
)
value_constructor = get_constructor_for_constant_data(sample)
if not isinstance(node, Input):
for output_value in node.outputs:
if isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer):
output_value.dtype = make_integer_to_hold(
(min_bound, max_bound), force_signed=False
)
else:
assert_true(
isinstance(min_data_type, Float) and isinstance(max_data_type, Float),
(
"min_bound and max_bound have different common types, "
"this should never happen.\n"
f"min_bound: {min_data_type}, max_bound: {max_data_type}"
),
)
output_value.dtype = Float(64)
output_value.underlying_constructor = value_constructor
else:
# Currently variable inputs are only allowed to be integers
assert_true(
isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer),
(
f"Inputs to a graph should be integers, got bounds that were float, \n"
f"min: {min_bound} ({type(min_bound)}), "
f"max: {max_bound} ({type(max_bound)})"
),
)
node.inputs[0].dtype = make_integer_to_hold(
(min_bound, max_bound), force_signed=False
)
node.inputs[0].underlying_constructor = value_constructor
node.outputs[0] = deepcopy(node.inputs[0])
successors = self.graph.successors(node)
for succ in successors:
edge_data = self.graph.get_edge_data(node, succ)
for edge in edge_data.values():
input_idx, output_idx = edge["input_idx"], edge["output_idx"]
succ.inputs[input_idx] = deepcopy(node.outputs[output_idx])
def prune_nodes(self):
"""Remove unreachable nodes from outputs."""
current_nodes = {node: None for node in self.get_ordered_outputs()}
useful_nodes: Dict[IntermediateNode, None] = {}
while current_nodes:
next_nodes: Dict[IntermediateNode, None] = {}
useful_nodes.update(current_nodes)
for node in current_nodes:
next_nodes.update({node: None for node in self.graph.predecessors(node)})
current_nodes = next_nodes
useless_nodes = [node for node in self.graph.nodes() if node not in useful_nodes]
self.graph.remove_nodes_from(useless_nodes)

View File

@@ -1 +0,0 @@
"""Module holding various optimization/simplification code."""

View File

@@ -1,594 +0,0 @@
"""File holding topological optimization/simplification code."""
from collections import defaultdict
from copy import deepcopy
from typing import DefaultDict, Dict, Iterable, List, Optional, Set, Tuple, cast
import networkx as nx
from loguru import logger
from ..compilation.artifacts import CompilationArtifacts
from ..data_types.floats import Float
from ..data_types.integers import Integer
from ..debugging import format_operation_graph
from ..debugging.custom_assert import assert_true
from ..operator_graph import OPGraph
from ..representation.intermediate import Constant, GenericFunction, Input, IntermediateNode
from ..values import TensorValue
def fuse_float_operations(
op_graph: OPGraph,
compilation_artifacts: Optional[CompilationArtifacts] = None,
):
"""Find and fuse float domains into single Integer to Integer GenericFunction.
Args:
op_graph (OPGraph): The OPGraph to simplify
compilation_artifacts (Optional[CompilationArtifacts]): The CompilationArtifacts of the
current compilation, this argument is optional as it's not required to execute float
fusing.
"""
nx_graph = op_graph.graph
processed_terminal_nodes: Set[IntermediateNode] = set()
number_of_fuse = 0
while True:
float_subgraph_search_result = find_float_subgraph_with_unique_terminal_node(
nx_graph, processed_terminal_nodes
)
if float_subgraph_search_result is None:
break
float_subgraph_start_nodes, terminal_node, subgraph_all_nodes = float_subgraph_search_result
processed_terminal_nodes.add(terminal_node)
subgraph_conversion_result = convert_float_subgraph_to_fused_node(
op_graph,
float_subgraph_start_nodes,
terminal_node,
subgraph_all_nodes,
)
# Not a subgraph we can handle, continue
if subgraph_conversion_result is None:
continue
fused_node, node_before_subgraph = subgraph_conversion_result
nx_graph.add_node(fused_node)
if terminal_node in op_graph.output_nodes.values():
# Output value replace it
# As the graph changes recreate the output_node_to_idx dict
output_node_to_idx: Dict[IntermediateNode, List[int]] = {
out_node: [] for out_node in op_graph.output_nodes.values()
}
for output_idx, output_node in op_graph.output_nodes.items():
output_node_to_idx[output_node].append(output_idx)
for output_idx in output_node_to_idx.get(terminal_node, []):
op_graph.output_nodes[output_idx] = fused_node
# Disconnect after terminal node and connect fused node instead
terminal_node_succ = list(nx_graph.successors(terminal_node))
for succ in terminal_node_succ:
succ_edge_data = deepcopy(nx_graph.get_edge_data(terminal_node, succ))
for edge_key, edge_data in succ_edge_data.items():
nx_graph.remove_edge(terminal_node, succ, key=edge_key)
# fused_node is always a GenericFunction so output_idx == 0 always
new_edge_data = deepcopy(edge_data)
new_edge_data["output_idx"] = 0
nx_graph.add_edge(fused_node, succ, key=edge_key, **new_edge_data)
# Connect the node feeding the subgraph contained in fused_node
# node_before_subgraph has a single integer output currently so output_idx == 0
nx_graph.add_edge(node_before_subgraph, fused_node, input_idx=0, output_idx=0)
op_graph.prune_nodes()
if compilation_artifacts is not None:
compilation_artifacts.add_operation_graph(
f"after-float-fuse-{number_of_fuse}", op_graph
)
number_of_fuse += 1
def convert_float_subgraph_to_fused_node(
op_graph: OPGraph,
float_subgraph_start_nodes: Dict[IntermediateNode, None],
terminal_node: IntermediateNode,
subgraph_all_nodes: Dict[IntermediateNode, None],
) -> Optional[Tuple[GenericFunction, IntermediateNode]]:
"""Convert a float subgraph to an equivalent fused GenericFunction node.
Args:
op_graph (OPGraph): The OPGraph the float subgraph is part of.
float_subgraph_start_nodes (Dict[IntermediateNode, None]): The nodes starting the float
subgraph in `op_graph`.
terminal_node (IntermediateNode): The node ending the float subgraph.
subgraph_all_nodes (Dict[IntermediateNode, None]): All the nodes in the float subgraph.
Returns:
Optional[Tuple[GenericFunction, IntermediateNode]]: None if the float subgraph
cannot be fused, otherwise returns a tuple containing the fused node and the node whose
output must be plugged as the input to the subgraph.
"""
node_with_issues_for_fusing: DefaultDict[IntermediateNode, List[str]] = defaultdict(list)
subgraph_can_be_fused = subgraph_has_unique_variable_input(
float_subgraph_start_nodes, terminal_node, node_with_issues_for_fusing
)
if subgraph_can_be_fused:
# subgraph_values_allow_fusing can be called iff the subgraph has a unique variable input
subgraph_can_be_fused = subgraph_nodes_and_values_allow_fusing(
float_subgraph_start_nodes, subgraph_all_nodes, node_with_issues_for_fusing
)
# This test is separate from the previous one to only handle printing issues once
if not subgraph_can_be_fused:
float_subgraph = nx.MultiDiGraph(op_graph.graph.subgraph(subgraph_all_nodes))
float_subgraph_as_op_graph = OPGraph.from_graph(float_subgraph, [], [terminal_node])
printable_graph = format_operation_graph(
float_subgraph_as_op_graph,
highlighted_nodes=node_with_issues_for_fusing,
)
message = f"The following subgraph is not fusable:\n\n{printable_graph}"
logger.warning(message)
return None
# Only one variable input node, find which node feeds its input
variable_input_nodes = [
node for node in float_subgraph_start_nodes if not isinstance(node, Constant)
]
assert_true(len(variable_input_nodes) == 1)
current_subgraph_variable_input = variable_input_nodes[0]
assert_true(len(current_subgraph_variable_input.outputs) == 1)
new_input_value = deepcopy(current_subgraph_variable_input.outputs[0])
nx_graph = op_graph.graph
nodes_after_input_set = {
node: None
for node in subgraph_all_nodes
if node in nx_graph.succ[current_subgraph_variable_input]
}
# # Previous non-deterministic implementation :
# # For some reason creating a graph from a subgraph this way is not deterministic
# float_subgraph = nx.MultiDiGraph(nx_graph.subgraph(subgraph_all_nodes))
# Create a copy of the graph, remove nodes that are not in all the subgraph nodes in order to
# get a subgraph deterministically
float_subgraph = nx.MultiDiGraph(nx_graph)
nodes_to_remove = [node for node in float_subgraph.nodes() if node not in subgraph_all_nodes]
float_subgraph.remove_nodes_from(nodes_to_remove)
new_subgraph_variable_input = Input(new_input_value, "float_subgraph_input", 0)
float_subgraph.add_node(new_subgraph_variable_input)
for node_after_input in nodes_after_input_set:
# Connect the new input to our subgraph
edge_data_input_to_subgraph = deepcopy(
float_subgraph.get_edge_data(
current_subgraph_variable_input,
node_after_input,
)
)
for edge_key, edge_data in edge_data_input_to_subgraph.items():
float_subgraph.remove_edge(
current_subgraph_variable_input, node_after_input, key=edge_key
)
# new_subgraph_variable_input is always an Input so output_idx == 0 always
new_edge_data = deepcopy(edge_data)
new_edge_data["output_idx"] = 0
float_subgraph.add_edge(
new_subgraph_variable_input,
node_after_input,
key=edge_key,
**new_edge_data,
)
float_op_subgraph = OPGraph.from_graph(
float_subgraph,
[new_subgraph_variable_input],
[terminal_node],
)
assert_true(len(terminal_node.outputs) == 1)
# Create fused_node
fused_node = GenericFunction(
inputs=[new_subgraph_variable_input.inputs[0]],
arbitrary_func=lambda x, float_op_subgraph, terminal_node: float_op_subgraph.evaluate(
{0: x}
)[terminal_node],
output_value=terminal_node.outputs[0],
op_kind="TLU",
op_kwargs={
"float_op_subgraph": float_op_subgraph,
"terminal_node": terminal_node,
},
op_name="subgraph",
)
return (
fused_node,
current_subgraph_variable_input,
)
def is_single_int_output_node(node: IntermediateNode) -> bool:
"""Check if a node has a single output and that output is an integer.
Args:
node (IntermediateNode): the node to check.
Returns:
bool: returns True if the node has a single integer output, False otherwise.
"""
return len(node.outputs) == 1 and isinstance(node.outputs[0].dtype, Integer)
def find_closest_single_int_output_nodes(
nx_graph: nx.MultiDiGraph,
start_nodes: List[IntermediateNode],
subgraph_all_nodes: Dict[IntermediateNode, None],
) -> Tuple[Dict[IntermediateNode, None], Dict[IntermediateNode, None]]:
"""Find in nx_graph the closest upstream single integer output nodes to some start nodes.
Args:
nx_graph (nx.MultiDiGraph): the networkx graph to search in.
start_nodes (List[IntermediateNode]): the nodes from which to start the search.
subgraph_all_nodes (Dict[IntermediateNode, None]): a set that will be updated with all the
nodes visited during the search.
Returns:
Tuple[Dict[IntermediateNode, None], Dict[IntermediateNode, None]]: returns the dict used as
an ordered set containing the found single output nodes and the updated set of the
visited nodes during the search.
"""
# Use dict as ordered set
current_nodes = {start_node: None for start_node in start_nodes}
closest_single_int_output_nodes: Dict[IntermediateNode, None] = {}
visited_nodes: Set[IntermediateNode] = set()
while current_nodes:
next_nodes: Dict[IntermediateNode, None] = {}
for node in current_nodes:
if node in visited_nodes:
continue
visited_nodes.add(node)
subgraph_all_nodes.update({node: None})
predecessors = nx_graph.predecessors(node)
for pred in predecessors:
if is_single_int_output_node(pred):
# Limit of subgraph, record that and record the node as we won't visit it
closest_single_int_output_nodes.update({pred: None})
subgraph_all_nodes.update({pred: None})
else:
next_nodes.update({pred: None})
current_nodes = next_nodes
return closest_single_int_output_nodes, subgraph_all_nodes
def add_nodes_from_to(
nx_graph: nx.MultiDiGraph,
from_nodes: Iterable[IntermediateNode],
to_nodes: Dict[IntermediateNode, None],
subgraph_all_nodes: Dict[IntermediateNode, None],
) -> Dict[IntermediateNode, None]:
"""Add nodes from from_nodes to to_nodes to the subgraph_all_nodes set.
Args:
nx_graph (nx.MultiDiGraph): the graph to traverse.
from_nodes (Iterable[IntermediateNode]): the nodes from which we will add nodes to
subgraph_all_nodes.
to_nodes (Dict[IntermediateNode, None]): the nodes we should stop at.
subgraph_all_nodes (Dict[IntermediateNode, None]): All the nodes in the float subgraph, will
be updated and returned.
Returns:
Dict[IntermediateNode, None]: returns the updated subgraph_all_nodes.
"""
# Add the end nodes we won't visit
subgraph_all_nodes.update(to_nodes)
current_nodes = {from_node: None for from_node in from_nodes}
visited_nodes: Set[IntermediateNode] = set()
while current_nodes:
next_nodes: Dict[IntermediateNode, None] = {}
for node in current_nodes:
if node in visited_nodes:
continue
visited_nodes.add(node)
subgraph_all_nodes.update({node: None})
predecessors = nx_graph.predecessors(node)
# Add nodes to explore next if they are not indicated as end nodes
next_nodes.update({pred: None for pred in predecessors if pred not in to_nodes})
current_nodes = next_nodes
return subgraph_all_nodes
def find_float_subgraph_with_unique_terminal_node(
nx_graph: nx.MultiDiGraph,
processed_terminal_nodes: Set[IntermediateNode],
) -> Optional[Tuple[Dict[IntermediateNode, None], IntermediateNode, Dict[IntermediateNode, None]]]:
"""Find a subgraph of the graph with float computations.
Args:
nx_graph (nx.MultiDiGraph): The networkx graph to search in.
processed_terminal_nodes (Dict[IntermediateNode, None]): The set of terminal nodes for which
subgraphs have already been searched, those will be skipped.
Returns:
Optional[
Tuple[Dict[IntermediateNode, None], IntermediateNode, Dict[IntermediateNode, None]]]:
None if there are no float subgraphs to process in `nx_graph`. Otherwise returns a
tuple containing the set of nodes beginning a float subgraph, the terminal node of
the subgraph and the set of all the nodes in the subgraph.
"""
def is_float_to_single_int_node(node: IntermediateNode) -> bool:
return (
any(isinstance(input_.dtype, Float) for input_ in node.inputs)
and len(node.outputs) == 1
and isinstance(node.outputs[0].dtype, Integer)
)
float_subgraphs_terminal_nodes = (
node
for node in nx_graph.nodes()
if is_float_to_single_int_node(node) and node not in processed_terminal_nodes
)
terminal_node: IntermediateNode
try:
terminal_node = next(float_subgraphs_terminal_nodes)
except StopIteration:
return None
# networkx does not implement lowest common ancestor search for multidigraph, but we only care
# about parent relationship here and not the meaning of edges, so we can convert our
# multidigraph to a digraph and use the lca search algorithm (if needed), we create the
# equivalent digraph here as it will avoid recreating it in a loop. Constant nodes could cause
# issues in our search so we remove them.
equivalent_digraph_without_constants = nx.DiGraph(nx_graph)
constant_graph_nodes = [
constant_node
for constant_node in equivalent_digraph_without_constants.nodes()
if isinstance(constant_node, Constant)
]
equivalent_digraph_without_constants.remove_nodes_from(constant_graph_nodes)
# Use dict as ordered set
subgraph_all_nodes: Dict[IntermediateNode, None] = {}
start_single_int_output_nodes_search_from = terminal_node
while True:
float_subgraph_start_nodes, subgraph_all_nodes = find_closest_single_int_output_nodes(
nx_graph,
[start_single_int_output_nodes_search_from],
subgraph_all_nodes,
)
variable_start_nodes = [
start_node
for start_node in float_subgraph_start_nodes
if not isinstance(start_node, Constant)
]
# We found a single input variable node
if len(variable_start_nodes) == 1:
break
# Otherwise find a common ancestor as we need a single variable input node
# lca == lowest common ancestor
# lca search only works for node pairs in networkx, so we progressively find the ancestors
# setting the lca by default to one of the nodes we are searching the lca for
lca = variable_start_nodes.pop()
while len(variable_start_nodes) > 0 and lca is not None:
node_to_find_lca = variable_start_nodes.pop()
lca = nx.algorithms.lowest_common_ancestors.lowest_common_ancestor(
equivalent_digraph_without_constants, lca, node_to_find_lca, default=None
)
# The subgraph cannot be fused as there is no way to find a common ancestor
if lca is None:
break
# if lca is not None, add the nodes from the current start nodes to the lca to
# subgraph_all_nodes
subgraph_all_nodes = add_nodes_from_to(
nx_graph, float_subgraph_start_nodes, {lca: None}, subgraph_all_nodes
)
# if the lca is a valid starting node for fusing break
if is_single_int_output_node(lca):
# the lca is our new start node
float_subgraph_start_nodes = {lca: None}
break
# otherwise push a little bit further the search (if there is a node just before that has an
# integer output e.g.)
start_single_int_output_nodes_search_from = lca
return float_subgraph_start_nodes, terminal_node, subgraph_all_nodes
def subgraph_nodes_and_values_allow_fusing(
float_subgraph_start_nodes: Dict[IntermediateNode, None],
subgraph_all_nodes: Dict[IntermediateNode, None],
node_with_issues_for_fusing: DefaultDict[IntermediateNode, List[str]],
) -> bool:
"""Check if a subgraph's values are compatible with fusing.
A fused subgraph for example only works on an input tensor if the resulting GenericFunction
can be applied per cell, hence shuffling or tensor shape changes make fusing impossible.
Args:
float_subgraph_start_nodes (Dict[IntermediateNode, None]): The nodes starting the float
subgraph.
subgraph_all_nodes (Dict[IntermediateNode, None]): All the nodes in the float subgraph.
node_with_issues_for_fusing (DefaultDict[IntermediateNode, List[str]]): Dictionary to fill
with potential nodes issues preventing fusing.
Returns:
bool: True if all inputs and outputs of the nodes in the subgraph are compatible with fusing
i.e. outputs have the same shapes equal to the variable input.
"""
node: IntermediateNode
variable_input_nodes = [
node for node in float_subgraph_start_nodes if not isinstance(node, Constant)
]
assert_true(
(num_variable_input_nodes := len(variable_input_nodes)) == 1,
f"{subgraph_nodes_and_values_allow_fusing.__name__} "
f"only works for subgraphs with 1 variable input node, got {num_variable_input_nodes}",
)
explicitely_non_fusable = [
node
for node in subgraph_all_nodes
if isinstance(node, GenericFunction) and not node.op_attributes["fusable"]
]
for node in explicitely_non_fusable:
node_with_issues_for_fusing[node].append(
"this node is explicitely marked by the package as non-fusable"
)
if len(explicitely_non_fusable) > 0:
return False
all_values_are_tensors = all(
all(isinstance(input_, TensorValue) for input_ in node.inputs)
and all(isinstance(output, TensorValue) for output in node.outputs)
for node in subgraph_all_nodes
)
if not all_values_are_tensors:
# This cannot be reached today as scalars are Tensors with shape == () (numpy convention)
return False # pragma: no cover
variable_input_node = variable_input_nodes[0]
# A cheap check is that the variable input node must have the biggest size, i.e. have the most
# elements, meaning all constants will broadcast to its shape. This is because the
# GenericFunction input and output must have the same shape so that it can be applied to each
# of the input tensor cells.
# There *may* be a way to manage the other case by simulating the broadcast of the smaller input
# array and then concatenating/stacking the results. This is not currently doable as we don't
# have a concatenate operator on the compiler side.
# TODO: #587 https://github.com/zama-ai/concrete-numpy-internal/issues/587
variable_input_node_output = cast(TensorValue, variable_input_node.outputs[0])
variable_input_node_output_size, variable_input_node_output_shape = (
variable_input_node_output.size,
variable_input_node_output.shape,
)
constant_nodes_with_bigger_size_than_variable_input = [
constant_input_node
for constant_input_node in subgraph_all_nodes
if isinstance(constant_input_node, Constant)
and cast(TensorValue, constant_input_node.outputs[0]).size > variable_input_node_output_size
]
for bigger_constant_node in constant_nodes_with_bigger_size_than_variable_input:
bigger_constant_node_shape = cast(TensorValue, bigger_constant_node.outputs[0]).shape
node_with_issues_for_fusing[bigger_constant_node].append(
f"this constant node has a bigger shape {bigger_constant_node_shape} "
f"than the subgraph's input: {variable_input_node_output_shape}"
)
if len(constant_nodes_with_bigger_size_than_variable_input) > 0:
node_with_issues_for_fusing[variable_input_node].append(
f"input node with shape {variable_input_node_output_shape}"
)
return False
# Now that we know the variable input node has the biggest size we can check shapes are
# consistent throughout the subgraph: outputs of ir nodes that are not constant must be equal.
non_constant_nodes = (node for node in subgraph_all_nodes if not isinstance(node, Constant))
nodes_with_different_output_shapes = {
node: [
(output_idx, output.shape)
for output_idx, output in enumerate(node.outputs)
if isinstance(output, TensorValue) and output.shape != variable_input_node
]
for node in non_constant_nodes
if any(
isinstance(output, TensorValue) and output.shape != variable_input_node_output_shape
for output in node.outputs
)
}
for node, node_shape_infos in nodes_with_different_output_shapes.items():
shape_issue_details = "; ".join(
f"#{output_idx}, {output_shape}" for output_idx, output_shape in node_shape_infos
)
node_with_issues_for_fusing[node].append(
f"output shapes: {shape_issue_details} are not the same as the subgraph's input: "
f"{variable_input_node_output_shape}"
)
all_nodes_have_same_shape_as_input = len(nodes_with_different_output_shapes) == 0
if not all_nodes_have_same_shape_as_input:
node_with_issues_for_fusing[variable_input_node].append(
f"input node with shape {variable_input_node_output_shape}"
)
# All non constant node outputs currently need to have the same shape
return all_nodes_have_same_shape_as_input
def subgraph_has_unique_variable_input(
float_subgraph_start_nodes: Dict[IntermediateNode, None],
terminal_node: IntermediateNode,
node_with_issues_for_fusing: DefaultDict[IntermediateNode, List[str]],
) -> bool:
"""Check that only one of the nodes starting the subgraph is variable.
Args:
float_subgraph_start_nodes (Dict[IntermediateNode, None]): The nodes starting the subgraph.
terminal_node (IntermediateNode): The node ending the float subgraph.
node_with_issues_for_fusing (DefaultDict[IntermediateNode, List[str]]): Dictionary to fill
with potential nodes issues preventing fusing.
Returns:
bool: True if only one of the nodes is not an Constant
"""
variable_inputs_list = [
node for node in float_subgraph_start_nodes if not isinstance(node, Constant)
]
variable_inputs_num = len(variable_inputs_list)
# Only one input to the subgraph where computations are done in floats can be variable, this
# is the only case we can manage with GenericFunction fusing
has_unique_variable_input = variable_inputs_num == 1
if not has_unique_variable_input:
for node in variable_inputs_list:
node_with_issues_for_fusing[node].append(
f"one of {variable_inputs_num} variable inputs (can only have 1 for fusing)"
)
node_with_issues_for_fusing[terminal_node].append(
f"cannot fuse here as the subgraph has {variable_inputs_num} variable inputs"
)
return has_unique_variable_input

View File

@@ -1,2 +0,0 @@
"""Representation module to represent source programs."""
from . import intermediate

View File

@@ -1,650 +0,0 @@
"""File containing code to represent source programs operations."""
from abc import ABC, abstractmethod
from collections import deque
from copy import deepcopy
from enum import Enum, unique
from math import floor
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast
import numpy as np
import torch
from loguru import logger
from ..data_types.base import BaseDataType
from ..data_types.dtypes_helpers import (
get_base_value_for_python_constant_data,
mix_values_determine_holding_dtype,
)
from ..data_types.integers import Integer
from ..debugging.custom_assert import assert_true
from ..helpers import indexing_helpers
from ..helpers.formatting_helpers import format_constant
from ..helpers.python_helpers import catch, update_and_return_dict
from ..values import BaseValue, ClearTensor, EncryptedTensor, TensorValue
IR_MIX_VALUES_FUNC_ARG_NAME = "mix_values_func"
ALL_IR_NODES: Set[Type] = set()
class IntermediateNode(ABC):
"""Abstract Base Class to derive from to represent source program operations."""
inputs: List[BaseValue]
outputs: List[BaseValue]
_n_in: int # _n_in indicates how many inputs are required to evaluate the IntermediateNode
def __init__(
self,
inputs: Iterable[BaseValue],
**_kwargs, # This is to be able to feed arbitrary arguments to IntermediateNodes
) -> None:
self.inputs = list(inputs)
assert_true(all(isinstance(x, BaseValue) for x in self.inputs))
# Register all IR nodes
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
ALL_IR_NODES.add(cls)
def _init_binary(
self,
inputs: Iterable[BaseValue],
mix_values_func: Callable[..., BaseValue] = mix_values_determine_holding_dtype,
**_kwargs, # Required to conform to __init__ typing
) -> None:
"""__init__ for a binary operation, ie two inputs."""
IntermediateNode.__init__(self, inputs)
assert_true(len(self.inputs) == 2)
self.outputs = [mix_values_func(self.inputs[0], self.inputs[1])]
def text_for_formatting(self, predecessors: List[str], _maximum_constant_length: int) -> str:
"""Get the formatted node (used in formatting operation graphs).
Args:
predecessors (List[str]): predecessor names to this node
_maximum_constant_length (int): desired maximum constant length
Returns:
str: the formatted node
"""
return f"{self.__class__.__name__.lower()}({', '.join(predecessors)})"
@abstractmethod
def text_for_drawing(self) -> str:
"""Get the label of the node (used in drawing operation graphs).
Returns:
str: the label of the node
"""
@abstractmethod
def evaluate(self, inputs: Dict[int, Any]) -> Any:
"""Simulate what the represented computation would output for the given inputs.
Args:
inputs (Dict[int, Any]): Dict containing the inputs for the evaluation
Returns:
Any: the result of the computation
"""
@classmethod
def n_in(cls) -> int:
"""Return how many inputs the node has.
Returns:
int: The number of inputs of the node.
"""
return cls._n_in
@classmethod
def requires_mix_values_func(cls) -> bool:
"""Determine whether the Class requires a mix_values_func to be built.
Returns:
bool: True if __init__ expects a mix_values_func argument.
"""
return cls.n_in() > 1
class Add(IntermediateNode):
"""Addition between two values."""
_n_in: int = 2
__init__ = IntermediateNode._init_binary
def text_for_drawing(self) -> str:
return "+"
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0] + inputs[1]
class Sub(IntermediateNode):
"""Subtraction between two values."""
_n_in: int = 2
__init__ = IntermediateNode._init_binary
def text_for_drawing(self) -> str:
return "-"
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0] - inputs[1]
class Mul(IntermediateNode):
"""Multiplication between two values."""
_n_in: int = 2
__init__ = IntermediateNode._init_binary
def text_for_drawing(self) -> str:
return "*"
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0] * inputs[1]
class Input(IntermediateNode):
"""Node representing an input of the program."""
input_name: str
program_input_idx: int
_n_in: int = 1
def __init__(
self,
input_value: BaseValue,
input_name: str,
program_input_idx: int,
) -> None:
super().__init__((input_value,))
assert_true(len(self.inputs) == 1)
self.input_name = input_name
self.program_input_idx = program_input_idx
self.outputs = [deepcopy(self.inputs[0])]
def text_for_formatting(self, predecessors: List[str], _maximum_constant_length: int) -> str:
assert_true(len(predecessors) == 0)
return self.input_name
def text_for_drawing(self) -> str:
return self.input_name
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0]
class Constant(IntermediateNode):
"""Node representing a constant of the program."""
_constant_data: Any
_n_in: int = 0
def __init__(
self,
constant_data: Any,
get_base_value_for_data_func: Callable[
[Any], Callable[..., BaseValue]
] = get_base_value_for_python_constant_data,
) -> None:
super().__init__([])
base_value_class = get_base_value_for_data_func(constant_data)
self._constant_data = constant_data
self.outputs = [base_value_class(is_encrypted=False)]
def text_for_formatting(self, predecessors: List[str], maximum_constant_length: int) -> str:
assert_true(len(predecessors) == 0)
return format_constant(self.constant_data, maximum_constant_length)
def text_for_drawing(self) -> str:
return format_constant(self.constant_data)
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return self.constant_data
@property
def constant_data(self) -> Any:
"""Return the constant_data stored in the Constant node.
Returns:
Any: The constant data that was stored.
"""
return self._constant_data
class Conv2D(IntermediateNode):
"""Return the node representing a 2d-convolution."""
def __init__(
self,
inputs: Iterable[BaseValue],
output_dtype: BaseDataType,
pads: Union[List[int], Tuple[int, int, int, int]],
strides: Union[List[int], Tuple[int, int]],
dilations: Union[List[int], Tuple[int, int]],
) -> None:
# TODO: remove this when padding is supported (#427)
assert all(pad == 0 for pad in pads), "conv2d doesn't support padding yet"
super().__init__(inputs)
self.pads = pads
self.strides = strides
self.dilations = dilations
self._n_in = len(self.inputs)
assert_true(len(self.inputs) == 2 or len(self.inputs) == 3)
assert_true(
all(
isinstance(input_value, TensorValue) and input_value.ndim == 4
for input_value in self.inputs[:2]
),
f"Conv2D only supports input and weight tensors of 4 dimensions"
f"({TensorValue.__name__} with ndim == 4)",
)
bias = cast(TensorValue, self.inputs[2]) if len(self.inputs) == 3 else None
if bias is not None:
assert_true(
isinstance(bias, TensorValue) and bias.ndim == 1,
f"Conv2D only supports bias 1 dimension ({TensorValue.__name__} with ndim == 1)",
)
x = cast(TensorValue, self.inputs[0])
weight = cast(TensorValue, self.inputs[1])
# Compute output shape
input_n, _, input_h, input_w = x.shape
weight_f, _, weight_h, weight_w = weight.shape
pads_h = pads[0] + pads[2]
pads_w = pads[1] + pads[3]
output_h = floor((input_h + pads_h - dilations[0] * (weight_h - 1) - 1) / strides[0]) + 1
output_w = floor((input_w + pads_w - dilations[1] * (weight_w - 1) - 1) / strides[1]) + 1
output_shape = (input_n, weight_f, output_h, output_w)
output_value = EncryptedTensor(dtype=output_dtype, shape=output_shape)
self.outputs = [output_value]
def text_for_drawing(self) -> str:
return "conv2d"
def evaluate(self, inputs: Dict[int, Any]) -> Any:
assert_true(
len(inputs) == self._n_in, f"expected {self.n_in} inputs, but got {len(inputs)}"
)
x, weight = inputs[0], inputs[1]
bias = inputs[2] if len(inputs) == 3 else np.zeros(weight.shape[0])
return self.evaluate_conv2d(x, weight, bias, self.pads, self.strides, self.dilations)
@staticmethod
def evaluate_conv2d(
x: np.ndarray,
weight: np.ndarray,
bias: np.ndarray,
# TODO: use padding when supported (#427)
_: Union[Tuple[int, int, int, int], List[int]],
strides: Union[Tuple[int, int], List[int]],
dilations: Union[Tuple[int, int], List[int]],
):
"""Evaluate 2D convolution.
Args:
x (np.ndarray): Input of shape (NxCxHxW)
weight (np.ndarray): Weight (kernel) of shape (FxCxHxW)
bias (np.ndarray): Bias vector of size (F)
pads (Union[Tuple[int, int, int, int], List[int]]): Padding over each
axis (H_beg, W_beg, H_end, W_end)
strides (Union[Tuple[int, int], List[int]]): Stride over each
axis (height and width)
dilations (Union[Tuple[int, int], List[int]]): Dilation over each
axis (height and width)
Returns:
np.ndarray: Result of the convolution of shape (NxCxHxW)
"""
# pylint: disable=no-member
return torch.conv2d(
torch.tensor(x, dtype=torch.long),
torch.tensor(weight, dtype=torch.long),
torch.tensor(bias, dtype=torch.long),
stride=strides,
dilation=dilations,
).numpy()
# pylint: enable=no-member
class IndexConstant(IntermediateNode):
"""Node representing a constant indexing in the program.
What we mean by constant indexing is that the index part of the operation is a constant.
Here are some examples: `x[2]`, `x[0, 1]`, `y[:, 0]`, `y[3:, :5]`
The opposite is to have dynamic indexing, which this node does not support.
Some examples of dynamic indexing are: `x[y]`, `x[y, z]`, `x[:, y]`
"""
_n_in: int = 1
index: Tuple[Union[int, slice], ...]
def __init__(
self,
input_: BaseValue,
index: Union[int, slice, Tuple[Union[int, slice], ...]],
) -> None:
super().__init__((input_,))
if not isinstance(self.inputs[0], TensorValue) or self.inputs[0].is_scalar:
raise TypeError(f"Only tensors can be indexed but you tried to index {self.inputs[0]}")
self.index = indexing_helpers.validate_index(index)
output_dtype = self.inputs[0].dtype
output_shape = indexing_helpers.determine_output_shape(self.inputs[0].shape, self.index)
self.outputs = [
EncryptedTensor(output_dtype, output_shape)
if self.inputs[0].is_encrypted
else ClearTensor(output_dtype, output_shape)
]
def text_for_formatting(self, predecessors: List[str], _maximum_constant_length: int) -> str:
assert_true(len(predecessors) == 1)
elements = [indexing_helpers.format_indexing_element(element) for element in self.index]
index = ", ".join(elements)
return f"{predecessors[0]}[{index}]"
def text_for_drawing(self) -> str:
return self.text_for_formatting(["value"], 0) # 0 is unused
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0][self.index]
def flood_replace_none_values(table: list):
"""Use a flooding algorithm to replace None values.
Args:
table (list): the list in which there are None values that need to be replaced by copies of
the closest non None data from the list.
"""
assert_true(any(value is not None for value in table))
not_none_values_idx = deque(idx for idx, value in enumerate(table) if value is not None)
while not_none_values_idx:
current_idx = not_none_values_idx.popleft()
current_value = table[current_idx]
previous_idx = current_idx - 1
next_idx = current_idx + 1
if previous_idx >= 0 and table[previous_idx] is None:
table[previous_idx] = deepcopy(current_value)
not_none_values_idx.append(previous_idx)
if next_idx < len(table) and table[next_idx] is None:
table[next_idx] = deepcopy(current_value)
not_none_values_idx.append(next_idx)
assert_true(all(value is not None for value in table))
@unique
class GenericFunctionKind(str, Enum):
"""Enum to validate GenericFunction op_kind."""
TLU = "TLU"
MEMORY = "Memory"
class GenericFunction(IntermediateNode):
"""Node representing an arbitrary function with a single output, e.g. sin(x)."""
# The arbitrary_func is not optional but mypy has a long standing bug and is not able to
# understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623
# arbitrary_func can take more than one argument but during evaluation the input variable will
# be the first argument passed to it. You can add other constant arguments needed for the proper
# execution of the function through op_args and op_kwargs.
arbitrary_func: Optional[Callable]
op_kind: GenericFunctionKind
op_name: str
op_args: Tuple[Any, ...]
op_kwargs: Dict[str, Any]
op_attributes: Dict[str, Any]
_n_in: int
# TODO: https://github.com/zama-ai/concrete-numpy-internal/issues/798 have a proper
# attribute system
DEFAULT_OP_ATTRIBUTES: Dict[str, Any] = {"fusable": True}
KWARGS_IGNORED_IN_FORMATTING: Set[str] = {
"float_op_subgraph",
"terminal_node",
}
def __init__(
self,
inputs: Iterable[BaseValue],
arbitrary_func: Callable,
output_value: BaseValue,
op_kind: Union[str, GenericFunctionKind],
op_name: Optional[str] = None,
op_args: Optional[Tuple[Any, ...]] = None,
op_kwargs: Optional[Dict[str, Any]] = None,
op_attributes: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__([deepcopy(i) for i in inputs])
self._n_in = len(self.inputs)
self.arbitrary_func = arbitrary_func
self.op_kind = GenericFunctionKind(op_kind)
self.op_args = op_args if op_args is not None else ()
self.op_kwargs = op_kwargs if op_kwargs is not None else {}
self.op_attributes = deepcopy(self.DEFAULT_OP_ATTRIBUTES)
if op_attributes is not None:
self.op_attributes.update(op_attributes)
self.outputs = [output_value]
self.op_name = op_name if op_name is not None else self.__class__.__name__
def text_for_formatting(self, predecessors: List[str], maximum_constant_length: int) -> str:
if self.op_name == "concat":
all_args = ["(" + ", ".join(predecessors) + ")"]
else:
all_args = deepcopy(predecessors)
all_args.extend(format_constant(value, maximum_constant_length) for value in self.op_args)
all_args.extend(
f"{name}={format_constant(value, maximum_constant_length)}"
for name, value in self.op_kwargs.items()
if name not in GenericFunction.KWARGS_IGNORED_IN_FORMATTING
)
return f"{self.op_name}({', '.join(all_args)})"
def text_for_drawing(self) -> str:
return self.op_name
def evaluate(self, inputs: Dict[int, Any]) -> Any:
# This is the continuation of the mypy bug workaround
assert self.arbitrary_func is not None
ordered_inputs = [inputs[idx] for idx in range(len(inputs))]
if self.op_name == "concat":
return self.arbitrary_func(tuple(ordered_inputs), *self.op_args, **self.op_kwargs)
return self.arbitrary_func(*ordered_inputs, *self.op_args, **self.op_kwargs)
def get_table(self, ordered_preds: List[IntermediateNode]) -> List[Any]:
"""Get the table for the current input value of this GenericFunction.
This function only works if the GenericFunction variable input value is an Integer.
This function only works if there is a single variable input node among ordered_preds.
Args:
ordered_preds (List[IntermediateNode]): List of predecessors of the node. This list must
contain a single non constant node and any number of Constant nodes.
Returns:
List[Any]: The table.
"""
variable_input_indices = [
idx for idx, pred in enumerate(ordered_preds) if not isinstance(pred, Constant)
]
assert_true(
(non_constant_pred_count := len(variable_input_indices)) == 1,
f"Can only have 1 non constant predecessor in {self.get_table.__name__}, "
f"got {non_constant_pred_count}",
)
variable_input_idx = variable_input_indices[0]
variable_input_dtype = self.inputs[variable_input_idx].dtype
# Check the input is an integer to be able to build a table
assert_true(
isinstance(variable_input_dtype, Integer),
f"{self.get_table.__name__} only works for an unsigned Integer input",
)
variable_input_dtype = cast(Integer, variable_input_dtype)
input_value_constructor = self.inputs[variable_input_idx].underlying_constructor
if input_value_constructor is None:
logger.info(
f"{self.__class__.__name__} input data type constructor was None, defaulting to int"
)
input_value_constructor = int
min_input_range = variable_input_dtype.min_value()
max_input_range = variable_input_dtype.max_value() + 1
template_input_dict = {
idx: node.evaluate({}) if isinstance(node, Constant) else None
for idx, node in enumerate(ordered_preds)
}
table = [
catch(
self.evaluate,
update_and_return_dict(
template_input_dict, {variable_input_idx: input_value_constructor(input_value)}
),
)
for input_value in range(min_input_range, max_input_range)
]
flood_replace_none_values(table)
return table
def default_dot_evaluation_function(lhs: Any, rhs: Any) -> Any:
"""Return the default python dot implementation for 1D iterable arrays.
Args:
lhs (Any): lhs vector of the dot.
rhs (Any): rhs vector of the dot.
Returns:
Any: the result of the dot operation.
"""
return sum(lhs * rhs for lhs, rhs in zip(lhs, rhs))
class Dot(IntermediateNode):
"""Return the node representing a dot product."""
_n_in: int = 2
# Optional, same issue as in GenericFunction for mypy
evaluation_function: Optional[Callable[[Any, Any], Any]]
# Allows to use specialized implementations from e.g. numpy
def __init__(
self,
inputs: Iterable[BaseValue],
output_dtype: BaseDataType,
delegate_evaluation_function: Optional[
Callable[[Any, Any], Any]
] = default_dot_evaluation_function,
) -> None:
super().__init__(inputs)
assert_true(len(self.inputs) == 2)
assert_true(
all(
isinstance(input_value, TensorValue) and input_value.ndim <= 1
for input_value in self.inputs
),
f"Dot only supports two scalars or vectors ({TensorValue.__name__} with ndim up to 1)",
)
lhs = cast(TensorValue, self.inputs[0])
rhs = cast(TensorValue, self.inputs[1])
if lhs.ndim == 1 and rhs.ndim == 1:
assert_true(
lhs.shape[0] == rhs.shape[0],
f"Dot between vectors of shapes {lhs.shape} and {rhs.shape} is not supported",
)
output_shape: Tuple[int, ...]
if (lhs.ndim == 1 and rhs.ndim == 1) or (lhs.ndim == 0 and rhs.ndim == 0):
# numpy.dot(x, y) where x and y are both vectors or both scalars
output_shape = ()
elif lhs.ndim == 1:
# numpy.dot(x, y) where x is a vector and y is a scalar
output_shape = lhs.shape
else:
# numpy.dot(x, y) where x is a scalar and y is a vector
output_shape = rhs.shape
output_value = EncryptedTensor if (lhs.is_encrypted or rhs.is_encrypted) else ClearTensor
self.outputs = [output_value(output_dtype, output_shape)]
self.evaluation_function = delegate_evaluation_function
def text_for_drawing(self) -> str:
return "dot"
def evaluate(self, inputs: Dict[int, Any]) -> Any:
# This is the continuation of the mypy bug workaround
assert self.evaluation_function is not None
return self.evaluation_function(inputs[0], inputs[1])
class MatMul(IntermediateNode):
"""Return the node representing a matrix multiplication."""
_n_in: int = 2
def __init__(
self,
inputs: Iterable[BaseValue],
output_dtype: BaseDataType,
output_shape: Tuple[int, ...],
) -> None:
super().__init__(inputs)
assert_true(len(self.inputs) == 2)
output_value = (
EncryptedTensor(dtype=output_dtype, shape=output_shape)
if (self.inputs[0].is_encrypted or self.inputs[1].is_encrypted)
else ClearTensor(dtype=output_dtype, shape=output_shape)
)
self.outputs = [output_value]
def text_for_drawing(self) -> str:
return "matmul"
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0] @ inputs[1]

View File

@@ -1,8 +0,0 @@
"""Module for basic tracing facilities."""
from .base_tracer import BaseTracer
from .tracing_helpers import (
create_graph_from_output_tracers,
make_input_tracer,
make_input_tracers,
prepare_function_parameters,
)

View File

@@ -1,462 +0,0 @@
"""This file holds the code that can be shared between tracers."""
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union, cast
from ..data_types import Float
from ..data_types.base import BaseDataType
from ..debugging.custom_assert import assert_true
from ..representation.intermediate import (
IR_MIX_VALUES_FUNC_ARG_NAME,
Add,
Constant,
GenericFunction,
IndexConstant,
IntermediateNode,
Mul,
Sub,
)
from ..values import BaseValue, TensorValue
class BaseTracer(ABC):
"""Base class for implementing tracers."""
# this variable changes the behavior of __eq__ so that it can be traced but still allows to hash
# BaseTracers when not tracing.
_is_tracing: bool = False
inputs: List["BaseTracer"]
traced_computation: IntermediateNode
output_idx: int
output: BaseValue
_mix_values_func: Callable[..., BaseValue]
def __init__(
self,
inputs: Iterable["BaseTracer"],
traced_computation: IntermediateNode,
output_idx: int,
) -> None:
self.inputs = list(inputs)
self.traced_computation = traced_computation
self.output_idx = output_idx
self.output = traced_computation.outputs[output_idx]
@property
def shape(self) -> Tuple[int, ...]:
"""Get the shape of the output of the tracer.
Returns:
Tuple[int, ...]: the shape of the output
"""
if isinstance(self.output, TensorValue):
return self.output.shape
raise AttributeError(
f"'{self.__class__.__name__}' object "
f"with '{self.output}' output "
f"has no attribute 'shape'"
) # pragma: no cover
# this error cannot be covered because we only have TensorValue for now
@abstractmethod
def _supports_other_operand(self, other: Any) -> bool:
"""Check if the current class supports tracing with the other operand.
Args:
other (Any): the operand to check compatibility with.
Returns:
bool: True if the tracer can manage operations with the other operand.
"""
return isinstance(other, self.__class__)
@abstractmethod
def _make_const_input_tracer(self, constant_data: Any) -> "BaseTracer":
"""Create a tracer for a constant input.
Args:
constant_data (Any): The constant to store.
Returns:
BaseTracer: The BaseTracer for that constant.
"""
@classmethod
def set_is_tracing(cls, is_tracing: bool) -> None:
"""Set whether we are in a tracing context to change __eq__ behavior.
Args:
is_tracing (bool): boolean to use to set whether we are tracing
"""
cls._is_tracing = is_tracing
@classmethod
def _get_mix_values_func(cls):
return cls._mix_values_func
def _sanitize(self, inp) -> "BaseTracer":
if not isinstance(inp, BaseTracer) and not (
isinstance(inp, Tuple) # type: ignore
and all(isinstance(item, BaseTracer) for item in inp) # type: ignore
):
return self._make_const_input_tracer(inp)
return inp
def instantiate_output_tracers(
self,
inputs: Iterable[Union["BaseTracer", Any]],
computation_to_trace: Type[IntermediateNode],
) -> Tuple["BaseTracer", ...]:
"""Instantiate all output BaseTracer for a given computation.
Args:
inputs (Iterable[Union[BaseTracer, Any]]): Previous BaseTracer or data used as inputs
for a new node.
computation_to_trace (Type[IntermediateNode]): The IntermediateNode class
to instantiate for the computation being traced
Returns:
Tuple[BaseTracer, ...]: A tuple containing an BaseTracer per output function
"""
# For inputs which are actually constant, first convert into a tracer
sanitized_inputs = [self._sanitize(inp) for inp in inputs]
additional_parameters = (
{IR_MIX_VALUES_FUNC_ARG_NAME: self._get_mix_values_func()}
if computation_to_trace.requires_mix_values_func()
else {}
)
traced_computation = computation_to_trace(
(x.output for x in sanitized_inputs),
**additional_parameters,
)
output_tracers = tuple(
self.__class__(sanitized_inputs, traced_computation, output_idx)
for output_idx in range(len(traced_computation.outputs))
)
return output_tracers
def _helper_for_unary_functions(self, op_lambda: Callable, op_name: str) -> "BaseTracer":
"""Trace a unary operator which maintains the shape, which will thus be replaced by a TLU.
Returns:
BaseTracer: The output NPTracer containing the traced function
"""
first_arg_output = self.output
assert_true(isinstance(first_arg_output, TensorValue))
first_arg_output = cast(TensorValue, first_arg_output)
out_dtype = first_arg_output.dtype
out_shape = first_arg_output.shape
generic_function_output_value = TensorValue(
out_dtype,
first_arg_output.is_encrypted,
out_shape,
)
traced_computation = GenericFunction(
inputs=[first_arg_output],
arbitrary_func=op_lambda,
output_value=generic_function_output_value,
op_kind="TLU",
op_name=f"{op_name}",
)
output_tracer = self.__class__(
[self],
traced_computation=traced_computation,
output_idx=0,
)
return output_tracer
def _helper_for_binary_functions_with_one_cst_input(
self,
lhs: Union["BaseTracer", Any],
rhs: Union["BaseTracer", Any],
op_lambda: Callable,
op_name: str,
output_dtype: Optional[BaseDataType] = None,
) -> "BaseTracer":
"""Trace a binary operator which maintains the shape, when one input is a constant.
This function is helpful to convert an operation with two inputs, one of which being a
constant, into a TLU, while maintaining the constant somewhere in the graph, eg to simplify
debugging.
Returns:
BaseTracer: The output NPTracer containing the traced function
"""
if isinstance(lhs, BaseTracer):
if not self._supports_other_operand(rhs):
return NotImplemented
elif isinstance(rhs, BaseTracer):
if not self._supports_other_operand(lhs):
return NotImplemented
sanitized_inputs = [self._sanitize(inp) for inp in [lhs, rhs]]
# One of the inputs has to be constant
if not (
isinstance(sanitized_inputs[0].traced_computation, Constant)
or isinstance(sanitized_inputs[1].traced_computation, Constant)
):
raise NotImplementedError(f"Can't manage binary operator {op_name}")
sanitized_input_values = [san_input.output for san_input in sanitized_inputs]
output_value = self._get_mix_values_func()(*sanitized_input_values)
if output_dtype is not None:
output_value.dtype = deepcopy(output_dtype)
traced_computation = GenericFunction(
inputs=sanitized_input_values,
arbitrary_func=op_lambda,
output_value=output_value,
op_kind="TLU",
op_name=op_name,
)
result_tracer = self.__class__(sanitized_inputs, traced_computation, 0)
return result_tracer
def __hash__(self) -> int:
return id(self)
def __add__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
if not self._supports_other_operand(other):
return NotImplemented
result_tracer = self.instantiate_output_tracers(
[self, other],
Add,
)
assert_true(len(result_tracer) == 1)
return result_tracer[0]
# With that is that x + 1 and 1 + x have the same graph. If we want to keep
# the order, we need to do as in __rsub__, ie mostly a copy of __sub__ +
# some changes
__radd__ = __add__
def __neg__(self) -> "BaseTracer":
return 0 - self
def __pos__(self) -> "BaseTracer":
# Remark that we don't want to return 'self' since we want the result to be a copy, ie not
# a reference to the same object
return 0 + self
def _lshift(self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]) -> "BaseTracer":
return self._helper_for_binary_functions_with_one_cst_input(
lhs, rhs, lambda x, y: x << y, "lshift"
)
def __lshift__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
# x << shift
return self._lshift(self, other)
def __rlshift__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
# cst << x
return self._lshift(other, self)
def _rshift(self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]) -> "BaseTracer":
return self._helper_for_binary_functions_with_one_cst_input(
lhs, rhs, lambda x, y: x >> y, "rshift"
)
def __rshift__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
# x >> shift
return self._rshift(self, other)
def __rrshift__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
# cst >> x
return self._rshift(other, self)
def __gt__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
# x > cst
return self._helper_for_binary_functions_with_one_cst_input(
self, other, lambda x, y: x > y, "gt"
)
def __ge__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
# x >= cst
return self._helper_for_binary_functions_with_one_cst_input(
self, other, lambda x, y: x >= y, "ge"
)
def __lt__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
# x < cst
return self._helper_for_binary_functions_with_one_cst_input(
self, other, lambda x, y: x < y, "lt"
)
def __le__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
# x <= cst
return self._helper_for_binary_functions_with_one_cst_input(
self, other, lambda x, y: x <= y, "le"
)
def __eq__(self, other: Union["BaseTracer", Any]):
# x == cst
# Return the tracer if we are tracing, else return the result of the default __eq__ function
# allows to have hash capabilities outside of tracing
return (
self._helper_for_binary_functions_with_one_cst_input(
self, other, lambda x, y: x == y, "eq"
)
if self._is_tracing
else self is other
)
def __ne__(self, other: Union["BaseTracer", Any]):
# x != cst
return self._helper_for_binary_functions_with_one_cst_input(
self, other, lambda x, y: x != y, "ne"
)
def __pow__(self, other: Union["BaseTracer", Any]):
# x ** cst
return self._helper_for_binary_functions_with_one_cst_input(
self, other, lambda x, y: x ** y, "pow"
)
def __rpow__(self, other: Union["BaseTracer", Any]):
# cst ** x
return self._helper_for_binary_functions_with_one_cst_input(
other, self, lambda x, y: x ** y, "pow"
)
def __mod__(self, other: Union["BaseTracer", Any]):
# x % cst
return self._helper_for_binary_functions_with_one_cst_input(
self, other, lambda x, y: x % y, "mod"
)
def __rmod__(self, other: Union["BaseTracer", Any]):
# cst % x
return self._helper_for_binary_functions_with_one_cst_input(
other, self, lambda x, y: x % y, "mod"
)
def __and__(self, other: Union["BaseTracer", Any]):
# x & cst
return self._helper_for_binary_functions_with_one_cst_input(
self, other, lambda x, y: x & y, "and"
)
def __rand__(self, other: Union["BaseTracer", Any]):
# cst & x
return self._helper_for_binary_functions_with_one_cst_input(
other, self, lambda x, y: x & y, "and"
)
def __or__(self, other: Union["BaseTracer", Any]):
# x | cst
return self._helper_for_binary_functions_with_one_cst_input(
self, other, lambda x, y: x | y, "or"
)
def __ror__(self, other: Union["BaseTracer", Any]):
# cst | x
return self._helper_for_binary_functions_with_one_cst_input(
other, self, lambda x, y: x | y, "or"
)
def __xor__(self, other: Union["BaseTracer", Any]):
# x ^ cst
return self._helper_for_binary_functions_with_one_cst_input(
self, other, lambda x, y: x ^ y, "xor"
)
def __rxor__(self, other: Union["BaseTracer", Any]):
# cst ^ x
return self._helper_for_binary_functions_with_one_cst_input(
other, self, lambda x, y: x ^ y, "xor"
)
def __sub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
if not self._supports_other_operand(other):
return NotImplemented
result_tracer = self.instantiate_output_tracers(
[self, other],
Sub,
)
assert_true(len(result_tracer) == 1)
return result_tracer[0]
def __rsub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
if not self._supports_other_operand(other):
return NotImplemented
result_tracer = self.instantiate_output_tracers(
[other, self],
Sub,
)
assert_true(len(result_tracer) == 1)
return result_tracer[0]
def __mul__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
if not self._supports_other_operand(other):
return NotImplemented
result_tracer = self.instantiate_output_tracers(
[self, other],
Mul,
)
assert_true(len(result_tracer) == 1)
return result_tracer[0]
# With that is that x * 3 and 3 * x have the same graph. If we want to keep
# the order, we need to do as in __rmul__, ie mostly a copy of __mul__ +
# some changes
__rmul__ = __mul__
def __abs__(self):
return self._helper_for_unary_functions(lambda x: x.__abs__(), "__abs__")
def __invert__(self):
return self._helper_for_unary_functions(lambda x: x.__invert__(), "__invert__")
def __getitem__(self, item):
traced_computation = IndexConstant(self.output, item)
return self.__class__([self], traced_computation, 0)
def _truediv(
self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]
) -> "BaseTracer":
return self._helper_for_binary_functions_with_one_cst_input(
lhs, rhs, lambda x, y: x / y, "truediv", Float(64)
)
def __truediv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
return self._truediv(self, other)
def __rtruediv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
return self._truediv(other, self)
def _floordiv(
self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]
) -> "BaseTracer":
return self._helper_for_binary_functions_with_one_cst_input(
lhs, rhs, lambda x, y: x // y, "floordiv"
)
def __floordiv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
return self._floordiv(self, other)
def __rfloordiv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
return self._floordiv(other, self)

View File

@@ -1,162 +0,0 @@
"""Helper functions for tracing."""
import collections
from contextlib import contextmanager
from inspect import signature
from typing import Callable, Dict, Iterable, List, OrderedDict, Set, Type
import networkx as nx
from networkx.algorithms.dag import is_directed_acyclic_graph
from ..debugging.custom_assert import assert_true
from ..representation.intermediate import Input
from ..values import BaseValue
from .base_tracer import BaseTracer
def make_input_tracers(
tracer_class: Type[BaseTracer],
function_parameters: OrderedDict[str, BaseValue],
) -> OrderedDict[str, BaseTracer]:
"""Create tracers for a function's parameters.
Args:
tracer_class (Type[BaseTracer]): the class of tracer to create an Input for
function_parameters (OrderedDict[str, BaseValue]): the dictionary with the parameters names
and corresponding Values
Returns:
OrderedDict[str, BaseTracer]: the dictionary containing the Input Tracers for each parameter
"""
return collections.OrderedDict(
(param_name, make_input_tracer(tracer_class, param_name, input_idx, param))
for input_idx, (param_name, param) in enumerate(function_parameters.items())
)
def make_input_tracer(
tracer_class: Type[BaseTracer],
input_name: str,
input_idx: int,
input_value: BaseValue,
) -> BaseTracer:
"""Create a tracer for an input value.
Args:
tracer_class (Type[BaseTracer]): the class of tracer to create an Input for
input_name (str): the name of the input in the traced function
input_idx (int): the input index in the function parameters
input_value (BaseValue): the Value that is an input and needs to be wrapped in an
BaseTracer
Returns:
BaseTracer: The BaseTracer for that input value
"""
return tracer_class([], Input(input_value, input_name, input_idx), 0)
def prepare_function_parameters(
function_to_trace: Callable, function_parameters: Dict[str, BaseValue]
) -> OrderedDict[str, BaseValue]:
"""Filter the passed function_parameters to trace function_to_trace.
Args:
function_to_trace (Callable): function that will be traced for which parameters are checked
function_parameters (Dict[str, BaseValue]): parameters given to trace the function
Raises:
ValueError: Raised when some parameters are missing to trace function_to_trace
Returns:
OrderedDict[str, BaseValue]: filtered function_parameters dictionary
"""
function_signature = signature(function_to_trace)
missing_args = function_signature.parameters.keys() - function_parameters.keys()
if len(missing_args) > 0:
raise ValueError(
f"The function '{function_to_trace.__name__}' requires the following parameters"
f"that were not provided: {', '.join(sorted(missing_args))}"
)
# This convoluted way of creating the dict is to ensure key order is maintained
return collections.OrderedDict(
(param_name, function_parameters[param_name])
for param_name in function_signature.parameters.keys()
)
def create_graph_from_output_tracers(
output_tracers: Iterable[BaseTracer],
) -> nx.MultiDiGraph:
"""Generate a networkx Directed Graph that represents the computation from a traced function.
Args:
output_tracers (Iterable[BaseTracer]): the output tracers resulting from running the
function over the proper input tracers
Returns:
nx.MultiDiGraph: Directed Graph that is guaranteed to be a DAG containing the ir nodes
representing the traced program/function
"""
graph = nx.MultiDiGraph()
visited_tracers: Set[BaseTracer] = set()
# use dict as ordered set
current_tracers = {tracer: None for tracer in output_tracers}
while current_tracers:
# use dict as ordered set
next_tracers: Dict[BaseTracer, None] = {}
for tracer in current_tracers:
if tracer in visited_tracers:
continue
current_ir_node = tracer.traced_computation
graph.add_node(current_ir_node)
for input_idx, input_tracer in enumerate(tracer.inputs):
input_ir_node = input_tracer.traced_computation
output_idx = input_tracer.output_idx
graph.add_node(input_ir_node)
graph.add_edge(
input_ir_node,
current_ir_node,
input_idx=input_idx,
output_idx=output_idx,
)
if input_tracer not in visited_tracers:
next_tracers.update({input_tracer: None})
visited_tracers.add(tracer)
current_tracers = next_tracers
assert_true(is_directed_acyclic_graph(graph))
# Check each edge is unique
unique_edges = set(
(pred, succ, tuple((k, v) for k, v in edge_data.items()))
for pred, succ, edge_data in graph.edges(data=True)
)
number_of_edges = len(graph.edges)
assert_true(len(unique_edges) == number_of_edges)
return graph
@contextmanager
def tracing_context(tracer_classes: List[Type[BaseTracer]]):
"""Set tracer classes in tracing mode.
Args:
tracer_classes (List[Type[BaseTracer]]): The list of tracers for which we should enable
tracing.
"""
try:
for tracer_class in tracer_classes:
tracer_class.set_is_tracing(True)
yield
finally:
for tracer_class in tracer_classes:
tracer_class.set_is_tracing(False)

View File

@@ -1,5 +0,0 @@
"""Module for value structures."""
from . import tensors
from .base import BaseValue
from .tensors import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue

View File

@@ -1,45 +0,0 @@
"""Module that defines the values in a program."""
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Callable, Optional
from ..data_types.base import BaseDataType
class BaseValue(ABC):
"""Abstract base class to represent any kind of value in a program."""
dtype: BaseDataType
_is_encrypted: bool
underlying_constructor: Optional[Callable]
def __init__(self, dtype: BaseDataType, is_encrypted: bool) -> None:
self.dtype = deepcopy(dtype)
self._is_encrypted = is_encrypted
self.underlying_constructor = None
def __repr__(self) -> str: # pragma: no cover
return str(self)
@abstractmethod
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.dtype == other.dtype
@property
def is_encrypted(self) -> bool:
"""Whether Value is encrypted or not.
Returns:
bool: True if encrypted False otherwise
"""
return self._is_encrypted
@property
def is_clear(self) -> bool:
"""Whether Value is clear or not.
Returns:
bool: True if clear False otherwise
"""
return not self._is_encrypted

View File

@@ -1,142 +0,0 @@
"""Module that defines the tensor values in a program."""
from math import prod
from typing import Tuple
from ..data_types.base import BaseDataType
from .base import BaseValue
class TensorValue(BaseValue):
"""Class representing a tensor value."""
_shape: Tuple[int, ...]
_ndim: int
_size: int
def __init__(
self,
dtype: BaseDataType,
is_encrypted: bool,
shape: Tuple[int, ...],
):
super().__init__(dtype, is_encrypted)
# Managing tensors as in numpy, shape of () means the value is scalar
self._shape = shape
self._ndim = len(self._shape)
self._size = prod(self._shape) if self._shape != () else 1
def __eq__(self, other: object) -> bool:
return (
isinstance(other, self.__class__)
and self.shape == other.shape
and self.ndim == other.ndim
and self.size == other.size
and super().__eq__(other)
)
def __str__(self) -> str:
encrypted_str = "Encrypted" if self._is_encrypted else "Clear"
tensor_or_scalar_str = "Scalar" if self.is_scalar else "Tensor"
shape_str = f", shape={self.shape}" if self.shape != () else ""
return f"{encrypted_str}{tensor_or_scalar_str}<{str(self.dtype)}{shape_str}>"
@property
def shape(self) -> Tuple[int, ...]:
"""Return the TensorValue shape property.
Returns:
Tuple[int, ...]: The TensorValue shape.
"""
return self._shape
@property
def ndim(self) -> int:
"""Return the TensorValue ndim property.
Returns:
int: The TensorValue ndim.
"""
return self._ndim
@property
def size(self) -> int:
"""Return the TensorValue size property.
Returns:
int: The TensorValue size.
"""
return self._size
@property
def is_scalar(self) -> bool:
"""Whether Value is scalar or not.
Returns:
bool: True if scalar False otherwise
"""
return self.shape == ()
def make_clear_tensor(
dtype: BaseDataType,
shape: Tuple[int, ...],
) -> TensorValue:
"""Create a clear TensorValue.
Args:
dtype (BaseDataType): The data type for the tensor.
shape (Optional[Tuple[int, ...]], optional): The tensor shape. Defaults to None.
Returns:
TensorValue: The corresponding TensorValue.
"""
return TensorValue(dtype=dtype, is_encrypted=False, shape=shape)
def make_encrypted_tensor(
dtype: BaseDataType,
shape: Tuple[int, ...],
) -> TensorValue:
"""Create an encrypted TensorValue.
Args:
dtype (BaseDataType): The data type for the tensor.
shape (Optional[Tuple[int, ...]], optional): The tensor shape. Defaults to None.
Returns:
TensorValue: The corresponding TensorValue.
"""
return TensorValue(dtype=dtype, is_encrypted=True, shape=shape)
ClearTensor = make_clear_tensor
EncryptedTensor = make_encrypted_tensor
def make_clear_scalar(dtype: BaseDataType) -> TensorValue:
"""Create a clear scalar value.
Args:
dtype (BaseDataType): The data type for the value.
Returns:
TensorValue: The corresponding TensorValue.
"""
return TensorValue(dtype=dtype, is_encrypted=False, shape=())
def make_encrypted_scalar(dtype: BaseDataType) -> TensorValue:
"""Create an encrypted scalar value.
Args:
dtype (BaseDataType): The data type for the value.
Returns:
TensorValue: The corresponding TensorValue.
"""
return TensorValue(dtype=dtype, is_encrypted=True, shape=())
ClearScalar = make_clear_scalar
EncryptedScalar = make_encrypted_scalar

View File

@@ -1,25 +0,0 @@
"""Module for compiling numpy functions to homomorphic equivalents."""
# Import differently to put at the top, and avoid circular import issues
from concrete.numpy.compile import (
compile_numpy_function,
compile_numpy_function_into_op_graph_and_measure_bounds,
)
from concrete.numpy.np_fhe_compiler import NPFHECompiler
from concrete.numpy.tracing import trace_numpy_function
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
from ..common.data_types import (
Float,
Float16,
Float32,
Float64,
Integer,
SignedInteger,
UnsignedInteger,
)
from ..common.debugging import draw_graph, format_operation_graph
from ..common.extensions.convolution import conv2d
from ..common.extensions.multi_table import MultiLookupTable
from ..common.extensions.table import LookupTable
from ..common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue

View File

@@ -1,805 +0,0 @@
"""numpy compilation function."""
import sys
import traceback
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union, cast
import numpy
from ..common.bounds_measurement.inputset_eval import eval_op_graph_bounds_on_inputset
from ..common.common_helpers import check_op_graph_is_integer_program
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
from ..common.data_types import Integer
from ..common.debugging import format_operation_graph
from ..common.debugging.custom_assert import assert_true
from ..common.fhe_circuit import FHECircuit
from ..common.mlir.utils import (
check_graph_values_compatibility_with_mlir,
update_bit_width_for_mlir,
)
from ..common.operator_graph import OPGraph
from ..common.optimization.topological import fuse_float_operations
from ..common.representation.intermediate import Add, Constant, GenericFunction, IntermediateNode
from ..common.values import BaseValue, ClearScalar
from ..numpy.tracing import trace_numpy_function
from .np_dtypes_helpers import (
get_base_data_type_for_numpy_or_python_constant_data,
get_base_value_for_numpy_or_python_constant_data,
get_constructor_for_numpy_or_python_constant_data,
)
from .np_inputset_helpers import _check_special_inputset_availability, _generate_random_inputset
from .np_mlir_converter import NPMLIRConverter
_COMPILE_FHE_INSECURE_KEY_CACHE_DIR: Optional[str] = None
def numpy_max_func(lhs: Any, rhs: Any) -> Any:
"""Compute the maximum value between two values which can be numpy classes (e.g. ndarray).
Args:
lhs (Any): lhs value to compute max from.
rhs (Any): rhs value to compute max from.
Returns:
Any: maximum scalar value between lhs and rhs.
"""
return numpy.maximum(lhs, rhs).max()
def numpy_min_func(lhs: Any, rhs: Any) -> Any:
"""Compute the minimum value between two values which can be numpy classes (e.g. ndarray).
Args:
lhs (Any): lhs value to compute min from.
rhs (Any): rhs value to compute min from.
Returns:
Any: minimum scalar value between lhs and rhs.
"""
return numpy.minimum(lhs, rhs).min()
def sanitize_compilation_configuration_and_artifacts(
compilation_configuration: Optional[CompilationConfiguration] = None,
compilation_artifacts: Optional[CompilationArtifacts] = None,
) -> Tuple[CompilationConfiguration, CompilationArtifacts]:
"""Return the proper compilation configuration and artifacts.
Default values are returned if None is passed for each argument.
Args:
compilation_configuration (Optional[CompilationConfiguration], optional): the compilation
configuration to sanitize. Defaults to None.
compilation_artifacts (Optional[CompilationArtifacts], optional): the compilation artifacts
to sanitize. Defaults to None.
Returns:
Tuple[CompilationConfiguration, CompilationArtifacts]: the tuple of sanitized configuration
and artifacts.
"""
# Create default configuration if custom configuration is not specified
compilation_configuration = (
CompilationConfiguration()
if compilation_configuration is None
else compilation_configuration
)
# Create temporary artifacts if custom artifacts is not specified (in case of exceptions)
if compilation_artifacts is None:
compilation_artifacts = CompilationArtifacts()
return compilation_configuration, compilation_artifacts
def get_inputset_to_use(
function_parameters: Dict[str, BaseValue],
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]], str],
compilation_configuration: CompilationConfiguration,
) -> Union[Iterable[Any], Iterable[Tuple[Any, ...]]]:
"""Get the proper inputset to use for compilation.
Args:
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]], str]): The inputset over which
op_graph is evaluated. It needs to be an iterable on tuples which are of the same length
than the number of parameters in the function, and in the same order than these same
parameters
compilation_configuration (CompilationConfiguration): Configuration object to use during
compilation
Returns:
Union[Iterable[Any], Iterable[Tuple[Any, ...]]]: the inputset to use.
"""
# Generate random inputset if it is requested and available
if isinstance(inputset, str):
_check_special_inputset_availability(inputset, compilation_configuration)
return _generate_random_inputset(function_parameters, compilation_configuration)
return inputset
def run_compilation_function_with_error_management(
compilation_function: Callable,
compilation_configuration: CompilationConfiguration,
compilation_artifacts: CompilationArtifacts,
) -> Any:
"""Call compilation_function() and manage exceptions that may occur.
Args:
compilation_function (Callable): the compilation function to call.
compilation_configuration (CompilationConfiguration): the current compilation configuration.
compilation_artifacts (CompilationArtifacts): the current compilation artifacts.
Returns:
Any: returns the result of the call to compilation_function
"""
# Try to compile the function and save partial artifacts on failure
try:
# Use context manager to restore numpy error handling
with numpy.errstate(**numpy.geterr()):
return compilation_function()
except Exception: # pragma: no cover
# This branch is reserved for unexpected issues and hence it shouldn't be tested.
# If it could be tested, we would have fixed the underlying issue.
# We need to export all the information we have about the compilation
# If the user wants them to be exported
if compilation_configuration.dump_artifacts_on_unexpected_failures:
compilation_artifacts.export()
traceback_path = compilation_artifacts.output_directory.joinpath("traceback.txt")
with open(traceback_path, "w", encoding="utf-8") as f:
f.write(traceback.format_exc())
raise
def _compile_numpy_function_into_op_graph_internal(
function_to_compile: Callable,
function_parameters: Dict[str, BaseValue],
compilation_configuration: CompilationConfiguration,
compilation_artifacts: CompilationArtifacts,
) -> OPGraph:
"""Compile a function into an OPGraph without evaluating the intermediate nodes bounds.
Args:
function_to_compile (Callable): The function to compile
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
compilation_configuration (CompilationConfiguration): Configuration object to use
during compilation
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
during compilation
Returns:
OPGraph: compiled function into a graph, node values are not representative of the values
that can be observed during execution.
Use _compile_numpy_function_into_op_graph_and_measure_bounds_internal if you need bounds
estimation.
"""
# Check function parameters
wrong_inputs = {
inp: function_parameters[inp]
for inp in function_parameters.keys()
if not isinstance(function_parameters[inp], BaseValue)
}
list_of_possible_basevalue = [
"ClearTensor",
"EncryptedTensor",
"ClearScalar",
"EncryptedScalar",
]
assert_true(
len(wrong_inputs.keys()) == 0,
f"wrong type for inputs {wrong_inputs}, needs to be one of {list_of_possible_basevalue}",
)
# Add the function to compile as an artifact
compilation_artifacts.add_function_to_compile(function_to_compile)
# Add the parameters of function to compile as artifacts
for name, value in function_parameters.items():
compilation_artifacts.add_parameter_of_function_to_compile(name, str(value))
# Trace the function
op_graph = trace_numpy_function(function_to_compile, function_parameters)
# Add the initial graph as an artifact
compilation_artifacts.add_operation_graph("initial", op_graph)
# Apply topological optimizations if they are enabled
if compilation_configuration.enable_topological_optimizations:
# Fuse float operations to have int to int GenericFunction
if not check_op_graph_is_integer_program(op_graph):
fuse_float_operations(op_graph, compilation_artifacts)
return op_graph
def compile_numpy_function_into_op_graph(
function_to_compile: Callable,
function_parameters: Dict[str, BaseValue],
compilation_configuration: Optional[CompilationConfiguration] = None,
compilation_artifacts: Optional[CompilationArtifacts] = None,
) -> OPGraph:
"""Compile a function into an OPGraph.
Args:
function_to_compile (Callable): The function to compile
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use
during compilation
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
during compilation
Returns:
OPGraph: compiled function into a graph
"""
(
compilation_configuration,
compilation_artifacts,
) = sanitize_compilation_configuration_and_artifacts(
compilation_configuration, compilation_artifacts
)
def compilation_function():
return _compile_numpy_function_into_op_graph_internal(
function_to_compile,
function_parameters,
compilation_configuration,
compilation_artifacts,
)
result = run_compilation_function_with_error_management(
compilation_function, compilation_configuration, compilation_artifacts
)
# for mypy
assert isinstance(result, OPGraph)
return result
def _measure_op_graph_bounds_and_update_internal(
op_graph: OPGraph,
function_parameters: Dict[str, BaseValue],
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]],
compilation_configuration: CompilationConfiguration,
compilation_artifacts: CompilationArtifacts,
prev_node_bounds_and_samples: Optional[Dict[IntermediateNode, Dict[str, Any]]] = None,
warn_on_inputset_length: bool = True,
) -> Dict[IntermediateNode, Dict[str, Any]]:
"""Measure the intermediate values and update the OPGraph accordingly for the given inputset.
Args:
op_graph (OPGraph): the OPGraph for which to measure bounds and update node values.
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]]]): The inputset over which op_graph
is evaluated. It needs to be an iterable on tuples which are of the same length than the
number of parameters in the function, and in the same order than these same parameters
compilation_configuration (CompilationConfiguration): Configuration object to use
during compilation
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
during compilation
prev_node_bounds_and_samples (Optional[Dict[IntermediateNode, Dict[str, Any]]], optional):
Bounds and samples from a previous run. Defaults to None.
warn_on_inputset_length (bool, optional): Set to True to get a warning if inputset is not
long enough. Defaults to True.
Raises:
ValueError: Raises an error if the inputset is too small and the compilation configuration
treats warnings as error.
Returns:
Dict[IntermediateNode, Dict[str, Any]]: a dict containing the bounds for each node from
op_graph, stored with the node as key and a dict with keys "min", "max" and "sample" as
value.
"""
# Find bounds with the inputset
inputset_size, node_bounds_and_samples = eval_op_graph_bounds_on_inputset(
op_graph,
inputset,
compilation_configuration=compilation_configuration,
min_func=numpy_min_func,
max_func=numpy_max_func,
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
prev_node_bounds_and_samples=prev_node_bounds_and_samples,
)
if warn_on_inputset_length:
# Check inputset size
inputset_size_upper_limit = 1
# this loop will determine the number of possible inputs of the function
# if a function have a single 3-bit input, for example, inputset_size_upper_limit will be 8
for parameter_value in function_parameters.values():
if isinstance(parameter_value.dtype, Integer):
# multiple parameter bit-widths are multiplied as they can be combined into an input
inputset_size_upper_limit *= 2 ** parameter_value.dtype.bit_width
# if the upper limit of the inputset size goes above 10,
# break the loop as we will require at least 10 inputs in this case
if inputset_size_upper_limit > 10:
break
minimum_required_inputset_size = min(inputset_size_upper_limit, 10)
if inputset_size < minimum_required_inputset_size:
message = (
f"Provided inputset contains too few inputs "
f"(it should have had at least {minimum_required_inputset_size} "
f"but it only had {inputset_size})\n"
)
if compilation_configuration.treat_warnings_as_errors:
raise ValueError(message)
sys.stderr.write(f"Warning: {message}")
# Add the bounds as an artifact
compilation_artifacts.add_final_operation_graph_bounds(node_bounds_and_samples)
# Update the graph accordingly: after that, we have the compilable graph
op_graph.update_values_with_bounds_and_samples(
node_bounds_and_samples,
get_base_data_type_for_numpy_or_python_constant_data,
get_constructor_for_numpy_or_python_constant_data,
)
return node_bounds_and_samples
def measure_op_graph_bounds_and_update(
op_graph: OPGraph,
function_parameters: Dict[str, BaseValue],
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]], str],
compilation_configuration: Optional[CompilationConfiguration] = None,
compilation_artifacts: Optional[CompilationArtifacts] = None,
prev_node_bounds_and_samples: Optional[Dict[IntermediateNode, Dict[str, Any]]] = None,
warn_on_inputset_length: bool = True,
) -> Dict[IntermediateNode, Dict[str, Any]]:
"""Measure the intermediate values and update the OPGraph accordingly for the given inputset.
Args:
op_graph (OPGraph): the OPGraph for which to measure bounds and update node values.
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]], str]): The inputset over which
op_graph is evaluated. It needs to be an iterable on tuples which are of the same length
than the number of parameters in the function, and in the same order than these same
parameters
compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use
during compilation
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
during compilation
prev_node_bounds_and_samples (Optional[Dict[IntermediateNode, Dict[str, Any]]], optional):
Bounds and samples from a previous run. Defaults to None.
warn_on_inputset_length (bool, optional): Set to True to get a warning if inputset is not
long enough. Defaults to True.
Raises:
ValueError: Raises an error if the inputset is too small and the compilation configuration
treats warnings as error.
Returns:
Dict[IntermediateNode, Dict[str, Any]]: a dict containing the bounds for each node from
op_graph, stored with the node as key and a dict with keys "min", "max" and "sample" as
value.
"""
(
compilation_configuration,
compilation_artifacts,
) = sanitize_compilation_configuration_and_artifacts(
compilation_configuration, compilation_artifacts
)
inputset = get_inputset_to_use(function_parameters, inputset, compilation_configuration)
def compilation_function():
return _measure_op_graph_bounds_and_update_internal(
op_graph,
function_parameters,
inputset,
compilation_configuration,
compilation_artifacts,
prev_node_bounds_and_samples,
warn_on_inputset_length,
)
result = run_compilation_function_with_error_management(
compilation_function, compilation_configuration, compilation_artifacts
)
# for mypy
assert isinstance(result, dict)
return result
def _compile_numpy_function_into_op_graph_and_measure_bounds_internal(
function_to_compile: Callable,
function_parameters: Dict[str, BaseValue],
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]],
compilation_configuration: CompilationConfiguration,
compilation_artifacts: CompilationArtifacts,
) -> OPGraph:
"""Compile a function into an OPGraph and evaluate the intermediate nodes bounds.
Args:
function_to_compile (Callable): The function to compile
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]]]): The inputset over which op_graph
is evaluated. It needs to be an iterable on tuples which are of the same length than the
number of parameters in the function, and in the same order than these same parameters
compilation_configuration (CompilationConfiguration): Configuration object to use
during compilation
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
during compilation
Returns:
OPGraph: compiled function into a graph with estimated bounds in node values.
"""
op_graph = _compile_numpy_function_into_op_graph_internal(
function_to_compile,
function_parameters,
compilation_configuration,
compilation_artifacts,
)
_measure_op_graph_bounds_and_update_internal(
op_graph,
function_parameters,
inputset,
compilation_configuration,
compilation_artifacts,
)
# Add the final graph as an artifact
compilation_artifacts.add_operation_graph("final", op_graph)
return op_graph
def compile_numpy_function_into_op_graph_and_measure_bounds(
function_to_compile: Callable,
function_parameters: Dict[str, BaseValue],
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]], str],
compilation_configuration: Optional[CompilationConfiguration] = None,
compilation_artifacts: Optional[CompilationArtifacts] = None,
) -> OPGraph:
"""Compile a function into an OPGraph.
Args:
function_to_compile (Callable): The function to compile
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]], str]): The inputset over which
op_graph is evaluated. It needs to be an iterable on tuples which are of the same length
than the number of parameters in the function, and in the same order than these same
parameters. Alternatively, it can be "random" but that's an unstable feature and should
not be used in production.
compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use
during compilation
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
during compilation
Returns:
OPGraph: compiled function into a graph
"""
(
compilation_configuration,
compilation_artifacts,
) = sanitize_compilation_configuration_and_artifacts(
compilation_configuration, compilation_artifacts
)
inputset = get_inputset_to_use(function_parameters, inputset, compilation_configuration)
def compilation_function():
return _compile_numpy_function_into_op_graph_and_measure_bounds_internal(
function_to_compile,
function_parameters,
inputset,
compilation_configuration,
compilation_artifacts,
)
result = run_compilation_function_with_error_management(
compilation_function, compilation_configuration, compilation_artifacts
)
# for mypy
assert isinstance(result, OPGraph)
return result
# HACK
# TODO: remove this ugly hack when
# https://github.com/zama-ai/concrete-numpy-internal/issues/1001 is done
# TODO: https://github.com/zama-ai/concrete-numpy-internal/issues/1015
def hack_offset_negative_inputs_to_lookup_tables(op_graph: OPGraph) -> None:
"""Hack the op_graph to add offsets to signed inputs to TLUs.
Args:
op_graph (OPGraph): the OPGraph to hack.
"""
# Ugly hack to add an offset before entering a TLU if its variable input node has a signed
# output.
# This is ugly as this makes hardcoded assumptions about the way bit widths are handled in MLIR.
# This does not update the TLU input values to allow for proper table generation.
# Thankfully we are not supposed to touch the op_graph beyond that point
for node in list((nx_graph := op_graph.graph).nodes):
if isinstance(node, GenericFunction) and node.op_kind == "TLU":
ordered_preds_and_inputs = op_graph.get_ordered_preds_and_inputs_of(node)
variable_input_indices = [
idx
for idx, (pred, _) in enumerate(ordered_preds_and_inputs)
if not isinstance(pred, Constant)
]
assert_true(len(variable_input_indices) == 1)
variable_input_idx = variable_input_indices[0]
variable_input_node = ordered_preds_and_inputs[variable_input_idx][0]
variable_input_value = variable_input_node.outputs[0]
variable_input_dtype = variable_input_value.dtype
assert_true(isinstance(variable_input_dtype, Integer))
variable_input_dtype = cast(Integer, variable_input_dtype)
if not variable_input_dtype.is_signed:
continue
# input_bit_width + 1 to be MLIR compliant
input_bit_width = variable_input_dtype.bit_width
mlir_compliant_int_type = Integer(input_bit_width + 1, True)
# Manually fix the output values to be MLIR compliant
# offset_constant is set to abs(min_value) for the variable input so that the values
# [- 2 ** (n - 1); 2 ** (n - 1) - 1] is mapped to [0; 2 ** n - 1], changing the signed
# TLU to an actual unsigned TLU. The get_table function creates the table from the min
# value to the max value. As we keep the input value as a signed value, it will be from
# - 2 ** (n - 1) to 2 ** (n - 1) - 1. Then, the get_table function stores corresponding
# values in increasing indexes from 0 to 2 ** n - 1. As our signed values have been
# shifted by 2 ** (n - 1), the table will be usable as-is, without needing any change in
# the lambda function of the GenericFunction.
offset_constant = Constant(abs(variable_input_dtype.min_value()))
offset_constant.outputs[0].dtype = deepcopy(mlir_compliant_int_type)
add_offset = Add(
[deepcopy(variable_input_value), ClearScalar(deepcopy(mlir_compliant_int_type))]
)
add_offset.outputs[0] = deepcopy(variable_input_value)
nx_graph.remove_edge(variable_input_node, node)
nx_graph.add_edge(variable_input_node, add_offset, input_idx=0, output_idx=0)
nx_graph.add_edge(offset_constant, add_offset, input_idx=1, output_idx=0)
nx_graph.add_edge(add_offset, node, input_idx=variable_input_idx, output_idx=0)
def prepare_op_graph_for_mlir(op_graph: OPGraph):
"""Prepare OPGraph for MLIR lowering.
This includes checking compatibility, changing bit-widths, and modifying lookup tables.
Args:
op_graph (OPGraph): The operation graph to prepare
Returns:
None
"""
# Make sure the graph can be lowered to MLIR
offending_nodes = check_graph_values_compatibility_with_mlir(op_graph)
if offending_nodes is not None:
raise RuntimeError(
"function you are trying to compile isn't supported for MLIR lowering\n\n"
+ format_operation_graph(op_graph, highlighted_nodes=offending_nodes)
)
# Update bit_width for MLIR
update_bit_width_for_mlir(op_graph)
# HACK
# TODO: remove this ugly hack when
# https://github.com/zama-ai/concrete-numpy-internal/issues/1001 is done
# TODO: https://github.com/zama-ai/concrete-numpy-internal/issues/1015
hack_offset_negative_inputs_to_lookup_tables(op_graph)
def _compile_op_graph_to_fhe_circuit_internal(
op_graph: OPGraph,
show_mlir: bool,
compilation_configuration: CompilationConfiguration,
compilation_artifacts: CompilationArtifacts,
) -> FHECircuit:
"""Compile the OPGraph to an FHECircuit.
Args:
op_graph (OPGraph): the OPGraph to compile.
show_mlir (bool): determine whether we print the mlir string.
compilation_configuration (CompilationConfiguration): Configuration object to use
during compilation
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
during compilation
Returns:
FHECircuit: the compiled FHECircuit
"""
prepare_op_graph_for_mlir(op_graph)
# Convert graph to an MLIR representation
converter = NPMLIRConverter()
mlir_result = converter.convert(op_graph)
# Show MLIR representation if requested
if show_mlir:
print(f"MLIR which is going to be compiled: \n{mlir_result}")
# Add MLIR representation as an artifact
compilation_artifacts.add_final_operation_graph_mlir(mlir_result)
if _COMPILE_FHE_INSECURE_KEY_CACHE_DIR is not None and not (
compilation_configuration.use_insecure_key_cache
and compilation_configuration.enable_unsafe_features
):
raise RuntimeError(
f"Unable to use insecure key cache {_COMPILE_FHE_INSECURE_KEY_CACHE_DIR} "
"as use_insecure_key_cache or enable_unsafe_features are not set to True in"
"compilation_configuration"
)
return FHECircuit(
op_graph,
mlir_result,
unsecure_key_set_cache_path=_COMPILE_FHE_INSECURE_KEY_CACHE_DIR,
auto_parallelize=compilation_configuration.auto_parallelize,
loop_parallelize=compilation_configuration.loop_parallelize,
dataflow_parallelize=compilation_configuration.dataflow_parallelize,
)
def compile_op_graph_to_fhe_circuit(
op_graph: OPGraph,
show_mlir: bool,
compilation_configuration: Optional[CompilationConfiguration] = None,
compilation_artifacts: Optional[CompilationArtifacts] = None,
) -> FHECircuit:
"""Compile the OPGraph to an FHECircuit.
Args:
op_graph (OPGraph): the OPGraph to compile.
show_mlir (bool): determine whether we print the mlir string.
compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use
during compilation
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
during compilation
Returns:
FHECircuit: the compiled circuit and the compiled FHECircuit
"""
(
compilation_configuration,
compilation_artifacts,
) = sanitize_compilation_configuration_and_artifacts(
compilation_configuration, compilation_artifacts
)
def compilation_function():
return _compile_op_graph_to_fhe_circuit_internal(
op_graph, show_mlir, compilation_configuration, compilation_artifacts
)
result = run_compilation_function_with_error_management(
compilation_function, compilation_configuration, compilation_artifacts
)
# for mypy
assert isinstance(result, FHECircuit)
return result
def _compile_numpy_function_internal(
function_to_compile: Callable,
function_parameters: Dict[str, BaseValue],
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]],
compilation_configuration: CompilationConfiguration,
compilation_artifacts: CompilationArtifacts,
show_mlir: bool,
) -> FHECircuit:
"""Compile an homomorphic program (internal part of the API).
Args:
function_to_compile (Callable): The function you want to compile
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]]]): The inputset over which op_graph
is evaluated. It needs to be an iterable on tuples which are of the same length than the
number of parameters in the function, and in the same order than these same parameters
compilation_configuration (CompilationConfiguration): Configuration object to use
during compilation
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
during compilation
show_mlir (bool): if set, the MLIR produced by the converter and which is going
to be sent to the compiler backend is shown on the screen, e.g., for debugging or demo
Returns:
CompilerEngine: engine to run and debug the compiled graph
"""
# Compile into an OPGraph
op_graph = _compile_numpy_function_into_op_graph_and_measure_bounds_internal(
function_to_compile,
function_parameters,
inputset,
compilation_configuration,
compilation_artifacts,
)
fhe_circuit = _compile_op_graph_to_fhe_circuit_internal(
op_graph, show_mlir, compilation_configuration, compilation_artifacts
)
return fhe_circuit
def compile_numpy_function(
function_to_compile: Callable,
function_parameters: Dict[str, BaseValue],
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]], str],
compilation_configuration: Optional[CompilationConfiguration] = None,
compilation_artifacts: Optional[CompilationArtifacts] = None,
show_mlir: bool = False,
) -> FHECircuit:
"""Compile an homomorphic program (main API).
Args:
function_to_compile (Callable): The function to compile
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]], str]): The inputset over which
op_graph is evaluated. It needs to be an iterable on tuples which are of the same length
than the number of parameters in the function, and in the same order than these same
parameters. Alternatively, it can be "random" but that's an unstable feature and should
not be used in production.
compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use
during compilation
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
during compilation
show_mlir (bool): if set, the MLIR produced by the converter and which is going
to be sent to the compiler backend is shown on the screen, e.g., for debugging or demo
Returns:
CompilerEngine: engine to run and debug the compiled graph
"""
(
compilation_configuration,
compilation_artifacts,
) = sanitize_compilation_configuration_and_artifacts(
compilation_configuration, compilation_artifacts
)
inputset = get_inputset_to_use(function_parameters, inputset, compilation_configuration)
def compilation_function():
return _compile_numpy_function_internal(
function_to_compile,
function_parameters,
inputset,
compilation_configuration,
compilation_artifacts,
show_mlir,
)
result = run_compilation_function_with_error_management(
compilation_function, compilation_configuration, compilation_artifacts
)
# for mypy
assert isinstance(result, FHECircuit)
return result

View File

@@ -1,308 +0,0 @@
"""File to hold code to manage package and numpy dtypes."""
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, List, Tuple, Union
import numpy
from numpy.typing import DTypeLike
from ..common.data_types.base import BaseDataType
from ..common.data_types.dtypes_helpers import (
BASE_DATA_TYPES,
find_type_to_hold_both_lossy,
get_base_data_type_for_python_constant_data,
get_base_value_for_python_constant_data,
get_constructor_for_python_constant_data,
)
from ..common.data_types.floats import Float
from ..common.data_types.integers import Integer
from ..common.debugging.custom_assert import assert_true
from ..common.tracing import BaseTracer
from ..common.values import BaseValue, TensorValue
NUMPY_TO_COMMON_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = {
numpy.dtype(numpy.byte): Integer(numpy.byte(0).nbytes * 8, is_signed=True),
numpy.dtype(numpy.short): Integer(numpy.short(0).nbytes * 8, is_signed=True),
numpy.dtype(numpy.intc): Integer(numpy.intc(0).nbytes * 8, is_signed=True),
numpy.dtype(numpy.int_): Integer(numpy.int_(0).nbytes * 8, is_signed=True),
numpy.dtype(numpy.longlong): Integer(numpy.longlong(0).nbytes * 8, is_signed=True),
numpy.dtype(numpy.int8): Integer(numpy.int8(0).nbytes * 8, is_signed=True),
numpy.dtype(numpy.int16): Integer(numpy.int16(0).nbytes * 8, is_signed=True),
numpy.dtype(numpy.int32): Integer(numpy.int32(0).nbytes * 8, is_signed=True),
numpy.dtype(numpy.int64): Integer(numpy.int64(0).nbytes * 8, is_signed=True),
numpy.dtype(numpy.ubyte): Integer(numpy.ubyte(0).nbytes * 8, is_signed=False),
numpy.dtype(numpy.ushort): Integer(numpy.ushort(0).nbytes * 8, is_signed=False),
numpy.dtype(numpy.uintc): Integer(numpy.uintc(0).nbytes * 8, is_signed=False),
numpy.dtype(numpy.uint): Integer(numpy.uint(0).nbytes * 8, is_signed=False),
numpy.dtype(numpy.ulonglong): Integer(numpy.ulonglong(0).nbytes * 8, is_signed=False),
numpy.dtype(numpy.uint8): Integer(numpy.uint8(0).nbytes * 8, is_signed=False),
numpy.dtype(numpy.uint16): Integer(numpy.uint16(0).nbytes * 8, is_signed=False),
numpy.dtype(numpy.uint32): Integer(numpy.uint32(0).nbytes * 8, is_signed=False),
numpy.dtype(numpy.uint64): Integer(numpy.uint64(0).nbytes * 8, is_signed=False),
numpy.dtype(numpy.float16): Float(16),
numpy.dtype(numpy.float32): Float(32),
numpy.dtype(numpy.float64): Float(64),
numpy.dtype(bool): Integer(8, is_signed=False),
}
SUPPORTED_NUMPY_DTYPES = tuple(NUMPY_TO_COMMON_DTYPE_MAPPING)
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES = tuple(dtype.type for dtype in NUMPY_TO_COMMON_DTYPE_MAPPING)
SUPPORTED_DTYPE_MSG_STRING = ", ".join(sorted(str(dtype) for dtype in SUPPORTED_NUMPY_DTYPES))
def convert_numpy_dtype_to_base_data_type(numpy_dtype: DTypeLike) -> BaseDataType:
"""Get the corresponding BaseDataType from a numpy dtype.
Args:
numpy_dtype (DTypeLike): Any python object that can be translated to a numpy.dtype
Raises:
ValueError: If the numpy_dtype is not supported
Returns:
BaseDataType: The corresponding data type corresponding to the input numpy_dtype
"""
# Normalize numpy_dtype
normalized_numpy_dtype = numpy.dtype(numpy_dtype)
corresponding_common_dtype = NUMPY_TO_COMMON_DTYPE_MAPPING.get(normalized_numpy_dtype, None)
if corresponding_common_dtype is None:
raise ValueError(
f"Unsupported numpy type: {numpy_dtype} ({normalized_numpy_dtype}), "
f"supported numpy types: "
f"{SUPPORTED_DTYPE_MSG_STRING}"
)
# deepcopy to avoid having the value from the dict modified
return deepcopy(corresponding_common_dtype)
def convert_base_data_type_to_numpy_dtype(common_dtype: BaseDataType) -> numpy.dtype:
"""Convert a BaseDataType to corresponding numpy.dtype.
Args:
common_dtype (BaseDataType): dtype to convert to numpy.dtype
Returns:
numpy.dtype: The resulting numpy.dtype
"""
assert_true(
isinstance(common_dtype, BASE_DATA_TYPES), f"Unsupported common_dtype: {type(common_dtype)}"
)
type_to_return: numpy.dtype
if isinstance(common_dtype, Float):
assert_true(
(bit_width := common_dtype.bit_width)
in (
16,
32,
64,
),
"Only converting Float(16), Float(32) or Float(64) is supported",
)
if bit_width == 64:
type_to_return = numpy.dtype(numpy.float64)
elif bit_width == 32:
type_to_return = numpy.dtype(numpy.float32)
else:
type_to_return = numpy.dtype(numpy.float16)
elif isinstance(common_dtype, Integer):
signed = common_dtype.is_signed
if common_dtype.bit_width <= 32:
type_to_return = numpy.dtype(numpy.int32) if signed else numpy.dtype(numpy.uint32)
elif common_dtype.bit_width <= 64:
type_to_return = numpy.dtype(numpy.int64) if signed else numpy.dtype(numpy.uint64)
else:
raise NotImplementedError(
f"Conversion to numpy dtype only supports Integers with bit_width <= 64, "
f"got {common_dtype!r}"
)
return type_to_return
def get_base_data_type_for_numpy_or_python_constant_data(constant_data: Any) -> BaseDataType:
"""Determine the BaseDataType to hold the input constant data.
Args:
constant_data (Any): The constant data for which to determine the
corresponding BaseDataType.
Returns:
BaseDataType: The corresponding BaseDataType
"""
base_dtype: BaseDataType
assert_true(
isinstance(
constant_data, (int, float, list, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
),
f"Unsupported constant data of type {type(constant_data)}",
)
if isinstance(constant_data, (numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)):
native_type = float if (constant_data.dtype in (numpy.float32, numpy.float64)) else int
min_value = native_type(constant_data.min())
max_value = native_type(constant_data.max())
min_value_dtype = get_base_data_type_for_python_constant_data(min_value)
max_value_dtype = get_base_data_type_for_python_constant_data(max_value)
# numpy
base_dtype = find_type_to_hold_both_lossy(min_value_dtype, max_value_dtype)
else:
# python
base_dtype = get_base_data_type_for_python_constant_data(constant_data)
return base_dtype
def get_base_value_for_numpy_or_python_constant_data(
constant_data: Any,
) -> Callable[..., BaseValue]:
"""Determine the BaseValue and BaseDataType to hold the input constant data.
This function is able to handle numpy types
Args:
constant_data (Any): The constant data for which to determine the
corresponding BaseValue and BaseDataType.
Raises:
AssertionError: If `constant_data` is of an unsupported type.
Returns:
Callable[..., BaseValue]: A partial object that will return the proper BaseValue when called
with `encrypted` as keyword argument (forwarded to the BaseValue `__init__` method).
"""
constant_data_value: Callable[..., BaseValue]
assert_true(
not isinstance(constant_data, list),
"Unsupported constant data of type list "
"(if you meant to use a list as an array, please use numpy.array instead)",
)
assert_true(
isinstance(
constant_data,
(int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES),
),
f"Unsupported constant data of type {type(constant_data)}",
)
base_dtype = get_base_data_type_for_numpy_or_python_constant_data(constant_data)
if isinstance(constant_data, numpy.ndarray):
constant_data_value = partial(TensorValue, dtype=base_dtype, shape=constant_data.shape)
elif isinstance(constant_data, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES):
constant_data_value = partial(TensorValue, dtype=base_dtype, shape=())
else:
constant_data_value = get_base_value_for_python_constant_data(constant_data)
return constant_data_value
def get_numpy_function_output_dtype_and_shape_from_input_dtypes(
function: Union[numpy.ufunc, Callable],
input_dtypes: List[BaseDataType],
input_shapes: List[Tuple[int, ...]],
) -> List[Tuple[numpy.dtype, Tuple[int, ...]]]:
"""Record the output dtype of a numpy function given some input types.
Args:
function (Union[numpy.ufunc, Callable]): The numpy function whose output types need to
be recorded
input_dtypes (List[BaseDataType]): BaseDataTypes in the same order as they will be used with
the function inputs
input_shapes (List[Tuple[int, ...]]): Shapes in the same order as they will be used with
the function inputs
Returns:
List[Tuple[numpy.dtype, Tuple[int, ...]]]: appropriate (numpy.dtype, shape) tuple for each
output of the function
"""
if isinstance(function, numpy.ufunc):
assert_true(
(len(input_dtypes) == function.nin),
f"Expected {function.nin} types, got {len(input_dtypes)}: {input_dtypes}",
)
input_numpy_dtypes = [convert_base_data_type_to_numpy_dtype(dtype) for dtype in input_dtypes]
dummy_inputs = tuple(
(
dtype.type(10.0 * numpy.random.random_sample())
if shape == ()
else numpy.abs(numpy.random.randn(*shape) * 10.0).astype(dtype)
)
for dtype, shape in zip(input_numpy_dtypes, input_shapes)
)
# We ignore errors as we may call functions with invalid inputs just to get the proper output
# dtypes
with numpy.errstate(all="ignore"):
outputs = function(*dummy_inputs)
if not isinstance(outputs, tuple):
outputs = (outputs,)
return [(output.dtype, output.shape) for output in outputs]
def get_numpy_function_output_dtype_and_shape_from_input_tracers(
func: Union[numpy.ufunc, Callable],
*input_tracers: BaseTracer,
) -> List[Tuple[BaseDataType, Tuple[int, ...]]]:
"""Determine output dtypes and shapes for a numpy function.
This function is responsible for determining the output dtype
of a numpy function after inputs with specific dtypes are passed to it.
Args:
func (Union[numpy.ufunc, Callable]): function that is being managed
*input_tracers (BaseTracer): inputs to the function
Returns:
List[Tuple[BaseDataType, Tuple[int, ...]]]: appropriate (BaseDataType, shape) tuple for each
output of the function
"""
input_shapes = [
input_tracer.output.shape if isinstance(input_tracer.output, TensorValue) else ()
for input_tracer in input_tracers
]
output_dtypes_and_shapes = get_numpy_function_output_dtype_and_shape_from_input_dtypes(
func,
[input_tracer.output.dtype for input_tracer in input_tracers],
input_shapes,
)
common_output_dtypes = [
(convert_numpy_dtype_to_base_data_type(dtype), shape)
for dtype, shape in output_dtypes_and_shapes
]
return common_output_dtypes
def get_constructor_for_numpy_or_python_constant_data(constant_data: Any):
"""Get the constructor for the numpy constant data or python dtype.
Args:
constant_data (Any): The data for which we want to determine the type constructor.
"""
assert_true(
isinstance(
constant_data, (int, float, list, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
),
f"Unsupported constant data of type {type(constant_data)}",
)
if isinstance(constant_data, list):
# this is required because some operations return python lists from their evaluate function
# an example of such operation is evaluation of multi tlu during bound measurements
constant_data = numpy.array(constant_data)
if isinstance(constant_data, (numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)):
if isinstance(constant_data, numpy.ndarray):
return lambda x: numpy.full(constant_data.shape, x, dtype=constant_data.dtype)
return constant_data.dtype.type
return get_constructor_for_python_constant_data(constant_data)

View File

@@ -1,309 +0,0 @@
"""Module to hold a user friendly class to compile programs."""
import itertools
from copy import deepcopy
from enum import Enum, unique
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from loguru import logger
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
from ..common.data_types import Integer
from ..common.debugging import draw_graph, format_operation_graph
from ..common.fhe_circuit import FHECircuit
from ..common.operator_graph import OPGraph
from ..common.representation.intermediate import IntermediateNode
from ..common.values import BaseValue
from .compile import (
compile_numpy_function_into_op_graph,
compile_op_graph_to_fhe_circuit,
measure_op_graph_bounds_and_update,
)
from .np_dtypes_helpers import get_base_value_for_numpy_or_python_constant_data
@unique
class EncryptedStatus(str, Enum):
"""Enum to validate GenericFunction op_kind."""
CLEAR = "clear"
ENCRYPTED = "encrypted"
class NPFHECompiler:
"""Class to ease the conversion of a numpy program to its FHE equivalent."""
INPUTSET_SIZE_BEFORE_AUTO_BOUND_UPDATE: int = 128
# _function_to_compile is not optional but mypy has a long standing bug and is not able to
# understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623
_function_to_compile: Optional[Callable]
_function_parameters_encrypted_status: Dict[str, bool]
_current_inputset: List[Union[Any, Tuple]]
_op_graph: Optional[OPGraph]
_nodes_and_bounds: Dict[IntermediateNode, Dict[str, Any]]
_compilation_configuration: CompilationConfiguration
compilation_artifacts: CompilationArtifacts
def __init__(
self,
function_to_compile: Callable,
function_parameters_encrypted_status: Dict[str, Union[str, EncryptedStatus]],
compilation_configuration: Optional[CompilationConfiguration] = None,
compilation_artifacts: Optional[CompilationArtifacts] = None,
) -> None:
self._function_to_compile = function_to_compile
self._function_parameters_encrypted_status = {
param_name: EncryptedStatus(status.lower()) == EncryptedStatus.ENCRYPTED
for param_name, status in function_parameters_encrypted_status.items()
}
self._current_inputset = []
self._op_graph = None
self._nodes_and_bounds = {}
self._compilation_configuration = (
deepcopy(compilation_configuration)
if compilation_configuration is not None
else CompilationConfiguration()
)
self.compilation_artifacts = (
compilation_artifacts if compilation_artifacts is not None else CompilationArtifacts()
)
@property
def function_to_compile(self) -> Callable:
"""Get the function to compile.
Returns:
Callable: the function to compile.
"""
# Continuation of mypy bug
assert self._function_to_compile is not None
return self._function_to_compile
@property
def op_graph(self) -> Optional[OPGraph]:
"""Return a copy of the OPGraph.
Returns:
Optional[OPGraph]: the held OPGraph or None
"""
# To keep consistency with what the user expects, we make sure to evaluate on the remaining
# inputset values if any before giving a copy of the OPGraph we trace
self._eval_on_current_inputset()
return deepcopy(self._op_graph)
@property
def compilation_configuration(self) -> Optional[CompilationConfiguration]:
"""Get a copy of the compilation configuration.
Returns:
Optional[CompilationConfiguration]: copy of the current compilation configuration.
"""
return deepcopy(self._compilation_configuration)
def __call__(self, *args: Any) -> Any:
"""Evaluate the OPGraph corresponding to the function being compiled and return result.
Returns:
Any: the result of the OPGraph evaluation.
"""
self._current_inputset.append(deepcopy(args) if len(args) > 1 else deepcopy(args[0]))
inferred_args = {
param_name: get_base_value_for_numpy_or_python_constant_data(val)(
is_encrypted=is_encrypted
)
for (param_name, is_encrypted), val in zip(
self._function_parameters_encrypted_status.items(), args
)
}
if len(self._current_inputset) >= self.INPUTSET_SIZE_BEFORE_AUTO_BOUND_UPDATE:
self._eval_on_current_inputset()
self._trace_op_graph_if_needed(inferred_args)
# For mypy
assert self._op_graph is not None
return self._op_graph(*args)
def __str__(self) -> str:
self._eval_on_current_inputset()
if self._op_graph is None:
warning_msg = (
f"__str__ failed: OPGraph is None, {self.__class__.__name__} "
"needs evaluation on an inputset"
)
logger.warning(warning_msg)
return warning_msg
return format_operation_graph(self._op_graph)
def draw_graph(
self,
show: bool = False,
vertical: bool = True,
save_to: Optional[Path] = None,
) -> Optional[str]:
"""Draws operation graphs and optionally saves/shows the drawing.
Args:
op_graph (OPGraph): the operation graph to be drawn and optionally saved/shown
show (bool): if set to True, the drawing will be shown using matplotlib
vertical (bool): if set to True, the orientation will be vertical
save_to (Optional[Path]): if specified, the drawn graph will be saved to this path; else
it is saved in a temporary file
Returns:
Optional[str]: if OPGraph was not None returns the path as a string of the file where
the drawn graph is saved
"""
self._eval_on_current_inputset()
if self._op_graph is None:
logger.warning(
f"{self.draw_graph.__name__} failed: OPGraph is None, {self.__class__.__name__} "
"needs evaluation on an inputset"
)
return None
return draw_graph(self._op_graph, show, vertical, save_to)
def eval_on_inputset(
self,
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]],
warn_on_inputset_length: bool = False,
) -> None:
"""Evaluate the underlying function on an inputset in one go, populates OPGraph and bounds.
Args:
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]]]): The inputset on which the
function should be evaluated.
warn_on_inputset_length (bool, optional): Set to True to get a warning
if inputset is not long enough. Defaults to False.
"""
inputset_iter = iter(inputset)
try:
first_sample = next(inputset_iter)
except StopIteration:
return
inferred_args = {
param_name: get_base_value_for_numpy_or_python_constant_data(val)(
is_encrypted=is_encrypted
)
for (param_name, is_encrypted), val in zip(
self._function_parameters_encrypted_status.items(),
first_sample
if len(self._function_parameters_encrypted_status) > 1
else (first_sample,),
)
}
self._trace_op_graph_if_needed(inferred_args)
# For mypy
assert self._op_graph is not None
self._patch_op_graph_input_to_accept_any_integer_input()
self._nodes_and_bounds = measure_op_graph_bounds_and_update(
self._op_graph,
inferred_args,
itertools.chain((first_sample,), inputset_iter),
self._compilation_configuration,
self.compilation_artifacts,
self._nodes_and_bounds,
warn_on_inputset_length,
)
def _eval_on_current_inputset(self) -> None:
"""Evaluate OPGraph on _current_inputset."""
self.eval_on_inputset(self._current_inputset)
self._current_inputset.clear()
def _needs_tracing(self) -> bool:
"""Return whether we need to trace the function and populate the OPGraph."""
return self._op_graph is None
def _trace_op_graph_if_needed(self, inferred_args: Dict[str, BaseValue]) -> None:
"""Populate _op_graph with the OPGraph for _function_to_compile."""
if not self._needs_tracing():
return
self._op_graph = compile_numpy_function_into_op_graph(
self.function_to_compile,
inferred_args,
self._compilation_configuration,
self.compilation_artifacts,
)
def _patch_op_graph_input_to_accept_any_integer_input(self) -> None:
"""Patch inputs as we don't know what data we expect."""
# Can only do that if the OPGraph was created hence the test.
if self._needs_tracing():
return
# For mypy
assert self._op_graph is not None
# Cheat on Input nodes to avoid issues during inputset eval as we do not know in advance
# what the final bit width for the inputs should be
for node in self._op_graph.input_nodes.values():
for input_ in node.inputs:
if isinstance(dtype := (input_.dtype), Integer):
dtype.bit_width = 128
dtype.is_signed = True
def compile_on_inputset(
self,
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]],
show_mlir: bool = False,
) -> FHECircuit:
"""Compile the function on an inputset and get resulting FHECircuit.
Args:
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]]]):
The inputset on which the function is evaluated.
show_mlir (bool, optional, defaults to False):
The flag to enable printing the MLIR that is being compiled for debugging purposes.
Returns:
FHECircuit: the compiled FHECircuit
"""
self.eval_on_inputset(inputset)
return self.get_compiled_fhe_circuit(show_mlir)
def get_compiled_fhe_circuit(self, show_mlir: bool = False) -> FHECircuit:
"""Return a compiled FHECircuit if the instance was evaluated on an inputset.
Args:
show_mlir (bool, optional): if set, the MLIR produced by the converter and which is
going to be sent to the compiler backend is shown on the screen, e.g., for debugging
or demo. Defaults to False.
Raises:
RuntimeError: raised if no inputset was passed to the instance.
Returns:
FHECircuit: the compiled FHECircuit
"""
self._eval_on_current_inputset()
if self._op_graph is None:
raise RuntimeError(
"Requested FHECircuit but no OPGraph was compiled. "
f"Did you forget to evaluate {self.__class__.__name__} over an inputset?"
)
return compile_op_graph_to_fhe_circuit(
self._op_graph,
show_mlir,
self.compilation_configuration,
self.compilation_artifacts,
)

View File

@@ -1,59 +0,0 @@
"""Helpers for indexing with numpy values functionality."""
from typing import Any
import numpy
def should_sanitize(indexing_element: Any) -> bool:
"""Decide whether to sanitize an indexing element or not.
Sanitizing in this context means converting supported numpy values into python values.
Args:
indexing_element (Any): the indexing element to decide sanitization.
Returns:
bool: True if indexing element should be sanitized otherwise False.
"""
return isinstance(indexing_element, numpy.integer) or (
isinstance(indexing_element, numpy.ndarray)
and issubclass(indexing_element.dtype.type, numpy.integer)
and indexing_element.shape == ()
)
def process_indexing_element(indexing_element: Any) -> Any:
"""Process an indexing element.
Processing in this context means converting supported numpy values into python values.
(if they are decided to be sanitized)
Args:
indexing_element (Any): the indexing element to sanitize.
Returns:
Any: the sanitized indexing element.
"""
if isinstance(indexing_element, slice):
start = indexing_element.start
if should_sanitize(start):
start = int(start)
stop = indexing_element.stop
if should_sanitize(stop):
stop = int(stop)
step = indexing_element.step
if should_sanitize(step):
step = int(step)
indexing_element = slice(start, stop, step)
elif should_sanitize(indexing_element):
indexing_element = int(indexing_element)
return indexing_element

View File

@@ -1,157 +0,0 @@
"""Helpers for numpy inputset related functionality."""
import random
from typing import Any, Dict, Iterable, Tuple, Union
import numpy
from ..common.compilation import CompilationConfiguration
from ..common.data_types import Float, Integer
from ..common.values import BaseValue, TensorValue
def _generate_random_integer_scalar(dtype: Integer) -> int:
"""Generate a random integer scalar.
Args:
dtype (Integer): the data type to extract bounds
Returns:
int: a random value within the range [dtype.min_value(), dtype.max_value()]
"""
return random.randint(dtype.min_value(), dtype.max_value())
def _generate_random_integer_tensor(dtype: Integer, shape: Tuple[int, ...]) -> numpy.ndarray:
"""Generate a random integer tensor.
Args:
dtype (Integer): the data type to extract bounds
shape (Tuple[int, ...]): the shape of the generated tensor
Returns:
numpy.ndarray: a random array of the specified shape where each value of it
is within the range [dtype.min_value(), dtype.max_value()]
"""
return numpy.random.randint(
dtype.min_value(),
dtype.max_value() + 1,
size=shape,
dtype=numpy.int64 if dtype.is_signed else numpy.uint64, # type: ignore
)
def _generate_random_float_scalar() -> float:
"""Generate a random float scalar.
Returns:
float: a random value within the range [0, 1)
"""
return random.random()
def _generate_random_float_tensor(dtype: Float, shape: Tuple[int, ...]) -> numpy.ndarray:
"""Generate a random float tensor.
Args:
dtype (Integer): the data type to extract resulting numpy data type
shape (Tuple[int, ...]): the shape of the generated tensor
Returns:
numpy.ndarray: a random array of the specified shape where each value of it
is within the range [0, 1)
"""
result = numpy.random.rand(*shape)
return result.astype(numpy.float32 if dtype.bit_width == 32 else numpy.float64)
def _generate_random_inputset(
function_parameters: Dict[str, BaseValue],
compilation_configuration: CompilationConfiguration,
) -> Union[Iterable[Any], Iterable[Tuple[Any, ...]]]:
"""Generate a random inputset from function parameters.
Using this function is not a good practice since the randomly generated inputset
might not reflect real world data. We have it to speed up our development workflow
and we also don't use it in any of our tests, benchmarks, or examples.
Args:
function_parameters (Dict[str, BaseValue]): the function parameters
to extract data types and shapes
compilation_configuration (CompilationConfiguration): the compilation configuration
to extract the sample size of the resulting inputset
Raises:
ValueError: if the provided function arguments cannot be used for random inputset generation
Returns:
Union[Iterable[Any], Iterable[Tuple[Any, ...]]]: the inputset
"""
inputset = []
for _ in range(compilation_configuration.random_inputset_samples):
sample = []
for parameter in function_parameters.values():
if not isinstance(parameter, TensorValue):
raise ValueError(f"Random inputset cannot be generated for {parameter} parameters")
if isinstance(parameter.dtype, Integer):
sample.append(
_generate_random_integer_scalar(parameter.dtype)
if parameter.is_scalar
else _generate_random_integer_tensor(parameter.dtype, parameter.shape)
)
elif isinstance(parameter.dtype, Float):
sample.append(
_generate_random_float_scalar()
if parameter.is_scalar
else _generate_random_float_tensor(parameter.dtype, parameter.shape)
)
else:
raise ValueError(
f"Random inputset cannot be generated "
f"for parameters of type {parameter.dtype}"
)
inputset.append(tuple(sample) if len(sample) > 1 else sample[0])
return inputset
def _check_special_inputset_availability(
inputset: str,
compilation_configuration: CompilationConfiguration,
):
"""Check special inputset is valid and is available.
This function makes sure the provided special inputset is valid and can be used with the
provided compilation configuration.
Currently, the only special inputset is "random" but this can be extended in the future.
Args:
inputset (str): the special inputset to check
compilation_configuration (CompilationConfiguration): the compilation configuration
to check the availability of the provided special inputset
Raises:
ValueError: if the provided special inputset is not valid
RuntimeError: if the provided special inputset is not available
Returns:
None
"""
if inputset != "random":
raise ValueError(
f"inputset can only be an iterable of tuples or the string 'random' "
f"but you specified '{inputset}' for it"
)
if not compilation_configuration.enable_unsafe_features:
raise RuntimeError(
"Random inputset generation is an unsafe feature and should not be used "
"if you don't know what you are doing"
)

View File

@@ -1,85 +0,0 @@
"""Numpy-specific MLIR converter."""
import math
from collections import defaultdict
from itertools import product
from typing import Any, DefaultDict, Dict, List, Tuple
import numpy
from ..common.debugging import assert_true
from ..common.mlir.graph_converter import OPGraphConverter
from ..common.operator_graph import OPGraph
from ..common.representation.intermediate import GenericFunction, IntermediateNode
class HashableNPArray:
"""Class to easily manipulate numpy arrays for hashing.
Note that the hash behavior won't work if the array is modified after being hashed, as it will
have been hashed to a certain value and the new array content will be hashed to a different one.
"""
array: numpy.ndarray
def __init__(self, array: numpy.ndarray) -> None:
self.array = array
def __hash__(self) -> int:
return hash(self.array.tobytes())
def __eq__(self, other: object) -> bool:
return isinstance(other, HashableNPArray) and numpy.array_equal(self.array, other.array)
def generate_deduplicated_tables(
node: GenericFunction, ordered_preds: List[IntermediateNode]
) -> Tuple[Tuple[numpy.ndarray, List[Tuple[int, ...]]], ...]:
"""Deduplicate the tables for the different cells of a tensor if needed.
Args:
node (GenericFunction): the node for which to deduplicate the table.
ordered_preds (List[IntermediateNode]): ordered list of predecessors of the node.
Returns:
Tuple[Tuple[numpy.ndarray, List[Tuple[int, ...]]], ...]: A tuple containing tuples whose
first element is a table and the second element is a list of tuples indicating which
cells in the tensor will use that table.
"""
# This is the tensor containing the tables for each cell of the tensor for node
node_complete_table = numpy.concatenate(
tuple(numpy.expand_dims(array, -1) for array in node.get_table(ordered_preds)), axis=-1
)
all_cells_idx = product(*tuple(range(max_val) for max_val in node_complete_table.shape[:-1]))
tables_to_cell_idx: DefaultDict[HashableNPArray, List[Tuple[int, ...]]] = defaultdict(list)
idx: Tuple[int, ...]
all_idx_set = set()
for idx in all_cells_idx:
hashable_array = HashableNPArray(node_complete_table[idx])
tables_to_cell_idx[hashable_array].append(idx)
all_idx_set.add(idx)
assert_true(len(all_idx_set) == math.prod(node_complete_table.shape[:-1]))
return tuple(
(hashable_array.array, indices) for hashable_array, indices in tables_to_cell_idx.items()
)
class NPMLIRConverter(OPGraphConverter):
"""Numpy-specific MLIR converter."""
@staticmethod
def _generate_additional_info_dict(op_graph: OPGraph) -> Dict[str, Any]:
additional_conversion_info = {}
# Disable numpy warnings during conversion to avoid issues during TLU generation
with numpy.errstate(all="ignore"):
additional_conversion_info["tables"] = {
node: generate_deduplicated_tables(node, op_graph.get_ordered_preds(node))
for node in op_graph.graph.nodes()
if isinstance(node, GenericFunction) and node.op_kind == "TLU"
}
return additional_conversion_info

View File

@@ -1,818 +0,0 @@
"""numpy tracing utilities."""
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
import numpy
from numpy.typing import DTypeLike
from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype
from ..common.debugging.custom_assert import assert_true
from ..common.operator_graph import OPGraph
from ..common.representation.intermediate import Constant, Dot, GenericFunction, MatMul
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
from ..common.tracing.tracing_helpers import tracing_context
from ..common.values import BaseValue, TensorValue
from .np_dtypes_helpers import (
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES,
convert_numpy_dtype_to_base_data_type,
get_base_value_for_numpy_or_python_constant_data,
get_numpy_function_output_dtype_and_shape_from_input_tracers,
)
from .np_indexing_helpers import process_indexing_element
SUPPORTED_TYPES_FOR_TRACING = (int, float, numpy.ndarray) + tuple(
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES
)
NPConstant = partial(
Constant,
get_base_value_for_data_func=get_base_value_for_numpy_or_python_constant_data,
)
class NPTracer(BaseTracer):
"""Tracer class for numpy operations."""
_mix_values_func: Callable[..., BaseValue] = mix_values_determine_holding_dtype
def __array_ufunc__(self, ufunc: numpy.ufunc, method, *args, **kwargs):
"""Catch calls to numpy ufunc and routes them to tracing functions if supported.
Read more: https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch
"""
if method == "__call__":
tracing_func = self.get_tracing_func_for_np_function(ufunc)
assert_true(
(len(kwargs) == 0),
f"**kwargs are currently not supported for numpy ufuncs, ufunc: {ufunc.__name__}",
)
# Create constant tracers for args, numpy only passes ufunc.nin args so we can
# sanitize all of them without issues
sanitized_args = [self._sanitize(arg) for arg in args]
return tracing_func(*sanitized_args, **kwargs)
raise NotImplementedError("Only __call__ method is supported currently")
def __array_function__(self, func, _types, args, kwargs):
"""Catch calls to numpy function in routes them to tracing functions if supported.
Read more: https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch
"""
tracing_func = self.get_tracing_func_for_np_function(func)
assert_true(
(tracing_func in [NPTracer.numpy_sum, NPTracer.numpy_concatenate]) or len(kwargs) == 0,
f"**kwargs are currently not supported for numpy functions, func: {func}",
)
# Fixme: Special case to be removed once #772 is done
if func is not numpy.reshape:
sanitized_args = [self._sanitize(arg) for arg in args]
else:
# In numpy.reshape, the second argument is the new shape
sanitized_args = [self._sanitize(args[0]), args[1]]
return tracing_func(self, sanitized_args[0], sanitized_args[1], **kwargs)
return tracing_func(self, *sanitized_args, **kwargs)
def astype(self, numpy_dtype: DTypeLike, *args, **kwargs) -> "NPTracer":
r"""Support numpy astype feature.
For now it only accepts a dtype and no additional parameters, \*args and
\*\*kwargs are accepted for interface compatibility only
Args:
numpy_dtype (DTypeLike): The object describing a numpy type
Returns:
NPTracer: The NPTracer representing the casting operation
"""
assert_true(
len(args) == 0, f"astype currently only supports tracing without *args, got {args}"
)
assert_true(
(len(kwargs) == 0),
f"astype currently only supports tracing without **kwargs, got {kwargs}",
)
normalized_numpy_dtype = numpy.dtype(numpy_dtype)
output_dtype = convert_numpy_dtype_to_base_data_type(numpy_dtype)
generic_function_output_value = deepcopy(self.output)
generic_function_output_value.dtype = output_dtype
traced_computation = GenericFunction(
inputs=[self.output],
arbitrary_func=lambda x, dtype: x.astype(dtype),
output_value=generic_function_output_value,
op_kind="TLU",
op_kwargs={"dtype": normalized_numpy_dtype.type},
op_name="astype",
)
output_tracer = self.__class__([self], traced_computation=traced_computation, output_idx=0)
return output_tracer
@staticmethod
def get_tracing_func_for_np_function(func: Union[numpy.ufunc, Callable]) -> Callable:
"""Get the tracing function for a numpy function.
Args:
func (Union[numpy.ufunc, Callable]): The numpy function that will be traced
Raises:
NotImplementedError: Raised if the passed function is not supported by NPTracer
Returns:
Callable: the tracing function that needs to be called to trace func
"""
tracing_func: Optional[Callable]
# numpy.invert is not great in term of types it supports, so we've decided not to support it
# and to propose to the user to use numpy.bitwise_not
if func == numpy.invert:
raise RuntimeError(
f"NPTracer does not manage the following func: {func.__name__}. Please replace by "
f"calls to bitwise_xor with appropriate mask"
)
if isinstance(func, numpy.ufunc):
tracing_func = NPTracer.UFUNC_ROUTING.get(func, None)
else:
tracing_func = NPTracer.FUNC_ROUTING.get(func, None)
if tracing_func is None:
raise NotImplementedError(
f"NPTracer does not yet manage the following func: {func.__name__}"
)
return tracing_func
def _supports_other_operand(self, other: Any) -> bool:
return super()._supports_other_operand(other) or isinstance(
other, SUPPORTED_TYPES_FOR_TRACING
)
def _make_const_input_tracer(self, constant_data: Any) -> "NPTracer":
return self.__class__([], NPConstant(constant_data), 0)
@classmethod
def _np_operator(
cls,
numpy_operator,
numpy_operator_string,
numpy_operator_nin,
*input_tracers: "NPTracer",
**kwargs,
) -> "NPTracer":
"""Trace a numpy operator.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
assert_true(len(input_tracers) == numpy_operator_nin)
common_output_dtypes_and_shapes = (
get_numpy_function_output_dtype_and_shape_from_input_tracers(
numpy_operator,
*input_tracers,
)
)
assert_true(len(common_output_dtypes_and_shapes) == 1)
variable_input_indices = [
idx
for idx, pred in enumerate(input_tracers)
if not isinstance(pred.traced_computation, Constant)
]
assert_true(
(non_constant_pred_count := len(variable_input_indices)) == 1,
f"Can only have 1 non constant predecessor in {cls._np_operator.__name__}, "
f"got {non_constant_pred_count} for operator {numpy_operator}",
)
variable_input_idx = variable_input_indices[0]
output_dtype, output_shape = common_output_dtypes_and_shapes[0]
generic_function_output_value = TensorValue(
output_dtype,
input_tracers[variable_input_idx].output.is_encrypted,
output_shape,
)
op_kwargs = deepcopy(kwargs)
traced_computation = GenericFunction(
inputs=[input_tracer.output for input_tracer in input_tracers],
arbitrary_func=numpy_operator,
output_value=generic_function_output_value,
op_kind="TLU",
op_kwargs=op_kwargs,
op_name=numpy_operator_string,
)
output_tracer = cls(
input_tracers,
traced_computation=traced_computation,
output_idx=0,
)
return output_tracer
def numpy_dot(self, *args: "NPTracer", **_kwargs) -> "NPTracer":
"""Trace numpy.dot.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
assert_true((num_args := len(args)) == 2, f"dot expects 2 inputs got {num_args}")
common_output_dtypes_and_shapes = (
get_numpy_function_output_dtype_and_shape_from_input_tracers(numpy.dot, *args)
)
assert_true(len(common_output_dtypes_and_shapes) == 1)
traced_computation = Dot(
[input_tracer.output for input_tracer in args],
common_output_dtypes_and_shapes[0][0],
delegate_evaluation_function=numpy.dot,
)
output_tracer = self.__class__(
args,
traced_computation=traced_computation,
output_idx=0,
)
return output_tracer
def clip(self, *args: Union["NPTracer", Any], **kwargs) -> "NPTracer":
"""Trace x.clip.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
sanitized_args = [cast(NPTracer, self._sanitize(arg)) for arg in args]
return self.numpy_clip(self, *sanitized_args, **kwargs)
def numpy_clip(self, *args: "NPTracer", **kwargs) -> "NPTracer":
"""Trace numpy.clip.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
return self._np_operator(numpy.clip, "clip", 3, *args, **kwargs)
def dot(self, *args: "NPTracer", **kwargs) -> "NPTracer":
"""Trace x.dot.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
assert len(args) == 1
arg0 = self._sanitize(args[0])
assert_true(isinstance(arg0, NPTracer))
arg0 = cast(NPTracer, arg0)
return self.numpy_dot(self, arg0, **kwargs)
def transpose(self, *args: "NPTracer", **kwargs) -> "NPTracer":
"""Trace x.transpose.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
return self.numpy_transpose(self, *args, **kwargs)
def numpy_transpose(self, *args: "NPTracer", **kwargs) -> "NPTracer":
"""Trace numpy.transpose.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
assert_true((num_args := len(args)) == 1, f"transpose expect 1 input got {num_args}")
first_arg_output = args[0].output
assert_true(isinstance(first_arg_output, TensorValue))
first_arg_output = cast(TensorValue, first_arg_output)
transpose_is_fusable = first_arg_output.is_scalar or first_arg_output.ndim == 1
out_dtype = first_arg_output.dtype
out_shape = first_arg_output.shape[::-1]
generic_function_output_value = TensorValue(
out_dtype,
first_arg_output.is_encrypted,
out_shape,
)
traced_computation = GenericFunction(
inputs=[first_arg_output],
arbitrary_func=numpy.transpose,
output_value=generic_function_output_value,
op_kind="Memory",
op_kwargs=deepcopy(kwargs),
op_name="transpose",
op_attributes={"fusable": transpose_is_fusable},
)
output_tracer = self.__class__(
args,
traced_computation=traced_computation,
output_idx=0,
)
return output_tracer
def ravel(self, *args: "NPTracer", **kwargs) -> "NPTracer":
"""Trace x.ravel.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
return self.numpy_ravel(self, *args, **kwargs)
def numpy_ravel(self, *args: "NPTracer", **kwargs) -> "NPTracer":
"""Trace numpy.ravel.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
assert_true((num_args := len(args)) == 1, f"ravel expect 1 input got {num_args}")
first_arg_output = args[0].output
assert_true(isinstance(first_arg_output, TensorValue))
first_arg_output = cast(TensorValue, first_arg_output)
ravel_is_fusable = first_arg_output.ndim == 1
out_dtype = first_arg_output.dtype
out_shape = (1,) if first_arg_output.is_scalar else (numpy.product(first_arg_output.shape),)
generic_function_output_value = TensorValue(
out_dtype,
first_arg_output.is_encrypted,
out_shape,
)
traced_computation = GenericFunction(
inputs=[first_arg_output],
arbitrary_func=numpy.ravel,
output_value=generic_function_output_value,
op_kind="Memory",
op_kwargs=deepcopy(kwargs),
op_name="ravel",
op_attributes={"fusable": ravel_is_fusable},
)
output_tracer = self.__class__(
args,
traced_computation=traced_computation,
output_idx=0,
)
return output_tracer
def reshape(self, newshape: Tuple[Any, ...], **kwargs) -> "NPTracer":
"""Trace x.reshape.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
return self.numpy_reshape(self, newshape, **kwargs)
def numpy_reshape(self, arg0: "NPTracer", arg1: Tuple[Any, ...], **kwargs) -> "NPTracer":
"""Trace numpy.reshape.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
# FIXME: #772, restore reshape(self, *args, **kwargs) signature when possible, with mypy
# types
# FIXME: #772, restore
# assert_true((num_args := len(args)) == 2, f"reshape expect 2 input got {num_args}")
# when possible
assert_true((num_kwargs := len(kwargs)) == 0, f"reshape expect 0 kwargs got {num_kwargs}")
first_arg_output = arg0.output
assert_true(isinstance(first_arg_output, TensorValue))
first_arg_output = cast(TensorValue, first_arg_output)
try:
# calculate a newshape using numpy to handle edge cases such as `-1`s within new shape
newshape = numpy.zeros(first_arg_output.shape).reshape(arg1).shape
except Exception as error:
raise ValueError(
f"shapes are not compatible (old shape {first_arg_output.shape}, new shape {arg1})"
) from error
reshape_is_fusable = newshape == first_arg_output.shape
out_dtype = first_arg_output.dtype
out_shape = newshape
generic_function_output_value = TensorValue(
out_dtype,
first_arg_output.is_encrypted,
out_shape,
)
traced_computation = GenericFunction(
inputs=[first_arg_output],
arbitrary_func=numpy.reshape,
output_value=generic_function_output_value,
op_kind="Memory",
op_kwargs={"newshape": newshape},
op_name="reshape",
op_attributes={"fusable": reshape_is_fusable},
)
output_tracer = self.__class__(
[arg0],
traced_computation=traced_computation,
output_idx=0,
)
return output_tracer
def flatten(self, *args: "NPTracer", **kwargs) -> "NPTracer":
"""Trace x.flatten.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
assert_true((num_args := len(args)) == 0, f"flatten expect 0 input got {num_args}")
first_arg_output = self.output
assert_true(isinstance(first_arg_output, TensorValue))
first_arg_output = cast(TensorValue, first_arg_output)
flatten_is_fusable = first_arg_output.ndim == 1
out_dtype = first_arg_output.dtype
out_shape = (1,) if first_arg_output.is_scalar else (numpy.product(first_arg_output.shape),)
generic_function_output_value = TensorValue(
out_dtype,
first_arg_output.is_encrypted,
out_shape,
)
traced_computation = GenericFunction(
inputs=[first_arg_output],
arbitrary_func=lambda x: x.flatten(),
output_value=generic_function_output_value,
op_kind="Memory",
op_kwargs=deepcopy(kwargs),
op_name="flatten",
op_attributes={"fusable": flatten_is_fusable},
)
output_tracer = self.__class__(
[self],
traced_computation=traced_computation,
output_idx=0,
)
return output_tracer
def numpy_sum(self, inp: "NPTracer", **kwargs) -> "NPTracer":
"""Trace numpy.sum.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
input_value = inp.output
def supported(value):
if not value.is_encrypted or not isinstance(input_value, TensorValue):
return False
value = cast(TensorValue, value)
if value.shape == ():
return False
return True
if not supported(input_value):
raise ValueError(
f"only encrypted tensor sum is supported but you tried to sum {input_value}"
)
try:
# calculate a newshape using numpy to handle all cases
newshape = numpy.sum(numpy.zeros(input_value.shape), **kwargs).shape # type: ignore
except Exception as error:
raise ValueError(
f"invalid sum on {input_value} with "
f"{', '.join('='.join([key, str(value)]) for key, value in kwargs.items())}"
) from error
output_value = TensorValue(
input_value.dtype,
input_value.is_encrypted,
newshape,
)
traced_computation = GenericFunction(
inputs=[input_value],
arbitrary_func=numpy.sum,
output_value=output_value,
op_kind="Memory",
op_kwargs=kwargs,
op_name="sum",
op_attributes={"fusable": False},
)
output_tracer = self.__class__(
[inp],
traced_computation=traced_computation,
output_idx=0,
)
return output_tracer
def numpy_concatenate(self, inputs: Tuple["NPTracer", ...], **kwargs) -> "NPTracer":
"""Trace numpy.concatenate.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
input_values = [tracer.output for tracer in inputs]
def supported(values):
if any(
not value.is_encrypted or not isinstance(value, TensorValue) for value in values
):
return False
values = [cast(TensorValue, value) for value in values]
if any(value.shape == () for value in values):
return False
return True
if not supported(input_values):
raise ValueError(
f"only encrypted tensor concatenation is supported "
f"but you tried to concatenate "
f"{', '.join(str(input_value) for input_value in input_values)}"
)
input_tensor_values = [cast(TensorValue, value) for value in input_values]
try:
# calculate a newshape using numpy to handle all cases
sample = tuple(numpy.zeros(input_value.shape) for input_value in input_tensor_values)
newshape = numpy.concatenate(sample, **kwargs).shape
except Exception as error:
kwarg_info = ""
if len(kwargs) != 0:
kwarg_info += " with "
kwarg_info += ", ".join(
"=".join([key, str(value)]) for key, value in kwargs.items()
)
raise ValueError(
f"invalid concatenation of "
f"{', '.join(str(input_value) for input_value in input_values)}{kwarg_info}"
) from error
output_value = TensorValue(
input_tensor_values[0].dtype,
input_tensor_values[0].is_encrypted,
newshape,
)
traced_computation = GenericFunction(
inputs=input_values,
arbitrary_func=numpy.concatenate,
output_value=output_value,
op_kind="Memory",
op_kwargs=kwargs,
op_name="concat",
op_attributes={"fusable": False},
)
output_tracer = self.__class__(
list(inputs),
traced_computation=traced_computation,
output_idx=0,
)
return output_tracer
def __getitem__(self, item):
if isinstance(item, tuple):
item = tuple(process_indexing_element(indexing_element) for indexing_element in item)
else:
item = process_indexing_element(item)
return BaseTracer.__getitem__(self, item)
def __matmul__(self, other):
"""Trace numpy.matmul."""
return self.__array_ufunc__(numpy.matmul, "__call__", self, other)
# Supported functions are either univariate or bivariate for which one of the two
# sources is a constant
#
# numpy.add, numpy.multiply and numpy.subtract are not there since already managed
# by leveled operations
#
# numpy.conjugate is not there since working on complex numbers
#
# numpy.isnat is not there since it is about timings
#
# numpy.divmod, numpy.modf and numpy.frexp are not there since output two values
#
# numpy.invert (as known as numpy.bitwise_not) is not here, because it has strange input type.
# We ask the user to replace bitwise_xor instead
LIST_OF_SUPPORTED_UFUNC: List[numpy.ufunc] = [
numpy.absolute,
numpy.arccos,
numpy.arccosh,
numpy.arcsin,
numpy.arcsinh,
numpy.arctan,
numpy.arctan2,
numpy.arctanh,
numpy.bitwise_and,
numpy.bitwise_or,
numpy.bitwise_xor,
numpy.cbrt,
numpy.ceil,
numpy.copysign,
numpy.cos,
numpy.cosh,
numpy.deg2rad,
numpy.degrees,
numpy.equal,
numpy.exp,
numpy.exp2,
numpy.expm1,
numpy.fabs,
numpy.float_power,
numpy.floor,
numpy.floor_divide,
numpy.fmax,
numpy.fmin,
numpy.fmod,
numpy.gcd,
numpy.greater,
numpy.greater_equal,
numpy.heaviside,
numpy.hypot,
numpy.isfinite,
numpy.isinf,
numpy.isnan,
numpy.lcm,
numpy.ldexp,
numpy.left_shift,
numpy.less,
numpy.less_equal,
numpy.log,
numpy.log10,
numpy.log1p,
numpy.log2,
numpy.logaddexp,
numpy.logaddexp2,
numpy.logical_and,
numpy.logical_not,
numpy.logical_or,
numpy.logical_xor,
numpy.maximum,
numpy.minimum,
numpy.negative,
numpy.nextafter,
numpy.not_equal,
numpy.positive,
numpy.power,
numpy.rad2deg,
numpy.radians,
numpy.reciprocal,
numpy.remainder,
numpy.right_shift,
numpy.rint,
numpy.sign,
numpy.signbit,
numpy.sin,
numpy.sinh,
numpy.spacing,
numpy.sqrt,
numpy.square,
numpy.tan,
numpy.tanh,
numpy.true_divide,
numpy.trunc,
]
# We build UFUNC_ROUTING dynamically after the creation of the class,
# because of some limits of python or our unability to do it properly
# in the class with techniques which are compatible with the different
# coding checks we use
UFUNC_ROUTING: Dict[numpy.ufunc, Callable] = {}
FUNC_ROUTING: Dict[Callable, Callable] = {
numpy.dot: numpy_dot,
numpy.transpose: numpy_transpose,
numpy.reshape: numpy_reshape,
numpy.ravel: numpy_ravel,
numpy.clip: numpy_clip,
numpy.sum: numpy_sum,
numpy.concatenate: numpy_concatenate,
}
def _get_unary_fun(function: numpy.ufunc):
"""Wrap _unary_operator in a lambda to populate NPTRACER.UFUNC_ROUTING."""
# We have to access this method to be able to build NPTracer.UFUNC_ROUTING
# dynamically
# pylint: disable=protected-access
return lambda *input_tracers, **kwargs: NPTracer._np_operator(
function, f"{function.__name__}", 1, *input_tracers, **kwargs
)
# pylint: enable=protected-access
def _get_binary_fun(function: numpy.ufunc):
"""Wrap _binary_operator in a lambda to populate NPTRACER.UFUNC_ROUTING."""
# We have to access this method to be able to build NPTracer.UFUNC_ROUTING
# dynamically
# pylint: disable=protected-access
return lambda *input_tracers, **kwargs: NPTracer._np_operator(
function, f"{function.__name__}", 2, *input_tracers, **kwargs
)
# pylint: enable=protected-access
# We are populating NPTracer.UFUNC_ROUTING dynamically
NPTracer.UFUNC_ROUTING = {
fun: _get_unary_fun(fun) for fun in NPTracer.LIST_OF_SUPPORTED_UFUNC if fun.nin == 1
}
NPTracer.UFUNC_ROUTING.update(
{fun: _get_binary_fun(fun) for fun in NPTracer.LIST_OF_SUPPORTED_UFUNC if fun.nin == 2}
)
list_of_not_supported = [
(ufunc.__name__, ufunc.nin)
for ufunc in NPTracer.LIST_OF_SUPPORTED_UFUNC
if ufunc.nin not in [1, 2]
]
assert_true(len(list_of_not_supported) == 0, f"Not supported nin's, {list_of_not_supported}")
del list_of_not_supported
# We are adding initial support for `np.array(...)` +,-,* `BaseTracer`
# (note that this is not the proper complete handling of these functions)
def _on_numpy_add(lhs, rhs):
return lhs.__add__(rhs)
def _on_numpy_subtract(lhs, rhs):
return lhs.__sub__(rhs)
def _on_numpy_multiply(lhs, rhs):
return lhs.__mul__(rhs)
def _on_numpy_matmul(lhs: NPTracer, rhs: NPTracer):
common_output_dtypes_and_shapes = get_numpy_function_output_dtype_and_shape_from_input_tracers(
numpy.matmul, lhs, rhs
)
assert_true(len(common_output_dtypes_and_shapes) == 1)
output_shape = common_output_dtypes_and_shapes[0][1]
traced_computation = MatMul(
[lhs.output, rhs.output],
common_output_dtypes_and_shapes[0][0],
output_shape,
)
return NPTracer([lhs, rhs], traced_computation, output_idx=0)
NPTracer.UFUNC_ROUTING[numpy.add] = _on_numpy_add
NPTracer.UFUNC_ROUTING[numpy.subtract] = _on_numpy_subtract
NPTracer.UFUNC_ROUTING[numpy.multiply] = _on_numpy_multiply
NPTracer.UFUNC_ROUTING[numpy.matmul] = _on_numpy_matmul
def trace_numpy_function(
function_to_trace: Callable, function_parameters: Dict[str, BaseValue]
) -> OPGraph:
"""Trace a numpy function.
Args:
function_to_trace (Callable): The function you want to trace
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
Returns:
OPGraph: The graph containing the ir nodes representing the computation done in the input
function
"""
function_parameters = prepare_function_parameters(function_to_trace, function_parameters)
input_tracers = make_input_tracers(NPTracer, function_parameters)
# We could easily create a graph of NPTracer, but we may end up with dead nodes starting from
# the inputs that's why we create the graph starting from the outputs
with tracing_context([NPTracer]):
output_tracers = function_to_trace(**input_tracers)
if isinstance(output_tracers, NPTracer):
output_tracers = (output_tracers,)
op_graph = OPGraph.from_output_tracers(output_tracers)
return op_graph

View File

@@ -1,533 +0,0 @@
"""Test file for bounds evaluation with a inputset"""
from typing import Tuple
import numpy as np
import pytest
from concrete.common.bounds_measurement.inputset_eval import eval_op_graph_bounds_on_inputset
from concrete.common.compilation import CompilationConfiguration
from concrete.common.data_types.floats import Float
from concrete.common.data_types.integers import Integer, UnsignedInteger
from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor
from concrete.numpy.compile import numpy_max_func, numpy_min_func
from concrete.numpy.np_dtypes_helpers import get_base_value_for_numpy_or_python_constant_data
from concrete.numpy.tracing import trace_numpy_function
@pytest.mark.parametrize(
"function,input_ranges,expected_output_bounds,expected_output_data_type",
[
pytest.param(
lambda x, y: x + y,
((-10, 10), (-10, 10)),
(-20, 20),
Integer(6, is_signed=True),
id="x + y, (-10, 10), (-10, 10), (-20, 20)",
),
pytest.param(
lambda x, y: x + y,
((-10, 2), (-4, 5)),
(-14, 7),
Integer(5, is_signed=True),
id="x + y, (-10, 2), (-4, 5), (-14, 7)",
),
pytest.param(
lambda x, y: x + y + 1.7,
((-10, 2), (-4, 5)),
(-12.3, 8.7),
Float(64),
id="x + y + 1.7, (-10, 2), (-4, 5), (-12.3, 8.7)",
),
pytest.param(
lambda x, y: x + y + 1,
((-10, 2), (-4, 5)),
(-13, 8),
Integer(5, is_signed=True),
id="x + y + 1, (-10, 2), (-4, 5), (-13, 8)",
),
pytest.param(
lambda x, y: x + y + (-3),
((-10, 2), (-4, 5)),
(-17, 4),
Integer(6, is_signed=True),
id="x + y + 1, (-10, 2), (-4, 5), (-17, 4)",
),
pytest.param(
lambda x, y: (1 + x) + y,
((-10, 2), (-4, 5)),
(-13, 8),
Integer(5, is_signed=True),
id="(1 + x) + y, (-10, 2), (-4, 5), (-13, 8)",
),
pytest.param(
lambda x, y: x - y,
((-10, 10), (-10, 10)),
(-20, 20),
Integer(6, is_signed=True),
id="x - y, (-10, 10), (-10, 10), (-20, 20)",
),
pytest.param(
lambda x, y: x - y,
((-10, 2), (-4, 5)),
(-15, 6),
Integer(5, is_signed=True),
id="x - y, (-10, 2), (-4, 5), (-15, 6)",
),
pytest.param(
lambda x, y: x - y - 42,
((-10, 2), (-4, 5)),
(-57, -36),
Integer(7, is_signed=True),
id="x - y - 42, (-10, 2), (-4, 5), (-57, -36)",
),
pytest.param(
lambda x, y: x - y - 41.5,
((-10, 2), (-4, 5)),
(-56.5, -35.5),
Float(64),
id="x - y - 41.5, (-10, 2), (-4, 5), (-56.5, -35.5)",
),
pytest.param(
lambda x, y: 3 - x + y,
((-10, 2), (-4, 5)),
(-3, 18),
Integer(6, is_signed=True),
id="3 - x + y, (-10, 2), (-4, 5), (-3, 18)",
),
pytest.param(
lambda x, y: 2.8 - x + y,
((-10, 2), (-4, 5)),
(-3.2, 17.8),
Float(64),
id="2.8 - x + y, (-10, 2), (-4, 5), (-3.2, 17.8)",
),
pytest.param(
lambda x, y: (-13) - x + y,
((-10, 2), (-4, 5)),
(-19, 2),
Integer(6, is_signed=True),
id="(-13) - x + y, (-10, 2), (-4, 5), (-19, 2)",
),
pytest.param(
lambda x, y: (-13.5) - x + y,
((-10, 2), (-4, 5)),
(-19.5, 1.5),
Float(64),
id="(-13.5) - x + y, (-10, 2), (-4, 5), (-19.5, 1.5)",
),
pytest.param(
lambda x, y: x * y,
((-10, 10), (-10, 10)),
(-100, 100),
Integer(8, is_signed=True),
id="x * y, (-10, 10), (-10, 10), (-100, 100)",
),
pytest.param(
lambda x, y: x * y,
((-10, 2), (-4, 5)),
(-50, 40),
Integer(7, is_signed=True),
id="x * y, (-10, 2), (-4, 5), (-50, 40)",
),
pytest.param(
lambda x, y: (3 * x) * y,
((-10, 2), (-4, 5)),
(-150, 120),
Integer(9, is_signed=True),
id="(3 * x) * y, (-10, 2), (-4, 5), (-150, 120)",
),
pytest.param(
lambda x, y: (3.0 * x) * y,
((-10, 2), (-4, 5)),
(-150.0, 120.0),
Float(64),
id="(3.0 * x) * y, (-10, 2), (-4, 5), (-150.0, 120.0)",
),
pytest.param(
lambda x, y: (x * 11) * y,
((-10, 2), (-4, 5)),
(-550, 440),
Integer(11, is_signed=True),
id="x * y, (-10, 2), (-4, 5), (-550, 440)",
),
pytest.param(
lambda x, y: (x * (-11)) * y,
((-10, 2), (-4, 5)),
(-440, 550),
Integer(11, is_signed=True),
id="(x * (-11)) * y, (-10, 2), (-4, 5), (-440, 550)",
),
pytest.param(
lambda x, y: (x * (-11.0)) * y,
((-10, 2), (-4, 5)),
(-440.0, 550.0),
Float(64),
id="(x * (-11.0)) * y, (-10, 2), (-4, 5), (-440.0, 550.0)",
),
pytest.param(
lambda x, y: x + x + y,
((-10, 10), (-10, 10)),
(-30, 30),
Integer(6, is_signed=True),
id="x + x + y, (-10, 10), (-10, 10), (-30, 30)",
),
pytest.param(
lambda x, y: x - x + y,
((-10, 10), (-10, 10)),
(-10, 10),
Integer(5, is_signed=True),
id="x - x + y, (-10, 10), (-10, 10), (-10, 10)",
),
pytest.param(
lambda x, y: x - x + y,
((-10, 2), (-4, 5)),
(-4, 5),
Integer(4, is_signed=True),
id="x - x + y, (-10, 2), (-4, 5), (-4, 5)",
),
pytest.param(
lambda x, y: x * y - x,
((-10, 10), (-10, 10)),
(-110, 110),
Integer(8, is_signed=True),
id="x * y - x, (-10, 10), (-10, 10), (-110, 110)",
),
pytest.param(
lambda x, y: x * y - x,
((-10, 2), (-4, 5)),
(-40, 50),
Integer(7, is_signed=True),
id="x * y - x, (-10, 2), (-4, 5), (-40, 50),",
),
pytest.param(
lambda x, y: (x * 3) * y - (x + 3) + (y - 13) + x * (11 + y) * (12 + y) + (15 - x),
((-10, 2), (-4, 5)),
(-2846, 574),
Integer(13, is_signed=True),
id="x * y - x, (-10, 2), (-4, 5), (-2846, 574),",
),
],
)
def test_eval_op_graph_bounds_on_inputset(
function,
input_ranges,
expected_output_bounds,
expected_output_data_type: Integer,
):
"""Test function for eval_op_graph_bounds_on_inputset"""
test_eval_op_graph_bounds_on_inputset_multiple_output(
function,
input_ranges,
(expected_output_bounds,),
(expected_output_data_type,),
)
@pytest.mark.parametrize(
"function,input_ranges,expected_output_bounds,expected_output_data_type",
[
pytest.param(
lambda x, y: (x + 1, y + 10),
((-1, 1), (3, 4)),
((0, 2), (13, 14)),
(Integer(2, is_signed=False), Integer(4, is_signed=False)),
),
pytest.param(
lambda x, y: (x + 1.5, y + 9.6),
((-1, 1), (3, 4)),
((0.5, 2.5), (12.6, 13.6)),
(Float(64), Float(64)),
),
pytest.param(
lambda x, y: (x + y + 1, x * y + 42),
((-1, 1), (3, 4)),
((3, 6), (38, 46)),
(Integer(3, is_signed=False), Integer(6, is_signed=False)),
),
pytest.param(
lambda x, y: (x + y + 0.4, x * y + 41.7),
((-1, 1), (3, 4)),
((2.4, 5.4), (37.7, 45.7)),
(Float(64), Float(64)),
),
pytest.param(
lambda x, y: (x + y + 1, x * y + 41.7),
((-1, 1), (3, 4)),
((3, 6), (37.7, 45.7)),
(Integer(3, is_signed=False), Float(64)),
),
pytest.param(
lambda x, y: (x + y + 0.4, x * y + 42),
((-1, 1), (3, 4)),
((2.4, 5.4), (38, 46)),
(Float(64), Integer(6, is_signed=False)),
),
],
)
def test_eval_op_graph_bounds_on_inputset_multiple_output(
function,
input_ranges,
expected_output_bounds,
expected_output_data_type: Tuple[Integer],
):
"""Test function for eval_op_graph_bounds_on_inputset"""
op_graph = trace_numpy_function(
function, {"x": EncryptedScalar(Integer(64, True)), "y": EncryptedScalar(Integer(64, True))}
)
def data_gen(range_x, range_y):
for x_gen in range_x:
for y_gen in range_y:
yield (x_gen, y_gen)
_, node_bounds_and_samples = eval_op_graph_bounds_on_inputset(
op_graph,
data_gen(*tuple(range(x[0], x[1] + 1) for x in input_ranges)),
CompilationConfiguration(),
)
for i, output_node in op_graph.output_nodes.items():
output_node_bounds = node_bounds_and_samples[output_node]
assert (output_node_bounds["min"], output_node_bounds["max"]) == expected_output_bounds[i]
op_graph.update_values_with_bounds_and_samples(node_bounds_and_samples)
for i, output_node in op_graph.output_nodes.items():
assert expected_output_data_type[i] == output_node.outputs[0].dtype
def test_eval_op_graph_bounds_on_non_conformant_inputset_default(capsys):
"""Test function for eval_op_graph_bounds_on_inputset with non conformant inputset"""
def f(x, y):
return np.dot(x, y)
x = EncryptedTensor(UnsignedInteger(2), (3,))
y = ClearTensor(UnsignedInteger(2), (3,))
inputset = [
(np.array([2, 1, 3, 1]), np.array([1, 2, 1, 1])),
(np.array([3, 3, 3]), np.array([3, 3, 5])),
]
op_graph = trace_numpy_function(f, {"x": x, "y": y})
configuration = CompilationConfiguration()
eval_op_graph_bounds_on_inputset(
op_graph,
inputset,
compilation_configuration=configuration,
min_func=numpy_min_func,
max_func=numpy_max_func,
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
)
captured = capsys.readouterr()
assert (
captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
"(expected EncryptedTensor<uint2, shape=(3,)> for parameter `x` "
"but got EncryptedTensor<uint2, shape=(4,)> which is not compatible)\n"
"Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
"(expected ClearTensor<uint2, shape=(3,)> for parameter `y` "
"but got ClearTensor<uint2, shape=(4,)> which is not compatible)\n"
)
def test_eval_op_graph_bounds_on_non_conformant_inputset_check_all(capsys):
"""Test function for eval_op_graph_bounds_on_inputset with non conformant inputset, check all"""
def f(x, y):
return np.dot(x, y)
x = EncryptedTensor(UnsignedInteger(2), (3,))
y = ClearTensor(UnsignedInteger(2), (3,))
inputset = [
(np.array([2, 1, 3, 1]), np.array([1, 2, 1, 1])),
(np.array([3, 3, 3]), np.array([3, 3, 5])),
]
op_graph = trace_numpy_function(f, {"x": x, "y": y})
configuration = CompilationConfiguration(check_every_input_in_inputset=True)
eval_op_graph_bounds_on_inputset(
op_graph,
inputset,
compilation_configuration=configuration,
min_func=numpy_min_func,
max_func=numpy_max_func,
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
)
captured = capsys.readouterr()
assert (
captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
"(expected EncryptedTensor<uint2, shape=(3,)> for parameter `x` "
"but got EncryptedTensor<uint2, shape=(4,)> which is not compatible)\n"
"Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
"(expected ClearTensor<uint2, shape=(3,)> for parameter `y` "
"but got ClearTensor<uint2, shape=(4,)> which is not compatible)\n"
"Warning: Input #1 (0-indexed) is not coherent with the hinted parameters "
"(expected ClearTensor<uint2, shape=(3,)> for parameter `y` "
"but got ClearTensor<uint3, shape=(3,)> which is not compatible)\n"
)
def test_eval_op_graph_bounds_on_conformant_numpy_inputset_check_all(capsys):
"""Test function for eval_op_graph_bounds_on_inputset
with conformant inputset of numpy arrays, check all"""
def f(x, y):
return np.dot(x, y)
x = EncryptedTensor(UnsignedInteger(2), (3,))
y = ClearTensor(UnsignedInteger(2), (3,))
inputset = [
(np.array([2, 1, 3]), np.array([1, 2, 1])),
(np.array([3, 3, 3]), np.array([3, 3, 1])),
]
op_graph = trace_numpy_function(f, {"x": x, "y": y})
configuration = CompilationConfiguration(check_every_input_in_inputset=True)
eval_op_graph_bounds_on_inputset(
op_graph,
inputset,
compilation_configuration=configuration,
min_func=numpy_min_func,
max_func=numpy_max_func,
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
)
captured = capsys.readouterr()
assert captured.err == ""
def test_eval_op_graph_bounds_on_non_conformant_numpy_inputset_check_all(capsys):
"""Test function for eval_op_graph_bounds_on_inputset with non conformant inputset, check all"""
def f(x, y):
return np.dot(x, y)
x = EncryptedTensor(UnsignedInteger(2), (3,))
y = ClearTensor(UnsignedInteger(2), (3,))
inputset = [
(np.array([2, 1, 3, 1]), np.array([1, 2, 1, 1])),
(np.array([3, 3, 3]), np.array([3, 3, 5])),
]
op_graph = trace_numpy_function(f, {"x": x, "y": y})
configuration = CompilationConfiguration(check_every_input_in_inputset=True)
eval_op_graph_bounds_on_inputset(
op_graph,
inputset,
compilation_configuration=configuration,
min_func=numpy_min_func,
max_func=numpy_max_func,
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
)
captured = capsys.readouterr()
assert (
captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
"(expected EncryptedTensor<uint2, shape=(3,)> for parameter `x` "
"but got EncryptedTensor<uint2, shape=(4,)> which is not compatible)\n"
"Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
"(expected ClearTensor<uint2, shape=(3,)> for parameter `y` "
"but got ClearTensor<uint2, shape=(4,)> which is not compatible)\n"
"Warning: Input #1 (0-indexed) is not coherent with the hinted parameters "
"(expected ClearTensor<uint2, shape=(3,)> for parameter `y` "
"but got ClearTensor<uint3, shape=(3,)> which is not compatible)\n"
)
def test_eval_op_graph_bounds_on_non_conformant_inputset_treating_warnings_as_errors():
"""Test function for eval_op_graph_bounds_on_inputset with non conformant inputset and errors"""
def f(x, y):
return np.dot(x, y)
x = EncryptedTensor(UnsignedInteger(2), (3,))
y = ClearTensor(UnsignedInteger(2), (3,))
inputset = [
(np.array([2, 1, 3, 1]), np.array([1, 2, 1, 1])),
(np.array([3, 3, 3]), np.array([3, 3, 5])),
]
op_graph = trace_numpy_function(f, {"x": x, "y": y})
with pytest.raises(ValueError, match=".* is not coherent with the hinted parameters .*"):
configuration = CompilationConfiguration(treat_warnings_as_errors=True)
eval_op_graph_bounds_on_inputset(
op_graph,
inputset,
compilation_configuration=configuration,
min_func=numpy_min_func,
max_func=numpy_max_func,
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
)
def test_inpuset_eval_1_input(default_compilation_configuration):
"""Test case for a function with a single parameter and passing the inputset without tuples."""
def f(x):
return x + 42
x = EncryptedScalar(UnsignedInteger(4))
inputset = range(10)
op_graph = trace_numpy_function(f, {"x": x})
eval_op_graph_bounds_on_inputset(
op_graph,
inputset,
compilation_configuration=default_compilation_configuration,
min_func=numpy_min_func,
max_func=numpy_max_func,
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
)
input_node = op_graph.input_nodes[0]
assert input_node.inputs[0] == input_node.outputs[0]
assert input_node.inputs[0] == EncryptedScalar(UnsignedInteger(4))
output_node = op_graph.output_nodes[0]
assert output_node.outputs[0] == EncryptedScalar(UnsignedInteger(6))
# TODO: https://github.com/zama-ai/concrete-numpy-internal/issues/772
# Remove once this issue is done
def test_inpuset_eval_1_input_refuse_tuple(default_compilation_configuration):
"""Test case for a function with a single parameter and passing the inputset with tuples."""
def f(x):
return x + 42
x = EncryptedScalar(UnsignedInteger(4))
inputset = [(i,) for i in range(10)]
op_graph = trace_numpy_function(f, {"x": x})
with pytest.raises(AssertionError) as excinfo:
eval_op_graph_bounds_on_inputset(
op_graph,
inputset,
compilation_configuration=default_compilation_configuration,
min_func=numpy_min_func,
max_func=numpy_max_func,
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
)
assert str(excinfo.value) == "Tuples are unsupported for single input inputset evaluation"

View File

@@ -1,48 +0,0 @@
"""Test file for compilation artifacts"""
import tempfile
from pathlib import Path
from concrete.common.compilation import CompilationArtifacts
from concrete.common.data_types.integers import UnsignedInteger
from concrete.common.values import EncryptedScalar
from concrete.numpy.compile import compile_numpy_function
def test_artifacts_export(default_compilation_configuration):
"""Test function to check exporting compilation artifacts"""
def function(x):
return x + 42
with tempfile.TemporaryDirectory() as tmp:
output_directory = Path(tmp)
artifacts = CompilationArtifacts(output_directory)
compile_numpy_function(
function,
{"x": EncryptedScalar(UnsignedInteger(7))},
range(10),
default_compilation_configuration,
compilation_artifacts=artifacts,
)
artifacts.export()
assert output_directory.joinpath("environment.txt").exists()
assert output_directory.joinpath("requirements.txt").exists()
assert output_directory.joinpath("function.txt").exists()
assert output_directory.joinpath("parameters.txt").exists()
assert output_directory.joinpath("1.initial.graph.txt").exists()
assert output_directory.joinpath("1.initial.graph.png").exists()
assert output_directory.joinpath("2.final.graph.txt").exists()
assert output_directory.joinpath("2.final.graph.png").exists()
assert output_directory.joinpath("bounds.txt").exists()
assert output_directory.joinpath("mlir.txt").exists()
# format of those files might change in the future
# so it is sufficient to test their existance

View File

@@ -1,77 +0,0 @@
"""Test file for compilation configuration"""
from inspect import signature
import numpy
import pytest
from concrete.common.compilation import CompilationConfiguration
from concrete.common.data_types.integers import Integer
from concrete.common.values import EncryptedScalar
from concrete.numpy.compile import compile_numpy_function_into_op_graph_and_measure_bounds
def no_fuse(x):
"""No fuse"""
return x + 2
def simple_fuse_not_output(x):
"""Simple fuse not output"""
intermediate = x.astype(numpy.float64)
intermediate = intermediate.astype(numpy.uint32)
return intermediate + 2
@pytest.mark.parametrize(
"function_to_trace,fused",
[
pytest.param(
no_fuse,
False,
id="no_fuse",
),
pytest.param(
simple_fuse_not_output,
True,
id="simple_fuse_not_output",
),
],
)
def test_enable_topological_optimizations(
test_helpers, function_to_trace, fused, default_compilation_configuration
):
"""Test function for enable_topological_optimizations flag of compilation configuration"""
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
function_to_trace,
{
param: EncryptedScalar(Integer(32, is_signed=False))
for param in signature(function_to_trace).parameters.keys()
},
[numpy.array(i) for i in range(10)],
default_compilation_configuration,
)
op_graph_not_optimized = compile_numpy_function_into_op_graph_and_measure_bounds(
function_to_trace,
{
param: EncryptedScalar(Integer(32, is_signed=False))
for param in signature(function_to_trace).parameters.keys()
},
[numpy.array(i) for i in range(10)],
CompilationConfiguration(
dump_artifacts_on_unexpected_failures=False,
enable_topological_optimizations=False,
treat_warnings_as_errors=True,
),
)
graph = op_graph.graph
not_optimized_graph = op_graph_not_optimized.graph
if fused:
assert not test_helpers.digraphs_are_equivalent(graph, not_optimized_graph)
assert len(graph) < len(not_optimized_graph)
else:
assert test_helpers.digraphs_are_equivalent(graph, not_optimized_graph)
assert len(graph) == len(not_optimized_graph)

View File

@@ -1,299 +0,0 @@
"""Test file for data types helpers"""
import pytest
from concrete.common.data_types.base import BaseDataType
from concrete.common.data_types.dtypes_helpers import (
broadcast_shapes,
find_type_to_hold_both_lossy,
mix_values_determine_holding_dtype,
value_is_encrypted_scalar_integer,
value_is_encrypted_scalar_unsigned_integer,
)
from concrete.common.data_types.floats import Float
from concrete.common.data_types.integers import Integer
from concrete.common.values import (
BaseValue,
ClearScalar,
ClearTensor,
EncryptedScalar,
EncryptedTensor,
)
@pytest.mark.parametrize(
"value,expected_result",
[
pytest.param(
ClearScalar(Integer(8, is_signed=False)),
False,
id="ClearScalar 8 bits unsigned Integer",
),
pytest.param(
EncryptedScalar(Integer(8, is_signed=True)),
True,
id="EncryptedScalar 8 bits signed Integer",
),
],
)
def test_value_is_encrypted_integer(value: BaseValue, expected_result: bool):
"""Test value_is_encrypted_integer helper"""
assert value_is_encrypted_scalar_integer(value) == expected_result
@pytest.mark.parametrize(
"value,expected_result",
[
pytest.param(
ClearScalar(Integer(8, is_signed=False)),
False,
id="ClearScalar 8 bits unsigned Integer",
),
pytest.param(
EncryptedScalar(Integer(8, is_signed=True)),
False,
id="EncryptedScalar 8 bits signed Integer",
),
pytest.param(
EncryptedScalar(Integer(8, is_signed=False)),
True,
id="EncryptedScalar 8 bits unsigned Integer",
),
],
)
def test_value_is_encrypted_unsigned_integer(value: BaseValue, expected_result: bool):
"""Test value_is_encrypted_unsigned_integer helper"""
assert value_is_encrypted_scalar_unsigned_integer(value) == expected_result
class UnsupportedDataType(BaseDataType):
"""Test helper class to represent an UnsupportedDataType"""
def __eq__(self, o: object) -> bool:
return isinstance(o, self.__class__)
@pytest.mark.parametrize(
"dtype1,dtype2,expected_mixed_dtype",
[
pytest.param(Integer(6, True), Integer(6, True), Integer(6, True), id="int6, int6, int6"),
pytest.param(
Integer(6, False), Integer(6, False), Integer(6, False), id="uint6, uint6, uint6"
),
pytest.param(Integer(6, True), Integer(6, False), Integer(7, True), id="int6, uint6, int7"),
pytest.param(Integer(6, False), Integer(6, True), Integer(7, True), id="uint6, int6, int7"),
pytest.param(Integer(6, True), Integer(5, False), Integer(6, True), id="int6, uint5, int6"),
pytest.param(Integer(5, False), Integer(6, True), Integer(6, True), id="uint5, int6, int6"),
pytest.param(Integer(32, True), Float(32), Float(32), id="int32, float32, float32"),
pytest.param(Integer(64, True), Float(32), Float(32), id="int64, float32, float32"),
pytest.param(Integer(64, True), Float(64), Float(64), id="int64, float64, float64"),
pytest.param(Integer(32, True), Float(64), Float(64), id="int32, float64, float64"),
pytest.param(Float(64), Integer(32, True), Float(64), id="float64, int32, float64"),
pytest.param(Float(64), Integer(7, False), Float(64), id="float64, uint7, float64"),
pytest.param(Float(32), Float(32), Float(32), id="float32, float32, float32"),
pytest.param(Float(32), Float(64), Float(64), id="float32, float64, float64"),
pytest.param(Float(64), Float(32), Float(64), id="float64, float32, float64"),
pytest.param(Float(64), Float(64), Float(64), id="float64, float64, float64"),
pytest.param(
UnsupportedDataType(),
UnsupportedDataType(),
None,
id="unsupported, unsupported, xfail",
marks=pytest.mark.xfail(strict=True),
),
pytest.param(
Integer(6, True),
UnsupportedDataType(),
None,
id="int6, unsupported, xfail",
marks=pytest.mark.xfail(strict=True),
),
pytest.param(
UnsupportedDataType(),
Integer(6, True),
None,
id="unsupported, int6, xfail",
marks=pytest.mark.xfail(strict=True),
),
pytest.param(
UnsupportedDataType(),
Float(32),
None,
id="unsupported, float32, xfail",
marks=pytest.mark.xfail(strict=True),
),
],
)
def test_mix_data_types(
dtype1: BaseDataType,
dtype2: BaseDataType,
expected_mixed_dtype: BaseDataType,
):
"""Test find_type_to_hold_both_lossy helper"""
assert expected_mixed_dtype == find_type_to_hold_both_lossy(dtype1, dtype2)
@pytest.mark.parametrize(
"value1,value2,expected_mixed_value",
[
pytest.param(
EncryptedScalar(Integer(7, False)),
EncryptedScalar(Integer(7, False)),
EncryptedScalar(Integer(7, False)),
id="euint7, euint7, euint7",
),
pytest.param(
EncryptedScalar(Integer(7, False)),
ClearScalar(Integer(7, False)),
EncryptedScalar(Integer(7, False)),
id="euint7, cuint7, euint7",
),
pytest.param(
ClearScalar(Integer(7, False)),
EncryptedScalar(Integer(7, False)),
EncryptedScalar(Integer(7, False)),
id="cuint7, euint7, euint7",
),
pytest.param(
ClearScalar(Integer(7, False)),
ClearScalar(Integer(7, False)),
ClearScalar(Integer(7, False)),
id="cuint7, cuint7, cuint7",
),
pytest.param(
ClearScalar(Float(32)),
ClearScalar(Float(32)),
ClearScalar(Float(32)),
id="cfloat32, cfloat32, cfloat32",
),
pytest.param(
EncryptedScalar(Float(32)),
ClearScalar(Float(32)),
EncryptedScalar(Float(32)),
id="efloat32, cfloat32, efloat32",
),
],
)
def test_mix_scalar_values(value1, value2, expected_mixed_value):
"""Test mix_values_determine_holding_dtype helper with scalars"""
assert expected_mixed_value == mix_values_determine_holding_dtype(value1, value2)
@pytest.mark.parametrize(
"value1,value2,expected_mixed_value",
[
pytest.param(
EncryptedTensor(Integer(7, False), (1, 2, 3)),
EncryptedTensor(Integer(7, False), (1, 2, 3)),
EncryptedTensor(Integer(7, False), (1, 2, 3)),
),
pytest.param(
ClearTensor(Integer(7, False), (1, 2, 3)),
EncryptedTensor(Integer(7, False), (1, 2, 3)),
EncryptedTensor(Integer(7, False), (1, 2, 3)),
),
pytest.param(
ClearTensor(Integer(7, False), (1, 2, 3)),
ClearTensor(Integer(7, False), (1, 2, 3)),
ClearTensor(Integer(7, False), (1, 2, 3)),
),
pytest.param(
ClearTensor(Integer(7, False), (1, 2, 3)),
ClearTensor(Integer(7, False), (1, 2, 3)),
ClearTensor(Integer(7, False), (1, 2, 3)),
),
pytest.param(
ClearTensor(Integer(7, False), (1, 2, 3)),
EncryptedScalar(Integer(7, False)),
None,
marks=pytest.mark.xfail(strict=True, raises=AssertionError),
),
pytest.param(
ClearTensor(Integer(7, False), (1, 2, 3)),
ClearTensor(Integer(7, False), (3, 2, 1)),
None,
marks=pytest.mark.xfail(strict=True, raises=AssertionError),
),
],
)
def test_mix_tensor_values(value1, value2, expected_mixed_value):
"""Test mix_values_determine_holding_dtype helper with tensors"""
assert expected_mixed_value == mix_values_determine_holding_dtype(value1, value2)
class DummyValue(BaseValue):
"""DummyValue"""
def __eq__(self, other: object) -> bool:
return BaseValue.__eq__(self, other)
def test_fail_mix_values_determine_holding_dtype():
"""Test function for failure case of mix_values_determine_holding_dtype"""
with pytest.raises(ValueError, match=r".* does not support value .*"):
mix_values_determine_holding_dtype(
DummyValue(Integer(32, True), True),
DummyValue(Integer(32, True), True),
)
@pytest.mark.parametrize(
"shape1,shape2,expected_shape",
[
pytest.param((), (), ()),
pytest.param((3,), (), (3,)),
pytest.param((3,), (1,), (3,)),
pytest.param((3,), (2,), None),
pytest.param((3,), (3,), (3,)),
pytest.param((2, 3), (), (2, 3)),
pytest.param((2, 3), (1,), (2, 3)),
pytest.param((2, 3), (2,), None),
pytest.param((2, 3), (3,), (2, 3)),
pytest.param((2, 3), (1, 1), (2, 3)),
pytest.param((2, 3), (2, 1), (2, 3)),
pytest.param((2, 3), (3, 1), None),
pytest.param((2, 3), (1, 2), None),
pytest.param((2, 3), (2, 2), None),
pytest.param((2, 3), (3, 2), None),
pytest.param((2, 3), (1, 3), (2, 3)),
pytest.param((2, 3), (2, 3), (2, 3)),
pytest.param((2, 3), (3, 3), None),
pytest.param((2, 1, 3), (1, 1, 1), (2, 1, 3)),
pytest.param((2, 1, 3), (1, 4, 1), (2, 4, 3)),
pytest.param((2, 1, 3), (2, 4, 3), (2, 4, 3)),
# Tests cases taken from `numpy`
# https://github.com/numpy/numpy/blob/623bc1fae1d47df24e7f1e29321d0c0ba2771ce0/numpy/lib/tests/test_stride_tricks.py#L296-L351
pytest.param((1, 2), (2,), (1, 2)),
pytest.param((1, 1), (3, 4), (3, 4)),
pytest.param((1, 3), (3, 1), (3, 3)),
pytest.param((1, 0), (0, 0), (0, 0)),
pytest.param((0, 1), (0, 0), (0, 0)),
pytest.param((1, 0), (0, 1), (0, 0)),
pytest.param((1, 1), (0, 0), (0, 0)),
pytest.param((1, 1), (1, 0), (1, 0)),
pytest.param((1, 1), (0, 1), (0, 1)),
pytest.param((), (0,), (0,)),
pytest.param((0,), (0, 0), (0, 0)),
pytest.param((0,), (0, 1), (0, 0)),
pytest.param((1,), (0, 0), (0, 0)),
pytest.param((2,), (0, 0), (0, 0)),
pytest.param((), (0, 0), (0, 0)),
pytest.param((1, 1), (0,), (1, 0)),
pytest.param((1,), (0, 1), (0, 1)),
pytest.param((1,), (1, 0), (1, 0)),
pytest.param((), (1, 0), (1, 0)),
pytest.param((), (0, 1), (0, 1)),
pytest.param((1,), (3,), (3,)),
pytest.param((2,), (3, 2), (3, 2)),
pytest.param((3,), (4,), None),
pytest.param((2, 3), (2,), None),
pytest.param((1, 3, 4), (2, 3, 3), None),
pytest.param((2,), (2, 3), None),
],
)
def test_broadcast_shapes(shape1, shape2, expected_shape):
"""Test function for `broadcast_shapes` helper"""
assert broadcast_shapes(shape1=shape1, shape2=shape2) == expected_shape
assert broadcast_shapes(shape1=shape2, shape2=shape1) == expected_shape

View File

@@ -1,50 +0,0 @@
"""Test file for float data types"""
import pytest
from concrete.common.data_types.floats import Float, Float32, Float64
@pytest.mark.parametrize(
"float_,expected_repr_str",
[
pytest.param(
Float32(),
"Float<32 bits>",
id="Float32",
),
pytest.param(
Float(32),
"Float<32 bits>",
id="32 bits Float",
),
pytest.param(
Float64(),
"Float<64 bits>",
id="Float64",
),
pytest.param(
Float(64),
"Float<64 bits>",
id="64 bits Float",
),
],
)
def test_floats_repr(float_: Float, expected_repr_str: str):
"""Test float repr"""
assert float_.__repr__() == expected_repr_str
@pytest.mark.parametrize(
"float_1,float_2,expected_equal",
[
pytest.param(Float32(), Float(32), True),
pytest.param(Float(64), Float32(), False),
pytest.param(Float64(), Float(64), True),
],
)
def test_floats_eq(float_1: Float, float_2: Float, expected_equal: bool):
"""Test float eq"""
assert expected_equal == (float_1 == float_2)
assert expected_equal == (float_2 == float_1)

View File

@@ -1,112 +0,0 @@
"""Test file for integers data types"""
import random
import pytest
from concrete.common.data_types.integers import (
Integer,
SignedInteger,
UnsignedInteger,
make_integer_to_hold,
)
@pytest.mark.parametrize(
"integer,expected_min,expected_max",
[
pytest.param(Integer(8, is_signed=False), 0, 255, id="8 bits unsigned Integer"),
pytest.param(UnsignedInteger(8), 0, 255, id="8 bits UnsignedInteger"),
pytest.param(Integer(8, is_signed=True), -128, 127, id="8 bits signed Integer"),
pytest.param(SignedInteger(8), -128, 127, id="8 bits SignedInteger"),
pytest.param(Integer(32, is_signed=False), 0, 4_294_967_295, id="32 bits unsigned Integer"),
pytest.param(UnsignedInteger(32), 0, 4_294_967_295, id="32 bits UnsignedInteger"),
pytest.param(
Integer(32, is_signed=True),
-2_147_483_648,
2_147_483_647,
id="32 bits signed Integer",
),
pytest.param(
SignedInteger(32),
-2_147_483_648,
2_147_483_647,
id="32 bits SignedInteger",
),
],
)
def test_basic_integers(integer: Integer, expected_min: int, expected_max: int):
"""Test integer class basic functions"""
assert integer.min_value() == expected_min
assert integer.max_value() == expected_max
assert integer.can_represent_value(random.randint(expected_min, expected_max))
assert not integer.can_represent_value(expected_min - 1)
assert not integer.can_represent_value(expected_max + 1)
@pytest.mark.parametrize(
"integer,expected_repr_str",
[
pytest.param(
Integer(8, is_signed=False),
"Integer<unsigned, 8 bits>",
id="8 bits unsigned Integer",
),
pytest.param(
Integer(8, is_signed=True),
"Integer<signed, 8 bits>",
id="8 bits signed Integer",
),
pytest.param(
Integer(32, is_signed=False),
"Integer<unsigned, 32 bits>",
id="32 bits unsigned Integer",
),
pytest.param(
Integer(32, is_signed=True),
"Integer<signed, 32 bits>",
id="32 bits signed Integer",
),
],
)
def test_integers_repr(integer: Integer, expected_repr_str: str):
"""Test integer repr"""
assert integer.__repr__() == expected_repr_str
@pytest.mark.parametrize(
"values,force_signed,expected_result",
[
([0], False, Integer(1, is_signed=False)),
([0], True, Integer(2, is_signed=True)),
([1], False, Integer(1, is_signed=False)),
([1], True, Integer(2, is_signed=True)),
([-1], False, Integer(2, is_signed=True)),
([-2], False, Integer(2, is_signed=True)),
([0, 1], False, Integer(1, is_signed=False)),
([0, 1], True, Integer(2, is_signed=True)),
([7], False, Integer(3, is_signed=False)),
([7], True, Integer(4, is_signed=True)),
([8], False, Integer(4, is_signed=False)),
([8], True, Integer(5, is_signed=True)),
([-7], False, Integer(4, is_signed=True)),
([-8], False, Integer(4, is_signed=True)),
([-7, -8], False, Integer(4, is_signed=True)),
([-9], False, Integer(5, is_signed=True)),
([-9], True, Integer(5, is_signed=True)),
([0, 127], False, Integer(7, is_signed=False)),
([0, 127], True, Integer(8, is_signed=True)),
([0, 128], False, Integer(8, is_signed=False)),
([0, 128], True, Integer(9, is_signed=True)),
([-1, 127], False, Integer(8, is_signed=True)),
([-256, 127], False, Integer(9, is_signed=True)),
([-128, 127], False, Integer(8, is_signed=True)),
([-128, 128], False, Integer(9, is_signed=True)),
([-13, 4], False, Integer(5, is_signed=True)),
([42, 1019], False, Integer(10, is_signed=False)),
],
)
def test_make_integer_to_hold(values, force_signed, expected_result):
"""Test make_integer_to_hold"""
assert expected_result == make_integer_to_hold(values, force_signed)

View File

@@ -1,86 +0,0 @@
"""Test file for values related code."""
from copy import deepcopy
from functools import partial
from typing import Callable, Optional, Tuple, Union
import pytest
from concrete.common.data_types.base import BaseDataType
from concrete.common.data_types.floats import Float
from concrete.common.data_types.integers import Integer
from concrete.common.values import ClearTensor, EncryptedTensor, TensorValue
class DummyDtype(BaseDataType):
"""Dummy Helper Dtype"""
def __eq__(self, o: object) -> bool:
return isinstance(o, self.__class__)
@pytest.mark.parametrize(
"tensor_constructor,expected_is_encrypted",
[
(ClearTensor, False),
(partial(TensorValue, is_encrypted=False), False),
(EncryptedTensor, True),
(partial(TensorValue, is_encrypted=True), True),
],
)
@pytest.mark.parametrize(
"shape,expected_shape,expected_ndim,expected_size",
[
((), (), 0, 1),
((3, 256, 256), (3, 256, 256), 3, 196_608),
((1920, 1080, 3), (1920, 1080, 3), 3, 6_220_800),
],
)
@pytest.mark.parametrize(
"data_type",
[
Integer(7, False),
Integer(32, True),
Integer(32, False),
Integer(64, True),
Integer(64, False),
Float(32),
Float(64),
],
)
def test_tensor_value(
tensor_constructor: Callable[..., TensorValue],
expected_is_encrypted: bool,
shape: Optional[Tuple[int, ...]],
expected_shape: Tuple[int, ...],
expected_ndim: int,
expected_size: int,
data_type: Union[Integer, Float],
):
"""Test function for TensorValue"""
tensor_value = tensor_constructor(dtype=data_type, shape=shape)
assert expected_is_encrypted == tensor_value.is_encrypted
assert expected_shape == tensor_value.shape
assert expected_ndim == tensor_value.ndim
assert expected_size == tensor_value.size
assert data_type == tensor_value.dtype
other_tensor = deepcopy(tensor_value)
assert other_tensor == tensor_value
other_tensor_value = deepcopy(other_tensor)
other_tensor_value.dtype = DummyDtype()
assert other_tensor_value != tensor_value
other_shape = tuple(val + 1 for val in shape) if shape is not None else ()
other_shape += (2,)
other_tensor_value = tensor_constructor(dtype=data_type, shape=other_shape)
assert other_tensor_value.shape != tensor_value.shape
assert other_tensor_value.ndim != tensor_value.ndim
assert other_tensor_value.size != tensor_value.size
assert other_tensor_value != tensor_value

View File

@@ -1,25 +0,0 @@
"""Test custom assert functions."""
import pytest
from concrete.common.debugging.custom_assert import assert_false, assert_not_reached, assert_true
def test_assert_not_functions():
"""Test custom assert functions"""
assert_true(True, "one check")
assert_false(False, "another check")
with pytest.raises(AssertionError) as excinfo:
assert_not_reached("yet another one")
assert "yet another one" in str(excinfo.value)
with pytest.raises(AssertionError) as excinfo:
assert_true(False, "one failing check")
assert "one failing check" in str(excinfo.value)
with pytest.raises(AssertionError) as excinfo:
assert_false(True, "another failing check")
assert "another failing check" in str(excinfo.value)

View File

@@ -1,46 +0,0 @@
"""Test file for drawing"""
import filecmp
import tempfile
from pathlib import Path
from concrete.common.data_types.integers import Integer
from concrete.common.debugging import draw_graph
from concrete.common.values import EncryptedScalar
from concrete.numpy import NPFHECompiler
from concrete.numpy.compile import compile_numpy_function_into_op_graph_and_measure_bounds
def test_draw_graph_with_saving(default_compilation_configuration):
"""Tests drawing and saving a graph"""
def function(x):
return x + 42
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
function,
{"x": EncryptedScalar(Integer(7, True))},
range(-5, 5),
default_compilation_configuration,
)
compiler = NPFHECompiler(function, {"x": "encrypted"}, default_compilation_configuration)
assert (got := compiler.draw_graph()) is None, got
compiler.eval_on_inputset(range(-5, 5))
with tempfile.TemporaryDirectory() as tmp:
output_directory = Path(tmp)
output_file = output_directory.joinpath("test.png")
draw_graph(op_graph, save_to=output_file)
assert output_file.exists()
output_file_compiler = output_directory.joinpath("test_compiler.png")
compiler_output_file = compiler.draw_graph(save_to=output_file_compiler)
assert compiler_output_file is not None
compiler_output_file = Path(compiler_output_file)
assert compiler_output_file == output_file_compiler
assert compiler_output_file.exists()
assert filecmp.cmp(output_file, compiler_output_file)

View File

@@ -1,160 +0,0 @@
"""Test file for formatting"""
import numpy
from concrete.common.data_types.integers import Integer, UnsignedInteger
from concrete.common.debugging import format_operation_graph
from concrete.common.values import EncryptedScalar
from concrete.numpy import NPFHECompiler
from concrete.numpy.compile import (
compile_numpy_function,
compile_numpy_function_into_op_graph_and_measure_bounds,
)
def test_format_operation_graph_with_multiple_edges(default_compilation_configuration):
"""Test format_operation_graph with multiple edges"""
def function(x):
return x + x
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
function,
{"x": EncryptedScalar(Integer(4, True))},
range(0, 10),
default_compilation_configuration,
)
formatted_graph = format_operation_graph(op_graph)
assert (
formatted_graph
== """
%0 = x # EncryptedScalar<uint4>
%1 = add(%0, %0) # EncryptedScalar<uint5>
return %1
""".strip()
)
def test_format_operation_graph_with_offending_nodes(default_compilation_configuration):
"""Test format_operation_graph with offending nodes"""
def function(x):
return x + 42
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
function,
{"x": EncryptedScalar(Integer(7, True))},
range(-5, 5),
default_compilation_configuration,
)
highlighted_nodes = {op_graph.input_nodes[0]: ["foo"]}
formatted_graph = format_operation_graph(op_graph, highlighted_nodes=highlighted_nodes).strip()
assert (
formatted_graph
== """
%0 = x # EncryptedScalar<int4>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo
%1 = 42 # ClearScalar<uint6>
%2 = add(%0, %1) # EncryptedScalar<uint6>
return %2
""".strip()
)
highlighted_nodes = {op_graph.input_nodes[0]: ["foo"], op_graph.output_nodes[0]: ["bar", "baz"]}
formatted_graph = format_operation_graph(op_graph, highlighted_nodes=highlighted_nodes).strip()
assert (
formatted_graph
== """
%0 = x # EncryptedScalar<int4>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo
%1 = 42 # ClearScalar<uint6>
%2 = add(%0, %1) # EncryptedScalar<uint6>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ bar
baz
return %2
""".strip()
)
def test_format_operation_graph_with_fusing(default_compilation_configuration):
"""Test format_operation_graph with fusing"""
def function(x):
return (10 * (numpy.cos(x + 1) + 1)).astype(numpy.uint32)
circuit = compile_numpy_function(
function,
{
"x": EncryptedScalar(UnsignedInteger(3)),
},
range(2 ** 3),
default_compilation_configuration,
)
assert (got := str(circuit)) == (
"""
%0 = x # EncryptedScalar<uint5>
%1 = 1 # ClearScalar<uint6>
%2 = add(%0, %1) # EncryptedScalar<uint5>
%3 = subgraph(%2) # EncryptedScalar<uint5>
return %3
Subgraphs:
%3 = subgraph(%2):
%0 = 10 # ClearScalar<uint4>
%1 = 1 # ClearScalar<uint1>
%2 = float_subgraph_input # EncryptedScalar<uint3>
%3 = cos(%2) # EncryptedScalar<float64>
%4 = add(%3, %1) # EncryptedScalar<float64>
%5 = mul(%4, %0) # EncryptedScalar<float64>
%6 = astype(%5, dtype=uint32) # EncryptedScalar<uint5>
return %6
""".strip()
), got
compiler = NPFHECompiler(function, {"x": "encrypted"}, default_compilation_configuration)
assert (
got := str(compiler)
) == "__str__ failed: OPGraph is None, NPFHECompiler needs evaluation on an inputset", got
compiler.eval_on_inputset(range(2 ** 3))
# String is different here as the type that is first propagated to trace the opgraph is not the
# same
assert (got := str(compiler)) == (
"""
%0 = x # EncryptedScalar<uint3>
%1 = 1 # ClearScalar<uint1>
%2 = add(%0, %1) # EncryptedScalar<uint4>
%3 = subgraph(%2) # EncryptedScalar<uint5>
return %3
Subgraphs:
%3 = subgraph(%2):
%0 = 10 # ClearScalar<uint4>
%1 = 1 # ClearScalar<uint1>
%2 = float_subgraph_input # EncryptedScalar<uint1>
%3 = cos(%2) # EncryptedScalar<float64>
%4 = add(%3, %1) # EncryptedScalar<float64>
%5 = mul(%4, %0) # EncryptedScalar<float64>
%6 = astype(%5, dtype=uint32) # EncryptedScalar<uint5>
return %6
""".strip()
), got

View File

@@ -1,193 +0,0 @@
"""Test file for convolution"""
import numpy as np
import pytest
import torch
from concrete.common.extensions import convolution
from concrete.common.representation.intermediate import Conv2D
from concrete.common.tracing.base_tracer import BaseTracer
from concrete.common.values.tensors import TensorValue
from concrete.numpy.tracing import NPConstant, NPTracer
@pytest.mark.parametrize(
"kwargs, error_msg",
[
pytest.param(
{"x": None, "weight": np.zeros(1)},
"input x must be an ndarray, or a BaseTracer, not a",
),
pytest.param(
{"x": np.zeros(1), "weight": None},
"weight must be an ndarray, or a BaseTracer, not a",
),
pytest.param(
{"x": np.zeros(1), "weight": np.zeros(1), "bias": 0},
"bias must be an ndarray, a BaseTracer, or None, not a",
),
pytest.param(
{"x": np.zeros(1), "weight": np.zeros(1), "strides": None},
"strides must be a tuple, or list, not a",
),
pytest.param(
{"x": np.zeros(1), "weight": np.zeros(1), "dilations": None},
"dilations must be a tuple, or list, not a",
),
pytest.param(
{"x": np.zeros(1), "weight": np.zeros(1), "pads": None},
"padding must be a tuple, or list, not a",
),
],
)
def test_invalid_arg_types(kwargs, error_msg):
"""Test function to make sure convolution doesn't accept invalid types"""
with pytest.raises(TypeError) as err:
convolution.conv2d(**kwargs)
assert error_msg in str(err)
@pytest.mark.parametrize(
"kwargs, error_msg",
[
pytest.param(
{"x": np.zeros(1), "weight": np.zeros(1)},
"input x should have size (N x C x H x W), not",
),
pytest.param(
{"x": np.zeros((1, 2, 3, 4)), "weight": np.zeros(1)},
"weight should have size (F x C x H x W), not",
),
pytest.param(
{
"x": np.zeros((1, 2, 3, 4)),
"weight": np.zeros((1, 2, 3, 4)),
"bias": np.zeros((1, 2)),
},
"bias should have size (F), not",
),
pytest.param(
{"x": np.zeros(1), "weight": np.zeros(1), "strides": (1,)},
"strides should be of the form",
),
pytest.param(
{"x": np.zeros(1), "weight": np.zeros(1), "dilations": (1,)},
"dilations should be of the form",
),
pytest.param(
{"x": np.zeros(1), "weight": np.zeros(1), "pads": (1,)},
"padding should be of the form",
),
pytest.param(
{"x": np.zeros(1), "weight": np.zeros(1), "auto_pad": None},
"invalid auto_pad is specified",
),
],
)
def test_invalid_input_shape(kwargs, error_msg):
"""Test function to make sure convolution doesn't accept invalid shapes"""
with pytest.raises((ValueError, AssertionError)) as err:
convolution.conv2d(**kwargs)
assert error_msg in str(err)
@pytest.mark.parametrize(
"input_shape, weight_shape",
[
pytest.param((1, 1, 4, 4), (1, 1, 2, 2)),
pytest.param((3, 1, 4, 4), (1, 1, 2, 2)),
pytest.param((1, 1, 4, 4), (3, 1, 2, 2)),
pytest.param((1, 3, 4, 4), (1, 3, 2, 2)),
pytest.param((4, 3, 4, 4), (3, 3, 2, 2)),
pytest.param((4, 3, 16, 16), (3, 3, 2, 2)),
pytest.param((4, 3, 16, 16), (3, 3, 3, 3)),
],
)
@pytest.mark.parametrize("strides", [(1, 1), (1, 2), (2, 1), (2, 2)])
@pytest.mark.parametrize("dilations", [(1, 1), (1, 2), (2, 1), (2, 2)])
@pytest.mark.parametrize("has_bias", [True, False])
@pytest.mark.parametrize("use_ndarray", [True, False])
def test_tracing(input_shape, weight_shape, strides, dilations, has_bias, use_ndarray):
"""Test function to make sure tracong of conv2d works properly"""
if has_bias:
bias = np.random.randint(0, 4, size=(weight_shape[0],))
if not use_ndarray:
bias = NPTracer([], NPConstant(bias), 0)
else:
bias = None
x = NPTracer([], NPConstant(np.random.randint(0, 4, size=input_shape)), 0)
weight = np.random.randint(0, 4, size=weight_shape)
if not use_ndarray:
weight = NPTracer([], NPConstant(weight), 0)
output_tracer = convolution.conv2d(x, weight, bias, strides=strides, dilations=dilations)
traced_computation = output_tracer.traced_computation
assert isinstance(traced_computation, Conv2D)
if has_bias:
assert len(output_tracer.inputs) == 3
else:
assert len(output_tracer.inputs) == 2
assert all(
isinstance(input_, BaseTracer) for input_ in output_tracer.inputs
), f"{output_tracer.inputs}"
assert len(traced_computation.outputs) == 1
output_value = traced_computation.outputs[0]
assert isinstance(output_value, TensorValue) and output_value.is_encrypted
# pylint: disable=no-member
expected_shape = torch.conv2d(
torch.randn(input_shape),
torch.randn(weight_shape),
torch.randn((weight_shape[0])),
stride=strides,
dilation=dilations,
).shape
# pylint: enable=no-member
assert output_value.shape == expected_shape
@pytest.mark.parametrize(
"input_shape, weight_shape",
[
pytest.param((1, 1, 4, 4), (1, 1, 2, 2)),
pytest.param((3, 1, 4, 4), (1, 1, 2, 2)),
pytest.param((1, 1, 4, 4), (3, 1, 2, 2)),
pytest.param((1, 3, 4, 4), (1, 3, 2, 2)),
pytest.param((4, 3, 4, 4), (3, 3, 2, 2)),
pytest.param((4, 3, 16, 16), (3, 3, 2, 2)),
pytest.param((4, 3, 16, 16), (3, 3, 3, 3)),
],
)
@pytest.mark.parametrize("strides", [(1, 1), (1, 2), (2, 1), (2, 2)])
@pytest.mark.parametrize("dilations", [(1, 1), (1, 2), (2, 1), (2, 2)])
@pytest.mark.parametrize("has_bias", [True, False])
def test_evaluation(input_shape, weight_shape, strides, dilations, has_bias):
"""Test function to make sure evaluation of conv2d on plain data works properly"""
if has_bias:
bias = np.random.randint(0, 4, size=(weight_shape[0],))
else:
bias = np.zeros((weight_shape[0],))
x = np.random.randint(0, 4, size=input_shape)
weight = np.random.randint(0, 4, size=weight_shape)
# pylint: disable=no-member
expected = torch.conv2d(
torch.tensor(x, dtype=torch.long),
torch.tensor(weight, dtype=torch.long),
torch.tensor(bias, dtype=torch.long),
stride=strides,
dilation=dilations,
).numpy()
# pylint: enable=no-member
# conv2d should handle None biases
if not has_bias:
bias = None
result = convolution.conv2d(x, weight, bias, strides=strides, dilations=dilations)
assert (result == expected).all()

View File

@@ -1,118 +0,0 @@
"""Test file for direct multi table lookups"""
import random
import numpy
import pytest
from concrete.common.data_types.integers import Integer
from concrete.common.extensions.multi_table import MultiLookupTable
from concrete.common.extensions.table import LookupTable
table_2b_to_2b = LookupTable([1, 2, 0, 3])
table_2b_to_1b = LookupTable([1, 0, 0, 1])
table_2b_to_3b = LookupTable([5, 2, 7, 0])
table_3b_to_2b = LookupTable([1, 2, 0, 3, 0, 3, 1, 2])
table_3b_to_1b = LookupTable([1, 0, 0, 1, 1, 1, 1, 0])
table_3b_to_3b = LookupTable([5, 2, 7, 0, 4, 1, 6, 2])
tables_2b = [table_2b_to_1b, table_2b_to_2b, table_2b_to_3b]
tables_3b = [table_3b_to_1b, table_3b_to_2b, table_3b_to_3b]
def test_multi_lookup_table_creation_and_indexing():
"""Test function for creating and indexing multi lookup tables"""
tables = [
[tables_2b[random.randint(0, 2)], tables_2b[random.randint(0, 2)]],
[tables_2b[random.randint(0, 2)], tables_2b[random.randint(0, 2)]],
[tables_2b[random.randint(0, 2)], tables_2b[random.randint(0, 2)]],
]
multitable = MultiLookupTable(tables)
assert multitable.input_shape == (3, 2)
assert isinstance(multitable.output_dtype, Integer)
assert multitable.output_dtype.bit_width <= 3
index = numpy.random.randint(0, 2 ** 2, size=multitable.input_shape).tolist()
result = multitable[index]
for i in range(3):
for j in range(2):
assert result[i][j] == multitable.tables[i][j][index[i][j]], f"i={i}, j={j}"
@pytest.mark.parametrize(
"tables,match",
[
pytest.param(
[
[],
[table_2b_to_2b, table_2b_to_3b],
],
"MultiLookupTable cannot have an empty array within it",
),
pytest.param(
[
[table_2b_to_1b, 42.0],
[table_2b_to_2b, table_2b_to_3b],
],
"MultiLookupTable should have been made out of LookupTables "
"but it had an object of type float within it",
),
pytest.param(
[
[table_2b_to_2b],
[table_2b_to_2b, table_2b_to_3b],
[table_2b_to_2b, table_2b_to_1b],
],
"MultiLookupTable should have the shape (3, 1) but it does not "
"(an array on dimension 1 has the size 2 but its size should have been 1 "
"as the expected shape is (3, 1))",
),
pytest.param(
[
[table_2b_to_2b, table_3b_to_3b],
[table_2b_to_2b, table_3b_to_1b],
],
"LookupTables within a MultiLookupTable should have the same size but they do not "
"(there was a table with the size of 4 and another with the size of 8)",
),
],
)
def test_multi_lookup_table_creation_failure(tables, match):
"""Test function for failing to create multi lookup tables"""
with pytest.raises(ValueError) as excinfo:
MultiLookupTable(tables)
assert str(excinfo.value) == match
@pytest.mark.parametrize(
"tables,index,match",
[
pytest.param(
[
[table_2b_to_2b, table_2b_to_1b, table_2b_to_3b],
[table_2b_to_1b, table_2b_to_2b, table_2b_to_3b],
],
[
[1, 2],
[3, 0],
],
"Multiple Lookup Table of shape (2, 3) cannot be looked up with [[1, 2], [3, 0]] "
"(you should check your inputset)",
),
],
)
def test_multi_lookup_table_indexing_failure(tables, index, match):
"""Test function for failing to index multi lookup tables"""
table = MultiLookupTable(tables)
with pytest.raises(ValueError) as excinfo:
table.__getitem__(index)
assert str(excinfo.value) == match

View File

@@ -1,131 +0,0 @@
"""Test file for direct table lookups"""
from copy import deepcopy
import networkx as nx
import pytest
from concrete.common import is_a_power_of_2
from concrete.common.data_types.integers import Integer
from concrete.common.extensions.table import LookupTable
from concrete.common.representation import intermediate as ir
from concrete.common.values import EncryptedScalar
from concrete.numpy import tracing
def test_lookup_table_size_constraints():
"""Test function to make sure lookup tables have correct size"""
table = []
# creating empty lookup table is not acceptable
with pytest.raises(ValueError):
LookupTable(table)
for _ in range(512):
table.append(0)
if is_a_power_of_2(len(table)):
# creating lookup table with 2^N entries are acceptable
LookupTable(table)
else:
# creating lookup table with anything other than 2^N entries are not acceptable
with pytest.raises(ValueError):
LookupTable(table)
def test_lookup_table_encrypted_lookup(test_helpers):
"""Test function for tracing with explicit table lookups using encrypted inputs"""
table = LookupTable([3, 6, 0, 2])
def f(x):
return table[x]
x = EncryptedScalar(Integer(2, is_signed=False))
op_graph = tracing.trace_numpy_function(f, {"x": x})
table_node = op_graph.output_nodes[0]
assert table_node.get_table(op_graph.get_ordered_preds(table_node)) == [3, 6, 0, 2]
ref_graph = nx.MultiDiGraph()
# Here is the ASCII drawing of the expected graph:
# (x) - (TLU)
input_x = ir.Input(input_value=x, input_name="x", program_input_idx=0)
ref_graph.add_node(input_x)
generic_function_output_value = deepcopy(x)
generic_function_output_value.dtype = table.output_dtype
# pylint: disable=protected-access
# Need access to _checked_indexing to have is_equivalent_to work for ir.GenericFunction
output_arbitrary_function = ir.GenericFunction(
inputs=[x],
arbitrary_func=LookupTable._checked_indexing,
output_value=generic_function_output_value,
op_kind="TLU",
op_kwargs={"table": deepcopy(table.table)},
op_name="TLU",
)
# pylint: enable=protected-access
ref_graph.add_node(output_arbitrary_function)
ref_graph.add_edge(input_x, output_arbitrary_function, input_idx=0, output_idx=0)
# TODO: discuss if this check is enough as == is not overloaded properly for GenericFunction
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)
def test_lookup_table_encrypted_and_plain_lookup(test_helpers):
"""Test function for tracing with explicit table lookups using encrypted and plain inputs"""
table = LookupTable([3, 6, 0, 2, 1, 4, 5, 7])
def f(x):
return table[x] + table[0]
x = EncryptedScalar(Integer(3, is_signed=False))
op_graph = tracing.trace_numpy_function(f, {"x": x})
ref_graph = nx.MultiDiGraph()
# Here is the ASCII drawing of the expected graph:
# (x) - (TLU)
# \
# (+)
# /
# (3)
input_x = ir.Input(input_value=x, input_name="x", program_input_idx=0)
ref_graph.add_node(input_x)
generic_function_output_value = deepcopy(x)
generic_function_output_value.dtype = table.output_dtype
# pylint: disable=protected-access
# Need access to _checked_indexing to have is_equivalent_to work for ir.GenericFunction
intermediate_arbitrary_function = ir.GenericFunction(
inputs=[x],
arbitrary_func=LookupTable._checked_indexing,
output_value=generic_function_output_value,
op_kind="TLU",
op_kwargs={"table": deepcopy(table.table)},
op_name="TLU",
)
# pylint: enable=protected-access
ref_graph.add_node(intermediate_arbitrary_function)
constant_3 = ir.Constant(3)
ref_graph.add_node(constant_3)
output_add = ir.Add((intermediate_arbitrary_function.outputs[0], constant_3.outputs[0]))
ref_graph.add_node(output_add)
ref_graph.add_edge(input_x, intermediate_arbitrary_function, input_idx=0, output_idx=0)
ref_graph.add_edge(intermediate_arbitrary_function, output_add, input_idx=0, output_idx=0)
ref_graph.add_edge(constant_3, output_add, input_idx=1, output_idx=0)
# TODO: discuss if this check is enough as == is not overloaded properly for GenericFunction
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)

View File

@@ -1,24 +0,0 @@
"""Test file for common python helpers"""
from concrete.common.helpers.python_helpers import catch
def test_catch_failure():
"""Test case for when the function called with catch raises an exception."""
def f_fail():
return 1 / 0
assert catch(f_fail) is None
def test_catch():
"""Test case for catch"""
def f(*args, **kwargs):
return *args, dict(**kwargs)
assert catch(f, (1, 2, 3,), **{"one": 1, "two": 2, "three": 3}) == (
(1, 2, 3),
{"one": 1, "two": 2, "three": 3},
)

View File

@@ -1,87 +0,0 @@
"""Test file for MLIR conversion helpers."""
# pylint cannot extract symbol information of 'mlir' module so we need to disable some lints
# pylint: disable=no-name-in-module
import concrete.lang as concretelang
import pytest
from mlir.ir import Context, Location
from concrete.common.data_types import Float, SignedInteger, UnsignedInteger
from concrete.common.mlir.conversion_helpers import integer_to_mlir_type, value_to_mlir_type
from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor
# pylint: enable=no-name-in-module
@pytest.mark.parametrize(
"integer,is_encrypted,expected_mlir_type_str",
[
pytest.param(SignedInteger(5), False, "i5"),
pytest.param(UnsignedInteger(5), False, "i5"),
pytest.param(SignedInteger(32), False, "i32"),
pytest.param(UnsignedInteger(32), False, "i32"),
pytest.param(SignedInteger(5), True, "!FHE.eint<5>"),
pytest.param(UnsignedInteger(5), True, "!FHE.eint<5>"),
],
)
def test_integer_to_mlir_type(integer, is_encrypted, expected_mlir_type_str):
"""Test function for integer to MLIR type conversion."""
with Context() as ctx, Location.unknown():
concretelang.register_dialects(ctx)
assert str(integer_to_mlir_type(ctx, integer, is_encrypted)) == expected_mlir_type_str
@pytest.mark.parametrize(
"value,expected_mlir_type_str",
[
pytest.param(ClearScalar(SignedInteger(5)), "i5"),
pytest.param(ClearTensor(SignedInteger(5), shape=(2, 3)), "tensor<2x3xi5>"),
pytest.param(EncryptedScalar(SignedInteger(5)), "!FHE.eint<5>"),
pytest.param(EncryptedTensor(SignedInteger(5), shape=(2, 3)), "tensor<2x3x!FHE.eint<5>>"),
pytest.param(ClearScalar(UnsignedInteger(5)), "i5"),
pytest.param(ClearTensor(UnsignedInteger(5), shape=(2, 3)), "tensor<2x3xi5>"),
pytest.param(EncryptedScalar(UnsignedInteger(5)), "!FHE.eint<5>"),
pytest.param(EncryptedTensor(UnsignedInteger(5), shape=(2, 3)), "tensor<2x3x!FHE.eint<5>>"),
],
)
def test_value_to_mlir_type(value, expected_mlir_type_str):
"""Test function for value to MLIR type conversion."""
with Context() as ctx, Location.unknown():
concretelang.register_dialects(ctx)
assert str(value_to_mlir_type(ctx, value)) == expected_mlir_type_str
@pytest.mark.parametrize(
"value,expected_error_message",
[
pytest.param(
ClearScalar(Float(32)),
"ClearScalar<float32> is not supported for MLIR conversion",
),
pytest.param(
ClearTensor(Float(32), shape=(2, 3)),
"ClearTensor<float32, shape=(2, 3)> is not supported for MLIR conversion",
),
pytest.param(
EncryptedScalar(Float(32)),
"EncryptedScalar<float32> is not supported for MLIR conversion",
),
pytest.param(
EncryptedTensor(Float(32), shape=(2, 3)),
"EncryptedTensor<float32, shape=(2, 3)> is not supported for MLIR conversion",
),
],
)
def test_fail_value_to_mlir_type(value, expected_error_message):
"""Test function for failed value to MLIR type conversion."""
with pytest.raises(TypeError) as excinfo:
with Context() as ctx, Location.unknown():
concretelang.register_dialects(ctx)
value_to_mlir_type(ctx, value)
assert str(excinfo.value) == expected_error_message

View File

@@ -1,107 +0,0 @@
"""Test file for intermediate node to MLIR converter."""
import random
import numpy
import pytest
from concrete.common.data_types import UnsignedInteger
from concrete.common.values import EncryptedScalar, EncryptedTensor
from concrete.numpy import compile_numpy_function
@pytest.mark.parametrize(
"function_to_compile,parameters,inputset,expected_error_type,expected_error_message",
[
pytest.param(
lambda x, y: x * y,
{
"x": EncryptedScalar(UnsignedInteger(3)),
"y": EncryptedScalar(UnsignedInteger(3)),
},
[(random.randint(0, 7), random.randint(0, 7)) for _ in range(10)] + [(7, 7)],
NotImplementedError,
"Multiplication "
"between "
"EncryptedScalar<uint6> "
"and "
"EncryptedScalar<uint6> "
"cannot be converted to MLIR yet",
),
pytest.param(
lambda x, y: x - y,
{
"x": EncryptedScalar(UnsignedInteger(3)),
"y": EncryptedScalar(UnsignedInteger(3)),
},
[(random.randint(5, 7), random.randint(0, 5)) for _ in range(10)],
NotImplementedError,
"Subtraction "
"of "
"EncryptedScalar<uint3> "
"from "
"EncryptedScalar<uint3> "
"cannot be converted to MLIR yet",
),
pytest.param(
lambda x, y: numpy.dot(x, y),
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(2,)),
"y": EncryptedTensor(UnsignedInteger(3), shape=(2,)),
},
[
(
numpy.random.randint(0, 2 ** 3, size=(2,)),
numpy.random.randint(0, 2 ** 3, size=(2,)),
)
for _ in range(10)
]
+ [(numpy.array([7, 7]), numpy.array([7, 7]))],
NotImplementedError,
"Dot product "
"between "
"EncryptedTensor<uint7, shape=(2,)> "
"and "
"EncryptedTensor<uint7, shape=(2,)> "
"cannot be converted to MLIR yet",
),
pytest.param(
lambda x, y: x @ y,
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
"y": EncryptedTensor(UnsignedInteger(3), shape=(2, 1)),
},
[
(
numpy.random.randint(0, 2 ** 3, size=(3, 2)),
numpy.random.randint(0, 2 ** 3, size=(2, 1)),
)
for i in range(10)
]
+ [(numpy.array([[7, 7], [7, 7], [7, 7]]), numpy.array([[7], [7]]))],
NotImplementedError,
"Matrix multiplication "
"between "
"EncryptedTensor<uint7, shape=(3, 2)> "
"and "
"EncryptedTensor<uint7, shape=(2, 1)> "
"cannot be converted to MLIR yet",
),
],
)
def test_fail_node_conversion(
function_to_compile,
parameters,
inputset,
expected_error_type,
expected_error_message,
default_compilation_configuration,
):
"""Test function for failed intermediate node conversion."""
with pytest.raises(expected_error_type) as excinfo:
compile_numpy_function(
function_to_compile, parameters, inputset, default_compilation_configuration
)
assert str(excinfo.value) == expected_error_message

View File

@@ -1,735 +0,0 @@
"""Test file for float subgraph fusing"""
import random
from copy import deepcopy
from inspect import signature
import numpy
import pytest
from concrete.common.data_types.integers import Integer
from concrete.common.debugging import format_operation_graph
from concrete.common.debugging.custom_assert import assert_not_reached
from concrete.common.optimization.topological import fuse_float_operations
from concrete.common.values import EncryptedScalar, EncryptedTensor
from concrete.numpy import tracing
from concrete.numpy.tracing import trace_numpy_function
def no_fuse(x):
"""No fuse"""
return x + 2
def no_fuse_unhandled(x, y):
"""No fuse unhandled"""
x_1 = x + 0.7
y_1 = y + 1.3
intermediate = x_1 + y_1
return intermediate.astype(numpy.int32)
def fusable_with_bigger_search(x, y):
"""fusable with bigger search"""
x = x + 1
x_1 = x.astype(numpy.int32)
x_1 = x_1 + 1.5
x_2 = x.astype(numpy.int32)
x_2 = x_2 + 3.4
add = x_1 + x_2
add_int = add.astype(numpy.int32)
return add_int + y
def fusable_with_bigger_search_needs_second_iteration(x, y):
"""fusable with bigger search and triggers a second iteration in the fusing"""
x = x + 1
x = x + 0.5
x = numpy.cos(x)
x_1 = x.astype(numpy.int32)
x_1 = x_1 + 1.5
x_p = x + 1
x_p2 = x_p + 1
x_2 = (x_p + x_p2).astype(numpy.int32)
x_2 = x_2 + 3.4
add = x_1 + x_2
add_int = add.astype(numpy.int32)
return add_int + y
def no_fuse_big_constant_3_10_10(x):
"""Pass an array x with size < 100 to trigger a no fuse condition."""
x = x.astype(numpy.float64)
return (x + numpy.ones((3, 10, 10))).astype(numpy.int32)
def no_fuse_dot(x):
"""No fuse dot"""
return numpy.dot(x, numpy.full((10,), 1.33, dtype=numpy.float64)).astype(numpy.int32)
def simple_create_fuse_opportunity(f, x):
"""No fuse because the function is explicitely marked as unfusable in our code."""
return f(x.astype(numpy.float64)).astype(numpy.int32)
def ravel_cases(x):
"""Simple ravel cases"""
return simple_create_fuse_opportunity(numpy.ravel, x)
def transpose_cases(x):
"""Simple transpose cases"""
return simple_create_fuse_opportunity(numpy.transpose, x)
def reshape_cases(x, newshape):
"""Simple reshape cases"""
return simple_create_fuse_opportunity(lambda x: numpy.reshape(x, newshape), x)
def simple_fuse_not_output(x):
"""Simple fuse not output"""
intermediate = x.astype(numpy.float64)
intermediate = intermediate.astype(numpy.int32)
return intermediate + 2
def simple_fuse_output(x):
"""Simple fuse output"""
return x.astype(numpy.float64).astype(numpy.int32)
def mix_x_and_y_intricately_and_call_f(function, x, y):
"""Mix x and y in an intricated way, that can't be simplified by
an optimizer eg, and then call function
"""
intermediate = x + y
intermediate = intermediate + 2
intermediate = intermediate.astype(numpy.float32)
intermediate = intermediate.astype(numpy.int32)
x_p_1 = intermediate + 1.5
x_p_2 = intermediate + 2.7
x_p_3 = function(x_p_1 + x_p_2)
return (
x_p_3.astype(numpy.int32),
x_p_2.astype(numpy.int32),
(x_p_2 + 3).astype(numpy.int32),
x_p_3.astype(numpy.int32) + 67,
y,
(y + 4.7).astype(numpy.int32) + 3,
)
def mix_x_and_y_and_call_f(function, x, y):
"""Mix x and y and then call function"""
x_p_1 = x + 0.1
x_p_2 = x + 0.2
x_p_3 = function(x_p_1 + x_p_2)
return (
x_p_3.astype(numpy.int32),
x_p_2.astype(numpy.int32),
(x_p_2 + 3).astype(numpy.int32),
x_p_3.astype(numpy.int32) + 67,
y,
(y + 4.7).astype(numpy.int32) + 3,
)
def mix_x_and_y_into_range_0_to_1_and_call_f(function, x, y):
"""Mix x and y and then call function, in such a way that the input to function is between
0 and 1"""
x_p_1 = x + 0.1
x_p_2 = x + 0.2
x_p_4 = 1 - numpy.abs(numpy.sin(x_p_1 + x_p_2 + 0.3))
x_p_3 = function(x_p_4)
return (
x_p_3.astype(numpy.int32),
x_p_2.astype(numpy.int32),
(x_p_2 + 3).astype(numpy.int32),
x_p_3.astype(numpy.int32) + 67,
y,
(y + 4.7).astype(numpy.int32) + 3,
)
def mix_x_and_y_into_integer_and_call_f(function, x, y):
"""Mix x and y but keep the entry to function as an integer"""
x_p_1 = x + 1
x_p_2 = x + 2
x_p_3 = function(x_p_1 + x_p_2)
return (
x_p_3.astype(numpy.int32),
x_p_2.astype(numpy.int32),
(x_p_2 + 3).astype(numpy.int32),
x_p_3.astype(numpy.int32) + 67,
y,
(y + 4.7).astype(numpy.int32) + 3,
)
def get_func_params_int32(func, scalar=True):
"""Returns a dict with parameters as scalar int32"""
return {
param_name: EncryptedScalar(Integer(32, True))
if scalar
else EncryptedTensor(Integer(32, True), (1,))
for param_name in signature(func).parameters.keys()
}
@pytest.mark.parametrize(
"function_to_trace,fused,params,warning_message",
[
pytest.param(no_fuse, False, get_func_params_int32(no_fuse), "", id="no_fuse"),
pytest.param(
no_fuse_unhandled,
False,
get_func_params_int32(no_fuse_unhandled),
"""
The following subgraph is not fusable:
%0 = x # EncryptedScalar<int32>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of 2 variable inputs (can only have 1 for fusing)
%1 = 0.7 # ClearScalar<float64>
%2 = y # EncryptedScalar<int32>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of 2 variable inputs (can only have 1 for fusing)
%3 = 1.3 # ClearScalar<float64>
%4 = add(%0, %1) # EncryptedScalar<float64>
%5 = add(%2, %3) # EncryptedScalar<float64>
%6 = add(%4, %5) # EncryptedScalar<float64>
%7 = astype(%6, dtype=int32) # EncryptedScalar<int32>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ cannot fuse here as the subgraph has 2 variable inputs
return %7
""".strip(), # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_unhandled",
),
pytest.param(
fusable_with_bigger_search,
True,
get_func_params_int32(fusable_with_bigger_search),
None,
id="fusable_with_bigger_search",
),
pytest.param(
fusable_with_bigger_search_needs_second_iteration,
True,
get_func_params_int32(fusable_with_bigger_search_needs_second_iteration),
None,
id="fusable_with_bigger_search",
),
pytest.param(
no_fuse_dot,
False,
{"x": EncryptedTensor(Integer(32, True), (10,))},
"""
The following subgraph is not fusable:
%0 = x # EncryptedTensor<int32, shape=(10,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10,)
%1 = [1.33 1.33 ... 1.33 1.33] # ClearTensor<float64, shape=(10,)>
%2 = dot(%0, %1) # EncryptedScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ output shapes: #0, () are not the same as the subgraph's input: (10,)
%3 = astype(%2, dtype=int32) # EncryptedScalar<int32>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ output shapes: #0, () are not the same as the subgraph's input: (10,)
return %3
""".strip(), # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_dot",
),
pytest.param(
ravel_cases,
False,
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
"""
The following subgraph is not fusable:
%0 = x # EncryptedTensor<int32, shape=(10, 20)>
%1 = astype(%0, dtype=float64) # EncryptedTensor<float64, shape=(10, 20)>
%2 = ravel(%1) # EncryptedTensor<float64, shape=(200,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
%3 = astype(%2, dtype=int32) # EncryptedTensor<int32, shape=(200,)>
return %3
""".strip(), # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_explicitely_ravel",
),
pytest.param(
transpose_cases,
False,
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
"""
The following subgraph is not fusable:
%0 = x # EncryptedTensor<int32, shape=(10, 20)>
%1 = astype(%0, dtype=float64) # EncryptedTensor<float64, shape=(10, 20)>
%2 = transpose(%1) # EncryptedTensor<float64, shape=(20, 10)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
%3 = astype(%2, dtype=int32) # EncryptedTensor<int32, shape=(20, 10)>
return %3
""".strip(), # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_explicitely_transpose",
),
pytest.param(
lambda x: reshape_cases(x, (20, 10)),
False,
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
"""
The following subgraph is not fusable:
%0 = x # EncryptedTensor<int32, shape=(10, 20)>
%1 = astype(%0, dtype=float64) # EncryptedTensor<float64, shape=(10, 20)>
%2 = reshape(%1, newshape=(20, 10)) # EncryptedTensor<float64, shape=(20, 10)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
%3 = astype(%2, dtype=int32) # EncryptedTensor<int32, shape=(20, 10)>
return %3
""".strip(), # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_explicitely_reshape",
),
pytest.param(
no_fuse_big_constant_3_10_10,
False,
{"x": EncryptedTensor(Integer(32, True), (10, 10))},
"""
The following subgraph is not fusable:
%0 = [[[1. 1. 1 ... . 1. 1.]]] # ClearTensor<float64, shape=(3, 10, 10)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this constant node has a bigger shape (3, 10, 10) than the subgraph's input: (10, 10)
%1 = x # EncryptedTensor<int32, shape=(10, 10)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10, 10)
%2 = astype(%1, dtype=float64) # EncryptedTensor<float64, shape=(10, 10)>
%3 = add(%2, %0) # EncryptedTensor<float64, shape=(3, 10, 10)>
%4 = astype(%3, dtype=int32) # EncryptedTensor<int32, shape=(3, 10, 10)>
return %4
""".strip(), # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_big_constant_3_10_10",
),
pytest.param(
simple_fuse_not_output,
True,
get_func_params_int32(simple_fuse_not_output),
None,
id="simple_fuse_not_output",
),
pytest.param(
simple_fuse_output,
True,
get_func_params_int32(simple_fuse_output),
None,
id="simple_fuse_output",
),
pytest.param(
lambda x, y: mix_x_and_y_intricately_and_call_f(numpy.rint, x, y),
True,
get_func_params_int32(lambda x, y: None),
None,
id="mix_x_and_y_intricately_and_call_f_with_rint",
),
pytest.param(
lambda x, y: mix_x_and_y_and_call_f(numpy.rint, x, y),
True,
get_func_params_int32(lambda x, y: None),
None,
id="mix_x_and_y_and_call_f_with_rint",
),
pytest.param(
transpose_cases,
True,
get_func_params_int32(transpose_cases),
None,
id="transpose_cases scalar",
),
pytest.param(
transpose_cases,
True,
{"x": EncryptedTensor(Integer(32, True), (10,))},
None,
id="transpose_cases ndim == 1",
),
pytest.param(
ravel_cases,
True,
{"x": EncryptedTensor(Integer(32, True), (10,))},
None,
id="ravel_cases ndim == 1",
),
pytest.param(
lambda x: reshape_cases(x, (10, 20)),
True,
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
None,
id="reshape_cases same shape",
),
],
)
def test_fuse_float_operations(
function_to_trace,
fused,
params,
warning_message,
capfd,
remove_color_codes,
check_array_equality,
):
"""Test function for fuse_float_operations"""
op_graph = trace_numpy_function(
function_to_trace,
params,
)
copied_graph = deepcopy(op_graph)
orig_num_nodes = len(op_graph.graph)
fuse_float_operations(op_graph)
fused_num_nodes = len(op_graph.graph)
fuse_float_operations(copied_graph)
# Check determinism
assert format_operation_graph(copied_graph) == format_operation_graph(op_graph)
if fused:
assert fused_num_nodes < orig_num_nodes
else:
assert fused_num_nodes == orig_num_nodes
captured = capfd.readouterr()
assert warning_message in (output := remove_color_codes(captured.err)), output
for input_ in [0, 2, 42, 44]:
inputs = ()
for param_input_value in params.values():
if param_input_value.is_scalar:
input_ = numpy.int32(input_)
else:
input_ = numpy.full(param_input_value.shape, input_, dtype=numpy.int32)
inputs += (input_,)
check_array_equality(function_to_trace(*inputs), op_graph(*inputs))
def subtest_tensor_no_fuse(fun, tensor_shape):
"""Test case to verify float fusing is only applied on functions on scalars."""
if tensor_shape == ():
# We want tensors
return
if fun in LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES:
# We need at least one input of the bivariate function to be float
return
# Float fusing currently cannot work if the constant in a bivariate operator is bigger than the
# variable input.
# Make a broadcastable shape but with the constant being bigger
variable_tensor_shape = (1,) + tensor_shape
constant_bigger_shape = (random.randint(2, 10),) + tensor_shape
def tensor_no_fuse(x):
intermediate = x.astype(numpy.float64)
intermediate = fun(intermediate, numpy.ones(constant_bigger_shape))
return intermediate.astype(numpy.int32)
function_to_trace = tensor_no_fuse
params_names = signature(function_to_trace).parameters.keys()
op_graph = trace_numpy_function(
function_to_trace,
{
param_name: EncryptedTensor(Integer(32, True), shape=variable_tensor_shape)
for param_name in params_names
},
)
orig_num_nodes = len(op_graph.graph)
fuse_float_operations(op_graph)
fused_num_nodes = len(op_graph.graph)
assert orig_num_nodes == fused_num_nodes
def check_results_are_equal(function_result, op_graph_result):
"""Check the output of function execution and OPGraph evaluation are equal."""
if isinstance(function_result, tuple) and isinstance(op_graph_result, tuple):
assert len(function_result) == len(op_graph_result)
are_equal = (
function_output == op_graph_output
for function_output, op_graph_output in zip(function_result, op_graph_result)
)
elif not isinstance(function_result, tuple) and not isinstance(op_graph_result, tuple):
are_equal = (function_result == op_graph_result,)
else:
assert_not_reached(f"Incompatible outputs: {function_result}, {op_graph_result}")
return all(value.all() if isinstance(value, numpy.ndarray) else value for value in are_equal)
def subtest_fuse_float_unary_operations_correctness(fun, tensor_shape):
"""Test a unary function with fuse_float_operations."""
# Some manipulation to avoid issues with domain of definitions of functions
if fun == numpy.arccosh:
# 0 is not in the domain of definition
input_list = [1, 2, 42, 44]
super_fun_list = [mix_x_and_y_and_call_f]
elif fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan]:
# Needs values between 0 and 1 in the call function
input_list = [0, 2, 42, 44]
super_fun_list = [mix_x_and_y_into_range_0_to_1_and_call_f]
elif fun in [numpy.cosh, numpy.sinh, numpy.exp, numpy.exp2, numpy.expm1]:
# Not too large values to avoid overflows
input_list = [1, 2, 5, 11]
super_fun_list = [mix_x_and_y_and_call_f, mix_x_and_y_intricately_and_call_f]
else:
# Regular case
input_list = [0, 2, 42, 44]
super_fun_list = [mix_x_and_y_and_call_f, mix_x_and_y_intricately_and_call_f]
for super_fun in super_fun_list:
for input_ in input_list:
def get_function_to_trace():
return lambda x, y: super_fun(fun, x, y)
function_to_trace = get_function_to_trace()
params_names = signature(function_to_trace).parameters.keys()
op_graph = trace_numpy_function(
function_to_trace,
{
param_name: EncryptedTensor(Integer(32, True), tensor_shape)
for param_name in params_names
},
)
copied_graph = deepcopy(op_graph)
orig_num_nodes = len(op_graph.graph)
fuse_float_operations(op_graph)
fused_num_nodes = len(op_graph.graph)
fuse_float_operations(copied_graph)
# Check determinism
assert format_operation_graph(copied_graph) == format_operation_graph(op_graph)
assert fused_num_nodes < orig_num_nodes
# Check that the call to the function or to the op_graph evaluation give the same
# result
tensor_diversifier = (
# The following +1 in the range is to avoid to have 0's which is not in the
# domain definition of some of our functions
numpy.arange(1, numpy.product(tensor_shape) + 1, dtype=numpy.int32).reshape(
tensor_shape
)
if tensor_shape != ()
else 1
)
if fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan]:
# Domain of definition for these functions
tensor_diversifier = (
numpy.ones(tensor_shape, dtype=numpy.int32) if tensor_shape != () else 1
)
input_ = numpy.int32(input_ * tensor_diversifier)
num_params = len(params_names)
assert num_params == 2
# Create inputs which are either of the form [x, x] or [x, y]
for j in range(4):
if fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan] and j > 0:
# Domain of definition for these functions
break
input_a = input_
input_b = input_ + j
if tensor_shape != ():
numpy.random.shuffle(input_a)
numpy.random.shuffle(input_b)
inputs = (input_a, input_b) if random.randint(0, 1) == 0 else (input_b, input_a)
function_result = function_to_trace(*inputs)
op_graph_result = op_graph(*inputs)
assert check_results_are_equal(function_result, op_graph_result)
LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES = {
numpy.bitwise_and,
numpy.bitwise_or,
numpy.bitwise_xor,
numpy.gcd,
numpy.lcm,
numpy.ldexp,
numpy.left_shift,
numpy.logical_and,
numpy.logical_not,
numpy.logical_or,
numpy.logical_xor,
numpy.remainder,
numpy.right_shift,
}
def subtest_fuse_float_binary_operations_correctness(fun, tensor_shape):
"""Test a binary functions with fuse_float_operations, with a constant as a source."""
for i in range(4):
# Know if the function is defined for integer inputs
if fun in LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES:
if i not in [0, 2]:
continue
# The .astype(numpy.float64) that we have in cases 0 and 2 is here to force
# a float output even for functions which return an integer (eg, XOR), such
# that our frontend always try to fuse them
# The .astype(numpy.float64) that we have in cases 1 and 3 is here to force
# a float output even for functions which return a bool (eg, EQUAL), such
# that our frontend always try to fuse them
# For bivariate functions: fix one of the inputs
if i == 0:
# With an integer in first position
ones_0 = numpy.ones(tensor_shape, dtype=numpy.int32) if tensor_shape != () else 1
def get_function_to_trace():
return lambda x, y: fun(3 * ones_0, x + y).astype(numpy.float64).astype(numpy.int32)
elif i == 1:
# With a float in first position
ones_1 = numpy.ones(tensor_shape, dtype=numpy.float64) if tensor_shape != () else 1
def get_function_to_trace():
return (
lambda x, y: fun(2.3 * ones_1, x + y).astype(numpy.float64).astype(numpy.int32)
)
elif i == 2:
# With an integer in second position
ones_2 = numpy.ones(tensor_shape, dtype=numpy.int32) if tensor_shape != () else 1
def get_function_to_trace():
return lambda x, y: fun(x + y, 4 * ones_2).astype(numpy.float64).astype(numpy.int32)
else:
# With a float in second position
ones_else = numpy.ones(tensor_shape, dtype=numpy.float64) if tensor_shape != () else 1
def get_function_to_trace():
return (
lambda x, y: fun(x + y, 5.7 * ones_else)
.astype(numpy.float64)
.astype(numpy.int32)
)
input_list = [0, 2, 42, 44]
# Domain of definition
if fun in [numpy.true_divide, numpy.remainder, numpy.floor_divide, numpy.fmod]:
input_list = [2, 42, 44]
for input_ in input_list:
function_to_trace = get_function_to_trace()
params_names = signature(function_to_trace).parameters.keys()
op_graph = trace_numpy_function(
function_to_trace,
{
param_name: EncryptedTensor(Integer(32, True), tensor_shape)
for param_name in params_names
},
)
copied_graph = deepcopy(op_graph)
orig_num_nodes = len(op_graph.graph)
fuse_float_operations(op_graph)
fused_num_nodes = len(op_graph.graph)
fuse_float_operations(copied_graph)
# Check determinism
assert format_operation_graph(copied_graph) == format_operation_graph(op_graph)
assert fused_num_nodes < orig_num_nodes
# Check that the call to the function or to the op_graph evaluation give the same
# result
tensor_diversifier = (
# The following +1 in the range is to avoid to have 0's which is not in the
# domain definition of some of our functions
numpy.arange(1, numpy.product(tensor_shape) + 1, dtype=numpy.int32).reshape(
tensor_shape
)
if tensor_shape != ()
else numpy.int64(1)
)
# Make sure the tensor diversifier is a numpy variable, otherwise some cases may fail
# as python int and float don't have the astype method
input_ = input_ * tensor_diversifier
num_params = len(params_names)
assert num_params == 2
# Create inputs which are either of the form [x, x] or [x, y]
for j in range(4):
inputs = (input_, input_ + j)
function_result = function_to_trace(*inputs)
op_graph_result = op_graph(*inputs)
assert check_results_are_equal(function_result, op_graph_result)
def subtest_fuse_float_binary_operations_dont_support_two_variables(fun, tensor_shape):
"""Test a binary function with fuse_float_operations, with no constant as
a source."""
def get_function_to_trace():
return lambda x, y: fun(x, y).astype(numpy.int32)
function_to_trace = get_function_to_trace()
params_names = signature(function_to_trace).parameters.keys()
with pytest.raises(
AssertionError,
match=r"Can only have 1 non constant predecessor in _np_operator, got 2 for operator",
):
trace_numpy_function(
function_to_trace,
{
param_name: EncryptedTensor(Integer(32, True), tensor_shape)
for param_name in params_names
},
)
@pytest.mark.parametrize("fun", tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC)
@pytest.mark.parametrize(
"tensor_shape", [pytest.param((), id="scalar"), pytest.param((3, 1, 2), id="tensor")]
)
def test_ufunc_operations(fun, tensor_shape):
"""Test functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC."""
if fun.nin == 1:
subtest_fuse_float_unary_operations_correctness(fun, tensor_shape)
elif fun.nin == 2:
subtest_fuse_float_binary_operations_correctness(fun, tensor_shape)
subtest_fuse_float_binary_operations_dont_support_two_variables(fun, tensor_shape)
subtest_tensor_no_fuse(fun, tensor_shape)
else:
raise NotImplementedError("Only unary and binary functions are tested for now")

View File

@@ -1,433 +0,0 @@
"""Test file for intermediate representation"""
from copy import deepcopy
import numpy
import pytest
from concrete.common.data_types.floats import Float
from concrete.common.data_types.integers import Integer
from concrete.common.representation import intermediate as ir
from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor
@pytest.mark.parametrize(
"node,input_data,expected_result",
[
pytest.param(
ir.Add([EncryptedScalar(Integer(64, False)), EncryptedScalar(Integer(64, False))]),
[10, 4589],
4599,
id="Add",
),
pytest.param(
ir.Sub([EncryptedScalar(Integer(64, False)), EncryptedScalar(Integer(64, False))]),
[10, 4589],
-4579,
id="Sub",
),
pytest.param(
ir.Mul([EncryptedScalar(Integer(64, False)), EncryptedScalar(Integer(64, False))]),
[10, 4589],
45890,
id="Mul",
),
pytest.param(ir.Input(ClearScalar(Integer(32, True)), "in", 0), [42], 42, id="Input"),
pytest.param(ir.Constant(42), None, 42, id="Constant"),
pytest.param(ir.Constant(-42), None, -42, id="Constant"),
pytest.param(
ir.GenericFunction(
[EncryptedScalar(Integer(7, False))],
lambda x: x + 3,
EncryptedScalar(Integer(7, False)),
op_kind="TLU",
),
[10],
13,
id="GenericFunction, x + 3",
),
pytest.param(
ir.GenericFunction(
[EncryptedScalar(Integer(7, False))],
lambda x, y: x + y,
EncryptedScalar(Integer(7, False)),
op_kind="TLU",
op_kwargs={"y": 3},
),
[10],
13,
id="GenericFunction, (x, y) -> x + y, where y is constant == 3",
),
pytest.param(
ir.GenericFunction(
[EncryptedScalar(Integer(7, False))],
lambda x, y: y[x],
EncryptedScalar(Integer(7, False)),
op_kind="TLU",
op_kwargs={"y": (1, 2, 3, 4)},
),
[2],
3,
id="GenericFunction, (x, y) -> y[x], where y is constant == (1, 2, 3, 4)",
),
pytest.param(
ir.GenericFunction(
[EncryptedScalar(Integer(7, False))],
lambda x, y: y[3],
EncryptedScalar(Integer(7, False)),
op_kind="TLU",
op_kwargs={"y": (1, 2, 3, 4)},
),
[2],
4,
id="GenericFunction, x, y -> y[3], where y is constant == (1, 2, 3, 4)",
),
pytest.param(
ir.Dot(
[
EncryptedTensor(Integer(32, True), shape=(4,)),
ClearTensor(Integer(32, True), shape=(4,)),
],
Integer(32, True),
),
[[1, 2, 3, 4], [4, 3, 2, 1]],
20,
id="Dot, [1, 2, 3, 4], [4, 3, 2, 1]",
),
pytest.param(
ir.Dot(
[
EncryptedTensor(Float(32), shape=(4,)),
ClearTensor(Float(32), shape=(4,)),
],
Float(32),
),
[[1.0, 2.0, 3.0, 4.0], [4.0, 3.0, 2.0, 1.0]],
20,
id="Dot, [1.0, 2.0, 3.0, 4.0], [4.0, 3.0, 2.0, 1.0]",
),
pytest.param(
ir.Dot(
[
EncryptedTensor(Integer(32, True), shape=(4,)),
ClearTensor(Integer(32, True), shape=(4,)),
],
Integer(32, True),
delegate_evaluation_function=numpy.dot,
),
[
numpy.array([1, 2, 3, 4], dtype=numpy.int32),
numpy.array([4, 3, 2, 1], dtype=numpy.int32),
],
20,
id="Dot, np.array([1, 2, 3, 4]), np.array([4, 3, 2, 1])",
),
pytest.param(
ir.IndexConstant(EncryptedTensor(Integer(4, True), shape=(4,)), (0,)),
[
numpy.array([1, 2, 3, 4], dtype=numpy.int32),
],
1,
id="IndexConstant, np.array([1, 2, 3, 4])[0]",
),
pytest.param(
ir.IndexConstant(EncryptedTensor(Integer(4, True), shape=(4,)), (slice(1, 3, None),)),
[
numpy.array([1, 2, 3, 4], dtype=numpy.int32),
],
numpy.array([2, 3]),
id="IndexConstant, np.array([1, 2, 3, 4])[1:3]",
),
pytest.param(
ir.IndexConstant(EncryptedTensor(Integer(4, True), shape=(4,)), (slice(3, 1, -1),)),
[
numpy.array([1, 2, 3, 4], dtype=numpy.int32),
],
numpy.array([4, 3], dtype=numpy.int32),
id="IndexConstant, np.array([1, 2, 3, 4])[3:1:-1]",
),
pytest.param(
ir.IndexConstant(
EncryptedTensor(Integer(5, True), shape=(4, 4)), (slice(1, 3, 1), slice(2, 0, -1))
),
[
numpy.array(
[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
],
dtype=numpy.int32,
),
],
numpy.array(
[
[7, 6],
[11, 10],
],
dtype=numpy.int32,
),
id="IndexConstant, np.array([[1, 2, 3, 4]...[13, 14, 15, 16]])[1:3, 2:0:-1]",
),
pytest.param(
ir.MatMul(
[
EncryptedTensor(Integer(32, True), shape=(3, 2)),
ClearTensor(Integer(32, True), shape=(2, 3)),
],
Integer(32, True),
(3, 3),
),
[numpy.arange(1, 7).reshape(3, 2), numpy.arange(1, 7).reshape(2, 3)],
numpy.array([[9, 12, 15], [19, 26, 33], [29, 40, 51]]),
id="MatMul, numpy.arange(1, 7).reshape(3, 2), numpy.arange(1, 7).reshape(2, 3)",
),
pytest.param(
ir.GenericFunction(
[EncryptedTensor(Integer(32, False), shape=(3, 5))],
lambda x: numpy.transpose(x),
EncryptedTensor(Integer(32, False), shape=(5, 3)),
op_kind="Memory",
),
[numpy.arange(15).reshape(3, 5)],
numpy.array([[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8, 13], [4, 9, 14]]),
id="GenericFunction, x transpose",
),
pytest.param(
ir.GenericFunction(
[EncryptedTensor(Integer(32, False), shape=(3, 5))],
lambda x: numpy.ravel(x),
EncryptedTensor(Integer(32, False), shape=(5, 3)),
op_kind="Memory",
),
[numpy.arange(15).reshape(3, 5)],
numpy.arange(15),
id="GenericFunction, x ravel",
),
pytest.param(
ir.GenericFunction(
[EncryptedTensor(Integer(32, False), shape=(3, 5))],
lambda x: numpy.reshape(x, (5, 3)),
output_value=EncryptedTensor(Integer(32, False), shape=(5, 3)),
op_kind="Memory",
),
[numpy.arange(15).reshape(3, 5)],
numpy.arange(15).reshape(5, 3),
id="GenericFunction, x reshape",
),
],
)
def test_evaluate(
node: ir.IntermediateNode,
input_data,
expected_result: int,
check_array_equality,
):
"""Test evaluate methods on IntermediateNodes"""
if isinstance(expected_result, numpy.ndarray):
check_array_equality(node.evaluate(input_data), expected_result)
else:
assert node.evaluate(input_data) == expected_result
@pytest.mark.parametrize(
"node1,node2,expected_result",
[
(
ir.Add([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]),
ir.Add([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]),
True,
),
(
ir.Add([EncryptedScalar(Integer(16, False)), EncryptedScalar(Integer(32, False))]),
ir.Add([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(16, False))]),
True,
),
(
ir.Add([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]),
ir.Sub([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]),
False,
),
(
ir.Sub([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]),
ir.Sub([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]),
True,
),
(
ir.Sub([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(16, False))]),
ir.Sub([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(16, False))]),
True,
),
(
ir.Sub([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(16, False))]),
ir.Sub([EncryptedScalar(Integer(16, False)), EncryptedScalar(Integer(32, False))]),
False,
),
(
ir.Mul([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]),
ir.Mul([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]),
True,
),
(
ir.Mul([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]),
ir.Sub([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]),
False,
),
(
ir.Input(EncryptedScalar(Integer(32, False)), "x", 0),
ir.Sub([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]),
False,
),
(
ir.Input(EncryptedScalar(Integer(32, False)), "x", 0),
ir.Input(EncryptedScalar(Integer(32, False)), "x", 0),
True,
),
(
ir.Input(EncryptedScalar(Integer(32, False)), "x", 0),
ir.Input(EncryptedScalar(Integer(32, False)), "y", 0),
False,
),
(
ir.Input(EncryptedScalar(Integer(32, False)), "x", 0),
ir.Input(EncryptedScalar(Integer(32, False)), "x", 1),
False,
),
(
ir.Input(EncryptedScalar(Integer(32, False)), "x", 0),
ir.Input(EncryptedScalar(Integer(8, False)), "x", 0),
False,
),
(
ir.Constant(10),
ir.Constant(10),
True,
),
(
ir.Constant(10),
ir.Input(EncryptedScalar(Integer(8, False)), "x", 0),
False,
),
(
ir.Constant(10),
ir.Constant(10.0),
False,
),
(
ir.GenericFunction(
[EncryptedScalar(Integer(8, False))],
lambda x: x,
EncryptedScalar(Integer(8, False)),
op_kind="TLU",
),
ir.GenericFunction(
[EncryptedScalar(Integer(8, False))],
lambda x: x,
EncryptedScalar(Integer(8, False)),
op_kind="TLU",
),
True,
),
(
ir.GenericFunction(
[EncryptedScalar(Integer(8, False))],
lambda x: x,
EncryptedScalar(Integer(8, False)),
op_kind="TLU",
op_args=(1, 2, 3),
),
ir.GenericFunction(
[EncryptedScalar(Integer(8, False))],
lambda x: x,
EncryptedScalar(Integer(8, False)),
op_kind="TLU",
),
False,
),
(
ir.GenericFunction(
[EncryptedScalar(Integer(8, False))],
lambda x: x,
EncryptedScalar(Integer(8, False)),
op_kind="TLU",
op_kwargs={"tuple": (1, 2, 3)},
),
ir.GenericFunction(
[EncryptedScalar(Integer(8, False))],
lambda x: x,
EncryptedScalar(Integer(8, False)),
op_kind="TLU",
),
False,
),
(
ir.Dot(
[
EncryptedTensor(Integer(32, True), shape=(4,)),
ClearTensor(Integer(32, True), shape=(4,)),
],
Integer(32, True),
delegate_evaluation_function=numpy.dot,
),
ir.Dot(
[
EncryptedTensor(Integer(32, True), shape=(4,)),
ClearTensor(Integer(32, True), shape=(4,)),
],
Integer(32, True),
delegate_evaluation_function=numpy.dot,
),
True,
),
(
ir.Dot(
[
EncryptedTensor(Integer(32, True), shape=(4,)),
ClearTensor(Integer(32, True), shape=(4,)),
],
Integer(32, True),
delegate_evaluation_function=numpy.dot,
),
ir.Dot(
[
EncryptedTensor(Integer(32, True), shape=(4,)),
ClearTensor(Integer(32, True), shape=(4,)),
],
Integer(32, True),
),
False,
),
],
)
def test_is_equivalent_to(
node1: ir.IntermediateNode,
node2: ir.IntermediateNode,
expected_result: bool,
test_helpers,
):
"""Test is_equivalent_to methods on IntermediateNodes"""
assert (
test_helpers.nodes_are_equivalent(node1, node2)
== test_helpers.nodes_are_equivalent(node2, node1)
== expected_result
)
@pytest.mark.parametrize(
"list_to_fill,expected_list",
[
pytest.param([None, 1, 2, 3, None, None], [1, 1, 2, 3, 3, 3]),
pytest.param([None], None, marks=pytest.mark.xfail(strict=True)),
pytest.param([None, None, None, None, 7, None, None, None], [7, 7, 7, 7, 7, 7, 7, 7]),
pytest.param([None, None, 3, None, None, None, 2, None], [3, 3, 3, 3, 3, 2, 2, 2]),
],
)
def test_flood_replace_none_values(list_to_fill: list, expected_list: list):
"""Unit test for flood_replace_none_values"""
# avoid modifying the test input
list_to_fill_copy = deepcopy(list_to_fill)
ir.flood_replace_none_values(list_to_fill_copy)
assert all(value is not None for value in list_to_fill_copy)
assert list_to_fill_copy == expected_list

View File

@@ -1,74 +0,0 @@
"""Test file for common helpers"""
from copy import deepcopy
import pytest
from concrete.common import check_op_graph_is_integer_program, is_a_power_of_2
from concrete.common.data_types.floats import Float64
from concrete.common.data_types.integers import Integer
from concrete.common.values import EncryptedScalar
from concrete.numpy.tracing import trace_numpy_function
@pytest.mark.parametrize(
"x,result",
[
(0, False),
(1, True),
(2, True),
(3, False),
(4, True),
(10, False),
(16, True),
],
)
def test_is_a_power_of_2(x, result):
"""Test function for test_is_a_power_of_2"""
assert is_a_power_of_2(x) == result
def test_check_op_graph_is_integer_program():
"""Test function for check_op_graph_is_integer_program"""
def function(x, y):
return x + y - y * y + x * y
op_graph = trace_numpy_function(
function, {"x": EncryptedScalar(Integer(64, True)), "y": EncryptedScalar(Integer(64, True))}
)
# Test without and with output list
offending_nodes = []
assert check_op_graph_is_integer_program(op_graph)
assert check_op_graph_is_integer_program(op_graph, offending_nodes)
assert len(offending_nodes) == 0
op_graph_copy = deepcopy(op_graph)
op_graph_copy.output_nodes[0].outputs[0].dtype = Float64
offending_nodes = []
assert not check_op_graph_is_integer_program(op_graph_copy)
assert not check_op_graph_is_integer_program(op_graph_copy, offending_nodes)
assert len(offending_nodes) == 1
assert offending_nodes == [op_graph_copy.output_nodes[0]]
op_graph_copy = deepcopy(op_graph)
op_graph_copy.input_nodes[0].inputs[0].dtype = Float64
offending_nodes = []
assert not check_op_graph_is_integer_program(op_graph_copy)
assert not check_op_graph_is_integer_program(op_graph_copy, offending_nodes)
assert len(offending_nodes) == 1
assert offending_nodes == [op_graph_copy.input_nodes[0]]
op_graph_copy = deepcopy(op_graph)
op_graph_copy.input_nodes[0].inputs[0].dtype = Float64
op_graph_copy.input_nodes[1].inputs[0].dtype = Float64
offending_nodes = []
assert not check_op_graph_is_integer_program(op_graph_copy)
assert not check_op_graph_is_integer_program(op_graph_copy, offending_nodes)
assert len(offending_nodes) == 2
assert set(offending_nodes) == set([op_graph_copy.input_nodes[0], op_graph_copy.input_nodes[1]])

View File

@@ -1,54 +0,0 @@
"""Test module for Circuit class"""
import filecmp
import concrete.numpy as hnp
from concrete.common.debugging import draw_graph, format_operation_graph
def test_circuit_str(default_compilation_configuration):
"""Test function for `__str__` method of `Circuit`"""
def f(x):
return x + 42
x = hnp.EncryptedScalar(hnp.UnsignedInteger(3))
inputset = range(2 ** 3)
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration)
assert str(circuit) == format_operation_graph(circuit.op_graph)
def test_circuit_draw(default_compilation_configuration):
"""Test function for `draw` method of `Circuit`"""
def f(x):
return x + 42
x = hnp.EncryptedScalar(hnp.UnsignedInteger(3))
inputset = range(2 ** 3)
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration)
assert filecmp.cmp(circuit.draw(), draw_graph(circuit.op_graph))
assert filecmp.cmp(circuit.draw(vertical=False), draw_graph(circuit.op_graph, vertical=False))
def test_circuit_run(default_compilation_configuration):
"""Test equivalence of encrypt/run/decrypt and encrypt_run_decrypt"""
def f(x):
return x + 42
x = hnp.EncryptedScalar(hnp.UnsignedInteger(3))
inputset = range(2 ** 3)
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration)
circuit.keygen()
for x in inputset:
enc_x = circuit.encrypt(x)
enc_res = circuit.run(enc_x)
res = circuit.decrypt(enc_res)
assert circuit.encrypt_run_decrypt(x) == res

View File

@@ -1,44 +0,0 @@
"""Test file for common tracing helpers"""
from typing import Any, Dict, List
import pytest
from concrete.common.tracing.tracing_helpers import prepare_function_parameters
@pytest.mark.parametrize(
"function,function_parameters,ref_dict",
[
pytest.param(lambda x: None, {}, {}, id="Missing x", marks=pytest.mark.xfail(strict=True)),
pytest.param(lambda x: None, {"x": None}, {"x": None}, id="Only x"),
pytest.param(
lambda x: None, {"x": None, "y": None}, {"x": None}, id="Additional y filtered"
),
],
)
def test_prepare_function_parameters(
function, function_parameters: Dict[str, Any], ref_dict: Dict[str, Any]
):
"""Test prepare_function_parameters"""
prepared_dict = prepare_function_parameters(function, function_parameters)
assert prepared_dict == ref_dict
@pytest.mark.parametrize(
"function,function_parameters,expected_ordered_keys",
[
(lambda x: None, {"x": None}, ["x"]),
(lambda x, y: None, {"x": None, "y": None}, ["x", "y"]),
(lambda x, y: None, {"y": None, "x": None}, ["x", "y"]),
(lambda z, x, y: None, {"y": None, "z": None, "x": None}, ["z", "x", "y"]),
],
)
def test_prepare_function_parameters_order(
function, function_parameters: Dict[str, Any], expected_ordered_keys: List[str]
):
"""Test prepare_function_parameters output order"""
prepared_dict = prepare_function_parameters(function, function_parameters)
assert list(prepared_dict.keys()) == expected_ordered_keys

View File

@@ -1,473 +0,0 @@
"""PyTest configuration file"""
import json
import operator
import re
import shutil
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, Optional, Type
import networkx as nx
import networkx.algorithms.isomorphism as iso
import numpy
import pytest
from concrete.common.compilation import CompilationConfiguration
from concrete.common.fhe_circuit import FHECircuit
from concrete.common.mlir.utils import (
ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB,
get_op_graph_max_bit_width_and_nodes_over_bit_width_limit,
)
from concrete.common.representation.intermediate import (
ALL_IR_NODES,
Add,
Constant,
Conv2D,
Dot,
GenericFunction,
IndexConstant,
Input,
IntermediateNode,
MatMul,
Mul,
Sub,
)
from concrete.numpy import compile as compile_
def pytest_addoption(parser):
"""Options for pytest"""
parser.addoption(
"--global-coverage-infos-json",
action="store",
default=None,
type=str,
help="To dump pytest-cov term report to a text file.",
)
parser.addoption(
"--keyring-dir",
action="store",
default=None,
type=str,
help="Specify the dir to use to store key cache",
)
DEFAULT_KEYRING_PATH = Path.home().resolve() / ".cache/concrete-numpy_pytest"
def get_keyring_dir_from_session_or_default(
session: Optional[pytest.Session] = None,
) -> Optional[Path]:
"""Get keyring dir from test session."""
if session is None:
return DEFAULT_KEYRING_PATH
keyring_dir = session.config.getoption("--keyring-dir", default=None)
if keyring_dir is not None:
if keyring_dir.lower() == "disable":
return None
keyring_dir = Path(keyring_dir).expanduser().resolve()
else:
keyring_dir = DEFAULT_KEYRING_PATH
return keyring_dir
@pytest.fixture
def default_keyring_path():
"""Fixture to get test keyring dir."""
return DEFAULT_KEYRING_PATH
# This is only for doctests where we currently cannot make use of fixtures
original_compilation_config_init = CompilationConfiguration.__init__
def monkeypatched_compilation_configuration_init_for_codeblocks(
self: CompilationConfiguration, *args, **kwargs
):
"""Monkeypatched compilation configuration init for codeblocks tests."""
original_compilation_config_init(self, *args, **kwargs)
self.dump_artifacts_on_unexpected_failures = False
self.enable_unsafe_features = True # This is for our tests only, never use that in prod
self.treat_warnings_as_errors = True
self.use_insecure_key_cache = True # This is for our tests only, never use that in prod
def pytest_sessionstart(session: pytest.Session):
"""Handle keyring for session and codeblocks CompilationConfiguration if needed."""
if session.config.getoption("--codeblocks", default=False):
# setattr to avoid mypy complaining
# Disable the flake8 bug bear warning for the mypy fix
setattr( # noqa: B010
CompilationConfiguration,
"__init__",
monkeypatched_compilation_configuration_init_for_codeblocks,
)
keyring_dir = get_keyring_dir_from_session_or_default(session)
if keyring_dir is None:
return
keyring_dir.mkdir(parents=True, exist_ok=True)
keyring_dir_as_str = str(keyring_dir)
print(f"Using {keyring_dir_as_str} as key cache dir")
compile_._COMPILE_FHE_INSECURE_KEY_CACHE_DIR = ( # pylint: disable=protected-access
keyring_dir_as_str
)
def pytest_sessionfinish(session: pytest.Session, exitstatus): # pylint: disable=unused-argument
"""Pytest callback when testing ends."""
# Hacked together from the source code, they don't have an option to export to file and it's too
# much work to get a PR in for such a little thing
# https://github.com/pytest-dev/pytest-cov/blob/
# ec344d8adf2d78238d8f07cb20ed2463d7536970/src/pytest_cov/plugin.py#L329
if session.config.pluginmanager.hasplugin("_cov"):
global_coverage_file = session.config.getoption(
"--global-coverage-infos-json", default=None
)
if global_coverage_file is not None:
cov_plugin = session.config.pluginmanager.getplugin("_cov")
coverage_txt = cov_plugin.cov_report.getvalue()
coverage_status = 0
if (
cov_plugin.options.cov_fail_under is not None
and cov_plugin.options.cov_fail_under > 0
):
failed = cov_plugin.cov_total < cov_plugin.options.cov_fail_under
# If failed is False coverage_status is 0, if True it's 1
coverage_status = int(failed)
global_coverage_file_path = Path(global_coverage_file).resolve()
with open(global_coverage_file_path, "w", encoding="utf-8") as f:
json.dump({"exit_code": coverage_status, "content": coverage_txt}, f)
keyring_dir = get_keyring_dir_from_session_or_default(session)
if keyring_dir is not None:
# Remove incomplete keys
for incomplete_keys in keyring_dir.glob("**/*incomplete*"):
shutil.rmtree(incomplete_keys, ignore_errors=True)
def _is_equivalent_to_binary_commutative(lhs: IntermediateNode, rhs: object) -> bool:
"""is_equivalent_to for a binary and commutative operation."""
return (
isinstance(rhs, lhs.__class__)
and (lhs.inputs in (rhs.inputs, rhs.inputs[::-1]))
and lhs.outputs == rhs.outputs
)
def _is_equivalent_to_binary_non_commutative(lhs: IntermediateNode, rhs: object) -> bool:
"""is_equivalent_to for a binary and non-commutative operation."""
return (
isinstance(rhs, lhs.__class__) and lhs.inputs == rhs.inputs and lhs.outputs == rhs.outputs
)
def is_equivalent_add(lhs: Add, rhs: object) -> bool:
"""Helper function to check if an Add node is equivalent to an other object."""
return _is_equivalent_to_binary_commutative(lhs, rhs)
# From https://stackoverflow.com/a/28635464
_code_and_constants_attr_getter = operator.attrgetter("co_code", "co_consts")
def _code_and_constants(object_):
"""Helper function to get python code and constants"""
return _code_and_constants_attr_getter(object_.__code__)
def python_functions_are_equal_or_equivalent(lhs: object, rhs: object) -> bool:
"""Helper function to check if two functions are equal or their code are equivalent.
This is not perfect, but will be good enough for tests.
"""
if lhs == rhs:
return True
try:
lhs_code_and_constants = _code_and_constants(lhs)
rhs_code_and_constants = _code_and_constants(rhs)
return lhs_code_and_constants == rhs_code_and_constants
except AttributeError:
return False
def is_equivalent_arbitrary_function(lhs: GenericFunction, rhs: object) -> bool:
"""Helper function to check if an GenericFunction node is equivalent to an other object."""
return (
isinstance(rhs, GenericFunction)
and python_functions_are_equal_or_equivalent(lhs.arbitrary_func, rhs.arbitrary_func)
and lhs.op_kind == rhs.op_kind
and lhs.op_args == rhs.op_args
and lhs.op_kwargs == rhs.op_kwargs
and lhs.op_attributes == rhs.op_attributes
and lhs.op_name == rhs.op_name
and is_equivalent_intermediate_node(lhs, rhs)
)
def is_equivalent_constant(lhs: Constant, rhs: object) -> bool:
"""Helper function to check if a Constant node is equivalent to an other object."""
return (
isinstance(rhs, Constant)
and lhs.constant_data == rhs.constant_data
and is_equivalent_intermediate_node(lhs, rhs)
)
def is_equivalent_dot(lhs: Dot, rhs: object) -> bool:
"""Helper function to check if a Dot node is equivalent to an other object."""
return (
isinstance(rhs, Dot)
and lhs.evaluation_function == rhs.evaluation_function
and is_equivalent_intermediate_node(lhs, rhs)
)
def is_equivalent_input(lhs: Input, rhs: object) -> bool:
"""Helper function to check if an Input node is equivalent to an other object."""
return (
isinstance(rhs, Input)
and lhs.input_name == rhs.input_name
and lhs.program_input_idx == rhs.program_input_idx
and is_equivalent_intermediate_node(lhs, rhs)
)
def is_equivalent_index_constant(lhs: IndexConstant, rhs: object) -> bool:
"""Helper function to check if an IndexConstant node is equivalent to an other object."""
return (
isinstance(rhs, IndexConstant)
and lhs.index == rhs.index
and is_equivalent_intermediate_node(lhs, rhs)
)
def is_equivalent_mul(lhs: Mul, rhs: object) -> bool:
"""Helper function to check if a Mul node is equivalent to an other object."""
return _is_equivalent_to_binary_commutative(lhs, rhs)
def is_equivalent_sub(lhs: Sub, rhs: object) -> bool:
"""Helper function to check if a Sub node is equivalent to an other object."""
return _is_equivalent_to_binary_non_commutative(lhs, rhs)
def is_equivalent_matmul(lhs: MatMul, rhs: object) -> bool:
"""Helper function to check if a MatMul node is equivalent to an other object."""
return isinstance(rhs, MatMul) and is_equivalent_intermediate_node(lhs, rhs)
def is_equivalent_conv2d(lhs: Conv2D, rhs: object) -> bool:
"""Helper function to check if a Conv2D node is equivalent to an other object."""
return isinstance(rhs, Conv2D) and is_equivalent_intermediate_node(lhs, rhs)
def is_equivalent_intermediate_node(lhs: IntermediateNode, rhs: object) -> bool:
"""Helper function to check if an IntermediateNode node is equivalent to an other object."""
return (
isinstance(rhs, IntermediateNode)
and lhs.inputs == rhs.inputs
and lhs.outputs == rhs.outputs
)
EQUIVALENT_TEST_FUNC: Dict[Type, Callable[..., bool]] = {
Add: is_equivalent_add,
GenericFunction: is_equivalent_arbitrary_function,
Constant: is_equivalent_constant,
Conv2D: is_equivalent_conv2d,
Dot: is_equivalent_dot,
IndexConstant: is_equivalent_index_constant,
Input: is_equivalent_input,
Mul: is_equivalent_mul,
Sub: is_equivalent_sub,
MatMul: is_equivalent_matmul,
}
_missing_nodes_in_mapping = ALL_IR_NODES - EQUIVALENT_TEST_FUNC.keys()
assert len(_missing_nodes_in_mapping) == 0, (
f"Missing IR node in EQUIVALENT_TEST_FUNC : "
f"{', '.join(sorted(str(node_type) for node_type in _missing_nodes_in_mapping))}"
)
del _missing_nodes_in_mapping
class TestHelpers:
"""Class allowing to pass helper functions to tests"""
@staticmethod
def nodes_are_equivalent(lhs, rhs) -> bool:
"""Helper function for tests to check if two nodes are equivalent."""
equivalent_func = EQUIVALENT_TEST_FUNC.get(type(lhs), None)
if equivalent_func is not None:
return equivalent_func(lhs, rhs)
# This is a default for the test_conftest.py that should remain separate from the package
# nodes is_equivalent_* functions
return lhs.is_equivalent_to(rhs)
@staticmethod
def digraphs_are_equivalent(reference: nx.MultiDiGraph, to_compare: nx.MultiDiGraph):
"""Check that two digraphs are equivalent without modifications"""
# edge_match is a copy of node_match
edge_matcher = iso.categorical_multiedge_match(["input_idx", "output_idx"], [None, None])
node_matcher = iso.generic_node_match(
"_test_content", None, TestHelpers.nodes_are_equivalent
)
# Set the _test_content for each node in the graphs
for node in reference.nodes():
reference.add_node(node, _test_content=node)
for node in to_compare.nodes():
to_compare.add_node(node, _test_content=node)
graphs_are_isomorphic = nx.is_isomorphic(
reference,
to_compare,
node_match=node_matcher,
edge_match=edge_matcher,
)
return graphs_are_isomorphic
@staticmethod
def python_functions_are_equal_or_equivalent(lhs, rhs):
"""Helper function to check if two functions are equal or their code are equivalent.
This is not perfect, but will be good enough for tests.
"""
return python_functions_are_equal_or_equivalent(lhs, rhs)
@pytest.fixture
def test_helpers():
"""Fixture to return the static helper class"""
return TestHelpers
@pytest.fixture
def default_compilation_configuration():
"""Return the default test compilation configuration"""
return CompilationConfiguration(
dump_artifacts_on_unexpected_failures=False,
enable_unsafe_features=True, # This is for our tests only, never use that in prod
treat_warnings_as_errors=True,
use_insecure_key_cache=True, # This is for our tests only, never use that in prod
)
REMOVE_COLOR_CODES_RE = re.compile(r"\x1b[^m]*m")
@pytest.fixture
def remove_color_codes():
"""Return the re object to remove color codes"""
return lambda x: REMOVE_COLOR_CODES_RE.sub("", x)
def check_is_good_execution_impl(
fhe_circuit: FHECircuit,
function: Callable,
args: Iterable[Any],
preprocess_input_func: Callable[[Any], Any] = lambda x: x,
postprocess_output_func: Callable[[Any], Any] = lambda x: x,
check_function: Callable[[Any, Any], bool] = numpy.equal,
verbose: bool = True,
):
"""Run several times the check compiler_engine.run(*args) == function(*args). If always wrong,
return an error. One can set the expected probability of success of one execution and the
number of tests, to finetune the probability of bad luck, ie that we run several times the
check and always have a wrong result."""
max_bit_width, _ = get_op_graph_max_bit_width_and_nodes_over_bit_width_limit(
fhe_circuit.op_graph
)
# Allow tests to pass if cells of the output result are good at least once over the nb_tries
# Enabled only when we have a circuit that's using the maximum possible bit width
# >= if there are 8 bits signed integers
allow_relaxed_tests_passing = max_bit_width >= ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB
# FIXME: https://github.com/zama-ai/concrete-numpy-internal/issues/1255
# Increased with compiler accuracy which dropped, make sure to remove once accuracy improves
nb_tries = 10
# Prepare the bool array to record if cells were properly computed
preprocessed_args = tuple(preprocess_input_func(val) for val in args)
cells_were_properly_computed = numpy.zeros_like(function(*preprocessed_args), dtype=bool)
for i in range(1, nb_tries + 1):
preprocessed_args = tuple(preprocess_input_func(val) for val in args)
last_engine_result = postprocess_output_func(
fhe_circuit.encrypt_run_decrypt(*preprocessed_args)
)
last_function_result = postprocess_output_func(function(*preprocessed_args))
ok_execution = check_function(last_engine_result, last_function_result)
if isinstance(ok_execution, numpy.ndarray):
# Record the cells that were well computed
cells_were_properly_computed = numpy.logical_or(
cells_were_properly_computed, ok_execution
)
# Get a boolean for the execution
ok_execution = ok_execution.all()
if ok_execution:
# Good computation after i tries
if verbose:
print(f"Good computation after {i} tries")
return
# FIXME: https://github.com/zama-ai/concrete-numpy-internal/issues/1264
# Remove the relaxed tests once accuracy is good again for 7 bits
if allow_relaxed_tests_passing and cells_were_properly_computed.all():
print(
"Computation was never good for all output cells at the same time, "
f"however each was evaluated properly at least once, stopped after {i} tries"
)
return
raise AssertionError(
f"bad computation after {nb_tries} tries.\nLast engine result:\n{last_engine_result}\n"
f"Last function result:\n{last_function_result}"
)
@pytest.fixture
def check_is_good_execution():
"""Fixture to seed torch"""
return check_is_good_execution_impl
def check_array_equality_impl(actual: Any, expected: Any, verbose: bool = True):
"""Assert that `actual` is equal to `expected`."""
assert numpy.array_equal(actual, expected), (
""
if not verbose
else f"""
Expected Output
===============
{expected}
Actual Output
=============
{actual}
"""
)
@pytest.fixture
def check_array_equality():
"""Fixture to check array equality"""
return check_array_equality_impl

View File

@@ -1,57 +0,0 @@
"""Test file for conftest helper functions"""
import networkx as nx
def test_digraphs_are_equivalent(test_helpers):
"""Function to test digraphs_are_equivalent helper function"""
class TestNode:
"""Dummy test node"""
computation: str
def __init__(self, computation: str) -> None:
self.computation = computation
def __hash__(self) -> int:
return self.computation.__hash__()
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.computation == other.computation
is_equivalent_to = __eq__
g_1 = nx.MultiDiGraph()
g_2 = nx.MultiDiGraph()
t_0 = TestNode("Add")
t_1 = TestNode("Mul")
t_2 = TestNode("TLU")
g_1.add_edge(t_0, t_2, input_idx=0, output_idx=0)
g_1.add_edge(t_1, t_2, input_idx=1, output_idx=0)
t0p = TestNode("Add")
t1p = TestNode("Mul")
t2p = TestNode("TLU")
g_2.add_edge(t1p, t2p, input_idx=1, output_idx=0)
g_2.add_edge(t0p, t2p, input_idx=0, output_idx=0)
bad_g2 = nx.MultiDiGraph()
bad_t0 = TestNode("Not Add")
bad_g2.add_edge(bad_t0, t_2, input_idx=0, output_idx=0)
bad_g2.add_edge(t_1, t_2, input_idx=1, output_idx=0)
bad_g3 = nx.MultiDiGraph()
bad_g3.add_edge(t_0, t_2, input_idx=1, output_idx=0)
bad_g3.add_edge(t_1, t_2, input_idx=0, output_idx=0)
assert test_helpers.digraphs_are_equivalent(g_1, g_2), "Graphs should be equivalent"
assert not test_helpers.digraphs_are_equivalent(g_1, bad_g2), "Graphs should not be equivalent"
assert not test_helpers.digraphs_are_equivalent(g_2, bad_g2), "Graphs should not be equivalent"
assert not test_helpers.digraphs_are_equivalent(g_1, bad_g3), "Graphs should not be equivalent"
assert not test_helpers.digraphs_are_equivalent(g_2, bad_g3), "Graphs should not be equivalent"

File diff suppressed because it is too large Load Diff

View File

@@ -1,727 +0,0 @@
"""Test module for constant indexing."""
import numpy as np
import pytest
from concrete.common.data_types import UnsignedInteger
from concrete.common.values import EncryptedScalar, EncryptedTensor
from concrete.numpy import (
compile_numpy_function,
compile_numpy_function_into_op_graph_and_measure_bounds,
)
@pytest.mark.parametrize(
"input_value,function_with_indexing,output_value",
[
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-3],
EncryptedScalar(UnsignedInteger(1)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-2],
EncryptedScalar(UnsignedInteger(1)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-1],
EncryptedScalar(UnsignedInteger(1)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[0],
EncryptedScalar(UnsignedInteger(1)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[1],
EncryptedScalar(UnsignedInteger(1)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[2],
EncryptedScalar(UnsignedInteger(1)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[:],
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-3:],
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-2:],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-1:],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[0:],
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[1:],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[2:],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[:-1],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[:-2],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[:1],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[:2],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[:3],
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-3:-2],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-3:-1],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-3:1],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-3:2],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-3:3],
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-2:-1],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-2:2],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-2:3],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-1:3],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[0:-2],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[0:-1],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[0:1],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[0:2],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[0:3],
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[1:-1],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[1:2],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[1:3],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[2:3],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[::-1],
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-3::-1],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-2::-1],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-1::-1],
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[0::-1],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[1::-1],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[2::-1],
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[:-3:-1],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[:-2:-1],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[:0:-1],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[:1:-1],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[2:0:-1],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[2:1:-1],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-1:1:-1],
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[-1:0:-1],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
lambda x: x[:, :, :],
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
lambda x: x[0, :, :],
EncryptedTensor(UnsignedInteger(1), shape=(4, 5)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
lambda x: x[:, 0, :],
EncryptedTensor(UnsignedInteger(1), shape=(3, 5)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
lambda x: x[:, :, 0],
EncryptedTensor(UnsignedInteger(1), shape=(3, 4)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
lambda x: x[0, 0, :],
EncryptedTensor(UnsignedInteger(1), shape=(5,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
lambda x: x[0, :, 0],
EncryptedTensor(UnsignedInteger(1), shape=(4,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
lambda x: x[:, 0, 0],
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
lambda x: x[0:, 1:, 2:],
EncryptedTensor(UnsignedInteger(1), shape=(3, 3, 3)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
lambda x: x[2:, 1:, 0:],
EncryptedTensor(UnsignedInteger(1), shape=(1, 3, 5)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
lambda x: x[0],
EncryptedTensor(UnsignedInteger(1), shape=(4, 5)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
lambda x: x[0, 0],
EncryptedTensor(UnsignedInteger(1), shape=(5,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
lambda x: x[0, 0, 0],
EncryptedScalar(UnsignedInteger(1)),
),
],
)
def test_constant_indexing(
default_compilation_configuration,
input_value,
function_with_indexing,
output_value,
):
"""Test compile_numpy_function_into_op_graph with constant indexing"""
inputset = [
np.random.randint(
input_value.dtype.min_value(),
input_value.dtype.max_value() + 1,
size=input_value.shape,
)
for _ in range(10)
]
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
function_with_indexing,
{"x": input_value},
inputset,
default_compilation_configuration,
)
assert len(op_graph.output_nodes) == 1
output_node = op_graph.output_nodes[0]
assert len(output_node.outputs) == 1
assert output_value == output_node.outputs[0]
@pytest.mark.parametrize(
"input_value,function_with_indexing,expected_error_type,expected_error_message",
[
pytest.param(
EncryptedScalar(UnsignedInteger(1)),
lambda x: x[0],
TypeError,
"Only tensors can be indexed but you tried to index EncryptedScalar<uint1>",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[0.5],
TypeError,
"Only integers and integer slices can be used for indexing "
"but you tried to use 0.5 for indexing",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[1:5:0.5], # type: ignore
TypeError,
"Only integers and integer slices can be used for indexing "
"but you tried to use 1:5:0.5 for indexing",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[0, 1],
ValueError,
"Tensor of shape (3,) cannot be indexed with [0, 1] "
"as the index has more elements than the number of dimensions of the tensor",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[5],
ValueError,
"Tensor of shape (3,) cannot be indexed with [5] "
"because index is out of range for dimension 0",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[5:],
ValueError,
"Tensor of shape (3,) cannot be indexed with [5:] "
"because start index is out of range for dimension 0",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[:10],
ValueError,
"Tensor of shape (3,) cannot be indexed with [:10] "
"because stop index is out of range for dimension 0",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[2:0],
ValueError,
"Tensor of shape (3,) cannot be indexed with [2:0] "
"because start index is not less than stop index for dimension 0",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[5::-1],
ValueError,
"Tensor of shape (3,) cannot be indexed with [5::-1] "
"because start index is out of range for dimension 0",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[:10:-1],
ValueError,
"Tensor of shape (3,) cannot be indexed with [:10:-1] "
"because stop index is out of range for dimension 0",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[0:2:-1],
ValueError,
"Tensor of shape (3,) cannot be indexed with [0:2:-1] "
"because step is negative and stop index is not less than start index for dimension 0",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[::0],
ValueError,
"Tensor of shape (3,) cannot be indexed with [::0] "
"because step is zero for dimension 0",
),
],
)
def test_invalid_constant_indexing(
default_compilation_configuration,
input_value,
function_with_indexing,
expected_error_type,
expected_error_message,
):
"""Test compile_numpy_function_into_op_graph with invalid constant indexing"""
with pytest.raises(expected_error_type):
try:
inputset = [
(
np.random.randint(
input_value.dtype.min_value(),
input_value.dtype.max_value() + 1,
size=input_value.shape,
),
)
for _ in range(10)
]
compile_numpy_function_into_op_graph_and_measure_bounds(
function_with_indexing,
{"x": input_value},
inputset,
default_compilation_configuration,
)
except Exception as error:
assert str(error) == expected_error_message
raise
@pytest.mark.parametrize(
"input_value,function_with_indexing,output_value",
[
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[np.uint32(0)],
EncryptedScalar(UnsignedInteger(1)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[slice(np.uint32(2), np.int32(0), np.int8(-1))],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[np.array(0)],
EncryptedScalar(UnsignedInteger(1)),
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[slice(np.array(2), np.array(0), np.array(-1))],
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
),
],
)
def test_constant_indexing_with_numpy_integers(
default_compilation_configuration,
input_value,
function_with_indexing,
output_value,
):
"""Test compile_numpy_function_into_op_graph with constant indexing with numpy integers"""
inputset = [
np.random.randint(
input_value.dtype.min_value(),
input_value.dtype.max_value() + 1,
size=input_value.shape,
)
for _ in range(10)
]
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
function_with_indexing,
{"x": input_value},
inputset,
default_compilation_configuration,
)
assert len(op_graph.output_nodes) == 1
output_node = op_graph.output_nodes[0]
assert len(output_node.outputs) == 1
assert output_value == output_node.outputs[0]
@pytest.mark.parametrize(
"input_value,function_with_indexing,expected_error_type,expected_error_message",
[
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[np.float32(1.5)],
TypeError,
"Only integers and integer slices can be used for indexing "
"but you tried to use 1.5 for indexing",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[np.array(1.5)],
TypeError,
"Only integers and integer slices can be used for indexing "
"but you tried to use 1.5 for indexing",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
lambda x: x[np.array([1, 2])],
TypeError,
"Only integers and integer slices can be used for indexing "
"but you tried to use [1 2] for indexing",
),
],
)
def test_invalid_constant_indexing_with_numpy_values(
default_compilation_configuration,
input_value,
function_with_indexing,
expected_error_type,
expected_error_message,
):
"""Test compile_numpy_function_into_op_graph with invalid constant indexing with numpy values"""
with pytest.raises(expected_error_type):
try:
inputset = [
(
np.random.randint(
input_value.dtype.min_value(),
input_value.dtype.max_value() + 1,
size=input_value.shape,
),
)
for _ in range(10)
]
compile_numpy_function_into_op_graph_and_measure_bounds(
function_with_indexing,
{"x": input_value},
inputset,
default_compilation_configuration,
)
except Exception as error:
assert str(error) == expected_error_message
raise
@pytest.mark.parametrize(
"function,parameters,inputset,test_input,expected_output",
[
pytest.param(
lambda x: x[0],
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(3,)),
},
[np.random.randint(0, 2 ** 3, size=(3,)) for _ in range(10)],
([4, 2, 6],),
4,
),
pytest.param(
lambda x: x[-1],
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(3,)),
},
[np.random.randint(0, 2 ** 3, size=(3,)) for _ in range(10)],
([4, 2, 6],),
6,
),
pytest.param(
lambda x: x[:3],
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(4,)),
},
[np.random.randint(0, 2 ** 3, size=(4,)) for _ in range(10)],
([4, 2, 6, 1],),
[4, 2, 6],
),
pytest.param(
lambda x: x[2:],
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(4,)),
},
[np.random.randint(0, 2 ** 3, size=(4,)) for _ in range(10)],
([4, 2, 6, 1],),
[6, 1],
),
pytest.param(
lambda x: x[1:3],
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(4,)),
},
[np.random.randint(0, 2 ** 3, size=(4,)) for _ in range(10)],
([4, 2, 6, 1],),
[2, 6],
),
pytest.param(
lambda x: x[::2],
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(4,)),
},
[np.random.randint(0, 2 ** 3, size=(4,)) for _ in range(10)],
([4, 2, 6, 1],),
[4, 6],
),
pytest.param(
lambda x: x[::-1],
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(4,)),
},
[np.random.randint(0, 2 ** 3, size=(4,)) for _ in range(10)],
([4, 2, 6, 1],),
[1, 6, 2, 4],
),
pytest.param(
lambda x: x[1, 0],
{
"x": EncryptedTensor(UnsignedInteger(6), shape=(3, 2)),
},
[np.random.randint(0, 2 ** 6, size=(3, 2)) for _ in range(10)],
([[11, 12], [21, 22], [31, 32]],),
21,
),
pytest.param(
lambda x: x[:, :],
{
"x": EncryptedTensor(UnsignedInteger(6), shape=(3, 2)),
},
[np.random.randint(0, 2 ** 6, size=(3, 2)) for _ in range(10)],
([[11, 12], [21, 22], [31, 32]],),
[[11, 12], [21, 22], [31, 32]],
),
pytest.param(
lambda x: x[0, :],
{
"x": EncryptedTensor(UnsignedInteger(6), shape=(3, 2)),
},
[np.random.randint(0, 2 ** 6, size=(3, 2)) for _ in range(10)],
([[11, 12], [21, 22], [31, 32]],),
[11, 12],
),
pytest.param(
lambda x: x[:, 0],
{
"x": EncryptedTensor(UnsignedInteger(6), shape=(3, 2)),
},
[np.random.randint(0, 2 ** 6, size=(3, 2)) for _ in range(10)],
([[11, 12], [21, 22], [31, 32]],),
[11, 21, 31],
),
],
)
def test_constant_indexing_run_correctness(
function,
parameters,
inputset,
test_input,
expected_output,
default_compilation_configuration,
check_array_equality,
):
"""Test correctness of results when running a compiled function with tensor operators"""
circuit = compile_numpy_function(
function,
parameters,
inputset,
default_compilation_configuration,
)
numpy_test_input = tuple(
item if isinstance(item, int) else np.array(item, dtype=np.uint8) for item in test_input
)
output = circuit.encrypt_run_decrypt(*numpy_test_input)
expected = np.array(expected_output, dtype=np.uint8)
check_array_equality(output, expected)

View File

@@ -1,44 +0,0 @@
"""Test module for convolution compilation and execution."""
import numpy as np
import pytest
import concrete.numpy as hnp
from concrete.common.data_types.integers import Integer
from concrete.common.values.tensors import EncryptedTensor
from concrete.numpy.compile import compile_numpy_function
@pytest.mark.parametrize(
"input_shape, weight_shape",
[
pytest.param((1, 1, 4, 4), (1, 1, 2, 2)),
pytest.param((4, 3, 4, 4), (2, 3, 2, 2)),
],
)
@pytest.mark.parametrize("strides", [(2, 2)])
@pytest.mark.parametrize("dilations", [(1, 1)])
@pytest.mark.parametrize("has_bias", [True, False])
def test_compile_and_run(
input_shape, weight_shape, strides, dilations, has_bias, default_compilation_configuration
):
"""Test function to make sure compilation and execution of conv2d works properly"""
if has_bias:
bias = np.random.randint(0, 4, size=(weight_shape[0],))
else:
bias = None
weight = np.random.randint(0, 4, size=weight_shape)
def conv(x):
return hnp.conv2d(x, weight, bias, strides=strides, dilations=dilations)
compiler_engine = compile_numpy_function(
conv,
{"x": EncryptedTensor(Integer(64, False), input_shape)},
[np.random.randint(0, 4, size=input_shape) for i in range(20)],
default_compilation_configuration,
)
x = np.random.randint(0, 4, size=input_shape, dtype=np.uint8)
expected = conv(x)
result = compiler_engine.encrypt_run_decrypt(x)
assert (expected == result).all()

View File

@@ -1,265 +0,0 @@
"""Test module for memory operations."""
import numpy
import pytest
from concrete.common.data_types import UnsignedInteger
from concrete.common.values import EncryptedTensor
from concrete.numpy import compile_numpy_function
@pytest.mark.parametrize(
"function,parameters,inputset,test_input,expected_output",
[
pytest.param(
lambda x: x.flatten(),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2)),
},
[numpy.random.randint(0, 2 ** 4, size=(3, 2)) for _ in range(10)],
[[0, 1], [1, 2], [2, 3]],
[0, 1, 1, 2, 2, 3],
),
pytest.param(
lambda x: x.flatten(),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4, 5, 6)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4, 5, 6)) for _ in range(10)],
(numpy.arange(720) % 10).reshape((2, 3, 4, 5, 6)),
(numpy.arange(720) % 10),
),
pytest.param(
lambda x: x.reshape((1, 3)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(3,)),
},
[numpy.random.randint(0, 2 ** 4, size=(3,)) for _ in range(10)],
[5, 9, 1],
[[5, 9, 1]],
),
pytest.param(
lambda x: x.reshape((3, 1)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(3,)),
},
[numpy.random.randint(0, 2 ** 4, size=(3,)) for _ in range(10)],
[5, 9, 1],
[[5], [9], [1]],
),
pytest.param(
lambda x: x.reshape((3, 2)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2)),
},
[numpy.random.randint(0, 2 ** 4, size=(3, 2)) for _ in range(10)],
[[0, 1], [1, 2], [2, 3]],
[[0, 1], [1, 2], [2, 3]],
),
pytest.param(
lambda x: x.reshape((3, 2)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3)) for _ in range(10)],
[[0, 1, 1], [2, 2, 3]],
[[0, 1], [1, 2], [2, 3]],
),
pytest.param(
lambda x: x.reshape(-1),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2)),
},
[numpy.random.randint(0, 2 ** 4, size=(3, 2)) for _ in range(10)],
[[0, 1], [1, 2], [2, 3]],
[0, 1, 1, 2, 2, 3],
),
pytest.param(
lambda x: x.reshape((2, 2, 3)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(4, 3)),
},
[numpy.random.randint(0, 2 ** 4, size=(4, 3)) for _ in range(10)],
(numpy.arange(12) % 10).reshape((4, 3)),
(numpy.arange(12) % 10).reshape((2, 2, 3)),
),
pytest.param(
lambda x: x.reshape((4, 3)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 2, 3)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 2, 3)) for _ in range(10)],
(numpy.arange(12) % 10).reshape((2, 2, 3)),
(numpy.arange(12) % 10).reshape((4, 3)),
),
pytest.param(
lambda x: x.reshape((3, 2, 2)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 4)),
},
[numpy.random.randint(0, 2 ** 4, size=(3, 4)) for _ in range(10)],
(numpy.arange(12) % 10).reshape((3, 4)),
(numpy.arange(12) % 10).reshape((3, 2, 2)),
),
pytest.param(
lambda x: x.reshape((3, 4)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2, 2)),
},
[numpy.random.randint(0, 2 ** 4, size=(3, 2, 2)) for _ in range(10)],
(numpy.arange(12) % 10).reshape((3, 2, 2)),
(numpy.arange(12) % 10).reshape((3, 4)),
),
pytest.param(
lambda x: x.reshape((5, 3, 2)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(6, 5)),
},
[numpy.random.randint(0, 2 ** 4, size=(6, 5)) for _ in range(10)],
(numpy.arange(30) % 10).reshape((6, 5)),
(numpy.arange(30) % 10).reshape((5, 3, 2)),
),
pytest.param(
lambda x: x.reshape((5, 6)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 5)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 5)) for _ in range(10)],
(numpy.arange(30) % 10).reshape((2, 3, 5)),
(numpy.arange(30) % 10).reshape((5, 6)),
),
pytest.param(
lambda x: x.reshape((6, 4, 30)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4, 5, 6)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4, 5, 6)) for _ in range(10)],
(numpy.arange(720) % 10).reshape((2, 3, 4, 5, 6)),
(numpy.arange(720) % 10).reshape((6, 4, 30)),
),
pytest.param(
lambda x: x.reshape((2, 60, 6)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4, 5, 6)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4, 5, 6)) for _ in range(10)],
(numpy.arange(720) % 10).reshape((2, 3, 4, 5, 6)),
(numpy.arange(720) % 10).reshape((2, 60, 6)),
),
pytest.param(
lambda x: x.reshape((6, 6, -1)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 2, 3, 4)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 2, 3, 4)) for _ in range(10)],
(numpy.arange(144) % 10).reshape((2, 3, 2, 3, 4)),
(numpy.arange(144) % 10).reshape((6, 6, -1)),
),
pytest.param(
lambda x: x.reshape((6, -1, 12)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 2, 3, 4)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 2, 3, 4)) for _ in range(10)],
(numpy.arange(144) % 10).reshape((2, 3, 2, 3, 4)),
(numpy.arange(144) % 10).reshape((6, -1, 12)),
),
pytest.param(
lambda x: x.reshape((-1, 18, 4)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 2, 3, 4)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 2, 3, 4)) for _ in range(10)],
(numpy.arange(144) % 10).reshape((2, 3, 2, 3, 4)),
(numpy.arange(144) % 10).reshape((-1, 18, 4)),
),
],
)
def test_memory_operation_run_correctness(
function,
parameters,
inputset,
test_input,
expected_output,
default_compilation_configuration,
check_array_equality,
):
"""
Test correctness of results when running a compiled function with memory operators.
e.g.,
- flatten
- reshape
"""
circuit = compile_numpy_function(
function,
parameters,
inputset,
default_compilation_configuration,
)
actual = circuit.encrypt_run_decrypt(numpy.array(test_input, dtype=numpy.uint8))
expected = numpy.array(expected_output, dtype=numpy.uint8)
check_array_equality(actual, expected)
@pytest.mark.parametrize(
"function,parameters,inputset,error,match",
[
pytest.param(
lambda x: x.reshape((-1, -1, 2)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4)) for _ in range(10)],
ValueError,
"shapes are not compatible (old shape (2, 3, 4), new shape (-1, -1, 2))",
),
pytest.param(
lambda x: x.reshape((3, -1, 3)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4)) for _ in range(10)],
ValueError,
"shapes are not compatible (old shape (2, 3, 4), new shape (3, -1, 3))",
),
],
)
def test_memory_operation_failed_compilation(
function,
parameters,
inputset,
error,
match,
default_compilation_configuration,
):
"""
Test compilation failures of compiled function with memory operations.
e.g.,
- reshape
"""
with pytest.raises(error) as excinfo:
compile_numpy_function(
function,
parameters,
inputset,
default_compilation_configuration,
)
assert (
str(excinfo.value) == match
), f"""
Actual Output
=============
{excinfo.value}
Expected Output
===============
{match}
"""

View File

@@ -1,280 +0,0 @@
"""Test file for user-friendly numpy compilation functions"""
import numpy
import pytest
from concrete.common.debugging import format_operation_graph
from concrete.numpy.np_fhe_compiler import NPFHECompiler
def complicated_topology(x, y):
"""Mix x in an intricated way."""
intermediate = x + y
x_p_1 = intermediate + 1
x_p_2 = intermediate + 2
x_p_3 = x_p_1 + x_p_2
return (
x_p_3.astype(numpy.int32),
x_p_2.astype(numpy.int32),
(x_p_2 + 3).astype(numpy.int32),
x_p_3.astype(numpy.int32) + 67,
)
@pytest.mark.parametrize("input_shape", [(), (3, 1, 2)])
def test_np_fhe_compiler_op_graph(
input_shape, default_compilation_configuration, check_array_equality
):
"""Test NPFHECompiler in two subtests."""
subtest_np_fhe_compiler_1_input_op_graph(
input_shape, default_compilation_configuration, check_array_equality
)
subtest_np_fhe_compiler_2_inputs_op_graph(
input_shape, default_compilation_configuration, check_array_equality
)
def subtest_np_fhe_compiler_1_input_op_graph(
input_shape, default_compilation_configuration, check_array_equality
):
"""test for NPFHECompiler on one input function"""
def function_to_compile(x):
return complicated_topology(x, 0)
compiler = NPFHECompiler(
function_to_compile,
{"x": "encrypted"},
default_compilation_configuration,
)
# For coverage when the OPGraph is not yet traced
compiler._patch_op_graph_input_to_accept_any_integer_input() # pylint: disable=protected-access
assert compiler.compilation_configuration == default_compilation_configuration
assert compiler.compilation_configuration is not default_compilation_configuration
for i in numpy.arange(5):
i = numpy.ones(input_shape, dtype=numpy.int64) * i
check_array_equality(compiler(i), function_to_compile(i))
# For coverage, check that we flush the inputset when we query the OPGraph
current_op_graph = compiler.op_graph
assert current_op_graph is not compiler.op_graph
assert len(compiler._current_inputset) == 0 # pylint: disable=protected-access
# For coverage, cover case where the current inputset is empty
compiler._eval_on_current_inputset() # pylint: disable=protected-access
# Continue a bit more
for i in numpy.arange(5, 10):
i = numpy.ones(input_shape, dtype=numpy.int64) * i
check_array_equality(compiler(i), function_to_compile(i))
if input_shape == ():
assert (
(got := format_operation_graph(compiler.op_graph))
== """ %0 = 67 # ClearScalar<uint7>
%1 = 2 # ClearScalar<uint2>
%2 = 3 # ClearScalar<uint2>
%3 = 1 # ClearScalar<uint1>
%4 = x # EncryptedScalar<uint4>
%5 = 0 # ClearScalar<uint1>
%6 = add(%4, %5) # EncryptedScalar<uint4>
%7 = add(%6, %1) # EncryptedScalar<uint4>
%8 = add(%6, %3) # EncryptedScalar<uint4>
%9 = astype(%7, dtype=int32) # EncryptedScalar<uint4>
%10 = add(%7, %2) # EncryptedScalar<uint4>
%11 = add(%8, %7) # EncryptedScalar<uint5>
%12 = astype(%10, dtype=int32) # EncryptedScalar<uint4>
%13 = astype(%11, dtype=int32) # EncryptedScalar<uint5>
%14 = astype(%11, dtype=int32) # EncryptedScalar<uint5>
%15 = add(%14, %0) # EncryptedScalar<uint7>
(%13, %9, %12, %15)"""
), got
else:
assert (
(got := format_operation_graph(compiler.op_graph))
== """ %0 = 67 # ClearScalar<uint7>
%1 = 2 # ClearScalar<uint2>
%2 = 3 # ClearScalar<uint2>
%3 = 1 # ClearScalar<uint1>
%4 = x # EncryptedTensor<uint4, shape=(3, 1, 2)>
%5 = 0 # ClearScalar<uint1>
%6 = add(%4, %5) # EncryptedTensor<uint4, shape=(3, 1, 2)>
%7 = add(%6, %1) # EncryptedTensor<uint4, shape=(3, 1, 2)>
%8 = add(%6, %3) # EncryptedTensor<uint4, shape=(3, 1, 2)>
%9 = astype(%7, dtype=int32) # EncryptedTensor<uint4, shape=(3, 1, 2)>
%10 = add(%7, %2) # EncryptedTensor<uint4, shape=(3, 1, 2)>
%11 = add(%8, %7) # EncryptedTensor<uint5, shape=(3, 1, 2)>
%12 = astype(%10, dtype=int32) # EncryptedTensor<uint4, shape=(3, 1, 2)>
%13 = astype(%11, dtype=int32) # EncryptedTensor<uint5, shape=(3, 1, 2)>
%14 = astype(%11, dtype=int32) # EncryptedTensor<uint5, shape=(3, 1, 2)>
%15 = add(%14, %0) # EncryptedTensor<uint7, shape=(3, 1, 2)>
(%13, %9, %12, %15)"""
), got
def subtest_np_fhe_compiler_2_inputs_op_graph(
input_shape, default_compilation_configuration, check_array_equality
):
"""test for NPFHECompiler on two inputs function"""
compiler = NPFHECompiler(
complicated_topology,
{"x": "encrypted", "y": "clear"},
default_compilation_configuration,
)
# For coverage when the OPGraph is not yet traced
compiler._patch_op_graph_input_to_accept_any_integer_input() # pylint: disable=protected-access
assert compiler.compilation_configuration == default_compilation_configuration
assert compiler.compilation_configuration is not default_compilation_configuration
for i, j in zip(numpy.arange(5), numpy.arange(5, 10)):
i = numpy.ones(input_shape, dtype=numpy.int64) * i
j = numpy.ones(input_shape, dtype=numpy.int64) * j
check_array_equality(compiler(i, j), complicated_topology(i, j))
# For coverage, check that we flush the inputset when we query the OPGraph
current_op_graph = compiler.op_graph
assert current_op_graph is not compiler.op_graph
assert len(compiler._current_inputset) == 0 # pylint: disable=protected-access
# For coverage, cover case where the current inputset is empty
compiler._eval_on_current_inputset() # pylint: disable=protected-access
# Continue a bit more
for i, j in zip(numpy.arange(5, 10), numpy.arange(5)):
i = numpy.ones(input_shape, dtype=numpy.int64) * i
j = numpy.ones(input_shape, dtype=numpy.int64) * j
check_array_equality(compiler(i, j), complicated_topology(i, j))
if input_shape == ():
assert (
(got := format_operation_graph(compiler.op_graph))
== """ %0 = 67 # ClearScalar<uint7>
%1 = 2 # ClearScalar<uint2>
%2 = 3 # ClearScalar<uint2>
%3 = 1 # ClearScalar<uint1>
%4 = x # EncryptedScalar<uint4>
%5 = y # ClearScalar<uint4>
%6 = add(%4, %5) # EncryptedScalar<uint4>
%7 = add(%6, %1) # EncryptedScalar<uint4>
%8 = add(%6, %3) # EncryptedScalar<uint4>
%9 = astype(%7, dtype=int32) # EncryptedScalar<uint4>
%10 = add(%7, %2) # EncryptedScalar<uint5>
%11 = add(%8, %7) # EncryptedScalar<uint5>
%12 = astype(%10, dtype=int32) # EncryptedScalar<uint5>
%13 = astype(%11, dtype=int32) # EncryptedScalar<uint5>
%14 = astype(%11, dtype=int32) # EncryptedScalar<uint5>
%15 = add(%14, %0) # EncryptedScalar<uint7>
(%13, %9, %12, %15)"""
), got
else:
assert (
(got := format_operation_graph(compiler.op_graph))
== """ %0 = 67 # ClearScalar<uint7>
%1 = 2 # ClearScalar<uint2>
%2 = 3 # ClearScalar<uint2>
%3 = 1 # ClearScalar<uint1>
%4 = x # EncryptedTensor<uint4, shape=(3, 1, 2)>
%5 = y # ClearTensor<uint4, shape=(3, 1, 2)>
%6 = add(%4, %5) # EncryptedTensor<uint4, shape=(3, 1, 2)>
%7 = add(%6, %1) # EncryptedTensor<uint4, shape=(3, 1, 2)>
%8 = add(%6, %3) # EncryptedTensor<uint4, shape=(3, 1, 2)>
%9 = astype(%7, dtype=int32) # EncryptedTensor<uint4, shape=(3, 1, 2)>
%10 = add(%7, %2) # EncryptedTensor<uint5, shape=(3, 1, 2)>
%11 = add(%8, %7) # EncryptedTensor<uint5, shape=(3, 1, 2)>
%12 = astype(%10, dtype=int32) # EncryptedTensor<uint5, shape=(3, 1, 2)>
%13 = astype(%11, dtype=int32) # EncryptedTensor<uint5, shape=(3, 1, 2)>
%14 = astype(%11, dtype=int32) # EncryptedTensor<uint5, shape=(3, 1, 2)>
%15 = add(%14, %0) # EncryptedTensor<uint7, shape=(3, 1, 2)>
(%13, %9, %12, %15)"""
), got
def remaining_inputset_size(inputset_len):
"""Small function to generate test cases below for remaining inputset length."""
return inputset_len % NPFHECompiler.INPUTSET_SIZE_BEFORE_AUTO_BOUND_UPDATE
@pytest.mark.parametrize(
"inputset_len, expected_remaining_inputset_len",
[
(42, remaining_inputset_size(42)),
(128, remaining_inputset_size(128)),
(234, remaining_inputset_size(234)),
],
)
def test_np_fhe_compiler_auto_flush(
inputset_len,
expected_remaining_inputset_len,
default_compilation_configuration,
check_array_equality,
):
"""Test the auto flush of NPFHECompiler once the inputset is 128 elements."""
def function_to_compile(x):
return x // 2
compiler = NPFHECompiler(
function_to_compile,
{"x": "encrypted"},
default_compilation_configuration,
)
for i in numpy.arange(inputset_len):
check_array_equality(compiler(i), function_to_compile(i))
# Check the inputset was properly flushed
assert (
len(compiler._current_inputset) # pylint: disable=protected-access
== expected_remaining_inputset_len
)
def test_np_fhe_compiler_full_compilation(default_compilation_configuration, check_array_equality):
"""Test the case where we generate an FHE circuit."""
def function_to_compile(x):
return x + 42
compiler = NPFHECompiler(
function_to_compile,
{"x": "encrypted"},
default_compilation_configuration,
)
# For coverage
with pytest.raises(RuntimeError) as excinfo:
compiler.get_compiled_fhe_circuit()
assert str(excinfo.value) == (
"Requested FHECircuit but no OPGraph was compiled. "
"Did you forget to evaluate NPFHECompiler over an inputset?"
)
for i in numpy.arange(64):
check_array_equality(compiler(i), function_to_compile(i))
fhe_circuit = compiler.get_compiled_fhe_circuit()
for i in range(64):
assert fhe_circuit.encrypt_run_decrypt(i) == function_to_compile(i)
def test_np_fhe_compiler_compile_on_inputset(default_compilation_configuration):
"""Test the case where we generate an FHE circuit with a single call."""
def function_to_compile(x):
return x + 42
compiler = NPFHECompiler(
function_to_compile,
{"x": "encrypted"},
default_compilation_configuration,
)
circuit = compiler.compile_on_inputset(numpy.arange(64))
for i in range(64):
assert circuit.encrypt_run_decrypt(i) == function_to_compile(i)

View File

@@ -1,111 +0,0 @@
"""Test file for numpy dtype helpers"""
import numpy
import pytest
from concrete.common.data_types.floats import Float
from concrete.common.data_types.integers import Integer
from concrete.numpy.np_dtypes_helpers import (
convert_base_data_type_to_numpy_dtype,
convert_numpy_dtype_to_base_data_type,
get_base_value_for_numpy_or_python_constant_data,
get_constructor_for_numpy_or_python_constant_data,
)
@pytest.mark.parametrize(
"numpy_dtype,expected_common_type",
[
pytest.param(numpy.int8, Integer(8, is_signed=True)),
pytest.param("int8", Integer(8, is_signed=True)),
pytest.param(numpy.int16, Integer(16, is_signed=True)),
pytest.param("int16", Integer(16, is_signed=True)),
pytest.param(numpy.int32, Integer(32, is_signed=True)),
pytest.param("int32", Integer(32, is_signed=True)),
pytest.param(numpy.int64, Integer(64, is_signed=True)),
pytest.param("int64", Integer(64, is_signed=True)),
pytest.param(numpy.uint8, Integer(8, is_signed=False)),
pytest.param("uint8", Integer(8, is_signed=False)),
pytest.param(numpy.uint16, Integer(16, is_signed=False)),
pytest.param("uint16", Integer(16, is_signed=False)),
pytest.param(numpy.uint32, Integer(32, is_signed=False)),
pytest.param("uint32", Integer(32, is_signed=False)),
pytest.param(numpy.uint64, Integer(64, is_signed=False)),
pytest.param("uint64", Integer(64, is_signed=False)),
pytest.param(numpy.float32, Float(32)),
pytest.param("float32", Float(32)),
pytest.param(numpy.float64, Float(64)),
pytest.param("float64", Float(64)),
pytest.param("complex64", None, marks=pytest.mark.xfail(strict=True, raises=ValueError)),
],
)
def test_convert_numpy_dtype_to_base_data_type(numpy_dtype, expected_common_type):
"""Test function for convert_numpy_dtype_to_base_data_type"""
assert convert_numpy_dtype_to_base_data_type(numpy_dtype) == expected_common_type
@pytest.mark.parametrize(
"common_dtype,expected_numpy_dtype",
[
pytest.param(Integer(7, is_signed=False), numpy.uint32),
pytest.param(Integer(7, is_signed=True), numpy.int32),
pytest.param(Integer(32, is_signed=True), numpy.int32),
pytest.param(Integer(64, is_signed=True), numpy.int64),
pytest.param(Integer(32, is_signed=False), numpy.uint32),
pytest.param(Integer(64, is_signed=False), numpy.uint64),
pytest.param(Float(32), numpy.float32),
pytest.param(Float(64), numpy.float64),
pytest.param(
Integer(128, is_signed=True),
None,
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
),
],
)
def test_convert_common_dtype_to_numpy_dtype(common_dtype, expected_numpy_dtype):
"""Test function for convert_common_dtype_to_numpy_dtype"""
assert expected_numpy_dtype == convert_base_data_type_to_numpy_dtype(common_dtype)
@pytest.mark.parametrize(
"constant_data,expected_constructor",
[
(10, int),
(42.0, float),
(numpy.int32(10), numpy.int32),
],
)
def test_get_constructor_for_numpy_or_python_constant_data(constant_data, expected_constructor):
"""Test function for get_constructor_for_numpy_or_python_constant_data"""
assert expected_constructor == get_constructor_for_numpy_or_python_constant_data(constant_data)
def test_get_constructor_for_numpy_arrays(test_helpers):
"""Test function for get_constructor_for_numpy_or_python_constant_data for numpy arrays."""
arrays = [
numpy.array([[0, 1], [3, 4]], dtype=numpy.uint64),
numpy.array([[0, 1], [3, 4]], dtype=numpy.float64),
]
def get_expected_constructor(array: numpy.ndarray):
return lambda x: numpy.full(array.shape, x, dtype=array.dtype)
expected_constructors = [get_expected_constructor(array) for array in arrays]
for array, expected_constructor in zip(arrays, expected_constructors):
assert test_helpers.python_functions_are_equal_or_equivalent(
expected_constructor, get_constructor_for_numpy_or_python_constant_data(array)
)
def test_get_base_value_for_numpy_or_python_constant_data_with_list():
"""Test function for get_base_value_for_numpy_or_python_constant_data called with list"""
with pytest.raises(
AssertionError,
match="Unsupported constant data of type list "
"\\(if you meant to use a list as an array, please use numpy\\.array instead\\)",
):
get_base_value_for_numpy_or_python_constant_data([1, 2, 3])

View File

@@ -1,96 +0,0 @@
"""Test file for numpy inputset helpers"""
import numpy as np
import pytest
from concrete.common.compilation import CompilationConfiguration
from concrete.common.data_types import Float, UnsignedInteger
from concrete.common.data_types.base import BaseDataType
from concrete.common.values import BaseValue, EncryptedScalar, EncryptedTensor
from concrete.numpy.np_inputset_helpers import _generate_random_inputset
def test_generate_random_inputset():
"""Test function for generate_random_inputset"""
inputset = _generate_random_inputset(
{
"x1": EncryptedScalar(UnsignedInteger(4)),
"x2": EncryptedTensor(UnsignedInteger(4), shape=(2, 3)),
"x3": EncryptedScalar(Float(64)),
"x4": EncryptedTensor(Float(64), shape=(3, 2)),
},
CompilationConfiguration(random_inputset_samples=15),
)
assert isinstance(inputset, list)
assert len(inputset) == 15
for sample in inputset:
assert isinstance(sample, tuple)
assert len(sample) == 4
assert isinstance(sample[0], int)
assert 0 <= sample[0] < 2 ** 4
assert isinstance(sample[1], np.ndarray)
assert sample[1].dtype == np.uint64
assert sample[1].shape == (2, 3)
assert (sample[1] >= 0).all()
assert (sample[1] < 2 ** 4).all()
assert isinstance(sample[2], float)
assert 0 <= sample[2] < 1
assert isinstance(sample[3], np.ndarray)
assert sample[3].dtype == np.float64
assert sample[3].shape == (3, 2)
assert (sample[3] >= 0).all()
assert (sample[3] < 1).all()
def test_fail_generate_random_inputset():
"""Test function for failed generate_random_inputset"""
class MockDtype(BaseDataType):
"""Unsupported dtype to check error messages"""
def __eq__(self, o: object) -> bool:
return False
def __str__(self):
return "MockDtype"
class MockValue(BaseValue):
"""Unsupported value to check error messages"""
def __init__(self):
super().__init__(MockDtype(), is_encrypted=True)
def __eq__(self, other: object) -> bool:
return False
def __str__(self):
return "MockValue"
with pytest.raises(ValueError):
try:
_generate_random_inputset(
{"x": MockValue()},
CompilationConfiguration(random_inputset_samples=15),
)
except Exception as error:
expected = "Random inputset cannot be generated for MockValue parameters"
assert str(error) == expected
raise
with pytest.raises(ValueError):
try:
_generate_random_inputset(
{"x": EncryptedScalar(MockDtype())},
CompilationConfiguration(random_inputset_samples=15),
)
except Exception as error:
expected = "Random inputset cannot be generated for parameters of type MockDtype"
assert str(error) == expected
raise

View File

@@ -1,110 +0,0 @@
"""Test file for numpy mlir converter"""
import math
import numpy
import pytest
import concrete.numpy as hnp
from concrete.common.representation.intermediate import GenericFunction
from concrete.numpy.np_mlir_converter import generate_deduplicated_tables
def multi_tlu_func(x, cst):
"""Multi TLU function"""
y = x + cst
return y.astype(numpy.int32)
RESNET_BIGGEST_SHAPE = (64, 112, 112)
RESNET_BIGGEST_SIZE = math.prod(RESNET_BIGGEST_SHAPE)
@pytest.mark.parametrize(
"function,expected_number_of_tables",
[
(
lambda x: multi_tlu_func(x, numpy.zeros(RESNET_BIGGEST_SHAPE, dtype=numpy.float64)),
1,
),
(
lambda x: multi_tlu_func(
x,
numpy.arange(RESNET_BIGGEST_SIZE, dtype=numpy.float64).reshape(
RESNET_BIGGEST_SHAPE
),
),
RESNET_BIGGEST_SIZE,
),
],
)
def test_generate_deduplicated_tables(
function, expected_number_of_tables, default_compilation_configuration
):
"""Test function for generate_deduplicated_tables"""
op_graph = hnp.compile_numpy_function_into_op_graph_and_measure_bounds(
function,
{"x": hnp.EncryptedTensor(hnp.Integer(7, False), RESNET_BIGGEST_SHAPE)},
(i * numpy.ones(RESNET_BIGGEST_SHAPE, dtype=numpy.int32) for i in range(128)),
default_compilation_configuration,
)
univariate_function_nodes = [
node for node in op_graph.graph.nodes() if isinstance(node, GenericFunction)
]
assert len(univariate_function_nodes) == 1
tlu_node = univariate_function_nodes[0]
deduplication_result = generate_deduplicated_tables(
tlu_node, op_graph.get_ordered_preds(tlu_node)
)
assert len(deduplication_result) == expected_number_of_tables
def test_deduplicated_tables_correctness(default_compilation_configuration):
"""Check the deduplicated tables are the expected ones"""
tensor_shape = (2, 2)
op_graph = hnp.compile_numpy_function_into_op_graph_and_measure_bounds(
lambda x: multi_tlu_func(x, numpy.arange(4, dtype=numpy.float64).reshape(tensor_shape)),
{"x": hnp.EncryptedTensor(hnp.Integer(2, False), tensor_shape)},
(i * numpy.ones(tensor_shape, dtype=numpy.int32) for i in range(4)),
default_compilation_configuration,
)
univariate_function_nodes = [
node for node in op_graph.graph.nodes() if isinstance(node, GenericFunction)
]
assert len(univariate_function_nodes) == 1
tlu_node = univariate_function_nodes[0]
deduplication_result = generate_deduplicated_tables(
tlu_node, op_graph.get_ordered_preds(tlu_node)
)
expected_result = tuple(
(
numpy.arange(i, 4 + i, dtype=numpy.int32),
[
numpy.unravel_index(i, tensor_shape),
],
)
for i in range(4)
)
assert len(deduplication_result) == len(expected_result)
for computed_array, computed_idx in deduplication_result:
for expected_array, expected_idx in expected_result:
if numpy.array_equal(computed_array, expected_array) and computed_idx == expected_idx:
break
else:
raise AssertionError(
f"Could not find {(computed_array, computed_idx)} "
f"in expected_result: {expected_result}"
)

View File

@@ -1,991 +0,0 @@
"""Test file for numpy tracing"""
import inspect
import networkx as nx
import numpy
import pytest
from concrete.common.data_types.dtypes_helpers import broadcast_shapes
from concrete.common.data_types.floats import Float
from concrete.common.data_types.integers import Integer
from concrete.common.debugging import format_operation_graph
from concrete.common.representation import intermediate as ir
from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor
from concrete.numpy import tracing
OPERATIONS_TO_TEST = [ir.Add, ir.Sub, ir.Mul]
@pytest.mark.parametrize(
"operation",
OPERATIONS_TO_TEST,
)
@pytest.mark.parametrize(
"x",
[
pytest.param(EncryptedScalar(Integer(64, is_signed=False)), id="x: Encrypted uint"),
pytest.param(
EncryptedScalar(Integer(64, is_signed=True)),
id="x: Encrypted int",
),
pytest.param(
ClearScalar(Integer(64, is_signed=False)),
id="x: Clear uint",
),
pytest.param(
ClearScalar(Integer(64, is_signed=True)),
id="x: Clear int",
),
],
)
@pytest.mark.parametrize(
"y",
[
pytest.param(EncryptedScalar(Integer(64, is_signed=False)), id="y: Encrypted uint"),
pytest.param(
EncryptedScalar(Integer(64, is_signed=True)),
id="y: Encrypted int",
),
pytest.param(
ClearScalar(Integer(64, is_signed=False)),
id="y: Clear uint",
),
pytest.param(
ClearScalar(Integer(64, is_signed=True)),
id="y: Clear int",
),
],
)
def test_numpy_tracing_binary_op(operation, x, y, test_helpers):
"Test numpy tracing a binary operation (in the supported ops)"
# Remark that the functions here have a common structure (which is
# 2x op y), such that creating further the ref_graph is easy, by
# hand
def simple_add_function(x, y):
z = x + x
return z + y
def simple_sub_function(x, y):
z = x + x
return z - y
def simple_mul_function(x, y):
z = x + x
return z * y
assert operation in OPERATIONS_TO_TEST, f"unknown operation {operation}"
if operation == ir.Add:
function_to_compile = simple_add_function
elif operation == ir.Sub:
function_to_compile = simple_sub_function
elif operation == ir.Mul:
function_to_compile = simple_mul_function
op_graph = tracing.trace_numpy_function(function_to_compile, {"x": x, "y": y})
ref_graph = nx.MultiDiGraph()
input_x = ir.Input(x, input_name="x", program_input_idx=0)
input_y = ir.Input(y, input_name="y", program_input_idx=1)
add_node_z = ir.Add(
(
input_x.outputs[0],
input_x.outputs[0],
)
)
returned_final_node = operation(
(
add_node_z.outputs[0],
input_y.outputs[0],
)
)
ref_graph.add_node(input_x)
ref_graph.add_node(input_y)
ref_graph.add_node(add_node_z)
ref_graph.add_node(returned_final_node)
ref_graph.add_edge(input_x, add_node_z, input_idx=0, output_idx=0)
ref_graph.add_edge(input_x, add_node_z, input_idx=1, output_idx=0)
ref_graph.add_edge(add_node_z, returned_final_node, input_idx=0, output_idx=0)
ref_graph.add_edge(input_y, returned_final_node, input_idx=1, output_idx=0)
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)
def test_numpy_tracing_tensors():
"Test numpy tracing tensors"
def all_operations(x):
intermediate = x + numpy.array([[1, 2], [3, 4]])
intermediate = numpy.array([[5, 6], [7, 8]]) + intermediate
intermediate = numpy.array([[100, 200], [300, 400]]) - intermediate
intermediate = intermediate - numpy.array([[10, 20], [30, 40]])
intermediate = intermediate * numpy.array([[1, 2], [2, 1]])
intermediate = numpy.array([[2, 1], [1, 2]]) * intermediate
return intermediate
op_graph = tracing.trace_numpy_function(
all_operations, {"x": EncryptedTensor(Integer(32, True), shape=(2, 2))}
)
expected = """ %0 = [[2 1] [1 2]] # ClearTensor<uint2, shape=(2, 2)>
%1 = [[1 2] [2 1]] # ClearTensor<uint2, shape=(2, 2)>
%2 = [[10 20] [30 40]] # ClearTensor<uint6, shape=(2, 2)>
%3 = [[100 200] [300 400]] # ClearTensor<uint9, shape=(2, 2)>
%4 = [[5 6] [7 8]] # ClearTensor<uint4, shape=(2, 2)>
%5 = x # EncryptedTensor<int32, shape=(2, 2)>
%6 = [[1 2] [3 4]] # ClearTensor<uint3, shape=(2, 2)>
%7 = add(%5, %6) # EncryptedTensor<int32, shape=(2, 2)>
%8 = add(%4, %7) # EncryptedTensor<int32, shape=(2, 2)>
%9 = sub(%3, %8) # EncryptedTensor<int32, shape=(2, 2)>
%10 = sub(%9, %2) # EncryptedTensor<int32, shape=(2, 2)>
%11 = mul(%10, %1) # EncryptedTensor<int32, shape=(2, 2)>
%12 = mul(%0, %11) # EncryptedTensor<int32, shape=(2, 2)>
return %12""" # noqa: E501
assert format_operation_graph(op_graph) == expected, format_operation_graph(op_graph)
def test_numpy_explicit_tracing_tensors():
"Test numpy tracing tensors using explicit operations"
def all_explicit_operations(x):
intermediate = numpy.add(x, numpy.array([[1, 2], [3, 4]]))
intermediate = numpy.add(numpy.array([[5, 6], [7, 8]]), intermediate)
intermediate = numpy.subtract(numpy.array([[100, 200], [300, 400]]), intermediate)
intermediate = numpy.subtract(intermediate, numpy.array([[10, 20], [30, 40]]))
intermediate = numpy.multiply(intermediate, numpy.array([[1, 2], [2, 1]]))
intermediate = numpy.multiply(numpy.array([[2, 1], [1, 2]]), intermediate)
return intermediate
op_graph = tracing.trace_numpy_function(
all_explicit_operations, {"x": EncryptedTensor(Integer(32, True), shape=(2, 2))}
)
expected = """ %0 = [[2 1] [1 2]] # ClearTensor<uint2, shape=(2, 2)>
%1 = [[1 2] [2 1]] # ClearTensor<uint2, shape=(2, 2)>
%2 = [[10 20] [30 40]] # ClearTensor<uint6, shape=(2, 2)>
%3 = [[100 200] [300 400]] # ClearTensor<uint9, shape=(2, 2)>
%4 = [[5 6] [7 8]] # ClearTensor<uint4, shape=(2, 2)>
%5 = x # EncryptedTensor<int32, shape=(2, 2)>
%6 = [[1 2] [3 4]] # ClearTensor<uint3, shape=(2, 2)>
%7 = add(%5, %6) # EncryptedTensor<int32, shape=(2, 2)>
%8 = add(%4, %7) # EncryptedTensor<int32, shape=(2, 2)>
%9 = sub(%3, %8) # EncryptedTensor<int32, shape=(2, 2)>
%10 = sub(%9, %2) # EncryptedTensor<int32, shape=(2, 2)>
%11 = mul(%10, %1) # EncryptedTensor<int32, shape=(2, 2)>
%12 = mul(%0, %11) # EncryptedTensor<int32, shape=(2, 2)>
return %12""" # noqa: E501
assert format_operation_graph(op_graph) == expected
@pytest.mark.parametrize(
"x_shape,y_shape",
[
pytest.param((), ()),
pytest.param((3,), ()),
pytest.param((3,), (1,)),
pytest.param((3,), (2,), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
pytest.param((3,), (3,)),
pytest.param((2, 3), ()),
pytest.param((2, 3), (1,)),
pytest.param((2, 3), (2,), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
pytest.param((2, 3), (3,)),
pytest.param((2, 3), (1, 1)),
pytest.param((2, 3), (2, 1)),
pytest.param((2, 3), (3, 1), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
pytest.param((2, 3), (1, 2), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
pytest.param((2, 3), (2, 2), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
pytest.param((2, 3), (3, 2), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
pytest.param((2, 3), (1, 3)),
pytest.param((2, 3), (2, 3)),
pytest.param((2, 3), (3, 3), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
pytest.param((2, 1, 3), (1, 1, 1)),
pytest.param((2, 1, 3), (1, 4, 1)),
pytest.param((2, 1, 3), (2, 4, 3)),
],
)
def test_numpy_tracing_broadcasted_tensors(x_shape, y_shape):
"""Test numpy tracing broadcasted tensors"""
def f(x, y):
return x + y
op_graph = tracing.trace_numpy_function(
f,
{
"x": EncryptedTensor(Integer(3, True), shape=x_shape),
"y": EncryptedTensor(Integer(3, True), shape=y_shape),
},
)
assert op_graph.input_nodes[0].outputs[0].shape == x_shape
assert op_graph.input_nodes[1].outputs[0].shape == y_shape
assert op_graph.output_nodes[0].outputs[0].shape == broadcast_shapes(x_shape, y_shape)
@pytest.mark.parametrize(
"function_to_trace,op_graph_expected_output_type,input_and_expected_output_tuples",
[
(
lambda x: x.astype(numpy.int32),
Integer(32, is_signed=True),
[
(14, numpy.int32(14)),
(1.5, numpy.int32(1)),
(2.0, numpy.int32(2)),
(-1.5, numpy.int32(-1)),
(2 ** 31 - 1, numpy.int32(2 ** 31 - 1)),
(-(2 ** 31), numpy.int32(-(2 ** 31))),
],
),
(
lambda x: x.astype(numpy.uint32),
Integer(32, is_signed=False),
[
(14, numpy.uint32(14)),
(1.5, numpy.uint32(1)),
(2.0, numpy.uint32(2)),
(2 ** 32 - 1, numpy.uint32(2 ** 32 - 1)),
],
),
(
lambda x: x.astype(numpy.int64),
Integer(64, is_signed=True),
[
(14, numpy.int64(14)),
(1.5, numpy.int64(1)),
(2.0, numpy.int64(2)),
(-1.5, numpy.int64(-1)),
(2 ** 63 - 1, numpy.int64(2 ** 63 - 1)),
(-(2 ** 63), numpy.int64(-(2 ** 63))),
],
),
(
lambda x: x.astype(numpy.uint64),
Integer(64, is_signed=False),
[
(14, numpy.uint64(14)),
(1.5, numpy.uint64(1)),
(2.0, numpy.uint64(2)),
(2 ** 64 - 1, numpy.uint64(2 ** 64 - 1)),
],
),
(
lambda x: x.astype(numpy.float64),
Float(64),
[
(14, numpy.float64(14.0)),
(1.5, numpy.float64(1.5)),
(2.0, numpy.float64(2.0)),
(-1.5, numpy.float64(-1.5)),
],
),
(
lambda x: x.astype(numpy.float32),
Float(32),
[
(14, numpy.float32(14.0)),
(1.5, numpy.float32(1.5)),
(2.0, numpy.float32(2.0)),
(-1.5, numpy.float32(-1.5)),
],
),
],
)
def test_tracing_astype(
function_to_trace, op_graph_expected_output_type, input_and_expected_output_tuples
):
"""Test function for NPTracer.astype"""
for input_, expected_output in input_and_expected_output_tuples:
input_value = (
EncryptedScalar(Integer(64, is_signed=True))
if isinstance(input_, int)
else EncryptedScalar(Float(64))
)
op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value})
output_node = op_graph.output_nodes[0]
assert op_graph_expected_output_type == output_node.outputs[0].dtype
node_results = op_graph.evaluate({0: numpy.array(input_)})
evaluated_output = node_results[output_node]
assert evaluated_output.dtype == expected_output.dtype
assert expected_output == evaluated_output
def test_tracing_astype_single_element_array_corner_case(check_array_equality):
"""Test corner case where an array could be transformed to its scalar element"""
a = numpy.array([1], dtype=numpy.float64)
op_graph = tracing.trace_numpy_function(
lambda x: x.astype(numpy.int32), {"x": EncryptedTensor(Float(64), (1,))}
)
eval_result = op_graph(a)
check_array_equality(eval_result, numpy.array([1], dtype=numpy.int32))
@pytest.mark.parametrize(
"function_to_trace,inputs,expected_output_node,expected_output_value",
[
pytest.param(
lambda x, y: numpy.dot(x, y),
{
"x": EncryptedTensor(Integer(7, is_signed=False), shape=(10,)),
"y": EncryptedTensor(Integer(7, is_signed=False), shape=(10,)),
},
ir.Dot,
EncryptedScalar(Integer(32, False)),
),
pytest.param(
lambda x, y: numpy.dot(x, y),
{
"x": EncryptedTensor(Float(64), shape=(10,)),
"y": EncryptedTensor(Float(64), shape=(10,)),
},
ir.Dot,
EncryptedScalar(Float(64)),
),
pytest.param(
lambda x, y: numpy.dot(x, y),
{
"x": ClearTensor(Integer(64, is_signed=True), shape=(6,)),
"y": ClearTensor(Integer(64, is_signed=True), shape=(6,)),
},
ir.Dot,
ClearScalar(Integer(64, is_signed=True)),
),
pytest.param(
lambda x: numpy.dot(x, numpy.array([1, 2, 3, 4, 5], dtype=numpy.int64)),
{
"x": EncryptedTensor(Integer(64, is_signed=True), shape=(5,)),
},
ir.Dot,
EncryptedScalar(Integer(64, True)),
),
pytest.param(
lambda x: x.dot(numpy.array([1, 2, 3, 4, 5], dtype=numpy.int64)),
{
"x": EncryptedTensor(Integer(64, is_signed=True), shape=(5,)),
},
ir.Dot,
EncryptedScalar(Integer(64, True)),
),
],
)
def test_trace_numpy_dot(function_to_trace, inputs, expected_output_node, expected_output_value):
"""Function to test dot tracing"""
op_graph = tracing.trace_numpy_function(function_to_trace, inputs)
assert len(op_graph.output_nodes) == 1
assert isinstance(op_graph.output_nodes[0], expected_output_node)
assert len(op_graph.output_nodes[0].outputs) == 1
assert op_graph.output_nodes[0].outputs[0] == expected_output_value
@pytest.mark.parametrize("np_function", tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC)
def test_nptracer_get_tracing_func_for_np_functions(np_function):
"""Test NPTracer get_tracing_func_for_np_function"""
expected_tracing_func = tracing.NPTracer.UFUNC_ROUTING[np_function]
assert tracing.NPTracer.get_tracing_func_for_np_function(np_function) == expected_tracing_func
def test_nptracer_get_tracing_func_for_np_functions_not_implemented():
"""Check NPTracer in case of not-implemented function"""
with pytest.raises(NotImplementedError) as excinfo:
tracing.NPTracer.get_tracing_func_for_np_function(numpy.conjugate)
assert "NPTracer does not yet manage the following func: conjugate" in str(excinfo.value)
@pytest.mark.parametrize(
"operation,exception_type,match",
[
pytest.param(
lambda x: x + "fail",
TypeError,
"unsupported operand type(s) for +: 'NPTracer' and 'str'",
),
pytest.param(
lambda x: "fail" + x,
TypeError,
'can only concatenate str (not "NPTracer") to str',
),
pytest.param(
lambda x: x - "fail",
TypeError,
"unsupported operand type(s) for -: 'NPTracer' and 'str'",
),
pytest.param(
lambda x: "fail" - x,
TypeError,
"unsupported operand type(s) for -: 'str' and 'NPTracer'",
),
pytest.param(
lambda x: x * "fail",
TypeError,
"can't multiply sequence by non-int of type 'NPTracer'",
),
pytest.param(
lambda x: "fail" * x,
TypeError,
"can't multiply sequence by non-int of type 'NPTracer'",
),
pytest.param(
lambda x: x / "fail",
TypeError,
"unsupported operand type(s) for /: 'NPTracer' and 'str'",
),
pytest.param(
lambda x: "fail" / x,
TypeError,
"unsupported operand type(s) for /: 'str' and 'NPTracer'",
),
pytest.param(
lambda x: x // "fail",
TypeError,
"unsupported operand type(s) for //: 'NPTracer' and 'str'",
),
pytest.param(
lambda x: "fail" // x,
TypeError,
"unsupported operand type(s) for //: 'str' and 'NPTracer'",
),
pytest.param(
lambda x, y: x / y, NotImplementedError, "Can't manage binary operator truediv"
),
pytest.param(
lambda x, y: x // y, NotImplementedError, "Can't manage binary operator floordiv"
),
],
)
def test_nptracer_unsupported_operands(operation, exception_type, match):
"""Test cases where NPTracer cannot be used with other operands."""
tracers = [
tracing.NPTracer([], ir.Input(ClearScalar(Integer(32, True)), param_name, idx), 0)
for idx, param_name in enumerate(inspect.signature(operation).parameters.keys())
]
with pytest.raises(exception_type) as exc_info:
_ = operation(*tracers)
assert match in str(exc_info)
def subtest_tracing_calls(
function_to_trace,
input_value_input_and_expected_output_tuples,
check_array_equality,
):
"""Test memory function managed by GenericFunction node of the form numpy.something"""
for input_value, input_, expected_output in input_value_input_and_expected_output_tuples:
op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value})
output_node = op_graph.output_nodes[0]
node_results = op_graph.evaluate({0: input_})
evaluated_output = node_results[output_node]
assert isinstance(evaluated_output, type(expected_output)), type(evaluated_output)
check_array_equality(evaluated_output, expected_output)
@pytest.mark.parametrize(
"function_to_trace,input_value_input_and_expected_output_tuples",
[
(
lambda x: numpy.transpose(x),
[
(
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
numpy.arange(4).reshape(2, 2),
numpy.array([[0, 2], [1, 3]]),
),
(
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
numpy.arange(4, 8).reshape(2, 2),
numpy.array([[4, 6], [5, 7]]),
),
(
EncryptedTensor(Integer(6, is_signed=False), shape=()),
numpy.int64(42),
numpy.int64(42),
),
],
),
(
lambda x: numpy.transpose(x) + 42,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(42, 57).reshape(3, 5).transpose(),
),
(
EncryptedTensor(Integer(6, is_signed=False), shape=()),
numpy.int64(42),
numpy.int64(84),
),
],
),
(
lambda x: numpy.ravel(x),
[
(
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
numpy.arange(4),
numpy.array([0, 1, 2, 3]),
),
(
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
numpy.arange(4).reshape(2, 2),
numpy.array([0, 1, 2, 3]),
),
(
EncryptedTensor(Integer(6, is_signed=False), shape=()),
numpy.int64(42),
numpy.array([42], dtype=numpy.int64),
),
],
),
(
lambda x: numpy.reshape(x, (5, 3)) + 42,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(42, 57).reshape(5, 3),
),
],
),
],
)
def test_tracing_numpy_calls(
function_to_trace,
input_value_input_and_expected_output_tuples,
check_array_equality,
):
"""Test memory function managed by GenericFunction node of the form numpy.something"""
subtest_tracing_calls(
function_to_trace, input_value_input_and_expected_output_tuples, check_array_equality
)
@pytest.mark.parametrize(
"function_to_trace,input_value_input_and_expected_output_tuples",
[
(
lambda x: x.transpose() + 42,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(42, 57).reshape(3, 5).transpose(),
),
(
EncryptedTensor(Integer(6, is_signed=False), shape=()),
numpy.int64(42),
numpy.int64(84),
),
],
),
(
lambda x: x.ravel(),
[
(
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
numpy.arange(4),
numpy.array([0, 1, 2, 3]),
),
(
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
numpy.arange(4).reshape(2, 2),
numpy.array([0, 1, 2, 3]),
),
(
EncryptedTensor(Integer(6, is_signed=False), shape=()),
numpy.int64(42),
numpy.array([42], dtype=numpy.int64),
),
],
),
(
lambda x: x.reshape((5, 3)) + 42,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(42, 57).reshape(5, 3),
),
],
),
pytest.param(
lambda x: x.reshape((5, 3)),
[
(
EncryptedTensor(Integer(6, is_signed=False), shape=()),
numpy.int64(42),
None,
)
],
marks=pytest.mark.xfail(strict=True, raises=ValueError),
),
pytest.param(
lambda x: x.flatten(),
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(15),
)
],
),
pytest.param(
lambda x: abs(x),
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(15).reshape(3, 5),
)
],
),
pytest.param(
lambda x: +x,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(15).reshape(3, 5),
)
],
),
pytest.param(
lambda x: -x,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
(numpy.arange(15).reshape(3, 5)) * (-1),
)
],
),
pytest.param(
lambda x: ~x,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(15).reshape(3, 5).__invert__(),
)
],
),
pytest.param(
lambda x: x << 3,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(15).reshape(3, 5) * 8,
)
],
),
pytest.param(
lambda x: x >> 1,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(15).reshape(3, 5) // 2,
)
],
),
pytest.param(
lambda x: 2 << x,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5) % 8,
2 << (numpy.arange(15).reshape(3, 5) % 8),
)
],
),
pytest.param(
lambda x: 256 >> x,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5) % 8,
256 >> (numpy.arange(15).reshape(3, 5) % 8),
)
],
),
pytest.param(
lambda x: x > 4,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(15).reshape(3, 5) > 4,
)
],
),
pytest.param(
lambda x: x < 5,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(15).reshape(3, 5) < 5,
)
],
),
pytest.param(
lambda x: x <= 7,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(15).reshape(3, 5) <= 7,
)
],
),
pytest.param(
lambda x: x >= 9,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(15).reshape(3, 5) >= 9,
)
],
),
pytest.param(
lambda x: x == 11,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(15).reshape(3, 5) == 11,
)
],
),
pytest.param(
lambda x: x != 12,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(15).reshape(3, 5) != 12,
)
],
),
# Remove misplaced-comparison-constant because precisely, we want to be sure it works fine
# pylint: disable=misplaced-comparison-constant
pytest.param(
lambda x: 4 > x,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
4 > numpy.arange(15).reshape(3, 5),
)
],
),
pytest.param(
lambda x: 5 < x,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
5 < numpy.arange(15).reshape(3, 5),
)
],
),
pytest.param(
lambda x: 7 <= x,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
7 <= numpy.arange(15).reshape(3, 5),
)
],
),
pytest.param(
lambda x: 9 >= x,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
9 >= numpy.arange(15).reshape(3, 5),
)
],
),
pytest.param(
lambda x: 11 == x,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
11 == numpy.arange(15).reshape(3, 5),
)
],
),
pytest.param(
lambda x: 12 != x,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
12 != numpy.arange(15).reshape(3, 5),
)
],
),
# pylint: enable=misplaced-comparison-constant
(
lambda x: x & 11,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.array([i & 11 for i in range(15)]).reshape(3, 5),
),
],
),
(
lambda x: 13 & x,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.array([i & 13 for i in range(15)]).reshape(3, 5),
),
],
),
(
lambda x: x | 6,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.array([i | 6 for i in range(15)]).reshape(3, 5),
),
],
),
(
lambda x: 30 | x,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.array([i | 30 for i in range(15)]).reshape(3, 5),
),
],
),
(
lambda x: x ^ 91,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.array([i ^ 91 for i in range(15)]).reshape(3, 5),
),
],
),
(
lambda x: 115 ^ x,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.array([i ^ 115 for i in range(15)]).reshape(3, 5),
),
],
),
(
lambda x: x % 11,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.array([i % 11 for i in range(15)]).reshape(3, 5),
),
],
),
(
lambda x: 150 % (x + 1),
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.array([150 % (i + 1) for i in range(15)]).reshape(3, 5),
),
],
),
(
lambda x: x ** 2,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.array([i ** 2 for i in range(15)]).reshape(3, 5),
),
],
),
(
lambda x: 2 ** x,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5) % 7,
numpy.array([2 ** (i % 7) for i in range(15)]).reshape(3, 5),
),
],
),
],
)
def test_tracing_ndarray_calls(
function_to_trace,
input_value_input_and_expected_output_tuples,
check_array_equality,
):
"""Test memory function managed by GenericFunction node of the form ndarray.something"""
subtest_tracing_calls(
function_to_trace, input_value_input_and_expected_output_tuples, check_array_equality
)
@pytest.mark.parametrize(
"lambda_f,params",
[
(
lambda x: numpy.reshape(x, (5, 3)),
{
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(7, 5)),
},
),
],
)
def test_errors_with_generic_function(lambda_f, params):
"Test some errors with generic function"
with pytest.raises(ValueError) as excinfo:
tracing.trace_numpy_function(lambda_f, params)
assert "shapes are not compatible (old shape (7, 5), new shape (5, 3))" in str(excinfo.value)

View File

@@ -1,309 +0,0 @@
"""Test file for numpy tracing"""
from copy import deepcopy
import numpy
import pytest
from concrete.common.data_types.floats import Float
from concrete.common.data_types.integers import Integer
from concrete.common.representation import intermediate as ir
from concrete.common.values import EncryptedScalar, EncryptedTensor
from concrete.numpy import tracing
OPERATIONS_TO_TEST = [ir.Add, ir.Sub, ir.Mul]
# Functions from tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC, whose output
# is a float64, whatever the input type
LIST_OF_UFUNC_WHOSE_OUTPUT_IS_FLOAT64 = set(
[
numpy.arccos,
numpy.arccosh,
numpy.arcsin,
numpy.arcsinh,
numpy.arctan,
numpy.arctanh,
numpy.cbrt,
numpy.ceil,
numpy.cos,
numpy.cosh,
numpy.deg2rad,
numpy.degrees,
numpy.exp,
numpy.exp2,
numpy.expm1,
numpy.fabs,
numpy.floor,
numpy.log,
numpy.log10,
numpy.log1p,
numpy.log2,
numpy.rad2deg,
numpy.radians,
numpy.rint,
numpy.sin,
numpy.sinh,
numpy.spacing,
numpy.sqrt,
numpy.tan,
numpy.tanh,
numpy.trunc,
]
)
# Functions from tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC, whose output
# is a boolean, whatever the input type
LIST_OF_UFUNC_WHOSE_OUTPUT_IS_BOOL = set(
[
numpy.isfinite,
numpy.isinf,
numpy.isnan,
numpy.signbit,
numpy.logical_not,
]
)
@pytest.mark.parametrize(
"inputs,expected_output_node",
[
pytest.param(
{"x": EncryptedScalar(Integer(7, is_signed=False))},
ir.GenericFunction,
),
pytest.param(
{"x": EncryptedScalar(Integer(32, is_signed=True))},
ir.GenericFunction,
),
pytest.param(
{"x": EncryptedScalar(Integer(64, is_signed=True))},
ir.GenericFunction,
),
pytest.param(
{"x": EncryptedScalar(Integer(128, is_signed=True))},
ir.GenericFunction,
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
),
pytest.param(
{"x": EncryptedScalar(Float(64))},
ir.GenericFunction,
),
],
)
@pytest.mark.parametrize(
"function_to_trace_def",
[f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 1],
)
def test_trace_numpy_supported_unary_ufuncs(inputs, expected_output_node, function_to_trace_def):
"""Function to trace supported numpy ufuncs"""
# We really need a lambda (because numpy functions are not playing
# nice with inspect.signature), but pylint and flake8 are not happy
# with it
# pylint: disable=cell-var-from-loop
function_to_trace = lambda x: function_to_trace_def(x) # noqa: E731
# pylint: enable=cell-var-from-loop
op_graph = tracing.trace_numpy_function(function_to_trace, inputs)
assert len(op_graph.output_nodes) == 1
assert isinstance(op_graph.output_nodes[0], expected_output_node)
assert len(op_graph.output_nodes[0].outputs) == 1
if function_to_trace_def in LIST_OF_UFUNC_WHOSE_OUTPUT_IS_FLOAT64:
assert op_graph.output_nodes[0].outputs[0] == EncryptedScalar(Float(64))
elif function_to_trace_def in LIST_OF_UFUNC_WHOSE_OUTPUT_IS_BOOL:
# Boolean function
assert op_graph.output_nodes[0].outputs[0] == EncryptedScalar(Integer(8, is_signed=False))
else:
# Function keeping more or less input type
input_node_type = inputs["x"]
expected_output_node_type = deepcopy(input_node_type)
expected_output_node_type.dtype.bit_width = max(
expected_output_node_type.dtype.bit_width, 32
)
assert op_graph.output_nodes[0].outputs[0] == expected_output_node_type
@pytest.mark.parametrize("np_function", tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC)
def test_nptracer_get_tracing_func_for_np_functions(np_function):
"""Test NPTracer get_tracing_func_for_np_function"""
expected_tracing_func = tracing.NPTracer.UFUNC_ROUTING[np_function]
assert tracing.NPTracer.get_tracing_func_for_np_function(np_function) == expected_tracing_func
def subtest_tracing_calls(
function_to_trace,
input_value_input_and_expected_output_tuples,
check_array_equality,
):
"""Test memory function managed by GenericFunction node of the form numpy.something"""
for input_value, input_, expected_output in input_value_input_and_expected_output_tuples:
op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value})
output_node = op_graph.output_nodes[0]
node_results = op_graph.evaluate({0: input_})
evaluated_output = node_results[output_node]
assert isinstance(evaluated_output, type(expected_output)), type(evaluated_output)
check_array_equality(evaluated_output, expected_output)
@pytest.mark.parametrize(
"function_to_trace,input_value_input_and_expected_output_tuples",
[
(
lambda x: numpy.transpose(x),
[
(
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
numpy.arange(4).reshape(2, 2),
numpy.array([[0, 2], [1, 3]]),
),
(
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
numpy.arange(4, 8).reshape(2, 2),
numpy.array([[4, 6], [5, 7]]),
),
(
EncryptedTensor(Integer(6, is_signed=False), shape=()),
numpy.int64(42),
numpy.int64(42),
),
],
),
(
lambda x: numpy.transpose(x) + 42,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(42, 57).reshape(3, 5).transpose(),
),
(
EncryptedTensor(Integer(6, is_signed=False), shape=()),
numpy.int64(42),
numpy.int64(84),
),
],
),
(
lambda x: numpy.ravel(x),
[
(
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
numpy.arange(4),
numpy.array([0, 1, 2, 3]),
),
(
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
numpy.arange(4).reshape(2, 2),
numpy.array([0, 1, 2, 3]),
),
(
EncryptedTensor(Integer(6, is_signed=False), shape=()),
numpy.int64(42),
numpy.array([42], dtype=numpy.int64),
),
],
),
(
lambda x: numpy.reshape(x, (5, 3)) + 42,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(42, 57).reshape(5, 3),
),
],
),
],
)
def test_tracing_numpy_calls(
function_to_trace,
input_value_input_and_expected_output_tuples,
check_array_equality,
):
"""Test memory function managed by GenericFunction node of the form numpy.something"""
subtest_tracing_calls(
function_to_trace, input_value_input_and_expected_output_tuples, check_array_equality
)
@pytest.mark.parametrize(
"function_to_trace,input_value_input_and_expected_output_tuples",
[
(
lambda x: x.transpose() + 42,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(42, 57).reshape(3, 5).transpose(),
),
(
EncryptedTensor(Integer(6, is_signed=False), shape=()),
numpy.int64(42),
numpy.int64(84),
),
],
),
(
lambda x: x.ravel(),
[
(
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
numpy.arange(4),
numpy.array([0, 1, 2, 3]),
),
(
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
numpy.arange(4).reshape(2, 2),
numpy.array([0, 1, 2, 3]),
),
(
EncryptedTensor(Integer(6, is_signed=False), shape=()),
numpy.int64(42),
numpy.array([42], dtype=numpy.int64),
),
],
),
(
lambda x: x.reshape((5, 3)) + 42,
[
(
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(42, 57).reshape(5, 3),
),
],
),
pytest.param(
lambda x: x.reshape((5, 3)),
[
(
EncryptedTensor(Integer(6, is_signed=False), shape=()),
numpy.int64(42),
None,
)
],
marks=pytest.mark.xfail(strict=True, raises=ValueError),
),
],
)
def test_tracing_ndarray_calls(
function_to_trace,
input_value_input_and_expected_output_tuples,
check_array_equality,
):
"""Test memory function managed by GenericFunction node of the form ndarray.something"""
subtest_tracing_calls(
function_to_trace, input_value_input_and_expected_output_tuples, check_array_equality
)

View File

@@ -1,184 +0,0 @@
"""Test file for numpy tracing"""
import inspect
import numpy
import pytest
from concrete.common.data_types.integers import Integer
from concrete.common.representation import intermediate as ir
from concrete.common.values import ClearScalar, EncryptedScalar, EncryptedTensor
from concrete.numpy import tracing
@pytest.mark.parametrize(
"inputs",
[
pytest.param(
{"x": EncryptedScalar(Integer(32, is_signed=True))},
),
],
)
@pytest.mark.parametrize(
"function_to_trace",
# We really need a lambda (because numpy functions are not playing
# nice with inspect.signature), but pylint is not happy
# with it
[lambda x: numpy.invert(x), lambda x: numpy.bitwise_not(x)],
)
def test_trace_numpy_fails_for_invert(inputs, function_to_trace):
"""Check we catch calls to numpy.invert and tell user to change their code"""
with pytest.raises(RuntimeError) as excinfo:
tracing.trace_numpy_function(function_to_trace, inputs)
assert (
"NPTracer does not manage the following func: invert. Please replace by calls to "
"bitwise_xor with appropriate mask" in str(excinfo.value)
)
def test_trace_numpy_ufuncs_not_supported():
"""Testing a failure case of trace_numpy_function"""
inputs = {"x": EncryptedScalar(Integer(128, is_signed=True))}
# We really need a lambda (because numpy functions are not playing
# nice with inspect.signature), but pylint and flake8 are not happy
# with it
function_to_trace = lambda x: numpy.add.reduce(x) # noqa: E731
with pytest.raises(NotImplementedError) as excinfo:
tracing.trace_numpy_function(function_to_trace, inputs)
assert "Only __call__ method is supported currently" in str(excinfo.value)
def test_trace_numpy_ufuncs_no_kwargs_no_extra_args():
"""Test a case where kwargs are not allowed and too many inputs are passed"""
inputs = {
"x": EncryptedScalar(Integer(32, is_signed=True)),
"y": EncryptedScalar(Integer(32, is_signed=True)),
"z": EncryptedScalar(Integer(32, is_signed=True)),
}
# We really need a lambda (because numpy functions are not playing
# nice with inspect.signature), but pylint and flake8 are not happy
# with it
function_to_trace = lambda x, y, z: numpy.add(x, y, z) # noqa: E731
with pytest.raises(AssertionError) as excinfo:
tracing.trace_numpy_function(function_to_trace, inputs)
# numpy only passes ufunc.nin tracers so the extra arguments are passed as kwargs
assert "**kwargs are currently not supported for numpy ufuncs, ufunc: add" in str(excinfo.value)
# We really need a lambda (because numpy functions are not playing
# nice with inspect.signature), but pylint and flake8 are not happy
# with it
function_to_trace = lambda x, y, z: numpy.add(x, y, out=z) # noqa: E731
with pytest.raises(AssertionError) as excinfo:
tracing.trace_numpy_function(function_to_trace, inputs)
assert "**kwargs are currently not supported for numpy ufuncs, ufunc: add" in str(excinfo.value)
def test_nptracer_get_tracing_func_for_np_functions_not_implemented():
"""Check NPTracer in case of not-implemented function"""
with pytest.raises(NotImplementedError) as excinfo:
tracing.NPTracer.get_tracing_func_for_np_function(numpy.conjugate)
assert "NPTracer does not yet manage the following func: conjugate" in str(excinfo.value)
@pytest.mark.parametrize(
"operation,exception_type,match",
[
pytest.param(
lambda x: x + "fail",
TypeError,
"unsupported operand type(s) for +: 'NPTracer' and 'str'",
),
pytest.param(
lambda x: "fail" + x,
TypeError,
'can only concatenate str (not "NPTracer") to str',
),
pytest.param(
lambda x: x - "fail",
TypeError,
"unsupported operand type(s) for -: 'NPTracer' and 'str'",
),
pytest.param(
lambda x: "fail" - x,
TypeError,
"unsupported operand type(s) for -: 'str' and 'NPTracer'",
),
pytest.param(
lambda x: x * "fail",
TypeError,
"can't multiply sequence by non-int of type 'NPTracer'",
),
pytest.param(
lambda x: "fail" * x,
TypeError,
"can't multiply sequence by non-int of type 'NPTracer'",
),
pytest.param(
lambda x: x / "fail",
TypeError,
"unsupported operand type(s) for /: 'NPTracer' and 'str'",
),
pytest.param(
lambda x: "fail" / x,
TypeError,
"unsupported operand type(s) for /: 'str' and 'NPTracer'",
),
pytest.param(
lambda x: x // "fail",
TypeError,
"unsupported operand type(s) for //: 'NPTracer' and 'str'",
),
pytest.param(
lambda x: "fail" // x,
TypeError,
"unsupported operand type(s) for //: 'str' and 'NPTracer'",
),
pytest.param(
lambda x, y: x / y, NotImplementedError, "Can't manage binary operator truediv"
),
pytest.param(
lambda x, y: x // y, NotImplementedError, "Can't manage binary operator floordiv"
),
],
)
def test_nptracer_unsupported_operands(operation, exception_type, match):
"""Test cases where NPTracer cannot be used with other operands."""
tracers = [
tracing.NPTracer([], ir.Input(ClearScalar(Integer(32, True)), param_name, idx), 0)
for idx, param_name in enumerate(inspect.signature(operation).parameters.keys())
]
with pytest.raises(exception_type) as exc_info:
_ = operation(*tracers)
assert match in str(exc_info)
@pytest.mark.parametrize(
"lambda_f,params",
[
(
lambda x: numpy.reshape(x, (5, 3)),
{
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(7, 5)),
},
),
],
)
def test_errors_with_generic_function(lambda_f, params):
"Test some errors with generic function"
with pytest.raises(ValueError) as excinfo:
tracing.trace_numpy_function(lambda_f, params)
assert "shapes are not compatible (old shape (7, 5), new shape (5, 3))" in str(excinfo.value)