mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
chore: remove the old implementation and its tests
This commit is contained in:
@@ -1,5 +0,0 @@
|
||||
"""Top level import."""
|
||||
# Do not modify, this is to have a compatible namespace package
|
||||
# https://packaging.python.org/en/latest/guides/packaging-namespace-packages/
|
||||
# #pkg-resources-style-namespace-packages
|
||||
__import__("pkg_resources").declare_namespace(__name__) # pragma: no cover
|
||||
@@ -1,3 +0,0 @@
|
||||
"""Module for shared data structures and code."""
|
||||
from . import compilation, data_types, debugging, representation, values
|
||||
from .common_helpers import check_op_graph_is_integer_program, is_a_power_of_2
|
||||
@@ -1,2 +0,0 @@
|
||||
"""Bounds measurement module."""
|
||||
from . import inputset_eval
|
||||
@@ -1,260 +0,0 @@
|
||||
"""Code to evaluate the IR graph on inputsets."""
|
||||
|
||||
import sys
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
|
||||
|
||||
from ..compilation import CompilationConfiguration
|
||||
from ..data_types.dtypes_helpers import (
|
||||
get_base_value_for_python_constant_data,
|
||||
is_data_type_compatible_with,
|
||||
)
|
||||
from ..debugging import assert_true
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import IntermediateNode
|
||||
|
||||
|
||||
def _check_input_coherency(
|
||||
input_to_check: Dict[str, Any],
|
||||
parameters: Dict[str, Any],
|
||||
get_base_value_for_constant_data_func: Callable[[Any], Any],
|
||||
):
|
||||
"""Check whether `input_to_check` is coherent with `parameters`.
|
||||
|
||||
This function works by iterating over each constant of the input,
|
||||
determining base value of the constant using `get_base_value_for_constant_data_func` and
|
||||
checking if the base value of the contant is compatible with the base value of the parameter.
|
||||
|
||||
Args:
|
||||
input_to_check (Dict[str, Any]): input to check coherency of
|
||||
parameters (Dict[str, Any]): parameters and their expected base values
|
||||
get_base_value_for_constant_data_func (Callable[[Any], Any]):
|
||||
function to get the base value of python objects.
|
||||
|
||||
Returns:
|
||||
List[str]: List of warnings about the coherency
|
||||
"""
|
||||
|
||||
warnings = []
|
||||
for parameter_name, value in input_to_check.items():
|
||||
parameter_base_value = parameters[parameter_name]
|
||||
|
||||
base_value_class = get_base_value_for_constant_data_func(value)
|
||||
base_value = base_value_class(is_encrypted=parameter_base_value.is_encrypted)
|
||||
|
||||
if base_value.shape != parameter_base_value.shape or not is_data_type_compatible_with(
|
||||
base_value.dtype, parameter_base_value.dtype
|
||||
):
|
||||
warnings.append(
|
||||
f"expected {str(parameter_base_value)} "
|
||||
f"for parameter `{parameter_name}` "
|
||||
f"but got {str(base_value)} "
|
||||
f"which is not compatible"
|
||||
)
|
||||
return warnings
|
||||
|
||||
|
||||
def _print_input_coherency_warnings(
|
||||
current_input_index: int,
|
||||
current_input_data: Dict[int, Any],
|
||||
parameters: Dict[str, Any],
|
||||
parameter_index_to_parameter_name: Dict[int, str],
|
||||
get_base_value_for_constant_data_func: Callable[[Any], Any],
|
||||
treat_warnings_as_errors: bool,
|
||||
):
|
||||
"""Print coherency warning for `input_to_check` against `parameters`.
|
||||
|
||||
Args:
|
||||
current_input_index (int): index of the current input on the inputset
|
||||
current_input_data (Dict[int, Any]): input to print coherency warnings of
|
||||
parameters (Dict[str, Any]): parameters and their expected base values
|
||||
parameter_index_to_parameter_name (Dict[int, str]):
|
||||
dict to get parameter names from parameter indices
|
||||
get_base_value_for_constant_data_func (Callable[[Any], Any]):
|
||||
function to get the base value of python objects.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
current_input_named_data = {
|
||||
parameter_index_to_parameter_name[index]: data for index, data in current_input_data.items()
|
||||
}
|
||||
|
||||
problems = _check_input_coherency(
|
||||
current_input_named_data,
|
||||
parameters,
|
||||
get_base_value_for_constant_data_func,
|
||||
)
|
||||
messages = [
|
||||
(
|
||||
f"Input #{current_input_index} (0-indexed) "
|
||||
f"is not coherent with the hinted parameters ({problem})\n"
|
||||
)
|
||||
for problem in problems
|
||||
]
|
||||
|
||||
if len(messages) > 0:
|
||||
if treat_warnings_as_errors:
|
||||
raise ValueError(", ".join(messages))
|
||||
|
||||
for message in messages:
|
||||
sys.stderr.write(f"Warning: {message}")
|
||||
|
||||
|
||||
def eval_op_graph_bounds_on_inputset(
|
||||
op_graph: OPGraph,
|
||||
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]],
|
||||
compilation_configuration: CompilationConfiguration,
|
||||
min_func: Callable[[Any, Any], Any] = min,
|
||||
max_func: Callable[[Any, Any], Any] = max,
|
||||
get_base_value_for_constant_data_func: Callable[
|
||||
[Any], Any
|
||||
] = get_base_value_for_python_constant_data,
|
||||
prev_node_bounds_and_samples: Optional[Dict[IntermediateNode, Dict[str, Any]]] = None,
|
||||
) -> Tuple[int, Dict[IntermediateNode, Dict[str, Any]]]:
|
||||
"""Evaluate the bounds with a inputset.
|
||||
|
||||
Evaluate the bounds for all output values of the operators in the graph op_graph over data
|
||||
coming from the inputset
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): The graph for which we want to determine the bounds
|
||||
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]]]): The inputset over which op_graph
|
||||
is evaluated. It needs to be an iterable on tuples (can be single values in case the
|
||||
function has only one argument) which are of the same length than the number of
|
||||
parameters in the function, and in the same order than these same parameters
|
||||
compilation_configuration (CompilationConfiguration): Configuration object to use
|
||||
during determining input checking strategy
|
||||
min_func (Callable[[Any, Any], Any], optional): custom function to compute a scalar minimum
|
||||
between two values that can be encountered during evaluation (for e.g. numpy or torch
|
||||
tensors). Defaults to min.
|
||||
max_func (Callable[[Any, Any], Any], optional): custom function to compute a scalar maximum
|
||||
between two values that can be encountered during evaluation (for e.g. numpy or torch
|
||||
tensors). Defaults to max.
|
||||
get_base_value_for_constant_data_func (Callable[[Any], Any], optional): custom function
|
||||
to compute the base value of a python object.
|
||||
prev_node_bounds_and_samples (Optional[Dict[IntermediateNode, Dict[str, Any]]], optional):
|
||||
Bounds and samples from a previous run. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tuple[int, Dict[IntermediateNode, Dict[str, Any]]]: number of inputs in the inputset and
|
||||
a dict containing the bounds for each node from op_graph, stored with the node
|
||||
as key and a dict with keys "min", "max" and "sample" as value.
|
||||
"""
|
||||
|
||||
num_input_nodes = len(op_graph.input_nodes)
|
||||
|
||||
def check_inputset_input_len_is_valid(data_to_check):
|
||||
# Only check if there are more than one input node, otherwise accept the value as the sole
|
||||
# argument passed to the OPGraph for evaluation
|
||||
if num_input_nodes > 1:
|
||||
assert_true(
|
||||
len(data_to_check) == num_input_nodes,
|
||||
(
|
||||
f"Got input data from inputset of len: {len(data_to_check)}, "
|
||||
f"function being evaluated has {num_input_nodes} inputs, please make "
|
||||
f"sure your data generator returns valid tuples of input values"
|
||||
),
|
||||
)
|
||||
|
||||
def generate_input_values_dict(input_data) -> Dict[int, Any]:
|
||||
if num_input_nodes > 1:
|
||||
return dict(enumerate(input_data))
|
||||
# TODO: https://github.com/zama-ai/concrete-numpy-internal/issues/772
|
||||
# update this to support tuple in case of 1-input functions accepting tuples
|
||||
assert_true(
|
||||
not isinstance(input_data, tuple),
|
||||
"Tuples are unsupported for single input inputset evaluation",
|
||||
)
|
||||
return {0: input_data}
|
||||
|
||||
# TODO: do we want to check coherence between the input data type and the corresponding Input ir
|
||||
# node expected data type ? Not considering bit_width as they may not make sense at this stage
|
||||
|
||||
parameter_index_to_parameter_name = {
|
||||
index: input_node.input_name for index, input_node in op_graph.input_nodes.items()
|
||||
}
|
||||
parameters = {
|
||||
input_node.input_name: input_node.inputs[0] for input_node in op_graph.input_nodes.values()
|
||||
}
|
||||
|
||||
inputset_iterator = iter(inputset)
|
||||
inputset_size = 0
|
||||
|
||||
current_input_data = generate_input_values_dict(next(inputset_iterator))
|
||||
inputset_size += 1
|
||||
|
||||
check_inputset_input_len_is_valid(current_input_data.values())
|
||||
_print_input_coherency_warnings(
|
||||
inputset_size - 1,
|
||||
current_input_data,
|
||||
parameters,
|
||||
parameter_index_to_parameter_name,
|
||||
get_base_value_for_constant_data_func,
|
||||
compilation_configuration.treat_warnings_as_errors,
|
||||
)
|
||||
|
||||
first_output = op_graph.evaluate(current_input_data)
|
||||
|
||||
prev_node_bounds_and_samples = (
|
||||
{} if prev_node_bounds_and_samples is None else prev_node_bounds_and_samples
|
||||
)
|
||||
|
||||
def get_previous_value_for_key_or_default_for_dict(
|
||||
dict_: Dict[IntermediateNode, Dict[str, Any]],
|
||||
node: IntermediateNode,
|
||||
key: str,
|
||||
default: Any,
|
||||
) -> Any:
|
||||
return_value = default
|
||||
|
||||
previous_value_dict = dict_.get(node, None)
|
||||
|
||||
if previous_value_dict is not None:
|
||||
return_value = previous_value_dict.get(key, default)
|
||||
|
||||
return return_value
|
||||
|
||||
get_previous_value_for_key_or_default = partial(
|
||||
get_previous_value_for_key_or_default_for_dict, prev_node_bounds_and_samples
|
||||
)
|
||||
|
||||
# We evaluate the min and max func to be able to resolve the tensors min and max rather than
|
||||
# having the tensor itself as the stored min and max values.
|
||||
# As we don't know the integrity of prev_node_bounds_and_samples we make sure we can
|
||||
# populate the new node_bounds_and_samples
|
||||
node_bounds_and_samples = {
|
||||
node: {
|
||||
"min": min_func(value, get_previous_value_for_key_or_default(node, "min", value)),
|
||||
"max": max_func(value, get_previous_value_for_key_or_default(node, "max", value)),
|
||||
"sample": get_previous_value_for_key_or_default(node, "sample", value),
|
||||
}
|
||||
for node, value in first_output.items()
|
||||
}
|
||||
|
||||
for input_data in inputset_iterator:
|
||||
inputset_size += 1
|
||||
current_input_data = generate_input_values_dict(input_data)
|
||||
|
||||
check_inputset_input_len_is_valid(current_input_data.values())
|
||||
if compilation_configuration.check_every_input_in_inputset:
|
||||
_print_input_coherency_warnings(
|
||||
inputset_size - 1,
|
||||
current_input_data,
|
||||
parameters,
|
||||
parameter_index_to_parameter_name,
|
||||
get_base_value_for_constant_data_func,
|
||||
compilation_configuration.treat_warnings_as_errors,
|
||||
)
|
||||
|
||||
current_output = op_graph.evaluate(current_input_data)
|
||||
for node, value in current_output.items():
|
||||
node_bounds_and_samples[node]["min"] = min_func(
|
||||
node_bounds_and_samples[node]["min"], value
|
||||
)
|
||||
node_bounds_and_samples[node]["max"] = max_func(
|
||||
node_bounds_and_samples[node]["max"], value
|
||||
)
|
||||
|
||||
return inputset_size, node_bounds_and_samples
|
||||
@@ -1,67 +0,0 @@
|
||||
"""File to hold some helper code."""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from .data_types.integers import Integer
|
||||
from .debugging import assert_true
|
||||
from .operator_graph import OPGraph
|
||||
from .representation.intermediate import IntermediateNode
|
||||
|
||||
|
||||
def is_a_power_of_2(x: int) -> bool:
|
||||
"""Check if an integer is a power of two.
|
||||
|
||||
Args:
|
||||
x (int): Number to check
|
||||
|
||||
Returns:
|
||||
bool: True if the number is a power of two
|
||||
"""
|
||||
# https://stackoverflow.com/questions/57025836/how-to-check-if-a-given-number-is-a-power-of-two
|
||||
|
||||
return x > 0 and (x & (x - 1)) == 0
|
||||
|
||||
|
||||
def ir_nodes_has_integer_input_and_output(node: IntermediateNode) -> bool:
|
||||
"""Check if an ir node has Integer inputs and outputs.
|
||||
|
||||
Args:
|
||||
node (IntermediateNode): Node to check
|
||||
|
||||
Returns:
|
||||
bool: True if all input and output values hold Integers
|
||||
"""
|
||||
return all(isinstance(x.dtype, Integer) for x in node.inputs) and all(
|
||||
isinstance(x.dtype, Integer) for x in node.outputs
|
||||
)
|
||||
|
||||
|
||||
# This check makes sense as long as the compiler backend only manages integers, to be removed in the
|
||||
# long run probably
|
||||
def check_op_graph_is_integer_program(
|
||||
op_graph: OPGraph,
|
||||
offending_nodes_out: Optional[List[IntermediateNode]] = None,
|
||||
) -> bool:
|
||||
"""Check if an op_graph inputs, outputs and intermediate values are Integers.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): The OPGraph to check
|
||||
offending_nodes_out (Optional[List[IntermediateNode]]): Optionally pass a list that will
|
||||
be populated with offending nodes, the list will be cleared before being filled
|
||||
|
||||
Returns:
|
||||
bool: True if inputs, outputs and intermediate values are Integers, False otherwise
|
||||
"""
|
||||
offending_nodes = [] if offending_nodes_out is None else offending_nodes_out
|
||||
|
||||
assert_true(
|
||||
isinstance(offending_nodes, list),
|
||||
f"offending_nodes_out must be a list, got {type(offending_nodes_out)}",
|
||||
)
|
||||
|
||||
offending_nodes.clear()
|
||||
offending_nodes.extend(
|
||||
node for node in op_graph.graph.nodes() if not ir_nodes_has_integer_input_and_output(node)
|
||||
)
|
||||
|
||||
return len(offending_nodes) == 0
|
||||
@@ -1,4 +0,0 @@
|
||||
"""Module for compilation related types."""
|
||||
|
||||
from .artifacts import CompilationArtifacts
|
||||
from .configuration import CompilationConfiguration
|
||||
@@ -1,222 +0,0 @@
|
||||
"""Module for compilation artifacts."""
|
||||
|
||||
import inspect
|
||||
import platform
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import networkx as nx
|
||||
from loguru import logger
|
||||
|
||||
from ..debugging import assert_true, draw_graph, format_operation_graph
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import IntermediateNode
|
||||
from ..values import BaseValue
|
||||
|
||||
DEFAULT_OUTPUT_DIRECTORY: Path = Path(".artifacts")
|
||||
|
||||
|
||||
class CompilationArtifacts:
|
||||
"""Class that conveys information about compilation process."""
|
||||
|
||||
output_directory: Path
|
||||
|
||||
source_code_of_the_function_to_compile: Optional[str]
|
||||
parameters_of_the_function_to_compile: Dict[str, str]
|
||||
|
||||
drawings_of_operation_graphs: Dict[str, str]
|
||||
textual_representations_of_operation_graphs: Dict[str, str]
|
||||
|
||||
final_operation_graph: Optional[OPGraph]
|
||||
bounds_of_the_final_operation_graph: Optional[Dict[IntermediateNode, Dict[str, Any]]]
|
||||
mlir_of_the_final_operation_graph: Optional[str]
|
||||
|
||||
def __init__(self, output_directory: Union[Path, str] = DEFAULT_OUTPUT_DIRECTORY):
|
||||
self.output_directory = Path(output_directory)
|
||||
|
||||
self.source_code_of_the_function_to_compile = None
|
||||
self.parameters_of_the_function_to_compile = {}
|
||||
|
||||
self.drawings_of_operation_graphs = {}
|
||||
self.textual_representations_of_operation_graphs = {}
|
||||
|
||||
self.final_operation_graph = None
|
||||
self.bounds_of_the_final_operation_graph = None
|
||||
self.mlir_of_the_final_operation_graph = None
|
||||
|
||||
def add_function_to_compile(self, function: Union[Callable, str]):
|
||||
"""Add the function to compile to artifacts.
|
||||
|
||||
Args:
|
||||
function (Union[Callable, str]): the function to compile or source code of it
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
try:
|
||||
self.source_code_of_the_function_to_compile = (
|
||||
function if isinstance(function, str) else inspect.getsource(function)
|
||||
)
|
||||
# When using the python console we cannot use getsource, so catch that and emit an error
|
||||
except OSError: # pragma: no cover
|
||||
function_str = function if isinstance(function, str) else function.__name__
|
||||
logger.error(f"Could not get source for function: {function_str}")
|
||||
self.source_code_of_the_function_to_compile = "unavailable"
|
||||
|
||||
def add_parameter_of_function_to_compile(self, name: str, value: Union[BaseValue, str]):
|
||||
"""Add a parameter of the function to compile to the artifacts.
|
||||
|
||||
Args:
|
||||
name (str): name of the parameter
|
||||
value (Union[BaseValue, str]): value of the parameter or textual representation of it
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
self.parameters_of_the_function_to_compile[name] = str(value)
|
||||
|
||||
def add_operation_graph(self, name: str, operation_graph: OPGraph):
|
||||
"""Add an operation graph to the artifacts.
|
||||
|
||||
Args:
|
||||
name (str): name of the graph
|
||||
operation_graph (OPGraph): the operation graph itself
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
try:
|
||||
drawing = draw_graph(operation_graph)
|
||||
self.drawings_of_operation_graphs[name] = drawing
|
||||
# Do not crash on imports ourselves for drawings if the package is not installed
|
||||
except ImportError as e: # pragma: no cover
|
||||
if "pygraphviz" in str(e):
|
||||
pass
|
||||
else:
|
||||
raise e
|
||||
textual_representation = format_operation_graph(operation_graph)
|
||||
|
||||
self.textual_representations_of_operation_graphs[name] = textual_representation
|
||||
|
||||
self.final_operation_graph = operation_graph
|
||||
|
||||
def add_final_operation_graph_bounds(self, bounds: Dict[IntermediateNode, Dict[str, Any]]):
|
||||
"""Add the bounds of the final operation graph to the artifacts.
|
||||
|
||||
Args:
|
||||
bounds (Dict[IntermediateNode, Dict[str, Any]]): the bound dictionary
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
assert_true(self.final_operation_graph is not None)
|
||||
self.bounds_of_the_final_operation_graph = bounds
|
||||
|
||||
def add_final_operation_graph_mlir(self, mlir: str):
|
||||
"""Add the mlir of the final operation graph to the artifacts.
|
||||
|
||||
Args:
|
||||
mlir (str): the mlir code of the final operation graph
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
assert_true(self.final_operation_graph is not None)
|
||||
self.mlir_of_the_final_operation_graph = mlir
|
||||
|
||||
def export(self):
|
||||
"""Export the artifacts to a the output directory.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
output_directory = self.output_directory
|
||||
if output_directory.exists():
|
||||
shutil.rmtree(output_directory)
|
||||
output_directory.mkdir(parents=True)
|
||||
|
||||
with open(output_directory.joinpath("environment.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{platform.platform()} {platform.version()}\n")
|
||||
f.write(f"Python {platform.python_version()}\n")
|
||||
|
||||
with open(output_directory.joinpath("requirements.txt"), "w", encoding="utf-8") as f:
|
||||
# example `pip list` output
|
||||
|
||||
# Package Version
|
||||
# ----------------------------- ---------
|
||||
# alabaster 0.7.12
|
||||
# appdirs 1.4.4
|
||||
# ... ...
|
||||
# ... ...
|
||||
# wrapt 1.12.1
|
||||
# zipp 3.5.0
|
||||
|
||||
pip_process = subprocess.run(
|
||||
["pip", "--disable-pip-version-check", "list"], stdout=subprocess.PIPE, check=True
|
||||
)
|
||||
dependencies = iter(pip_process.stdout.decode("utf-8").split("\n"))
|
||||
|
||||
# skip 'Package ... Version' line
|
||||
next(dependencies)
|
||||
|
||||
# skip '------- ... -------' line
|
||||
next(dependencies)
|
||||
|
||||
for dependency in dependencies:
|
||||
tokens = [token for token in dependency.split(" ") if token != ""]
|
||||
if len(tokens) == 0:
|
||||
continue
|
||||
|
||||
name = tokens[0]
|
||||
version = tokens[1]
|
||||
|
||||
f.write(f"{name}=={version}\n")
|
||||
|
||||
if self.source_code_of_the_function_to_compile is not None:
|
||||
with open(output_directory.joinpath("function.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(self.source_code_of_the_function_to_compile)
|
||||
|
||||
if len(self.parameters_of_the_function_to_compile) > 0:
|
||||
with open(output_directory.joinpath("parameters.txt"), "w", encoding="utf-8") as f:
|
||||
for name, parameter in self.parameters_of_the_function_to_compile.items():
|
||||
f.write(f"{name} :: {parameter}\n")
|
||||
|
||||
drawings = self.drawings_of_operation_graphs.items()
|
||||
for index, (name, drawing_filename) in enumerate(drawings):
|
||||
identifier = CompilationArtifacts._identifier(index, name)
|
||||
shutil.copy(drawing_filename, output_directory.joinpath(f"{identifier}.png"))
|
||||
|
||||
textual_representations = self.textual_representations_of_operation_graphs.items()
|
||||
for index, (name, representation) in enumerate(textual_representations):
|
||||
identifier = CompilationArtifacts._identifier(index, name)
|
||||
with open(output_directory.joinpath(f"{identifier}.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{representation}")
|
||||
|
||||
if self.bounds_of_the_final_operation_graph is not None:
|
||||
assert_true(self.final_operation_graph is not None)
|
||||
with open(output_directory.joinpath("bounds.txt"), "w", encoding="utf-8") as f:
|
||||
# TODO:
|
||||
# if nx.topological_sort is not deterministic between calls,
|
||||
# the lines below will not work properly
|
||||
# thus, we may want to change this in the future
|
||||
for index, node in enumerate(nx.topological_sort(self.final_operation_graph.graph)):
|
||||
bounds = self.bounds_of_the_final_operation_graph.get(node)
|
||||
assert_true(bounds is not None)
|
||||
f.write(f"%{index} :: [{bounds.get('min')}, {bounds.get('max')}]\n")
|
||||
|
||||
if self.mlir_of_the_final_operation_graph is not None:
|
||||
assert_true(self.final_operation_graph is not None)
|
||||
with open(output_directory.joinpath("mlir.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(self.mlir_of_the_final_operation_graph)
|
||||
|
||||
@staticmethod
|
||||
def _identifier(index, name):
|
||||
return f"{index + 1}.{name}.graph"
|
||||
@@ -1,46 +0,0 @@
|
||||
"""Module for compilation configuration."""
|
||||
|
||||
|
||||
class CompilationConfiguration:
|
||||
"""Class that allows the compilation process to be customized."""
|
||||
|
||||
dump_artifacts_on_unexpected_failures: bool
|
||||
enable_topological_optimizations: bool
|
||||
check_every_input_in_inputset: bool
|
||||
treat_warnings_as_errors: bool
|
||||
enable_unsafe_features: bool
|
||||
random_inputset_samples: int
|
||||
use_insecure_key_cache: bool
|
||||
auto_parallelize: bool
|
||||
loop_parallelize: bool
|
||||
dataflow_parallelize: bool
|
||||
|
||||
# pylint: disable=too-many-arguments
|
||||
def __init__(
|
||||
self,
|
||||
dump_artifacts_on_unexpected_failures: bool = True,
|
||||
enable_topological_optimizations: bool = True,
|
||||
check_every_input_in_inputset: bool = False,
|
||||
treat_warnings_as_errors: bool = False,
|
||||
enable_unsafe_features: bool = False,
|
||||
random_inputset_samples: int = 30,
|
||||
use_insecure_key_cache: bool = False,
|
||||
auto_parallelize: bool = False,
|
||||
loop_parallelize: bool = True,
|
||||
dataflow_parallelize: bool = False,
|
||||
):
|
||||
self.dump_artifacts_on_unexpected_failures = dump_artifacts_on_unexpected_failures
|
||||
self.enable_topological_optimizations = enable_topological_optimizations
|
||||
self.check_every_input_in_inputset = check_every_input_in_inputset
|
||||
self.treat_warnings_as_errors = treat_warnings_as_errors
|
||||
self.enable_unsafe_features = enable_unsafe_features
|
||||
self.random_inputset_samples = random_inputset_samples
|
||||
self.use_insecure_key_cache = use_insecure_key_cache
|
||||
self.auto_parallelize = auto_parallelize
|
||||
self.loop_parallelize = loop_parallelize
|
||||
self.dataflow_parallelize = dataflow_parallelize
|
||||
|
||||
# pylint: enable=too-many-arguments
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return isinstance(other, CompilationConfiguration) and self.__dict__ == other.__dict__
|
||||
@@ -1,4 +0,0 @@
|
||||
"""Module for data types code and data structures."""
|
||||
from . import dtypes_helpers, floats, integers
|
||||
from .floats import Float, Float16, Float32, Float64
|
||||
from .integers import Integer, SignedInteger, UnsignedInteger
|
||||
@@ -1,11 +0,0 @@
|
||||
"""File holding code to represent data types in a program."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseDataType(ABC):
|
||||
"""Base class to represent a data type."""
|
||||
|
||||
@abstractmethod
|
||||
def __eq__(self, o: object) -> bool:
|
||||
"""No default implementation."""
|
||||
@@ -1,393 +0,0 @@
|
||||
"""File to hold helper functions for data types related stuff."""
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Tuple, Union, cast
|
||||
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from ..values import BaseValue, ClearTensor, EncryptedTensor, TensorValue
|
||||
from .base import BaseDataType
|
||||
from .floats import Float
|
||||
from .integers import Integer, get_bits_to_represent_value_as_integer
|
||||
|
||||
INTEGER_TYPES = (Integer,)
|
||||
FLOAT_TYPES = (Float,)
|
||||
BASE_DATA_TYPES = INTEGER_TYPES + FLOAT_TYPES
|
||||
|
||||
|
||||
def value_is_encrypted_scalar_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is an encrypted scalar of type Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is an encrypted scalar of type Integer
|
||||
"""
|
||||
return value_is_scalar_integer(value_to_check) and value_to_check.is_encrypted
|
||||
|
||||
|
||||
def value_is_encrypted_scalar_unsigned_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is an encrypted scalar of type unsigned Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is an encrypted scalar of type Integer and
|
||||
unsigned
|
||||
"""
|
||||
return (
|
||||
value_is_encrypted_scalar_integer(value_to_check)
|
||||
and not cast(Integer, value_to_check.dtype).is_signed
|
||||
)
|
||||
|
||||
|
||||
def value_is_clear_scalar_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is a clear scalar of type Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is a clear scalar of type Integer
|
||||
"""
|
||||
return value_is_scalar_integer(value_to_check) and value_to_check.is_clear
|
||||
|
||||
|
||||
def value_is_scalar_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is a scalar of type Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is a scalar of type Integer
|
||||
"""
|
||||
return (
|
||||
isinstance(value_to_check, TensorValue)
|
||||
and value_to_check.is_scalar
|
||||
and isinstance(value_to_check.dtype, INTEGER_TYPES)
|
||||
)
|
||||
|
||||
|
||||
def value_is_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is of type Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is of type Integer
|
||||
"""
|
||||
|
||||
return isinstance(value_to_check.dtype, INTEGER_TYPES)
|
||||
|
||||
|
||||
def value_is_unsigned_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is of type Integer and is unsigned.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is of type Integer and is unsigned
|
||||
"""
|
||||
|
||||
return (
|
||||
isinstance(value_to_check.dtype, INTEGER_TYPES)
|
||||
and not cast(Integer, value_to_check.dtype).is_signed
|
||||
)
|
||||
|
||||
|
||||
def value_is_encrypted_tensor_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is an encrypted TensorValue of type Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is an encrypted TensorValue of type Integer
|
||||
"""
|
||||
return value_is_tensor_integer(value_to_check) and value_to_check.is_encrypted
|
||||
|
||||
|
||||
def value_is_clear_tensor_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is a clear TensorValue of type Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is a clear TensorValue of type Integer
|
||||
"""
|
||||
return value_is_tensor_integer(value_to_check) and value_to_check.is_clear
|
||||
|
||||
|
||||
def value_is_tensor_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is a TensorValue of type Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is a TensorValue of type Integer
|
||||
"""
|
||||
return (
|
||||
isinstance(value_to_check, TensorValue)
|
||||
and not value_to_check.is_scalar
|
||||
and isinstance(value_to_check.dtype, INTEGER_TYPES)
|
||||
)
|
||||
|
||||
|
||||
def find_type_to_hold_both_lossy(
|
||||
dtype1: BaseDataType,
|
||||
dtype2: BaseDataType,
|
||||
) -> BaseDataType:
|
||||
"""Determine the type that can represent both dtype1 and dtype2 separately.
|
||||
|
||||
This is lossy with floating point types.
|
||||
|
||||
Args:
|
||||
dtype1 (BaseDataType): first dtype to hold
|
||||
dtype2 (BaseDataType): second dtype to hold
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Raised if one of the two input dtypes is not an Integer as they are the
|
||||
only type supported for now
|
||||
|
||||
Returns:
|
||||
BaseDataType: The dtype able to hold (potentially lossy) dtype1 and dtype2
|
||||
"""
|
||||
assert_true(isinstance(dtype1, BASE_DATA_TYPES), f"Unsupported dtype1: {type(dtype1)}")
|
||||
assert_true(isinstance(dtype2, BASE_DATA_TYPES), f"Unsupported dtype2: {type(dtype2)}")
|
||||
|
||||
type_to_return: BaseDataType
|
||||
|
||||
if isinstance(dtype1, Integer) and isinstance(dtype2, Integer):
|
||||
d1_signed = dtype1.is_signed
|
||||
d2_signed = dtype2.is_signed
|
||||
max_bits = max(dtype1.bit_width, dtype2.bit_width)
|
||||
|
||||
if d1_signed and d2_signed:
|
||||
type_to_return = Integer(max_bits, is_signed=True)
|
||||
elif not d1_signed and not d2_signed:
|
||||
type_to_return = Integer(max_bits, is_signed=False)
|
||||
elif d1_signed and not d2_signed:
|
||||
# 2 is unsigned, if it has the bigger bit_width, we need a signed integer that can hold
|
||||
# it, so add 1 bit of sign to its bit_width
|
||||
if dtype2.bit_width >= dtype1.bit_width:
|
||||
new_bit_width = dtype2.bit_width + 1
|
||||
type_to_return = Integer(new_bit_width, is_signed=True)
|
||||
else:
|
||||
type_to_return = Integer(dtype1.bit_width, is_signed=True)
|
||||
elif not d1_signed and d2_signed:
|
||||
# Same as above, with 1 and 2 switched around
|
||||
if dtype1.bit_width >= dtype2.bit_width:
|
||||
new_bit_width = dtype1.bit_width + 1
|
||||
type_to_return = Integer(new_bit_width, is_signed=True)
|
||||
else:
|
||||
type_to_return = Integer(dtype2.bit_width, is_signed=True)
|
||||
elif isinstance(dtype1, Float) and isinstance(dtype2, Float):
|
||||
max_bits = max(dtype1.bit_width, dtype2.bit_width)
|
||||
type_to_return = Float(max_bits)
|
||||
elif isinstance(dtype1, Float):
|
||||
type_to_return = deepcopy(dtype1)
|
||||
elif isinstance(dtype2, Float):
|
||||
type_to_return = deepcopy(dtype2)
|
||||
|
||||
return type_to_return
|
||||
|
||||
|
||||
def mix_tensor_values_determine_holding_dtype(
|
||||
value1: TensorValue,
|
||||
value2: TensorValue,
|
||||
) -> TensorValue:
|
||||
"""Return mixed TensorValue with data type able to hold both value1 and value2 dtypes.
|
||||
|
||||
Returns a TensorValue that would result from computation on both value1 and value2 while
|
||||
determining the data type able to hold both value1 and value2 data type (this can be lossy
|
||||
with floats).
|
||||
|
||||
Args:
|
||||
value1 (TensorValue): first TensorValue to mix.
|
||||
value2 (TensorValue): second TensorValue to mix.
|
||||
|
||||
Returns:
|
||||
TensorValue: The resulting mixed TensorValue with data type able to hold both value1 and
|
||||
value2 dtypes.
|
||||
"""
|
||||
|
||||
assert_true(
|
||||
isinstance(value1, TensorValue), f"Unsupported value1: {value1}, expected TensorValue"
|
||||
)
|
||||
assert_true(
|
||||
isinstance(value2, TensorValue), f"Unsupported value2: {value2}, expected TensorValue"
|
||||
)
|
||||
|
||||
resulting_shape = broadcast_shapes(value1.shape, value2.shape)
|
||||
assert_true(
|
||||
resulting_shape is not None,
|
||||
(
|
||||
f"Tensors have incompatible shapes which is not supported.\n"
|
||||
f"value1: {value1.shape}, value2: {value2.shape}"
|
||||
),
|
||||
)
|
||||
assert resulting_shape is not None # this is to make mypy happy
|
||||
|
||||
holding_type = find_type_to_hold_both_lossy(value1.dtype, value2.dtype)
|
||||
if value1.is_encrypted or value2.is_encrypted:
|
||||
mixed_value = EncryptedTensor(dtype=holding_type, shape=resulting_shape)
|
||||
else:
|
||||
mixed_value = ClearTensor(dtype=holding_type, shape=resulting_shape)
|
||||
|
||||
return mixed_value
|
||||
|
||||
|
||||
def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> BaseValue:
|
||||
"""Return mixed BaseValue with data type able to hold both value1 and value2 dtypes.
|
||||
|
||||
Returns a BaseValue that would result from computation on both value1 and value2 while
|
||||
determining the data type able to hold both value1 and value2 data type (this can be lossy
|
||||
with floats). Supports only mixing instances from the same class.
|
||||
|
||||
Args:
|
||||
value1 (BaseValue): first BaseValue to mix.
|
||||
value2 (BaseValue): second BaseValue to mix.
|
||||
|
||||
Raises:
|
||||
ValueError: raised if the BaseValue is not one of (TensorValue)
|
||||
|
||||
Returns:
|
||||
BaseValue: The resulting mixed BaseValue with data type able to hold both value1 and value2
|
||||
dtypes.
|
||||
"""
|
||||
|
||||
assert_true(
|
||||
(value1.__class__ == value2.__class__),
|
||||
f"Cannot mix values of different types: value 1:{type(value1)}, value2: {type(value2)}",
|
||||
)
|
||||
|
||||
if isinstance(value1, TensorValue) and isinstance(value2, TensorValue):
|
||||
return mix_tensor_values_determine_holding_dtype(value1, value2)
|
||||
|
||||
raise ValueError(
|
||||
f"{mix_values_determine_holding_dtype.__name__} does not support value {type(value1)}"
|
||||
)
|
||||
|
||||
|
||||
def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]) -> BaseDataType:
|
||||
"""Determine the BaseDataType to hold the input constant data.
|
||||
|
||||
Args:
|
||||
constant_data (Union[int, float]): The constant data for which to determine the
|
||||
corresponding BaseDataType.
|
||||
|
||||
Returns:
|
||||
BaseDataType: The corresponding BaseDataType
|
||||
"""
|
||||
constant_data_type: BaseDataType
|
||||
assert_true(
|
||||
isinstance(constant_data, (int, float)),
|
||||
f"Unsupported constant data of type {type(constant_data)}",
|
||||
)
|
||||
|
||||
if isinstance(constant_data, int):
|
||||
is_signed = constant_data < 0
|
||||
constant_data_type = Integer(
|
||||
get_bits_to_represent_value_as_integer(constant_data, is_signed), is_signed
|
||||
)
|
||||
elif isinstance(constant_data, float):
|
||||
constant_data_type = Float(64)
|
||||
|
||||
return constant_data_type
|
||||
|
||||
|
||||
def get_base_value_for_python_constant_data(
|
||||
constant_data: Union[int, float]
|
||||
) -> Callable[..., BaseValue]:
|
||||
"""Wrap the BaseDataType to hold the input constant data in BaseValue partial.
|
||||
|
||||
The returned object can then be instantiated as an Encrypted or Clear version
|
||||
by calling it with the proper arguments forwarded to the BaseValue `__init__` function
|
||||
|
||||
Args:
|
||||
constant_data (Union[int, float]): The constant data for which to determine the
|
||||
corresponding Value.
|
||||
|
||||
Returns:
|
||||
Callable[..., BaseValue]: A partial object that will return the proper BaseValue when
|
||||
called with `is_encrypted` as keyword argument (forwarded to the BaseValue `__init__`
|
||||
method).
|
||||
"""
|
||||
|
||||
constant_data_type = get_base_data_type_for_python_constant_data(constant_data)
|
||||
return partial(TensorValue, dtype=constant_data_type, shape=())
|
||||
|
||||
|
||||
def get_constructor_for_python_constant_data(constant_data: Union[int, float]):
|
||||
"""Get the constructor for the passed python constant data.
|
||||
|
||||
Args:
|
||||
constant_data (Any): The data for which we want to determine the type constructor.
|
||||
"""
|
||||
return type(constant_data)
|
||||
|
||||
|
||||
def is_data_type_compatible_with(
|
||||
dtype: BaseDataType,
|
||||
other: BaseDataType,
|
||||
) -> bool:
|
||||
"""Determine whether dtype is compatible with other.
|
||||
|
||||
`dtype` being compatible with `other` means `other` can hold every value of `dtype`
|
||||
(e.g., uint2 is compatible with uint4 and int4)
|
||||
(e.g., int2 is compatible with int4 but not with uint4)
|
||||
|
||||
Note that this function is not symetric.
|
||||
(e.g., uint2 is compatible with uint4, but uint4 is not compatible with uint2)
|
||||
|
||||
Args:
|
||||
dtype (BaseDataType): dtype to check compatiblity
|
||||
other (BaseDataType): dtype to check compatiblity against
|
||||
|
||||
Returns:
|
||||
bool: Whether the dtype is compatible with other or not
|
||||
"""
|
||||
|
||||
combination = find_type_to_hold_both_lossy(dtype, other)
|
||||
return other == combination
|
||||
|
||||
|
||||
def broadcast_shapes(shape1: Tuple[int, ...], shape2: Tuple[int, ...]) -> Optional[Tuple[int, ...]]:
|
||||
"""Broadcast two shapes into a single shape.
|
||||
|
||||
We are mimicing the exact semantics of broadcasting in numpy.
|
||||
You can learn more about it here: https://numpy.org/doc/stable/user/theory.broadcasting.html
|
||||
|
||||
Args:
|
||||
shape1 (Tuple[int, ...]): first shape to broadcast
|
||||
shape2 (Tuple[int, ...]): second shape to broadcast
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[int, ...]]: None if the shapes are not broadcastable else broadcasted shape
|
||||
"""
|
||||
|
||||
result = []
|
||||
for size1, size2 in zip(shape1[::-1], shape2[::-1]):
|
||||
if size1 != size2 and size1 != 1 and size2 != 1 and size1 != 0 and size2 != 0:
|
||||
return None
|
||||
|
||||
if size1 == 0 or size2 == 0:
|
||||
result.append(0)
|
||||
else:
|
||||
result.append(max(size1, size2))
|
||||
|
||||
if len(result) < len(shape1):
|
||||
for i in reversed(range(len(shape1) - len(result))):
|
||||
result.append(shape1[i])
|
||||
elif len(result) < len(shape2):
|
||||
for i in reversed(range(len(shape2) - len(result))):
|
||||
result.append(shape2[i])
|
||||
|
||||
return tuple(reversed(result))
|
||||
@@ -1,33 +0,0 @@
|
||||
"""This file holds the definitions for floating point types."""
|
||||
|
||||
from functools import partial
|
||||
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from . import base
|
||||
|
||||
|
||||
class Float(base.BaseDataType):
|
||||
"""Class representing a float."""
|
||||
|
||||
# bit_width is the total number of bits used to represent a floating point number, including
|
||||
# sign bit, exponent and mantissa
|
||||
bit_width: int
|
||||
|
||||
def __init__(self, bit_width: int) -> None:
|
||||
super().__init__()
|
||||
assert_true(bit_width in (16, 32, 64), "Only 16, 32 and 64 bits floats are supported")
|
||||
self.bit_width = bit_width
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}<{self.bit_width} bits>"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"float{self.bit_width}"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, self.__class__) and self.bit_width == other.bit_width
|
||||
|
||||
|
||||
Float16 = partial(Float, 16)
|
||||
Float32 = partial(Float, 32)
|
||||
Float64 = partial(Float, 64)
|
||||
@@ -1,144 +0,0 @@
|
||||
"""This file holds the definitions for integer types."""
|
||||
|
||||
import math
|
||||
from typing import Any, Iterable
|
||||
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from . import base
|
||||
|
||||
|
||||
class Integer(base.BaseDataType):
|
||||
"""Class representing an integer."""
|
||||
|
||||
bit_width: int
|
||||
is_signed: bool
|
||||
|
||||
def __init__(self, bit_width: int, is_signed: bool) -> None:
|
||||
super().__init__()
|
||||
assert_true(bit_width > 0, "bit_width must be > 0")
|
||||
self.bit_width = bit_width
|
||||
self.is_signed = is_signed
|
||||
|
||||
def __repr__(self) -> str:
|
||||
signed_str = "signed" if self.is_signed else "unsigned"
|
||||
return f"{self.__class__.__name__}<{signed_str}, {self.bit_width} bits>"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{('int' if self.is_signed else 'uint')}{self.bit_width}"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and self.bit_width == other.bit_width
|
||||
and self.is_signed == other.is_signed
|
||||
)
|
||||
|
||||
def min_value(self) -> int:
|
||||
"""Minimum value representable by the Integer."""
|
||||
if self.is_signed:
|
||||
return -(2 ** (self.bit_width - 1))
|
||||
|
||||
return 0
|
||||
|
||||
def max_value(self) -> int:
|
||||
"""Maximum value representable by the Integer."""
|
||||
if self.is_signed:
|
||||
return 2 ** (self.bit_width - 1) - 1
|
||||
|
||||
return 2 ** self.bit_width - 1
|
||||
|
||||
def can_represent_value(self, value_to_represent: int) -> bool:
|
||||
"""Check if a value is representable by the Integer.
|
||||
|
||||
Args:
|
||||
value_to_represent (int): Value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the value can be represented by this integer
|
||||
"""
|
||||
return self.min_value() <= value_to_represent <= self.max_value()
|
||||
|
||||
|
||||
def create_signed_integer(bit_width: int) -> Integer:
|
||||
"""Create a signed integer.
|
||||
|
||||
Args:
|
||||
bit_width (int): width of the integer
|
||||
|
||||
Returns:
|
||||
Integer: A signed integer with the requested bit_width
|
||||
"""
|
||||
return Integer(bit_width, is_signed=True)
|
||||
|
||||
|
||||
SignedInteger = create_signed_integer
|
||||
|
||||
|
||||
def create_unsigned_integer(bit_width: int) -> Integer:
|
||||
"""Create an unsigned integer.
|
||||
|
||||
Args:
|
||||
bit_width (int): width of the integer
|
||||
|
||||
Returns:
|
||||
Integer: An unsigned integer with the requested bit_width
|
||||
"""
|
||||
return Integer(bit_width, is_signed=False)
|
||||
|
||||
|
||||
UnsignedInteger = create_unsigned_integer
|
||||
|
||||
|
||||
def make_integer_to_hold(values: Iterable[Any], force_signed: bool) -> Integer:
|
||||
"""Return an Integer able to hold all values, it is possible to force the Integer to be signed.
|
||||
|
||||
Args:
|
||||
values (Iterable[Any]): The values to hold
|
||||
force_signed (bool): Set to True to force the result to be a signed Integer
|
||||
|
||||
Returns:
|
||||
Integer: The Integer able to hold values
|
||||
"""
|
||||
min_value = min(values)
|
||||
max_value = max(values)
|
||||
|
||||
make_signed_integer = force_signed or min_value < 0
|
||||
|
||||
num_bits = max(
|
||||
get_bits_to_represent_value_as_integer(min_value, make_signed_integer),
|
||||
get_bits_to_represent_value_as_integer(max_value, make_signed_integer),
|
||||
)
|
||||
|
||||
return Integer(num_bits, is_signed=make_signed_integer)
|
||||
|
||||
|
||||
def get_bits_to_represent_value_as_integer(value: Any, force_signed: bool) -> int:
|
||||
"""Return how many bits are required to represent a numerical Value.
|
||||
|
||||
Args:
|
||||
value (Any): The value for which we want to know how many bits are required.
|
||||
force_signed (bool): Set to True to force the result to be a signed integer.
|
||||
|
||||
Returns:
|
||||
int: required amount of bits
|
||||
"""
|
||||
# Writing this in a very dumb way
|
||||
num_bits: int
|
||||
if value < 0:
|
||||
abs_value = abs(value)
|
||||
if abs_value > 1:
|
||||
num_bits = math.ceil(math.log2(abs_value)) + 1
|
||||
else:
|
||||
# -1 case
|
||||
num_bits = 2
|
||||
else:
|
||||
if value > 1:
|
||||
num_bits = math.ceil(math.log2(value + 1))
|
||||
else:
|
||||
# 0 and 1 case
|
||||
num_bits = 1
|
||||
|
||||
if force_signed:
|
||||
num_bits += 1
|
||||
|
||||
return num_bits
|
||||
@@ -1,4 +0,0 @@
|
||||
"""Module for debugging."""
|
||||
from .custom_assert import assert_true
|
||||
from .drawing import draw_graph
|
||||
from .formatting import format_operation_graph
|
||||
@@ -1,49 +0,0 @@
|
||||
"""Provide some variants of assert."""
|
||||
|
||||
|
||||
def _custom_assert(condition: bool, on_error_msg: str = "") -> None:
|
||||
"""Provide a custom assert which is kept even if the optimized python mode is used.
|
||||
|
||||
See https://docs.python.org/3/reference/simple_stmts.html#assert for the documentation
|
||||
on the classical assert function
|
||||
|
||||
Args:
|
||||
condition(bool): the condition. If False, raise AssertionError
|
||||
on_error_msg(str): optional message for precising the error, in case of error
|
||||
|
||||
"""
|
||||
|
||||
if not condition:
|
||||
raise AssertionError(on_error_msg)
|
||||
|
||||
|
||||
def assert_true(condition: bool, on_error_msg: str = ""):
|
||||
"""Provide a custom assert to check that the condition is True.
|
||||
|
||||
Args:
|
||||
condition(bool): the condition. If False, raise AssertionError
|
||||
on_error_msg(str): optional message for precising the error, in case of error
|
||||
|
||||
"""
|
||||
return _custom_assert(condition, on_error_msg)
|
||||
|
||||
|
||||
def assert_false(condition: bool, on_error_msg: str = ""):
|
||||
"""Provide a custom assert to check that the condition is False.
|
||||
|
||||
Args:
|
||||
condition(bool): the condition. If True, raise AssertionError
|
||||
on_error_msg(str): optional message for precising the error, in case of error
|
||||
|
||||
"""
|
||||
return _custom_assert(not condition, on_error_msg)
|
||||
|
||||
|
||||
def assert_not_reached(on_error_msg: str):
|
||||
"""Provide a custom assert to check that a piece of code is never reached.
|
||||
|
||||
Args:
|
||||
on_error_msg(str): message for precising the error
|
||||
|
||||
"""
|
||||
return _custom_assert(False, on_error_msg)
|
||||
@@ -1,156 +0,0 @@
|
||||
"""functions to draw the different graphs we can generate in the package, eg to debug."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
from PIL import Image
|
||||
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import (
|
||||
ALL_IR_NODES,
|
||||
Add,
|
||||
Constant,
|
||||
Conv2D,
|
||||
Dot,
|
||||
GenericFunction,
|
||||
IndexConstant,
|
||||
Input,
|
||||
MatMul,
|
||||
Mul,
|
||||
Sub,
|
||||
)
|
||||
|
||||
IR_NODE_COLOR_MAPPING = {
|
||||
Input: "blue",
|
||||
Constant: "cyan",
|
||||
Conv2D: "brown",
|
||||
Add: "red",
|
||||
Sub: "yellow",
|
||||
Mul: "green",
|
||||
GenericFunction: "orange",
|
||||
IndexConstant: "black",
|
||||
Dot: "purple",
|
||||
MatMul: "brown",
|
||||
"GenericFunction": "orange",
|
||||
"TLU": "grey",
|
||||
"output": "magenta",
|
||||
}
|
||||
|
||||
_missing_nodes_in_mapping = ALL_IR_NODES - IR_NODE_COLOR_MAPPING.keys()
|
||||
assert_true(
|
||||
len(_missing_nodes_in_mapping) == 0,
|
||||
(
|
||||
f"Missing IR node in IR_NODE_COLOR_MAPPING : "
|
||||
f"{', '.join(sorted(str(node_type) for node_type in _missing_nodes_in_mapping))}"
|
||||
),
|
||||
)
|
||||
|
||||
del _missing_nodes_in_mapping
|
||||
|
||||
|
||||
def draw_graph(
|
||||
op_graph: OPGraph,
|
||||
show: bool = False,
|
||||
vertical: bool = True,
|
||||
save_to: Optional[Path] = None,
|
||||
) -> str:
|
||||
"""Draws operation graphs and optionally saves/shows the drawing.
|
||||
|
||||
Note that this function requires the python `pygraphviz` package which itself requires the
|
||||
installation of `graphviz` packages, see
|
||||
https://pygraphviz.github.io/documentation/stable/install.html
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): the operation graph to be drawn and optionally saved/shown
|
||||
show (bool): if set to True, the drawing will be shown using matplotlib
|
||||
vertical (bool): if set to True, the orientation will be vertical
|
||||
save_to (Optional[Path]): if specified, the drawn graph will be saved to this path; else
|
||||
it is saved in a temporary file
|
||||
|
||||
Returns:
|
||||
The path of the file where the drawn graph is saved
|
||||
|
||||
"""
|
||||
|
||||
def get_color(node, output_nodes):
|
||||
value_to_return = IR_NODE_COLOR_MAPPING[type(node)]
|
||||
if node in output_nodes:
|
||||
value_to_return = IR_NODE_COLOR_MAPPING["output"]
|
||||
elif isinstance(node, GenericFunction):
|
||||
value_to_return = IR_NODE_COLOR_MAPPING.get(node.op_name, value_to_return)
|
||||
return value_to_return
|
||||
|
||||
graph = op_graph.graph
|
||||
output_nodes = set(op_graph.output_nodes.values())
|
||||
|
||||
attributes = {
|
||||
node: {
|
||||
"label": node.text_for_drawing(),
|
||||
"color": get_color(node, output_nodes),
|
||||
"penwidth": 2, # double thickness for circles
|
||||
"peripheries": 2 if node in output_nodes else 1, # double circle for output nodes
|
||||
}
|
||||
for node in graph.nodes
|
||||
}
|
||||
nx.set_node_attributes(graph, attributes)
|
||||
|
||||
# TODO: #639 adapt drawing routine to manage output_idx
|
||||
for edge in graph.edges(keys=True):
|
||||
idx = graph.edges[edge]["input_idx"]
|
||||
graph.edges[edge]["label"] = f" {idx} " # spaces are there intentionally for a better look
|
||||
|
||||
try:
|
||||
agraph = nx.nx_agraph.to_agraph(graph)
|
||||
except ImportError as e: # pragma: no cover
|
||||
if "pygraphviz" in str(e):
|
||||
err_msg = (
|
||||
f"{draw_graph.__name__} requires pygraphviz, install your OS graphviz distribution "
|
||||
"https://pygraphviz.github.io/documentation/stable/install.html "
|
||||
"and reinstall with extras: `pip install --force-reinstall "
|
||||
"concrete-numpy[full]`"
|
||||
)
|
||||
raise ImportError(err_msg) from e
|
||||
agraph.graph_attr["rankdir"] = "TB" if vertical else "LR"
|
||||
agraph.layout("dot")
|
||||
|
||||
if save_to is None:
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
||||
# we need to change the permissions of the temporary file
|
||||
# so that it can be read by all users
|
||||
|
||||
# (https://stackoverflow.com/a/44130605)
|
||||
|
||||
# get the old umask and replace it with 0o666
|
||||
old_umask = os.umask(0o666)
|
||||
|
||||
# restore the old umask back
|
||||
os.umask(old_umask)
|
||||
|
||||
# combine the old umask with the wanted permissions
|
||||
permissions = 0o666 & ~old_umask
|
||||
|
||||
# set new permissions
|
||||
os.chmod(tmp.name, permissions)
|
||||
|
||||
save_to_str = str(tmp.name)
|
||||
else:
|
||||
save_to_str = str(save_to)
|
||||
|
||||
agraph.draw(save_to_str)
|
||||
|
||||
if show: # pragma: no cover
|
||||
# We can't have coverage in this branch as `plt.show()` blocks and waits for user action.
|
||||
plt.close("all")
|
||||
plt.figure()
|
||||
img = Image.open(save_to_str)
|
||||
plt.imshow(img)
|
||||
img.close()
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
|
||||
return save_to_str
|
||||
@@ -1,151 +0,0 @@
|
||||
"""Functions to format operation graphs for debugging purposes."""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import GenericFunction, IntermediateNode
|
||||
|
||||
|
||||
def format_operation_graph(
|
||||
op_graph: OPGraph,
|
||||
maximum_constant_length: int = 25,
|
||||
highlighted_nodes: Optional[Dict[IntermediateNode, List[str]]] = None,
|
||||
) -> str:
|
||||
"""Format an operation graph.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph):
|
||||
the operation graph to format
|
||||
|
||||
maximum_constant_length (int):
|
||||
maximum length of the constant throughout the formatting
|
||||
|
||||
highlighted_nodes (Optional[Dict[IntermediateNode, List[str]]] = None):
|
||||
the dict of nodes and their corresponding messages which will be highlighted
|
||||
|
||||
Returns:
|
||||
str: formatted operation graph
|
||||
"""
|
||||
|
||||
# This function is well documented and split into very readable sections
|
||||
# Thus, splitting it to multiple functions doesn't increase readability
|
||||
|
||||
# pylint: disable=too-many-locals,too-many-branches
|
||||
|
||||
assert_true(isinstance(op_graph, OPGraph))
|
||||
|
||||
# (node, output_index) -> identifier
|
||||
# e.g., id_map[(node1, 0)] = 2 and id_map[(node1, 1)] = 3
|
||||
# means line for node1 is in this form (%2, %3) = node1.format(...)
|
||||
id_map: Dict[Tuple[IntermediateNode, int], int] = {}
|
||||
|
||||
# lines that will be merged at the end
|
||||
lines: List[str] = []
|
||||
|
||||
# type information to add to each line (for alingment, this is done after lines are determined)
|
||||
type_informations: List[str] = []
|
||||
|
||||
# default highlighted nodes is empty
|
||||
highlighted_nodes = highlighted_nodes if highlighted_nodes is not None else {}
|
||||
|
||||
# highlight information for lines, this is required because highlights are added to lines
|
||||
# after their type information is added and we only have line numbers, not nodes
|
||||
highlighted_lines: Dict[int, List[str]] = {}
|
||||
|
||||
# subgraphs to format after the main graph is formatted
|
||||
subgraphs: Dict[str, OPGraph] = {}
|
||||
|
||||
# format nodes
|
||||
for node in nx.topological_sort(op_graph.graph):
|
||||
# assign a unique id to outputs of node
|
||||
assert_true(len(node.outputs) > 0)
|
||||
for i in range(len(node.outputs)):
|
||||
id_map[(node, i)] = len(id_map)
|
||||
|
||||
# remember highlights of the node
|
||||
if node in highlighted_nodes:
|
||||
highlighted_lines[len(lines)] = highlighted_nodes[node]
|
||||
|
||||
# extract predecessors and their ids
|
||||
predecessors = []
|
||||
for predecessor, output_idx in op_graph.get_ordered_preds_and_inputs_of(node):
|
||||
predecessors.append(f"%{id_map[(predecessor, output_idx)]}")
|
||||
|
||||
# start the build the line for the node
|
||||
line = ""
|
||||
|
||||
# add output information to the line
|
||||
outputs = ", ".join(f"%{id_map[(node, i)]}" for i in range(len(node.outputs)))
|
||||
line += outputs if len(node.outputs) == 1 else f"({outputs})"
|
||||
|
||||
# add node information to the line
|
||||
line += " = "
|
||||
line += node.text_for_formatting(predecessors, maximum_constant_length)
|
||||
|
||||
# append line to list of lines
|
||||
lines.append(line)
|
||||
|
||||
# if exists, save the subgraph
|
||||
if isinstance(node, GenericFunction) and "float_op_subgraph" in node.op_kwargs:
|
||||
subgraphs[line] = node.op_kwargs["float_op_subgraph"]
|
||||
|
||||
# remember type information of the node
|
||||
types = ", ".join(str(output) for output in node.outputs)
|
||||
type_informations.append(types if len(node.outputs) == 1 else f"({types})")
|
||||
|
||||
# align = signs
|
||||
#
|
||||
# e.g.,
|
||||
#
|
||||
# %1 = ...
|
||||
# %2 = ...
|
||||
# ...
|
||||
# %8 = ...
|
||||
# %9 = ...
|
||||
# %10 = ...
|
||||
# %11 = ...
|
||||
# ...
|
||||
longest_length_before_equals_sign = max(len(line.split("=")[0]) for line in lines)
|
||||
for i, line in enumerate(lines):
|
||||
length_before_equals_sign = len(line.split("=")[0])
|
||||
lines[i] = (" " * (longest_length_before_equals_sign - length_before_equals_sign)) + line
|
||||
|
||||
# add type informations
|
||||
longest_line_length = max(len(line) for line in lines)
|
||||
for i, line in enumerate(lines):
|
||||
lines[i] += " " * (longest_line_length - len(line))
|
||||
lines[i] += f" # {type_informations[i]}"
|
||||
|
||||
# add highlights (this is done in reverse to keep indices consistent)
|
||||
for i in reversed(range(len(lines))):
|
||||
if i in highlighted_lines:
|
||||
for j, message in enumerate(highlighted_lines[i]):
|
||||
highlight = "^" if j == 0 else " "
|
||||
lines.insert(i + 1 + j, f"{highlight * len(lines[i])} {message}")
|
||||
|
||||
# add return information
|
||||
# (if there is a single return, it's in the form `return %id`
|
||||
# (otherwise, it's in the form `return (%id1, %id2, ..., %idN)`
|
||||
returns: List[str] = []
|
||||
for node in op_graph.output_nodes.values():
|
||||
outputs = ", ".join(f"%{id_map[(node, i)]}" for i in range(len(node.outputs)))
|
||||
returns.append(outputs if len(node.outputs) == 1 else f"({outputs})")
|
||||
lines.append("return " + returns[0] if len(returns) == 1 else f"({', '.join(returns)})")
|
||||
|
||||
# format subgraphs after the actual graph
|
||||
result = "\n".join(lines)
|
||||
if len(subgraphs) > 0:
|
||||
result += "\n\n"
|
||||
result += "Subgraphs:"
|
||||
for line, subgraph in subgraphs.items():
|
||||
subgraph_lines = format_operation_graph(subgraph, maximum_constant_length).split("\n")
|
||||
result += "\n\n"
|
||||
result += f" {line}:\n\n"
|
||||
result += "\n".join(f" {line}" for line in subgraph_lines)
|
||||
|
||||
# pylint: enable=too-many-locals,too-many-branches
|
||||
|
||||
return result
|
||||
@@ -1,2 +0,0 @@
|
||||
"""Extensions module to provide additional functionality to our users."""
|
||||
from . import convolution, multi_table, table
|
||||
@@ -1,161 +0,0 @@
|
||||
"""This file contains tracers for convolution operations."""
|
||||
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...numpy.tracing import NPConstant, NPTracer
|
||||
from ..representation.intermediate import Conv2D
|
||||
from ..tracing.base_tracer import BaseTracer
|
||||
|
||||
SUPPORTED_AUTO_PAD = [
|
||||
"NOTSET",
|
||||
]
|
||||
|
||||
|
||||
def conv2d(
|
||||
x: Union[np.ndarray, BaseTracer],
|
||||
weight: Union[np.ndarray, BaseTracer],
|
||||
bias: Optional[Union[np.ndarray, BaseTracer]] = None,
|
||||
pads: Union[Tuple[int, int, int, int], List[int]] = (0, 0, 0, 0),
|
||||
strides: Union[Tuple[int, int], List[int]] = (1, 1),
|
||||
dilations: Union[Tuple[int, int], List[int]] = (1, 1),
|
||||
auto_pad: str = "NOTSET",
|
||||
) -> Union[np.ndarray, NPTracer]:
|
||||
"""Trace or evaluate 2D convolution.
|
||||
|
||||
Args:
|
||||
x (Union[np.ndarray, BaseTracer]): Input of shape (NxCxHxW)
|
||||
weight (Union[np.ndarray, BaseTracer]): Weight (kernel) of shape (FxCxHxW)
|
||||
bias (Optional[Union[np.ndarray, BaseTracer]], optional): Bias vector of size (F).
|
||||
Defaults to None.
|
||||
pads (Union[Tuple[int, int, int, int], List[int]], optional): Padding over each axis
|
||||
(H_beg, W_beg, H_end, W_end). Defaults to (0, 0, 0, 0).
|
||||
strides (Union[Tuple[int, int], List[int]], optional): Stride over each axis
|
||||
(height and width). Defaults to (1, 1).
|
||||
dilations (Union[Tuple[int, int], List[int]], optional): Dilation over each axis
|
||||
(height and width). Defaults to (1, 1).
|
||||
auto_pad (str, optional): Padding strategy. Defaults to "NOTSET".
|
||||
|
||||
Raises:
|
||||
ValueError: If one argument isn't in the range of expected values.
|
||||
TypeError: If one argument isn't of the appropriate type.
|
||||
|
||||
Returns:
|
||||
Union[np.ndarray, BaseTracer]: Evaluation result, or traced computation
|
||||
"""
|
||||
if auto_pad not in SUPPORTED_AUTO_PAD:
|
||||
raise ValueError("invalid auto_pad is specified")
|
||||
|
||||
if not isinstance(x, (np.ndarray, BaseTracer)):
|
||||
raise TypeError(f"input x must be an ndarray, or a BaseTracer, not a {type(x)}")
|
||||
if not isinstance(weight, (np.ndarray, BaseTracer)):
|
||||
raise TypeError(f"weight must be an ndarray, or a BaseTracer, not a {type(weight)}")
|
||||
if not isinstance(bias, (np.ndarray, BaseTracer, type(None))):
|
||||
raise TypeError(f"bias must be an ndarray, a BaseTracer, or None, not a {type(bias)}")
|
||||
if not isinstance(pads, (tuple, list)):
|
||||
raise TypeError(f"padding must be a tuple, or list, not a {type(pads)}")
|
||||
if not isinstance(strides, (tuple, list)):
|
||||
raise TypeError(f"strides must be a tuple, or list, not a {type(strides)}")
|
||||
if not isinstance(dilations, (tuple, list)):
|
||||
raise TypeError(f"dilations must be a tuple, or list, not a {type(dilations)}")
|
||||
|
||||
if len(pads) != 4:
|
||||
raise ValueError(
|
||||
f"padding should be of the form (pad_height_begin, pad_width_begin, pad_height_end, "
|
||||
f" pad_width_end), but got {type(pads)} of length {len(pads)}"
|
||||
)
|
||||
if len(strides) != 2:
|
||||
raise ValueError(
|
||||
f"strides should be of the form (stride_height, stride_width), but got {type(strides)}"
|
||||
f" of length {len(strides)}"
|
||||
)
|
||||
if len(dilations) != 2:
|
||||
raise ValueError(
|
||||
f"dilations should be of the form (dilation_height, dilation_width), but got"
|
||||
f" {type(dilations)} of length {len(dilations)}"
|
||||
)
|
||||
|
||||
assert len(x.shape) == 4, f"input x should have size (N x C x H x W), not {x.shape}"
|
||||
assert len(weight.shape) == 4, f"weight should have size (F x C x H x W), not {weight.shape}"
|
||||
if bias is not None:
|
||||
assert len(bias.shape) == 1, f"bias should have size (F), not {bias.shape}"
|
||||
|
||||
if isinstance(x, BaseTracer):
|
||||
return _trace_conv2d(x, weight, bias, pads, strides, dilations)
|
||||
# X is an ndarray
|
||||
bias = np.zeros(weight.shape[0]) if bias is None else bias
|
||||
# For mypy
|
||||
weight = cast(np.ndarray, weight)
|
||||
bias = cast(np.ndarray, bias)
|
||||
return _evaluate_conv2d(x, weight, bias, pads, strides, dilations)
|
||||
|
||||
|
||||
def _trace_conv2d(
|
||||
x: BaseTracer,
|
||||
weight: Union[np.ndarray, BaseTracer],
|
||||
bias: Optional[Union[np.ndarray, BaseTracer]],
|
||||
pads: Union[Tuple[int, int, int, int], List[int]],
|
||||
strides: Union[Tuple[int, int], List[int]],
|
||||
dilations: Union[Tuple[int, int], List[int]],
|
||||
) -> NPTracer:
|
||||
"""Trace 2D convolution.
|
||||
|
||||
Args:
|
||||
x (BaseTracer): Input of shape (NxCxHxW)
|
||||
weight (Union[np.ndarray, BaseTracer]): Weight (kernel) of shape (FxCxHxW)
|
||||
bias (Optional[Union[np.ndarray, BaseTracer]]): Bias vector of size (F)
|
||||
pads (Union[Tuple[int, int, int, int], List[int]]): Padding over each
|
||||
axis (H_beg, W_beg, H_end, W_end)
|
||||
strides (Union[Tuple[int, int], List[int]]): Stride over each
|
||||
axis (height and width)
|
||||
dilations (Union[Tuple[int, int], List[int]]): Dilation over each
|
||||
axis (height and width)
|
||||
|
||||
Returns:
|
||||
BaseTracer: Traced computation
|
||||
"""
|
||||
weight_tracer = (
|
||||
weight if isinstance(weight, BaseTracer) else NPTracer([], NPConstant(weight), 0)
|
||||
)
|
||||
inputs = [x.output, weight_tracer.output]
|
||||
output_tracer_inputs = [x, weight_tracer]
|
||||
if bias is not None:
|
||||
bias_tracer = bias if isinstance(bias, BaseTracer) else NPTracer([], NPConstant(bias), 0)
|
||||
inputs.append(bias_tracer.output)
|
||||
# For mypy
|
||||
bias = cast(BaseTracer, bias_tracer)
|
||||
output_tracer_inputs.append(bias)
|
||||
|
||||
traced_computation = Conv2D(inputs, x.output.dtype, pads, strides, dilations)
|
||||
output_tracer = x.__class__(
|
||||
output_tracer_inputs, traced_computation=traced_computation, output_idx=0
|
||||
)
|
||||
# For mypy
|
||||
assert isinstance(output_tracer, NPTracer)
|
||||
return output_tracer
|
||||
|
||||
|
||||
def _evaluate_conv2d(
|
||||
x: np.ndarray,
|
||||
weight: np.ndarray,
|
||||
bias: np.ndarray,
|
||||
pads: Union[Tuple[int, int, int, int], List[int]],
|
||||
strides: Union[Tuple[int, int], List[int]],
|
||||
dilations: Union[Tuple[int, int], List[int]],
|
||||
) -> np.ndarray:
|
||||
"""Evaluate 2D convolution.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Input of shape (NxCxHxW)
|
||||
weight (np.ndarray): Weight (kernel) of shape (FxCxHxW)
|
||||
bias (np.ndarray): Bias vector of size (F)
|
||||
pads (Union[Tuple[int, int, int, int], List[int]]): Padding over each
|
||||
axis (H_beg, W_beg, H_end, W_end)
|
||||
strides (Union[Tuple[int, int], List[int]]): Stride over each axis (height and width)
|
||||
dilations (Union[Tuple[int, int], List[int]]): Dilation over each axis (height and width)
|
||||
|
||||
Returns:
|
||||
np.ndarray: Result of the convolution of shape (NxCxHxW)
|
||||
"""
|
||||
return Conv2D.evaluate_conv2d(x, weight, bias, pads, strides, dilations)
|
||||
@@ -1,233 +0,0 @@
|
||||
"""This file contains a wrapper class for direct multi table lookups."""
|
||||
|
||||
import itertools
|
||||
from copy import deepcopy
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from ..data_types.base import BaseDataType
|
||||
from ..data_types.dtypes_helpers import find_type_to_hold_both_lossy
|
||||
from ..representation.intermediate import GenericFunction
|
||||
from ..tracing.base_tracer import BaseTracer
|
||||
from ..values import TensorValue
|
||||
from .table import LookupTable
|
||||
|
||||
|
||||
class MultiLookupTable:
|
||||
"""Class representing a multi lookup table."""
|
||||
|
||||
# Multi table lookup is needed when you want to perform a lookup on a tensor,
|
||||
# but you want each element to be used with a different lookup table.
|
||||
#
|
||||
# Here is an example:
|
||||
#
|
||||
# You have x which is of shape (2, 3),
|
||||
# you want the first row to be indexed with `table1 = LookupTable([2, 3, 1, 0])`
|
||||
# and the second row to be indexed with `table1 = LookupTable([0, 1, 3, 2])`
|
||||
#
|
||||
# You can create such a multi lookup table
|
||||
# multitable = MultiLookupTable(
|
||||
# [
|
||||
# [table1, table1, table1],
|
||||
# [table2, table2, table2],
|
||||
# ],
|
||||
# )
|
||||
# (notice the shape of multitable matches with the shape of x)
|
||||
#
|
||||
# and use multitable[x] toget the following result
|
||||
# assert multitable[x] == [
|
||||
# [table1[x[0, 0]], table1[x[0, 1]], table1[x[0, 2]]],
|
||||
# [table2[x[1, 0]], table2[x[1, 1]], table2[x[1, 2]]],
|
||||
# ]
|
||||
|
||||
# underlying lookup tables
|
||||
tables: List
|
||||
|
||||
# shape of the input of the lookup
|
||||
input_shape: Tuple[int, ...]
|
||||
|
||||
# type of the result of the lookup
|
||||
output_dtype: BaseDataType
|
||||
|
||||
def __init__(self, tables: List):
|
||||
input_shape_list: List[int] = []
|
||||
MultiLookupTable._extract_shape_using_first_elements_only(tables, input_shape_list)
|
||||
input_shape: Tuple[int, ...] = tuple(input_shape_list)
|
||||
|
||||
table_sizes: List[int] = []
|
||||
table_output_dtypes: List[BaseDataType] = []
|
||||
MultiLookupTable._check_shape_and_record_luts(
|
||||
tables,
|
||||
0,
|
||||
input_shape,
|
||||
table_sizes,
|
||||
table_output_dtypes,
|
||||
)
|
||||
|
||||
for i in range(1, len(table_sizes)):
|
||||
if table_sizes[i - 1] != table_sizes[i]:
|
||||
# this branch is for such a case:
|
||||
#
|
||||
# table1 = hnp.LookupTable([1, 3])
|
||||
# table2 = hnp.LookupTable([0, 2, 3, 1])
|
||||
#
|
||||
# multitable = hnp.MultiLookupTable(
|
||||
# [
|
||||
# [table1, table2, table1],
|
||||
# [table2, table1, table2],
|
||||
# ],
|
||||
# )
|
||||
raise ValueError(
|
||||
f"LookupTables within a MultiLookupTable "
|
||||
f"should have the same size but they do not "
|
||||
f"(there was a table with the size of {table_sizes[i - 1]} "
|
||||
f"and another with the size of {table_sizes[i]})"
|
||||
)
|
||||
|
||||
output_dtype = table_output_dtypes[0]
|
||||
for table_output_dtype in table_output_dtypes:
|
||||
output_dtype = find_type_to_hold_both_lossy(output_dtype, table_output_dtype)
|
||||
|
||||
self.tables = tables
|
||||
self.input_shape = input_shape
|
||||
self.output_dtype = output_dtype
|
||||
|
||||
def __getitem__(self, key: Union[int, BaseTracer]):
|
||||
# this branch is used during tracing and the regular flow is used during evaluation
|
||||
if isinstance(key, BaseTracer):
|
||||
out_dtype = deepcopy(key.output.dtype)
|
||||
out_shape = deepcopy(self.input_shape)
|
||||
|
||||
generic_function_output_value = TensorValue(
|
||||
out_dtype,
|
||||
key.output.is_encrypted,
|
||||
out_shape,
|
||||
)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[key.output],
|
||||
arbitrary_func=MultiLookupTable._checked_indexing,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
op_kwargs={
|
||||
"input_shape": deepcopy(self.input_shape),
|
||||
"tables": deepcopy(self.tables),
|
||||
},
|
||||
op_name="MultiTLU",
|
||||
)
|
||||
return key.__class__(
|
||||
inputs=[key],
|
||||
traced_computation=traced_computation,
|
||||
output_idx=0,
|
||||
)
|
||||
|
||||
# if not, it means table is indexed with a constant
|
||||
# thus, the result of the lookup is a constant
|
||||
# so, we can propagate it directly
|
||||
return MultiLookupTable._checked_indexing(key, self.input_shape, self.tables)
|
||||
|
||||
@staticmethod
|
||||
def _extract_shape_using_first_elements_only(array, shape):
|
||||
if not isinstance(array, list):
|
||||
# base case for recursion
|
||||
# the shape is already accumulated up to this point
|
||||
# so we just return
|
||||
return
|
||||
|
||||
if len(array) == 0:
|
||||
# this branch is for such a case:
|
||||
#
|
||||
# table1 = hnp.LookupTable([1, 3, 2, 0])
|
||||
# table2 = hnp.LookupTable([0, 2, 3, 1])
|
||||
#
|
||||
# multitable = hnp.MultiLookupTable(
|
||||
# [
|
||||
# [],
|
||||
# [table1, table2, table1],
|
||||
# [table2, table1, table2],
|
||||
# ],
|
||||
# )
|
||||
|
||||
raise ValueError("MultiLookupTable cannot have an empty array within it")
|
||||
|
||||
shape.append(len(array))
|
||||
MultiLookupTable._extract_shape_using_first_elements_only(array[0], shape)
|
||||
|
||||
@staticmethod
|
||||
def _check_shape_and_record_luts(array, dimension, shape, table_sizes, table_output_dtypes):
|
||||
if dimension == len(shape):
|
||||
if not isinstance(array, LookupTable):
|
||||
# this branch is for such a case:
|
||||
#
|
||||
# table1 = hnp.LookupTable([1, 3, 2, 0])
|
||||
# table2 = hnp.LookupTable([0, 2, 3, 1])
|
||||
#
|
||||
# multitable = hnp.MultiLookupTable(
|
||||
# [
|
||||
# [table1, table2, 4],
|
||||
# [table2, table1, table2],
|
||||
# ],
|
||||
# )
|
||||
raise ValueError(
|
||||
f"MultiLookupTable should have been made out of LookupTables "
|
||||
f"but it had an object of type {array.__class__.__name__} within it"
|
||||
)
|
||||
|
||||
table_sizes.append(len(array.table))
|
||||
table_output_dtypes.append(array.output_dtype)
|
||||
return
|
||||
|
||||
if not isinstance(array, list) or len(array) != shape[dimension]:
|
||||
# this branch is for such a case:
|
||||
#
|
||||
# table1 = hnp.LookupTable([1, 3, 2, 0])
|
||||
# table2 = hnp.LookupTable([0, 2, 3, 1])
|
||||
#
|
||||
# multitable = hnp.MultiLookupTable(
|
||||
# [
|
||||
# [table1, table2],
|
||||
# [table2, table1, table2],
|
||||
# ],
|
||||
# )
|
||||
raise ValueError(
|
||||
f"MultiLookupTable should have the shape {shape} but it does not "
|
||||
f"(an array on dimension {dimension} has the size {len(array)} "
|
||||
f"but its size should have been {shape[dimension]} "
|
||||
f"as the expected shape is {shape})"
|
||||
)
|
||||
|
||||
for item in array:
|
||||
MultiLookupTable._check_shape_and_record_luts(
|
||||
item,
|
||||
dimension + 1,
|
||||
shape,
|
||||
table_sizes,
|
||||
table_output_dtypes,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _checked_indexing(x, input_shape, tables):
|
||||
try:
|
||||
result = []
|
||||
for indices in itertools.product(*[range(dimension) for dimension in input_shape]):
|
||||
which_table_to_use = tables
|
||||
what_value_to_use = x
|
||||
where_to_append = result
|
||||
|
||||
for index in indices[:-1]:
|
||||
which_table_to_use = tables[index]
|
||||
what_value_to_use = x[index]
|
||||
|
||||
if len(where_to_append) == index:
|
||||
where_to_append.append([])
|
||||
where_to_append = result[index]
|
||||
|
||||
which_table_to_use = which_table_to_use[indices[-1]]
|
||||
what_value_to_use = what_value_to_use[indices[-1]]
|
||||
where_to_append.append(which_table_to_use[what_value_to_use])
|
||||
except Exception as error:
|
||||
raise ValueError(
|
||||
f"Multiple Lookup Table of shape {input_shape} cannot be looked up with {x} "
|
||||
f"(you should check your inputset)",
|
||||
) from error
|
||||
|
||||
return result
|
||||
@@ -1,118 +0,0 @@
|
||||
"""This file contains a wrapper class for direct table lookups."""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Iterable, List, Tuple, Union
|
||||
|
||||
from ..common_helpers import is_a_power_of_2
|
||||
from ..data_types.base import BaseDataType
|
||||
from ..data_types.integers import make_integer_to_hold
|
||||
from ..representation.intermediate import GenericFunction
|
||||
from ..tracing.base_tracer import BaseTracer
|
||||
|
||||
|
||||
class LookupTable:
|
||||
"""Class representing a lookup table."""
|
||||
|
||||
# lookup table itself, has 2^N entries
|
||||
table: Tuple[int, ...]
|
||||
|
||||
# type of the result of the lookup
|
||||
output_dtype: BaseDataType
|
||||
|
||||
def __init__(self, table: Iterable[int]):
|
||||
table = tuple(table)
|
||||
|
||||
if not is_a_power_of_2(len(table)):
|
||||
raise ValueError(
|
||||
f"Desired lookup table has inappropriate number of entries ({len(table)})"
|
||||
)
|
||||
|
||||
self.table = table
|
||||
self.output_dtype = make_integer_to_hold(table, force_signed=False)
|
||||
|
||||
def __repr__(self):
|
||||
return str(list(self.table))
|
||||
|
||||
def __getitem__(self, key: Union[int, Iterable, BaseTracer]):
|
||||
# if a tracer is used for indexing,
|
||||
# we need to create an `GenericFunction` node
|
||||
# because the result will be determined during the runtime
|
||||
if isinstance(key, BaseTracer):
|
||||
generic_function_output_value = deepcopy(key.output)
|
||||
generic_function_output_value.dtype = self.output_dtype
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[key.output],
|
||||
arbitrary_func=LookupTable._checked_indexing,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
op_kwargs={"table": deepcopy(self.table)},
|
||||
op_name="TLU",
|
||||
)
|
||||
return key.__class__(
|
||||
inputs=[key],
|
||||
traced_computation=traced_computation,
|
||||
output_idx=0,
|
||||
)
|
||||
|
||||
# if not, it means table is indexed with a constant
|
||||
# thus, the result of the lookup is a constant
|
||||
# so, we can propagate it directly
|
||||
return LookupTable._checked_indexing(key, self.table)
|
||||
|
||||
@staticmethod
|
||||
def _check_index_out_of_range(x, table):
|
||||
if not -len(table) <= x < len(table):
|
||||
raise ValueError(
|
||||
f"Lookup table with {len(table)} entries cannot be indexed with {x} "
|
||||
f"(you should check your inputset)",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _checked_indexing(x, table):
|
||||
"""Index `table` using `x`.
|
||||
|
||||
There is a single table and the indexing works with the following semantics:
|
||||
- when x == c
|
||||
- table[x] == table[c]
|
||||
- when x == [c1, c2]
|
||||
- table[x] == [table[c1], table[c2]]
|
||||
- when x == [[c1, c2], [c3, c4], [c5, c6]]
|
||||
- table[x] == [[table[c1], table[c2]], [table[c3], table[c4]], [table[c5], table[c6]]]
|
||||
|
||||
Args:
|
||||
x (Union[int, Iterable]): index to use
|
||||
table (Tuple[int, ...]): table to index
|
||||
|
||||
Returns:
|
||||
Union[int, List[int]]: result of indexing
|
||||
"""
|
||||
|
||||
if not isinstance(x, Iterable):
|
||||
LookupTable._check_index_out_of_range(x, table)
|
||||
return table[x]
|
||||
|
||||
def fill_result(partial_result: List[Any], partial_x: Iterable[Any]):
|
||||
"""Fill partial result with partial x.
|
||||
|
||||
This function implements the recursive indexing of nested iterables.
|
||||
|
||||
Args:
|
||||
partial_result (List[Any]): currently accumulated result
|
||||
partial_x (Iterable[Any]): current index to use
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
for item in partial_x:
|
||||
if isinstance(item, Iterable):
|
||||
partial_result.append([])
|
||||
fill_result(partial_result[-1], item)
|
||||
else:
|
||||
LookupTable._check_index_out_of_range(item, table)
|
||||
partial_result.append(table[item])
|
||||
|
||||
result = []
|
||||
fill_result(result, x)
|
||||
return result
|
||||
@@ -1,144 +0,0 @@
|
||||
"""Module to hold the result of compilation."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy
|
||||
from concrete.compiler import (
|
||||
ClientParameters,
|
||||
ClientSupport,
|
||||
CompilationOptions,
|
||||
JITCompilationResult,
|
||||
JITLambda,
|
||||
JITSupport,
|
||||
KeySet,
|
||||
KeySetCache,
|
||||
PublicArguments,
|
||||
PublicResult,
|
||||
)
|
||||
|
||||
from .debugging import draw_graph, format_operation_graph
|
||||
from .operator_graph import OPGraph
|
||||
|
||||
|
||||
class FHECircuit:
|
||||
"""Class which is the result of compilation."""
|
||||
|
||||
op_graph: OPGraph
|
||||
_jit_support: JITSupport
|
||||
_compilation_result: JITCompilationResult
|
||||
_client_parameters: ClientParameters
|
||||
_server_lambda: JITLambda
|
||||
_keyset_cache: KeySetCache
|
||||
_keyset: KeySet
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
op_graph: OPGraph,
|
||||
mlir_str: str,
|
||||
unsecure_key_set_cache_path: Optional[str] = None,
|
||||
auto_parallelize: bool = False,
|
||||
loop_parallelize: bool = False,
|
||||
dataflow_parallelize: bool = False,
|
||||
):
|
||||
self.op_graph = op_graph
|
||||
self._jit_support = JITSupport.new()
|
||||
# Set compilation options
|
||||
options = CompilationOptions.new("main")
|
||||
options.set_auto_parallelize(auto_parallelize)
|
||||
options.set_loop_parallelize(loop_parallelize)
|
||||
options.set_dataflow_parallelize(dataflow_parallelize)
|
||||
# Compile
|
||||
self._compilation_result = self._jit_support.compile(mlir_str, options)
|
||||
self._client_parameters = self._jit_support.load_client_parameters(self._compilation_result)
|
||||
self._server_lambda = self._jit_support.load_server_lambda(self._compilation_result)
|
||||
# Setup keyset cache
|
||||
self._keyset_cache = None
|
||||
if unsecure_key_set_cache_path:
|
||||
self._keyset_cache = KeySetCache.new(unsecure_key_set_cache_path)
|
||||
self._keyset = None
|
||||
|
||||
def __str__(self):
|
||||
return format_operation_graph(self.op_graph)
|
||||
|
||||
def draw(
|
||||
self,
|
||||
show: bool = False,
|
||||
vertical: bool = True,
|
||||
save_to: Optional[Path] = None,
|
||||
) -> str:
|
||||
"""Draw operation graph of the circuit and optionally save/show the drawing.
|
||||
|
||||
Args:
|
||||
show (bool): if set to True, the drawing will be shown using matplotlib
|
||||
vertical (bool): if set to True, the orientation will be vertical
|
||||
save_to (Optional[Path]): if specified, the drawn graph will be saved to this path;
|
||||
otherwise it will be saved to a temporary file
|
||||
|
||||
Returns:
|
||||
str: path of the file where the drawn graph is saved
|
||||
|
||||
"""
|
||||
|
||||
return draw_graph(self.op_graph, show, vertical, save_to)
|
||||
|
||||
def keygen(self, force: bool = False):
|
||||
"""Generate the keys required for the encrypted circuit.
|
||||
|
||||
Args:
|
||||
force (bool, optional): generate even if keyset already exists. Defaults to False.
|
||||
"""
|
||||
if self._keyset is None or force:
|
||||
self._keyset = ClientSupport.key_set(self._client_parameters, self._keyset_cache)
|
||||
|
||||
def encrypt(self, *args: Union[int, numpy.ndarray]) -> PublicArguments:
|
||||
"""Encrypt the inputs of the circuit.
|
||||
|
||||
Args:
|
||||
*args (Union[int, numpy.ndarray]): plain input of the circuit
|
||||
|
||||
Returns:
|
||||
PublicArguments: encrypted and plain arguments as well as public keys
|
||||
"""
|
||||
# Make sure keys are available: shouldn't regenerate if they already exist
|
||||
self.keygen(force=False)
|
||||
return ClientSupport.encrypt_arguments(self._client_parameters, self._keyset, args)
|
||||
|
||||
def run(self, args: PublicArguments) -> PublicResult:
|
||||
"""Evaluate the the encrypted circuit (no encryption or decryption involved).
|
||||
|
||||
Args:
|
||||
args (PublicArguments): encrypted inputs to the circuit
|
||||
|
||||
Returns:
|
||||
PublicResult: encrypted result
|
||||
"""
|
||||
return self._jit_support.server_call(self._server_lambda, args)
|
||||
|
||||
def decrypt(self, result: PublicResult) -> Union[int, numpy.ndarray]:
|
||||
"""Decrypt the result of the circuit.
|
||||
|
||||
Args:
|
||||
result (PublicResult): encrypted result of the circuit
|
||||
|
||||
Returns:
|
||||
Union[int, numpy.ndarray]: plain result of the circuit
|
||||
"""
|
||||
return ClientSupport.decrypt_result(self._keyset, result)
|
||||
|
||||
def encrypt_run_decrypt(self, *args: Union[int, numpy.ndarray]) -> Union[int, numpy.ndarray]:
|
||||
"""Encrypt, evaluate, and decrypt the inputs on the circuit.
|
||||
|
||||
Generate keyset automatically if not yet done.
|
||||
|
||||
Args:
|
||||
*args (Union[int, numpy.ndarray]): plain inputs of the circuit
|
||||
|
||||
Returns:
|
||||
Union[int, numpy.ndarray]: plain result of the circuit
|
||||
"""
|
||||
self.keygen(force=False)
|
||||
public_args = self.encrypt(*args)
|
||||
encrypted_result = self.run(public_args)
|
||||
decrypted_result = self.decrypt(encrypted_result)
|
||||
return decrypted_result
|
||||
@@ -1,3 +0,0 @@
|
||||
"""Helpers for all kinds of tasks."""
|
||||
|
||||
from . import indexing_helpers, python_helpers
|
||||
@@ -1,47 +0,0 @@
|
||||
"""Helpers for formatting functionality."""
|
||||
|
||||
from typing import Any, Dict, Hashable
|
||||
|
||||
import numpy
|
||||
|
||||
from ..debugging.custom_assert import assert_true
|
||||
|
||||
SPECIAL_OBJECT_MAPPING: Dict[Any, str] = {
|
||||
numpy.float32: "float32",
|
||||
numpy.float64: "float64",
|
||||
numpy.int8: "int8",
|
||||
numpy.int16: "int16",
|
||||
numpy.int32: "int32",
|
||||
numpy.int64: "int64",
|
||||
numpy.uint8: "uint8",
|
||||
numpy.uint16: "uint16",
|
||||
numpy.uint32: "uint32",
|
||||
numpy.uint64: "uint64",
|
||||
}
|
||||
|
||||
|
||||
def format_constant(constant: Any, maximum_length: int = 45) -> str:
|
||||
"""Format a constant.
|
||||
|
||||
Args:
|
||||
constant (Any): the constant to format
|
||||
maximum_length (int): maximum length of the resulting string
|
||||
|
||||
Returns:
|
||||
str: the formatted constant
|
||||
"""
|
||||
|
||||
if isinstance(constant, Hashable) and constant in SPECIAL_OBJECT_MAPPING:
|
||||
return SPECIAL_OBJECT_MAPPING[constant]
|
||||
|
||||
# maximum_length should not be smaller than 7 characters because
|
||||
# the constant will be formatted to `x ... y`
|
||||
# where x and y are part of the constant and they are at least 1 character
|
||||
assert_true(maximum_length >= 7)
|
||||
|
||||
content = str(constant).replace("\n", "")
|
||||
if len(content) > maximum_length:
|
||||
from_start = (maximum_length - 5) // 2
|
||||
from_end = (maximum_length - 5) - from_start
|
||||
content = f"{content[:from_start]} ... {content[-from_end:]}"
|
||||
return content
|
||||
@@ -1,277 +0,0 @@
|
||||
"""Helpers for indexing functionality."""
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
|
||||
def format_indexing_element(indexing_element: Union[int, slice]) -> str:
|
||||
"""Format an indexing element.
|
||||
|
||||
This is required mainly for slices. The reason is that string representation of slices
|
||||
are very long and verbose. To give an example, `x[:, 2:]` will have the following index
|
||||
`[slice(None, None, None), slice(2, None, None)]` if printed naively. With this helper,
|
||||
it will be formatted as `[:, 2:]`.
|
||||
|
||||
Args:
|
||||
indexing_element (Union[int, slice]): indexing element to be formatted
|
||||
|
||||
Returns:
|
||||
str: formatted element
|
||||
"""
|
||||
|
||||
result = ""
|
||||
if isinstance(indexing_element, slice):
|
||||
if indexing_element.start is not None:
|
||||
result += str(indexing_element.start)
|
||||
result += ":"
|
||||
if indexing_element.stop is not None:
|
||||
result += str(indexing_element.stop)
|
||||
if indexing_element.step is not None:
|
||||
result += ":"
|
||||
result += str(indexing_element.step)
|
||||
else:
|
||||
result += str(indexing_element)
|
||||
return result.replace("\n", " ")
|
||||
|
||||
|
||||
def validate_index(
|
||||
index: Union[int, slice, Tuple[Union[int, slice], ...]],
|
||||
) -> Tuple[Union[int, slice], ...]:
|
||||
"""Make sure index is valid and convert it to the tuple form.
|
||||
|
||||
For example in `x[2]`, `index` is passed as `2`.
|
||||
To make it easier to work with, this function converts index to `(2,)`.
|
||||
|
||||
Args:
|
||||
index (Union[int, slice, Tuple[Union[int, slice], ...]]): index to validate, improve
|
||||
and return
|
||||
|
||||
Returns:
|
||||
Tuple[Union[int, slice], ...]: validated and improved index
|
||||
"""
|
||||
|
||||
if not isinstance(index, tuple):
|
||||
index = (index,)
|
||||
|
||||
for indexing_element in index:
|
||||
valid = isinstance(indexing_element, (int, slice))
|
||||
|
||||
if isinstance(indexing_element, slice):
|
||||
if (
|
||||
not (indexing_element.start is None or isinstance(indexing_element.start, int))
|
||||
or not (indexing_element.stop is None or isinstance(indexing_element.stop, int))
|
||||
or not (indexing_element.step is None or isinstance(indexing_element.step, int))
|
||||
):
|
||||
valid = False
|
||||
|
||||
if not valid:
|
||||
raise TypeError(
|
||||
f"Only integers and integer slices can be used for indexing "
|
||||
f"but you tried to use {format_indexing_element(indexing_element)} for indexing"
|
||||
)
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def determine_output_shape(
|
||||
input_shape: Tuple[int, ...],
|
||||
index: Tuple[Union[int, slice], ...],
|
||||
) -> Tuple[int, ...]:
|
||||
"""Determine the output shape from the input shape and the index.
|
||||
|
||||
e.g., for `input_shape=(3, 2)` and `index=(:, 0)`, returns `(3,)`
|
||||
for `input_shape=(4, 3, 2)` and `index=(2:,)`, returns `(2, 3, 2)`
|
||||
|
||||
Args:
|
||||
input_shape (Tuple[int, ...]): shape of the input tensor that is indexed
|
||||
index (Tuple[Union[int, slice], ...]): desired and validated index
|
||||
|
||||
Returns:
|
||||
Tuple[int, ...]: shape of the result of indexing
|
||||
"""
|
||||
|
||||
indexing_elements = [format_indexing_element(indexing_element) for indexing_element in index]
|
||||
index_str = f"[{', '.join(indexing_elements)}]"
|
||||
|
||||
if len(index) > len(input_shape):
|
||||
raise ValueError(
|
||||
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
|
||||
f"as the index has more elements than the number of dimensions of the tensor"
|
||||
)
|
||||
|
||||
# indexing (3, 4, 5) with [1] is the same as indexing it with [1, :, :]
|
||||
# indexing (3, 4, 5) with [1, 2] is the same as indexing it with [1, 2, :]
|
||||
|
||||
# so let's replicate that behavior to make the rest of the code generic
|
||||
index += (slice(None, None, None),) * (len(input_shape) - len(index))
|
||||
|
||||
output_shape = []
|
||||
for dimension, (indexing_element, dimension_size) in enumerate(zip(index, input_shape)):
|
||||
if isinstance(indexing_element, int): # indexing removes the dimension
|
||||
indexing_element = (
|
||||
indexing_element if indexing_element >= 0 else indexing_element + dimension_size
|
||||
)
|
||||
if not 0 <= indexing_element < dimension_size:
|
||||
raise ValueError(
|
||||
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
|
||||
f"because index is out of range for dimension {dimension}"
|
||||
)
|
||||
elif isinstance(indexing_element, slice): # indexing possibly shrinks the dimension
|
||||
output_shape.append(
|
||||
determine_new_dimension_size(
|
||||
indexing_element,
|
||||
dimension_size,
|
||||
dimension,
|
||||
input_shape,
|
||||
index_str,
|
||||
)
|
||||
)
|
||||
|
||||
return tuple(output_shape)
|
||||
|
||||
|
||||
def sanitize_start_index(
|
||||
start: int,
|
||||
dimension_size: int,
|
||||
# the rest is used for detailed exception message
|
||||
dimension: int,
|
||||
input_shape: Tuple[int, ...],
|
||||
index_str: str,
|
||||
) -> int:
|
||||
"""Sanitize and check start index of a slice.
|
||||
|
||||
Args:
|
||||
start (int): start index being sanitized
|
||||
dimension_size (int): size of the dimension the slice is applied to
|
||||
dimension (int): index of the dimension being sliced (for better messages)
|
||||
input_shape (Tuple[int, ...]): shape of the whole input (for better messages)
|
||||
index_str (str): string representation of the whole index (for better messages)
|
||||
|
||||
Returns:
|
||||
int: sanitized start index
|
||||
"""
|
||||
|
||||
start = start if start >= 0 else start + dimension_size
|
||||
if not 0 <= start < dimension_size:
|
||||
raise ValueError(
|
||||
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
|
||||
f"because start index is out of range for dimension {dimension}"
|
||||
)
|
||||
return start
|
||||
|
||||
|
||||
def sanitize_stop_index(
|
||||
stop: int,
|
||||
dimension_size: int,
|
||||
# the rest is used for detailed exception message
|
||||
dimension: int,
|
||||
input_shape: Tuple[int, ...],
|
||||
index_str: str,
|
||||
) -> int:
|
||||
"""Sanitize and check stop index of a slice.
|
||||
|
||||
Args:
|
||||
stop (int): stop index being sanitized
|
||||
dimension_size (int): size of the dimension the slice is applied to
|
||||
dimension (int): index of the dimension being sliced (for better messages)
|
||||
input_shape (Tuple[int, ...]): shape of the whole input (for better messages)
|
||||
index_str (str): string representation of the whole index (for better messages)
|
||||
|
||||
Returns:
|
||||
int: sanitized stop index
|
||||
"""
|
||||
|
||||
stop = stop if stop >= 0 else stop + dimension_size
|
||||
if not 0 <= stop <= dimension_size:
|
||||
raise ValueError(
|
||||
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
|
||||
f"because stop index is out of range for dimension {dimension}"
|
||||
)
|
||||
return stop
|
||||
|
||||
|
||||
def determine_new_dimension_size(
|
||||
slice_: slice,
|
||||
dimension_size: int,
|
||||
# the rest is used for detailed exception message
|
||||
dimension: int,
|
||||
input_shape: Tuple[int, ...],
|
||||
index_str: str,
|
||||
) -> int:
|
||||
"""Determine the new size of a dimension from the old size and the slice applied to it.
|
||||
|
||||
e.g., for `slice_=1:4` and `dimension_size=5`, returns `3`
|
||||
for `slice_=::-1` and `dimension_size=5`, returns `5`
|
||||
|
||||
You may want to check this page to learn more about how this function works
|
||||
https://numpy.org/doc/stable/reference/arrays.indexing.html#basic-slicing-and-indexing
|
||||
|
||||
Args:
|
||||
slice_ (slice): slice being applied to the dimension
|
||||
dimension_size (int): size of the dimension the slice is applied to
|
||||
dimension (int): index of the dimension being sliced (for better messages)
|
||||
input_shape (Tuple[int, ...]): shape of the whole input (for better messages)
|
||||
index_str (str): string representation of the whole index (for better messages)
|
||||
|
||||
Returns:
|
||||
int: new size of the dimension
|
||||
"""
|
||||
|
||||
step = slice_.step if slice_.step is not None else 1
|
||||
|
||||
if step > 0:
|
||||
start = slice_.start if slice_.start is not None else 0
|
||||
stop = slice_.stop if slice_.stop is not None else dimension_size
|
||||
|
||||
start = sanitize_start_index(start, dimension_size, dimension, input_shape, index_str)
|
||||
stop = sanitize_stop_index(stop, dimension_size, dimension, input_shape, index_str)
|
||||
|
||||
if start >= stop:
|
||||
raise ValueError(
|
||||
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
|
||||
f"because start index is not less than stop index for dimension {dimension}"
|
||||
)
|
||||
|
||||
size_before_stepping = stop - start
|
||||
elif step < 0:
|
||||
start = slice_.start if slice_.start is not None else dimension_size - 1
|
||||
stop = slice_.stop
|
||||
|
||||
start = sanitize_start_index(start, dimension_size, dimension, input_shape, index_str)
|
||||
|
||||
if stop is None:
|
||||
# this is a weird case but it works as expected
|
||||
# the issue is that it's impossible to slice whole vector reversed
|
||||
# with a stop value different than none
|
||||
|
||||
# if `x.shape == (6,)` the only one that works is `x[::-1].shape == (6,)`
|
||||
# here is what doesn't work (and this is expected it's just weird)
|
||||
#
|
||||
# ...
|
||||
# `x[:-2:-1].shape == (1,)`
|
||||
# `x[:-1:-1].shape == (0,)` (note that this is a hard error for us)
|
||||
# `x[:0:-1].shape == (5,)`
|
||||
# `x[:1:-1].shape == (4,)`
|
||||
# ...
|
||||
|
||||
size_before_stepping = start + 1
|
||||
else:
|
||||
stop = sanitize_stop_index(stop, dimension_size, dimension, input_shape, index_str)
|
||||
|
||||
if stop >= start:
|
||||
raise ValueError(
|
||||
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
|
||||
f"because step is negative and "
|
||||
f"stop index is not less than start index for dimension {dimension}"
|
||||
)
|
||||
|
||||
size_before_stepping = start - stop
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Tensor of shape {input_shape} cannot be indexed with {index_str} "
|
||||
f"because step is zero for dimension {dimension}"
|
||||
)
|
||||
|
||||
quotient = size_before_stepping // abs(step)
|
||||
remainder = size_before_stepping % abs(step)
|
||||
|
||||
return quotient + (remainder != 0)
|
||||
@@ -1,35 +0,0 @@
|
||||
"""Common python helpers."""
|
||||
|
||||
from typing import Any, Callable, Iterable, Mapping, Tuple, Union
|
||||
|
||||
|
||||
def update_and_return_dict(
|
||||
dict_to_update: dict, update_values: Union[Mapping, Iterable[Tuple[Any, Any]]]
|
||||
) -> dict:
|
||||
"""Update a dictionary and return the ref to the dictionary that was updated.
|
||||
|
||||
Args:
|
||||
dict_to_update (dict): the dict to update
|
||||
update_values (Union[Mapping, Iterable[Tuple[Any, Any]]]): the values to update the dict
|
||||
with
|
||||
|
||||
Returns:
|
||||
dict: the dict that was just updated.
|
||||
"""
|
||||
dict_to_update.update(update_values)
|
||||
return dict_to_update
|
||||
|
||||
|
||||
def catch(func: Callable, *args, **kwargs) -> Union[Any, None]:
|
||||
"""Execute func by passing args and kwargs. Catch exceptions and return None in case of failure.
|
||||
|
||||
Args:
|
||||
func (Callable): function to execute and catch exceptions from
|
||||
|
||||
Returns:
|
||||
Union[Any, None]: the function result if there was no exception, None otherwise.
|
||||
"""
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
return None
|
||||
@@ -1,3 +0,0 @@
|
||||
"""MLIR conversion module."""
|
||||
|
||||
from .graph_converter import OPGraphConverter
|
||||
@@ -1,61 +0,0 @@
|
||||
"""Helpers for MLIR conversion functionality."""
|
||||
|
||||
# pylint cannot extract symbol information of 'mlir' module so we need to disable some lints
|
||||
|
||||
# pylint: disable=no-name-in-module
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from concrete.lang.dialects.fhe import EncryptedIntegerType
|
||||
from mlir.ir import Context, IntegerType, RankedTensorType, Type
|
||||
|
||||
from ..data_types import Integer
|
||||
from ..values import BaseValue, TensorValue
|
||||
|
||||
# pylint: enable=no-name-in-module
|
||||
|
||||
|
||||
def integer_to_mlir_type(ctx: Context, integer: Integer, is_encrypted: bool) -> Optional[Type]:
|
||||
"""Convert an integer to its corresponding MLIR type.
|
||||
|
||||
Args:
|
||||
ctx (Context): the MLIR context to perform the conversion
|
||||
integer (Integer): the integer to convert
|
||||
is_encrypted (bool): whether the integer is encrypted or not
|
||||
|
||||
Returns:
|
||||
Type:
|
||||
the MLIR type corresponding to given integer and encryption status
|
||||
if it's supported otherwise None
|
||||
"""
|
||||
|
||||
bit_width = integer.bit_width
|
||||
|
||||
if is_encrypted:
|
||||
result = EncryptedIntegerType.get(ctx, bit_width)
|
||||
else:
|
||||
result = IntegerType.get_signless(bit_width)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def value_to_mlir_type(ctx: Context, value: BaseValue) -> Type:
|
||||
"""Convert a value to its corresponding MLIR type.
|
||||
|
||||
Args:
|
||||
ctx (Context): the MLIR context to perform the conversion
|
||||
value (BaseValue): the value to convert
|
||||
|
||||
Returns:
|
||||
Type: the MLIR type corresponding to given value
|
||||
"""
|
||||
|
||||
dtype = value.dtype
|
||||
if isinstance(dtype, Integer):
|
||||
mlir_type = integer_to_mlir_type(ctx, dtype, value.is_encrypted)
|
||||
if isinstance(value, TensorValue):
|
||||
if not value.is_scalar:
|
||||
mlir_type = RankedTensorType.get(value.shape, mlir_type)
|
||||
return mlir_type
|
||||
|
||||
raise TypeError(f"{value} is not supported for MLIR conversion")
|
||||
@@ -1,138 +0,0 @@
|
||||
"""Module that provides OPGraph conversion functionality."""
|
||||
|
||||
# pylint cannot extract symbol information of 'mlir' module so we need to disable some lints
|
||||
|
||||
# pylint: disable=no-name-in-module
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import concrete.lang as concretelang
|
||||
import networkx as nx
|
||||
from mlir.dialects import builtin
|
||||
from mlir.ir import Context, InsertionPoint, Location, Module
|
||||
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import Input, IntermediateNode
|
||||
from .conversion_helpers import value_to_mlir_type
|
||||
from .node_converter import IntermediateNodeConverter
|
||||
|
||||
# pylint: enable=no-name-in-module
|
||||
|
||||
|
||||
class OPGraphConverter(ABC):
|
||||
"""Converter of OPGraph to MLIR."""
|
||||
|
||||
def convert(self, op_graph: OPGraph) -> str:
|
||||
"""Convert an operation graph to its corresponding MLIR representation.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): the operation graph to be converted
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to given operation graph
|
||||
"""
|
||||
|
||||
additional_conversion_info = self._generate_additional_info_dict(op_graph)
|
||||
|
||||
# There are no tensor +*- scalar operations in the compiler
|
||||
# But such operations are used commonly so we need to support them
|
||||
# So, we implemented some workarounds (pull request #970)
|
||||
# Once we have native support, this workaround shall be removed (issue #837)
|
||||
# (most changes in #970 shall be reverted)
|
||||
|
||||
# { node1: "%arg0", node2: "%0", node3: "%1" }
|
||||
nodes_to_mlir_names: Dict[IntermediateNode, str] = {}
|
||||
|
||||
# { "%arg0": "i5", "%0": "tensor<2x3x!FHE.eint<4>>" }
|
||||
mlir_names_to_mlir_types: Dict[str, str] = {}
|
||||
|
||||
# { "%0": ["%c1_i5"] } == for %0 we need to convert %c1_i5 to 1d tensor
|
||||
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]] = {}
|
||||
|
||||
with Context() as ctx, Location.unknown():
|
||||
concretelang.register_dialects(ctx)
|
||||
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
parameters = [
|
||||
value_to_mlir_type(ctx, input_node.outputs[0])
|
||||
for input_node in op_graph.get_ordered_inputs()
|
||||
]
|
||||
|
||||
@builtin.FuncOp.from_py_func(*parameters)
|
||||
def main(*arg):
|
||||
ir_to_mlir = {}
|
||||
for arg_num, node in op_graph.input_nodes.items():
|
||||
ir_to_mlir[node] = arg[arg_num]
|
||||
|
||||
mlir_name = f"%arg{arg_num}"
|
||||
nodes_to_mlir_names[node] = mlir_name
|
||||
mlir_names_to_mlir_types[mlir_name] = str(parameters[arg_num])
|
||||
|
||||
for node in nx.topological_sort(op_graph.graph):
|
||||
if isinstance(node, Input):
|
||||
continue
|
||||
|
||||
preds = [ir_to_mlir[pred] for pred in op_graph.get_ordered_preds(node)]
|
||||
node_converter = IntermediateNodeConverter(
|
||||
ctx,
|
||||
op_graph,
|
||||
node,
|
||||
preds,
|
||||
nodes_to_mlir_names,
|
||||
mlir_names_to_mlir_types,
|
||||
scalar_to_1d_tensor_conversion_hacks,
|
||||
)
|
||||
ir_to_mlir[node] = node_converter.convert(additional_conversion_info)
|
||||
|
||||
results = (
|
||||
ir_to_mlir[output_node] for output_node in op_graph.get_ordered_outputs()
|
||||
)
|
||||
return results
|
||||
|
||||
module_lines_after_hacks_are_applied = []
|
||||
for line in str(module).split("\n"):
|
||||
mlir_name = line.split("=")[0].strip()
|
||||
if mlir_name not in scalar_to_1d_tensor_conversion_hacks:
|
||||
module_lines_after_hacks_are_applied.append(line)
|
||||
continue
|
||||
|
||||
to_be_replaced = scalar_to_1d_tensor_conversion_hacks[mlir_name]
|
||||
for arg_name in to_be_replaced:
|
||||
new_name = f"%hack_{mlir_name.replace('%', '')}_{arg_name.replace('%', '')}"
|
||||
mlir_type = mlir_names_to_mlir_types[arg_name]
|
||||
|
||||
hack_line = (
|
||||
f" {new_name} = tensor.from_elements {arg_name} : tensor<1x{mlir_type}>"
|
||||
)
|
||||
module_lines_after_hacks_are_applied.append(hack_line)
|
||||
|
||||
line = line.replace(arg_name, new_name)
|
||||
|
||||
new_arg_types = []
|
||||
|
||||
arg_types = line.split(":")[1].split("->")[0].strip()[1:-1]
|
||||
for arg in arg_types.split(", "):
|
||||
if arg.startswith("tensor"):
|
||||
new_arg_types.append(arg)
|
||||
else:
|
||||
new_arg_types.append(f"tensor<1x{arg}>")
|
||||
|
||||
line = line.replace(arg_types, ", ".join(new_arg_types))
|
||||
|
||||
module_lines_after_hacks_are_applied.append(line)
|
||||
|
||||
return "\n".join(module_lines_after_hacks_are_applied)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def _generate_additional_info_dict(op_graph: OPGraph) -> Dict[str, Any]:
|
||||
"""Generate additional conversion info dict for the MLIR converter.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): the operation graph from which the additional info will be generated
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: dict of additional conversion info
|
||||
"""
|
||||
@@ -1,873 +0,0 @@
|
||||
"""Module that provides IntermediateNode conversion functionality."""
|
||||
|
||||
# pylint cannot extract symbol information of 'mlir' module so we need to disable some lints
|
||||
|
||||
# pylint: disable=no-name-in-module
|
||||
|
||||
from typing import Any, Dict, List, Tuple, cast
|
||||
|
||||
import numpy
|
||||
from concrete.lang.dialects import fhe, fhelinalg
|
||||
from mlir.dialects import arith, linalg, tensor
|
||||
from mlir.ir import (
|
||||
ArrayAttr,
|
||||
Attribute,
|
||||
BoolAttr,
|
||||
Context,
|
||||
DenseElementsAttr,
|
||||
IndexType,
|
||||
IntegerAttr,
|
||||
IntegerType,
|
||||
OpResult,
|
||||
RankedTensorType,
|
||||
)
|
||||
|
||||
from ..data_types import Integer
|
||||
from ..debugging import assert_true
|
||||
from ..helpers.indexing_helpers import determine_new_dimension_size
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import (
|
||||
Add,
|
||||
Constant,
|
||||
Conv2D,
|
||||
Dot,
|
||||
GenericFunction,
|
||||
IndexConstant,
|
||||
IntermediateNode,
|
||||
MatMul,
|
||||
Mul,
|
||||
Sub,
|
||||
)
|
||||
from ..values import TensorValue
|
||||
from .conversion_helpers import integer_to_mlir_type, value_to_mlir_type
|
||||
|
||||
# pylint: enable=no-name-in-module
|
||||
|
||||
|
||||
class IntermediateNodeConverter:
|
||||
"""Converter of IntermediateNode to MLIR."""
|
||||
|
||||
ctx: Context
|
||||
op_graph: OPGraph
|
||||
node: IntermediateNode
|
||||
preds: List[OpResult]
|
||||
|
||||
all_of_the_inputs_are_encrypted: bool
|
||||
all_of_the_inputs_are_tensors: bool
|
||||
one_of_the_inputs_is_a_tensor: bool
|
||||
|
||||
nodes_to_mlir_names: Dict[IntermediateNode, str]
|
||||
mlir_names_to_mlir_types: Dict[str, str]
|
||||
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ctx: Context,
|
||||
op_graph: OPGraph,
|
||||
node: IntermediateNode,
|
||||
preds: List[OpResult],
|
||||
nodes_to_mlir_names: Dict[OpResult, str],
|
||||
mlir_names_to_mlir_types: Dict[str, str],
|
||||
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]],
|
||||
):
|
||||
self.ctx = ctx
|
||||
self.op_graph = op_graph
|
||||
self.node = node
|
||||
self.preds = preds
|
||||
|
||||
self.all_of_the_inputs_are_encrypted = True
|
||||
self.all_of_the_inputs_are_tensors = True
|
||||
self.one_of_the_inputs_is_a_tensor = False
|
||||
|
||||
for inp in node.inputs:
|
||||
if inp.is_clear:
|
||||
self.all_of_the_inputs_are_encrypted = False
|
||||
|
||||
if isinstance(inp, TensorValue):
|
||||
if inp.is_scalar:
|
||||
self.all_of_the_inputs_are_tensors = False
|
||||
else:
|
||||
self.one_of_the_inputs_is_a_tensor = True
|
||||
else: # pragma: no cover
|
||||
# this branch is not covered as there are only TensorValues for now
|
||||
self.all_of_the_inputs_are_tensors = False
|
||||
|
||||
self.nodes_to_mlir_names = nodes_to_mlir_names
|
||||
self.mlir_names_to_mlir_types = mlir_names_to_mlir_types
|
||||
self.scalar_to_1d_tensor_conversion_hacks = scalar_to_1d_tensor_conversion_hacks
|
||||
|
||||
def convert(self, additional_conversion_info: Dict[str, Any]) -> OpResult:
|
||||
"""Convert an intermediate node to its corresponding MLIR representation.
|
||||
|
||||
Args:
|
||||
additional_conversion_info (Dict[str, Any]):
|
||||
external info that the converted node might need
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-branches
|
||||
|
||||
if isinstance(self.node, Add):
|
||||
result = self.convert_add()
|
||||
|
||||
elif isinstance(self.node, Constant):
|
||||
result = self.convert_constant()
|
||||
|
||||
elif isinstance(self.node, Dot):
|
||||
result = self.convert_dot()
|
||||
|
||||
elif isinstance(self.node, GenericFunction):
|
||||
if self.node.op_name in ["flatten", "reshape"]:
|
||||
# notice flatten() == reshape(-1) and convert_reshape can handle that
|
||||
result = self.convert_reshape()
|
||||
elif self.node.op_name == "sum":
|
||||
result = self.convert_sum()
|
||||
elif self.node.op_name == "concat":
|
||||
result = self.convert_concat()
|
||||
elif self.node.op_name == "transpose":
|
||||
result = self.convert_transpose()
|
||||
else:
|
||||
result = self.convert_generic_function(additional_conversion_info)
|
||||
|
||||
elif isinstance(self.node, IndexConstant):
|
||||
result = self.convert_index_constant()
|
||||
|
||||
elif isinstance(self.node, MatMul):
|
||||
result = self.convert_matmul()
|
||||
|
||||
elif isinstance(self.node, Mul):
|
||||
result = self.convert_mul()
|
||||
|
||||
elif isinstance(self.node, Sub):
|
||||
result = self.convert_sub()
|
||||
|
||||
elif isinstance(self.node, Conv2D):
|
||||
result = self.convert_conv2d()
|
||||
|
||||
else: # pragma: no cover
|
||||
# this branch is not covered as unsupported opeations fail on check mlir compatibility
|
||||
raise NotImplementedError(f"{type(self.node)} nodes cannot be converted to MLIR yet")
|
||||
|
||||
# pylint: enable=too-many-branches
|
||||
|
||||
mlir_name = str(result).replace("Value(", "").split("=", maxsplit=1)[0].strip()
|
||||
|
||||
self.nodes_to_mlir_names[self.node] = mlir_name
|
||||
self.mlir_names_to_mlir_types[mlir_name] = str(result.type)
|
||||
|
||||
if isinstance(self.node, (Add, Mul, Sub, Dot)):
|
||||
if self.one_of_the_inputs_is_a_tensor and not self.all_of_the_inputs_are_tensors:
|
||||
to_be_converted = []
|
||||
for (pred, output) in self.op_graph.get_ordered_preds_and_inputs_of(self.node):
|
||||
inp = pred.outputs[output]
|
||||
if isinstance(inp, TensorValue) and inp.is_scalar:
|
||||
to_be_converted.append(self.nodes_to_mlir_names[pred])
|
||||
self.scalar_to_1d_tensor_conversion_hacks[mlir_name] = to_be_converted
|
||||
|
||||
return result
|
||||
|
||||
def convert_add(self) -> OpResult:
|
||||
"""Convert an Add node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
assert_true(len(self.node.inputs) == 2)
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
|
||||
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
|
||||
preds = self.preds
|
||||
|
||||
if self.all_of_the_inputs_are_encrypted:
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
result = fhelinalg.AddEintOp(resulting_type, *preds).result
|
||||
else:
|
||||
result = fhe.AddEintOp(resulting_type, *preds).result
|
||||
else:
|
||||
if self.node.inputs[0].is_clear: # pragma: no cover
|
||||
# this branch is not covered as it's impossible to get into due to how tracing works
|
||||
# however, it doesn't hurt to keep it as an extra measure
|
||||
preds = preds[::-1]
|
||||
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
result = fhelinalg.AddEintIntOp(resulting_type, *preds).result
|
||||
else:
|
||||
result = fhe.AddEintIntOp(resulting_type, *preds).result
|
||||
|
||||
return result
|
||||
|
||||
def convert_concat(self) -> OpResult:
|
||||
"""Convert a "concat" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
assert_true(len(self.node.inputs) >= 2)
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
|
||||
node = cast(GenericFunction, self.node)
|
||||
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
|
||||
|
||||
axis = node.op_kwargs.get("axis", 0)
|
||||
if axis is not None:
|
||||
if axis < 0:
|
||||
axis += len(cast(TensorValue, self.node.inputs[0]).shape)
|
||||
return fhelinalg.ConcatOp(
|
||||
resulting_type,
|
||||
self.preds,
|
||||
IntegerAttr.get(IntegerType.get_signless(64), axis),
|
||||
).result
|
||||
|
||||
flattened_preds = []
|
||||
for pred, input_value in zip(self.preds, self.node.inputs):
|
||||
input_shape = cast(TensorValue, input_value).shape
|
||||
input_size = numpy.prod(input_shape)
|
||||
input_dtype = cast(Integer, input_value.dtype)
|
||||
|
||||
flattened_pred_type = RankedTensorType.get(
|
||||
[input_size],
|
||||
integer_to_mlir_type(self.ctx, input_dtype, input_value.is_encrypted),
|
||||
)
|
||||
flattened_pred = linalg.TensorCollapseShapeOp(
|
||||
flattened_pred_type,
|
||||
pred,
|
||||
ArrayAttr.get(
|
||||
[
|
||||
ArrayAttr.get(
|
||||
[
|
||||
IntegerAttr.get(IndexType.parse("index"), i)
|
||||
for i in range(len(input_shape))
|
||||
]
|
||||
)
|
||||
]
|
||||
),
|
||||
).result
|
||||
flattened_preds.append(flattened_pred)
|
||||
|
||||
return fhelinalg.ConcatOp(
|
||||
resulting_type,
|
||||
flattened_preds,
|
||||
IntegerAttr.get(IntegerType.get_signless(64), 0),
|
||||
).result
|
||||
|
||||
def convert_constant(self) -> OpResult:
|
||||
"""Convert a Constant node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
assert_true(len(self.node.inputs) == 0)
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
|
||||
value = self.node.outputs[0]
|
||||
if not isinstance(value, TensorValue): # pragma: no cover
|
||||
# this branch is not covered as there are only TensorValues for now
|
||||
raise NotImplementedError(f"{value} constants cannot be converted to MLIR yet")
|
||||
|
||||
resulting_type = value_to_mlir_type(self.ctx, value)
|
||||
data = cast(Constant, self.node).constant_data
|
||||
|
||||
if value.is_scalar:
|
||||
attr = IntegerAttr.get(resulting_type, data)
|
||||
else:
|
||||
# usage of `Attribute.parse` is the result of some limitations in the MLIR module
|
||||
# provided by LLVM
|
||||
|
||||
# what should have been used is `DenseElementsAttr` but it's impossible to assign
|
||||
# custom bit-widths using it (e.g., uint5)
|
||||
|
||||
# since we coudn't create a `DenseElementsAttr` with a custom bit width using python api
|
||||
# we use `Attribute.parse` to let the underlying library do it by itself
|
||||
|
||||
attr = Attribute.parse(f"dense<{str(data.tolist())}> : {resulting_type}")
|
||||
|
||||
return arith.ConstantOp(resulting_type, attr).result
|
||||
|
||||
def convert_conv2d(self) -> OpResult:
|
||||
"""Convert a Conv2D node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
assert_true(len(self.node.inputs) == 2 or len(self.node.inputs) == 3)
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
has_bias = len(self.node.inputs) == 3
|
||||
|
||||
x = self.node.inputs[0]
|
||||
weight = self.node.inputs[1]
|
||||
if not (x.is_encrypted and weight.is_clear): # pragma: no cover
|
||||
raise NotImplementedError(
|
||||
f"Conv2D with input {x} and weight {weight} cannot be converted to MLIR yet",
|
||||
)
|
||||
|
||||
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
|
||||
preds = self.preds
|
||||
|
||||
node = cast(Conv2D, self.node)
|
||||
integer_type = IntegerType.get_signless(64, context=self.ctx)
|
||||
strides = DenseElementsAttr.get(
|
||||
numpy.array(list(node.strides), dtype=numpy.uint64),
|
||||
context=self.ctx,
|
||||
type=integer_type,
|
||||
)
|
||||
dilations = DenseElementsAttr.get(
|
||||
numpy.array(list(node.dilations), dtype=numpy.uint64),
|
||||
context=self.ctx,
|
||||
type=integer_type,
|
||||
)
|
||||
pads = DenseElementsAttr.get(
|
||||
numpy.array(list(node.pads), dtype=numpy.uint64), context=self.ctx, type=integer_type
|
||||
)
|
||||
if has_bias:
|
||||
result = fhelinalg.Conv2dOp(resulting_type, *preds, pads, strides, dilations).result
|
||||
else:
|
||||
result = fhelinalg.Conv2dOp(
|
||||
resulting_type, *preds, None, pads, strides, dilations
|
||||
).result
|
||||
|
||||
return result
|
||||
|
||||
def convert_dot(self) -> OpResult:
|
||||
"""Convert a Dot node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
assert_true(len(self.node.inputs) == 2)
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
|
||||
if self.all_of_the_inputs_are_encrypted:
|
||||
lhs = self.node.inputs[0]
|
||||
rhs = self.node.inputs[1]
|
||||
raise NotImplementedError(
|
||||
f"Dot product between {lhs} and {rhs} cannot be converted to MLIR yet",
|
||||
)
|
||||
|
||||
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
|
||||
preds = self.preds
|
||||
|
||||
if self.node.inputs[0].is_clear:
|
||||
preds = preds[::-1]
|
||||
|
||||
if self.all_of_the_inputs_are_tensors:
|
||||
# numpy.dot(x, y) where x and y are both vectors = regular dot product
|
||||
result = fhelinalg.Dot(resulting_type, *preds).result
|
||||
|
||||
elif not self.one_of_the_inputs_is_a_tensor:
|
||||
# numpy.dot(x, y) where x and y are both scalars = x * y
|
||||
result = fhe.MulEintIntOp(resulting_type, *preds).result
|
||||
|
||||
else:
|
||||
# numpy.dot(x, y) where one of x or y is a scalar and the other one is a vector = x * y
|
||||
result = fhelinalg.MulEintIntOp(resulting_type, *preds).result
|
||||
|
||||
return result
|
||||
|
||||
def convert_generic_function(self, additional_conversion_info: Dict[str, Any]) -> OpResult:
|
||||
"""Convert a GenericFunction node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
variable_input_indices = [
|
||||
idx
|
||||
for idx, inp in enumerate(self.op_graph.get_ordered_preds(self.node))
|
||||
if not isinstance(inp, Constant)
|
||||
]
|
||||
if len(variable_input_indices) != 1: # pragma: no cover
|
||||
# this branch is not covered as it's impossible to get into due to how tracing works
|
||||
# however, it doesn't hurt to keep it as an extra measure
|
||||
raise NotImplementedError(
|
||||
"Table lookups with more than one variable input cannot be converted to MLIR yet"
|
||||
)
|
||||
variable_input_index = variable_input_indices[0]
|
||||
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
output = self.node.outputs[0]
|
||||
|
||||
value = self.node.inputs[variable_input_index]
|
||||
assert_true(value.is_encrypted)
|
||||
|
||||
if not isinstance(value.dtype, Integer): # pragma: no cover
|
||||
# this branch is not covered as it's impossible to get into due to how compilation works
|
||||
# however, it doesn't hurt to keep it as an extra measure
|
||||
raise NotImplementedError(f"Table lookup on {value} cannot be converted to MLIR yet")
|
||||
|
||||
tables = additional_conversion_info["tables"][self.node]
|
||||
assert_true(len(tables) > 0)
|
||||
|
||||
lut_shape: Tuple[int, ...] = ()
|
||||
map_shape: Tuple[int, ...] = ()
|
||||
|
||||
if len(tables) == 1:
|
||||
table = tables[0][0]
|
||||
|
||||
# The reduction on 63b is to avoid problems like doing a TLU of
|
||||
# the form T[j] = 2<<j, for j which is supposed to be 7b as per
|
||||
# constraint of the compiler, while in practice, it is a small
|
||||
# value. Reducing on 64b was not ok for some reason
|
||||
lut_shape = (len(table),)
|
||||
lut_values = numpy.array(table % (2 << 63), dtype=numpy.uint64)
|
||||
|
||||
map_shape = ()
|
||||
map_values = None
|
||||
else:
|
||||
assert_true(isinstance(output, TensorValue))
|
||||
assert isinstance(output, TensorValue)
|
||||
|
||||
individual_table_size = len(tables[0][0])
|
||||
|
||||
lut_shape = (len(tables), individual_table_size)
|
||||
map_shape = output.shape
|
||||
|
||||
lut_values = numpy.zeros(lut_shape, dtype=numpy.uint64)
|
||||
map_values = numpy.zeros(map_shape, dtype=numpy.intp)
|
||||
|
||||
for i, (table, indices) in enumerate(tables):
|
||||
assert_true(len(table) == individual_table_size)
|
||||
lut_values[i, :] = table
|
||||
for index in indices:
|
||||
map_values[index] = i
|
||||
|
||||
lut_type = RankedTensorType.get(lut_shape, IntegerType.get_signless(64, context=self.ctx))
|
||||
lut_attr = DenseElementsAttr.get(lut_values, context=self.ctx)
|
||||
lut = arith.ConstantOp(lut_type, lut_attr).result
|
||||
|
||||
resulting_type = value_to_mlir_type(self.ctx, output)
|
||||
pred = self.preds[variable_input_index]
|
||||
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
if len(tables) == 1:
|
||||
result = fhelinalg.ApplyLookupTableEintOp(resulting_type, pred, lut).result
|
||||
else:
|
||||
assert_true(map_shape != ())
|
||||
assert_true(map_values is not None)
|
||||
|
||||
index_type = IndexType.parse("index")
|
||||
map_type = RankedTensorType.get(map_shape, index_type)
|
||||
map_attr = DenseElementsAttr.get(map_values, context=self.ctx, type=index_type)
|
||||
|
||||
result = fhelinalg.ApplyMappedLookupTableEintOp(
|
||||
resulting_type,
|
||||
pred,
|
||||
lut,
|
||||
arith.ConstantOp(map_type, map_attr).result,
|
||||
).result
|
||||
else:
|
||||
result = fhe.ApplyLookupTableEintOp(resulting_type, pred, lut).result
|
||||
|
||||
return result
|
||||
|
||||
def convert_index_constant(self) -> OpResult:
|
||||
"""Convert a IndexConstant node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-locals
|
||||
|
||||
assert_true(len(self.node.inputs) == 1)
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
|
||||
tensor_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
|
||||
pred = self.preds[0]
|
||||
|
||||
input_value = cast(TensorValue, self.node.inputs[0])
|
||||
input_shape = input_value.shape
|
||||
|
||||
index = cast(IndexConstant, self.node).index
|
||||
index_str = self.node.text_for_formatting([""], 0)
|
||||
|
||||
index_type = IndexType.parse("index")
|
||||
|
||||
if len(index) == len(input_shape) and all(isinstance(i, int) for i in index):
|
||||
indices = []
|
||||
for value, dimension_size in zip(index, input_shape):
|
||||
assert isinstance(value, int) # mypy
|
||||
attr = IntegerAttr.get(index_type, value if value >= 0 else value + dimension_size)
|
||||
indices.append(arith.ConstantOp(index_type, attr).result)
|
||||
return tensor.ExtractOp(tensor_type, pred, indices).result
|
||||
|
||||
offsets = []
|
||||
sizes = []
|
||||
strides = []
|
||||
|
||||
destroyed_dimensions = []
|
||||
for dimension, (indexing_element, dimension_size) in enumerate(zip(index, input_shape)):
|
||||
|
||||
if isinstance(indexing_element, int):
|
||||
destroyed_dimensions.append(dimension)
|
||||
size = 1
|
||||
stride = 1
|
||||
offset = (
|
||||
indexing_element if indexing_element >= 0 else indexing_element + dimension_size
|
||||
)
|
||||
|
||||
elif isinstance(indexing_element, slice):
|
||||
size = determine_new_dimension_size(
|
||||
indexing_element,
|
||||
dimension_size,
|
||||
dimension,
|
||||
input_shape,
|
||||
index_str,
|
||||
)
|
||||
stride = indexing_element.step if isinstance(indexing_element.step, int) else 1
|
||||
offset = (
|
||||
(
|
||||
indexing_element.start
|
||||
if indexing_element.start >= 0
|
||||
else indexing_element.start + dimension_size
|
||||
)
|
||||
if isinstance(indexing_element.start, int)
|
||||
else (0 if stride > 0 else dimension_size - 1)
|
||||
)
|
||||
|
||||
else: # pragma: no cover
|
||||
# this branch is impossible to reach with all the previous checks
|
||||
# but let's keep it as an extra measure
|
||||
raise NotImplementedError(
|
||||
f"Indexing of {input_value} with {index_str} cannot be converted to MLIR",
|
||||
)
|
||||
|
||||
offsets.append(offset)
|
||||
sizes.append(size)
|
||||
strides.append(stride)
|
||||
|
||||
if len(destroyed_dimensions) == 0:
|
||||
return tensor.ExtractSliceOp(
|
||||
tensor_type,
|
||||
pred,
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in offsets]),
|
||||
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in sizes]),
|
||||
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in strides]),
|
||||
).result
|
||||
|
||||
output_value = cast(TensorValue, self.node.outputs[0])
|
||||
|
||||
intermediate_shape = list(output_value.shape)
|
||||
for dimension in destroyed_dimensions:
|
||||
intermediate_shape.insert(dimension, 1)
|
||||
|
||||
intermediate_type = RankedTensorType.get(
|
||||
intermediate_shape,
|
||||
integer_to_mlir_type(
|
||||
self.ctx,
|
||||
cast(Integer, output_value.dtype),
|
||||
output_value.is_encrypted,
|
||||
),
|
||||
)
|
||||
|
||||
intermediate = tensor.ExtractSliceOp(
|
||||
intermediate_type,
|
||||
pred,
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in offsets]),
|
||||
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in sizes]),
|
||||
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in strides]),
|
||||
).result
|
||||
|
||||
reassociaton = []
|
||||
|
||||
current_intermediate_dimension = 0
|
||||
for _ in range(len(output_value.shape)):
|
||||
indices = [current_intermediate_dimension]
|
||||
while current_intermediate_dimension in destroyed_dimensions:
|
||||
current_intermediate_dimension += 1
|
||||
indices.append(current_intermediate_dimension)
|
||||
|
||||
reassociaton.append(indices)
|
||||
current_intermediate_dimension += 1
|
||||
while current_intermediate_dimension < len(intermediate_shape):
|
||||
reassociaton[-1].append(current_intermediate_dimension)
|
||||
current_intermediate_dimension += 1
|
||||
|
||||
return linalg.TensorCollapseShapeOp(
|
||||
tensor_type,
|
||||
intermediate,
|
||||
ArrayAttr.get(
|
||||
[
|
||||
ArrayAttr.get(
|
||||
[IntegerAttr.get(index_type, index) for index in indices],
|
||||
)
|
||||
for indices in reassociaton
|
||||
],
|
||||
),
|
||||
).result
|
||||
|
||||
# pylint: enable=too-many-locals
|
||||
|
||||
def convert_matmul(self) -> OpResult:
|
||||
"""Convert a MatMul node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
assert_true(len(self.node.inputs) == 2)
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
|
||||
if self.all_of_the_inputs_are_encrypted:
|
||||
lhs = self.node.inputs[0]
|
||||
rhs = self.node.inputs[1]
|
||||
raise NotImplementedError(
|
||||
f"Matrix multiplication between {lhs} and {rhs} cannot be converted to MLIR yet",
|
||||
)
|
||||
|
||||
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
|
||||
preds = self.preds
|
||||
|
||||
assert isinstance(self.node.outputs[0], TensorValue)
|
||||
if self.node.outputs[0].shape == ():
|
||||
if self.node.inputs[0].is_clear:
|
||||
preds = preds[::-1]
|
||||
result = fhelinalg.Dot(resulting_type, *preds).result
|
||||
|
||||
elif self.node.inputs[0].is_clear:
|
||||
result = fhelinalg.MatMulIntEintOp(resulting_type, *preds).result
|
||||
else:
|
||||
result = fhelinalg.MatMulEintIntOp(resulting_type, *preds).result
|
||||
|
||||
return result
|
||||
|
||||
def convert_mul(self) -> OpResult:
|
||||
"""Convert a Mul node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
assert_true(len(self.node.inputs) == 2)
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
|
||||
if self.all_of_the_inputs_are_encrypted:
|
||||
lhs = self.node.inputs[0]
|
||||
rhs = self.node.inputs[1]
|
||||
raise NotImplementedError(
|
||||
f"Multiplication between {lhs} and {rhs} cannot be converted to MLIR yet",
|
||||
)
|
||||
|
||||
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
|
||||
preds = self.preds
|
||||
|
||||
if self.node.inputs[0].is_clear: # pragma: no cover
|
||||
# this branch is not covered as it's impossible to get into due to how tracing works
|
||||
# however, it doesn't hurt to keep it as an extra measure
|
||||
preds = preds[::-1]
|
||||
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
result = fhelinalg.MulEintIntOp(resulting_type, *preds).result
|
||||
else:
|
||||
result = fhe.MulEintIntOp(resulting_type, *preds).result
|
||||
|
||||
return result
|
||||
|
||||
def convert_reshape(self) -> OpResult:
|
||||
"""Convert a "reshape" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
assert_true(len(self.node.inputs) == 1)
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
|
||||
assert_true(isinstance(self.node.inputs[0], TensorValue))
|
||||
input_shape = cast(TensorValue, self.node.inputs[0]).shape
|
||||
|
||||
assert_true(isinstance(self.node.outputs[0], TensorValue))
|
||||
output_shape = cast(TensorValue, self.node.outputs[0]).shape
|
||||
|
||||
pred = self.preds[0]
|
||||
if input_shape == output_shape:
|
||||
return pred
|
||||
|
||||
# we can either collapse or expand, which changes the number of dimensions
|
||||
# this is a limitation of the current compiler and it will be improved in the future (#1060)
|
||||
can_be_converted_directly = len(input_shape) != len(output_shape)
|
||||
|
||||
reassociation: List[List[int]] = []
|
||||
if can_be_converted_directly:
|
||||
if len(output_shape) == 1:
|
||||
# output is 1 dimensional so collapse every dimension into the same dimension
|
||||
reassociation.append(list(range(len(input_shape))))
|
||||
else:
|
||||
# input is m dimensional
|
||||
# output is n dimensional
|
||||
# and m is different than n
|
||||
|
||||
# we don't want to duplicate code so we forget about input and output
|
||||
# and we focus on smaller shape and bigger shape
|
||||
|
||||
smaller_shape, bigger_shape = (
|
||||
(output_shape, input_shape)
|
||||
if len(output_shape) < len(input_shape)
|
||||
else (input_shape, output_shape)
|
||||
)
|
||||
s_index, b_index = 0, 0
|
||||
|
||||
# now we will figure out how to group the bigger shape to get the smaller shape
|
||||
# think of the algorithm below as
|
||||
# keep merging the dimensions of the bigger shape
|
||||
# until we have a match on the smaller shape
|
||||
# then try to match the next dimension of the smaller shape
|
||||
# if all dimensions of the smaller shape is matched
|
||||
# we can convert it
|
||||
|
||||
group = []
|
||||
size = 1
|
||||
while s_index < len(smaller_shape) and b_index < len(bigger_shape):
|
||||
# dimension `b_index` of `bigger_shape` belongs to current group
|
||||
group.append(b_index)
|
||||
|
||||
# and current group has `size * bigger_shape[b_index]` elements now
|
||||
size *= bigger_shape[b_index]
|
||||
|
||||
# if current group size matches the dimension `s_index` of `smaller_shape`
|
||||
if size == smaller_shape[s_index]:
|
||||
# we finalize this group and reset everything
|
||||
size = 1
|
||||
reassociation.append(group)
|
||||
group = []
|
||||
|
||||
# now try to match the next dimension of `smaller_shape`
|
||||
s_index += 1
|
||||
|
||||
# now process the next dimension of `bigger_shape`
|
||||
b_index += 1
|
||||
|
||||
# handle the case where bigger shape has proceeding 1s
|
||||
# e.g., (5,) -> (5, 1)
|
||||
while b_index < len(bigger_shape) and bigger_shape[b_index] == 1:
|
||||
reassociation[-1].append(b_index)
|
||||
b_index += 1
|
||||
|
||||
# if not all dimensions of both shapes are processed exactly
|
||||
if s_index != len(smaller_shape) or b_index != len(bigger_shape):
|
||||
# we cannot convert
|
||||
can_be_converted_directly = False
|
||||
|
||||
index_type = IndexType.parse("index")
|
||||
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
|
||||
|
||||
if can_be_converted_directly:
|
||||
reassociation_attr = ArrayAttr.get(
|
||||
[
|
||||
ArrayAttr.get([IntegerAttr.get(index_type, dimension) for dimension in group])
|
||||
for group in reassociation
|
||||
]
|
||||
)
|
||||
if len(output_shape) < len(input_shape):
|
||||
return linalg.TensorCollapseShapeOp(resulting_type, pred, reassociation_attr).result
|
||||
return linalg.TensorExpandShapeOp(resulting_type, pred, reassociation_attr).result
|
||||
|
||||
flattened_type = value_to_mlir_type(
|
||||
self.ctx,
|
||||
TensorValue(
|
||||
self.node.inputs[0].dtype,
|
||||
self.node.inputs[0].is_encrypted,
|
||||
(numpy.prod(input_shape),),
|
||||
),
|
||||
)
|
||||
flattened_result = linalg.TensorCollapseShapeOp(
|
||||
flattened_type,
|
||||
pred,
|
||||
ArrayAttr.get(
|
||||
[ArrayAttr.get([IntegerAttr.get(index_type, i) for i in range(len(input_shape))])]
|
||||
),
|
||||
).result
|
||||
|
||||
return linalg.TensorExpandShapeOp(
|
||||
resulting_type,
|
||||
flattened_result,
|
||||
ArrayAttr.get(
|
||||
[ArrayAttr.get([IntegerAttr.get(index_type, i) for i in range(len(output_shape))])]
|
||||
),
|
||||
).result
|
||||
|
||||
def convert_sub(self) -> OpResult:
|
||||
"""Convert a Sub node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
assert_true(len(self.node.inputs) == 2)
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
|
||||
lhs = self.node.inputs[0]
|
||||
rhs = self.node.inputs[1]
|
||||
if not (lhs.is_clear and rhs.is_encrypted):
|
||||
raise NotImplementedError(
|
||||
f"Subtraction of {rhs} from {lhs} cannot be converted to MLIR yet",
|
||||
)
|
||||
|
||||
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
|
||||
preds = self.preds
|
||||
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
result = fhelinalg.SubIntEintOp(resulting_type, *preds).result
|
||||
else:
|
||||
result = fhe.SubIntEintOp(resulting_type, *preds).result
|
||||
|
||||
return result
|
||||
|
||||
def convert_sum(self) -> OpResult:
|
||||
"""Convert a "sum" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
assert_true(len(self.node.inputs) == 1)
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
|
||||
node = cast(GenericFunction, self.node)
|
||||
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
|
||||
|
||||
axes = node.op_kwargs.get("axis", [])
|
||||
keep_dims = node.op_kwargs.get("keepdims", False)
|
||||
|
||||
if isinstance(axes, int):
|
||||
axes = [axes]
|
||||
elif isinstance(axes, tuple):
|
||||
axes = list(axes)
|
||||
|
||||
input_dimensions = len(cast(TensorValue, self.node.inputs[0]).shape)
|
||||
for i, axis in enumerate(axes):
|
||||
if axis < 0:
|
||||
axes[i] += input_dimensions
|
||||
|
||||
return fhelinalg.SumOp(
|
||||
resulting_type,
|
||||
self.preds[0],
|
||||
ArrayAttr.get([IntegerAttr.get(IntegerType.get_signless(64), axis) for axis in axes]),
|
||||
BoolAttr.get(keep_dims),
|
||||
).result
|
||||
|
||||
def convert_transpose(self) -> OpResult:
|
||||
"""Convert a Transpose node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
assert_true(len(self.node.inputs) == 1)
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
|
||||
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
|
||||
preds = self.preds
|
||||
return fhelinalg.TransposeOp(resulting_type, *preds).result
|
||||
@@ -1,235 +0,0 @@
|
||||
"""Utilities for MLIR conversion."""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from ..data_types.dtypes_helpers import (
|
||||
value_is_clear_scalar_integer,
|
||||
value_is_clear_tensor_integer,
|
||||
value_is_encrypted_scalar_integer,
|
||||
value_is_encrypted_tensor_integer,
|
||||
value_is_integer,
|
||||
value_is_unsigned_integer,
|
||||
)
|
||||
from ..debugging import format_operation_graph
|
||||
from ..debugging.custom_assert import assert_not_reached, assert_true
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation import intermediate
|
||||
from ..representation.intermediate import Conv2D, IntermediateNode
|
||||
|
||||
# TODO: should be removed as the supported bit-width is now dynamic
|
||||
ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB = 8
|
||||
|
||||
|
||||
def check_node_compatibility_with_mlir(
|
||||
node: IntermediateNode,
|
||||
nx_graph: nx.MultiDiGraph,
|
||||
is_output: bool,
|
||||
) -> Optional[str]:
|
||||
"""Check if node is compatible with MLIR.
|
||||
|
||||
Args:
|
||||
node (IntermediateNode): node to check
|
||||
nx_graph (nx.MultiDiGraph): the networkx graph to which node belongs
|
||||
is_output (bool): whether the node is an output node or not
|
||||
|
||||
Returns:
|
||||
Optional[str]: None if the node is compatible else reason for incompatibility
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-branches,too-many-return-statements
|
||||
|
||||
inputs = node.inputs
|
||||
outputs = node.outputs
|
||||
|
||||
if isinstance(node, intermediate.Add): # constraints for addition
|
||||
for inp in inputs:
|
||||
if not value_is_integer(inp):
|
||||
return "only integer addition is supported"
|
||||
|
||||
elif isinstance(node, intermediate.Sub): # constraints for subtraction
|
||||
for inp in inputs:
|
||||
if not value_is_integer(inp):
|
||||
return "only integer subtraction is supported"
|
||||
|
||||
elif isinstance(node, intermediate.Mul): # constraints for multiplication
|
||||
for inp in inputs:
|
||||
if not value_is_integer(inp):
|
||||
return "only integer multiplication is supported"
|
||||
|
||||
elif isinstance(node, intermediate.Input): # constraints for inputs
|
||||
assert_true(len(outputs) == 1)
|
||||
if not value_is_unsigned_integer(outputs[0]):
|
||||
return "only unsigned integer inputs are supported"
|
||||
|
||||
elif isinstance(node, intermediate.Constant): # constraints for constants
|
||||
assert_true(len(outputs) == 1)
|
||||
# We currently can't fail on the following assert, but let it for possible changes in the
|
||||
# future
|
||||
if not value_is_integer(outputs[0]):
|
||||
return "only integer constants are supported" # pragma: no cover
|
||||
|
||||
elif isinstance(node, intermediate.GenericFunction): # constraints for univariate functions
|
||||
for inp in inputs:
|
||||
if not value_is_integer(inp):
|
||||
return (
|
||||
f"{node.op_name} with floating-point inputs "
|
||||
f"is required to be fused to be supported"
|
||||
)
|
||||
|
||||
if node.op_kind == "TLU":
|
||||
assert_true(
|
||||
len(
|
||||
[
|
||||
pred_node
|
||||
for pred_node in nx_graph.pred[node]
|
||||
if not isinstance(pred_node, intermediate.Constant)
|
||||
]
|
||||
)
|
||||
== 1
|
||||
)
|
||||
else:
|
||||
if node.op_name not in ["flatten", "reshape", "sum", "concat", "transpose"]:
|
||||
return f"{node.op_name} is not supported for the time being"
|
||||
|
||||
elif isinstance(node, intermediate.Dot): # constraints for dot product
|
||||
assert_true(len(inputs) == 2)
|
||||
if not value_is_integer(inputs[0]) or not value_is_integer(inputs[1]):
|
||||
return "only integer dot product is supported"
|
||||
|
||||
elif isinstance(node, intermediate.IndexConstant): # constraints for constant indexing
|
||||
assert_true(len(outputs) == 1)
|
||||
|
||||
elif isinstance(node, intermediate.MatMul): # constraints for matrix multiplication
|
||||
assert_true(len(inputs) == 2)
|
||||
|
||||
elif isinstance(node, Conv2D):
|
||||
assert_true(len(inputs) in [2, 3])
|
||||
|
||||
else: # pragma: no cover
|
||||
assert_not_reached("Non IntermediateNode object in the OPGraph")
|
||||
|
||||
if is_output:
|
||||
for out in outputs:
|
||||
# For signed values and waiting for a real fix (#845): what is returned by the compiler
|
||||
# is not the (possibly negative) result r, but the always-positive (r mod 2**t), where t
|
||||
# is the bitwidth of r
|
||||
|
||||
# We currently can't fail on the following assert, but let it for possible changes in
|
||||
# the future
|
||||
if not value_is_integer(out):
|
||||
return "only integer outputs are supported" # pragma: no cover
|
||||
else:
|
||||
for out in outputs:
|
||||
# We currently can't fail on the following assert, but let it for possible changes in
|
||||
# the future
|
||||
if not value_is_integer(out):
|
||||
return "only integer intermediates are supported" # pragma: no cover
|
||||
|
||||
# pylint: enable=too-many-branches,too-many-return-statements
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def check_graph_values_compatibility_with_mlir(
|
||||
op_graph: OPGraph,
|
||||
) -> Optional[Dict[IntermediateNode, List[str]]]:
|
||||
"""Make sure the graph outputs are unsigned integers, which is what the compiler supports.
|
||||
|
||||
Args:
|
||||
op_graph: computation graph to check
|
||||
|
||||
Returns:
|
||||
Dict[IntermediateNode, str]: None if the graph is compatible
|
||||
information about offending nodes otherwise
|
||||
"""
|
||||
|
||||
offending_nodes = {}
|
||||
|
||||
for node in op_graph.graph.nodes:
|
||||
is_output = node in op_graph.output_nodes.values()
|
||||
if (
|
||||
reason := check_node_compatibility_with_mlir(node, op_graph.graph, is_output)
|
||||
) is not None:
|
||||
offending_nodes[node] = [reason]
|
||||
|
||||
return None if len(offending_nodes) == 0 else offending_nodes
|
||||
|
||||
|
||||
def _set_all_bit_width(op_graph: OPGraph, p: int):
|
||||
"""Set all bit_width in the graph to `p` and `p+1` for clear and encrypted values respectively.
|
||||
|
||||
Args:
|
||||
op_graph: graph to set bit_width for
|
||||
p: bit_width to set everywhere
|
||||
"""
|
||||
for node in op_graph.graph.nodes:
|
||||
for value in node.outputs + node.inputs:
|
||||
if value_is_clear_scalar_integer(value) or value_is_clear_tensor_integer(value):
|
||||
value.dtype.bit_width = p + 1
|
||||
elif value_is_encrypted_scalar_integer(value) or value_is_encrypted_tensor_integer(
|
||||
value
|
||||
):
|
||||
value.dtype.bit_width = p
|
||||
|
||||
|
||||
def get_op_graph_max_bit_width_and_nodes_over_bit_width_limit(
|
||||
op_graph: OPGraph,
|
||||
) -> Tuple[int, Dict[IntermediateNode, List[str]]]:
|
||||
"""Get the maximum bit width of integer nodes in the given OPGraph.
|
||||
|
||||
Also returns a dictionary with nodes having an unsupported bit width.
|
||||
|
||||
Args:
|
||||
op_graph: graph to update bit_width for
|
||||
|
||||
Returns:
|
||||
Tuple[int, Dict[IntermediateNode, List[str]]]: a tuple containing the maximum bit width of
|
||||
integer values in the OPGraph as well as a dictionary with nodes and the list of issues
|
||||
that the nodes have, in this case having an unsupported bit width.
|
||||
"""
|
||||
max_bit_width = 0
|
||||
offending_nodes: Dict[IntermediateNode, List[str]] = {}
|
||||
for node in op_graph.graph.nodes:
|
||||
for value_out in node.outputs:
|
||||
if value_is_clear_scalar_integer(value_out) or value_is_clear_tensor_integer(value_out):
|
||||
current_node_out_bit_width = value_out.dtype.bit_width - 1
|
||||
else:
|
||||
|
||||
assert_true(
|
||||
value_is_encrypted_scalar_integer(value_out)
|
||||
or value_is_encrypted_tensor_integer(value_out)
|
||||
)
|
||||
|
||||
current_node_out_bit_width = value_out.dtype.bit_width
|
||||
|
||||
max_bit_width = max(max_bit_width, current_node_out_bit_width)
|
||||
|
||||
# Check that current_node_out_bit_width is supported by the compiler
|
||||
if current_node_out_bit_width > ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB:
|
||||
offending_nodes[node] = [
|
||||
f"{current_node_out_bit_width} bits is not supported for the time being"
|
||||
]
|
||||
|
||||
return max_bit_width, offending_nodes
|
||||
|
||||
|
||||
def update_bit_width_for_mlir(op_graph: OPGraph):
|
||||
"""Prepare bit_width of all nodes to be the same, set to the maximum value in the graph.
|
||||
|
||||
Args:
|
||||
op_graph: graph to update bit_width for
|
||||
"""
|
||||
max_bit_width, offending_nodes = get_op_graph_max_bit_width_and_nodes_over_bit_width_limit(
|
||||
op_graph
|
||||
)
|
||||
|
||||
if len(offending_nodes) != 0:
|
||||
raise RuntimeError(
|
||||
f"max_bit_width of some nodes is too high for the current version of "
|
||||
f"the compiler (maximum must be {ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB}) "
|
||||
f"which is not compatible with:\n\n"
|
||||
+ format_operation_graph(op_graph, highlighted_nodes=offending_nodes)
|
||||
)
|
||||
|
||||
_set_all_bit_width(op_graph, max_bit_width)
|
||||
@@ -1,319 +0,0 @@
|
||||
"""Code to wrap and make manipulating networkx graphs easier."""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from .data_types.base import BaseDataType
|
||||
from .data_types.dtypes_helpers import (
|
||||
get_base_data_type_for_python_constant_data,
|
||||
get_constructor_for_python_constant_data,
|
||||
)
|
||||
from .data_types.floats import Float
|
||||
from .data_types.integers import Integer, make_integer_to_hold
|
||||
from .debugging.custom_assert import assert_true
|
||||
from .representation.intermediate import Input, IntermediateNode
|
||||
from .tracing import BaseTracer
|
||||
from .tracing.tracing_helpers import create_graph_from_output_tracers
|
||||
|
||||
|
||||
class OPGraph:
|
||||
"""Class to make work with nx graphs easier."""
|
||||
|
||||
graph: nx.MultiDiGraph
|
||||
input_nodes: Dict[int, Input]
|
||||
output_nodes: Dict[int, IntermediateNode]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: nx.MultiDiGraph,
|
||||
input_nodes: Dict[int, Input],
|
||||
output_nodes: Dict[int, IntermediateNode],
|
||||
) -> None:
|
||||
assert_true(
|
||||
all(isinstance(node, Input) for node in input_nodes.values()),
|
||||
"Got input nodes that were not Input, which is not supported",
|
||||
)
|
||||
assert_true(
|
||||
all(isinstance(node, IntermediateNode) for node in output_nodes.values()),
|
||||
"Got output nodes which were not IntermediateNode, which is not supported",
|
||||
)
|
||||
|
||||
self.graph = graph
|
||||
self.input_nodes = input_nodes
|
||||
self.output_nodes = output_nodes
|
||||
self.prune_nodes()
|
||||
|
||||
def __call__(self, *args) -> Union[Any, Tuple[Any, ...]]:
|
||||
assert_true(len(self.input_nodes) > 0, "Cannot evaluate a graph with no input nodes")
|
||||
inputs = dict(enumerate(args))
|
||||
|
||||
assert_true(
|
||||
len(inputs) == len(self.input_nodes),
|
||||
f"Expected {len(self.input_nodes)} arguments, got {len(inputs)} : {args}",
|
||||
)
|
||||
|
||||
results = self.evaluate(inputs)
|
||||
tuple_result = tuple(results[output_node] for output_node in self.get_ordered_outputs())
|
||||
return tuple_result if len(tuple_result) > 1 else tuple_result[0]
|
||||
|
||||
@staticmethod
|
||||
def from_output_tracers(output_tracers: Iterable[BaseTracer]) -> "OPGraph":
|
||||
"""Construct OPGraph from output tracers.
|
||||
|
||||
Args:
|
||||
output_tracers (Iterable[BaseTracer]): The tracers output by the function that was
|
||||
traced.
|
||||
|
||||
Returns:
|
||||
OPGraph: The resulting OPGraph.
|
||||
"""
|
||||
graph = create_graph_from_output_tracers(output_tracers)
|
||||
input_nodes = {
|
||||
node.program_input_idx: node
|
||||
for node in graph.nodes()
|
||||
if len(graph.pred[node]) == 0 and isinstance(node, Input)
|
||||
}
|
||||
output_nodes = {
|
||||
output_idx: tracer.traced_computation
|
||||
for output_idx, tracer in enumerate(output_tracers)
|
||||
}
|
||||
return OPGraph(graph, input_nodes, output_nodes)
|
||||
|
||||
@staticmethod
|
||||
def from_graph(
|
||||
graph: nx.MultiDiGraph,
|
||||
input_nodes: Iterable[Input],
|
||||
output_nodes: Iterable[IntermediateNode],
|
||||
) -> "OPGraph":
|
||||
"""Construct OPGraph from an existing networkx MultiDiGraph.
|
||||
|
||||
Args:
|
||||
graph (nx.MultiDiGraph): The networkx MultiDiGraph to use.
|
||||
input_nodes (Iterable[Input]): The input nodes of the MultiDiGraph.
|
||||
output_nodes (Iterable[IntermediateNode]): The output nodes of the MultiDiGraph.
|
||||
|
||||
Returns:
|
||||
OPGraph: The resulting OPGraph.
|
||||
"""
|
||||
return OPGraph(graph, dict(enumerate(input_nodes)), dict(enumerate(output_nodes)))
|
||||
|
||||
def get_ordered_inputs(self) -> List[Input]:
|
||||
"""Get the input nodes of the graph, ordered by their index.
|
||||
|
||||
Returns:
|
||||
List[Input]: ordered input nodes
|
||||
"""
|
||||
return [self.input_nodes[idx] for idx in range(len(self.input_nodes))]
|
||||
|
||||
def get_ordered_outputs(self) -> List[IntermediateNode]:
|
||||
"""Get the output nodes of the graph, ordered by their index.
|
||||
|
||||
Returns:
|
||||
List[IntermediateNode]: ordered input nodes
|
||||
"""
|
||||
return [self.output_nodes[idx] for idx in range(len(self.output_nodes))]
|
||||
|
||||
def get_ordered_preds(self, node: IntermediateNode) -> List[IntermediateNode]:
|
||||
"""Get node predecessors ordered by their indices.
|
||||
|
||||
Args:
|
||||
node (IntermediateNode): The node for which we want the ordered predecessors.
|
||||
|
||||
Returns:
|
||||
List[IntermediateNode]: The list of predecessors ordered by input index.
|
||||
"""
|
||||
# Replication of pred is managed e.g. x + x will yield the proper pred x twice
|
||||
idx_to_pred: Dict[int, IntermediateNode] = {}
|
||||
for pred in self.graph.predecessors(node):
|
||||
edge_data = self.graph.get_edge_data(pred, node)
|
||||
idx_to_pred.update((data["input_idx"], pred) for data in edge_data.values())
|
||||
return [idx_to_pred[i] for i in range(len(idx_to_pred))]
|
||||
|
||||
def get_ordered_preds_and_inputs_of(
|
||||
self, node: IntermediateNode
|
||||
) -> List[Tuple[IntermediateNode, int]]:
|
||||
"""Get node preds and inputs ordered by their indices.
|
||||
|
||||
Args:
|
||||
node (IntermediateNode): the node for which we want the ordered inputs
|
||||
|
||||
Returns:
|
||||
List[Tuple[IntermediateNode, int]]: the ordered list of preds and inputs
|
||||
"""
|
||||
|
||||
idx_to_inp: Dict[int, Tuple[IntermediateNode, int]] = {}
|
||||
for pred in self.graph.predecessors(node):
|
||||
edge_data = self.graph.get_edge_data(pred, node)
|
||||
idx_to_inp.update(
|
||||
(data["input_idx"], (pred, data["output_idx"])) for data in edge_data.values()
|
||||
)
|
||||
return [idx_to_inp[i] for i in range(len(idx_to_inp))]
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Dict[IntermediateNode, Any]:
|
||||
"""Evaluate a graph and get intermediate values for all nodes.
|
||||
|
||||
Args:
|
||||
inputs (Dict[int, Any]): The inputs to the program
|
||||
|
||||
Returns:
|
||||
Dict[IntermediateNode, Any]: Dictionary with node as keys and resulting values
|
||||
"""
|
||||
node_results: Dict[IntermediateNode, Any] = {}
|
||||
|
||||
def get_result_of_node_at_index(node: IntermediateNode, output_idx: int) -> Any:
|
||||
"""Get the output result at index output_idx for a node.
|
||||
|
||||
Args:
|
||||
node (IntermediateNode): the node from which we want the output.
|
||||
output_idx (int): which output we want.
|
||||
|
||||
Returns:
|
||||
Any: the output value of the evaluation of node.
|
||||
"""
|
||||
result = node_results[node]
|
||||
# TODO: #81 remove no cover once we have nodes with multiple outputs
|
||||
if isinstance(result, tuple): # pragma: no cover
|
||||
# If the node has multiple outputs (i.e. the result is a tuple), return the
|
||||
# requested output
|
||||
return result[output_idx]
|
||||
# If the result is not a tuple, then the result is the node's only output. Check that
|
||||
# the requested index is 0 (as it's the only valid value) and return the result itself.
|
||||
assert_true(
|
||||
output_idx == 0,
|
||||
f"Unable to get output at index {output_idx} for node {node}.\n"
|
||||
f"Node result: {result}",
|
||||
)
|
||||
return result
|
||||
|
||||
for node in nx.topological_sort(self.graph):
|
||||
if not isinstance(node, Input):
|
||||
curr_inputs = {}
|
||||
for pred_node in self.graph.predecessors(node):
|
||||
edges = self.graph.get_edge_data(pred_node, node)
|
||||
curr_inputs.update(
|
||||
{
|
||||
edge["input_idx"]: get_result_of_node_at_index(
|
||||
pred_node,
|
||||
output_idx=edge["output_idx"],
|
||||
)
|
||||
for edge in edges.values()
|
||||
}
|
||||
)
|
||||
node_results[node] = node.evaluate(curr_inputs)
|
||||
else:
|
||||
node_results[node] = node.evaluate({0: inputs[node.program_input_idx]})
|
||||
|
||||
return node_results
|
||||
|
||||
def update_values_with_bounds_and_samples(
|
||||
self,
|
||||
node_bounds_and_samples: dict,
|
||||
get_base_data_type_for_constant_data: Callable[
|
||||
[Any], BaseDataType
|
||||
] = get_base_data_type_for_python_constant_data,
|
||||
get_constructor_for_constant_data: Callable[
|
||||
..., Callable
|
||||
] = get_constructor_for_python_constant_data,
|
||||
):
|
||||
"""Update values with bounds.
|
||||
|
||||
Update nodes inputs and outputs values with data types able to hold data ranges measured
|
||||
and passed in nodes_bounds
|
||||
|
||||
Args:
|
||||
node_bounds_and_samples (dict): Dictionary with nodes as keys, holding dicts with a
|
||||
'min', 'max' and 'sample' keys. Those bounds will be taken as the data range to be
|
||||
represented, per node. The sample allows to determine the data constructors to
|
||||
prepare the GenericFunction nodes for table generation.
|
||||
get_base_data_type_for_constant_data (Callable[ [Any], BaseDataType ], optional): This
|
||||
is a callback function to convert data encountered during value updates to
|
||||
BaseDataType. This allows to manage data coming from foreign frameworks without
|
||||
specialising OPGraph. Defaults to get_base_data_type_for_python_constant_data.
|
||||
get_constructor_for_constant_data (Callable[ ..., Callable ], optional): This is a
|
||||
callback function to determine the type constructor of the data encountered while
|
||||
updating the graph bounds. Defaults to get_constructor_for_python_constant_data.
|
||||
"""
|
||||
node: IntermediateNode
|
||||
|
||||
for node in self.graph.nodes():
|
||||
current_node_bounds_and_samples = node_bounds_and_samples[node]
|
||||
min_bound, max_bound, sample = (
|
||||
current_node_bounds_and_samples["min"],
|
||||
current_node_bounds_and_samples["max"],
|
||||
current_node_bounds_and_samples["sample"],
|
||||
)
|
||||
|
||||
min_data_type = get_base_data_type_for_constant_data(min_bound)
|
||||
max_data_type = get_base_data_type_for_constant_data(max_bound)
|
||||
|
||||
# This is a sanity check
|
||||
min_value_constructor = get_constructor_for_constant_data(min_bound)
|
||||
max_value_constructor = get_constructor_for_constant_data(max_bound)
|
||||
|
||||
assert_true(
|
||||
max_value_constructor == min_value_constructor,
|
||||
(
|
||||
f"Got two different type constructors for min and max bound: "
|
||||
f"{min_value_constructor}, {max_value_constructor}"
|
||||
),
|
||||
)
|
||||
|
||||
value_constructor = get_constructor_for_constant_data(sample)
|
||||
|
||||
if not isinstance(node, Input):
|
||||
for output_value in node.outputs:
|
||||
if isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer):
|
||||
output_value.dtype = make_integer_to_hold(
|
||||
(min_bound, max_bound), force_signed=False
|
||||
)
|
||||
else:
|
||||
assert_true(
|
||||
isinstance(min_data_type, Float) and isinstance(max_data_type, Float),
|
||||
(
|
||||
"min_bound and max_bound have different common types, "
|
||||
"this should never happen.\n"
|
||||
f"min_bound: {min_data_type}, max_bound: {max_data_type}"
|
||||
),
|
||||
)
|
||||
output_value.dtype = Float(64)
|
||||
output_value.underlying_constructor = value_constructor
|
||||
else:
|
||||
# Currently variable inputs are only allowed to be integers
|
||||
assert_true(
|
||||
isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer),
|
||||
(
|
||||
f"Inputs to a graph should be integers, got bounds that were float, \n"
|
||||
f"min: {min_bound} ({type(min_bound)}), "
|
||||
f"max: {max_bound} ({type(max_bound)})"
|
||||
),
|
||||
)
|
||||
node.inputs[0].dtype = make_integer_to_hold(
|
||||
(min_bound, max_bound), force_signed=False
|
||||
)
|
||||
node.inputs[0].underlying_constructor = value_constructor
|
||||
|
||||
node.outputs[0] = deepcopy(node.inputs[0])
|
||||
|
||||
successors = self.graph.successors(node)
|
||||
for succ in successors:
|
||||
edge_data = self.graph.get_edge_data(node, succ)
|
||||
for edge in edge_data.values():
|
||||
input_idx, output_idx = edge["input_idx"], edge["output_idx"]
|
||||
succ.inputs[input_idx] = deepcopy(node.outputs[output_idx])
|
||||
|
||||
def prune_nodes(self):
|
||||
"""Remove unreachable nodes from outputs."""
|
||||
|
||||
current_nodes = {node: None for node in self.get_ordered_outputs()}
|
||||
useful_nodes: Dict[IntermediateNode, None] = {}
|
||||
while current_nodes:
|
||||
next_nodes: Dict[IntermediateNode, None] = {}
|
||||
useful_nodes.update(current_nodes)
|
||||
for node in current_nodes:
|
||||
next_nodes.update({node: None for node in self.graph.predecessors(node)})
|
||||
current_nodes = next_nodes
|
||||
|
||||
useless_nodes = [node for node in self.graph.nodes() if node not in useful_nodes]
|
||||
self.graph.remove_nodes_from(useless_nodes)
|
||||
@@ -1 +0,0 @@
|
||||
"""Module holding various optimization/simplification code."""
|
||||
@@ -1,594 +0,0 @@
|
||||
"""File holding topological optimization/simplification code."""
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from typing import DefaultDict, Dict, Iterable, List, Optional, Set, Tuple, cast
|
||||
|
||||
import networkx as nx
|
||||
from loguru import logger
|
||||
|
||||
from ..compilation.artifacts import CompilationArtifacts
|
||||
from ..data_types.floats import Float
|
||||
from ..data_types.integers import Integer
|
||||
from ..debugging import format_operation_graph
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import Constant, GenericFunction, Input, IntermediateNode
|
||||
from ..values import TensorValue
|
||||
|
||||
|
||||
def fuse_float_operations(
|
||||
op_graph: OPGraph,
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
):
|
||||
"""Find and fuse float domains into single Integer to Integer GenericFunction.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): The OPGraph to simplify
|
||||
compilation_artifacts (Optional[CompilationArtifacts]): The CompilationArtifacts of the
|
||||
current compilation, this argument is optional as it's not required to execute float
|
||||
fusing.
|
||||
"""
|
||||
|
||||
nx_graph = op_graph.graph
|
||||
processed_terminal_nodes: Set[IntermediateNode] = set()
|
||||
number_of_fuse = 0
|
||||
while True:
|
||||
float_subgraph_search_result = find_float_subgraph_with_unique_terminal_node(
|
||||
nx_graph, processed_terminal_nodes
|
||||
)
|
||||
if float_subgraph_search_result is None:
|
||||
break
|
||||
|
||||
float_subgraph_start_nodes, terminal_node, subgraph_all_nodes = float_subgraph_search_result
|
||||
processed_terminal_nodes.add(terminal_node)
|
||||
|
||||
subgraph_conversion_result = convert_float_subgraph_to_fused_node(
|
||||
op_graph,
|
||||
float_subgraph_start_nodes,
|
||||
terminal_node,
|
||||
subgraph_all_nodes,
|
||||
)
|
||||
|
||||
# Not a subgraph we can handle, continue
|
||||
if subgraph_conversion_result is None:
|
||||
continue
|
||||
|
||||
fused_node, node_before_subgraph = subgraph_conversion_result
|
||||
|
||||
nx_graph.add_node(fused_node)
|
||||
|
||||
if terminal_node in op_graph.output_nodes.values():
|
||||
# Output value replace it
|
||||
# As the graph changes recreate the output_node_to_idx dict
|
||||
output_node_to_idx: Dict[IntermediateNode, List[int]] = {
|
||||
out_node: [] for out_node in op_graph.output_nodes.values()
|
||||
}
|
||||
for output_idx, output_node in op_graph.output_nodes.items():
|
||||
output_node_to_idx[output_node].append(output_idx)
|
||||
|
||||
for output_idx in output_node_to_idx.get(terminal_node, []):
|
||||
op_graph.output_nodes[output_idx] = fused_node
|
||||
|
||||
# Disconnect after terminal node and connect fused node instead
|
||||
terminal_node_succ = list(nx_graph.successors(terminal_node))
|
||||
for succ in terminal_node_succ:
|
||||
succ_edge_data = deepcopy(nx_graph.get_edge_data(terminal_node, succ))
|
||||
for edge_key, edge_data in succ_edge_data.items():
|
||||
nx_graph.remove_edge(terminal_node, succ, key=edge_key)
|
||||
# fused_node is always a GenericFunction so output_idx == 0 always
|
||||
new_edge_data = deepcopy(edge_data)
|
||||
new_edge_data["output_idx"] = 0
|
||||
nx_graph.add_edge(fused_node, succ, key=edge_key, **new_edge_data)
|
||||
|
||||
# Connect the node feeding the subgraph contained in fused_node
|
||||
# node_before_subgraph has a single integer output currently so output_idx == 0
|
||||
nx_graph.add_edge(node_before_subgraph, fused_node, input_idx=0, output_idx=0)
|
||||
|
||||
op_graph.prune_nodes()
|
||||
if compilation_artifacts is not None:
|
||||
compilation_artifacts.add_operation_graph(
|
||||
f"after-float-fuse-{number_of_fuse}", op_graph
|
||||
)
|
||||
|
||||
number_of_fuse += 1
|
||||
|
||||
|
||||
def convert_float_subgraph_to_fused_node(
|
||||
op_graph: OPGraph,
|
||||
float_subgraph_start_nodes: Dict[IntermediateNode, None],
|
||||
terminal_node: IntermediateNode,
|
||||
subgraph_all_nodes: Dict[IntermediateNode, None],
|
||||
) -> Optional[Tuple[GenericFunction, IntermediateNode]]:
|
||||
"""Convert a float subgraph to an equivalent fused GenericFunction node.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): The OPGraph the float subgraph is part of.
|
||||
float_subgraph_start_nodes (Dict[IntermediateNode, None]): The nodes starting the float
|
||||
subgraph in `op_graph`.
|
||||
terminal_node (IntermediateNode): The node ending the float subgraph.
|
||||
subgraph_all_nodes (Dict[IntermediateNode, None]): All the nodes in the float subgraph.
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[GenericFunction, IntermediateNode]]: None if the float subgraph
|
||||
cannot be fused, otherwise returns a tuple containing the fused node and the node whose
|
||||
output must be plugged as the input to the subgraph.
|
||||
"""
|
||||
|
||||
node_with_issues_for_fusing: DefaultDict[IntermediateNode, List[str]] = defaultdict(list)
|
||||
|
||||
subgraph_can_be_fused = subgraph_has_unique_variable_input(
|
||||
float_subgraph_start_nodes, terminal_node, node_with_issues_for_fusing
|
||||
)
|
||||
|
||||
if subgraph_can_be_fused:
|
||||
# subgraph_values_allow_fusing can be called iff the subgraph has a unique variable input
|
||||
subgraph_can_be_fused = subgraph_nodes_and_values_allow_fusing(
|
||||
float_subgraph_start_nodes, subgraph_all_nodes, node_with_issues_for_fusing
|
||||
)
|
||||
|
||||
# This test is separate from the previous one to only handle printing issues once
|
||||
if not subgraph_can_be_fused:
|
||||
float_subgraph = nx.MultiDiGraph(op_graph.graph.subgraph(subgraph_all_nodes))
|
||||
float_subgraph_as_op_graph = OPGraph.from_graph(float_subgraph, [], [terminal_node])
|
||||
|
||||
printable_graph = format_operation_graph(
|
||||
float_subgraph_as_op_graph,
|
||||
highlighted_nodes=node_with_issues_for_fusing,
|
||||
)
|
||||
message = f"The following subgraph is not fusable:\n\n{printable_graph}"
|
||||
logger.warning(message)
|
||||
return None
|
||||
|
||||
# Only one variable input node, find which node feeds its input
|
||||
variable_input_nodes = [
|
||||
node for node in float_subgraph_start_nodes if not isinstance(node, Constant)
|
||||
]
|
||||
assert_true(len(variable_input_nodes) == 1)
|
||||
|
||||
current_subgraph_variable_input = variable_input_nodes[0]
|
||||
assert_true(len(current_subgraph_variable_input.outputs) == 1)
|
||||
new_input_value = deepcopy(current_subgraph_variable_input.outputs[0])
|
||||
|
||||
nx_graph = op_graph.graph
|
||||
|
||||
nodes_after_input_set = {
|
||||
node: None
|
||||
for node in subgraph_all_nodes
|
||||
if node in nx_graph.succ[current_subgraph_variable_input]
|
||||
}
|
||||
|
||||
# # Previous non-deterministic implementation :
|
||||
# # For some reason creating a graph from a subgraph this way is not deterministic
|
||||
# float_subgraph = nx.MultiDiGraph(nx_graph.subgraph(subgraph_all_nodes))
|
||||
|
||||
# Create a copy of the graph, remove nodes that are not in all the subgraph nodes in order to
|
||||
# get a subgraph deterministically
|
||||
float_subgraph = nx.MultiDiGraph(nx_graph)
|
||||
nodes_to_remove = [node for node in float_subgraph.nodes() if node not in subgraph_all_nodes]
|
||||
float_subgraph.remove_nodes_from(nodes_to_remove)
|
||||
|
||||
new_subgraph_variable_input = Input(new_input_value, "float_subgraph_input", 0)
|
||||
float_subgraph.add_node(new_subgraph_variable_input)
|
||||
|
||||
for node_after_input in nodes_after_input_set:
|
||||
# Connect the new input to our subgraph
|
||||
edge_data_input_to_subgraph = deepcopy(
|
||||
float_subgraph.get_edge_data(
|
||||
current_subgraph_variable_input,
|
||||
node_after_input,
|
||||
)
|
||||
)
|
||||
for edge_key, edge_data in edge_data_input_to_subgraph.items():
|
||||
float_subgraph.remove_edge(
|
||||
current_subgraph_variable_input, node_after_input, key=edge_key
|
||||
)
|
||||
# new_subgraph_variable_input is always an Input so output_idx == 0 always
|
||||
new_edge_data = deepcopy(edge_data)
|
||||
new_edge_data["output_idx"] = 0
|
||||
float_subgraph.add_edge(
|
||||
new_subgraph_variable_input,
|
||||
node_after_input,
|
||||
key=edge_key,
|
||||
**new_edge_data,
|
||||
)
|
||||
|
||||
float_op_subgraph = OPGraph.from_graph(
|
||||
float_subgraph,
|
||||
[new_subgraph_variable_input],
|
||||
[terminal_node],
|
||||
)
|
||||
|
||||
assert_true(len(terminal_node.outputs) == 1)
|
||||
|
||||
# Create fused_node
|
||||
fused_node = GenericFunction(
|
||||
inputs=[new_subgraph_variable_input.inputs[0]],
|
||||
arbitrary_func=lambda x, float_op_subgraph, terminal_node: float_op_subgraph.evaluate(
|
||||
{0: x}
|
||||
)[terminal_node],
|
||||
output_value=terminal_node.outputs[0],
|
||||
op_kind="TLU",
|
||||
op_kwargs={
|
||||
"float_op_subgraph": float_op_subgraph,
|
||||
"terminal_node": terminal_node,
|
||||
},
|
||||
op_name="subgraph",
|
||||
)
|
||||
|
||||
return (
|
||||
fused_node,
|
||||
current_subgraph_variable_input,
|
||||
)
|
||||
|
||||
|
||||
def is_single_int_output_node(node: IntermediateNode) -> bool:
|
||||
"""Check if a node has a single output and that output is an integer.
|
||||
|
||||
Args:
|
||||
node (IntermediateNode): the node to check.
|
||||
|
||||
Returns:
|
||||
bool: returns True if the node has a single integer output, False otherwise.
|
||||
"""
|
||||
return len(node.outputs) == 1 and isinstance(node.outputs[0].dtype, Integer)
|
||||
|
||||
|
||||
def find_closest_single_int_output_nodes(
|
||||
nx_graph: nx.MultiDiGraph,
|
||||
start_nodes: List[IntermediateNode],
|
||||
subgraph_all_nodes: Dict[IntermediateNode, None],
|
||||
) -> Tuple[Dict[IntermediateNode, None], Dict[IntermediateNode, None]]:
|
||||
"""Find in nx_graph the closest upstream single integer output nodes to some start nodes.
|
||||
|
||||
Args:
|
||||
nx_graph (nx.MultiDiGraph): the networkx graph to search in.
|
||||
start_nodes (List[IntermediateNode]): the nodes from which to start the search.
|
||||
subgraph_all_nodes (Dict[IntermediateNode, None]): a set that will be updated with all the
|
||||
nodes visited during the search.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[IntermediateNode, None], Dict[IntermediateNode, None]]: returns the dict used as
|
||||
an ordered set containing the found single output nodes and the updated set of the
|
||||
visited nodes during the search.
|
||||
"""
|
||||
|
||||
# Use dict as ordered set
|
||||
current_nodes = {start_node: None for start_node in start_nodes}
|
||||
closest_single_int_output_nodes: Dict[IntermediateNode, None] = {}
|
||||
visited_nodes: Set[IntermediateNode] = set()
|
||||
while current_nodes:
|
||||
next_nodes: Dict[IntermediateNode, None] = {}
|
||||
for node in current_nodes:
|
||||
if node in visited_nodes:
|
||||
continue
|
||||
visited_nodes.add(node)
|
||||
subgraph_all_nodes.update({node: None})
|
||||
predecessors = nx_graph.predecessors(node)
|
||||
for pred in predecessors:
|
||||
if is_single_int_output_node(pred):
|
||||
# Limit of subgraph, record that and record the node as we won't visit it
|
||||
closest_single_int_output_nodes.update({pred: None})
|
||||
subgraph_all_nodes.update({pred: None})
|
||||
else:
|
||||
next_nodes.update({pred: None})
|
||||
current_nodes = next_nodes
|
||||
|
||||
return closest_single_int_output_nodes, subgraph_all_nodes
|
||||
|
||||
|
||||
def add_nodes_from_to(
|
||||
nx_graph: nx.MultiDiGraph,
|
||||
from_nodes: Iterable[IntermediateNode],
|
||||
to_nodes: Dict[IntermediateNode, None],
|
||||
subgraph_all_nodes: Dict[IntermediateNode, None],
|
||||
) -> Dict[IntermediateNode, None]:
|
||||
"""Add nodes from from_nodes to to_nodes to the subgraph_all_nodes set.
|
||||
|
||||
Args:
|
||||
nx_graph (nx.MultiDiGraph): the graph to traverse.
|
||||
from_nodes (Iterable[IntermediateNode]): the nodes from which we will add nodes to
|
||||
subgraph_all_nodes.
|
||||
to_nodes (Dict[IntermediateNode, None]): the nodes we should stop at.
|
||||
subgraph_all_nodes (Dict[IntermediateNode, None]): All the nodes in the float subgraph, will
|
||||
be updated and returned.
|
||||
|
||||
Returns:
|
||||
Dict[IntermediateNode, None]: returns the updated subgraph_all_nodes.
|
||||
"""
|
||||
|
||||
# Add the end nodes we won't visit
|
||||
subgraph_all_nodes.update(to_nodes)
|
||||
|
||||
current_nodes = {from_node: None for from_node in from_nodes}
|
||||
visited_nodes: Set[IntermediateNode] = set()
|
||||
while current_nodes:
|
||||
next_nodes: Dict[IntermediateNode, None] = {}
|
||||
for node in current_nodes:
|
||||
if node in visited_nodes:
|
||||
continue
|
||||
visited_nodes.add(node)
|
||||
subgraph_all_nodes.update({node: None})
|
||||
predecessors = nx_graph.predecessors(node)
|
||||
# Add nodes to explore next if they are not indicated as end nodes
|
||||
next_nodes.update({pred: None for pred in predecessors if pred not in to_nodes})
|
||||
current_nodes = next_nodes
|
||||
|
||||
return subgraph_all_nodes
|
||||
|
||||
|
||||
def find_float_subgraph_with_unique_terminal_node(
|
||||
nx_graph: nx.MultiDiGraph,
|
||||
processed_terminal_nodes: Set[IntermediateNode],
|
||||
) -> Optional[Tuple[Dict[IntermediateNode, None], IntermediateNode, Dict[IntermediateNode, None]]]:
|
||||
"""Find a subgraph of the graph with float computations.
|
||||
|
||||
Args:
|
||||
nx_graph (nx.MultiDiGraph): The networkx graph to search in.
|
||||
processed_terminal_nodes (Dict[IntermediateNode, None]): The set of terminal nodes for which
|
||||
subgraphs have already been searched, those will be skipped.
|
||||
|
||||
Returns:
|
||||
Optional[
|
||||
Tuple[Dict[IntermediateNode, None], IntermediateNode, Dict[IntermediateNode, None]]]:
|
||||
None if there are no float subgraphs to process in `nx_graph`. Otherwise returns a
|
||||
tuple containing the set of nodes beginning a float subgraph, the terminal node of
|
||||
the subgraph and the set of all the nodes in the subgraph.
|
||||
"""
|
||||
|
||||
def is_float_to_single_int_node(node: IntermediateNode) -> bool:
|
||||
return (
|
||||
any(isinstance(input_.dtype, Float) for input_ in node.inputs)
|
||||
and len(node.outputs) == 1
|
||||
and isinstance(node.outputs[0].dtype, Integer)
|
||||
)
|
||||
|
||||
float_subgraphs_terminal_nodes = (
|
||||
node
|
||||
for node in nx_graph.nodes()
|
||||
if is_float_to_single_int_node(node) and node not in processed_terminal_nodes
|
||||
)
|
||||
|
||||
terminal_node: IntermediateNode
|
||||
|
||||
try:
|
||||
terminal_node = next(float_subgraphs_terminal_nodes)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
# networkx does not implement lowest common ancestor search for multidigraph, but we only care
|
||||
# about parent relationship here and not the meaning of edges, so we can convert our
|
||||
# multidigraph to a digraph and use the lca search algorithm (if needed), we create the
|
||||
# equivalent digraph here as it will avoid recreating it in a loop. Constant nodes could cause
|
||||
# issues in our search so we remove them.
|
||||
equivalent_digraph_without_constants = nx.DiGraph(nx_graph)
|
||||
constant_graph_nodes = [
|
||||
constant_node
|
||||
for constant_node in equivalent_digraph_without_constants.nodes()
|
||||
if isinstance(constant_node, Constant)
|
||||
]
|
||||
equivalent_digraph_without_constants.remove_nodes_from(constant_graph_nodes)
|
||||
|
||||
# Use dict as ordered set
|
||||
subgraph_all_nodes: Dict[IntermediateNode, None] = {}
|
||||
|
||||
start_single_int_output_nodes_search_from = terminal_node
|
||||
|
||||
while True:
|
||||
float_subgraph_start_nodes, subgraph_all_nodes = find_closest_single_int_output_nodes(
|
||||
nx_graph,
|
||||
[start_single_int_output_nodes_search_from],
|
||||
subgraph_all_nodes,
|
||||
)
|
||||
|
||||
variable_start_nodes = [
|
||||
start_node
|
||||
for start_node in float_subgraph_start_nodes
|
||||
if not isinstance(start_node, Constant)
|
||||
]
|
||||
|
||||
# We found a single input variable node
|
||||
if len(variable_start_nodes) == 1:
|
||||
break
|
||||
|
||||
# Otherwise find a common ancestor as we need a single variable input node
|
||||
# lca == lowest common ancestor
|
||||
# lca search only works for node pairs in networkx, so we progressively find the ancestors
|
||||
# setting the lca by default to one of the nodes we are searching the lca for
|
||||
lca = variable_start_nodes.pop()
|
||||
|
||||
while len(variable_start_nodes) > 0 and lca is not None:
|
||||
node_to_find_lca = variable_start_nodes.pop()
|
||||
lca = nx.algorithms.lowest_common_ancestors.lowest_common_ancestor(
|
||||
equivalent_digraph_without_constants, lca, node_to_find_lca, default=None
|
||||
)
|
||||
|
||||
# The subgraph cannot be fused as there is no way to find a common ancestor
|
||||
if lca is None:
|
||||
break
|
||||
|
||||
# if lca is not None, add the nodes from the current start nodes to the lca to
|
||||
# subgraph_all_nodes
|
||||
subgraph_all_nodes = add_nodes_from_to(
|
||||
nx_graph, float_subgraph_start_nodes, {lca: None}, subgraph_all_nodes
|
||||
)
|
||||
|
||||
# if the lca is a valid starting node for fusing break
|
||||
if is_single_int_output_node(lca):
|
||||
# the lca is our new start node
|
||||
float_subgraph_start_nodes = {lca: None}
|
||||
break
|
||||
|
||||
# otherwise push a little bit further the search (if there is a node just before that has an
|
||||
# integer output e.g.)
|
||||
start_single_int_output_nodes_search_from = lca
|
||||
|
||||
return float_subgraph_start_nodes, terminal_node, subgraph_all_nodes
|
||||
|
||||
|
||||
def subgraph_nodes_and_values_allow_fusing(
|
||||
float_subgraph_start_nodes: Dict[IntermediateNode, None],
|
||||
subgraph_all_nodes: Dict[IntermediateNode, None],
|
||||
node_with_issues_for_fusing: DefaultDict[IntermediateNode, List[str]],
|
||||
) -> bool:
|
||||
"""Check if a subgraph's values are compatible with fusing.
|
||||
|
||||
A fused subgraph for example only works on an input tensor if the resulting GenericFunction
|
||||
can be applied per cell, hence shuffling or tensor shape changes make fusing impossible.
|
||||
|
||||
Args:
|
||||
float_subgraph_start_nodes (Dict[IntermediateNode, None]): The nodes starting the float
|
||||
subgraph.
|
||||
subgraph_all_nodes (Dict[IntermediateNode, None]): All the nodes in the float subgraph.
|
||||
node_with_issues_for_fusing (DefaultDict[IntermediateNode, List[str]]): Dictionary to fill
|
||||
with potential nodes issues preventing fusing.
|
||||
|
||||
Returns:
|
||||
bool: True if all inputs and outputs of the nodes in the subgraph are compatible with fusing
|
||||
i.e. outputs have the same shapes equal to the variable input.
|
||||
"""
|
||||
|
||||
node: IntermediateNode
|
||||
|
||||
variable_input_nodes = [
|
||||
node for node in float_subgraph_start_nodes if not isinstance(node, Constant)
|
||||
]
|
||||
|
||||
assert_true(
|
||||
(num_variable_input_nodes := len(variable_input_nodes)) == 1,
|
||||
f"{subgraph_nodes_and_values_allow_fusing.__name__} "
|
||||
f"only works for subgraphs with 1 variable input node, got {num_variable_input_nodes}",
|
||||
)
|
||||
|
||||
explicitely_non_fusable = [
|
||||
node
|
||||
for node in subgraph_all_nodes
|
||||
if isinstance(node, GenericFunction) and not node.op_attributes["fusable"]
|
||||
]
|
||||
for node in explicitely_non_fusable:
|
||||
node_with_issues_for_fusing[node].append(
|
||||
"this node is explicitely marked by the package as non-fusable"
|
||||
)
|
||||
if len(explicitely_non_fusable) > 0:
|
||||
return False
|
||||
|
||||
all_values_are_tensors = all(
|
||||
all(isinstance(input_, TensorValue) for input_ in node.inputs)
|
||||
and all(isinstance(output, TensorValue) for output in node.outputs)
|
||||
for node in subgraph_all_nodes
|
||||
)
|
||||
|
||||
if not all_values_are_tensors:
|
||||
# This cannot be reached today as scalars are Tensors with shape == () (numpy convention)
|
||||
return False # pragma: no cover
|
||||
|
||||
variable_input_node = variable_input_nodes[0]
|
||||
|
||||
# A cheap check is that the variable input node must have the biggest size, i.e. have the most
|
||||
# elements, meaning all constants will broadcast to its shape. This is because the
|
||||
# GenericFunction input and output must have the same shape so that it can be applied to each
|
||||
# of the input tensor cells.
|
||||
# There *may* be a way to manage the other case by simulating the broadcast of the smaller input
|
||||
# array and then concatenating/stacking the results. This is not currently doable as we don't
|
||||
# have a concatenate operator on the compiler side.
|
||||
# TODO: #587 https://github.com/zama-ai/concrete-numpy-internal/issues/587
|
||||
|
||||
variable_input_node_output = cast(TensorValue, variable_input_node.outputs[0])
|
||||
variable_input_node_output_size, variable_input_node_output_shape = (
|
||||
variable_input_node_output.size,
|
||||
variable_input_node_output.shape,
|
||||
)
|
||||
|
||||
constant_nodes_with_bigger_size_than_variable_input = [
|
||||
constant_input_node
|
||||
for constant_input_node in subgraph_all_nodes
|
||||
if isinstance(constant_input_node, Constant)
|
||||
and cast(TensorValue, constant_input_node.outputs[0]).size > variable_input_node_output_size
|
||||
]
|
||||
|
||||
for bigger_constant_node in constant_nodes_with_bigger_size_than_variable_input:
|
||||
bigger_constant_node_shape = cast(TensorValue, bigger_constant_node.outputs[0]).shape
|
||||
node_with_issues_for_fusing[bigger_constant_node].append(
|
||||
f"this constant node has a bigger shape {bigger_constant_node_shape} "
|
||||
f"than the subgraph's input: {variable_input_node_output_shape}"
|
||||
)
|
||||
|
||||
if len(constant_nodes_with_bigger_size_than_variable_input) > 0:
|
||||
node_with_issues_for_fusing[variable_input_node].append(
|
||||
f"input node with shape {variable_input_node_output_shape}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Now that we know the variable input node has the biggest size we can check shapes are
|
||||
# consistent throughout the subgraph: outputs of ir nodes that are not constant must be equal.
|
||||
|
||||
non_constant_nodes = (node for node in subgraph_all_nodes if not isinstance(node, Constant))
|
||||
|
||||
nodes_with_different_output_shapes = {
|
||||
node: [
|
||||
(output_idx, output.shape)
|
||||
for output_idx, output in enumerate(node.outputs)
|
||||
if isinstance(output, TensorValue) and output.shape != variable_input_node
|
||||
]
|
||||
for node in non_constant_nodes
|
||||
if any(
|
||||
isinstance(output, TensorValue) and output.shape != variable_input_node_output_shape
|
||||
for output in node.outputs
|
||||
)
|
||||
}
|
||||
|
||||
for node, node_shape_infos in nodes_with_different_output_shapes.items():
|
||||
shape_issue_details = "; ".join(
|
||||
f"#{output_idx}, {output_shape}" for output_idx, output_shape in node_shape_infos
|
||||
)
|
||||
node_with_issues_for_fusing[node].append(
|
||||
f"output shapes: {shape_issue_details} are not the same as the subgraph's input: "
|
||||
f"{variable_input_node_output_shape}"
|
||||
)
|
||||
|
||||
all_nodes_have_same_shape_as_input = len(nodes_with_different_output_shapes) == 0
|
||||
|
||||
if not all_nodes_have_same_shape_as_input:
|
||||
node_with_issues_for_fusing[variable_input_node].append(
|
||||
f"input node with shape {variable_input_node_output_shape}"
|
||||
)
|
||||
|
||||
# All non constant node outputs currently need to have the same shape
|
||||
return all_nodes_have_same_shape_as_input
|
||||
|
||||
|
||||
def subgraph_has_unique_variable_input(
|
||||
float_subgraph_start_nodes: Dict[IntermediateNode, None],
|
||||
terminal_node: IntermediateNode,
|
||||
node_with_issues_for_fusing: DefaultDict[IntermediateNode, List[str]],
|
||||
) -> bool:
|
||||
"""Check that only one of the nodes starting the subgraph is variable.
|
||||
|
||||
Args:
|
||||
float_subgraph_start_nodes (Dict[IntermediateNode, None]): The nodes starting the subgraph.
|
||||
terminal_node (IntermediateNode): The node ending the float subgraph.
|
||||
node_with_issues_for_fusing (DefaultDict[IntermediateNode, List[str]]): Dictionary to fill
|
||||
with potential nodes issues preventing fusing.
|
||||
|
||||
Returns:
|
||||
bool: True if only one of the nodes is not an Constant
|
||||
"""
|
||||
|
||||
variable_inputs_list = [
|
||||
node for node in float_subgraph_start_nodes if not isinstance(node, Constant)
|
||||
]
|
||||
variable_inputs_num = len(variable_inputs_list)
|
||||
|
||||
# Only one input to the subgraph where computations are done in floats can be variable, this
|
||||
# is the only case we can manage with GenericFunction fusing
|
||||
has_unique_variable_input = variable_inputs_num == 1
|
||||
|
||||
if not has_unique_variable_input:
|
||||
for node in variable_inputs_list:
|
||||
node_with_issues_for_fusing[node].append(
|
||||
f"one of {variable_inputs_num} variable inputs (can only have 1 for fusing)"
|
||||
)
|
||||
node_with_issues_for_fusing[terminal_node].append(
|
||||
f"cannot fuse here as the subgraph has {variable_inputs_num} variable inputs"
|
||||
)
|
||||
|
||||
return has_unique_variable_input
|
||||
@@ -1,2 +0,0 @@
|
||||
"""Representation module to represent source programs."""
|
||||
from . import intermediate
|
||||
@@ -1,650 +0,0 @@
|
||||
"""File containing code to represent source programs operations."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
from enum import Enum, unique
|
||||
from math import floor
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from ..data_types.base import BaseDataType
|
||||
from ..data_types.dtypes_helpers import (
|
||||
get_base_value_for_python_constant_data,
|
||||
mix_values_determine_holding_dtype,
|
||||
)
|
||||
from ..data_types.integers import Integer
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from ..helpers import indexing_helpers
|
||||
from ..helpers.formatting_helpers import format_constant
|
||||
from ..helpers.python_helpers import catch, update_and_return_dict
|
||||
from ..values import BaseValue, ClearTensor, EncryptedTensor, TensorValue
|
||||
|
||||
IR_MIX_VALUES_FUNC_ARG_NAME = "mix_values_func"
|
||||
|
||||
ALL_IR_NODES: Set[Type] = set()
|
||||
|
||||
|
||||
class IntermediateNode(ABC):
|
||||
"""Abstract Base Class to derive from to represent source program operations."""
|
||||
|
||||
inputs: List[BaseValue]
|
||||
outputs: List[BaseValue]
|
||||
_n_in: int # _n_in indicates how many inputs are required to evaluate the IntermediateNode
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inputs: Iterable[BaseValue],
|
||||
**_kwargs, # This is to be able to feed arbitrary arguments to IntermediateNodes
|
||||
) -> None:
|
||||
self.inputs = list(inputs)
|
||||
assert_true(all(isinstance(x, BaseValue) for x in self.inputs))
|
||||
|
||||
# Register all IR nodes
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
ALL_IR_NODES.add(cls)
|
||||
|
||||
def _init_binary(
|
||||
self,
|
||||
inputs: Iterable[BaseValue],
|
||||
mix_values_func: Callable[..., BaseValue] = mix_values_determine_holding_dtype,
|
||||
**_kwargs, # Required to conform to __init__ typing
|
||||
) -> None:
|
||||
"""__init__ for a binary operation, ie two inputs."""
|
||||
IntermediateNode.__init__(self, inputs)
|
||||
|
||||
assert_true(len(self.inputs) == 2)
|
||||
|
||||
self.outputs = [mix_values_func(self.inputs[0], self.inputs[1])]
|
||||
|
||||
def text_for_formatting(self, predecessors: List[str], _maximum_constant_length: int) -> str:
|
||||
"""Get the formatted node (used in formatting operation graphs).
|
||||
|
||||
Args:
|
||||
predecessors (List[str]): predecessor names to this node
|
||||
_maximum_constant_length (int): desired maximum constant length
|
||||
|
||||
Returns:
|
||||
str: the formatted node
|
||||
"""
|
||||
|
||||
return f"{self.__class__.__name__.lower()}({', '.join(predecessors)})"
|
||||
|
||||
@abstractmethod
|
||||
def text_for_drawing(self) -> str:
|
||||
"""Get the label of the node (used in drawing operation graphs).
|
||||
|
||||
Returns:
|
||||
str: the label of the node
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
"""Simulate what the represented computation would output for the given inputs.
|
||||
|
||||
Args:
|
||||
inputs (Dict[int, Any]): Dict containing the inputs for the evaluation
|
||||
|
||||
Returns:
|
||||
Any: the result of the computation
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def n_in(cls) -> int:
|
||||
"""Return how many inputs the node has.
|
||||
|
||||
Returns:
|
||||
int: The number of inputs of the node.
|
||||
"""
|
||||
return cls._n_in
|
||||
|
||||
@classmethod
|
||||
def requires_mix_values_func(cls) -> bool:
|
||||
"""Determine whether the Class requires a mix_values_func to be built.
|
||||
|
||||
Returns:
|
||||
bool: True if __init__ expects a mix_values_func argument.
|
||||
"""
|
||||
return cls.n_in() > 1
|
||||
|
||||
|
||||
class Add(IntermediateNode):
|
||||
"""Addition between two values."""
|
||||
|
||||
_n_in: int = 2
|
||||
|
||||
__init__ = IntermediateNode._init_binary
|
||||
|
||||
def text_for_drawing(self) -> str:
|
||||
return "+"
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0] + inputs[1]
|
||||
|
||||
|
||||
class Sub(IntermediateNode):
|
||||
"""Subtraction between two values."""
|
||||
|
||||
_n_in: int = 2
|
||||
|
||||
__init__ = IntermediateNode._init_binary
|
||||
|
||||
def text_for_drawing(self) -> str:
|
||||
return "-"
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0] - inputs[1]
|
||||
|
||||
|
||||
class Mul(IntermediateNode):
|
||||
"""Multiplication between two values."""
|
||||
|
||||
_n_in: int = 2
|
||||
|
||||
__init__ = IntermediateNode._init_binary
|
||||
|
||||
def text_for_drawing(self) -> str:
|
||||
return "*"
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0] * inputs[1]
|
||||
|
||||
|
||||
class Input(IntermediateNode):
|
||||
"""Node representing an input of the program."""
|
||||
|
||||
input_name: str
|
||||
program_input_idx: int
|
||||
_n_in: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_value: BaseValue,
|
||||
input_name: str,
|
||||
program_input_idx: int,
|
||||
) -> None:
|
||||
super().__init__((input_value,))
|
||||
assert_true(len(self.inputs) == 1)
|
||||
self.input_name = input_name
|
||||
self.program_input_idx = program_input_idx
|
||||
self.outputs = [deepcopy(self.inputs[0])]
|
||||
|
||||
def text_for_formatting(self, predecessors: List[str], _maximum_constant_length: int) -> str:
|
||||
assert_true(len(predecessors) == 0)
|
||||
return self.input_name
|
||||
|
||||
def text_for_drawing(self) -> str:
|
||||
return self.input_name
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0]
|
||||
|
||||
|
||||
class Constant(IntermediateNode):
|
||||
"""Node representing a constant of the program."""
|
||||
|
||||
_constant_data: Any
|
||||
_n_in: int = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
constant_data: Any,
|
||||
get_base_value_for_data_func: Callable[
|
||||
[Any], Callable[..., BaseValue]
|
||||
] = get_base_value_for_python_constant_data,
|
||||
) -> None:
|
||||
super().__init__([])
|
||||
|
||||
base_value_class = get_base_value_for_data_func(constant_data)
|
||||
|
||||
self._constant_data = constant_data
|
||||
self.outputs = [base_value_class(is_encrypted=False)]
|
||||
|
||||
def text_for_formatting(self, predecessors: List[str], maximum_constant_length: int) -> str:
|
||||
assert_true(len(predecessors) == 0)
|
||||
return format_constant(self.constant_data, maximum_constant_length)
|
||||
|
||||
def text_for_drawing(self) -> str:
|
||||
return format_constant(self.constant_data)
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return self.constant_data
|
||||
|
||||
@property
|
||||
def constant_data(self) -> Any:
|
||||
"""Return the constant_data stored in the Constant node.
|
||||
|
||||
Returns:
|
||||
Any: The constant data that was stored.
|
||||
"""
|
||||
return self._constant_data
|
||||
|
||||
|
||||
class Conv2D(IntermediateNode):
|
||||
"""Return the node representing a 2d-convolution."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inputs: Iterable[BaseValue],
|
||||
output_dtype: BaseDataType,
|
||||
pads: Union[List[int], Tuple[int, int, int, int]],
|
||||
strides: Union[List[int], Tuple[int, int]],
|
||||
dilations: Union[List[int], Tuple[int, int]],
|
||||
) -> None:
|
||||
|
||||
# TODO: remove this when padding is supported (#427)
|
||||
assert all(pad == 0 for pad in pads), "conv2d doesn't support padding yet"
|
||||
|
||||
super().__init__(inputs)
|
||||
self.pads = pads
|
||||
self.strides = strides
|
||||
self.dilations = dilations
|
||||
|
||||
self._n_in = len(self.inputs)
|
||||
assert_true(len(self.inputs) == 2 or len(self.inputs) == 3)
|
||||
|
||||
assert_true(
|
||||
all(
|
||||
isinstance(input_value, TensorValue) and input_value.ndim == 4
|
||||
for input_value in self.inputs[:2]
|
||||
),
|
||||
f"Conv2D only supports input and weight tensors of 4 dimensions"
|
||||
f"({TensorValue.__name__} with ndim == 4)",
|
||||
)
|
||||
bias = cast(TensorValue, self.inputs[2]) if len(self.inputs) == 3 else None
|
||||
if bias is not None:
|
||||
assert_true(
|
||||
isinstance(bias, TensorValue) and bias.ndim == 1,
|
||||
f"Conv2D only supports bias 1 dimension ({TensorValue.__name__} with ndim == 1)",
|
||||
)
|
||||
|
||||
x = cast(TensorValue, self.inputs[0])
|
||||
weight = cast(TensorValue, self.inputs[1])
|
||||
|
||||
# Compute output shape
|
||||
input_n, _, input_h, input_w = x.shape
|
||||
weight_f, _, weight_h, weight_w = weight.shape
|
||||
pads_h = pads[0] + pads[2]
|
||||
pads_w = pads[1] + pads[3]
|
||||
output_h = floor((input_h + pads_h - dilations[0] * (weight_h - 1) - 1) / strides[0]) + 1
|
||||
output_w = floor((input_w + pads_w - dilations[1] * (weight_w - 1) - 1) / strides[1]) + 1
|
||||
output_shape = (input_n, weight_f, output_h, output_w)
|
||||
|
||||
output_value = EncryptedTensor(dtype=output_dtype, shape=output_shape)
|
||||
self.outputs = [output_value]
|
||||
|
||||
def text_for_drawing(self) -> str:
|
||||
return "conv2d"
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
|
||||
assert_true(
|
||||
len(inputs) == self._n_in, f"expected {self.n_in} inputs, but got {len(inputs)}"
|
||||
)
|
||||
x, weight = inputs[0], inputs[1]
|
||||
bias = inputs[2] if len(inputs) == 3 else np.zeros(weight.shape[0])
|
||||
|
||||
return self.evaluate_conv2d(x, weight, bias, self.pads, self.strides, self.dilations)
|
||||
|
||||
@staticmethod
|
||||
def evaluate_conv2d(
|
||||
x: np.ndarray,
|
||||
weight: np.ndarray,
|
||||
bias: np.ndarray,
|
||||
# TODO: use padding when supported (#427)
|
||||
_: Union[Tuple[int, int, int, int], List[int]],
|
||||
strides: Union[Tuple[int, int], List[int]],
|
||||
dilations: Union[Tuple[int, int], List[int]],
|
||||
):
|
||||
"""Evaluate 2D convolution.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Input of shape (NxCxHxW)
|
||||
weight (np.ndarray): Weight (kernel) of shape (FxCxHxW)
|
||||
bias (np.ndarray): Bias vector of size (F)
|
||||
pads (Union[Tuple[int, int, int, int], List[int]]): Padding over each
|
||||
axis (H_beg, W_beg, H_end, W_end)
|
||||
strides (Union[Tuple[int, int], List[int]]): Stride over each
|
||||
axis (height and width)
|
||||
dilations (Union[Tuple[int, int], List[int]]): Dilation over each
|
||||
axis (height and width)
|
||||
|
||||
Returns:
|
||||
np.ndarray: Result of the convolution of shape (NxCxHxW)
|
||||
"""
|
||||
# pylint: disable=no-member
|
||||
return torch.conv2d(
|
||||
torch.tensor(x, dtype=torch.long),
|
||||
torch.tensor(weight, dtype=torch.long),
|
||||
torch.tensor(bias, dtype=torch.long),
|
||||
stride=strides,
|
||||
dilation=dilations,
|
||||
).numpy()
|
||||
# pylint: enable=no-member
|
||||
|
||||
|
||||
class IndexConstant(IntermediateNode):
|
||||
"""Node representing a constant indexing in the program.
|
||||
|
||||
What we mean by constant indexing is that the index part of the operation is a constant.
|
||||
Here are some examples: `x[2]`, `x[0, 1]`, `y[:, 0]`, `y[3:, :5]`
|
||||
|
||||
The opposite is to have dynamic indexing, which this node does not support.
|
||||
Some examples of dynamic indexing are: `x[y]`, `x[y, z]`, `x[:, y]`
|
||||
"""
|
||||
|
||||
_n_in: int = 1
|
||||
|
||||
index: Tuple[Union[int, slice], ...]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_: BaseValue,
|
||||
index: Union[int, slice, Tuple[Union[int, slice], ...]],
|
||||
) -> None:
|
||||
super().__init__((input_,))
|
||||
|
||||
if not isinstance(self.inputs[0], TensorValue) or self.inputs[0].is_scalar:
|
||||
raise TypeError(f"Only tensors can be indexed but you tried to index {self.inputs[0]}")
|
||||
|
||||
self.index = indexing_helpers.validate_index(index)
|
||||
|
||||
output_dtype = self.inputs[0].dtype
|
||||
output_shape = indexing_helpers.determine_output_shape(self.inputs[0].shape, self.index)
|
||||
|
||||
self.outputs = [
|
||||
EncryptedTensor(output_dtype, output_shape)
|
||||
if self.inputs[0].is_encrypted
|
||||
else ClearTensor(output_dtype, output_shape)
|
||||
]
|
||||
|
||||
def text_for_formatting(self, predecessors: List[str], _maximum_constant_length: int) -> str:
|
||||
assert_true(len(predecessors) == 1)
|
||||
elements = [indexing_helpers.format_indexing_element(element) for element in self.index]
|
||||
index = ", ".join(elements)
|
||||
return f"{predecessors[0]}[{index}]"
|
||||
|
||||
def text_for_drawing(self) -> str:
|
||||
return self.text_for_formatting(["value"], 0) # 0 is unused
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0][self.index]
|
||||
|
||||
|
||||
def flood_replace_none_values(table: list):
|
||||
"""Use a flooding algorithm to replace None values.
|
||||
|
||||
Args:
|
||||
table (list): the list in which there are None values that need to be replaced by copies of
|
||||
the closest non None data from the list.
|
||||
"""
|
||||
assert_true(any(value is not None for value in table))
|
||||
|
||||
not_none_values_idx = deque(idx for idx, value in enumerate(table) if value is not None)
|
||||
while not_none_values_idx:
|
||||
current_idx = not_none_values_idx.popleft()
|
||||
current_value = table[current_idx]
|
||||
previous_idx = current_idx - 1
|
||||
next_idx = current_idx + 1
|
||||
if previous_idx >= 0 and table[previous_idx] is None:
|
||||
table[previous_idx] = deepcopy(current_value)
|
||||
not_none_values_idx.append(previous_idx)
|
||||
if next_idx < len(table) and table[next_idx] is None:
|
||||
table[next_idx] = deepcopy(current_value)
|
||||
not_none_values_idx.append(next_idx)
|
||||
|
||||
assert_true(all(value is not None for value in table))
|
||||
|
||||
|
||||
@unique
|
||||
class GenericFunctionKind(str, Enum):
|
||||
"""Enum to validate GenericFunction op_kind."""
|
||||
|
||||
TLU = "TLU"
|
||||
MEMORY = "Memory"
|
||||
|
||||
|
||||
class GenericFunction(IntermediateNode):
|
||||
"""Node representing an arbitrary function with a single output, e.g. sin(x)."""
|
||||
|
||||
# The arbitrary_func is not optional but mypy has a long standing bug and is not able to
|
||||
# understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623
|
||||
# arbitrary_func can take more than one argument but during evaluation the input variable will
|
||||
# be the first argument passed to it. You can add other constant arguments needed for the proper
|
||||
# execution of the function through op_args and op_kwargs.
|
||||
arbitrary_func: Optional[Callable]
|
||||
op_kind: GenericFunctionKind
|
||||
op_name: str
|
||||
op_args: Tuple[Any, ...]
|
||||
op_kwargs: Dict[str, Any]
|
||||
op_attributes: Dict[str, Any]
|
||||
_n_in: int
|
||||
|
||||
# TODO: https://github.com/zama-ai/concrete-numpy-internal/issues/798 have a proper
|
||||
# attribute system
|
||||
DEFAULT_OP_ATTRIBUTES: Dict[str, Any] = {"fusable": True}
|
||||
|
||||
KWARGS_IGNORED_IN_FORMATTING: Set[str] = {
|
||||
"float_op_subgraph",
|
||||
"terminal_node",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inputs: Iterable[BaseValue],
|
||||
arbitrary_func: Callable,
|
||||
output_value: BaseValue,
|
||||
op_kind: Union[str, GenericFunctionKind],
|
||||
op_name: Optional[str] = None,
|
||||
op_args: Optional[Tuple[Any, ...]] = None,
|
||||
op_kwargs: Optional[Dict[str, Any]] = None,
|
||||
op_attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__([deepcopy(i) for i in inputs])
|
||||
self._n_in = len(self.inputs)
|
||||
self.arbitrary_func = arbitrary_func
|
||||
self.op_kind = GenericFunctionKind(op_kind)
|
||||
self.op_args = op_args if op_args is not None else ()
|
||||
self.op_kwargs = op_kwargs if op_kwargs is not None else {}
|
||||
self.op_attributes = deepcopy(self.DEFAULT_OP_ATTRIBUTES)
|
||||
if op_attributes is not None:
|
||||
self.op_attributes.update(op_attributes)
|
||||
|
||||
self.outputs = [output_value]
|
||||
|
||||
self.op_name = op_name if op_name is not None else self.__class__.__name__
|
||||
|
||||
def text_for_formatting(self, predecessors: List[str], maximum_constant_length: int) -> str:
|
||||
if self.op_name == "concat":
|
||||
all_args = ["(" + ", ".join(predecessors) + ")"]
|
||||
else:
|
||||
all_args = deepcopy(predecessors)
|
||||
|
||||
all_args.extend(format_constant(value, maximum_constant_length) for value in self.op_args)
|
||||
all_args.extend(
|
||||
f"{name}={format_constant(value, maximum_constant_length)}"
|
||||
for name, value in self.op_kwargs.items()
|
||||
if name not in GenericFunction.KWARGS_IGNORED_IN_FORMATTING
|
||||
)
|
||||
|
||||
return f"{self.op_name}({', '.join(all_args)})"
|
||||
|
||||
def text_for_drawing(self) -> str:
|
||||
return self.op_name
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
# This is the continuation of the mypy bug workaround
|
||||
assert self.arbitrary_func is not None
|
||||
ordered_inputs = [inputs[idx] for idx in range(len(inputs))]
|
||||
if self.op_name == "concat":
|
||||
return self.arbitrary_func(tuple(ordered_inputs), *self.op_args, **self.op_kwargs)
|
||||
return self.arbitrary_func(*ordered_inputs, *self.op_args, **self.op_kwargs)
|
||||
|
||||
def get_table(self, ordered_preds: List[IntermediateNode]) -> List[Any]:
|
||||
"""Get the table for the current input value of this GenericFunction.
|
||||
|
||||
This function only works if the GenericFunction variable input value is an Integer.
|
||||
This function only works if there is a single variable input node among ordered_preds.
|
||||
|
||||
Args:
|
||||
ordered_preds (List[IntermediateNode]): List of predecessors of the node. This list must
|
||||
contain a single non constant node and any number of Constant nodes.
|
||||
|
||||
Returns:
|
||||
List[Any]: The table.
|
||||
"""
|
||||
|
||||
variable_input_indices = [
|
||||
idx for idx, pred in enumerate(ordered_preds) if not isinstance(pred, Constant)
|
||||
]
|
||||
|
||||
assert_true(
|
||||
(non_constant_pred_count := len(variable_input_indices)) == 1,
|
||||
f"Can only have 1 non constant predecessor in {self.get_table.__name__}, "
|
||||
f"got {non_constant_pred_count}",
|
||||
)
|
||||
|
||||
variable_input_idx = variable_input_indices[0]
|
||||
variable_input_dtype = self.inputs[variable_input_idx].dtype
|
||||
# Check the input is an integer to be able to build a table
|
||||
assert_true(
|
||||
isinstance(variable_input_dtype, Integer),
|
||||
f"{self.get_table.__name__} only works for an unsigned Integer input",
|
||||
)
|
||||
variable_input_dtype = cast(Integer, variable_input_dtype)
|
||||
|
||||
input_value_constructor = self.inputs[variable_input_idx].underlying_constructor
|
||||
if input_value_constructor is None:
|
||||
logger.info(
|
||||
f"{self.__class__.__name__} input data type constructor was None, defaulting to int"
|
||||
)
|
||||
input_value_constructor = int
|
||||
|
||||
min_input_range = variable_input_dtype.min_value()
|
||||
max_input_range = variable_input_dtype.max_value() + 1
|
||||
|
||||
template_input_dict = {
|
||||
idx: node.evaluate({}) if isinstance(node, Constant) else None
|
||||
for idx, node in enumerate(ordered_preds)
|
||||
}
|
||||
|
||||
table = [
|
||||
catch(
|
||||
self.evaluate,
|
||||
update_and_return_dict(
|
||||
template_input_dict, {variable_input_idx: input_value_constructor(input_value)}
|
||||
),
|
||||
)
|
||||
for input_value in range(min_input_range, max_input_range)
|
||||
]
|
||||
|
||||
flood_replace_none_values(table)
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def default_dot_evaluation_function(lhs: Any, rhs: Any) -> Any:
|
||||
"""Return the default python dot implementation for 1D iterable arrays.
|
||||
|
||||
Args:
|
||||
lhs (Any): lhs vector of the dot.
|
||||
rhs (Any): rhs vector of the dot.
|
||||
|
||||
Returns:
|
||||
Any: the result of the dot operation.
|
||||
"""
|
||||
return sum(lhs * rhs for lhs, rhs in zip(lhs, rhs))
|
||||
|
||||
|
||||
class Dot(IntermediateNode):
|
||||
"""Return the node representing a dot product."""
|
||||
|
||||
_n_in: int = 2
|
||||
# Optional, same issue as in GenericFunction for mypy
|
||||
evaluation_function: Optional[Callable[[Any, Any], Any]]
|
||||
# Allows to use specialized implementations from e.g. numpy
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inputs: Iterable[BaseValue],
|
||||
output_dtype: BaseDataType,
|
||||
delegate_evaluation_function: Optional[
|
||||
Callable[[Any, Any], Any]
|
||||
] = default_dot_evaluation_function,
|
||||
) -> None:
|
||||
super().__init__(inputs)
|
||||
assert_true(len(self.inputs) == 2)
|
||||
|
||||
assert_true(
|
||||
all(
|
||||
isinstance(input_value, TensorValue) and input_value.ndim <= 1
|
||||
for input_value in self.inputs
|
||||
),
|
||||
f"Dot only supports two scalars or vectors ({TensorValue.__name__} with ndim up to 1)",
|
||||
)
|
||||
|
||||
lhs = cast(TensorValue, self.inputs[0])
|
||||
rhs = cast(TensorValue, self.inputs[1])
|
||||
|
||||
if lhs.ndim == 1 and rhs.ndim == 1:
|
||||
assert_true(
|
||||
lhs.shape[0] == rhs.shape[0],
|
||||
f"Dot between vectors of shapes {lhs.shape} and {rhs.shape} is not supported",
|
||||
)
|
||||
|
||||
output_shape: Tuple[int, ...]
|
||||
if (lhs.ndim == 1 and rhs.ndim == 1) or (lhs.ndim == 0 and rhs.ndim == 0):
|
||||
# numpy.dot(x, y) where x and y are both vectors or both scalars
|
||||
output_shape = ()
|
||||
elif lhs.ndim == 1:
|
||||
# numpy.dot(x, y) where x is a vector and y is a scalar
|
||||
output_shape = lhs.shape
|
||||
else:
|
||||
# numpy.dot(x, y) where x is a scalar and y is a vector
|
||||
output_shape = rhs.shape
|
||||
|
||||
output_value = EncryptedTensor if (lhs.is_encrypted or rhs.is_encrypted) else ClearTensor
|
||||
|
||||
self.outputs = [output_value(output_dtype, output_shape)]
|
||||
self.evaluation_function = delegate_evaluation_function
|
||||
|
||||
def text_for_drawing(self) -> str:
|
||||
return "dot"
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
# This is the continuation of the mypy bug workaround
|
||||
assert self.evaluation_function is not None
|
||||
return self.evaluation_function(inputs[0], inputs[1])
|
||||
|
||||
|
||||
class MatMul(IntermediateNode):
|
||||
"""Return the node representing a matrix multiplication."""
|
||||
|
||||
_n_in: int = 2
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inputs: Iterable[BaseValue],
|
||||
output_dtype: BaseDataType,
|
||||
output_shape: Tuple[int, ...],
|
||||
) -> None:
|
||||
super().__init__(inputs)
|
||||
assert_true(len(self.inputs) == 2)
|
||||
|
||||
output_value = (
|
||||
EncryptedTensor(dtype=output_dtype, shape=output_shape)
|
||||
if (self.inputs[0].is_encrypted or self.inputs[1].is_encrypted)
|
||||
else ClearTensor(dtype=output_dtype, shape=output_shape)
|
||||
)
|
||||
|
||||
self.outputs = [output_value]
|
||||
|
||||
def text_for_drawing(self) -> str:
|
||||
return "matmul"
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0] @ inputs[1]
|
||||
@@ -1,8 +0,0 @@
|
||||
"""Module for basic tracing facilities."""
|
||||
from .base_tracer import BaseTracer
|
||||
from .tracing_helpers import (
|
||||
create_graph_from_output_tracers,
|
||||
make_input_tracer,
|
||||
make_input_tracers,
|
||||
prepare_function_parameters,
|
||||
)
|
||||
@@ -1,462 +0,0 @@
|
||||
"""This file holds the code that can be shared between tracers."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union, cast
|
||||
|
||||
from ..data_types import Float
|
||||
from ..data_types.base import BaseDataType
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from ..representation.intermediate import (
|
||||
IR_MIX_VALUES_FUNC_ARG_NAME,
|
||||
Add,
|
||||
Constant,
|
||||
GenericFunction,
|
||||
IndexConstant,
|
||||
IntermediateNode,
|
||||
Mul,
|
||||
Sub,
|
||||
)
|
||||
from ..values import BaseValue, TensorValue
|
||||
|
||||
|
||||
class BaseTracer(ABC):
|
||||
"""Base class for implementing tracers."""
|
||||
|
||||
# this variable changes the behavior of __eq__ so that it can be traced but still allows to hash
|
||||
# BaseTracers when not tracing.
|
||||
_is_tracing: bool = False
|
||||
|
||||
inputs: List["BaseTracer"]
|
||||
traced_computation: IntermediateNode
|
||||
output_idx: int
|
||||
output: BaseValue
|
||||
_mix_values_func: Callable[..., BaseValue]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inputs: Iterable["BaseTracer"],
|
||||
traced_computation: IntermediateNode,
|
||||
output_idx: int,
|
||||
) -> None:
|
||||
self.inputs = list(inputs)
|
||||
self.traced_computation = traced_computation
|
||||
self.output_idx = output_idx
|
||||
self.output = traced_computation.outputs[output_idx]
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
"""Get the shape of the output of the tracer.
|
||||
|
||||
Returns:
|
||||
Tuple[int, ...]: the shape of the output
|
||||
"""
|
||||
|
||||
if isinstance(self.output, TensorValue):
|
||||
return self.output.shape
|
||||
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object "
|
||||
f"with '{self.output}' output "
|
||||
f"has no attribute 'shape'"
|
||||
) # pragma: no cover
|
||||
|
||||
# this error cannot be covered because we only have TensorValue for now
|
||||
|
||||
@abstractmethod
|
||||
def _supports_other_operand(self, other: Any) -> bool:
|
||||
"""Check if the current class supports tracing with the other operand.
|
||||
|
||||
Args:
|
||||
other (Any): the operand to check compatibility with.
|
||||
|
||||
Returns:
|
||||
bool: True if the tracer can manage operations with the other operand.
|
||||
"""
|
||||
return isinstance(other, self.__class__)
|
||||
|
||||
@abstractmethod
|
||||
def _make_const_input_tracer(self, constant_data: Any) -> "BaseTracer":
|
||||
"""Create a tracer for a constant input.
|
||||
|
||||
Args:
|
||||
constant_data (Any): The constant to store.
|
||||
|
||||
Returns:
|
||||
BaseTracer: The BaseTracer for that constant.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def set_is_tracing(cls, is_tracing: bool) -> None:
|
||||
"""Set whether we are in a tracing context to change __eq__ behavior.
|
||||
|
||||
Args:
|
||||
is_tracing (bool): boolean to use to set whether we are tracing
|
||||
"""
|
||||
cls._is_tracing = is_tracing
|
||||
|
||||
@classmethod
|
||||
def _get_mix_values_func(cls):
|
||||
return cls._mix_values_func
|
||||
|
||||
def _sanitize(self, inp) -> "BaseTracer":
|
||||
if not isinstance(inp, BaseTracer) and not (
|
||||
isinstance(inp, Tuple) # type: ignore
|
||||
and all(isinstance(item, BaseTracer) for item in inp) # type: ignore
|
||||
):
|
||||
return self._make_const_input_tracer(inp)
|
||||
return inp
|
||||
|
||||
def instantiate_output_tracers(
|
||||
self,
|
||||
inputs: Iterable[Union["BaseTracer", Any]],
|
||||
computation_to_trace: Type[IntermediateNode],
|
||||
) -> Tuple["BaseTracer", ...]:
|
||||
"""Instantiate all output BaseTracer for a given computation.
|
||||
|
||||
Args:
|
||||
inputs (Iterable[Union[BaseTracer, Any]]): Previous BaseTracer or data used as inputs
|
||||
for a new node.
|
||||
computation_to_trace (Type[IntermediateNode]): The IntermediateNode class
|
||||
to instantiate for the computation being traced
|
||||
|
||||
Returns:
|
||||
Tuple[BaseTracer, ...]: A tuple containing an BaseTracer per output function
|
||||
"""
|
||||
|
||||
# For inputs which are actually constant, first convert into a tracer
|
||||
sanitized_inputs = [self._sanitize(inp) for inp in inputs]
|
||||
|
||||
additional_parameters = (
|
||||
{IR_MIX_VALUES_FUNC_ARG_NAME: self._get_mix_values_func()}
|
||||
if computation_to_trace.requires_mix_values_func()
|
||||
else {}
|
||||
)
|
||||
|
||||
traced_computation = computation_to_trace(
|
||||
(x.output for x in sanitized_inputs),
|
||||
**additional_parameters,
|
||||
)
|
||||
|
||||
output_tracers = tuple(
|
||||
self.__class__(sanitized_inputs, traced_computation, output_idx)
|
||||
for output_idx in range(len(traced_computation.outputs))
|
||||
)
|
||||
|
||||
return output_tracers
|
||||
|
||||
def _helper_for_unary_functions(self, op_lambda: Callable, op_name: str) -> "BaseTracer":
|
||||
"""Trace a unary operator which maintains the shape, which will thus be replaced by a TLU.
|
||||
|
||||
Returns:
|
||||
BaseTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
first_arg_output = self.output
|
||||
assert_true(isinstance(first_arg_output, TensorValue))
|
||||
first_arg_output = cast(TensorValue, first_arg_output)
|
||||
|
||||
out_dtype = first_arg_output.dtype
|
||||
out_shape = first_arg_output.shape
|
||||
|
||||
generic_function_output_value = TensorValue(
|
||||
out_dtype,
|
||||
first_arg_output.is_encrypted,
|
||||
out_shape,
|
||||
)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[first_arg_output],
|
||||
arbitrary_func=op_lambda,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
op_name=f"{op_name}",
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
[self],
|
||||
traced_computation=traced_computation,
|
||||
output_idx=0,
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def _helper_for_binary_functions_with_one_cst_input(
|
||||
self,
|
||||
lhs: Union["BaseTracer", Any],
|
||||
rhs: Union["BaseTracer", Any],
|
||||
op_lambda: Callable,
|
||||
op_name: str,
|
||||
output_dtype: Optional[BaseDataType] = None,
|
||||
) -> "BaseTracer":
|
||||
"""Trace a binary operator which maintains the shape, when one input is a constant.
|
||||
|
||||
This function is helpful to convert an operation with two inputs, one of which being a
|
||||
constant, into a TLU, while maintaining the constant somewhere in the graph, eg to simplify
|
||||
debugging.
|
||||
|
||||
Returns:
|
||||
BaseTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
if isinstance(lhs, BaseTracer):
|
||||
if not self._supports_other_operand(rhs):
|
||||
return NotImplemented
|
||||
elif isinstance(rhs, BaseTracer):
|
||||
if not self._supports_other_operand(lhs):
|
||||
return NotImplemented
|
||||
|
||||
sanitized_inputs = [self._sanitize(inp) for inp in [lhs, rhs]]
|
||||
|
||||
# One of the inputs has to be constant
|
||||
if not (
|
||||
isinstance(sanitized_inputs[0].traced_computation, Constant)
|
||||
or isinstance(sanitized_inputs[1].traced_computation, Constant)
|
||||
):
|
||||
raise NotImplementedError(f"Can't manage binary operator {op_name}")
|
||||
|
||||
sanitized_input_values = [san_input.output for san_input in sanitized_inputs]
|
||||
output_value = self._get_mix_values_func()(*sanitized_input_values)
|
||||
if output_dtype is not None:
|
||||
output_value.dtype = deepcopy(output_dtype)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=sanitized_input_values,
|
||||
arbitrary_func=op_lambda,
|
||||
output_value=output_value,
|
||||
op_kind="TLU",
|
||||
op_name=op_name,
|
||||
)
|
||||
|
||||
result_tracer = self.__class__(sanitized_inputs, traced_computation, 0)
|
||||
|
||||
return result_tracer
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return id(self)
|
||||
|
||||
def __add__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
if not self._supports_other_operand(other):
|
||||
return NotImplemented
|
||||
|
||||
result_tracer = self.instantiate_output_tracers(
|
||||
[self, other],
|
||||
Add,
|
||||
)
|
||||
|
||||
assert_true(len(result_tracer) == 1)
|
||||
return result_tracer[0]
|
||||
|
||||
# With that is that x + 1 and 1 + x have the same graph. If we want to keep
|
||||
# the order, we need to do as in __rsub__, ie mostly a copy of __sub__ +
|
||||
# some changes
|
||||
__radd__ = __add__
|
||||
|
||||
def __neg__(self) -> "BaseTracer":
|
||||
return 0 - self
|
||||
|
||||
def __pos__(self) -> "BaseTracer":
|
||||
# Remark that we don't want to return 'self' since we want the result to be a copy, ie not
|
||||
# a reference to the same object
|
||||
return 0 + self
|
||||
|
||||
def _lshift(self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
lhs, rhs, lambda x, y: x << y, "lshift"
|
||||
)
|
||||
|
||||
def __lshift__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
# x << shift
|
||||
return self._lshift(self, other)
|
||||
|
||||
def __rlshift__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
# cst << x
|
||||
return self._lshift(other, self)
|
||||
|
||||
def _rshift(self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
lhs, rhs, lambda x, y: x >> y, "rshift"
|
||||
)
|
||||
|
||||
def __rshift__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
# x >> shift
|
||||
return self._rshift(self, other)
|
||||
|
||||
def __rrshift__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
# cst >> x
|
||||
return self._rshift(other, self)
|
||||
|
||||
def __gt__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
# x > cst
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
self, other, lambda x, y: x > y, "gt"
|
||||
)
|
||||
|
||||
def __ge__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
# x >= cst
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
self, other, lambda x, y: x >= y, "ge"
|
||||
)
|
||||
|
||||
def __lt__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
# x < cst
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
self, other, lambda x, y: x < y, "lt"
|
||||
)
|
||||
|
||||
def __le__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
# x <= cst
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
self, other, lambda x, y: x <= y, "le"
|
||||
)
|
||||
|
||||
def __eq__(self, other: Union["BaseTracer", Any]):
|
||||
# x == cst
|
||||
# Return the tracer if we are tracing, else return the result of the default __eq__ function
|
||||
# allows to have hash capabilities outside of tracing
|
||||
return (
|
||||
self._helper_for_binary_functions_with_one_cst_input(
|
||||
self, other, lambda x, y: x == y, "eq"
|
||||
)
|
||||
if self._is_tracing
|
||||
else self is other
|
||||
)
|
||||
|
||||
def __ne__(self, other: Union["BaseTracer", Any]):
|
||||
# x != cst
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
self, other, lambda x, y: x != y, "ne"
|
||||
)
|
||||
|
||||
def __pow__(self, other: Union["BaseTracer", Any]):
|
||||
# x ** cst
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
self, other, lambda x, y: x ** y, "pow"
|
||||
)
|
||||
|
||||
def __rpow__(self, other: Union["BaseTracer", Any]):
|
||||
# cst ** x
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
other, self, lambda x, y: x ** y, "pow"
|
||||
)
|
||||
|
||||
def __mod__(self, other: Union["BaseTracer", Any]):
|
||||
# x % cst
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
self, other, lambda x, y: x % y, "mod"
|
||||
)
|
||||
|
||||
def __rmod__(self, other: Union["BaseTracer", Any]):
|
||||
# cst % x
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
other, self, lambda x, y: x % y, "mod"
|
||||
)
|
||||
|
||||
def __and__(self, other: Union["BaseTracer", Any]):
|
||||
# x & cst
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
self, other, lambda x, y: x & y, "and"
|
||||
)
|
||||
|
||||
def __rand__(self, other: Union["BaseTracer", Any]):
|
||||
# cst & x
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
other, self, lambda x, y: x & y, "and"
|
||||
)
|
||||
|
||||
def __or__(self, other: Union["BaseTracer", Any]):
|
||||
# x | cst
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
self, other, lambda x, y: x | y, "or"
|
||||
)
|
||||
|
||||
def __ror__(self, other: Union["BaseTracer", Any]):
|
||||
# cst | x
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
other, self, lambda x, y: x | y, "or"
|
||||
)
|
||||
|
||||
def __xor__(self, other: Union["BaseTracer", Any]):
|
||||
# x ^ cst
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
self, other, lambda x, y: x ^ y, "xor"
|
||||
)
|
||||
|
||||
def __rxor__(self, other: Union["BaseTracer", Any]):
|
||||
# cst ^ x
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
other, self, lambda x, y: x ^ y, "xor"
|
||||
)
|
||||
|
||||
def __sub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
if not self._supports_other_operand(other):
|
||||
return NotImplemented
|
||||
|
||||
result_tracer = self.instantiate_output_tracers(
|
||||
[self, other],
|
||||
Sub,
|
||||
)
|
||||
|
||||
assert_true(len(result_tracer) == 1)
|
||||
return result_tracer[0]
|
||||
|
||||
def __rsub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
if not self._supports_other_operand(other):
|
||||
return NotImplemented
|
||||
|
||||
result_tracer = self.instantiate_output_tracers(
|
||||
[other, self],
|
||||
Sub,
|
||||
)
|
||||
|
||||
assert_true(len(result_tracer) == 1)
|
||||
return result_tracer[0]
|
||||
|
||||
def __mul__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
if not self._supports_other_operand(other):
|
||||
return NotImplemented
|
||||
|
||||
result_tracer = self.instantiate_output_tracers(
|
||||
[self, other],
|
||||
Mul,
|
||||
)
|
||||
|
||||
assert_true(len(result_tracer) == 1)
|
||||
return result_tracer[0]
|
||||
|
||||
# With that is that x * 3 and 3 * x have the same graph. If we want to keep
|
||||
# the order, we need to do as in __rmul__, ie mostly a copy of __mul__ +
|
||||
# some changes
|
||||
__rmul__ = __mul__
|
||||
|
||||
def __abs__(self):
|
||||
return self._helper_for_unary_functions(lambda x: x.__abs__(), "__abs__")
|
||||
|
||||
def __invert__(self):
|
||||
return self._helper_for_unary_functions(lambda x: x.__invert__(), "__invert__")
|
||||
|
||||
def __getitem__(self, item):
|
||||
traced_computation = IndexConstant(self.output, item)
|
||||
return self.__class__([self], traced_computation, 0)
|
||||
|
||||
def _truediv(
|
||||
self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]
|
||||
) -> "BaseTracer":
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
lhs, rhs, lambda x, y: x / y, "truediv", Float(64)
|
||||
)
|
||||
|
||||
def __truediv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
return self._truediv(self, other)
|
||||
|
||||
def __rtruediv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
return self._truediv(other, self)
|
||||
|
||||
def _floordiv(
|
||||
self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]
|
||||
) -> "BaseTracer":
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
lhs, rhs, lambda x, y: x // y, "floordiv"
|
||||
)
|
||||
|
||||
def __floordiv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
return self._floordiv(self, other)
|
||||
|
||||
def __rfloordiv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
return self._floordiv(other, self)
|
||||
@@ -1,162 +0,0 @@
|
||||
"""Helper functions for tracing."""
|
||||
import collections
|
||||
from contextlib import contextmanager
|
||||
from inspect import signature
|
||||
from typing import Callable, Dict, Iterable, List, OrderedDict, Set, Type
|
||||
|
||||
import networkx as nx
|
||||
from networkx.algorithms.dag import is_directed_acyclic_graph
|
||||
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from ..representation.intermediate import Input
|
||||
from ..values import BaseValue
|
||||
from .base_tracer import BaseTracer
|
||||
|
||||
|
||||
def make_input_tracers(
|
||||
tracer_class: Type[BaseTracer],
|
||||
function_parameters: OrderedDict[str, BaseValue],
|
||||
) -> OrderedDict[str, BaseTracer]:
|
||||
"""Create tracers for a function's parameters.
|
||||
|
||||
Args:
|
||||
tracer_class (Type[BaseTracer]): the class of tracer to create an Input for
|
||||
function_parameters (OrderedDict[str, BaseValue]): the dictionary with the parameters names
|
||||
and corresponding Values
|
||||
|
||||
Returns:
|
||||
OrderedDict[str, BaseTracer]: the dictionary containing the Input Tracers for each parameter
|
||||
"""
|
||||
return collections.OrderedDict(
|
||||
(param_name, make_input_tracer(tracer_class, param_name, input_idx, param))
|
||||
for input_idx, (param_name, param) in enumerate(function_parameters.items())
|
||||
)
|
||||
|
||||
|
||||
def make_input_tracer(
|
||||
tracer_class: Type[BaseTracer],
|
||||
input_name: str,
|
||||
input_idx: int,
|
||||
input_value: BaseValue,
|
||||
) -> BaseTracer:
|
||||
"""Create a tracer for an input value.
|
||||
|
||||
Args:
|
||||
tracer_class (Type[BaseTracer]): the class of tracer to create an Input for
|
||||
input_name (str): the name of the input in the traced function
|
||||
input_idx (int): the input index in the function parameters
|
||||
input_value (BaseValue): the Value that is an input and needs to be wrapped in an
|
||||
BaseTracer
|
||||
|
||||
Returns:
|
||||
BaseTracer: The BaseTracer for that input value
|
||||
"""
|
||||
return tracer_class([], Input(input_value, input_name, input_idx), 0)
|
||||
|
||||
|
||||
def prepare_function_parameters(
|
||||
function_to_trace: Callable, function_parameters: Dict[str, BaseValue]
|
||||
) -> OrderedDict[str, BaseValue]:
|
||||
"""Filter the passed function_parameters to trace function_to_trace.
|
||||
|
||||
Args:
|
||||
function_to_trace (Callable): function that will be traced for which parameters are checked
|
||||
function_parameters (Dict[str, BaseValue]): parameters given to trace the function
|
||||
|
||||
Raises:
|
||||
ValueError: Raised when some parameters are missing to trace function_to_trace
|
||||
|
||||
Returns:
|
||||
OrderedDict[str, BaseValue]: filtered function_parameters dictionary
|
||||
"""
|
||||
function_signature = signature(function_to_trace)
|
||||
|
||||
missing_args = function_signature.parameters.keys() - function_parameters.keys()
|
||||
|
||||
if len(missing_args) > 0:
|
||||
raise ValueError(
|
||||
f"The function '{function_to_trace.__name__}' requires the following parameters"
|
||||
f"that were not provided: {', '.join(sorted(missing_args))}"
|
||||
)
|
||||
|
||||
# This convoluted way of creating the dict is to ensure key order is maintained
|
||||
return collections.OrderedDict(
|
||||
(param_name, function_parameters[param_name])
|
||||
for param_name in function_signature.parameters.keys()
|
||||
)
|
||||
|
||||
|
||||
def create_graph_from_output_tracers(
|
||||
output_tracers: Iterable[BaseTracer],
|
||||
) -> nx.MultiDiGraph:
|
||||
"""Generate a networkx Directed Graph that represents the computation from a traced function.
|
||||
|
||||
Args:
|
||||
output_tracers (Iterable[BaseTracer]): the output tracers resulting from running the
|
||||
function over the proper input tracers
|
||||
|
||||
Returns:
|
||||
nx.MultiDiGraph: Directed Graph that is guaranteed to be a DAG containing the ir nodes
|
||||
representing the traced program/function
|
||||
"""
|
||||
graph = nx.MultiDiGraph()
|
||||
|
||||
visited_tracers: Set[BaseTracer] = set()
|
||||
# use dict as ordered set
|
||||
current_tracers = {tracer: None for tracer in output_tracers}
|
||||
|
||||
while current_tracers:
|
||||
# use dict as ordered set
|
||||
next_tracers: Dict[BaseTracer, None] = {}
|
||||
for tracer in current_tracers:
|
||||
if tracer in visited_tracers:
|
||||
continue
|
||||
current_ir_node = tracer.traced_computation
|
||||
graph.add_node(current_ir_node)
|
||||
|
||||
for input_idx, input_tracer in enumerate(tracer.inputs):
|
||||
input_ir_node = input_tracer.traced_computation
|
||||
output_idx = input_tracer.output_idx
|
||||
graph.add_node(input_ir_node)
|
||||
graph.add_edge(
|
||||
input_ir_node,
|
||||
current_ir_node,
|
||||
input_idx=input_idx,
|
||||
output_idx=output_idx,
|
||||
)
|
||||
if input_tracer not in visited_tracers:
|
||||
next_tracers.update({input_tracer: None})
|
||||
|
||||
visited_tracers.add(tracer)
|
||||
|
||||
current_tracers = next_tracers
|
||||
|
||||
assert_true(is_directed_acyclic_graph(graph))
|
||||
|
||||
# Check each edge is unique
|
||||
unique_edges = set(
|
||||
(pred, succ, tuple((k, v) for k, v in edge_data.items()))
|
||||
for pred, succ, edge_data in graph.edges(data=True)
|
||||
)
|
||||
number_of_edges = len(graph.edges)
|
||||
assert_true(len(unique_edges) == number_of_edges)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tracing_context(tracer_classes: List[Type[BaseTracer]]):
|
||||
"""Set tracer classes in tracing mode.
|
||||
|
||||
Args:
|
||||
tracer_classes (List[Type[BaseTracer]]): The list of tracers for which we should enable
|
||||
tracing.
|
||||
"""
|
||||
|
||||
try:
|
||||
for tracer_class in tracer_classes:
|
||||
tracer_class.set_is_tracing(True)
|
||||
yield
|
||||
finally:
|
||||
for tracer_class in tracer_classes:
|
||||
tracer_class.set_is_tracing(False)
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Module for value structures."""
|
||||
|
||||
from . import tensors
|
||||
from .base import BaseValue
|
||||
from .tensors import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue
|
||||
@@ -1,45 +0,0 @@
|
||||
"""Module that defines the values in a program."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Optional
|
||||
|
||||
from ..data_types.base import BaseDataType
|
||||
|
||||
|
||||
class BaseValue(ABC):
|
||||
"""Abstract base class to represent any kind of value in a program."""
|
||||
|
||||
dtype: BaseDataType
|
||||
_is_encrypted: bool
|
||||
underlying_constructor: Optional[Callable]
|
||||
|
||||
def __init__(self, dtype: BaseDataType, is_encrypted: bool) -> None:
|
||||
self.dtype = deepcopy(dtype)
|
||||
self._is_encrypted = is_encrypted
|
||||
self.underlying_constructor = None
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover
|
||||
return str(self)
|
||||
|
||||
@abstractmethod
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, self.__class__) and self.dtype == other.dtype
|
||||
|
||||
@property
|
||||
def is_encrypted(self) -> bool:
|
||||
"""Whether Value is encrypted or not.
|
||||
|
||||
Returns:
|
||||
bool: True if encrypted False otherwise
|
||||
"""
|
||||
return self._is_encrypted
|
||||
|
||||
@property
|
||||
def is_clear(self) -> bool:
|
||||
"""Whether Value is clear or not.
|
||||
|
||||
Returns:
|
||||
bool: True if clear False otherwise
|
||||
"""
|
||||
return not self._is_encrypted
|
||||
@@ -1,142 +0,0 @@
|
||||
"""Module that defines the tensor values in a program."""
|
||||
|
||||
from math import prod
|
||||
from typing import Tuple
|
||||
|
||||
from ..data_types.base import BaseDataType
|
||||
from .base import BaseValue
|
||||
|
||||
|
||||
class TensorValue(BaseValue):
|
||||
"""Class representing a tensor value."""
|
||||
|
||||
_shape: Tuple[int, ...]
|
||||
_ndim: int
|
||||
_size: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dtype: BaseDataType,
|
||||
is_encrypted: bool,
|
||||
shape: Tuple[int, ...],
|
||||
):
|
||||
super().__init__(dtype, is_encrypted)
|
||||
# Managing tensors as in numpy, shape of () means the value is scalar
|
||||
self._shape = shape
|
||||
self._ndim = len(self._shape)
|
||||
self._size = prod(self._shape) if self._shape != () else 1
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and self.shape == other.shape
|
||||
and self.ndim == other.ndim
|
||||
and self.size == other.size
|
||||
and super().__eq__(other)
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
encrypted_str = "Encrypted" if self._is_encrypted else "Clear"
|
||||
tensor_or_scalar_str = "Scalar" if self.is_scalar else "Tensor"
|
||||
shape_str = f", shape={self.shape}" if self.shape != () else ""
|
||||
return f"{encrypted_str}{tensor_or_scalar_str}<{str(self.dtype)}{shape_str}>"
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
"""Return the TensorValue shape property.
|
||||
|
||||
Returns:
|
||||
Tuple[int, ...]: The TensorValue shape.
|
||||
"""
|
||||
return self._shape
|
||||
|
||||
@property
|
||||
def ndim(self) -> int:
|
||||
"""Return the TensorValue ndim property.
|
||||
|
||||
Returns:
|
||||
int: The TensorValue ndim.
|
||||
"""
|
||||
return self._ndim
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Return the TensorValue size property.
|
||||
|
||||
Returns:
|
||||
int: The TensorValue size.
|
||||
"""
|
||||
return self._size
|
||||
|
||||
@property
|
||||
def is_scalar(self) -> bool:
|
||||
"""Whether Value is scalar or not.
|
||||
|
||||
Returns:
|
||||
bool: True if scalar False otherwise
|
||||
"""
|
||||
return self.shape == ()
|
||||
|
||||
|
||||
def make_clear_tensor(
|
||||
dtype: BaseDataType,
|
||||
shape: Tuple[int, ...],
|
||||
) -> TensorValue:
|
||||
"""Create a clear TensorValue.
|
||||
|
||||
Args:
|
||||
dtype (BaseDataType): The data type for the tensor.
|
||||
shape (Optional[Tuple[int, ...]], optional): The tensor shape. Defaults to None.
|
||||
|
||||
Returns:
|
||||
TensorValue: The corresponding TensorValue.
|
||||
"""
|
||||
return TensorValue(dtype=dtype, is_encrypted=False, shape=shape)
|
||||
|
||||
|
||||
def make_encrypted_tensor(
|
||||
dtype: BaseDataType,
|
||||
shape: Tuple[int, ...],
|
||||
) -> TensorValue:
|
||||
"""Create an encrypted TensorValue.
|
||||
|
||||
Args:
|
||||
dtype (BaseDataType): The data type for the tensor.
|
||||
shape (Optional[Tuple[int, ...]], optional): The tensor shape. Defaults to None.
|
||||
|
||||
Returns:
|
||||
TensorValue: The corresponding TensorValue.
|
||||
"""
|
||||
return TensorValue(dtype=dtype, is_encrypted=True, shape=shape)
|
||||
|
||||
|
||||
ClearTensor = make_clear_tensor
|
||||
EncryptedTensor = make_encrypted_tensor
|
||||
|
||||
|
||||
def make_clear_scalar(dtype: BaseDataType) -> TensorValue:
|
||||
"""Create a clear scalar value.
|
||||
|
||||
Args:
|
||||
dtype (BaseDataType): The data type for the value.
|
||||
|
||||
Returns:
|
||||
TensorValue: The corresponding TensorValue.
|
||||
"""
|
||||
return TensorValue(dtype=dtype, is_encrypted=False, shape=())
|
||||
|
||||
|
||||
def make_encrypted_scalar(dtype: BaseDataType) -> TensorValue:
|
||||
"""Create an encrypted scalar value.
|
||||
|
||||
Args:
|
||||
dtype (BaseDataType): The data type for the value.
|
||||
|
||||
Returns:
|
||||
TensorValue: The corresponding TensorValue.
|
||||
"""
|
||||
return TensorValue(dtype=dtype, is_encrypted=True, shape=())
|
||||
|
||||
|
||||
ClearScalar = make_clear_scalar
|
||||
EncryptedScalar = make_encrypted_scalar
|
||||
@@ -1,25 +0,0 @@
|
||||
"""Module for compiling numpy functions to homomorphic equivalents."""
|
||||
|
||||
# Import differently to put at the top, and avoid circular import issues
|
||||
from concrete.numpy.compile import (
|
||||
compile_numpy_function,
|
||||
compile_numpy_function_into_op_graph_and_measure_bounds,
|
||||
)
|
||||
from concrete.numpy.np_fhe_compiler import NPFHECompiler
|
||||
from concrete.numpy.tracing import trace_numpy_function
|
||||
|
||||
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
|
||||
from ..common.data_types import (
|
||||
Float,
|
||||
Float16,
|
||||
Float32,
|
||||
Float64,
|
||||
Integer,
|
||||
SignedInteger,
|
||||
UnsignedInteger,
|
||||
)
|
||||
from ..common.debugging import draw_graph, format_operation_graph
|
||||
from ..common.extensions.convolution import conv2d
|
||||
from ..common.extensions.multi_table import MultiLookupTable
|
||||
from ..common.extensions.table import LookupTable
|
||||
from ..common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue
|
||||
@@ -1,805 +0,0 @@
|
||||
"""numpy compilation function."""
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy
|
||||
|
||||
from ..common.bounds_measurement.inputset_eval import eval_op_graph_bounds_on_inputset
|
||||
from ..common.common_helpers import check_op_graph_is_integer_program
|
||||
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
|
||||
from ..common.data_types import Integer
|
||||
from ..common.debugging import format_operation_graph
|
||||
from ..common.debugging.custom_assert import assert_true
|
||||
from ..common.fhe_circuit import FHECircuit
|
||||
from ..common.mlir.utils import (
|
||||
check_graph_values_compatibility_with_mlir,
|
||||
update_bit_width_for_mlir,
|
||||
)
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.optimization.topological import fuse_float_operations
|
||||
from ..common.representation.intermediate import Add, Constant, GenericFunction, IntermediateNode
|
||||
from ..common.values import BaseValue, ClearScalar
|
||||
from ..numpy.tracing import trace_numpy_function
|
||||
from .np_dtypes_helpers import (
|
||||
get_base_data_type_for_numpy_or_python_constant_data,
|
||||
get_base_value_for_numpy_or_python_constant_data,
|
||||
get_constructor_for_numpy_or_python_constant_data,
|
||||
)
|
||||
from .np_inputset_helpers import _check_special_inputset_availability, _generate_random_inputset
|
||||
from .np_mlir_converter import NPMLIRConverter
|
||||
|
||||
_COMPILE_FHE_INSECURE_KEY_CACHE_DIR: Optional[str] = None
|
||||
|
||||
|
||||
def numpy_max_func(lhs: Any, rhs: Any) -> Any:
|
||||
"""Compute the maximum value between two values which can be numpy classes (e.g. ndarray).
|
||||
|
||||
Args:
|
||||
lhs (Any): lhs value to compute max from.
|
||||
rhs (Any): rhs value to compute max from.
|
||||
|
||||
Returns:
|
||||
Any: maximum scalar value between lhs and rhs.
|
||||
"""
|
||||
return numpy.maximum(lhs, rhs).max()
|
||||
|
||||
|
||||
def numpy_min_func(lhs: Any, rhs: Any) -> Any:
|
||||
"""Compute the minimum value between two values which can be numpy classes (e.g. ndarray).
|
||||
|
||||
Args:
|
||||
lhs (Any): lhs value to compute min from.
|
||||
rhs (Any): rhs value to compute min from.
|
||||
|
||||
Returns:
|
||||
Any: minimum scalar value between lhs and rhs.
|
||||
"""
|
||||
return numpy.minimum(lhs, rhs).min()
|
||||
|
||||
|
||||
def sanitize_compilation_configuration_and_artifacts(
|
||||
compilation_configuration: Optional[CompilationConfiguration] = None,
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
) -> Tuple[CompilationConfiguration, CompilationArtifacts]:
|
||||
"""Return the proper compilation configuration and artifacts.
|
||||
|
||||
Default values are returned if None is passed for each argument.
|
||||
|
||||
Args:
|
||||
compilation_configuration (Optional[CompilationConfiguration], optional): the compilation
|
||||
configuration to sanitize. Defaults to None.
|
||||
compilation_artifacts (Optional[CompilationArtifacts], optional): the compilation artifacts
|
||||
to sanitize. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tuple[CompilationConfiguration, CompilationArtifacts]: the tuple of sanitized configuration
|
||||
and artifacts.
|
||||
"""
|
||||
# Create default configuration if custom configuration is not specified
|
||||
compilation_configuration = (
|
||||
CompilationConfiguration()
|
||||
if compilation_configuration is None
|
||||
else compilation_configuration
|
||||
)
|
||||
|
||||
# Create temporary artifacts if custom artifacts is not specified (in case of exceptions)
|
||||
if compilation_artifacts is None:
|
||||
compilation_artifacts = CompilationArtifacts()
|
||||
|
||||
return compilation_configuration, compilation_artifacts
|
||||
|
||||
|
||||
def get_inputset_to_use(
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]], str],
|
||||
compilation_configuration: CompilationConfiguration,
|
||||
) -> Union[Iterable[Any], Iterable[Tuple[Any, ...]]]:
|
||||
"""Get the proper inputset to use for compilation.
|
||||
|
||||
Args:
|
||||
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
|
||||
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
|
||||
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]], str]): The inputset over which
|
||||
op_graph is evaluated. It needs to be an iterable on tuples which are of the same length
|
||||
than the number of parameters in the function, and in the same order than these same
|
||||
parameters
|
||||
compilation_configuration (CompilationConfiguration): Configuration object to use during
|
||||
compilation
|
||||
|
||||
Returns:
|
||||
Union[Iterable[Any], Iterable[Tuple[Any, ...]]]: the inputset to use.
|
||||
"""
|
||||
# Generate random inputset if it is requested and available
|
||||
if isinstance(inputset, str):
|
||||
_check_special_inputset_availability(inputset, compilation_configuration)
|
||||
return _generate_random_inputset(function_parameters, compilation_configuration)
|
||||
return inputset
|
||||
|
||||
|
||||
def run_compilation_function_with_error_management(
|
||||
compilation_function: Callable,
|
||||
compilation_configuration: CompilationConfiguration,
|
||||
compilation_artifacts: CompilationArtifacts,
|
||||
) -> Any:
|
||||
"""Call compilation_function() and manage exceptions that may occur.
|
||||
|
||||
Args:
|
||||
compilation_function (Callable): the compilation function to call.
|
||||
compilation_configuration (CompilationConfiguration): the current compilation configuration.
|
||||
compilation_artifacts (CompilationArtifacts): the current compilation artifacts.
|
||||
|
||||
Returns:
|
||||
Any: returns the result of the call to compilation_function
|
||||
"""
|
||||
|
||||
# Try to compile the function and save partial artifacts on failure
|
||||
try:
|
||||
# Use context manager to restore numpy error handling
|
||||
with numpy.errstate(**numpy.geterr()):
|
||||
return compilation_function()
|
||||
except Exception: # pragma: no cover
|
||||
# This branch is reserved for unexpected issues and hence it shouldn't be tested.
|
||||
# If it could be tested, we would have fixed the underlying issue.
|
||||
|
||||
# We need to export all the information we have about the compilation
|
||||
# If the user wants them to be exported
|
||||
|
||||
if compilation_configuration.dump_artifacts_on_unexpected_failures:
|
||||
compilation_artifacts.export()
|
||||
|
||||
traceback_path = compilation_artifacts.output_directory.joinpath("traceback.txt")
|
||||
with open(traceback_path, "w", encoding="utf-8") as f:
|
||||
f.write(traceback.format_exc())
|
||||
|
||||
raise
|
||||
|
||||
|
||||
def _compile_numpy_function_into_op_graph_internal(
|
||||
function_to_compile: Callable,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
compilation_configuration: CompilationConfiguration,
|
||||
compilation_artifacts: CompilationArtifacts,
|
||||
) -> OPGraph:
|
||||
"""Compile a function into an OPGraph without evaluating the intermediate nodes bounds.
|
||||
|
||||
Args:
|
||||
function_to_compile (Callable): The function to compile
|
||||
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
|
||||
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
|
||||
compilation_configuration (CompilationConfiguration): Configuration object to use
|
||||
during compilation
|
||||
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
|
||||
during compilation
|
||||
|
||||
Returns:
|
||||
OPGraph: compiled function into a graph, node values are not representative of the values
|
||||
that can be observed during execution.
|
||||
Use _compile_numpy_function_into_op_graph_and_measure_bounds_internal if you need bounds
|
||||
estimation.
|
||||
"""
|
||||
# Check function parameters
|
||||
wrong_inputs = {
|
||||
inp: function_parameters[inp]
|
||||
for inp in function_parameters.keys()
|
||||
if not isinstance(function_parameters[inp], BaseValue)
|
||||
}
|
||||
list_of_possible_basevalue = [
|
||||
"ClearTensor",
|
||||
"EncryptedTensor",
|
||||
"ClearScalar",
|
||||
"EncryptedScalar",
|
||||
]
|
||||
assert_true(
|
||||
len(wrong_inputs.keys()) == 0,
|
||||
f"wrong type for inputs {wrong_inputs}, needs to be one of {list_of_possible_basevalue}",
|
||||
)
|
||||
|
||||
# Add the function to compile as an artifact
|
||||
compilation_artifacts.add_function_to_compile(function_to_compile)
|
||||
|
||||
# Add the parameters of function to compile as artifacts
|
||||
for name, value in function_parameters.items():
|
||||
compilation_artifacts.add_parameter_of_function_to_compile(name, str(value))
|
||||
|
||||
# Trace the function
|
||||
op_graph = trace_numpy_function(function_to_compile, function_parameters)
|
||||
|
||||
# Add the initial graph as an artifact
|
||||
compilation_artifacts.add_operation_graph("initial", op_graph)
|
||||
|
||||
# Apply topological optimizations if they are enabled
|
||||
if compilation_configuration.enable_topological_optimizations:
|
||||
# Fuse float operations to have int to int GenericFunction
|
||||
if not check_op_graph_is_integer_program(op_graph):
|
||||
fuse_float_operations(op_graph, compilation_artifacts)
|
||||
|
||||
return op_graph
|
||||
|
||||
|
||||
def compile_numpy_function_into_op_graph(
|
||||
function_to_compile: Callable,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
compilation_configuration: Optional[CompilationConfiguration] = None,
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
) -> OPGraph:
|
||||
"""Compile a function into an OPGraph.
|
||||
|
||||
Args:
|
||||
function_to_compile (Callable): The function to compile
|
||||
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
|
||||
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
|
||||
compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use
|
||||
during compilation
|
||||
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
|
||||
during compilation
|
||||
|
||||
Returns:
|
||||
OPGraph: compiled function into a graph
|
||||
"""
|
||||
|
||||
(
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
) = sanitize_compilation_configuration_and_artifacts(
|
||||
compilation_configuration, compilation_artifacts
|
||||
)
|
||||
|
||||
def compilation_function():
|
||||
return _compile_numpy_function_into_op_graph_internal(
|
||||
function_to_compile,
|
||||
function_parameters,
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
)
|
||||
|
||||
result = run_compilation_function_with_error_management(
|
||||
compilation_function, compilation_configuration, compilation_artifacts
|
||||
)
|
||||
|
||||
# for mypy
|
||||
assert isinstance(result, OPGraph)
|
||||
return result
|
||||
|
||||
|
||||
def _measure_op_graph_bounds_and_update_internal(
|
||||
op_graph: OPGraph,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]],
|
||||
compilation_configuration: CompilationConfiguration,
|
||||
compilation_artifacts: CompilationArtifacts,
|
||||
prev_node_bounds_and_samples: Optional[Dict[IntermediateNode, Dict[str, Any]]] = None,
|
||||
warn_on_inputset_length: bool = True,
|
||||
) -> Dict[IntermediateNode, Dict[str, Any]]:
|
||||
"""Measure the intermediate values and update the OPGraph accordingly for the given inputset.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): the OPGraph for which to measure bounds and update node values.
|
||||
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
|
||||
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
|
||||
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]]]): The inputset over which op_graph
|
||||
is evaluated. It needs to be an iterable on tuples which are of the same length than the
|
||||
number of parameters in the function, and in the same order than these same parameters
|
||||
compilation_configuration (CompilationConfiguration): Configuration object to use
|
||||
during compilation
|
||||
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
|
||||
during compilation
|
||||
prev_node_bounds_and_samples (Optional[Dict[IntermediateNode, Dict[str, Any]]], optional):
|
||||
Bounds and samples from a previous run. Defaults to None.
|
||||
warn_on_inputset_length (bool, optional): Set to True to get a warning if inputset is not
|
||||
long enough. Defaults to True.
|
||||
|
||||
Raises:
|
||||
ValueError: Raises an error if the inputset is too small and the compilation configuration
|
||||
treats warnings as error.
|
||||
|
||||
Returns:
|
||||
Dict[IntermediateNode, Dict[str, Any]]: a dict containing the bounds for each node from
|
||||
op_graph, stored with the node as key and a dict with keys "min", "max" and "sample" as
|
||||
value.
|
||||
"""
|
||||
|
||||
# Find bounds with the inputset
|
||||
inputset_size, node_bounds_and_samples = eval_op_graph_bounds_on_inputset(
|
||||
op_graph,
|
||||
inputset,
|
||||
compilation_configuration=compilation_configuration,
|
||||
min_func=numpy_min_func,
|
||||
max_func=numpy_max_func,
|
||||
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
|
||||
prev_node_bounds_and_samples=prev_node_bounds_and_samples,
|
||||
)
|
||||
|
||||
if warn_on_inputset_length:
|
||||
# Check inputset size
|
||||
inputset_size_upper_limit = 1
|
||||
|
||||
# this loop will determine the number of possible inputs of the function
|
||||
# if a function have a single 3-bit input, for example, inputset_size_upper_limit will be 8
|
||||
for parameter_value in function_parameters.values():
|
||||
if isinstance(parameter_value.dtype, Integer):
|
||||
# multiple parameter bit-widths are multiplied as they can be combined into an input
|
||||
inputset_size_upper_limit *= 2 ** parameter_value.dtype.bit_width
|
||||
|
||||
# if the upper limit of the inputset size goes above 10,
|
||||
# break the loop as we will require at least 10 inputs in this case
|
||||
if inputset_size_upper_limit > 10:
|
||||
break
|
||||
|
||||
minimum_required_inputset_size = min(inputset_size_upper_limit, 10)
|
||||
if inputset_size < minimum_required_inputset_size:
|
||||
message = (
|
||||
f"Provided inputset contains too few inputs "
|
||||
f"(it should have had at least {minimum_required_inputset_size} "
|
||||
f"but it only had {inputset_size})\n"
|
||||
)
|
||||
|
||||
if compilation_configuration.treat_warnings_as_errors:
|
||||
raise ValueError(message)
|
||||
|
||||
sys.stderr.write(f"Warning: {message}")
|
||||
|
||||
# Add the bounds as an artifact
|
||||
compilation_artifacts.add_final_operation_graph_bounds(node_bounds_and_samples)
|
||||
|
||||
# Update the graph accordingly: after that, we have the compilable graph
|
||||
op_graph.update_values_with_bounds_and_samples(
|
||||
node_bounds_and_samples,
|
||||
get_base_data_type_for_numpy_or_python_constant_data,
|
||||
get_constructor_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
return node_bounds_and_samples
|
||||
|
||||
|
||||
def measure_op_graph_bounds_and_update(
|
||||
op_graph: OPGraph,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]], str],
|
||||
compilation_configuration: Optional[CompilationConfiguration] = None,
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
prev_node_bounds_and_samples: Optional[Dict[IntermediateNode, Dict[str, Any]]] = None,
|
||||
warn_on_inputset_length: bool = True,
|
||||
) -> Dict[IntermediateNode, Dict[str, Any]]:
|
||||
"""Measure the intermediate values and update the OPGraph accordingly for the given inputset.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): the OPGraph for which to measure bounds and update node values.
|
||||
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
|
||||
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
|
||||
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]], str]): The inputset over which
|
||||
op_graph is evaluated. It needs to be an iterable on tuples which are of the same length
|
||||
than the number of parameters in the function, and in the same order than these same
|
||||
parameters
|
||||
compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use
|
||||
during compilation
|
||||
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
|
||||
during compilation
|
||||
prev_node_bounds_and_samples (Optional[Dict[IntermediateNode, Dict[str, Any]]], optional):
|
||||
Bounds and samples from a previous run. Defaults to None.
|
||||
warn_on_inputset_length (bool, optional): Set to True to get a warning if inputset is not
|
||||
long enough. Defaults to True.
|
||||
|
||||
Raises:
|
||||
ValueError: Raises an error if the inputset is too small and the compilation configuration
|
||||
treats warnings as error.
|
||||
|
||||
Returns:
|
||||
Dict[IntermediateNode, Dict[str, Any]]: a dict containing the bounds for each node from
|
||||
op_graph, stored with the node as key and a dict with keys "min", "max" and "sample" as
|
||||
value.
|
||||
"""
|
||||
|
||||
(
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
) = sanitize_compilation_configuration_and_artifacts(
|
||||
compilation_configuration, compilation_artifacts
|
||||
)
|
||||
|
||||
inputset = get_inputset_to_use(function_parameters, inputset, compilation_configuration)
|
||||
|
||||
def compilation_function():
|
||||
return _measure_op_graph_bounds_and_update_internal(
|
||||
op_graph,
|
||||
function_parameters,
|
||||
inputset,
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
prev_node_bounds_and_samples,
|
||||
warn_on_inputset_length,
|
||||
)
|
||||
|
||||
result = run_compilation_function_with_error_management(
|
||||
compilation_function, compilation_configuration, compilation_artifacts
|
||||
)
|
||||
|
||||
# for mypy
|
||||
assert isinstance(result, dict)
|
||||
return result
|
||||
|
||||
|
||||
def _compile_numpy_function_into_op_graph_and_measure_bounds_internal(
|
||||
function_to_compile: Callable,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]],
|
||||
compilation_configuration: CompilationConfiguration,
|
||||
compilation_artifacts: CompilationArtifacts,
|
||||
) -> OPGraph:
|
||||
"""Compile a function into an OPGraph and evaluate the intermediate nodes bounds.
|
||||
|
||||
Args:
|
||||
function_to_compile (Callable): The function to compile
|
||||
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
|
||||
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
|
||||
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]]]): The inputset over which op_graph
|
||||
is evaluated. It needs to be an iterable on tuples which are of the same length than the
|
||||
number of parameters in the function, and in the same order than these same parameters
|
||||
compilation_configuration (CompilationConfiguration): Configuration object to use
|
||||
during compilation
|
||||
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
|
||||
during compilation
|
||||
|
||||
Returns:
|
||||
OPGraph: compiled function into a graph with estimated bounds in node values.
|
||||
"""
|
||||
|
||||
op_graph = _compile_numpy_function_into_op_graph_internal(
|
||||
function_to_compile,
|
||||
function_parameters,
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
)
|
||||
|
||||
_measure_op_graph_bounds_and_update_internal(
|
||||
op_graph,
|
||||
function_parameters,
|
||||
inputset,
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
)
|
||||
|
||||
# Add the final graph as an artifact
|
||||
compilation_artifacts.add_operation_graph("final", op_graph)
|
||||
|
||||
return op_graph
|
||||
|
||||
|
||||
def compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function_to_compile: Callable,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]], str],
|
||||
compilation_configuration: Optional[CompilationConfiguration] = None,
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
) -> OPGraph:
|
||||
"""Compile a function into an OPGraph.
|
||||
|
||||
Args:
|
||||
function_to_compile (Callable): The function to compile
|
||||
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
|
||||
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
|
||||
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]], str]): The inputset over which
|
||||
op_graph is evaluated. It needs to be an iterable on tuples which are of the same length
|
||||
than the number of parameters in the function, and in the same order than these same
|
||||
parameters. Alternatively, it can be "random" but that's an unstable feature and should
|
||||
not be used in production.
|
||||
compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use
|
||||
during compilation
|
||||
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
|
||||
during compilation
|
||||
|
||||
Returns:
|
||||
OPGraph: compiled function into a graph
|
||||
"""
|
||||
|
||||
(
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
) = sanitize_compilation_configuration_and_artifacts(
|
||||
compilation_configuration, compilation_artifacts
|
||||
)
|
||||
|
||||
inputset = get_inputset_to_use(function_parameters, inputset, compilation_configuration)
|
||||
|
||||
def compilation_function():
|
||||
return _compile_numpy_function_into_op_graph_and_measure_bounds_internal(
|
||||
function_to_compile,
|
||||
function_parameters,
|
||||
inputset,
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
)
|
||||
|
||||
result = run_compilation_function_with_error_management(
|
||||
compilation_function, compilation_configuration, compilation_artifacts
|
||||
)
|
||||
|
||||
# for mypy
|
||||
assert isinstance(result, OPGraph)
|
||||
return result
|
||||
|
||||
|
||||
# HACK
|
||||
# TODO: remove this ugly hack when
|
||||
# https://github.com/zama-ai/concrete-numpy-internal/issues/1001 is done
|
||||
# TODO: https://github.com/zama-ai/concrete-numpy-internal/issues/1015
|
||||
def hack_offset_negative_inputs_to_lookup_tables(op_graph: OPGraph) -> None:
|
||||
"""Hack the op_graph to add offsets to signed inputs to TLUs.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): the OPGraph to hack.
|
||||
"""
|
||||
# Ugly hack to add an offset before entering a TLU if its variable input node has a signed
|
||||
# output.
|
||||
# This is ugly as this makes hardcoded assumptions about the way bit widths are handled in MLIR.
|
||||
# This does not update the TLU input values to allow for proper table generation.
|
||||
# Thankfully we are not supposed to touch the op_graph beyond that point
|
||||
for node in list((nx_graph := op_graph.graph).nodes):
|
||||
if isinstance(node, GenericFunction) and node.op_kind == "TLU":
|
||||
ordered_preds_and_inputs = op_graph.get_ordered_preds_and_inputs_of(node)
|
||||
variable_input_indices = [
|
||||
idx
|
||||
for idx, (pred, _) in enumerate(ordered_preds_and_inputs)
|
||||
if not isinstance(pred, Constant)
|
||||
]
|
||||
assert_true(len(variable_input_indices) == 1)
|
||||
variable_input_idx = variable_input_indices[0]
|
||||
variable_input_node = ordered_preds_and_inputs[variable_input_idx][0]
|
||||
variable_input_value = variable_input_node.outputs[0]
|
||||
variable_input_dtype = variable_input_value.dtype
|
||||
assert_true(isinstance(variable_input_dtype, Integer))
|
||||
variable_input_dtype = cast(Integer, variable_input_dtype)
|
||||
if not variable_input_dtype.is_signed:
|
||||
continue
|
||||
|
||||
# input_bit_width + 1 to be MLIR compliant
|
||||
input_bit_width = variable_input_dtype.bit_width
|
||||
mlir_compliant_int_type = Integer(input_bit_width + 1, True)
|
||||
|
||||
# Manually fix the output values to be MLIR compliant
|
||||
# offset_constant is set to abs(min_value) for the variable input so that the values
|
||||
# [- 2 ** (n - 1); 2 ** (n - 1) - 1] is mapped to [0; 2 ** n - 1], changing the signed
|
||||
# TLU to an actual unsigned TLU. The get_table function creates the table from the min
|
||||
# value to the max value. As we keep the input value as a signed value, it will be from
|
||||
# - 2 ** (n - 1) to 2 ** (n - 1) - 1. Then, the get_table function stores corresponding
|
||||
# values in increasing indexes from 0 to 2 ** n - 1. As our signed values have been
|
||||
# shifted by 2 ** (n - 1), the table will be usable as-is, without needing any change in
|
||||
# the lambda function of the GenericFunction.
|
||||
offset_constant = Constant(abs(variable_input_dtype.min_value()))
|
||||
offset_constant.outputs[0].dtype = deepcopy(mlir_compliant_int_type)
|
||||
add_offset = Add(
|
||||
[deepcopy(variable_input_value), ClearScalar(deepcopy(mlir_compliant_int_type))]
|
||||
)
|
||||
add_offset.outputs[0] = deepcopy(variable_input_value)
|
||||
|
||||
nx_graph.remove_edge(variable_input_node, node)
|
||||
nx_graph.add_edge(variable_input_node, add_offset, input_idx=0, output_idx=0)
|
||||
nx_graph.add_edge(offset_constant, add_offset, input_idx=1, output_idx=0)
|
||||
nx_graph.add_edge(add_offset, node, input_idx=variable_input_idx, output_idx=0)
|
||||
|
||||
|
||||
def prepare_op_graph_for_mlir(op_graph: OPGraph):
|
||||
"""Prepare OPGraph for MLIR lowering.
|
||||
|
||||
This includes checking compatibility, changing bit-widths, and modifying lookup tables.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): The operation graph to prepare
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Make sure the graph can be lowered to MLIR
|
||||
offending_nodes = check_graph_values_compatibility_with_mlir(op_graph)
|
||||
if offending_nodes is not None:
|
||||
raise RuntimeError(
|
||||
"function you are trying to compile isn't supported for MLIR lowering\n\n"
|
||||
+ format_operation_graph(op_graph, highlighted_nodes=offending_nodes)
|
||||
)
|
||||
|
||||
# Update bit_width for MLIR
|
||||
update_bit_width_for_mlir(op_graph)
|
||||
|
||||
# HACK
|
||||
# TODO: remove this ugly hack when
|
||||
# https://github.com/zama-ai/concrete-numpy-internal/issues/1001 is done
|
||||
# TODO: https://github.com/zama-ai/concrete-numpy-internal/issues/1015
|
||||
hack_offset_negative_inputs_to_lookup_tables(op_graph)
|
||||
|
||||
|
||||
def _compile_op_graph_to_fhe_circuit_internal(
|
||||
op_graph: OPGraph,
|
||||
show_mlir: bool,
|
||||
compilation_configuration: CompilationConfiguration,
|
||||
compilation_artifacts: CompilationArtifacts,
|
||||
) -> FHECircuit:
|
||||
"""Compile the OPGraph to an FHECircuit.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): the OPGraph to compile.
|
||||
show_mlir (bool): determine whether we print the mlir string.
|
||||
compilation_configuration (CompilationConfiguration): Configuration object to use
|
||||
during compilation
|
||||
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
|
||||
during compilation
|
||||
|
||||
Returns:
|
||||
FHECircuit: the compiled FHECircuit
|
||||
"""
|
||||
prepare_op_graph_for_mlir(op_graph)
|
||||
|
||||
# Convert graph to an MLIR representation
|
||||
converter = NPMLIRConverter()
|
||||
mlir_result = converter.convert(op_graph)
|
||||
|
||||
# Show MLIR representation if requested
|
||||
if show_mlir:
|
||||
print(f"MLIR which is going to be compiled: \n{mlir_result}")
|
||||
|
||||
# Add MLIR representation as an artifact
|
||||
compilation_artifacts.add_final_operation_graph_mlir(mlir_result)
|
||||
|
||||
if _COMPILE_FHE_INSECURE_KEY_CACHE_DIR is not None and not (
|
||||
compilation_configuration.use_insecure_key_cache
|
||||
and compilation_configuration.enable_unsafe_features
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Unable to use insecure key cache {_COMPILE_FHE_INSECURE_KEY_CACHE_DIR} "
|
||||
"as use_insecure_key_cache or enable_unsafe_features are not set to True in"
|
||||
"compilation_configuration"
|
||||
)
|
||||
|
||||
return FHECircuit(
|
||||
op_graph,
|
||||
mlir_result,
|
||||
unsecure_key_set_cache_path=_COMPILE_FHE_INSECURE_KEY_CACHE_DIR,
|
||||
auto_parallelize=compilation_configuration.auto_parallelize,
|
||||
loop_parallelize=compilation_configuration.loop_parallelize,
|
||||
dataflow_parallelize=compilation_configuration.dataflow_parallelize,
|
||||
)
|
||||
|
||||
|
||||
def compile_op_graph_to_fhe_circuit(
|
||||
op_graph: OPGraph,
|
||||
show_mlir: bool,
|
||||
compilation_configuration: Optional[CompilationConfiguration] = None,
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
) -> FHECircuit:
|
||||
"""Compile the OPGraph to an FHECircuit.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): the OPGraph to compile.
|
||||
show_mlir (bool): determine whether we print the mlir string.
|
||||
compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use
|
||||
during compilation
|
||||
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
|
||||
during compilation
|
||||
|
||||
Returns:
|
||||
FHECircuit: the compiled circuit and the compiled FHECircuit
|
||||
"""
|
||||
|
||||
(
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
) = sanitize_compilation_configuration_and_artifacts(
|
||||
compilation_configuration, compilation_artifacts
|
||||
)
|
||||
|
||||
def compilation_function():
|
||||
return _compile_op_graph_to_fhe_circuit_internal(
|
||||
op_graph, show_mlir, compilation_configuration, compilation_artifacts
|
||||
)
|
||||
|
||||
result = run_compilation_function_with_error_management(
|
||||
compilation_function, compilation_configuration, compilation_artifacts
|
||||
)
|
||||
|
||||
# for mypy
|
||||
assert isinstance(result, FHECircuit)
|
||||
return result
|
||||
|
||||
|
||||
def _compile_numpy_function_internal(
|
||||
function_to_compile: Callable,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]],
|
||||
compilation_configuration: CompilationConfiguration,
|
||||
compilation_artifacts: CompilationArtifacts,
|
||||
show_mlir: bool,
|
||||
) -> FHECircuit:
|
||||
"""Compile an homomorphic program (internal part of the API).
|
||||
|
||||
Args:
|
||||
function_to_compile (Callable): The function you want to compile
|
||||
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
|
||||
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
|
||||
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]]]): The inputset over which op_graph
|
||||
is evaluated. It needs to be an iterable on tuples which are of the same length than the
|
||||
number of parameters in the function, and in the same order than these same parameters
|
||||
compilation_configuration (CompilationConfiguration): Configuration object to use
|
||||
during compilation
|
||||
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
|
||||
during compilation
|
||||
show_mlir (bool): if set, the MLIR produced by the converter and which is going
|
||||
to be sent to the compiler backend is shown on the screen, e.g., for debugging or demo
|
||||
|
||||
Returns:
|
||||
CompilerEngine: engine to run and debug the compiled graph
|
||||
"""
|
||||
|
||||
# Compile into an OPGraph
|
||||
op_graph = _compile_numpy_function_into_op_graph_and_measure_bounds_internal(
|
||||
function_to_compile,
|
||||
function_parameters,
|
||||
inputset,
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
)
|
||||
|
||||
fhe_circuit = _compile_op_graph_to_fhe_circuit_internal(
|
||||
op_graph, show_mlir, compilation_configuration, compilation_artifacts
|
||||
)
|
||||
|
||||
return fhe_circuit
|
||||
|
||||
|
||||
def compile_numpy_function(
|
||||
function_to_compile: Callable,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]], str],
|
||||
compilation_configuration: Optional[CompilationConfiguration] = None,
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
show_mlir: bool = False,
|
||||
) -> FHECircuit:
|
||||
"""Compile an homomorphic program (main API).
|
||||
|
||||
Args:
|
||||
function_to_compile (Callable): The function to compile
|
||||
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
|
||||
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
|
||||
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]], str]): The inputset over which
|
||||
op_graph is evaluated. It needs to be an iterable on tuples which are of the same length
|
||||
than the number of parameters in the function, and in the same order than these same
|
||||
parameters. Alternatively, it can be "random" but that's an unstable feature and should
|
||||
not be used in production.
|
||||
compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use
|
||||
during compilation
|
||||
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
|
||||
during compilation
|
||||
show_mlir (bool): if set, the MLIR produced by the converter and which is going
|
||||
to be sent to the compiler backend is shown on the screen, e.g., for debugging or demo
|
||||
|
||||
Returns:
|
||||
CompilerEngine: engine to run and debug the compiled graph
|
||||
"""
|
||||
|
||||
(
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
) = sanitize_compilation_configuration_and_artifacts(
|
||||
compilation_configuration, compilation_artifacts
|
||||
)
|
||||
|
||||
inputset = get_inputset_to_use(function_parameters, inputset, compilation_configuration)
|
||||
|
||||
def compilation_function():
|
||||
return _compile_numpy_function_internal(
|
||||
function_to_compile,
|
||||
function_parameters,
|
||||
inputset,
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
show_mlir,
|
||||
)
|
||||
|
||||
result = run_compilation_function_with_error_management(
|
||||
compilation_function, compilation_configuration, compilation_artifacts
|
||||
)
|
||||
|
||||
# for mypy
|
||||
assert isinstance(result, FHECircuit)
|
||||
return result
|
||||
@@ -1,308 +0,0 @@
|
||||
"""File to hold code to manage package and numpy dtypes."""
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Tuple, Union
|
||||
|
||||
import numpy
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
from ..common.data_types.base import BaseDataType
|
||||
from ..common.data_types.dtypes_helpers import (
|
||||
BASE_DATA_TYPES,
|
||||
find_type_to_hold_both_lossy,
|
||||
get_base_data_type_for_python_constant_data,
|
||||
get_base_value_for_python_constant_data,
|
||||
get_constructor_for_python_constant_data,
|
||||
)
|
||||
from ..common.data_types.floats import Float
|
||||
from ..common.data_types.integers import Integer
|
||||
from ..common.debugging.custom_assert import assert_true
|
||||
from ..common.tracing import BaseTracer
|
||||
from ..common.values import BaseValue, TensorValue
|
||||
|
||||
NUMPY_TO_COMMON_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = {
|
||||
numpy.dtype(numpy.byte): Integer(numpy.byte(0).nbytes * 8, is_signed=True),
|
||||
numpy.dtype(numpy.short): Integer(numpy.short(0).nbytes * 8, is_signed=True),
|
||||
numpy.dtype(numpy.intc): Integer(numpy.intc(0).nbytes * 8, is_signed=True),
|
||||
numpy.dtype(numpy.int_): Integer(numpy.int_(0).nbytes * 8, is_signed=True),
|
||||
numpy.dtype(numpy.longlong): Integer(numpy.longlong(0).nbytes * 8, is_signed=True),
|
||||
numpy.dtype(numpy.int8): Integer(numpy.int8(0).nbytes * 8, is_signed=True),
|
||||
numpy.dtype(numpy.int16): Integer(numpy.int16(0).nbytes * 8, is_signed=True),
|
||||
numpy.dtype(numpy.int32): Integer(numpy.int32(0).nbytes * 8, is_signed=True),
|
||||
numpy.dtype(numpy.int64): Integer(numpy.int64(0).nbytes * 8, is_signed=True),
|
||||
numpy.dtype(numpy.ubyte): Integer(numpy.ubyte(0).nbytes * 8, is_signed=False),
|
||||
numpy.dtype(numpy.ushort): Integer(numpy.ushort(0).nbytes * 8, is_signed=False),
|
||||
numpy.dtype(numpy.uintc): Integer(numpy.uintc(0).nbytes * 8, is_signed=False),
|
||||
numpy.dtype(numpy.uint): Integer(numpy.uint(0).nbytes * 8, is_signed=False),
|
||||
numpy.dtype(numpy.ulonglong): Integer(numpy.ulonglong(0).nbytes * 8, is_signed=False),
|
||||
numpy.dtype(numpy.uint8): Integer(numpy.uint8(0).nbytes * 8, is_signed=False),
|
||||
numpy.dtype(numpy.uint16): Integer(numpy.uint16(0).nbytes * 8, is_signed=False),
|
||||
numpy.dtype(numpy.uint32): Integer(numpy.uint32(0).nbytes * 8, is_signed=False),
|
||||
numpy.dtype(numpy.uint64): Integer(numpy.uint64(0).nbytes * 8, is_signed=False),
|
||||
numpy.dtype(numpy.float16): Float(16),
|
||||
numpy.dtype(numpy.float32): Float(32),
|
||||
numpy.dtype(numpy.float64): Float(64),
|
||||
numpy.dtype(bool): Integer(8, is_signed=False),
|
||||
}
|
||||
|
||||
SUPPORTED_NUMPY_DTYPES = tuple(NUMPY_TO_COMMON_DTYPE_MAPPING)
|
||||
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES = tuple(dtype.type for dtype in NUMPY_TO_COMMON_DTYPE_MAPPING)
|
||||
|
||||
SUPPORTED_DTYPE_MSG_STRING = ", ".join(sorted(str(dtype) for dtype in SUPPORTED_NUMPY_DTYPES))
|
||||
|
||||
|
||||
def convert_numpy_dtype_to_base_data_type(numpy_dtype: DTypeLike) -> BaseDataType:
|
||||
"""Get the corresponding BaseDataType from a numpy dtype.
|
||||
|
||||
Args:
|
||||
numpy_dtype (DTypeLike): Any python object that can be translated to a numpy.dtype
|
||||
|
||||
Raises:
|
||||
ValueError: If the numpy_dtype is not supported
|
||||
|
||||
Returns:
|
||||
BaseDataType: The corresponding data type corresponding to the input numpy_dtype
|
||||
"""
|
||||
# Normalize numpy_dtype
|
||||
normalized_numpy_dtype = numpy.dtype(numpy_dtype)
|
||||
corresponding_common_dtype = NUMPY_TO_COMMON_DTYPE_MAPPING.get(normalized_numpy_dtype, None)
|
||||
|
||||
if corresponding_common_dtype is None:
|
||||
raise ValueError(
|
||||
f"Unsupported numpy type: {numpy_dtype} ({normalized_numpy_dtype}), "
|
||||
f"supported numpy types: "
|
||||
f"{SUPPORTED_DTYPE_MSG_STRING}"
|
||||
)
|
||||
|
||||
# deepcopy to avoid having the value from the dict modified
|
||||
return deepcopy(corresponding_common_dtype)
|
||||
|
||||
|
||||
def convert_base_data_type_to_numpy_dtype(common_dtype: BaseDataType) -> numpy.dtype:
|
||||
"""Convert a BaseDataType to corresponding numpy.dtype.
|
||||
|
||||
Args:
|
||||
common_dtype (BaseDataType): dtype to convert to numpy.dtype
|
||||
|
||||
Returns:
|
||||
numpy.dtype: The resulting numpy.dtype
|
||||
"""
|
||||
assert_true(
|
||||
isinstance(common_dtype, BASE_DATA_TYPES), f"Unsupported common_dtype: {type(common_dtype)}"
|
||||
)
|
||||
type_to_return: numpy.dtype
|
||||
|
||||
if isinstance(common_dtype, Float):
|
||||
assert_true(
|
||||
(bit_width := common_dtype.bit_width)
|
||||
in (
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
),
|
||||
"Only converting Float(16), Float(32) or Float(64) is supported",
|
||||
)
|
||||
if bit_width == 64:
|
||||
type_to_return = numpy.dtype(numpy.float64)
|
||||
elif bit_width == 32:
|
||||
type_to_return = numpy.dtype(numpy.float32)
|
||||
else:
|
||||
type_to_return = numpy.dtype(numpy.float16)
|
||||
elif isinstance(common_dtype, Integer):
|
||||
signed = common_dtype.is_signed
|
||||
if common_dtype.bit_width <= 32:
|
||||
type_to_return = numpy.dtype(numpy.int32) if signed else numpy.dtype(numpy.uint32)
|
||||
elif common_dtype.bit_width <= 64:
|
||||
type_to_return = numpy.dtype(numpy.int64) if signed else numpy.dtype(numpy.uint64)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Conversion to numpy dtype only supports Integers with bit_width <= 64, "
|
||||
f"got {common_dtype!r}"
|
||||
)
|
||||
|
||||
return type_to_return
|
||||
|
||||
|
||||
def get_base_data_type_for_numpy_or_python_constant_data(constant_data: Any) -> BaseDataType:
|
||||
"""Determine the BaseDataType to hold the input constant data.
|
||||
|
||||
Args:
|
||||
constant_data (Any): The constant data for which to determine the
|
||||
corresponding BaseDataType.
|
||||
|
||||
Returns:
|
||||
BaseDataType: The corresponding BaseDataType
|
||||
"""
|
||||
base_dtype: BaseDataType
|
||||
assert_true(
|
||||
isinstance(
|
||||
constant_data, (int, float, list, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
|
||||
),
|
||||
f"Unsupported constant data of type {type(constant_data)}",
|
||||
)
|
||||
if isinstance(constant_data, (numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)):
|
||||
native_type = float if (constant_data.dtype in (numpy.float32, numpy.float64)) else int
|
||||
|
||||
min_value = native_type(constant_data.min())
|
||||
max_value = native_type(constant_data.max())
|
||||
|
||||
min_value_dtype = get_base_data_type_for_python_constant_data(min_value)
|
||||
max_value_dtype = get_base_data_type_for_python_constant_data(max_value)
|
||||
|
||||
# numpy
|
||||
base_dtype = find_type_to_hold_both_lossy(min_value_dtype, max_value_dtype)
|
||||
else:
|
||||
# python
|
||||
base_dtype = get_base_data_type_for_python_constant_data(constant_data)
|
||||
return base_dtype
|
||||
|
||||
|
||||
def get_base_value_for_numpy_or_python_constant_data(
|
||||
constant_data: Any,
|
||||
) -> Callable[..., BaseValue]:
|
||||
"""Determine the BaseValue and BaseDataType to hold the input constant data.
|
||||
|
||||
This function is able to handle numpy types
|
||||
|
||||
Args:
|
||||
constant_data (Any): The constant data for which to determine the
|
||||
corresponding BaseValue and BaseDataType.
|
||||
|
||||
Raises:
|
||||
AssertionError: If `constant_data` is of an unsupported type.
|
||||
|
||||
Returns:
|
||||
Callable[..., BaseValue]: A partial object that will return the proper BaseValue when called
|
||||
with `encrypted` as keyword argument (forwarded to the BaseValue `__init__` method).
|
||||
"""
|
||||
constant_data_value: Callable[..., BaseValue]
|
||||
assert_true(
|
||||
not isinstance(constant_data, list),
|
||||
"Unsupported constant data of type list "
|
||||
"(if you meant to use a list as an array, please use numpy.array instead)",
|
||||
)
|
||||
assert_true(
|
||||
isinstance(
|
||||
constant_data,
|
||||
(int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES),
|
||||
),
|
||||
f"Unsupported constant data of type {type(constant_data)}",
|
||||
)
|
||||
|
||||
base_dtype = get_base_data_type_for_numpy_or_python_constant_data(constant_data)
|
||||
if isinstance(constant_data, numpy.ndarray):
|
||||
constant_data_value = partial(TensorValue, dtype=base_dtype, shape=constant_data.shape)
|
||||
elif isinstance(constant_data, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES):
|
||||
constant_data_value = partial(TensorValue, dtype=base_dtype, shape=())
|
||||
else:
|
||||
constant_data_value = get_base_value_for_python_constant_data(constant_data)
|
||||
return constant_data_value
|
||||
|
||||
|
||||
def get_numpy_function_output_dtype_and_shape_from_input_dtypes(
|
||||
function: Union[numpy.ufunc, Callable],
|
||||
input_dtypes: List[BaseDataType],
|
||||
input_shapes: List[Tuple[int, ...]],
|
||||
) -> List[Tuple[numpy.dtype, Tuple[int, ...]]]:
|
||||
"""Record the output dtype of a numpy function given some input types.
|
||||
|
||||
Args:
|
||||
function (Union[numpy.ufunc, Callable]): The numpy function whose output types need to
|
||||
be recorded
|
||||
input_dtypes (List[BaseDataType]): BaseDataTypes in the same order as they will be used with
|
||||
the function inputs
|
||||
input_shapes (List[Tuple[int, ...]]): Shapes in the same order as they will be used with
|
||||
the function inputs
|
||||
|
||||
Returns:
|
||||
List[Tuple[numpy.dtype, Tuple[int, ...]]]: appropriate (numpy.dtype, shape) tuple for each
|
||||
output of the function
|
||||
"""
|
||||
if isinstance(function, numpy.ufunc):
|
||||
assert_true(
|
||||
(len(input_dtypes) == function.nin),
|
||||
f"Expected {function.nin} types, got {len(input_dtypes)}: {input_dtypes}",
|
||||
)
|
||||
|
||||
input_numpy_dtypes = [convert_base_data_type_to_numpy_dtype(dtype) for dtype in input_dtypes]
|
||||
|
||||
dummy_inputs = tuple(
|
||||
(
|
||||
dtype.type(10.0 * numpy.random.random_sample())
|
||||
if shape == ()
|
||||
else numpy.abs(numpy.random.randn(*shape) * 10.0).astype(dtype)
|
||||
)
|
||||
for dtype, shape in zip(input_numpy_dtypes, input_shapes)
|
||||
)
|
||||
|
||||
# We ignore errors as we may call functions with invalid inputs just to get the proper output
|
||||
# dtypes
|
||||
with numpy.errstate(all="ignore"):
|
||||
outputs = function(*dummy_inputs)
|
||||
|
||||
if not isinstance(outputs, tuple):
|
||||
outputs = (outputs,)
|
||||
|
||||
return [(output.dtype, output.shape) for output in outputs]
|
||||
|
||||
|
||||
def get_numpy_function_output_dtype_and_shape_from_input_tracers(
|
||||
func: Union[numpy.ufunc, Callable],
|
||||
*input_tracers: BaseTracer,
|
||||
) -> List[Tuple[BaseDataType, Tuple[int, ...]]]:
|
||||
"""Determine output dtypes and shapes for a numpy function.
|
||||
|
||||
This function is responsible for determining the output dtype
|
||||
of a numpy function after inputs with specific dtypes are passed to it.
|
||||
|
||||
Args:
|
||||
func (Union[numpy.ufunc, Callable]): function that is being managed
|
||||
*input_tracers (BaseTracer): inputs to the function
|
||||
|
||||
Returns:
|
||||
List[Tuple[BaseDataType, Tuple[int, ...]]]: appropriate (BaseDataType, shape) tuple for each
|
||||
output of the function
|
||||
"""
|
||||
|
||||
input_shapes = [
|
||||
input_tracer.output.shape if isinstance(input_tracer.output, TensorValue) else ()
|
||||
for input_tracer in input_tracers
|
||||
]
|
||||
output_dtypes_and_shapes = get_numpy_function_output_dtype_and_shape_from_input_dtypes(
|
||||
func,
|
||||
[input_tracer.output.dtype for input_tracer in input_tracers],
|
||||
input_shapes,
|
||||
)
|
||||
common_output_dtypes = [
|
||||
(convert_numpy_dtype_to_base_data_type(dtype), shape)
|
||||
for dtype, shape in output_dtypes_and_shapes
|
||||
]
|
||||
return common_output_dtypes
|
||||
|
||||
|
||||
def get_constructor_for_numpy_or_python_constant_data(constant_data: Any):
|
||||
"""Get the constructor for the numpy constant data or python dtype.
|
||||
|
||||
Args:
|
||||
constant_data (Any): The data for which we want to determine the type constructor.
|
||||
"""
|
||||
|
||||
assert_true(
|
||||
isinstance(
|
||||
constant_data, (int, float, list, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
|
||||
),
|
||||
f"Unsupported constant data of type {type(constant_data)}",
|
||||
)
|
||||
|
||||
if isinstance(constant_data, list):
|
||||
# this is required because some operations return python lists from their evaluate function
|
||||
# an example of such operation is evaluation of multi tlu during bound measurements
|
||||
constant_data = numpy.array(constant_data)
|
||||
|
||||
if isinstance(constant_data, (numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)):
|
||||
if isinstance(constant_data, numpy.ndarray):
|
||||
return lambda x: numpy.full(constant_data.shape, x, dtype=constant_data.dtype)
|
||||
return constant_data.dtype.type
|
||||
|
||||
return get_constructor_for_python_constant_data(constant_data)
|
||||
@@ -1,309 +0,0 @@
|
||||
"""Module to hold a user friendly class to compile programs."""
|
||||
|
||||
import itertools
|
||||
from copy import deepcopy
|
||||
from enum import Enum, unique
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
|
||||
from ..common.data_types import Integer
|
||||
from ..common.debugging import draw_graph, format_operation_graph
|
||||
from ..common.fhe_circuit import FHECircuit
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.representation.intermediate import IntermediateNode
|
||||
from ..common.values import BaseValue
|
||||
from .compile import (
|
||||
compile_numpy_function_into_op_graph,
|
||||
compile_op_graph_to_fhe_circuit,
|
||||
measure_op_graph_bounds_and_update,
|
||||
)
|
||||
from .np_dtypes_helpers import get_base_value_for_numpy_or_python_constant_data
|
||||
|
||||
|
||||
@unique
|
||||
class EncryptedStatus(str, Enum):
|
||||
"""Enum to validate GenericFunction op_kind."""
|
||||
|
||||
CLEAR = "clear"
|
||||
ENCRYPTED = "encrypted"
|
||||
|
||||
|
||||
class NPFHECompiler:
|
||||
"""Class to ease the conversion of a numpy program to its FHE equivalent."""
|
||||
|
||||
INPUTSET_SIZE_BEFORE_AUTO_BOUND_UPDATE: int = 128
|
||||
|
||||
# _function_to_compile is not optional but mypy has a long standing bug and is not able to
|
||||
# understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623
|
||||
_function_to_compile: Optional[Callable]
|
||||
_function_parameters_encrypted_status: Dict[str, bool]
|
||||
_current_inputset: List[Union[Any, Tuple]]
|
||||
_op_graph: Optional[OPGraph]
|
||||
_nodes_and_bounds: Dict[IntermediateNode, Dict[str, Any]]
|
||||
|
||||
_compilation_configuration: CompilationConfiguration
|
||||
|
||||
compilation_artifacts: CompilationArtifacts
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
function_to_compile: Callable,
|
||||
function_parameters_encrypted_status: Dict[str, Union[str, EncryptedStatus]],
|
||||
compilation_configuration: Optional[CompilationConfiguration] = None,
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
) -> None:
|
||||
self._function_to_compile = function_to_compile
|
||||
self._function_parameters_encrypted_status = {
|
||||
param_name: EncryptedStatus(status.lower()) == EncryptedStatus.ENCRYPTED
|
||||
for param_name, status in function_parameters_encrypted_status.items()
|
||||
}
|
||||
|
||||
self._current_inputset = []
|
||||
self._op_graph = None
|
||||
self._nodes_and_bounds = {}
|
||||
|
||||
self._compilation_configuration = (
|
||||
deepcopy(compilation_configuration)
|
||||
if compilation_configuration is not None
|
||||
else CompilationConfiguration()
|
||||
)
|
||||
self.compilation_artifacts = (
|
||||
compilation_artifacts if compilation_artifacts is not None else CompilationArtifacts()
|
||||
)
|
||||
|
||||
@property
|
||||
def function_to_compile(self) -> Callable:
|
||||
"""Get the function to compile.
|
||||
|
||||
Returns:
|
||||
Callable: the function to compile.
|
||||
"""
|
||||
# Continuation of mypy bug
|
||||
assert self._function_to_compile is not None
|
||||
return self._function_to_compile
|
||||
|
||||
@property
|
||||
def op_graph(self) -> Optional[OPGraph]:
|
||||
"""Return a copy of the OPGraph.
|
||||
|
||||
Returns:
|
||||
Optional[OPGraph]: the held OPGraph or None
|
||||
"""
|
||||
# To keep consistency with what the user expects, we make sure to evaluate on the remaining
|
||||
# inputset values if any before giving a copy of the OPGraph we trace
|
||||
self._eval_on_current_inputset()
|
||||
return deepcopy(self._op_graph)
|
||||
|
||||
@property
|
||||
def compilation_configuration(self) -> Optional[CompilationConfiguration]:
|
||||
"""Get a copy of the compilation configuration.
|
||||
|
||||
Returns:
|
||||
Optional[CompilationConfiguration]: copy of the current compilation configuration.
|
||||
"""
|
||||
return deepcopy(self._compilation_configuration)
|
||||
|
||||
def __call__(self, *args: Any) -> Any:
|
||||
"""Evaluate the OPGraph corresponding to the function being compiled and return result.
|
||||
|
||||
Returns:
|
||||
Any: the result of the OPGraph evaluation.
|
||||
"""
|
||||
self._current_inputset.append(deepcopy(args) if len(args) > 1 else deepcopy(args[0]))
|
||||
|
||||
inferred_args = {
|
||||
param_name: get_base_value_for_numpy_or_python_constant_data(val)(
|
||||
is_encrypted=is_encrypted
|
||||
)
|
||||
for (param_name, is_encrypted), val in zip(
|
||||
self._function_parameters_encrypted_status.items(), args
|
||||
)
|
||||
}
|
||||
|
||||
if len(self._current_inputset) >= self.INPUTSET_SIZE_BEFORE_AUTO_BOUND_UPDATE:
|
||||
self._eval_on_current_inputset()
|
||||
|
||||
self._trace_op_graph_if_needed(inferred_args)
|
||||
|
||||
# For mypy
|
||||
assert self._op_graph is not None
|
||||
return self._op_graph(*args)
|
||||
|
||||
def __str__(self) -> str:
|
||||
self._eval_on_current_inputset()
|
||||
if self._op_graph is None:
|
||||
warning_msg = (
|
||||
f"__str__ failed: OPGraph is None, {self.__class__.__name__} "
|
||||
"needs evaluation on an inputset"
|
||||
)
|
||||
logger.warning(warning_msg)
|
||||
return warning_msg
|
||||
return format_operation_graph(self._op_graph)
|
||||
|
||||
def draw_graph(
|
||||
self,
|
||||
show: bool = False,
|
||||
vertical: bool = True,
|
||||
save_to: Optional[Path] = None,
|
||||
) -> Optional[str]:
|
||||
"""Draws operation graphs and optionally saves/shows the drawing.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): the operation graph to be drawn and optionally saved/shown
|
||||
show (bool): if set to True, the drawing will be shown using matplotlib
|
||||
vertical (bool): if set to True, the orientation will be vertical
|
||||
save_to (Optional[Path]): if specified, the drawn graph will be saved to this path; else
|
||||
it is saved in a temporary file
|
||||
|
||||
Returns:
|
||||
Optional[str]: if OPGraph was not None returns the path as a string of the file where
|
||||
the drawn graph is saved
|
||||
"""
|
||||
self._eval_on_current_inputset()
|
||||
if self._op_graph is None:
|
||||
logger.warning(
|
||||
f"{self.draw_graph.__name__} failed: OPGraph is None, {self.__class__.__name__} "
|
||||
"needs evaluation on an inputset"
|
||||
)
|
||||
return None
|
||||
return draw_graph(self._op_graph, show, vertical, save_to)
|
||||
|
||||
def eval_on_inputset(
|
||||
self,
|
||||
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]],
|
||||
warn_on_inputset_length: bool = False,
|
||||
) -> None:
|
||||
"""Evaluate the underlying function on an inputset in one go, populates OPGraph and bounds.
|
||||
|
||||
Args:
|
||||
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]]]): The inputset on which the
|
||||
function should be evaluated.
|
||||
warn_on_inputset_length (bool, optional): Set to True to get a warning
|
||||
if inputset is not long enough. Defaults to False.
|
||||
"""
|
||||
|
||||
inputset_iter = iter(inputset)
|
||||
try:
|
||||
first_sample = next(inputset_iter)
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
inferred_args = {
|
||||
param_name: get_base_value_for_numpy_or_python_constant_data(val)(
|
||||
is_encrypted=is_encrypted
|
||||
)
|
||||
for (param_name, is_encrypted), val in zip(
|
||||
self._function_parameters_encrypted_status.items(),
|
||||
first_sample
|
||||
if len(self._function_parameters_encrypted_status) > 1
|
||||
else (first_sample,),
|
||||
)
|
||||
}
|
||||
|
||||
self._trace_op_graph_if_needed(inferred_args)
|
||||
|
||||
# For mypy
|
||||
assert self._op_graph is not None
|
||||
|
||||
self._patch_op_graph_input_to_accept_any_integer_input()
|
||||
|
||||
self._nodes_and_bounds = measure_op_graph_bounds_and_update(
|
||||
self._op_graph,
|
||||
inferred_args,
|
||||
itertools.chain((first_sample,), inputset_iter),
|
||||
self._compilation_configuration,
|
||||
self.compilation_artifacts,
|
||||
self._nodes_and_bounds,
|
||||
warn_on_inputset_length,
|
||||
)
|
||||
|
||||
def _eval_on_current_inputset(self) -> None:
|
||||
"""Evaluate OPGraph on _current_inputset."""
|
||||
self.eval_on_inputset(self._current_inputset)
|
||||
self._current_inputset.clear()
|
||||
|
||||
def _needs_tracing(self) -> bool:
|
||||
"""Return whether we need to trace the function and populate the OPGraph."""
|
||||
return self._op_graph is None
|
||||
|
||||
def _trace_op_graph_if_needed(self, inferred_args: Dict[str, BaseValue]) -> None:
|
||||
"""Populate _op_graph with the OPGraph for _function_to_compile."""
|
||||
if not self._needs_tracing():
|
||||
return
|
||||
|
||||
self._op_graph = compile_numpy_function_into_op_graph(
|
||||
self.function_to_compile,
|
||||
inferred_args,
|
||||
self._compilation_configuration,
|
||||
self.compilation_artifacts,
|
||||
)
|
||||
|
||||
def _patch_op_graph_input_to_accept_any_integer_input(self) -> None:
|
||||
"""Patch inputs as we don't know what data we expect."""
|
||||
|
||||
# Can only do that if the OPGraph was created hence the test.
|
||||
if self._needs_tracing():
|
||||
return
|
||||
|
||||
# For mypy
|
||||
assert self._op_graph is not None
|
||||
|
||||
# Cheat on Input nodes to avoid issues during inputset eval as we do not know in advance
|
||||
# what the final bit width for the inputs should be
|
||||
for node in self._op_graph.input_nodes.values():
|
||||
for input_ in node.inputs:
|
||||
if isinstance(dtype := (input_.dtype), Integer):
|
||||
dtype.bit_width = 128
|
||||
dtype.is_signed = True
|
||||
|
||||
def compile_on_inputset(
|
||||
self,
|
||||
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]],
|
||||
show_mlir: bool = False,
|
||||
) -> FHECircuit:
|
||||
"""Compile the function on an inputset and get resulting FHECircuit.
|
||||
|
||||
Args:
|
||||
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]]]):
|
||||
The inputset on which the function is evaluated.
|
||||
show_mlir (bool, optional, defaults to False):
|
||||
The flag to enable printing the MLIR that is being compiled for debugging purposes.
|
||||
|
||||
Returns:
|
||||
FHECircuit: the compiled FHECircuit
|
||||
"""
|
||||
|
||||
self.eval_on_inputset(inputset)
|
||||
return self.get_compiled_fhe_circuit(show_mlir)
|
||||
|
||||
def get_compiled_fhe_circuit(self, show_mlir: bool = False) -> FHECircuit:
|
||||
"""Return a compiled FHECircuit if the instance was evaluated on an inputset.
|
||||
|
||||
Args:
|
||||
show_mlir (bool, optional): if set, the MLIR produced by the converter and which is
|
||||
going to be sent to the compiler backend is shown on the screen, e.g., for debugging
|
||||
or demo. Defaults to False.
|
||||
|
||||
Raises:
|
||||
RuntimeError: raised if no inputset was passed to the instance.
|
||||
|
||||
Returns:
|
||||
FHECircuit: the compiled FHECircuit
|
||||
"""
|
||||
self._eval_on_current_inputset()
|
||||
|
||||
if self._op_graph is None:
|
||||
raise RuntimeError(
|
||||
"Requested FHECircuit but no OPGraph was compiled. "
|
||||
f"Did you forget to evaluate {self.__class__.__name__} over an inputset?"
|
||||
)
|
||||
|
||||
return compile_op_graph_to_fhe_circuit(
|
||||
self._op_graph,
|
||||
show_mlir,
|
||||
self.compilation_configuration,
|
||||
self.compilation_artifacts,
|
||||
)
|
||||
@@ -1,59 +0,0 @@
|
||||
"""Helpers for indexing with numpy values functionality."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy
|
||||
|
||||
|
||||
def should_sanitize(indexing_element: Any) -> bool:
|
||||
"""Decide whether to sanitize an indexing element or not.
|
||||
|
||||
Sanitizing in this context means converting supported numpy values into python values.
|
||||
|
||||
Args:
|
||||
indexing_element (Any): the indexing element to decide sanitization.
|
||||
|
||||
Returns:
|
||||
bool: True if indexing element should be sanitized otherwise False.
|
||||
"""
|
||||
|
||||
return isinstance(indexing_element, numpy.integer) or (
|
||||
isinstance(indexing_element, numpy.ndarray)
|
||||
and issubclass(indexing_element.dtype.type, numpy.integer)
|
||||
and indexing_element.shape == ()
|
||||
)
|
||||
|
||||
|
||||
def process_indexing_element(indexing_element: Any) -> Any:
|
||||
"""Process an indexing element.
|
||||
|
||||
Processing in this context means converting supported numpy values into python values.
|
||||
(if they are decided to be sanitized)
|
||||
|
||||
Args:
|
||||
indexing_element (Any): the indexing element to sanitize.
|
||||
|
||||
Returns:
|
||||
Any: the sanitized indexing element.
|
||||
"""
|
||||
|
||||
if isinstance(indexing_element, slice):
|
||||
|
||||
start = indexing_element.start
|
||||
if should_sanitize(start):
|
||||
start = int(start)
|
||||
|
||||
stop = indexing_element.stop
|
||||
if should_sanitize(stop):
|
||||
stop = int(stop)
|
||||
|
||||
step = indexing_element.step
|
||||
if should_sanitize(step):
|
||||
step = int(step)
|
||||
|
||||
indexing_element = slice(start, stop, step)
|
||||
|
||||
elif should_sanitize(indexing_element):
|
||||
indexing_element = int(indexing_element)
|
||||
|
||||
return indexing_element
|
||||
@@ -1,157 +0,0 @@
|
||||
"""Helpers for numpy inputset related functionality."""
|
||||
|
||||
import random
|
||||
from typing import Any, Dict, Iterable, Tuple, Union
|
||||
|
||||
import numpy
|
||||
|
||||
from ..common.compilation import CompilationConfiguration
|
||||
from ..common.data_types import Float, Integer
|
||||
from ..common.values import BaseValue, TensorValue
|
||||
|
||||
|
||||
def _generate_random_integer_scalar(dtype: Integer) -> int:
|
||||
"""Generate a random integer scalar.
|
||||
|
||||
Args:
|
||||
dtype (Integer): the data type to extract bounds
|
||||
|
||||
Returns:
|
||||
int: a random value within the range [dtype.min_value(), dtype.max_value()]
|
||||
"""
|
||||
|
||||
return random.randint(dtype.min_value(), dtype.max_value())
|
||||
|
||||
|
||||
def _generate_random_integer_tensor(dtype: Integer, shape: Tuple[int, ...]) -> numpy.ndarray:
|
||||
"""Generate a random integer tensor.
|
||||
|
||||
Args:
|
||||
dtype (Integer): the data type to extract bounds
|
||||
shape (Tuple[int, ...]): the shape of the generated tensor
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: a random array of the specified shape where each value of it
|
||||
is within the range [dtype.min_value(), dtype.max_value()]
|
||||
"""
|
||||
|
||||
return numpy.random.randint(
|
||||
dtype.min_value(),
|
||||
dtype.max_value() + 1,
|
||||
size=shape,
|
||||
dtype=numpy.int64 if dtype.is_signed else numpy.uint64, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
def _generate_random_float_scalar() -> float:
|
||||
"""Generate a random float scalar.
|
||||
|
||||
Returns:
|
||||
float: a random value within the range [0, 1)
|
||||
"""
|
||||
|
||||
return random.random()
|
||||
|
||||
|
||||
def _generate_random_float_tensor(dtype: Float, shape: Tuple[int, ...]) -> numpy.ndarray:
|
||||
"""Generate a random float tensor.
|
||||
|
||||
Args:
|
||||
dtype (Integer): the data type to extract resulting numpy data type
|
||||
shape (Tuple[int, ...]): the shape of the generated tensor
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: a random array of the specified shape where each value of it
|
||||
is within the range [0, 1)
|
||||
"""
|
||||
|
||||
result = numpy.random.rand(*shape)
|
||||
return result.astype(numpy.float32 if dtype.bit_width == 32 else numpy.float64)
|
||||
|
||||
|
||||
def _generate_random_inputset(
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
compilation_configuration: CompilationConfiguration,
|
||||
) -> Union[Iterable[Any], Iterable[Tuple[Any, ...]]]:
|
||||
"""Generate a random inputset from function parameters.
|
||||
|
||||
Using this function is not a good practice since the randomly generated inputset
|
||||
might not reflect real world data. We have it to speed up our development workflow
|
||||
and we also don't use it in any of our tests, benchmarks, or examples.
|
||||
|
||||
Args:
|
||||
function_parameters (Dict[str, BaseValue]): the function parameters
|
||||
to extract data types and shapes
|
||||
compilation_configuration (CompilationConfiguration): the compilation configuration
|
||||
to extract the sample size of the resulting inputset
|
||||
|
||||
Raises:
|
||||
ValueError: if the provided function arguments cannot be used for random inputset generation
|
||||
|
||||
Returns:
|
||||
Union[Iterable[Any], Iterable[Tuple[Any, ...]]]: the inputset
|
||||
"""
|
||||
|
||||
inputset = []
|
||||
for _ in range(compilation_configuration.random_inputset_samples):
|
||||
sample = []
|
||||
for parameter in function_parameters.values():
|
||||
if not isinstance(parameter, TensorValue):
|
||||
raise ValueError(f"Random inputset cannot be generated for {parameter} parameters")
|
||||
|
||||
if isinstance(parameter.dtype, Integer):
|
||||
sample.append(
|
||||
_generate_random_integer_scalar(parameter.dtype)
|
||||
if parameter.is_scalar
|
||||
else _generate_random_integer_tensor(parameter.dtype, parameter.shape)
|
||||
)
|
||||
elif isinstance(parameter.dtype, Float):
|
||||
sample.append(
|
||||
_generate_random_float_scalar()
|
||||
if parameter.is_scalar
|
||||
else _generate_random_float_tensor(parameter.dtype, parameter.shape)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Random inputset cannot be generated "
|
||||
f"for parameters of type {parameter.dtype}"
|
||||
)
|
||||
inputset.append(tuple(sample) if len(sample) > 1 else sample[0])
|
||||
return inputset
|
||||
|
||||
|
||||
def _check_special_inputset_availability(
|
||||
inputset: str,
|
||||
compilation_configuration: CompilationConfiguration,
|
||||
):
|
||||
"""Check special inputset is valid and is available.
|
||||
|
||||
This function makes sure the provided special inputset is valid and can be used with the
|
||||
provided compilation configuration.
|
||||
|
||||
Currently, the only special inputset is "random" but this can be extended in the future.
|
||||
|
||||
Args:
|
||||
inputset (str): the special inputset to check
|
||||
compilation_configuration (CompilationConfiguration): the compilation configuration
|
||||
to check the availability of the provided special inputset
|
||||
|
||||
Raises:
|
||||
ValueError: if the provided special inputset is not valid
|
||||
RuntimeError: if the provided special inputset is not available
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
if inputset != "random":
|
||||
raise ValueError(
|
||||
f"inputset can only be an iterable of tuples or the string 'random' "
|
||||
f"but you specified '{inputset}' for it"
|
||||
)
|
||||
|
||||
if not compilation_configuration.enable_unsafe_features:
|
||||
raise RuntimeError(
|
||||
"Random inputset generation is an unsafe feature and should not be used "
|
||||
"if you don't know what you are doing"
|
||||
)
|
||||
@@ -1,85 +0,0 @@
|
||||
"""Numpy-specific MLIR converter."""
|
||||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from itertools import product
|
||||
from typing import Any, DefaultDict, Dict, List, Tuple
|
||||
|
||||
import numpy
|
||||
|
||||
from ..common.debugging import assert_true
|
||||
from ..common.mlir.graph_converter import OPGraphConverter
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.representation.intermediate import GenericFunction, IntermediateNode
|
||||
|
||||
|
||||
class HashableNPArray:
|
||||
"""Class to easily manipulate numpy arrays for hashing.
|
||||
|
||||
Note that the hash behavior won't work if the array is modified after being hashed, as it will
|
||||
have been hashed to a certain value and the new array content will be hashed to a different one.
|
||||
"""
|
||||
|
||||
array: numpy.ndarray
|
||||
|
||||
def __init__(self, array: numpy.ndarray) -> None:
|
||||
self.array = array
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.array.tobytes())
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, HashableNPArray) and numpy.array_equal(self.array, other.array)
|
||||
|
||||
|
||||
def generate_deduplicated_tables(
|
||||
node: GenericFunction, ordered_preds: List[IntermediateNode]
|
||||
) -> Tuple[Tuple[numpy.ndarray, List[Tuple[int, ...]]], ...]:
|
||||
"""Deduplicate the tables for the different cells of a tensor if needed.
|
||||
|
||||
Args:
|
||||
node (GenericFunction): the node for which to deduplicate the table.
|
||||
ordered_preds (List[IntermediateNode]): ordered list of predecessors of the node.
|
||||
|
||||
Returns:
|
||||
Tuple[Tuple[numpy.ndarray, List[Tuple[int, ...]]], ...]: A tuple containing tuples whose
|
||||
first element is a table and the second element is a list of tuples indicating which
|
||||
cells in the tensor will use that table.
|
||||
"""
|
||||
# This is the tensor containing the tables for each cell of the tensor for node
|
||||
node_complete_table = numpy.concatenate(
|
||||
tuple(numpy.expand_dims(array, -1) for array in node.get_table(ordered_preds)), axis=-1
|
||||
)
|
||||
|
||||
all_cells_idx = product(*tuple(range(max_val) for max_val in node_complete_table.shape[:-1]))
|
||||
tables_to_cell_idx: DefaultDict[HashableNPArray, List[Tuple[int, ...]]] = defaultdict(list)
|
||||
idx: Tuple[int, ...]
|
||||
all_idx_set = set()
|
||||
for idx in all_cells_idx:
|
||||
hashable_array = HashableNPArray(node_complete_table[idx])
|
||||
tables_to_cell_idx[hashable_array].append(idx)
|
||||
all_idx_set.add(idx)
|
||||
|
||||
assert_true(len(all_idx_set) == math.prod(node_complete_table.shape[:-1]))
|
||||
|
||||
return tuple(
|
||||
(hashable_array.array, indices) for hashable_array, indices in tables_to_cell_idx.items()
|
||||
)
|
||||
|
||||
|
||||
class NPMLIRConverter(OPGraphConverter):
|
||||
"""Numpy-specific MLIR converter."""
|
||||
|
||||
@staticmethod
|
||||
def _generate_additional_info_dict(op_graph: OPGraph) -> Dict[str, Any]:
|
||||
additional_conversion_info = {}
|
||||
|
||||
# Disable numpy warnings during conversion to avoid issues during TLU generation
|
||||
with numpy.errstate(all="ignore"):
|
||||
additional_conversion_info["tables"] = {
|
||||
node: generate_deduplicated_tables(node, op_graph.get_ordered_preds(node))
|
||||
for node in op_graph.graph.nodes()
|
||||
if isinstance(node, GenericFunction) and node.op_kind == "TLU"
|
||||
}
|
||||
|
||||
return additional_conversion_info
|
||||
@@ -1,818 +0,0 @@
|
||||
"""numpy tracing utilities."""
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype
|
||||
from ..common.debugging.custom_assert import assert_true
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.representation.intermediate import Constant, Dot, GenericFunction, MatMul
|
||||
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
|
||||
from ..common.tracing.tracing_helpers import tracing_context
|
||||
from ..common.values import BaseValue, TensorValue
|
||||
from .np_dtypes_helpers import (
|
||||
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES,
|
||||
convert_numpy_dtype_to_base_data_type,
|
||||
get_base_value_for_numpy_or_python_constant_data,
|
||||
get_numpy_function_output_dtype_and_shape_from_input_tracers,
|
||||
)
|
||||
from .np_indexing_helpers import process_indexing_element
|
||||
|
||||
SUPPORTED_TYPES_FOR_TRACING = (int, float, numpy.ndarray) + tuple(
|
||||
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES
|
||||
)
|
||||
|
||||
NPConstant = partial(
|
||||
Constant,
|
||||
get_base_value_for_data_func=get_base_value_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
|
||||
class NPTracer(BaseTracer):
|
||||
"""Tracer class for numpy operations."""
|
||||
|
||||
_mix_values_func: Callable[..., BaseValue] = mix_values_determine_holding_dtype
|
||||
|
||||
def __array_ufunc__(self, ufunc: numpy.ufunc, method, *args, **kwargs):
|
||||
"""Catch calls to numpy ufunc and routes them to tracing functions if supported.
|
||||
|
||||
Read more: https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch
|
||||
"""
|
||||
if method == "__call__":
|
||||
tracing_func = self.get_tracing_func_for_np_function(ufunc)
|
||||
assert_true(
|
||||
(len(kwargs) == 0),
|
||||
f"**kwargs are currently not supported for numpy ufuncs, ufunc: {ufunc.__name__}",
|
||||
)
|
||||
|
||||
# Create constant tracers for args, numpy only passes ufunc.nin args so we can
|
||||
# sanitize all of them without issues
|
||||
sanitized_args = [self._sanitize(arg) for arg in args]
|
||||
return tracing_func(*sanitized_args, **kwargs)
|
||||
raise NotImplementedError("Only __call__ method is supported currently")
|
||||
|
||||
def __array_function__(self, func, _types, args, kwargs):
|
||||
"""Catch calls to numpy function in routes them to tracing functions if supported.
|
||||
|
||||
Read more: https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch
|
||||
"""
|
||||
tracing_func = self.get_tracing_func_for_np_function(func)
|
||||
assert_true(
|
||||
(tracing_func in [NPTracer.numpy_sum, NPTracer.numpy_concatenate]) or len(kwargs) == 0,
|
||||
f"**kwargs are currently not supported for numpy functions, func: {func}",
|
||||
)
|
||||
|
||||
# Fixme: Special case to be removed once #772 is done
|
||||
if func is not numpy.reshape:
|
||||
sanitized_args = [self._sanitize(arg) for arg in args]
|
||||
else:
|
||||
# In numpy.reshape, the second argument is the new shape
|
||||
sanitized_args = [self._sanitize(args[0]), args[1]]
|
||||
return tracing_func(self, sanitized_args[0], sanitized_args[1], **kwargs)
|
||||
|
||||
return tracing_func(self, *sanitized_args, **kwargs)
|
||||
|
||||
def astype(self, numpy_dtype: DTypeLike, *args, **kwargs) -> "NPTracer":
|
||||
r"""Support numpy astype feature.
|
||||
|
||||
For now it only accepts a dtype and no additional parameters, \*args and
|
||||
\*\*kwargs are accepted for interface compatibility only
|
||||
|
||||
Args:
|
||||
numpy_dtype (DTypeLike): The object describing a numpy type
|
||||
|
||||
Returns:
|
||||
NPTracer: The NPTracer representing the casting operation
|
||||
"""
|
||||
assert_true(
|
||||
len(args) == 0, f"astype currently only supports tracing without *args, got {args}"
|
||||
)
|
||||
assert_true(
|
||||
(len(kwargs) == 0),
|
||||
f"astype currently only supports tracing without **kwargs, got {kwargs}",
|
||||
)
|
||||
|
||||
normalized_numpy_dtype = numpy.dtype(numpy_dtype)
|
||||
output_dtype = convert_numpy_dtype_to_base_data_type(numpy_dtype)
|
||||
generic_function_output_value = deepcopy(self.output)
|
||||
generic_function_output_value.dtype = output_dtype
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[self.output],
|
||||
arbitrary_func=lambda x, dtype: x.astype(dtype),
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
op_kwargs={"dtype": normalized_numpy_dtype.type},
|
||||
op_name="astype",
|
||||
)
|
||||
output_tracer = self.__class__([self], traced_computation=traced_computation, output_idx=0)
|
||||
return output_tracer
|
||||
|
||||
@staticmethod
|
||||
def get_tracing_func_for_np_function(func: Union[numpy.ufunc, Callable]) -> Callable:
|
||||
"""Get the tracing function for a numpy function.
|
||||
|
||||
Args:
|
||||
func (Union[numpy.ufunc, Callable]): The numpy function that will be traced
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Raised if the passed function is not supported by NPTracer
|
||||
|
||||
Returns:
|
||||
Callable: the tracing function that needs to be called to trace func
|
||||
"""
|
||||
tracing_func: Optional[Callable]
|
||||
|
||||
# numpy.invert is not great in term of types it supports, so we've decided not to support it
|
||||
# and to propose to the user to use numpy.bitwise_not
|
||||
if func == numpy.invert:
|
||||
raise RuntimeError(
|
||||
f"NPTracer does not manage the following func: {func.__name__}. Please replace by "
|
||||
f"calls to bitwise_xor with appropriate mask"
|
||||
)
|
||||
|
||||
if isinstance(func, numpy.ufunc):
|
||||
tracing_func = NPTracer.UFUNC_ROUTING.get(func, None)
|
||||
else:
|
||||
tracing_func = NPTracer.FUNC_ROUTING.get(func, None)
|
||||
|
||||
if tracing_func is None:
|
||||
raise NotImplementedError(
|
||||
f"NPTracer does not yet manage the following func: {func.__name__}"
|
||||
)
|
||||
return tracing_func
|
||||
|
||||
def _supports_other_operand(self, other: Any) -> bool:
|
||||
return super()._supports_other_operand(other) or isinstance(
|
||||
other, SUPPORTED_TYPES_FOR_TRACING
|
||||
)
|
||||
|
||||
def _make_const_input_tracer(self, constant_data: Any) -> "NPTracer":
|
||||
return self.__class__([], NPConstant(constant_data), 0)
|
||||
|
||||
@classmethod
|
||||
def _np_operator(
|
||||
cls,
|
||||
numpy_operator,
|
||||
numpy_operator_string,
|
||||
numpy_operator_nin,
|
||||
*input_tracers: "NPTracer",
|
||||
**kwargs,
|
||||
) -> "NPTracer":
|
||||
"""Trace a numpy operator.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
assert_true(len(input_tracers) == numpy_operator_nin)
|
||||
|
||||
common_output_dtypes_and_shapes = (
|
||||
get_numpy_function_output_dtype_and_shape_from_input_tracers(
|
||||
numpy_operator,
|
||||
*input_tracers,
|
||||
)
|
||||
)
|
||||
assert_true(len(common_output_dtypes_and_shapes) == 1)
|
||||
|
||||
variable_input_indices = [
|
||||
idx
|
||||
for idx, pred in enumerate(input_tracers)
|
||||
if not isinstance(pred.traced_computation, Constant)
|
||||
]
|
||||
assert_true(
|
||||
(non_constant_pred_count := len(variable_input_indices)) == 1,
|
||||
f"Can only have 1 non constant predecessor in {cls._np_operator.__name__}, "
|
||||
f"got {non_constant_pred_count} for operator {numpy_operator}",
|
||||
)
|
||||
|
||||
variable_input_idx = variable_input_indices[0]
|
||||
output_dtype, output_shape = common_output_dtypes_and_shapes[0]
|
||||
|
||||
generic_function_output_value = TensorValue(
|
||||
output_dtype,
|
||||
input_tracers[variable_input_idx].output.is_encrypted,
|
||||
output_shape,
|
||||
)
|
||||
|
||||
op_kwargs = deepcopy(kwargs)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[input_tracer.output for input_tracer in input_tracers],
|
||||
arbitrary_func=numpy_operator,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
op_kwargs=op_kwargs,
|
||||
op_name=numpy_operator_string,
|
||||
)
|
||||
output_tracer = cls(
|
||||
input_tracers,
|
||||
traced_computation=traced_computation,
|
||||
output_idx=0,
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def numpy_dot(self, *args: "NPTracer", **_kwargs) -> "NPTracer":
|
||||
"""Trace numpy.dot.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
assert_true((num_args := len(args)) == 2, f"dot expects 2 inputs got {num_args}")
|
||||
|
||||
common_output_dtypes_and_shapes = (
|
||||
get_numpy_function_output_dtype_and_shape_from_input_tracers(numpy.dot, *args)
|
||||
)
|
||||
assert_true(len(common_output_dtypes_and_shapes) == 1)
|
||||
|
||||
traced_computation = Dot(
|
||||
[input_tracer.output for input_tracer in args],
|
||||
common_output_dtypes_and_shapes[0][0],
|
||||
delegate_evaluation_function=numpy.dot,
|
||||
)
|
||||
|
||||
output_tracer = self.__class__(
|
||||
args,
|
||||
traced_computation=traced_computation,
|
||||
output_idx=0,
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def clip(self, *args: Union["NPTracer", Any], **kwargs) -> "NPTracer":
|
||||
"""Trace x.clip.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
sanitized_args = [cast(NPTracer, self._sanitize(arg)) for arg in args]
|
||||
return self.numpy_clip(self, *sanitized_args, **kwargs)
|
||||
|
||||
def numpy_clip(self, *args: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Trace numpy.clip.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
return self._np_operator(numpy.clip, "clip", 3, *args, **kwargs)
|
||||
|
||||
def dot(self, *args: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Trace x.dot.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
assert len(args) == 1
|
||||
arg0 = self._sanitize(args[0])
|
||||
assert_true(isinstance(arg0, NPTracer))
|
||||
arg0 = cast(NPTracer, arg0)
|
||||
return self.numpy_dot(self, arg0, **kwargs)
|
||||
|
||||
def transpose(self, *args: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Trace x.transpose.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
return self.numpy_transpose(self, *args, **kwargs)
|
||||
|
||||
def numpy_transpose(self, *args: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Trace numpy.transpose.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
assert_true((num_args := len(args)) == 1, f"transpose expect 1 input got {num_args}")
|
||||
|
||||
first_arg_output = args[0].output
|
||||
assert_true(isinstance(first_arg_output, TensorValue))
|
||||
first_arg_output = cast(TensorValue, first_arg_output)
|
||||
|
||||
transpose_is_fusable = first_arg_output.is_scalar or first_arg_output.ndim == 1
|
||||
|
||||
out_dtype = first_arg_output.dtype
|
||||
out_shape = first_arg_output.shape[::-1]
|
||||
|
||||
generic_function_output_value = TensorValue(
|
||||
out_dtype,
|
||||
first_arg_output.is_encrypted,
|
||||
out_shape,
|
||||
)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[first_arg_output],
|
||||
arbitrary_func=numpy.transpose,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="Memory",
|
||||
op_kwargs=deepcopy(kwargs),
|
||||
op_name="transpose",
|
||||
op_attributes={"fusable": transpose_is_fusable},
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
args,
|
||||
traced_computation=traced_computation,
|
||||
output_idx=0,
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def ravel(self, *args: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Trace x.ravel.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
return self.numpy_ravel(self, *args, **kwargs)
|
||||
|
||||
def numpy_ravel(self, *args: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Trace numpy.ravel.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
assert_true((num_args := len(args)) == 1, f"ravel expect 1 input got {num_args}")
|
||||
|
||||
first_arg_output = args[0].output
|
||||
assert_true(isinstance(first_arg_output, TensorValue))
|
||||
first_arg_output = cast(TensorValue, first_arg_output)
|
||||
|
||||
ravel_is_fusable = first_arg_output.ndim == 1
|
||||
|
||||
out_dtype = first_arg_output.dtype
|
||||
out_shape = (1,) if first_arg_output.is_scalar else (numpy.product(first_arg_output.shape),)
|
||||
|
||||
generic_function_output_value = TensorValue(
|
||||
out_dtype,
|
||||
first_arg_output.is_encrypted,
|
||||
out_shape,
|
||||
)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[first_arg_output],
|
||||
arbitrary_func=numpy.ravel,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="Memory",
|
||||
op_kwargs=deepcopy(kwargs),
|
||||
op_name="ravel",
|
||||
op_attributes={"fusable": ravel_is_fusable},
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
args,
|
||||
traced_computation=traced_computation,
|
||||
output_idx=0,
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def reshape(self, newshape: Tuple[Any, ...], **kwargs) -> "NPTracer":
|
||||
"""Trace x.reshape.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
return self.numpy_reshape(self, newshape, **kwargs)
|
||||
|
||||
def numpy_reshape(self, arg0: "NPTracer", arg1: Tuple[Any, ...], **kwargs) -> "NPTracer":
|
||||
"""Trace numpy.reshape.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
|
||||
# FIXME: #772, restore reshape(self, *args, **kwargs) signature when possible, with mypy
|
||||
# types
|
||||
|
||||
# FIXME: #772, restore
|
||||
# assert_true((num_args := len(args)) == 2, f"reshape expect 2 input got {num_args}")
|
||||
# when possible
|
||||
|
||||
assert_true((num_kwargs := len(kwargs)) == 0, f"reshape expect 0 kwargs got {num_kwargs}")
|
||||
|
||||
first_arg_output = arg0.output
|
||||
assert_true(isinstance(first_arg_output, TensorValue))
|
||||
first_arg_output = cast(TensorValue, first_arg_output)
|
||||
|
||||
try:
|
||||
# calculate a newshape using numpy to handle edge cases such as `-1`s within new shape
|
||||
newshape = numpy.zeros(first_arg_output.shape).reshape(arg1).shape
|
||||
except Exception as error:
|
||||
raise ValueError(
|
||||
f"shapes are not compatible (old shape {first_arg_output.shape}, new shape {arg1})"
|
||||
) from error
|
||||
|
||||
reshape_is_fusable = newshape == first_arg_output.shape
|
||||
|
||||
out_dtype = first_arg_output.dtype
|
||||
out_shape = newshape
|
||||
|
||||
generic_function_output_value = TensorValue(
|
||||
out_dtype,
|
||||
first_arg_output.is_encrypted,
|
||||
out_shape,
|
||||
)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[first_arg_output],
|
||||
arbitrary_func=numpy.reshape,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="Memory",
|
||||
op_kwargs={"newshape": newshape},
|
||||
op_name="reshape",
|
||||
op_attributes={"fusable": reshape_is_fusable},
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
[arg0],
|
||||
traced_computation=traced_computation,
|
||||
output_idx=0,
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def flatten(self, *args: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Trace x.flatten.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
assert_true((num_args := len(args)) == 0, f"flatten expect 0 input got {num_args}")
|
||||
|
||||
first_arg_output = self.output
|
||||
assert_true(isinstance(first_arg_output, TensorValue))
|
||||
first_arg_output = cast(TensorValue, first_arg_output)
|
||||
|
||||
flatten_is_fusable = first_arg_output.ndim == 1
|
||||
|
||||
out_dtype = first_arg_output.dtype
|
||||
out_shape = (1,) if first_arg_output.is_scalar else (numpy.product(first_arg_output.shape),)
|
||||
|
||||
generic_function_output_value = TensorValue(
|
||||
out_dtype,
|
||||
first_arg_output.is_encrypted,
|
||||
out_shape,
|
||||
)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[first_arg_output],
|
||||
arbitrary_func=lambda x: x.flatten(),
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="Memory",
|
||||
op_kwargs=deepcopy(kwargs),
|
||||
op_name="flatten",
|
||||
op_attributes={"fusable": flatten_is_fusable},
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
[self],
|
||||
traced_computation=traced_computation,
|
||||
output_idx=0,
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def numpy_sum(self, inp: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Trace numpy.sum.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
|
||||
input_value = inp.output
|
||||
|
||||
def supported(value):
|
||||
if not value.is_encrypted or not isinstance(input_value, TensorValue):
|
||||
return False
|
||||
|
||||
value = cast(TensorValue, value)
|
||||
if value.shape == ():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
if not supported(input_value):
|
||||
raise ValueError(
|
||||
f"only encrypted tensor sum is supported but you tried to sum {input_value}"
|
||||
)
|
||||
|
||||
try:
|
||||
# calculate a newshape using numpy to handle all cases
|
||||
newshape = numpy.sum(numpy.zeros(input_value.shape), **kwargs).shape # type: ignore
|
||||
except Exception as error:
|
||||
raise ValueError(
|
||||
f"invalid sum on {input_value} with "
|
||||
f"{', '.join('='.join([key, str(value)]) for key, value in kwargs.items())}"
|
||||
) from error
|
||||
|
||||
output_value = TensorValue(
|
||||
input_value.dtype,
|
||||
input_value.is_encrypted,
|
||||
newshape,
|
||||
)
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[input_value],
|
||||
arbitrary_func=numpy.sum,
|
||||
output_value=output_value,
|
||||
op_kind="Memory",
|
||||
op_kwargs=kwargs,
|
||||
op_name="sum",
|
||||
op_attributes={"fusable": False},
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
[inp],
|
||||
traced_computation=traced_computation,
|
||||
output_idx=0,
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def numpy_concatenate(self, inputs: Tuple["NPTracer", ...], **kwargs) -> "NPTracer":
|
||||
"""Trace numpy.concatenate.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
|
||||
input_values = [tracer.output for tracer in inputs]
|
||||
|
||||
def supported(values):
|
||||
if any(
|
||||
not value.is_encrypted or not isinstance(value, TensorValue) for value in values
|
||||
):
|
||||
return False
|
||||
|
||||
values = [cast(TensorValue, value) for value in values]
|
||||
if any(value.shape == () for value in values):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
if not supported(input_values):
|
||||
raise ValueError(
|
||||
f"only encrypted tensor concatenation is supported "
|
||||
f"but you tried to concatenate "
|
||||
f"{', '.join(str(input_value) for input_value in input_values)}"
|
||||
)
|
||||
|
||||
input_tensor_values = [cast(TensorValue, value) for value in input_values]
|
||||
|
||||
try:
|
||||
# calculate a newshape using numpy to handle all cases
|
||||
sample = tuple(numpy.zeros(input_value.shape) for input_value in input_tensor_values)
|
||||
newshape = numpy.concatenate(sample, **kwargs).shape
|
||||
except Exception as error:
|
||||
kwarg_info = ""
|
||||
if len(kwargs) != 0:
|
||||
kwarg_info += " with "
|
||||
kwarg_info += ", ".join(
|
||||
"=".join([key, str(value)]) for key, value in kwargs.items()
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"invalid concatenation of "
|
||||
f"{', '.join(str(input_value) for input_value in input_values)}{kwarg_info}"
|
||||
) from error
|
||||
|
||||
output_value = TensorValue(
|
||||
input_tensor_values[0].dtype,
|
||||
input_tensor_values[0].is_encrypted,
|
||||
newshape,
|
||||
)
|
||||
traced_computation = GenericFunction(
|
||||
inputs=input_values,
|
||||
arbitrary_func=numpy.concatenate,
|
||||
output_value=output_value,
|
||||
op_kind="Memory",
|
||||
op_kwargs=kwargs,
|
||||
op_name="concat",
|
||||
op_attributes={"fusable": False},
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
list(inputs),
|
||||
traced_computation=traced_computation,
|
||||
output_idx=0,
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def __getitem__(self, item):
|
||||
if isinstance(item, tuple):
|
||||
item = tuple(process_indexing_element(indexing_element) for indexing_element in item)
|
||||
else:
|
||||
item = process_indexing_element(item)
|
||||
|
||||
return BaseTracer.__getitem__(self, item)
|
||||
|
||||
def __matmul__(self, other):
|
||||
"""Trace numpy.matmul."""
|
||||
return self.__array_ufunc__(numpy.matmul, "__call__", self, other)
|
||||
|
||||
# Supported functions are either univariate or bivariate for which one of the two
|
||||
# sources is a constant
|
||||
#
|
||||
# numpy.add, numpy.multiply and numpy.subtract are not there since already managed
|
||||
# by leveled operations
|
||||
#
|
||||
# numpy.conjugate is not there since working on complex numbers
|
||||
#
|
||||
# numpy.isnat is not there since it is about timings
|
||||
#
|
||||
# numpy.divmod, numpy.modf and numpy.frexp are not there since output two values
|
||||
#
|
||||
# numpy.invert (as known as numpy.bitwise_not) is not here, because it has strange input type.
|
||||
# We ask the user to replace bitwise_xor instead
|
||||
LIST_OF_SUPPORTED_UFUNC: List[numpy.ufunc] = [
|
||||
numpy.absolute,
|
||||
numpy.arccos,
|
||||
numpy.arccosh,
|
||||
numpy.arcsin,
|
||||
numpy.arcsinh,
|
||||
numpy.arctan,
|
||||
numpy.arctan2,
|
||||
numpy.arctanh,
|
||||
numpy.bitwise_and,
|
||||
numpy.bitwise_or,
|
||||
numpy.bitwise_xor,
|
||||
numpy.cbrt,
|
||||
numpy.ceil,
|
||||
numpy.copysign,
|
||||
numpy.cos,
|
||||
numpy.cosh,
|
||||
numpy.deg2rad,
|
||||
numpy.degrees,
|
||||
numpy.equal,
|
||||
numpy.exp,
|
||||
numpy.exp2,
|
||||
numpy.expm1,
|
||||
numpy.fabs,
|
||||
numpy.float_power,
|
||||
numpy.floor,
|
||||
numpy.floor_divide,
|
||||
numpy.fmax,
|
||||
numpy.fmin,
|
||||
numpy.fmod,
|
||||
numpy.gcd,
|
||||
numpy.greater,
|
||||
numpy.greater_equal,
|
||||
numpy.heaviside,
|
||||
numpy.hypot,
|
||||
numpy.isfinite,
|
||||
numpy.isinf,
|
||||
numpy.isnan,
|
||||
numpy.lcm,
|
||||
numpy.ldexp,
|
||||
numpy.left_shift,
|
||||
numpy.less,
|
||||
numpy.less_equal,
|
||||
numpy.log,
|
||||
numpy.log10,
|
||||
numpy.log1p,
|
||||
numpy.log2,
|
||||
numpy.logaddexp,
|
||||
numpy.logaddexp2,
|
||||
numpy.logical_and,
|
||||
numpy.logical_not,
|
||||
numpy.logical_or,
|
||||
numpy.logical_xor,
|
||||
numpy.maximum,
|
||||
numpy.minimum,
|
||||
numpy.negative,
|
||||
numpy.nextafter,
|
||||
numpy.not_equal,
|
||||
numpy.positive,
|
||||
numpy.power,
|
||||
numpy.rad2deg,
|
||||
numpy.radians,
|
||||
numpy.reciprocal,
|
||||
numpy.remainder,
|
||||
numpy.right_shift,
|
||||
numpy.rint,
|
||||
numpy.sign,
|
||||
numpy.signbit,
|
||||
numpy.sin,
|
||||
numpy.sinh,
|
||||
numpy.spacing,
|
||||
numpy.sqrt,
|
||||
numpy.square,
|
||||
numpy.tan,
|
||||
numpy.tanh,
|
||||
numpy.true_divide,
|
||||
numpy.trunc,
|
||||
]
|
||||
|
||||
# We build UFUNC_ROUTING dynamically after the creation of the class,
|
||||
# because of some limits of python or our unability to do it properly
|
||||
# in the class with techniques which are compatible with the different
|
||||
# coding checks we use
|
||||
UFUNC_ROUTING: Dict[numpy.ufunc, Callable] = {}
|
||||
|
||||
FUNC_ROUTING: Dict[Callable, Callable] = {
|
||||
numpy.dot: numpy_dot,
|
||||
numpy.transpose: numpy_transpose,
|
||||
numpy.reshape: numpy_reshape,
|
||||
numpy.ravel: numpy_ravel,
|
||||
numpy.clip: numpy_clip,
|
||||
numpy.sum: numpy_sum,
|
||||
numpy.concatenate: numpy_concatenate,
|
||||
}
|
||||
|
||||
|
||||
def _get_unary_fun(function: numpy.ufunc):
|
||||
"""Wrap _unary_operator in a lambda to populate NPTRACER.UFUNC_ROUTING."""
|
||||
|
||||
# We have to access this method to be able to build NPTracer.UFUNC_ROUTING
|
||||
# dynamically
|
||||
# pylint: disable=protected-access
|
||||
return lambda *input_tracers, **kwargs: NPTracer._np_operator(
|
||||
function, f"{function.__name__}", 1, *input_tracers, **kwargs
|
||||
)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
def _get_binary_fun(function: numpy.ufunc):
|
||||
"""Wrap _binary_operator in a lambda to populate NPTRACER.UFUNC_ROUTING."""
|
||||
|
||||
# We have to access this method to be able to build NPTracer.UFUNC_ROUTING
|
||||
# dynamically
|
||||
# pylint: disable=protected-access
|
||||
return lambda *input_tracers, **kwargs: NPTracer._np_operator(
|
||||
function, f"{function.__name__}", 2, *input_tracers, **kwargs
|
||||
)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
# We are populating NPTracer.UFUNC_ROUTING dynamically
|
||||
NPTracer.UFUNC_ROUTING = {
|
||||
fun: _get_unary_fun(fun) for fun in NPTracer.LIST_OF_SUPPORTED_UFUNC if fun.nin == 1
|
||||
}
|
||||
|
||||
NPTracer.UFUNC_ROUTING.update(
|
||||
{fun: _get_binary_fun(fun) for fun in NPTracer.LIST_OF_SUPPORTED_UFUNC if fun.nin == 2}
|
||||
)
|
||||
|
||||
list_of_not_supported = [
|
||||
(ufunc.__name__, ufunc.nin)
|
||||
for ufunc in NPTracer.LIST_OF_SUPPORTED_UFUNC
|
||||
if ufunc.nin not in [1, 2]
|
||||
]
|
||||
|
||||
assert_true(len(list_of_not_supported) == 0, f"Not supported nin's, {list_of_not_supported}")
|
||||
del list_of_not_supported
|
||||
|
||||
# We are adding initial support for `np.array(...)` +,-,* `BaseTracer`
|
||||
# (note that this is not the proper complete handling of these functions)
|
||||
|
||||
|
||||
def _on_numpy_add(lhs, rhs):
|
||||
return lhs.__add__(rhs)
|
||||
|
||||
|
||||
def _on_numpy_subtract(lhs, rhs):
|
||||
return lhs.__sub__(rhs)
|
||||
|
||||
|
||||
def _on_numpy_multiply(lhs, rhs):
|
||||
return lhs.__mul__(rhs)
|
||||
|
||||
|
||||
def _on_numpy_matmul(lhs: NPTracer, rhs: NPTracer):
|
||||
common_output_dtypes_and_shapes = get_numpy_function_output_dtype_and_shape_from_input_tracers(
|
||||
numpy.matmul, lhs, rhs
|
||||
)
|
||||
assert_true(len(common_output_dtypes_and_shapes) == 1)
|
||||
|
||||
output_shape = common_output_dtypes_and_shapes[0][1]
|
||||
traced_computation = MatMul(
|
||||
[lhs.output, rhs.output],
|
||||
common_output_dtypes_and_shapes[0][0],
|
||||
output_shape,
|
||||
)
|
||||
return NPTracer([lhs, rhs], traced_computation, output_idx=0)
|
||||
|
||||
|
||||
NPTracer.UFUNC_ROUTING[numpy.add] = _on_numpy_add
|
||||
NPTracer.UFUNC_ROUTING[numpy.subtract] = _on_numpy_subtract
|
||||
NPTracer.UFUNC_ROUTING[numpy.multiply] = _on_numpy_multiply
|
||||
NPTracer.UFUNC_ROUTING[numpy.matmul] = _on_numpy_matmul
|
||||
|
||||
|
||||
def trace_numpy_function(
|
||||
function_to_trace: Callable, function_parameters: Dict[str, BaseValue]
|
||||
) -> OPGraph:
|
||||
"""Trace a numpy function.
|
||||
|
||||
Args:
|
||||
function_to_trace (Callable): The function you want to trace
|
||||
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
|
||||
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
|
||||
|
||||
Returns:
|
||||
OPGraph: The graph containing the ir nodes representing the computation done in the input
|
||||
function
|
||||
"""
|
||||
function_parameters = prepare_function_parameters(function_to_trace, function_parameters)
|
||||
|
||||
input_tracers = make_input_tracers(NPTracer, function_parameters)
|
||||
|
||||
# We could easily create a graph of NPTracer, but we may end up with dead nodes starting from
|
||||
# the inputs that's why we create the graph starting from the outputs
|
||||
with tracing_context([NPTracer]):
|
||||
output_tracers = function_to_trace(**input_tracers)
|
||||
|
||||
if isinstance(output_tracers, NPTracer):
|
||||
output_tracers = (output_tracers,)
|
||||
|
||||
op_graph = OPGraph.from_output_tracers(output_tracers)
|
||||
|
||||
return op_graph
|
||||
@@ -1,533 +0,0 @@
|
||||
"""Test file for bounds evaluation with a inputset"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from concrete.common.bounds_measurement.inputset_eval import eval_op_graph_bounds_on_inputset
|
||||
from concrete.common.compilation import CompilationConfiguration
|
||||
from concrete.common.data_types.floats import Float
|
||||
from concrete.common.data_types.integers import Integer, UnsignedInteger
|
||||
from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy.compile import numpy_max_func, numpy_min_func
|
||||
from concrete.numpy.np_dtypes_helpers import get_base_value_for_numpy_or_python_constant_data
|
||||
from concrete.numpy.tracing import trace_numpy_function
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,expected_output_bounds,expected_output_data_type",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: x + y,
|
||||
((-10, 10), (-10, 10)),
|
||||
(-20, 20),
|
||||
Integer(6, is_signed=True),
|
||||
id="x + y, (-10, 10), (-10, 10), (-20, 20)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x + y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-14, 7),
|
||||
Integer(5, is_signed=True),
|
||||
id="x + y, (-10, 2), (-4, 5), (-14, 7)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x + y + 1.7,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-12.3, 8.7),
|
||||
Float(64),
|
||||
id="x + y + 1.7, (-10, 2), (-4, 5), (-12.3, 8.7)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x + y + 1,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-13, 8),
|
||||
Integer(5, is_signed=True),
|
||||
id="x + y + 1, (-10, 2), (-4, 5), (-13, 8)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x + y + (-3),
|
||||
((-10, 2), (-4, 5)),
|
||||
(-17, 4),
|
||||
Integer(6, is_signed=True),
|
||||
id="x + y + 1, (-10, 2), (-4, 5), (-17, 4)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (1 + x) + y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-13, 8),
|
||||
Integer(5, is_signed=True),
|
||||
id="(1 + x) + y, (-10, 2), (-4, 5), (-13, 8)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x - y,
|
||||
((-10, 10), (-10, 10)),
|
||||
(-20, 20),
|
||||
Integer(6, is_signed=True),
|
||||
id="x - y, (-10, 10), (-10, 10), (-20, 20)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x - y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-15, 6),
|
||||
Integer(5, is_signed=True),
|
||||
id="x - y, (-10, 2), (-4, 5), (-15, 6)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x - y - 42,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-57, -36),
|
||||
Integer(7, is_signed=True),
|
||||
id="x - y - 42, (-10, 2), (-4, 5), (-57, -36)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x - y - 41.5,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-56.5, -35.5),
|
||||
Float(64),
|
||||
id="x - y - 41.5, (-10, 2), (-4, 5), (-56.5, -35.5)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: 3 - x + y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-3, 18),
|
||||
Integer(6, is_signed=True),
|
||||
id="3 - x + y, (-10, 2), (-4, 5), (-3, 18)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: 2.8 - x + y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-3.2, 17.8),
|
||||
Float(64),
|
||||
id="2.8 - x + y, (-10, 2), (-4, 5), (-3.2, 17.8)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (-13) - x + y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-19, 2),
|
||||
Integer(6, is_signed=True),
|
||||
id="(-13) - x + y, (-10, 2), (-4, 5), (-19, 2)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (-13.5) - x + y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-19.5, 1.5),
|
||||
Float(64),
|
||||
id="(-13.5) - x + y, (-10, 2), (-4, 5), (-19.5, 1.5)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x * y,
|
||||
((-10, 10), (-10, 10)),
|
||||
(-100, 100),
|
||||
Integer(8, is_signed=True),
|
||||
id="x * y, (-10, 10), (-10, 10), (-100, 100)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x * y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-50, 40),
|
||||
Integer(7, is_signed=True),
|
||||
id="x * y, (-10, 2), (-4, 5), (-50, 40)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (3 * x) * y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-150, 120),
|
||||
Integer(9, is_signed=True),
|
||||
id="(3 * x) * y, (-10, 2), (-4, 5), (-150, 120)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (3.0 * x) * y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-150.0, 120.0),
|
||||
Float(64),
|
||||
id="(3.0 * x) * y, (-10, 2), (-4, 5), (-150.0, 120.0)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x * 11) * y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-550, 440),
|
||||
Integer(11, is_signed=True),
|
||||
id="x * y, (-10, 2), (-4, 5), (-550, 440)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x * (-11)) * y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-440, 550),
|
||||
Integer(11, is_signed=True),
|
||||
id="(x * (-11)) * y, (-10, 2), (-4, 5), (-440, 550)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x * (-11.0)) * y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-440.0, 550.0),
|
||||
Float(64),
|
||||
id="(x * (-11.0)) * y, (-10, 2), (-4, 5), (-440.0, 550.0)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x + x + y,
|
||||
((-10, 10), (-10, 10)),
|
||||
(-30, 30),
|
||||
Integer(6, is_signed=True),
|
||||
id="x + x + y, (-10, 10), (-10, 10), (-30, 30)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x - x + y,
|
||||
((-10, 10), (-10, 10)),
|
||||
(-10, 10),
|
||||
Integer(5, is_signed=True),
|
||||
id="x - x + y, (-10, 10), (-10, 10), (-10, 10)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x - x + y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-4, 5),
|
||||
Integer(4, is_signed=True),
|
||||
id="x - x + y, (-10, 2), (-4, 5), (-4, 5)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x * y - x,
|
||||
((-10, 10), (-10, 10)),
|
||||
(-110, 110),
|
||||
Integer(8, is_signed=True),
|
||||
id="x * y - x, (-10, 10), (-10, 10), (-110, 110)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x * y - x,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-40, 50),
|
||||
Integer(7, is_signed=True),
|
||||
id="x * y - x, (-10, 2), (-4, 5), (-40, 50),",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x * 3) * y - (x + 3) + (y - 13) + x * (11 + y) * (12 + y) + (15 - x),
|
||||
((-10, 2), (-4, 5)),
|
||||
(-2846, 574),
|
||||
Integer(13, is_signed=True),
|
||||
id="x * y - x, (-10, 2), (-4, 5), (-2846, 574),",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_eval_op_graph_bounds_on_inputset(
|
||||
function,
|
||||
input_ranges,
|
||||
expected_output_bounds,
|
||||
expected_output_data_type: Integer,
|
||||
):
|
||||
"""Test function for eval_op_graph_bounds_on_inputset"""
|
||||
|
||||
test_eval_op_graph_bounds_on_inputset_multiple_output(
|
||||
function,
|
||||
input_ranges,
|
||||
(expected_output_bounds,),
|
||||
(expected_output_data_type,),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,expected_output_bounds,expected_output_data_type",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: (x + 1, y + 10),
|
||||
((-1, 1), (3, 4)),
|
||||
((0, 2), (13, 14)),
|
||||
(Integer(2, is_signed=False), Integer(4, is_signed=False)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x + 1.5, y + 9.6),
|
||||
((-1, 1), (3, 4)),
|
||||
((0.5, 2.5), (12.6, 13.6)),
|
||||
(Float(64), Float(64)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x + y + 1, x * y + 42),
|
||||
((-1, 1), (3, 4)),
|
||||
((3, 6), (38, 46)),
|
||||
(Integer(3, is_signed=False), Integer(6, is_signed=False)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x + y + 0.4, x * y + 41.7),
|
||||
((-1, 1), (3, 4)),
|
||||
((2.4, 5.4), (37.7, 45.7)),
|
||||
(Float(64), Float(64)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x + y + 1, x * y + 41.7),
|
||||
((-1, 1), (3, 4)),
|
||||
((3, 6), (37.7, 45.7)),
|
||||
(Integer(3, is_signed=False), Float(64)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x + y + 0.4, x * y + 42),
|
||||
((-1, 1), (3, 4)),
|
||||
((2.4, 5.4), (38, 46)),
|
||||
(Float(64), Integer(6, is_signed=False)),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_eval_op_graph_bounds_on_inputset_multiple_output(
|
||||
function,
|
||||
input_ranges,
|
||||
expected_output_bounds,
|
||||
expected_output_data_type: Tuple[Integer],
|
||||
):
|
||||
"""Test function for eval_op_graph_bounds_on_inputset"""
|
||||
|
||||
op_graph = trace_numpy_function(
|
||||
function, {"x": EncryptedScalar(Integer(64, True)), "y": EncryptedScalar(Integer(64, True))}
|
||||
)
|
||||
|
||||
def data_gen(range_x, range_y):
|
||||
for x_gen in range_x:
|
||||
for y_gen in range_y:
|
||||
yield (x_gen, y_gen)
|
||||
|
||||
_, node_bounds_and_samples = eval_op_graph_bounds_on_inputset(
|
||||
op_graph,
|
||||
data_gen(*tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
CompilationConfiguration(),
|
||||
)
|
||||
|
||||
for i, output_node in op_graph.output_nodes.items():
|
||||
output_node_bounds = node_bounds_and_samples[output_node]
|
||||
assert (output_node_bounds["min"], output_node_bounds["max"]) == expected_output_bounds[i]
|
||||
|
||||
op_graph.update_values_with_bounds_and_samples(node_bounds_and_samples)
|
||||
|
||||
for i, output_node in op_graph.output_nodes.items():
|
||||
assert expected_output_data_type[i] == output_node.outputs[0].dtype
|
||||
|
||||
|
||||
def test_eval_op_graph_bounds_on_non_conformant_inputset_default(capsys):
|
||||
"""Test function for eval_op_graph_bounds_on_inputset with non conformant inputset"""
|
||||
|
||||
def f(x, y):
|
||||
return np.dot(x, y)
|
||||
|
||||
x = EncryptedTensor(UnsignedInteger(2), (3,))
|
||||
y = ClearTensor(UnsignedInteger(2), (3,))
|
||||
|
||||
inputset = [
|
||||
(np.array([2, 1, 3, 1]), np.array([1, 2, 1, 1])),
|
||||
(np.array([3, 3, 3]), np.array([3, 3, 5])),
|
||||
]
|
||||
|
||||
op_graph = trace_numpy_function(f, {"x": x, "y": y})
|
||||
|
||||
configuration = CompilationConfiguration()
|
||||
eval_op_graph_bounds_on_inputset(
|
||||
op_graph,
|
||||
inputset,
|
||||
compilation_configuration=configuration,
|
||||
min_func=numpy_min_func,
|
||||
max_func=numpy_max_func,
|
||||
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert (
|
||||
captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
|
||||
"(expected EncryptedTensor<uint2, shape=(3,)> for parameter `x` "
|
||||
"but got EncryptedTensor<uint2, shape=(4,)> which is not compatible)\n"
|
||||
"Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
|
||||
"(expected ClearTensor<uint2, shape=(3,)> for parameter `y` "
|
||||
"but got ClearTensor<uint2, shape=(4,)> which is not compatible)\n"
|
||||
)
|
||||
|
||||
|
||||
def test_eval_op_graph_bounds_on_non_conformant_inputset_check_all(capsys):
|
||||
"""Test function for eval_op_graph_bounds_on_inputset with non conformant inputset, check all"""
|
||||
|
||||
def f(x, y):
|
||||
return np.dot(x, y)
|
||||
|
||||
x = EncryptedTensor(UnsignedInteger(2), (3,))
|
||||
y = ClearTensor(UnsignedInteger(2), (3,))
|
||||
|
||||
inputset = [
|
||||
(np.array([2, 1, 3, 1]), np.array([1, 2, 1, 1])),
|
||||
(np.array([3, 3, 3]), np.array([3, 3, 5])),
|
||||
]
|
||||
|
||||
op_graph = trace_numpy_function(f, {"x": x, "y": y})
|
||||
|
||||
configuration = CompilationConfiguration(check_every_input_in_inputset=True)
|
||||
eval_op_graph_bounds_on_inputset(
|
||||
op_graph,
|
||||
inputset,
|
||||
compilation_configuration=configuration,
|
||||
min_func=numpy_min_func,
|
||||
max_func=numpy_max_func,
|
||||
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert (
|
||||
captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
|
||||
"(expected EncryptedTensor<uint2, shape=(3,)> for parameter `x` "
|
||||
"but got EncryptedTensor<uint2, shape=(4,)> which is not compatible)\n"
|
||||
"Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
|
||||
"(expected ClearTensor<uint2, shape=(3,)> for parameter `y` "
|
||||
"but got ClearTensor<uint2, shape=(4,)> which is not compatible)\n"
|
||||
"Warning: Input #1 (0-indexed) is not coherent with the hinted parameters "
|
||||
"(expected ClearTensor<uint2, shape=(3,)> for parameter `y` "
|
||||
"but got ClearTensor<uint3, shape=(3,)> which is not compatible)\n"
|
||||
)
|
||||
|
||||
|
||||
def test_eval_op_graph_bounds_on_conformant_numpy_inputset_check_all(capsys):
|
||||
"""Test function for eval_op_graph_bounds_on_inputset
|
||||
with conformant inputset of numpy arrays, check all"""
|
||||
|
||||
def f(x, y):
|
||||
return np.dot(x, y)
|
||||
|
||||
x = EncryptedTensor(UnsignedInteger(2), (3,))
|
||||
y = ClearTensor(UnsignedInteger(2), (3,))
|
||||
|
||||
inputset = [
|
||||
(np.array([2, 1, 3]), np.array([1, 2, 1])),
|
||||
(np.array([3, 3, 3]), np.array([3, 3, 1])),
|
||||
]
|
||||
|
||||
op_graph = trace_numpy_function(f, {"x": x, "y": y})
|
||||
|
||||
configuration = CompilationConfiguration(check_every_input_in_inputset=True)
|
||||
eval_op_graph_bounds_on_inputset(
|
||||
op_graph,
|
||||
inputset,
|
||||
compilation_configuration=configuration,
|
||||
min_func=numpy_min_func,
|
||||
max_func=numpy_max_func,
|
||||
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert captured.err == ""
|
||||
|
||||
|
||||
def test_eval_op_graph_bounds_on_non_conformant_numpy_inputset_check_all(capsys):
|
||||
"""Test function for eval_op_graph_bounds_on_inputset with non conformant inputset, check all"""
|
||||
|
||||
def f(x, y):
|
||||
return np.dot(x, y)
|
||||
|
||||
x = EncryptedTensor(UnsignedInteger(2), (3,))
|
||||
y = ClearTensor(UnsignedInteger(2), (3,))
|
||||
|
||||
inputset = [
|
||||
(np.array([2, 1, 3, 1]), np.array([1, 2, 1, 1])),
|
||||
(np.array([3, 3, 3]), np.array([3, 3, 5])),
|
||||
]
|
||||
|
||||
op_graph = trace_numpy_function(f, {"x": x, "y": y})
|
||||
|
||||
configuration = CompilationConfiguration(check_every_input_in_inputset=True)
|
||||
eval_op_graph_bounds_on_inputset(
|
||||
op_graph,
|
||||
inputset,
|
||||
compilation_configuration=configuration,
|
||||
min_func=numpy_min_func,
|
||||
max_func=numpy_max_func,
|
||||
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert (
|
||||
captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
|
||||
"(expected EncryptedTensor<uint2, shape=(3,)> for parameter `x` "
|
||||
"but got EncryptedTensor<uint2, shape=(4,)> which is not compatible)\n"
|
||||
"Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
|
||||
"(expected ClearTensor<uint2, shape=(3,)> for parameter `y` "
|
||||
"but got ClearTensor<uint2, shape=(4,)> which is not compatible)\n"
|
||||
"Warning: Input #1 (0-indexed) is not coherent with the hinted parameters "
|
||||
"(expected ClearTensor<uint2, shape=(3,)> for parameter `y` "
|
||||
"but got ClearTensor<uint3, shape=(3,)> which is not compatible)\n"
|
||||
)
|
||||
|
||||
|
||||
def test_eval_op_graph_bounds_on_non_conformant_inputset_treating_warnings_as_errors():
|
||||
"""Test function for eval_op_graph_bounds_on_inputset with non conformant inputset and errors"""
|
||||
|
||||
def f(x, y):
|
||||
return np.dot(x, y)
|
||||
|
||||
x = EncryptedTensor(UnsignedInteger(2), (3,))
|
||||
y = ClearTensor(UnsignedInteger(2), (3,))
|
||||
|
||||
inputset = [
|
||||
(np.array([2, 1, 3, 1]), np.array([1, 2, 1, 1])),
|
||||
(np.array([3, 3, 3]), np.array([3, 3, 5])),
|
||||
]
|
||||
|
||||
op_graph = trace_numpy_function(f, {"x": x, "y": y})
|
||||
|
||||
with pytest.raises(ValueError, match=".* is not coherent with the hinted parameters .*"):
|
||||
configuration = CompilationConfiguration(treat_warnings_as_errors=True)
|
||||
eval_op_graph_bounds_on_inputset(
|
||||
op_graph,
|
||||
inputset,
|
||||
compilation_configuration=configuration,
|
||||
min_func=numpy_min_func,
|
||||
max_func=numpy_max_func,
|
||||
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
|
||||
def test_inpuset_eval_1_input(default_compilation_configuration):
|
||||
"""Test case for a function with a single parameter and passing the inputset without tuples."""
|
||||
|
||||
def f(x):
|
||||
return x + 42
|
||||
|
||||
x = EncryptedScalar(UnsignedInteger(4))
|
||||
|
||||
inputset = range(10)
|
||||
|
||||
op_graph = trace_numpy_function(f, {"x": x})
|
||||
|
||||
eval_op_graph_bounds_on_inputset(
|
||||
op_graph,
|
||||
inputset,
|
||||
compilation_configuration=default_compilation_configuration,
|
||||
min_func=numpy_min_func,
|
||||
max_func=numpy_max_func,
|
||||
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
input_node = op_graph.input_nodes[0]
|
||||
|
||||
assert input_node.inputs[0] == input_node.outputs[0]
|
||||
assert input_node.inputs[0] == EncryptedScalar(UnsignedInteger(4))
|
||||
|
||||
output_node = op_graph.output_nodes[0]
|
||||
|
||||
assert output_node.outputs[0] == EncryptedScalar(UnsignedInteger(6))
|
||||
|
||||
|
||||
# TODO: https://github.com/zama-ai/concrete-numpy-internal/issues/772
|
||||
# Remove once this issue is done
|
||||
def test_inpuset_eval_1_input_refuse_tuple(default_compilation_configuration):
|
||||
"""Test case for a function with a single parameter and passing the inputset with tuples."""
|
||||
|
||||
def f(x):
|
||||
return x + 42
|
||||
|
||||
x = EncryptedScalar(UnsignedInteger(4))
|
||||
|
||||
inputset = [(i,) for i in range(10)]
|
||||
|
||||
op_graph = trace_numpy_function(f, {"x": x})
|
||||
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
eval_op_graph_bounds_on_inputset(
|
||||
op_graph,
|
||||
inputset,
|
||||
compilation_configuration=default_compilation_configuration,
|
||||
min_func=numpy_min_func,
|
||||
max_func=numpy_max_func,
|
||||
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
assert str(excinfo.value) == "Tuples are unsupported for single input inputset evaluation"
|
||||
@@ -1,48 +0,0 @@
|
||||
"""Test file for compilation artifacts"""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from concrete.common.compilation import CompilationArtifacts
|
||||
from concrete.common.data_types.integers import UnsignedInteger
|
||||
from concrete.common.values import EncryptedScalar
|
||||
from concrete.numpy.compile import compile_numpy_function
|
||||
|
||||
|
||||
def test_artifacts_export(default_compilation_configuration):
|
||||
"""Test function to check exporting compilation artifacts"""
|
||||
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
output_directory = Path(tmp)
|
||||
artifacts = CompilationArtifacts(output_directory)
|
||||
|
||||
compile_numpy_function(
|
||||
function,
|
||||
{"x": EncryptedScalar(UnsignedInteger(7))},
|
||||
range(10),
|
||||
default_compilation_configuration,
|
||||
compilation_artifacts=artifacts,
|
||||
)
|
||||
|
||||
artifacts.export()
|
||||
|
||||
assert output_directory.joinpath("environment.txt").exists()
|
||||
assert output_directory.joinpath("requirements.txt").exists()
|
||||
|
||||
assert output_directory.joinpath("function.txt").exists()
|
||||
assert output_directory.joinpath("parameters.txt").exists()
|
||||
|
||||
assert output_directory.joinpath("1.initial.graph.txt").exists()
|
||||
assert output_directory.joinpath("1.initial.graph.png").exists()
|
||||
|
||||
assert output_directory.joinpath("2.final.graph.txt").exists()
|
||||
assert output_directory.joinpath("2.final.graph.png").exists()
|
||||
|
||||
assert output_directory.joinpath("bounds.txt").exists()
|
||||
assert output_directory.joinpath("mlir.txt").exists()
|
||||
|
||||
# format of those files might change in the future
|
||||
# so it is sufficient to test their existance
|
||||
@@ -1,77 +0,0 @@
|
||||
"""Test file for compilation configuration"""
|
||||
|
||||
from inspect import signature
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.compilation import CompilationConfiguration
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.values import EncryptedScalar
|
||||
from concrete.numpy.compile import compile_numpy_function_into_op_graph_and_measure_bounds
|
||||
|
||||
|
||||
def no_fuse(x):
|
||||
"""No fuse"""
|
||||
return x + 2
|
||||
|
||||
|
||||
def simple_fuse_not_output(x):
|
||||
"""Simple fuse not output"""
|
||||
intermediate = x.astype(numpy.float64)
|
||||
intermediate = intermediate.astype(numpy.uint32)
|
||||
return intermediate + 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,fused",
|
||||
[
|
||||
pytest.param(
|
||||
no_fuse,
|
||||
False,
|
||||
id="no_fuse",
|
||||
),
|
||||
pytest.param(
|
||||
simple_fuse_not_output,
|
||||
True,
|
||||
id="simple_fuse_not_output",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_enable_topological_optimizations(
|
||||
test_helpers, function_to_trace, fused, default_compilation_configuration
|
||||
):
|
||||
"""Test function for enable_topological_optimizations flag of compilation configuration"""
|
||||
|
||||
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function_to_trace,
|
||||
{
|
||||
param: EncryptedScalar(Integer(32, is_signed=False))
|
||||
for param in signature(function_to_trace).parameters.keys()
|
||||
},
|
||||
[numpy.array(i) for i in range(10)],
|
||||
default_compilation_configuration,
|
||||
)
|
||||
op_graph_not_optimized = compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function_to_trace,
|
||||
{
|
||||
param: EncryptedScalar(Integer(32, is_signed=False))
|
||||
for param in signature(function_to_trace).parameters.keys()
|
||||
},
|
||||
[numpy.array(i) for i in range(10)],
|
||||
CompilationConfiguration(
|
||||
dump_artifacts_on_unexpected_failures=False,
|
||||
enable_topological_optimizations=False,
|
||||
treat_warnings_as_errors=True,
|
||||
),
|
||||
)
|
||||
|
||||
graph = op_graph.graph
|
||||
not_optimized_graph = op_graph_not_optimized.graph
|
||||
|
||||
if fused:
|
||||
assert not test_helpers.digraphs_are_equivalent(graph, not_optimized_graph)
|
||||
assert len(graph) < len(not_optimized_graph)
|
||||
else:
|
||||
assert test_helpers.digraphs_are_equivalent(graph, not_optimized_graph)
|
||||
assert len(graph) == len(not_optimized_graph)
|
||||
@@ -1,299 +0,0 @@
|
||||
"""Test file for data types helpers"""
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types.base import BaseDataType
|
||||
from concrete.common.data_types.dtypes_helpers import (
|
||||
broadcast_shapes,
|
||||
find_type_to_hold_both_lossy,
|
||||
mix_values_determine_holding_dtype,
|
||||
value_is_encrypted_scalar_integer,
|
||||
value_is_encrypted_scalar_unsigned_integer,
|
||||
)
|
||||
from concrete.common.data_types.floats import Float
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.values import (
|
||||
BaseValue,
|
||||
ClearScalar,
|
||||
ClearTensor,
|
||||
EncryptedScalar,
|
||||
EncryptedTensor,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value,expected_result",
|
||||
[
|
||||
pytest.param(
|
||||
ClearScalar(Integer(8, is_signed=False)),
|
||||
False,
|
||||
id="ClearScalar 8 bits unsigned Integer",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedScalar(Integer(8, is_signed=True)),
|
||||
True,
|
||||
id="EncryptedScalar 8 bits signed Integer",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_value_is_encrypted_integer(value: BaseValue, expected_result: bool):
|
||||
"""Test value_is_encrypted_integer helper"""
|
||||
assert value_is_encrypted_scalar_integer(value) == expected_result
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value,expected_result",
|
||||
[
|
||||
pytest.param(
|
||||
ClearScalar(Integer(8, is_signed=False)),
|
||||
False,
|
||||
id="ClearScalar 8 bits unsigned Integer",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedScalar(Integer(8, is_signed=True)),
|
||||
False,
|
||||
id="EncryptedScalar 8 bits signed Integer",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedScalar(Integer(8, is_signed=False)),
|
||||
True,
|
||||
id="EncryptedScalar 8 bits unsigned Integer",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_value_is_encrypted_unsigned_integer(value: BaseValue, expected_result: bool):
|
||||
"""Test value_is_encrypted_unsigned_integer helper"""
|
||||
assert value_is_encrypted_scalar_unsigned_integer(value) == expected_result
|
||||
|
||||
|
||||
class UnsupportedDataType(BaseDataType):
|
||||
"""Test helper class to represent an UnsupportedDataType"""
|
||||
|
||||
def __eq__(self, o: object) -> bool:
|
||||
return isinstance(o, self.__class__)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dtype1,dtype2,expected_mixed_dtype",
|
||||
[
|
||||
pytest.param(Integer(6, True), Integer(6, True), Integer(6, True), id="int6, int6, int6"),
|
||||
pytest.param(
|
||||
Integer(6, False), Integer(6, False), Integer(6, False), id="uint6, uint6, uint6"
|
||||
),
|
||||
pytest.param(Integer(6, True), Integer(6, False), Integer(7, True), id="int6, uint6, int7"),
|
||||
pytest.param(Integer(6, False), Integer(6, True), Integer(7, True), id="uint6, int6, int7"),
|
||||
pytest.param(Integer(6, True), Integer(5, False), Integer(6, True), id="int6, uint5, int6"),
|
||||
pytest.param(Integer(5, False), Integer(6, True), Integer(6, True), id="uint5, int6, int6"),
|
||||
pytest.param(Integer(32, True), Float(32), Float(32), id="int32, float32, float32"),
|
||||
pytest.param(Integer(64, True), Float(32), Float(32), id="int64, float32, float32"),
|
||||
pytest.param(Integer(64, True), Float(64), Float(64), id="int64, float64, float64"),
|
||||
pytest.param(Integer(32, True), Float(64), Float(64), id="int32, float64, float64"),
|
||||
pytest.param(Float(64), Integer(32, True), Float(64), id="float64, int32, float64"),
|
||||
pytest.param(Float(64), Integer(7, False), Float(64), id="float64, uint7, float64"),
|
||||
pytest.param(Float(32), Float(32), Float(32), id="float32, float32, float32"),
|
||||
pytest.param(Float(32), Float(64), Float(64), id="float32, float64, float64"),
|
||||
pytest.param(Float(64), Float(32), Float(64), id="float64, float32, float64"),
|
||||
pytest.param(Float(64), Float(64), Float(64), id="float64, float64, float64"),
|
||||
pytest.param(
|
||||
UnsupportedDataType(),
|
||||
UnsupportedDataType(),
|
||||
None,
|
||||
id="unsupported, unsupported, xfail",
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
pytest.param(
|
||||
Integer(6, True),
|
||||
UnsupportedDataType(),
|
||||
None,
|
||||
id="int6, unsupported, xfail",
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
pytest.param(
|
||||
UnsupportedDataType(),
|
||||
Integer(6, True),
|
||||
None,
|
||||
id="unsupported, int6, xfail",
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
pytest.param(
|
||||
UnsupportedDataType(),
|
||||
Float(32),
|
||||
None,
|
||||
id="unsupported, float32, xfail",
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mix_data_types(
|
||||
dtype1: BaseDataType,
|
||||
dtype2: BaseDataType,
|
||||
expected_mixed_dtype: BaseDataType,
|
||||
):
|
||||
"""Test find_type_to_hold_both_lossy helper"""
|
||||
assert expected_mixed_dtype == find_type_to_hold_both_lossy(dtype1, dtype2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value1,value2,expected_mixed_value",
|
||||
[
|
||||
pytest.param(
|
||||
EncryptedScalar(Integer(7, False)),
|
||||
EncryptedScalar(Integer(7, False)),
|
||||
EncryptedScalar(Integer(7, False)),
|
||||
id="euint7, euint7, euint7",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedScalar(Integer(7, False)),
|
||||
ClearScalar(Integer(7, False)),
|
||||
EncryptedScalar(Integer(7, False)),
|
||||
id="euint7, cuint7, euint7",
|
||||
),
|
||||
pytest.param(
|
||||
ClearScalar(Integer(7, False)),
|
||||
EncryptedScalar(Integer(7, False)),
|
||||
EncryptedScalar(Integer(7, False)),
|
||||
id="cuint7, euint7, euint7",
|
||||
),
|
||||
pytest.param(
|
||||
ClearScalar(Integer(7, False)),
|
||||
ClearScalar(Integer(7, False)),
|
||||
ClearScalar(Integer(7, False)),
|
||||
id="cuint7, cuint7, cuint7",
|
||||
),
|
||||
pytest.param(
|
||||
ClearScalar(Float(32)),
|
||||
ClearScalar(Float(32)),
|
||||
ClearScalar(Float(32)),
|
||||
id="cfloat32, cfloat32, cfloat32",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedScalar(Float(32)),
|
||||
ClearScalar(Float(32)),
|
||||
EncryptedScalar(Float(32)),
|
||||
id="efloat32, cfloat32, efloat32",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mix_scalar_values(value1, value2, expected_mixed_value):
|
||||
"""Test mix_values_determine_holding_dtype helper with scalars"""
|
||||
|
||||
assert expected_mixed_value == mix_values_determine_holding_dtype(value1, value2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value1,value2,expected_mixed_value",
|
||||
[
|
||||
pytest.param(
|
||||
EncryptedTensor(Integer(7, False), (1, 2, 3)),
|
||||
EncryptedTensor(Integer(7, False), (1, 2, 3)),
|
||||
EncryptedTensor(Integer(7, False), (1, 2, 3)),
|
||||
),
|
||||
pytest.param(
|
||||
ClearTensor(Integer(7, False), (1, 2, 3)),
|
||||
EncryptedTensor(Integer(7, False), (1, 2, 3)),
|
||||
EncryptedTensor(Integer(7, False), (1, 2, 3)),
|
||||
),
|
||||
pytest.param(
|
||||
ClearTensor(Integer(7, False), (1, 2, 3)),
|
||||
ClearTensor(Integer(7, False), (1, 2, 3)),
|
||||
ClearTensor(Integer(7, False), (1, 2, 3)),
|
||||
),
|
||||
pytest.param(
|
||||
ClearTensor(Integer(7, False), (1, 2, 3)),
|
||||
ClearTensor(Integer(7, False), (1, 2, 3)),
|
||||
ClearTensor(Integer(7, False), (1, 2, 3)),
|
||||
),
|
||||
pytest.param(
|
||||
ClearTensor(Integer(7, False), (1, 2, 3)),
|
||||
EncryptedScalar(Integer(7, False)),
|
||||
None,
|
||||
marks=pytest.mark.xfail(strict=True, raises=AssertionError),
|
||||
),
|
||||
pytest.param(
|
||||
ClearTensor(Integer(7, False), (1, 2, 3)),
|
||||
ClearTensor(Integer(7, False), (3, 2, 1)),
|
||||
None,
|
||||
marks=pytest.mark.xfail(strict=True, raises=AssertionError),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mix_tensor_values(value1, value2, expected_mixed_value):
|
||||
"""Test mix_values_determine_holding_dtype helper with tensors"""
|
||||
|
||||
assert expected_mixed_value == mix_values_determine_holding_dtype(value1, value2)
|
||||
|
||||
|
||||
class DummyValue(BaseValue):
|
||||
"""DummyValue"""
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return BaseValue.__eq__(self, other)
|
||||
|
||||
|
||||
def test_fail_mix_values_determine_holding_dtype():
|
||||
"""Test function for failure case of mix_values_determine_holding_dtype"""
|
||||
|
||||
with pytest.raises(ValueError, match=r".* does not support value .*"):
|
||||
mix_values_determine_holding_dtype(
|
||||
DummyValue(Integer(32, True), True),
|
||||
DummyValue(Integer(32, True), True),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"shape1,shape2,expected_shape",
|
||||
[
|
||||
pytest.param((), (), ()),
|
||||
pytest.param((3,), (), (3,)),
|
||||
pytest.param((3,), (1,), (3,)),
|
||||
pytest.param((3,), (2,), None),
|
||||
pytest.param((3,), (3,), (3,)),
|
||||
pytest.param((2, 3), (), (2, 3)),
|
||||
pytest.param((2, 3), (1,), (2, 3)),
|
||||
pytest.param((2, 3), (2,), None),
|
||||
pytest.param((2, 3), (3,), (2, 3)),
|
||||
pytest.param((2, 3), (1, 1), (2, 3)),
|
||||
pytest.param((2, 3), (2, 1), (2, 3)),
|
||||
pytest.param((2, 3), (3, 1), None),
|
||||
pytest.param((2, 3), (1, 2), None),
|
||||
pytest.param((2, 3), (2, 2), None),
|
||||
pytest.param((2, 3), (3, 2), None),
|
||||
pytest.param((2, 3), (1, 3), (2, 3)),
|
||||
pytest.param((2, 3), (2, 3), (2, 3)),
|
||||
pytest.param((2, 3), (3, 3), None),
|
||||
pytest.param((2, 1, 3), (1, 1, 1), (2, 1, 3)),
|
||||
pytest.param((2, 1, 3), (1, 4, 1), (2, 4, 3)),
|
||||
pytest.param((2, 1, 3), (2, 4, 3), (2, 4, 3)),
|
||||
# Tests cases taken from `numpy`
|
||||
# https://github.com/numpy/numpy/blob/623bc1fae1d47df24e7f1e29321d0c0ba2771ce0/numpy/lib/tests/test_stride_tricks.py#L296-L351
|
||||
pytest.param((1, 2), (2,), (1, 2)),
|
||||
pytest.param((1, 1), (3, 4), (3, 4)),
|
||||
pytest.param((1, 3), (3, 1), (3, 3)),
|
||||
pytest.param((1, 0), (0, 0), (0, 0)),
|
||||
pytest.param((0, 1), (0, 0), (0, 0)),
|
||||
pytest.param((1, 0), (0, 1), (0, 0)),
|
||||
pytest.param((1, 1), (0, 0), (0, 0)),
|
||||
pytest.param((1, 1), (1, 0), (1, 0)),
|
||||
pytest.param((1, 1), (0, 1), (0, 1)),
|
||||
pytest.param((), (0,), (0,)),
|
||||
pytest.param((0,), (0, 0), (0, 0)),
|
||||
pytest.param((0,), (0, 1), (0, 0)),
|
||||
pytest.param((1,), (0, 0), (0, 0)),
|
||||
pytest.param((2,), (0, 0), (0, 0)),
|
||||
pytest.param((), (0, 0), (0, 0)),
|
||||
pytest.param((1, 1), (0,), (1, 0)),
|
||||
pytest.param((1,), (0, 1), (0, 1)),
|
||||
pytest.param((1,), (1, 0), (1, 0)),
|
||||
pytest.param((), (1, 0), (1, 0)),
|
||||
pytest.param((), (0, 1), (0, 1)),
|
||||
pytest.param((1,), (3,), (3,)),
|
||||
pytest.param((2,), (3, 2), (3, 2)),
|
||||
pytest.param((3,), (4,), None),
|
||||
pytest.param((2, 3), (2,), None),
|
||||
pytest.param((1, 3, 4), (2, 3, 3), None),
|
||||
pytest.param((2,), (2, 3), None),
|
||||
],
|
||||
)
|
||||
def test_broadcast_shapes(shape1, shape2, expected_shape):
|
||||
"""Test function for `broadcast_shapes` helper"""
|
||||
assert broadcast_shapes(shape1=shape1, shape2=shape2) == expected_shape
|
||||
assert broadcast_shapes(shape1=shape2, shape2=shape1) == expected_shape
|
||||
@@ -1,50 +0,0 @@
|
||||
"""Test file for float data types"""
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types.floats import Float, Float32, Float64
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"float_,expected_repr_str",
|
||||
[
|
||||
pytest.param(
|
||||
Float32(),
|
||||
"Float<32 bits>",
|
||||
id="Float32",
|
||||
),
|
||||
pytest.param(
|
||||
Float(32),
|
||||
"Float<32 bits>",
|
||||
id="32 bits Float",
|
||||
),
|
||||
pytest.param(
|
||||
Float64(),
|
||||
"Float<64 bits>",
|
||||
id="Float64",
|
||||
),
|
||||
pytest.param(
|
||||
Float(64),
|
||||
"Float<64 bits>",
|
||||
id="64 bits Float",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_floats_repr(float_: Float, expected_repr_str: str):
|
||||
"""Test float repr"""
|
||||
assert float_.__repr__() == expected_repr_str
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"float_1,float_2,expected_equal",
|
||||
[
|
||||
pytest.param(Float32(), Float(32), True),
|
||||
pytest.param(Float(64), Float32(), False),
|
||||
pytest.param(Float64(), Float(64), True),
|
||||
],
|
||||
)
|
||||
def test_floats_eq(float_1: Float, float_2: Float, expected_equal: bool):
|
||||
"""Test float eq"""
|
||||
assert expected_equal == (float_1 == float_2)
|
||||
assert expected_equal == (float_2 == float_1)
|
||||
@@ -1,112 +0,0 @@
|
||||
"""Test file for integers data types"""
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types.integers import (
|
||||
Integer,
|
||||
SignedInteger,
|
||||
UnsignedInteger,
|
||||
make_integer_to_hold,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"integer,expected_min,expected_max",
|
||||
[
|
||||
pytest.param(Integer(8, is_signed=False), 0, 255, id="8 bits unsigned Integer"),
|
||||
pytest.param(UnsignedInteger(8), 0, 255, id="8 bits UnsignedInteger"),
|
||||
pytest.param(Integer(8, is_signed=True), -128, 127, id="8 bits signed Integer"),
|
||||
pytest.param(SignedInteger(8), -128, 127, id="8 bits SignedInteger"),
|
||||
pytest.param(Integer(32, is_signed=False), 0, 4_294_967_295, id="32 bits unsigned Integer"),
|
||||
pytest.param(UnsignedInteger(32), 0, 4_294_967_295, id="32 bits UnsignedInteger"),
|
||||
pytest.param(
|
||||
Integer(32, is_signed=True),
|
||||
-2_147_483_648,
|
||||
2_147_483_647,
|
||||
id="32 bits signed Integer",
|
||||
),
|
||||
pytest.param(
|
||||
SignedInteger(32),
|
||||
-2_147_483_648,
|
||||
2_147_483_647,
|
||||
id="32 bits SignedInteger",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_basic_integers(integer: Integer, expected_min: int, expected_max: int):
|
||||
"""Test integer class basic functions"""
|
||||
assert integer.min_value() == expected_min
|
||||
assert integer.max_value() == expected_max
|
||||
|
||||
assert integer.can_represent_value(random.randint(expected_min, expected_max))
|
||||
assert not integer.can_represent_value(expected_min - 1)
|
||||
assert not integer.can_represent_value(expected_max + 1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"integer,expected_repr_str",
|
||||
[
|
||||
pytest.param(
|
||||
Integer(8, is_signed=False),
|
||||
"Integer<unsigned, 8 bits>",
|
||||
id="8 bits unsigned Integer",
|
||||
),
|
||||
pytest.param(
|
||||
Integer(8, is_signed=True),
|
||||
"Integer<signed, 8 bits>",
|
||||
id="8 bits signed Integer",
|
||||
),
|
||||
pytest.param(
|
||||
Integer(32, is_signed=False),
|
||||
"Integer<unsigned, 32 bits>",
|
||||
id="32 bits unsigned Integer",
|
||||
),
|
||||
pytest.param(
|
||||
Integer(32, is_signed=True),
|
||||
"Integer<signed, 32 bits>",
|
||||
id="32 bits signed Integer",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_integers_repr(integer: Integer, expected_repr_str: str):
|
||||
"""Test integer repr"""
|
||||
assert integer.__repr__() == expected_repr_str
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"values,force_signed,expected_result",
|
||||
[
|
||||
([0], False, Integer(1, is_signed=False)),
|
||||
([0], True, Integer(2, is_signed=True)),
|
||||
([1], False, Integer(1, is_signed=False)),
|
||||
([1], True, Integer(2, is_signed=True)),
|
||||
([-1], False, Integer(2, is_signed=True)),
|
||||
([-2], False, Integer(2, is_signed=True)),
|
||||
([0, 1], False, Integer(1, is_signed=False)),
|
||||
([0, 1], True, Integer(2, is_signed=True)),
|
||||
([7], False, Integer(3, is_signed=False)),
|
||||
([7], True, Integer(4, is_signed=True)),
|
||||
([8], False, Integer(4, is_signed=False)),
|
||||
([8], True, Integer(5, is_signed=True)),
|
||||
([-7], False, Integer(4, is_signed=True)),
|
||||
([-8], False, Integer(4, is_signed=True)),
|
||||
([-7, -8], False, Integer(4, is_signed=True)),
|
||||
([-9], False, Integer(5, is_signed=True)),
|
||||
([-9], True, Integer(5, is_signed=True)),
|
||||
([0, 127], False, Integer(7, is_signed=False)),
|
||||
([0, 127], True, Integer(8, is_signed=True)),
|
||||
([0, 128], False, Integer(8, is_signed=False)),
|
||||
([0, 128], True, Integer(9, is_signed=True)),
|
||||
([-1, 127], False, Integer(8, is_signed=True)),
|
||||
([-256, 127], False, Integer(9, is_signed=True)),
|
||||
([-128, 127], False, Integer(8, is_signed=True)),
|
||||
([-128, 128], False, Integer(9, is_signed=True)),
|
||||
([-13, 4], False, Integer(5, is_signed=True)),
|
||||
([42, 1019], False, Integer(10, is_signed=False)),
|
||||
],
|
||||
)
|
||||
def test_make_integer_to_hold(values, force_signed, expected_result):
|
||||
"""Test make_integer_to_hold"""
|
||||
assert expected_result == make_integer_to_hold(values, force_signed)
|
||||
@@ -1,86 +0,0 @@
|
||||
"""Test file for values related code."""
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types.base import BaseDataType
|
||||
from concrete.common.data_types.floats import Float
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.values import ClearTensor, EncryptedTensor, TensorValue
|
||||
|
||||
|
||||
class DummyDtype(BaseDataType):
|
||||
"""Dummy Helper Dtype"""
|
||||
|
||||
def __eq__(self, o: object) -> bool:
|
||||
return isinstance(o, self.__class__)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tensor_constructor,expected_is_encrypted",
|
||||
[
|
||||
(ClearTensor, False),
|
||||
(partial(TensorValue, is_encrypted=False), False),
|
||||
(EncryptedTensor, True),
|
||||
(partial(TensorValue, is_encrypted=True), True),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"shape,expected_shape,expected_ndim,expected_size",
|
||||
[
|
||||
((), (), 0, 1),
|
||||
((3, 256, 256), (3, 256, 256), 3, 196_608),
|
||||
((1920, 1080, 3), (1920, 1080, 3), 3, 6_220_800),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"data_type",
|
||||
[
|
||||
Integer(7, False),
|
||||
Integer(32, True),
|
||||
Integer(32, False),
|
||||
Integer(64, True),
|
||||
Integer(64, False),
|
||||
Float(32),
|
||||
Float(64),
|
||||
],
|
||||
)
|
||||
def test_tensor_value(
|
||||
tensor_constructor: Callable[..., TensorValue],
|
||||
expected_is_encrypted: bool,
|
||||
shape: Optional[Tuple[int, ...]],
|
||||
expected_shape: Tuple[int, ...],
|
||||
expected_ndim: int,
|
||||
expected_size: int,
|
||||
data_type: Union[Integer, Float],
|
||||
):
|
||||
"""Test function for TensorValue"""
|
||||
|
||||
tensor_value = tensor_constructor(dtype=data_type, shape=shape)
|
||||
|
||||
assert expected_is_encrypted == tensor_value.is_encrypted
|
||||
assert expected_shape == tensor_value.shape
|
||||
assert expected_ndim == tensor_value.ndim
|
||||
assert expected_size == tensor_value.size
|
||||
|
||||
assert data_type == tensor_value.dtype
|
||||
|
||||
other_tensor = deepcopy(tensor_value)
|
||||
|
||||
assert other_tensor == tensor_value
|
||||
|
||||
other_tensor_value = deepcopy(other_tensor)
|
||||
other_tensor_value.dtype = DummyDtype()
|
||||
assert other_tensor_value != tensor_value
|
||||
|
||||
other_shape = tuple(val + 1 for val in shape) if shape is not None else ()
|
||||
other_shape += (2,)
|
||||
other_tensor_value = tensor_constructor(dtype=data_type, shape=other_shape)
|
||||
|
||||
assert other_tensor_value.shape != tensor_value.shape
|
||||
assert other_tensor_value.ndim != tensor_value.ndim
|
||||
assert other_tensor_value.size != tensor_value.size
|
||||
assert other_tensor_value != tensor_value
|
||||
@@ -1,25 +0,0 @@
|
||||
"""Test custom assert functions."""
|
||||
import pytest
|
||||
|
||||
from concrete.common.debugging.custom_assert import assert_false, assert_not_reached, assert_true
|
||||
|
||||
|
||||
def test_assert_not_functions():
|
||||
"""Test custom assert functions"""
|
||||
assert_true(True, "one check")
|
||||
assert_false(False, "another check")
|
||||
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
assert_not_reached("yet another one")
|
||||
|
||||
assert "yet another one" in str(excinfo.value)
|
||||
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
assert_true(False, "one failing check")
|
||||
|
||||
assert "one failing check" in str(excinfo.value)
|
||||
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
assert_false(True, "another failing check")
|
||||
|
||||
assert "another failing check" in str(excinfo.value)
|
||||
@@ -1,46 +0,0 @@
|
||||
"""Test file for drawing"""
|
||||
|
||||
import filecmp
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.debugging import draw_graph
|
||||
from concrete.common.values import EncryptedScalar
|
||||
from concrete.numpy import NPFHECompiler
|
||||
from concrete.numpy.compile import compile_numpy_function_into_op_graph_and_measure_bounds
|
||||
|
||||
|
||||
def test_draw_graph_with_saving(default_compilation_configuration):
|
||||
"""Tests drawing and saving a graph"""
|
||||
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function,
|
||||
{"x": EncryptedScalar(Integer(7, True))},
|
||||
range(-5, 5),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
compiler = NPFHECompiler(function, {"x": "encrypted"}, default_compilation_configuration)
|
||||
|
||||
assert (got := compiler.draw_graph()) is None, got
|
||||
|
||||
compiler.eval_on_inputset(range(-5, 5))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
output_directory = Path(tmp)
|
||||
output_file = output_directory.joinpath("test.png")
|
||||
draw_graph(op_graph, save_to=output_file)
|
||||
assert output_file.exists()
|
||||
|
||||
output_file_compiler = output_directory.joinpath("test_compiler.png")
|
||||
compiler_output_file = compiler.draw_graph(save_to=output_file_compiler)
|
||||
assert compiler_output_file is not None
|
||||
compiler_output_file = Path(compiler_output_file)
|
||||
assert compiler_output_file == output_file_compiler
|
||||
assert compiler_output_file.exists()
|
||||
|
||||
assert filecmp.cmp(output_file, compiler_output_file)
|
||||
@@ -1,160 +0,0 @@
|
||||
"""Test file for formatting"""
|
||||
|
||||
import numpy
|
||||
|
||||
from concrete.common.data_types.integers import Integer, UnsignedInteger
|
||||
from concrete.common.debugging import format_operation_graph
|
||||
from concrete.common.values import EncryptedScalar
|
||||
from concrete.numpy import NPFHECompiler
|
||||
from concrete.numpy.compile import (
|
||||
compile_numpy_function,
|
||||
compile_numpy_function_into_op_graph_and_measure_bounds,
|
||||
)
|
||||
|
||||
|
||||
def test_format_operation_graph_with_multiple_edges(default_compilation_configuration):
|
||||
"""Test format_operation_graph with multiple edges"""
|
||||
|
||||
def function(x):
|
||||
return x + x
|
||||
|
||||
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function,
|
||||
{"x": EncryptedScalar(Integer(4, True))},
|
||||
range(0, 10),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
formatted_graph = format_operation_graph(op_graph)
|
||||
assert (
|
||||
formatted_graph
|
||||
== """
|
||||
|
||||
%0 = x # EncryptedScalar<uint4>
|
||||
%1 = add(%0, %0) # EncryptedScalar<uint5>
|
||||
return %1
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def test_format_operation_graph_with_offending_nodes(default_compilation_configuration):
|
||||
"""Test format_operation_graph with offending nodes"""
|
||||
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function,
|
||||
{"x": EncryptedScalar(Integer(7, True))},
|
||||
range(-5, 5),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
highlighted_nodes = {op_graph.input_nodes[0]: ["foo"]}
|
||||
formatted_graph = format_operation_graph(op_graph, highlighted_nodes=highlighted_nodes).strip()
|
||||
assert (
|
||||
formatted_graph
|
||||
== """
|
||||
|
||||
%0 = x # EncryptedScalar<int4>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo
|
||||
%1 = 42 # ClearScalar<uint6>
|
||||
%2 = add(%0, %1) # EncryptedScalar<uint6>
|
||||
return %2
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
highlighted_nodes = {op_graph.input_nodes[0]: ["foo"], op_graph.output_nodes[0]: ["bar", "baz"]}
|
||||
formatted_graph = format_operation_graph(op_graph, highlighted_nodes=highlighted_nodes).strip()
|
||||
assert (
|
||||
formatted_graph
|
||||
== """
|
||||
|
||||
%0 = x # EncryptedScalar<int4>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo
|
||||
%1 = 42 # ClearScalar<uint6>
|
||||
%2 = add(%0, %1) # EncryptedScalar<uint6>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ bar
|
||||
baz
|
||||
return %2
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def test_format_operation_graph_with_fusing(default_compilation_configuration):
|
||||
"""Test format_operation_graph with fusing"""
|
||||
|
||||
def function(x):
|
||||
return (10 * (numpy.cos(x + 1) + 1)).astype(numpy.uint32)
|
||||
|
||||
circuit = compile_numpy_function(
|
||||
function,
|
||||
{
|
||||
"x": EncryptedScalar(UnsignedInteger(3)),
|
||||
},
|
||||
range(2 ** 3),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
assert (got := str(circuit)) == (
|
||||
"""
|
||||
|
||||
%0 = x # EncryptedScalar<uint5>
|
||||
%1 = 1 # ClearScalar<uint6>
|
||||
%2 = add(%0, %1) # EncryptedScalar<uint5>
|
||||
%3 = subgraph(%2) # EncryptedScalar<uint5>
|
||||
return %3
|
||||
|
||||
Subgraphs:
|
||||
|
||||
%3 = subgraph(%2):
|
||||
|
||||
%0 = 10 # ClearScalar<uint4>
|
||||
%1 = 1 # ClearScalar<uint1>
|
||||
%2 = float_subgraph_input # EncryptedScalar<uint3>
|
||||
%3 = cos(%2) # EncryptedScalar<float64>
|
||||
%4 = add(%3, %1) # EncryptedScalar<float64>
|
||||
%5 = mul(%4, %0) # EncryptedScalar<float64>
|
||||
%6 = astype(%5, dtype=uint32) # EncryptedScalar<uint5>
|
||||
return %6
|
||||
|
||||
""".strip()
|
||||
), got
|
||||
|
||||
compiler = NPFHECompiler(function, {"x": "encrypted"}, default_compilation_configuration)
|
||||
|
||||
assert (
|
||||
got := str(compiler)
|
||||
) == "__str__ failed: OPGraph is None, NPFHECompiler needs evaluation on an inputset", got
|
||||
|
||||
compiler.eval_on_inputset(range(2 ** 3))
|
||||
|
||||
# String is different here as the type that is first propagated to trace the opgraph is not the
|
||||
# same
|
||||
|
||||
assert (got := str(compiler)) == (
|
||||
"""
|
||||
|
||||
%0 = x # EncryptedScalar<uint3>
|
||||
%1 = 1 # ClearScalar<uint1>
|
||||
%2 = add(%0, %1) # EncryptedScalar<uint4>
|
||||
%3 = subgraph(%2) # EncryptedScalar<uint5>
|
||||
return %3
|
||||
|
||||
Subgraphs:
|
||||
|
||||
%3 = subgraph(%2):
|
||||
|
||||
%0 = 10 # ClearScalar<uint4>
|
||||
%1 = 1 # ClearScalar<uint1>
|
||||
%2 = float_subgraph_input # EncryptedScalar<uint1>
|
||||
%3 = cos(%2) # EncryptedScalar<float64>
|
||||
%4 = add(%3, %1) # EncryptedScalar<float64>
|
||||
%5 = mul(%4, %0) # EncryptedScalar<float64>
|
||||
%6 = astype(%5, dtype=uint32) # EncryptedScalar<uint5>
|
||||
return %6
|
||||
""".strip()
|
||||
), got
|
||||
@@ -1,193 +0,0 @@
|
||||
"""Test file for convolution"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from concrete.common.extensions import convolution
|
||||
from concrete.common.representation.intermediate import Conv2D
|
||||
from concrete.common.tracing.base_tracer import BaseTracer
|
||||
from concrete.common.values.tensors import TensorValue
|
||||
from concrete.numpy.tracing import NPConstant, NPTracer
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kwargs, error_msg",
|
||||
[
|
||||
pytest.param(
|
||||
{"x": None, "weight": np.zeros(1)},
|
||||
"input x must be an ndarray, or a BaseTracer, not a",
|
||||
),
|
||||
pytest.param(
|
||||
{"x": np.zeros(1), "weight": None},
|
||||
"weight must be an ndarray, or a BaseTracer, not a",
|
||||
),
|
||||
pytest.param(
|
||||
{"x": np.zeros(1), "weight": np.zeros(1), "bias": 0},
|
||||
"bias must be an ndarray, a BaseTracer, or None, not a",
|
||||
),
|
||||
pytest.param(
|
||||
{"x": np.zeros(1), "weight": np.zeros(1), "strides": None},
|
||||
"strides must be a tuple, or list, not a",
|
||||
),
|
||||
pytest.param(
|
||||
{"x": np.zeros(1), "weight": np.zeros(1), "dilations": None},
|
||||
"dilations must be a tuple, or list, not a",
|
||||
),
|
||||
pytest.param(
|
||||
{"x": np.zeros(1), "weight": np.zeros(1), "pads": None},
|
||||
"padding must be a tuple, or list, not a",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_arg_types(kwargs, error_msg):
|
||||
"""Test function to make sure convolution doesn't accept invalid types"""
|
||||
|
||||
with pytest.raises(TypeError) as err:
|
||||
convolution.conv2d(**kwargs)
|
||||
|
||||
assert error_msg in str(err)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kwargs, error_msg",
|
||||
[
|
||||
pytest.param(
|
||||
{"x": np.zeros(1), "weight": np.zeros(1)},
|
||||
"input x should have size (N x C x H x W), not",
|
||||
),
|
||||
pytest.param(
|
||||
{"x": np.zeros((1, 2, 3, 4)), "weight": np.zeros(1)},
|
||||
"weight should have size (F x C x H x W), not",
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"x": np.zeros((1, 2, 3, 4)),
|
||||
"weight": np.zeros((1, 2, 3, 4)),
|
||||
"bias": np.zeros((1, 2)),
|
||||
},
|
||||
"bias should have size (F), not",
|
||||
),
|
||||
pytest.param(
|
||||
{"x": np.zeros(1), "weight": np.zeros(1), "strides": (1,)},
|
||||
"strides should be of the form",
|
||||
),
|
||||
pytest.param(
|
||||
{"x": np.zeros(1), "weight": np.zeros(1), "dilations": (1,)},
|
||||
"dilations should be of the form",
|
||||
),
|
||||
pytest.param(
|
||||
{"x": np.zeros(1), "weight": np.zeros(1), "pads": (1,)},
|
||||
"padding should be of the form",
|
||||
),
|
||||
pytest.param(
|
||||
{"x": np.zeros(1), "weight": np.zeros(1), "auto_pad": None},
|
||||
"invalid auto_pad is specified",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_input_shape(kwargs, error_msg):
|
||||
"""Test function to make sure convolution doesn't accept invalid shapes"""
|
||||
|
||||
with pytest.raises((ValueError, AssertionError)) as err:
|
||||
convolution.conv2d(**kwargs)
|
||||
|
||||
assert error_msg in str(err)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_shape, weight_shape",
|
||||
[
|
||||
pytest.param((1, 1, 4, 4), (1, 1, 2, 2)),
|
||||
pytest.param((3, 1, 4, 4), (1, 1, 2, 2)),
|
||||
pytest.param((1, 1, 4, 4), (3, 1, 2, 2)),
|
||||
pytest.param((1, 3, 4, 4), (1, 3, 2, 2)),
|
||||
pytest.param((4, 3, 4, 4), (3, 3, 2, 2)),
|
||||
pytest.param((4, 3, 16, 16), (3, 3, 2, 2)),
|
||||
pytest.param((4, 3, 16, 16), (3, 3, 3, 3)),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("strides", [(1, 1), (1, 2), (2, 1), (2, 2)])
|
||||
@pytest.mark.parametrize("dilations", [(1, 1), (1, 2), (2, 1), (2, 2)])
|
||||
@pytest.mark.parametrize("has_bias", [True, False])
|
||||
@pytest.mark.parametrize("use_ndarray", [True, False])
|
||||
def test_tracing(input_shape, weight_shape, strides, dilations, has_bias, use_ndarray):
|
||||
"""Test function to make sure tracong of conv2d works properly"""
|
||||
if has_bias:
|
||||
bias = np.random.randint(0, 4, size=(weight_shape[0],))
|
||||
if not use_ndarray:
|
||||
bias = NPTracer([], NPConstant(bias), 0)
|
||||
else:
|
||||
bias = None
|
||||
|
||||
x = NPTracer([], NPConstant(np.random.randint(0, 4, size=input_shape)), 0)
|
||||
weight = np.random.randint(0, 4, size=weight_shape)
|
||||
if not use_ndarray:
|
||||
weight = NPTracer([], NPConstant(weight), 0)
|
||||
|
||||
output_tracer = convolution.conv2d(x, weight, bias, strides=strides, dilations=dilations)
|
||||
traced_computation = output_tracer.traced_computation
|
||||
assert isinstance(traced_computation, Conv2D)
|
||||
|
||||
if has_bias:
|
||||
assert len(output_tracer.inputs) == 3
|
||||
else:
|
||||
assert len(output_tracer.inputs) == 2
|
||||
|
||||
assert all(
|
||||
isinstance(input_, BaseTracer) for input_ in output_tracer.inputs
|
||||
), f"{output_tracer.inputs}"
|
||||
|
||||
assert len(traced_computation.outputs) == 1
|
||||
output_value = traced_computation.outputs[0]
|
||||
assert isinstance(output_value, TensorValue) and output_value.is_encrypted
|
||||
# pylint: disable=no-member
|
||||
expected_shape = torch.conv2d(
|
||||
torch.randn(input_shape),
|
||||
torch.randn(weight_shape),
|
||||
torch.randn((weight_shape[0])),
|
||||
stride=strides,
|
||||
dilation=dilations,
|
||||
).shape
|
||||
# pylint: enable=no-member
|
||||
|
||||
assert output_value.shape == expected_shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_shape, weight_shape",
|
||||
[
|
||||
pytest.param((1, 1, 4, 4), (1, 1, 2, 2)),
|
||||
pytest.param((3, 1, 4, 4), (1, 1, 2, 2)),
|
||||
pytest.param((1, 1, 4, 4), (3, 1, 2, 2)),
|
||||
pytest.param((1, 3, 4, 4), (1, 3, 2, 2)),
|
||||
pytest.param((4, 3, 4, 4), (3, 3, 2, 2)),
|
||||
pytest.param((4, 3, 16, 16), (3, 3, 2, 2)),
|
||||
pytest.param((4, 3, 16, 16), (3, 3, 3, 3)),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("strides", [(1, 1), (1, 2), (2, 1), (2, 2)])
|
||||
@pytest.mark.parametrize("dilations", [(1, 1), (1, 2), (2, 1), (2, 2)])
|
||||
@pytest.mark.parametrize("has_bias", [True, False])
|
||||
def test_evaluation(input_shape, weight_shape, strides, dilations, has_bias):
|
||||
"""Test function to make sure evaluation of conv2d on plain data works properly"""
|
||||
if has_bias:
|
||||
bias = np.random.randint(0, 4, size=(weight_shape[0],))
|
||||
else:
|
||||
bias = np.zeros((weight_shape[0],))
|
||||
x = np.random.randint(0, 4, size=input_shape)
|
||||
weight = np.random.randint(0, 4, size=weight_shape)
|
||||
# pylint: disable=no-member
|
||||
expected = torch.conv2d(
|
||||
torch.tensor(x, dtype=torch.long),
|
||||
torch.tensor(weight, dtype=torch.long),
|
||||
torch.tensor(bias, dtype=torch.long),
|
||||
stride=strides,
|
||||
dilation=dilations,
|
||||
).numpy()
|
||||
# pylint: enable=no-member
|
||||
# conv2d should handle None biases
|
||||
if not has_bias:
|
||||
bias = None
|
||||
result = convolution.conv2d(x, weight, bias, strides=strides, dilations=dilations)
|
||||
assert (result == expected).all()
|
||||
@@ -1,118 +0,0 @@
|
||||
"""Test file for direct multi table lookups"""
|
||||
|
||||
import random
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.extensions.multi_table import MultiLookupTable
|
||||
from concrete.common.extensions.table import LookupTable
|
||||
|
||||
table_2b_to_2b = LookupTable([1, 2, 0, 3])
|
||||
table_2b_to_1b = LookupTable([1, 0, 0, 1])
|
||||
table_2b_to_3b = LookupTable([5, 2, 7, 0])
|
||||
|
||||
table_3b_to_2b = LookupTable([1, 2, 0, 3, 0, 3, 1, 2])
|
||||
table_3b_to_1b = LookupTable([1, 0, 0, 1, 1, 1, 1, 0])
|
||||
table_3b_to_3b = LookupTable([5, 2, 7, 0, 4, 1, 6, 2])
|
||||
|
||||
tables_2b = [table_2b_to_1b, table_2b_to_2b, table_2b_to_3b]
|
||||
tables_3b = [table_3b_to_1b, table_3b_to_2b, table_3b_to_3b]
|
||||
|
||||
|
||||
def test_multi_lookup_table_creation_and_indexing():
|
||||
"""Test function for creating and indexing multi lookup tables"""
|
||||
tables = [
|
||||
[tables_2b[random.randint(0, 2)], tables_2b[random.randint(0, 2)]],
|
||||
[tables_2b[random.randint(0, 2)], tables_2b[random.randint(0, 2)]],
|
||||
[tables_2b[random.randint(0, 2)], tables_2b[random.randint(0, 2)]],
|
||||
]
|
||||
multitable = MultiLookupTable(tables)
|
||||
|
||||
assert multitable.input_shape == (3, 2)
|
||||
|
||||
assert isinstance(multitable.output_dtype, Integer)
|
||||
assert multitable.output_dtype.bit_width <= 3
|
||||
|
||||
index = numpy.random.randint(0, 2 ** 2, size=multitable.input_shape).tolist()
|
||||
result = multitable[index]
|
||||
|
||||
for i in range(3):
|
||||
for j in range(2):
|
||||
assert result[i][j] == multitable.tables[i][j][index[i][j]], f"i={i}, j={j}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tables,match",
|
||||
[
|
||||
pytest.param(
|
||||
[
|
||||
[],
|
||||
[table_2b_to_2b, table_2b_to_3b],
|
||||
],
|
||||
"MultiLookupTable cannot have an empty array within it",
|
||||
),
|
||||
pytest.param(
|
||||
[
|
||||
[table_2b_to_1b, 42.0],
|
||||
[table_2b_to_2b, table_2b_to_3b],
|
||||
],
|
||||
"MultiLookupTable should have been made out of LookupTables "
|
||||
"but it had an object of type float within it",
|
||||
),
|
||||
pytest.param(
|
||||
[
|
||||
[table_2b_to_2b],
|
||||
[table_2b_to_2b, table_2b_to_3b],
|
||||
[table_2b_to_2b, table_2b_to_1b],
|
||||
],
|
||||
"MultiLookupTable should have the shape (3, 1) but it does not "
|
||||
"(an array on dimension 1 has the size 2 but its size should have been 1 "
|
||||
"as the expected shape is (3, 1))",
|
||||
),
|
||||
pytest.param(
|
||||
[
|
||||
[table_2b_to_2b, table_3b_to_3b],
|
||||
[table_2b_to_2b, table_3b_to_1b],
|
||||
],
|
||||
"LookupTables within a MultiLookupTable should have the same size but they do not "
|
||||
"(there was a table with the size of 4 and another with the size of 8)",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_multi_lookup_table_creation_failure(tables, match):
|
||||
"""Test function for failing to create multi lookup tables"""
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
MultiLookupTable(tables)
|
||||
|
||||
assert str(excinfo.value) == match
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tables,index,match",
|
||||
[
|
||||
pytest.param(
|
||||
[
|
||||
[table_2b_to_2b, table_2b_to_1b, table_2b_to_3b],
|
||||
[table_2b_to_1b, table_2b_to_2b, table_2b_to_3b],
|
||||
],
|
||||
[
|
||||
[1, 2],
|
||||
[3, 0],
|
||||
],
|
||||
"Multiple Lookup Table of shape (2, 3) cannot be looked up with [[1, 2], [3, 0]] "
|
||||
"(you should check your inputset)",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_multi_lookup_table_indexing_failure(tables, index, match):
|
||||
"""Test function for failing to index multi lookup tables"""
|
||||
|
||||
table = MultiLookupTable(tables)
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
table.__getitem__(index)
|
||||
|
||||
assert str(excinfo.value) == match
|
||||
@@ -1,131 +0,0 @@
|
||||
"""Test file for direct table lookups"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import networkx as nx
|
||||
import pytest
|
||||
|
||||
from concrete.common import is_a_power_of_2
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.extensions.table import LookupTable
|
||||
from concrete.common.representation import intermediate as ir
|
||||
from concrete.common.values import EncryptedScalar
|
||||
from concrete.numpy import tracing
|
||||
|
||||
|
||||
def test_lookup_table_size_constraints():
|
||||
"""Test function to make sure lookup tables have correct size"""
|
||||
|
||||
table = []
|
||||
|
||||
# creating empty lookup table is not acceptable
|
||||
with pytest.raises(ValueError):
|
||||
LookupTable(table)
|
||||
|
||||
for _ in range(512):
|
||||
table.append(0)
|
||||
|
||||
if is_a_power_of_2(len(table)):
|
||||
# creating lookup table with 2^N entries are acceptable
|
||||
LookupTable(table)
|
||||
else:
|
||||
# creating lookup table with anything other than 2^N entries are not acceptable
|
||||
with pytest.raises(ValueError):
|
||||
LookupTable(table)
|
||||
|
||||
|
||||
def test_lookup_table_encrypted_lookup(test_helpers):
|
||||
"""Test function for tracing with explicit table lookups using encrypted inputs"""
|
||||
|
||||
table = LookupTable([3, 6, 0, 2])
|
||||
|
||||
def f(x):
|
||||
return table[x]
|
||||
|
||||
x = EncryptedScalar(Integer(2, is_signed=False))
|
||||
op_graph = tracing.trace_numpy_function(f, {"x": x})
|
||||
|
||||
table_node = op_graph.output_nodes[0]
|
||||
|
||||
assert table_node.get_table(op_graph.get_ordered_preds(table_node)) == [3, 6, 0, 2]
|
||||
|
||||
ref_graph = nx.MultiDiGraph()
|
||||
# Here is the ASCII drawing of the expected graph:
|
||||
# (x) - (TLU)
|
||||
|
||||
input_x = ir.Input(input_value=x, input_name="x", program_input_idx=0)
|
||||
ref_graph.add_node(input_x)
|
||||
|
||||
generic_function_output_value = deepcopy(x)
|
||||
generic_function_output_value.dtype = table.output_dtype
|
||||
|
||||
# pylint: disable=protected-access
|
||||
# Need access to _checked_indexing to have is_equivalent_to work for ir.GenericFunction
|
||||
output_arbitrary_function = ir.GenericFunction(
|
||||
inputs=[x],
|
||||
arbitrary_func=LookupTable._checked_indexing,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
op_kwargs={"table": deepcopy(table.table)},
|
||||
op_name="TLU",
|
||||
)
|
||||
# pylint: enable=protected-access
|
||||
ref_graph.add_node(output_arbitrary_function)
|
||||
|
||||
ref_graph.add_edge(input_x, output_arbitrary_function, input_idx=0, output_idx=0)
|
||||
|
||||
# TODO: discuss if this check is enough as == is not overloaded properly for GenericFunction
|
||||
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)
|
||||
|
||||
|
||||
def test_lookup_table_encrypted_and_plain_lookup(test_helpers):
|
||||
"""Test function for tracing with explicit table lookups using encrypted and plain inputs"""
|
||||
|
||||
table = LookupTable([3, 6, 0, 2, 1, 4, 5, 7])
|
||||
|
||||
def f(x):
|
||||
return table[x] + table[0]
|
||||
|
||||
x = EncryptedScalar(Integer(3, is_signed=False))
|
||||
op_graph = tracing.trace_numpy_function(f, {"x": x})
|
||||
|
||||
ref_graph = nx.MultiDiGraph()
|
||||
# Here is the ASCII drawing of the expected graph:
|
||||
# (x) - (TLU)
|
||||
# \
|
||||
# (+)
|
||||
# /
|
||||
# (3)
|
||||
|
||||
input_x = ir.Input(input_value=x, input_name="x", program_input_idx=0)
|
||||
ref_graph.add_node(input_x)
|
||||
|
||||
generic_function_output_value = deepcopy(x)
|
||||
generic_function_output_value.dtype = table.output_dtype
|
||||
|
||||
# pylint: disable=protected-access
|
||||
# Need access to _checked_indexing to have is_equivalent_to work for ir.GenericFunction
|
||||
intermediate_arbitrary_function = ir.GenericFunction(
|
||||
inputs=[x],
|
||||
arbitrary_func=LookupTable._checked_indexing,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
op_kwargs={"table": deepcopy(table.table)},
|
||||
op_name="TLU",
|
||||
)
|
||||
# pylint: enable=protected-access
|
||||
ref_graph.add_node(intermediate_arbitrary_function)
|
||||
|
||||
constant_3 = ir.Constant(3)
|
||||
ref_graph.add_node(constant_3)
|
||||
|
||||
output_add = ir.Add((intermediate_arbitrary_function.outputs[0], constant_3.outputs[0]))
|
||||
ref_graph.add_node(output_add)
|
||||
|
||||
ref_graph.add_edge(input_x, intermediate_arbitrary_function, input_idx=0, output_idx=0)
|
||||
|
||||
ref_graph.add_edge(intermediate_arbitrary_function, output_add, input_idx=0, output_idx=0)
|
||||
ref_graph.add_edge(constant_3, output_add, input_idx=1, output_idx=0)
|
||||
|
||||
# TODO: discuss if this check is enough as == is not overloaded properly for GenericFunction
|
||||
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)
|
||||
@@ -1,24 +0,0 @@
|
||||
"""Test file for common python helpers"""
|
||||
|
||||
from concrete.common.helpers.python_helpers import catch
|
||||
|
||||
|
||||
def test_catch_failure():
|
||||
"""Test case for when the function called with catch raises an exception."""
|
||||
|
||||
def f_fail():
|
||||
return 1 / 0
|
||||
|
||||
assert catch(f_fail) is None
|
||||
|
||||
|
||||
def test_catch():
|
||||
"""Test case for catch"""
|
||||
|
||||
def f(*args, **kwargs):
|
||||
return *args, dict(**kwargs)
|
||||
|
||||
assert catch(f, (1, 2, 3,), **{"one": 1, "two": 2, "three": 3}) == (
|
||||
(1, 2, 3),
|
||||
{"one": 1, "two": 2, "three": 3},
|
||||
)
|
||||
@@ -1,87 +0,0 @@
|
||||
"""Test file for MLIR conversion helpers."""
|
||||
|
||||
# pylint cannot extract symbol information of 'mlir' module so we need to disable some lints
|
||||
|
||||
# pylint: disable=no-name-in-module
|
||||
|
||||
import concrete.lang as concretelang
|
||||
import pytest
|
||||
from mlir.ir import Context, Location
|
||||
|
||||
from concrete.common.data_types import Float, SignedInteger, UnsignedInteger
|
||||
from concrete.common.mlir.conversion_helpers import integer_to_mlir_type, value_to_mlir_type
|
||||
from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor
|
||||
|
||||
# pylint: enable=no-name-in-module
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"integer,is_encrypted,expected_mlir_type_str",
|
||||
[
|
||||
pytest.param(SignedInteger(5), False, "i5"),
|
||||
pytest.param(UnsignedInteger(5), False, "i5"),
|
||||
pytest.param(SignedInteger(32), False, "i32"),
|
||||
pytest.param(UnsignedInteger(32), False, "i32"),
|
||||
pytest.param(SignedInteger(5), True, "!FHE.eint<5>"),
|
||||
pytest.param(UnsignedInteger(5), True, "!FHE.eint<5>"),
|
||||
],
|
||||
)
|
||||
def test_integer_to_mlir_type(integer, is_encrypted, expected_mlir_type_str):
|
||||
"""Test function for integer to MLIR type conversion."""
|
||||
|
||||
with Context() as ctx, Location.unknown():
|
||||
concretelang.register_dialects(ctx)
|
||||
assert str(integer_to_mlir_type(ctx, integer, is_encrypted)) == expected_mlir_type_str
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value,expected_mlir_type_str",
|
||||
[
|
||||
pytest.param(ClearScalar(SignedInteger(5)), "i5"),
|
||||
pytest.param(ClearTensor(SignedInteger(5), shape=(2, 3)), "tensor<2x3xi5>"),
|
||||
pytest.param(EncryptedScalar(SignedInteger(5)), "!FHE.eint<5>"),
|
||||
pytest.param(EncryptedTensor(SignedInteger(5), shape=(2, 3)), "tensor<2x3x!FHE.eint<5>>"),
|
||||
pytest.param(ClearScalar(UnsignedInteger(5)), "i5"),
|
||||
pytest.param(ClearTensor(UnsignedInteger(5), shape=(2, 3)), "tensor<2x3xi5>"),
|
||||
pytest.param(EncryptedScalar(UnsignedInteger(5)), "!FHE.eint<5>"),
|
||||
pytest.param(EncryptedTensor(UnsignedInteger(5), shape=(2, 3)), "tensor<2x3x!FHE.eint<5>>"),
|
||||
],
|
||||
)
|
||||
def test_value_to_mlir_type(value, expected_mlir_type_str):
|
||||
"""Test function for value to MLIR type conversion."""
|
||||
|
||||
with Context() as ctx, Location.unknown():
|
||||
concretelang.register_dialects(ctx)
|
||||
assert str(value_to_mlir_type(ctx, value)) == expected_mlir_type_str
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value,expected_error_message",
|
||||
[
|
||||
pytest.param(
|
||||
ClearScalar(Float(32)),
|
||||
"ClearScalar<float32> is not supported for MLIR conversion",
|
||||
),
|
||||
pytest.param(
|
||||
ClearTensor(Float(32), shape=(2, 3)),
|
||||
"ClearTensor<float32, shape=(2, 3)> is not supported for MLIR conversion",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedScalar(Float(32)),
|
||||
"EncryptedScalar<float32> is not supported for MLIR conversion",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(Float(32), shape=(2, 3)),
|
||||
"EncryptedTensor<float32, shape=(2, 3)> is not supported for MLIR conversion",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_fail_value_to_mlir_type(value, expected_error_message):
|
||||
"""Test function for failed value to MLIR type conversion."""
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
with Context() as ctx, Location.unknown():
|
||||
concretelang.register_dialects(ctx)
|
||||
value_to_mlir_type(ctx, value)
|
||||
|
||||
assert str(excinfo.value) == expected_error_message
|
||||
@@ -1,107 +0,0 @@
|
||||
"""Test file for intermediate node to MLIR converter."""
|
||||
|
||||
import random
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types import UnsignedInteger
|
||||
from concrete.common.values import EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import compile_numpy_function
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_compile,parameters,inputset,expected_error_type,expected_error_message",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: x * y,
|
||||
{
|
||||
"x": EncryptedScalar(UnsignedInteger(3)),
|
||||
"y": EncryptedScalar(UnsignedInteger(3)),
|
||||
},
|
||||
[(random.randint(0, 7), random.randint(0, 7)) for _ in range(10)] + [(7, 7)],
|
||||
NotImplementedError,
|
||||
"Multiplication "
|
||||
"between "
|
||||
"EncryptedScalar<uint6> "
|
||||
"and "
|
||||
"EncryptedScalar<uint6> "
|
||||
"cannot be converted to MLIR yet",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x - y,
|
||||
{
|
||||
"x": EncryptedScalar(UnsignedInteger(3)),
|
||||
"y": EncryptedScalar(UnsignedInteger(3)),
|
||||
},
|
||||
[(random.randint(5, 7), random.randint(0, 5)) for _ in range(10)],
|
||||
NotImplementedError,
|
||||
"Subtraction "
|
||||
"of "
|
||||
"EncryptedScalar<uint3> "
|
||||
"from "
|
||||
"EncryptedScalar<uint3> "
|
||||
"cannot be converted to MLIR yet",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: numpy.dot(x, y),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(2,)),
|
||||
"y": EncryptedTensor(UnsignedInteger(3), shape=(2,)),
|
||||
},
|
||||
[
|
||||
(
|
||||
numpy.random.randint(0, 2 ** 3, size=(2,)),
|
||||
numpy.random.randint(0, 2 ** 3, size=(2,)),
|
||||
)
|
||||
for _ in range(10)
|
||||
]
|
||||
+ [(numpy.array([7, 7]), numpy.array([7, 7]))],
|
||||
NotImplementedError,
|
||||
"Dot product "
|
||||
"between "
|
||||
"EncryptedTensor<uint7, shape=(2,)> "
|
||||
"and "
|
||||
"EncryptedTensor<uint7, shape=(2,)> "
|
||||
"cannot be converted to MLIR yet",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x @ y,
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
||||
"y": EncryptedTensor(UnsignedInteger(3), shape=(2, 1)),
|
||||
},
|
||||
[
|
||||
(
|
||||
numpy.random.randint(0, 2 ** 3, size=(3, 2)),
|
||||
numpy.random.randint(0, 2 ** 3, size=(2, 1)),
|
||||
)
|
||||
for i in range(10)
|
||||
]
|
||||
+ [(numpy.array([[7, 7], [7, 7], [7, 7]]), numpy.array([[7], [7]]))],
|
||||
NotImplementedError,
|
||||
"Matrix multiplication "
|
||||
"between "
|
||||
"EncryptedTensor<uint7, shape=(3, 2)> "
|
||||
"and "
|
||||
"EncryptedTensor<uint7, shape=(2, 1)> "
|
||||
"cannot be converted to MLIR yet",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_fail_node_conversion(
|
||||
function_to_compile,
|
||||
parameters,
|
||||
inputset,
|
||||
expected_error_type,
|
||||
expected_error_message,
|
||||
default_compilation_configuration,
|
||||
):
|
||||
"""Test function for failed intermediate node conversion."""
|
||||
|
||||
with pytest.raises(expected_error_type) as excinfo:
|
||||
compile_numpy_function(
|
||||
function_to_compile, parameters, inputset, default_compilation_configuration
|
||||
)
|
||||
|
||||
assert str(excinfo.value) == expected_error_message
|
||||
@@ -1,735 +0,0 @@
|
||||
"""Test file for float subgraph fusing"""
|
||||
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from inspect import signature
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.debugging import format_operation_graph
|
||||
from concrete.common.debugging.custom_assert import assert_not_reached
|
||||
from concrete.common.optimization.topological import fuse_float_operations
|
||||
from concrete.common.values import EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import tracing
|
||||
from concrete.numpy.tracing import trace_numpy_function
|
||||
|
||||
|
||||
def no_fuse(x):
|
||||
"""No fuse"""
|
||||
return x + 2
|
||||
|
||||
|
||||
def no_fuse_unhandled(x, y):
|
||||
"""No fuse unhandled"""
|
||||
x_1 = x + 0.7
|
||||
y_1 = y + 1.3
|
||||
intermediate = x_1 + y_1
|
||||
return intermediate.astype(numpy.int32)
|
||||
|
||||
|
||||
def fusable_with_bigger_search(x, y):
|
||||
"""fusable with bigger search"""
|
||||
x = x + 1
|
||||
x_1 = x.astype(numpy.int32)
|
||||
x_1 = x_1 + 1.5
|
||||
x_2 = x.astype(numpy.int32)
|
||||
x_2 = x_2 + 3.4
|
||||
add = x_1 + x_2
|
||||
add_int = add.astype(numpy.int32)
|
||||
return add_int + y
|
||||
|
||||
|
||||
def fusable_with_bigger_search_needs_second_iteration(x, y):
|
||||
"""fusable with bigger search and triggers a second iteration in the fusing"""
|
||||
x = x + 1
|
||||
x = x + 0.5
|
||||
x = numpy.cos(x)
|
||||
x_1 = x.astype(numpy.int32)
|
||||
x_1 = x_1 + 1.5
|
||||
x_p = x + 1
|
||||
x_p2 = x_p + 1
|
||||
x_2 = (x_p + x_p2).astype(numpy.int32)
|
||||
x_2 = x_2 + 3.4
|
||||
add = x_1 + x_2
|
||||
add_int = add.astype(numpy.int32)
|
||||
return add_int + y
|
||||
|
||||
|
||||
def no_fuse_big_constant_3_10_10(x):
|
||||
"""Pass an array x with size < 100 to trigger a no fuse condition."""
|
||||
x = x.astype(numpy.float64)
|
||||
return (x + numpy.ones((3, 10, 10))).astype(numpy.int32)
|
||||
|
||||
|
||||
def no_fuse_dot(x):
|
||||
"""No fuse dot"""
|
||||
return numpy.dot(x, numpy.full((10,), 1.33, dtype=numpy.float64)).astype(numpy.int32)
|
||||
|
||||
|
||||
def simple_create_fuse_opportunity(f, x):
|
||||
"""No fuse because the function is explicitely marked as unfusable in our code."""
|
||||
return f(x.astype(numpy.float64)).astype(numpy.int32)
|
||||
|
||||
|
||||
def ravel_cases(x):
|
||||
"""Simple ravel cases"""
|
||||
return simple_create_fuse_opportunity(numpy.ravel, x)
|
||||
|
||||
|
||||
def transpose_cases(x):
|
||||
"""Simple transpose cases"""
|
||||
return simple_create_fuse_opportunity(numpy.transpose, x)
|
||||
|
||||
|
||||
def reshape_cases(x, newshape):
|
||||
"""Simple reshape cases"""
|
||||
return simple_create_fuse_opportunity(lambda x: numpy.reshape(x, newshape), x)
|
||||
|
||||
|
||||
def simple_fuse_not_output(x):
|
||||
"""Simple fuse not output"""
|
||||
intermediate = x.astype(numpy.float64)
|
||||
intermediate = intermediate.astype(numpy.int32)
|
||||
return intermediate + 2
|
||||
|
||||
|
||||
def simple_fuse_output(x):
|
||||
"""Simple fuse output"""
|
||||
return x.astype(numpy.float64).astype(numpy.int32)
|
||||
|
||||
|
||||
def mix_x_and_y_intricately_and_call_f(function, x, y):
|
||||
"""Mix x and y in an intricated way, that can't be simplified by
|
||||
an optimizer eg, and then call function
|
||||
"""
|
||||
intermediate = x + y
|
||||
intermediate = intermediate + 2
|
||||
intermediate = intermediate.astype(numpy.float32)
|
||||
intermediate = intermediate.astype(numpy.int32)
|
||||
x_p_1 = intermediate + 1.5
|
||||
x_p_2 = intermediate + 2.7
|
||||
x_p_3 = function(x_p_1 + x_p_2)
|
||||
return (
|
||||
x_p_3.astype(numpy.int32),
|
||||
x_p_2.astype(numpy.int32),
|
||||
(x_p_2 + 3).astype(numpy.int32),
|
||||
x_p_3.astype(numpy.int32) + 67,
|
||||
y,
|
||||
(y + 4.7).astype(numpy.int32) + 3,
|
||||
)
|
||||
|
||||
|
||||
def mix_x_and_y_and_call_f(function, x, y):
|
||||
"""Mix x and y and then call function"""
|
||||
x_p_1 = x + 0.1
|
||||
x_p_2 = x + 0.2
|
||||
x_p_3 = function(x_p_1 + x_p_2)
|
||||
return (
|
||||
x_p_3.astype(numpy.int32),
|
||||
x_p_2.astype(numpy.int32),
|
||||
(x_p_2 + 3).astype(numpy.int32),
|
||||
x_p_3.astype(numpy.int32) + 67,
|
||||
y,
|
||||
(y + 4.7).astype(numpy.int32) + 3,
|
||||
)
|
||||
|
||||
|
||||
def mix_x_and_y_into_range_0_to_1_and_call_f(function, x, y):
|
||||
"""Mix x and y and then call function, in such a way that the input to function is between
|
||||
0 and 1"""
|
||||
x_p_1 = x + 0.1
|
||||
x_p_2 = x + 0.2
|
||||
x_p_4 = 1 - numpy.abs(numpy.sin(x_p_1 + x_p_2 + 0.3))
|
||||
x_p_3 = function(x_p_4)
|
||||
return (
|
||||
x_p_3.astype(numpy.int32),
|
||||
x_p_2.astype(numpy.int32),
|
||||
(x_p_2 + 3).astype(numpy.int32),
|
||||
x_p_3.astype(numpy.int32) + 67,
|
||||
y,
|
||||
(y + 4.7).astype(numpy.int32) + 3,
|
||||
)
|
||||
|
||||
|
||||
def mix_x_and_y_into_integer_and_call_f(function, x, y):
|
||||
"""Mix x and y but keep the entry to function as an integer"""
|
||||
x_p_1 = x + 1
|
||||
x_p_2 = x + 2
|
||||
x_p_3 = function(x_p_1 + x_p_2)
|
||||
return (
|
||||
x_p_3.astype(numpy.int32),
|
||||
x_p_2.astype(numpy.int32),
|
||||
(x_p_2 + 3).astype(numpy.int32),
|
||||
x_p_3.astype(numpy.int32) + 67,
|
||||
y,
|
||||
(y + 4.7).astype(numpy.int32) + 3,
|
||||
)
|
||||
|
||||
|
||||
def get_func_params_int32(func, scalar=True):
|
||||
"""Returns a dict with parameters as scalar int32"""
|
||||
|
||||
return {
|
||||
param_name: EncryptedScalar(Integer(32, True))
|
||||
if scalar
|
||||
else EncryptedTensor(Integer(32, True), (1,))
|
||||
for param_name in signature(func).parameters.keys()
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,fused,params,warning_message",
|
||||
[
|
||||
pytest.param(no_fuse, False, get_func_params_int32(no_fuse), "", id="no_fuse"),
|
||||
pytest.param(
|
||||
no_fuse_unhandled,
|
||||
False,
|
||||
get_func_params_int32(no_fuse_unhandled),
|
||||
"""
|
||||
|
||||
The following subgraph is not fusable:
|
||||
|
||||
%0 = x # EncryptedScalar<int32>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of 2 variable inputs (can only have 1 for fusing)
|
||||
%1 = 0.7 # ClearScalar<float64>
|
||||
%2 = y # EncryptedScalar<int32>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of 2 variable inputs (can only have 1 for fusing)
|
||||
%3 = 1.3 # ClearScalar<float64>
|
||||
%4 = add(%0, %1) # EncryptedScalar<float64>
|
||||
%5 = add(%2, %3) # EncryptedScalar<float64>
|
||||
%6 = add(%4, %5) # EncryptedScalar<float64>
|
||||
%7 = astype(%6, dtype=int32) # EncryptedScalar<int32>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ cannot fuse here as the subgraph has 2 variable inputs
|
||||
return %7
|
||||
|
||||
""".strip(), # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_unhandled",
|
||||
),
|
||||
pytest.param(
|
||||
fusable_with_bigger_search,
|
||||
True,
|
||||
get_func_params_int32(fusable_with_bigger_search),
|
||||
None,
|
||||
id="fusable_with_bigger_search",
|
||||
),
|
||||
pytest.param(
|
||||
fusable_with_bigger_search_needs_second_iteration,
|
||||
True,
|
||||
get_func_params_int32(fusable_with_bigger_search_needs_second_iteration),
|
||||
None,
|
||||
id="fusable_with_bigger_search",
|
||||
),
|
||||
pytest.param(
|
||||
no_fuse_dot,
|
||||
False,
|
||||
{"x": EncryptedTensor(Integer(32, True), (10,))},
|
||||
"""
|
||||
|
||||
The following subgraph is not fusable:
|
||||
|
||||
%0 = x # EncryptedTensor<int32, shape=(10,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10,)
|
||||
%1 = [1.33 1.33 ... 1.33 1.33] # ClearTensor<float64, shape=(10,)>
|
||||
%2 = dot(%0, %1) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ output shapes: #0, () are not the same as the subgraph's input: (10,)
|
||||
%3 = astype(%2, dtype=int32) # EncryptedScalar<int32>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ output shapes: #0, () are not the same as the subgraph's input: (10,)
|
||||
return %3
|
||||
|
||||
""".strip(), # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_dot",
|
||||
),
|
||||
pytest.param(
|
||||
ravel_cases,
|
||||
False,
|
||||
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
|
||||
"""
|
||||
|
||||
The following subgraph is not fusable:
|
||||
|
||||
%0 = x # EncryptedTensor<int32, shape=(10, 20)>
|
||||
%1 = astype(%0, dtype=float64) # EncryptedTensor<float64, shape=(10, 20)>
|
||||
%2 = ravel(%1) # EncryptedTensor<float64, shape=(200,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
|
||||
%3 = astype(%2, dtype=int32) # EncryptedTensor<int32, shape=(200,)>
|
||||
return %3
|
||||
|
||||
""".strip(), # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_explicitely_ravel",
|
||||
),
|
||||
pytest.param(
|
||||
transpose_cases,
|
||||
False,
|
||||
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
|
||||
"""
|
||||
|
||||
The following subgraph is not fusable:
|
||||
|
||||
%0 = x # EncryptedTensor<int32, shape=(10, 20)>
|
||||
%1 = astype(%0, dtype=float64) # EncryptedTensor<float64, shape=(10, 20)>
|
||||
%2 = transpose(%1) # EncryptedTensor<float64, shape=(20, 10)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
|
||||
%3 = astype(%2, dtype=int32) # EncryptedTensor<int32, shape=(20, 10)>
|
||||
return %3
|
||||
|
||||
""".strip(), # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_explicitely_transpose",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: reshape_cases(x, (20, 10)),
|
||||
False,
|
||||
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
|
||||
"""
|
||||
|
||||
The following subgraph is not fusable:
|
||||
|
||||
%0 = x # EncryptedTensor<int32, shape=(10, 20)>
|
||||
%1 = astype(%0, dtype=float64) # EncryptedTensor<float64, shape=(10, 20)>
|
||||
%2 = reshape(%1, newshape=(20, 10)) # EncryptedTensor<float64, shape=(20, 10)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
|
||||
%3 = astype(%2, dtype=int32) # EncryptedTensor<int32, shape=(20, 10)>
|
||||
return %3
|
||||
|
||||
""".strip(), # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_explicitely_reshape",
|
||||
),
|
||||
pytest.param(
|
||||
no_fuse_big_constant_3_10_10,
|
||||
False,
|
||||
{"x": EncryptedTensor(Integer(32, True), (10, 10))},
|
||||
"""
|
||||
|
||||
The following subgraph is not fusable:
|
||||
|
||||
%0 = [[[1. 1. 1 ... . 1. 1.]]] # ClearTensor<float64, shape=(3, 10, 10)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this constant node has a bigger shape (3, 10, 10) than the subgraph's input: (10, 10)
|
||||
%1 = x # EncryptedTensor<int32, shape=(10, 10)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10, 10)
|
||||
%2 = astype(%1, dtype=float64) # EncryptedTensor<float64, shape=(10, 10)>
|
||||
%3 = add(%2, %0) # EncryptedTensor<float64, shape=(3, 10, 10)>
|
||||
%4 = astype(%3, dtype=int32) # EncryptedTensor<int32, shape=(3, 10, 10)>
|
||||
return %4
|
||||
|
||||
""".strip(), # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_big_constant_3_10_10",
|
||||
),
|
||||
pytest.param(
|
||||
simple_fuse_not_output,
|
||||
True,
|
||||
get_func_params_int32(simple_fuse_not_output),
|
||||
None,
|
||||
id="simple_fuse_not_output",
|
||||
),
|
||||
pytest.param(
|
||||
simple_fuse_output,
|
||||
True,
|
||||
get_func_params_int32(simple_fuse_output),
|
||||
None,
|
||||
id="simple_fuse_output",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: mix_x_and_y_intricately_and_call_f(numpy.rint, x, y),
|
||||
True,
|
||||
get_func_params_int32(lambda x, y: None),
|
||||
None,
|
||||
id="mix_x_and_y_intricately_and_call_f_with_rint",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: mix_x_and_y_and_call_f(numpy.rint, x, y),
|
||||
True,
|
||||
get_func_params_int32(lambda x, y: None),
|
||||
None,
|
||||
id="mix_x_and_y_and_call_f_with_rint",
|
||||
),
|
||||
pytest.param(
|
||||
transpose_cases,
|
||||
True,
|
||||
get_func_params_int32(transpose_cases),
|
||||
None,
|
||||
id="transpose_cases scalar",
|
||||
),
|
||||
pytest.param(
|
||||
transpose_cases,
|
||||
True,
|
||||
{"x": EncryptedTensor(Integer(32, True), (10,))},
|
||||
None,
|
||||
id="transpose_cases ndim == 1",
|
||||
),
|
||||
pytest.param(
|
||||
ravel_cases,
|
||||
True,
|
||||
{"x": EncryptedTensor(Integer(32, True), (10,))},
|
||||
None,
|
||||
id="ravel_cases ndim == 1",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: reshape_cases(x, (10, 20)),
|
||||
True,
|
||||
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
|
||||
None,
|
||||
id="reshape_cases same shape",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_fuse_float_operations(
|
||||
function_to_trace,
|
||||
fused,
|
||||
params,
|
||||
warning_message,
|
||||
capfd,
|
||||
remove_color_codes,
|
||||
check_array_equality,
|
||||
):
|
||||
"""Test function for fuse_float_operations"""
|
||||
|
||||
op_graph = trace_numpy_function(
|
||||
function_to_trace,
|
||||
params,
|
||||
)
|
||||
copied_graph = deepcopy(op_graph)
|
||||
orig_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(op_graph)
|
||||
fused_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(copied_graph)
|
||||
|
||||
# Check determinism
|
||||
assert format_operation_graph(copied_graph) == format_operation_graph(op_graph)
|
||||
|
||||
if fused:
|
||||
assert fused_num_nodes < orig_num_nodes
|
||||
else:
|
||||
assert fused_num_nodes == orig_num_nodes
|
||||
captured = capfd.readouterr()
|
||||
assert warning_message in (output := remove_color_codes(captured.err)), output
|
||||
|
||||
for input_ in [0, 2, 42, 44]:
|
||||
inputs = ()
|
||||
for param_input_value in params.values():
|
||||
if param_input_value.is_scalar:
|
||||
input_ = numpy.int32(input_)
|
||||
else:
|
||||
input_ = numpy.full(param_input_value.shape, input_, dtype=numpy.int32)
|
||||
inputs += (input_,)
|
||||
|
||||
check_array_equality(function_to_trace(*inputs), op_graph(*inputs))
|
||||
|
||||
|
||||
def subtest_tensor_no_fuse(fun, tensor_shape):
|
||||
"""Test case to verify float fusing is only applied on functions on scalars."""
|
||||
|
||||
if tensor_shape == ():
|
||||
# We want tensors
|
||||
return
|
||||
|
||||
if fun in LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES:
|
||||
# We need at least one input of the bivariate function to be float
|
||||
return
|
||||
|
||||
# Float fusing currently cannot work if the constant in a bivariate operator is bigger than the
|
||||
# variable input.
|
||||
# Make a broadcastable shape but with the constant being bigger
|
||||
variable_tensor_shape = (1,) + tensor_shape
|
||||
constant_bigger_shape = (random.randint(2, 10),) + tensor_shape
|
||||
|
||||
def tensor_no_fuse(x):
|
||||
intermediate = x.astype(numpy.float64)
|
||||
intermediate = fun(intermediate, numpy.ones(constant_bigger_shape))
|
||||
return intermediate.astype(numpy.int32)
|
||||
|
||||
function_to_trace = tensor_no_fuse
|
||||
params_names = signature(function_to_trace).parameters.keys()
|
||||
|
||||
op_graph = trace_numpy_function(
|
||||
function_to_trace,
|
||||
{
|
||||
param_name: EncryptedTensor(Integer(32, True), shape=variable_tensor_shape)
|
||||
for param_name in params_names
|
||||
},
|
||||
)
|
||||
orig_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(op_graph)
|
||||
fused_num_nodes = len(op_graph.graph)
|
||||
|
||||
assert orig_num_nodes == fused_num_nodes
|
||||
|
||||
|
||||
def check_results_are_equal(function_result, op_graph_result):
|
||||
"""Check the output of function execution and OPGraph evaluation are equal."""
|
||||
|
||||
if isinstance(function_result, tuple) and isinstance(op_graph_result, tuple):
|
||||
assert len(function_result) == len(op_graph_result)
|
||||
are_equal = (
|
||||
function_output == op_graph_output
|
||||
for function_output, op_graph_output in zip(function_result, op_graph_result)
|
||||
)
|
||||
elif not isinstance(function_result, tuple) and not isinstance(op_graph_result, tuple):
|
||||
are_equal = (function_result == op_graph_result,)
|
||||
else:
|
||||
assert_not_reached(f"Incompatible outputs: {function_result}, {op_graph_result}")
|
||||
|
||||
return all(value.all() if isinstance(value, numpy.ndarray) else value for value in are_equal)
|
||||
|
||||
|
||||
def subtest_fuse_float_unary_operations_correctness(fun, tensor_shape):
|
||||
"""Test a unary function with fuse_float_operations."""
|
||||
|
||||
# Some manipulation to avoid issues with domain of definitions of functions
|
||||
if fun == numpy.arccosh:
|
||||
# 0 is not in the domain of definition
|
||||
input_list = [1, 2, 42, 44]
|
||||
super_fun_list = [mix_x_and_y_and_call_f]
|
||||
elif fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan]:
|
||||
# Needs values between 0 and 1 in the call function
|
||||
input_list = [0, 2, 42, 44]
|
||||
super_fun_list = [mix_x_and_y_into_range_0_to_1_and_call_f]
|
||||
elif fun in [numpy.cosh, numpy.sinh, numpy.exp, numpy.exp2, numpy.expm1]:
|
||||
# Not too large values to avoid overflows
|
||||
input_list = [1, 2, 5, 11]
|
||||
super_fun_list = [mix_x_and_y_and_call_f, mix_x_and_y_intricately_and_call_f]
|
||||
else:
|
||||
# Regular case
|
||||
input_list = [0, 2, 42, 44]
|
||||
super_fun_list = [mix_x_and_y_and_call_f, mix_x_and_y_intricately_and_call_f]
|
||||
|
||||
for super_fun in super_fun_list:
|
||||
|
||||
for input_ in input_list:
|
||||
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: super_fun(fun, x, y)
|
||||
|
||||
function_to_trace = get_function_to_trace()
|
||||
|
||||
params_names = signature(function_to_trace).parameters.keys()
|
||||
|
||||
op_graph = trace_numpy_function(
|
||||
function_to_trace,
|
||||
{
|
||||
param_name: EncryptedTensor(Integer(32, True), tensor_shape)
|
||||
for param_name in params_names
|
||||
},
|
||||
)
|
||||
copied_graph = deepcopy(op_graph)
|
||||
orig_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(op_graph)
|
||||
fused_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(copied_graph)
|
||||
|
||||
# Check determinism
|
||||
assert format_operation_graph(copied_graph) == format_operation_graph(op_graph)
|
||||
|
||||
assert fused_num_nodes < orig_num_nodes
|
||||
|
||||
# Check that the call to the function or to the op_graph evaluation give the same
|
||||
# result
|
||||
tensor_diversifier = (
|
||||
# The following +1 in the range is to avoid to have 0's which is not in the
|
||||
# domain definition of some of our functions
|
||||
numpy.arange(1, numpy.product(tensor_shape) + 1, dtype=numpy.int32).reshape(
|
||||
tensor_shape
|
||||
)
|
||||
if tensor_shape != ()
|
||||
else 1
|
||||
)
|
||||
|
||||
if fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan]:
|
||||
# Domain of definition for these functions
|
||||
tensor_diversifier = (
|
||||
numpy.ones(tensor_shape, dtype=numpy.int32) if tensor_shape != () else 1
|
||||
)
|
||||
|
||||
input_ = numpy.int32(input_ * tensor_diversifier)
|
||||
|
||||
num_params = len(params_names)
|
||||
assert num_params == 2
|
||||
|
||||
# Create inputs which are either of the form [x, x] or [x, y]
|
||||
for j in range(4):
|
||||
|
||||
if fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan] and j > 0:
|
||||
# Domain of definition for these functions
|
||||
break
|
||||
|
||||
input_a = input_
|
||||
input_b = input_ + j
|
||||
|
||||
if tensor_shape != ():
|
||||
numpy.random.shuffle(input_a)
|
||||
numpy.random.shuffle(input_b)
|
||||
|
||||
inputs = (input_a, input_b) if random.randint(0, 1) == 0 else (input_b, input_a)
|
||||
|
||||
function_result = function_to_trace(*inputs)
|
||||
op_graph_result = op_graph(*inputs)
|
||||
|
||||
assert check_results_are_equal(function_result, op_graph_result)
|
||||
|
||||
|
||||
LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES = {
|
||||
numpy.bitwise_and,
|
||||
numpy.bitwise_or,
|
||||
numpy.bitwise_xor,
|
||||
numpy.gcd,
|
||||
numpy.lcm,
|
||||
numpy.ldexp,
|
||||
numpy.left_shift,
|
||||
numpy.logical_and,
|
||||
numpy.logical_not,
|
||||
numpy.logical_or,
|
||||
numpy.logical_xor,
|
||||
numpy.remainder,
|
||||
numpy.right_shift,
|
||||
}
|
||||
|
||||
|
||||
def subtest_fuse_float_binary_operations_correctness(fun, tensor_shape):
|
||||
"""Test a binary functions with fuse_float_operations, with a constant as a source."""
|
||||
|
||||
for i in range(4):
|
||||
|
||||
# Know if the function is defined for integer inputs
|
||||
if fun in LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES:
|
||||
if i not in [0, 2]:
|
||||
continue
|
||||
|
||||
# The .astype(numpy.float64) that we have in cases 0 and 2 is here to force
|
||||
# a float output even for functions which return an integer (eg, XOR), such
|
||||
# that our frontend always try to fuse them
|
||||
|
||||
# The .astype(numpy.float64) that we have in cases 1 and 3 is here to force
|
||||
# a float output even for functions which return a bool (eg, EQUAL), such
|
||||
# that our frontend always try to fuse them
|
||||
|
||||
# For bivariate functions: fix one of the inputs
|
||||
if i == 0:
|
||||
# With an integer in first position
|
||||
ones_0 = numpy.ones(tensor_shape, dtype=numpy.int32) if tensor_shape != () else 1
|
||||
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(3 * ones_0, x + y).astype(numpy.float64).astype(numpy.int32)
|
||||
|
||||
elif i == 1:
|
||||
# With a float in first position
|
||||
ones_1 = numpy.ones(tensor_shape, dtype=numpy.float64) if tensor_shape != () else 1
|
||||
|
||||
def get_function_to_trace():
|
||||
return (
|
||||
lambda x, y: fun(2.3 * ones_1, x + y).astype(numpy.float64).astype(numpy.int32)
|
||||
)
|
||||
|
||||
elif i == 2:
|
||||
# With an integer in second position
|
||||
ones_2 = numpy.ones(tensor_shape, dtype=numpy.int32) if tensor_shape != () else 1
|
||||
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(x + y, 4 * ones_2).astype(numpy.float64).astype(numpy.int32)
|
||||
|
||||
else:
|
||||
# With a float in second position
|
||||
ones_else = numpy.ones(tensor_shape, dtype=numpy.float64) if tensor_shape != () else 1
|
||||
|
||||
def get_function_to_trace():
|
||||
return (
|
||||
lambda x, y: fun(x + y, 5.7 * ones_else)
|
||||
.astype(numpy.float64)
|
||||
.astype(numpy.int32)
|
||||
)
|
||||
|
||||
input_list = [0, 2, 42, 44]
|
||||
|
||||
# Domain of definition
|
||||
if fun in [numpy.true_divide, numpy.remainder, numpy.floor_divide, numpy.fmod]:
|
||||
input_list = [2, 42, 44]
|
||||
|
||||
for input_ in input_list:
|
||||
function_to_trace = get_function_to_trace()
|
||||
|
||||
params_names = signature(function_to_trace).parameters.keys()
|
||||
|
||||
op_graph = trace_numpy_function(
|
||||
function_to_trace,
|
||||
{
|
||||
param_name: EncryptedTensor(Integer(32, True), tensor_shape)
|
||||
for param_name in params_names
|
||||
},
|
||||
)
|
||||
copied_graph = deepcopy(op_graph)
|
||||
orig_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(op_graph)
|
||||
fused_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(copied_graph)
|
||||
|
||||
# Check determinism
|
||||
assert format_operation_graph(copied_graph) == format_operation_graph(op_graph)
|
||||
|
||||
assert fused_num_nodes < orig_num_nodes
|
||||
|
||||
# Check that the call to the function or to the op_graph evaluation give the same
|
||||
# result
|
||||
tensor_diversifier = (
|
||||
# The following +1 in the range is to avoid to have 0's which is not in the
|
||||
# domain definition of some of our functions
|
||||
numpy.arange(1, numpy.product(tensor_shape) + 1, dtype=numpy.int32).reshape(
|
||||
tensor_shape
|
||||
)
|
||||
if tensor_shape != ()
|
||||
else numpy.int64(1)
|
||||
)
|
||||
# Make sure the tensor diversifier is a numpy variable, otherwise some cases may fail
|
||||
# as python int and float don't have the astype method
|
||||
input_ = input_ * tensor_diversifier
|
||||
|
||||
num_params = len(params_names)
|
||||
assert num_params == 2
|
||||
|
||||
# Create inputs which are either of the form [x, x] or [x, y]
|
||||
for j in range(4):
|
||||
inputs = (input_, input_ + j)
|
||||
|
||||
function_result = function_to_trace(*inputs)
|
||||
op_graph_result = op_graph(*inputs)
|
||||
|
||||
assert check_results_are_equal(function_result, op_graph_result)
|
||||
|
||||
|
||||
def subtest_fuse_float_binary_operations_dont_support_two_variables(fun, tensor_shape):
|
||||
"""Test a binary function with fuse_float_operations, with no constant as
|
||||
a source."""
|
||||
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(x, y).astype(numpy.int32)
|
||||
|
||||
function_to_trace = get_function_to_trace()
|
||||
|
||||
params_names = signature(function_to_trace).parameters.keys()
|
||||
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match=r"Can only have 1 non constant predecessor in _np_operator, got 2 for operator",
|
||||
):
|
||||
trace_numpy_function(
|
||||
function_to_trace,
|
||||
{
|
||||
param_name: EncryptedTensor(Integer(32, True), tensor_shape)
|
||||
for param_name in params_names
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fun", tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC)
|
||||
@pytest.mark.parametrize(
|
||||
"tensor_shape", [pytest.param((), id="scalar"), pytest.param((3, 1, 2), id="tensor")]
|
||||
)
|
||||
def test_ufunc_operations(fun, tensor_shape):
|
||||
"""Test functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC."""
|
||||
|
||||
if fun.nin == 1:
|
||||
subtest_fuse_float_unary_operations_correctness(fun, tensor_shape)
|
||||
elif fun.nin == 2:
|
||||
subtest_fuse_float_binary_operations_correctness(fun, tensor_shape)
|
||||
subtest_fuse_float_binary_operations_dont_support_two_variables(fun, tensor_shape)
|
||||
subtest_tensor_no_fuse(fun, tensor_shape)
|
||||
else:
|
||||
raise NotImplementedError("Only unary and binary functions are tested for now")
|
||||
@@ -1,433 +0,0 @@
|
||||
"""Test file for intermediate representation"""
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types.floats import Float
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.representation import intermediate as ir
|
||||
from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"node,input_data,expected_result",
|
||||
[
|
||||
pytest.param(
|
||||
ir.Add([EncryptedScalar(Integer(64, False)), EncryptedScalar(Integer(64, False))]),
|
||||
[10, 4589],
|
||||
4599,
|
||||
id="Add",
|
||||
),
|
||||
pytest.param(
|
||||
ir.Sub([EncryptedScalar(Integer(64, False)), EncryptedScalar(Integer(64, False))]),
|
||||
[10, 4589],
|
||||
-4579,
|
||||
id="Sub",
|
||||
),
|
||||
pytest.param(
|
||||
ir.Mul([EncryptedScalar(Integer(64, False)), EncryptedScalar(Integer(64, False))]),
|
||||
[10, 4589],
|
||||
45890,
|
||||
id="Mul",
|
||||
),
|
||||
pytest.param(ir.Input(ClearScalar(Integer(32, True)), "in", 0), [42], 42, id="Input"),
|
||||
pytest.param(ir.Constant(42), None, 42, id="Constant"),
|
||||
pytest.param(ir.Constant(-42), None, -42, id="Constant"),
|
||||
pytest.param(
|
||||
ir.GenericFunction(
|
||||
[EncryptedScalar(Integer(7, False))],
|
||||
lambda x: x + 3,
|
||||
EncryptedScalar(Integer(7, False)),
|
||||
op_kind="TLU",
|
||||
),
|
||||
[10],
|
||||
13,
|
||||
id="GenericFunction, x + 3",
|
||||
),
|
||||
pytest.param(
|
||||
ir.GenericFunction(
|
||||
[EncryptedScalar(Integer(7, False))],
|
||||
lambda x, y: x + y,
|
||||
EncryptedScalar(Integer(7, False)),
|
||||
op_kind="TLU",
|
||||
op_kwargs={"y": 3},
|
||||
),
|
||||
[10],
|
||||
13,
|
||||
id="GenericFunction, (x, y) -> x + y, where y is constant == 3",
|
||||
),
|
||||
pytest.param(
|
||||
ir.GenericFunction(
|
||||
[EncryptedScalar(Integer(7, False))],
|
||||
lambda x, y: y[x],
|
||||
EncryptedScalar(Integer(7, False)),
|
||||
op_kind="TLU",
|
||||
op_kwargs={"y": (1, 2, 3, 4)},
|
||||
),
|
||||
[2],
|
||||
3,
|
||||
id="GenericFunction, (x, y) -> y[x], where y is constant == (1, 2, 3, 4)",
|
||||
),
|
||||
pytest.param(
|
||||
ir.GenericFunction(
|
||||
[EncryptedScalar(Integer(7, False))],
|
||||
lambda x, y: y[3],
|
||||
EncryptedScalar(Integer(7, False)),
|
||||
op_kind="TLU",
|
||||
op_kwargs={"y": (1, 2, 3, 4)},
|
||||
),
|
||||
[2],
|
||||
4,
|
||||
id="GenericFunction, x, y -> y[3], where y is constant == (1, 2, 3, 4)",
|
||||
),
|
||||
pytest.param(
|
||||
ir.Dot(
|
||||
[
|
||||
EncryptedTensor(Integer(32, True), shape=(4,)),
|
||||
ClearTensor(Integer(32, True), shape=(4,)),
|
||||
],
|
||||
Integer(32, True),
|
||||
),
|
||||
[[1, 2, 3, 4], [4, 3, 2, 1]],
|
||||
20,
|
||||
id="Dot, [1, 2, 3, 4], [4, 3, 2, 1]",
|
||||
),
|
||||
pytest.param(
|
||||
ir.Dot(
|
||||
[
|
||||
EncryptedTensor(Float(32), shape=(4,)),
|
||||
ClearTensor(Float(32), shape=(4,)),
|
||||
],
|
||||
Float(32),
|
||||
),
|
||||
[[1.0, 2.0, 3.0, 4.0], [4.0, 3.0, 2.0, 1.0]],
|
||||
20,
|
||||
id="Dot, [1.0, 2.0, 3.0, 4.0], [4.0, 3.0, 2.0, 1.0]",
|
||||
),
|
||||
pytest.param(
|
||||
ir.Dot(
|
||||
[
|
||||
EncryptedTensor(Integer(32, True), shape=(4,)),
|
||||
ClearTensor(Integer(32, True), shape=(4,)),
|
||||
],
|
||||
Integer(32, True),
|
||||
delegate_evaluation_function=numpy.dot,
|
||||
),
|
||||
[
|
||||
numpy.array([1, 2, 3, 4], dtype=numpy.int32),
|
||||
numpy.array([4, 3, 2, 1], dtype=numpy.int32),
|
||||
],
|
||||
20,
|
||||
id="Dot, np.array([1, 2, 3, 4]), np.array([4, 3, 2, 1])",
|
||||
),
|
||||
pytest.param(
|
||||
ir.IndexConstant(EncryptedTensor(Integer(4, True), shape=(4,)), (0,)),
|
||||
[
|
||||
numpy.array([1, 2, 3, 4], dtype=numpy.int32),
|
||||
],
|
||||
1,
|
||||
id="IndexConstant, np.array([1, 2, 3, 4])[0]",
|
||||
),
|
||||
pytest.param(
|
||||
ir.IndexConstant(EncryptedTensor(Integer(4, True), shape=(4,)), (slice(1, 3, None),)),
|
||||
[
|
||||
numpy.array([1, 2, 3, 4], dtype=numpy.int32),
|
||||
],
|
||||
numpy.array([2, 3]),
|
||||
id="IndexConstant, np.array([1, 2, 3, 4])[1:3]",
|
||||
),
|
||||
pytest.param(
|
||||
ir.IndexConstant(EncryptedTensor(Integer(4, True), shape=(4,)), (slice(3, 1, -1),)),
|
||||
[
|
||||
numpy.array([1, 2, 3, 4], dtype=numpy.int32),
|
||||
],
|
||||
numpy.array([4, 3], dtype=numpy.int32),
|
||||
id="IndexConstant, np.array([1, 2, 3, 4])[3:1:-1]",
|
||||
),
|
||||
pytest.param(
|
||||
ir.IndexConstant(
|
||||
EncryptedTensor(Integer(5, True), shape=(4, 4)), (slice(1, 3, 1), slice(2, 0, -1))
|
||||
),
|
||||
[
|
||||
numpy.array(
|
||||
[
|
||||
[1, 2, 3, 4],
|
||||
[5, 6, 7, 8],
|
||||
[9, 10, 11, 12],
|
||||
[13, 14, 15, 16],
|
||||
],
|
||||
dtype=numpy.int32,
|
||||
),
|
||||
],
|
||||
numpy.array(
|
||||
[
|
||||
[7, 6],
|
||||
[11, 10],
|
||||
],
|
||||
dtype=numpy.int32,
|
||||
),
|
||||
id="IndexConstant, np.array([[1, 2, 3, 4]...[13, 14, 15, 16]])[1:3, 2:0:-1]",
|
||||
),
|
||||
pytest.param(
|
||||
ir.MatMul(
|
||||
[
|
||||
EncryptedTensor(Integer(32, True), shape=(3, 2)),
|
||||
ClearTensor(Integer(32, True), shape=(2, 3)),
|
||||
],
|
||||
Integer(32, True),
|
||||
(3, 3),
|
||||
),
|
||||
[numpy.arange(1, 7).reshape(3, 2), numpy.arange(1, 7).reshape(2, 3)],
|
||||
numpy.array([[9, 12, 15], [19, 26, 33], [29, 40, 51]]),
|
||||
id="MatMul, numpy.arange(1, 7).reshape(3, 2), numpy.arange(1, 7).reshape(2, 3)",
|
||||
),
|
||||
pytest.param(
|
||||
ir.GenericFunction(
|
||||
[EncryptedTensor(Integer(32, False), shape=(3, 5))],
|
||||
lambda x: numpy.transpose(x),
|
||||
EncryptedTensor(Integer(32, False), shape=(5, 3)),
|
||||
op_kind="Memory",
|
||||
),
|
||||
[numpy.arange(15).reshape(3, 5)],
|
||||
numpy.array([[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8, 13], [4, 9, 14]]),
|
||||
id="GenericFunction, x transpose",
|
||||
),
|
||||
pytest.param(
|
||||
ir.GenericFunction(
|
||||
[EncryptedTensor(Integer(32, False), shape=(3, 5))],
|
||||
lambda x: numpy.ravel(x),
|
||||
EncryptedTensor(Integer(32, False), shape=(5, 3)),
|
||||
op_kind="Memory",
|
||||
),
|
||||
[numpy.arange(15).reshape(3, 5)],
|
||||
numpy.arange(15),
|
||||
id="GenericFunction, x ravel",
|
||||
),
|
||||
pytest.param(
|
||||
ir.GenericFunction(
|
||||
[EncryptedTensor(Integer(32, False), shape=(3, 5))],
|
||||
lambda x: numpy.reshape(x, (5, 3)),
|
||||
output_value=EncryptedTensor(Integer(32, False), shape=(5, 3)),
|
||||
op_kind="Memory",
|
||||
),
|
||||
[numpy.arange(15).reshape(3, 5)],
|
||||
numpy.arange(15).reshape(5, 3),
|
||||
id="GenericFunction, x reshape",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_evaluate(
|
||||
node: ir.IntermediateNode,
|
||||
input_data,
|
||||
expected_result: int,
|
||||
check_array_equality,
|
||||
):
|
||||
"""Test evaluate methods on IntermediateNodes"""
|
||||
if isinstance(expected_result, numpy.ndarray):
|
||||
check_array_equality(node.evaluate(input_data), expected_result)
|
||||
else:
|
||||
assert node.evaluate(input_data) == expected_result
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"node1,node2,expected_result",
|
||||
[
|
||||
(
|
||||
ir.Add([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]),
|
||||
ir.Add([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]),
|
||||
True,
|
||||
),
|
||||
(
|
||||
ir.Add([EncryptedScalar(Integer(16, False)), EncryptedScalar(Integer(32, False))]),
|
||||
ir.Add([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(16, False))]),
|
||||
True,
|
||||
),
|
||||
(
|
||||
ir.Add([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]),
|
||||
ir.Sub([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]),
|
||||
False,
|
||||
),
|
||||
(
|
||||
ir.Sub([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]),
|
||||
ir.Sub([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]),
|
||||
True,
|
||||
),
|
||||
(
|
||||
ir.Sub([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(16, False))]),
|
||||
ir.Sub([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(16, False))]),
|
||||
True,
|
||||
),
|
||||
(
|
||||
ir.Sub([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(16, False))]),
|
||||
ir.Sub([EncryptedScalar(Integer(16, False)), EncryptedScalar(Integer(32, False))]),
|
||||
False,
|
||||
),
|
||||
(
|
||||
ir.Mul([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]),
|
||||
ir.Mul([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]),
|
||||
True,
|
||||
),
|
||||
(
|
||||
ir.Mul([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]),
|
||||
ir.Sub([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]),
|
||||
False,
|
||||
),
|
||||
(
|
||||
ir.Input(EncryptedScalar(Integer(32, False)), "x", 0),
|
||||
ir.Sub([EncryptedScalar(Integer(32, False)), EncryptedScalar(Integer(32, False))]),
|
||||
False,
|
||||
),
|
||||
(
|
||||
ir.Input(EncryptedScalar(Integer(32, False)), "x", 0),
|
||||
ir.Input(EncryptedScalar(Integer(32, False)), "x", 0),
|
||||
True,
|
||||
),
|
||||
(
|
||||
ir.Input(EncryptedScalar(Integer(32, False)), "x", 0),
|
||||
ir.Input(EncryptedScalar(Integer(32, False)), "y", 0),
|
||||
False,
|
||||
),
|
||||
(
|
||||
ir.Input(EncryptedScalar(Integer(32, False)), "x", 0),
|
||||
ir.Input(EncryptedScalar(Integer(32, False)), "x", 1),
|
||||
False,
|
||||
),
|
||||
(
|
||||
ir.Input(EncryptedScalar(Integer(32, False)), "x", 0),
|
||||
ir.Input(EncryptedScalar(Integer(8, False)), "x", 0),
|
||||
False,
|
||||
),
|
||||
(
|
||||
ir.Constant(10),
|
||||
ir.Constant(10),
|
||||
True,
|
||||
),
|
||||
(
|
||||
ir.Constant(10),
|
||||
ir.Input(EncryptedScalar(Integer(8, False)), "x", 0),
|
||||
False,
|
||||
),
|
||||
(
|
||||
ir.Constant(10),
|
||||
ir.Constant(10.0),
|
||||
False,
|
||||
),
|
||||
(
|
||||
ir.GenericFunction(
|
||||
[EncryptedScalar(Integer(8, False))],
|
||||
lambda x: x,
|
||||
EncryptedScalar(Integer(8, False)),
|
||||
op_kind="TLU",
|
||||
),
|
||||
ir.GenericFunction(
|
||||
[EncryptedScalar(Integer(8, False))],
|
||||
lambda x: x,
|
||||
EncryptedScalar(Integer(8, False)),
|
||||
op_kind="TLU",
|
||||
),
|
||||
True,
|
||||
),
|
||||
(
|
||||
ir.GenericFunction(
|
||||
[EncryptedScalar(Integer(8, False))],
|
||||
lambda x: x,
|
||||
EncryptedScalar(Integer(8, False)),
|
||||
op_kind="TLU",
|
||||
op_args=(1, 2, 3),
|
||||
),
|
||||
ir.GenericFunction(
|
||||
[EncryptedScalar(Integer(8, False))],
|
||||
lambda x: x,
|
||||
EncryptedScalar(Integer(8, False)),
|
||||
op_kind="TLU",
|
||||
),
|
||||
False,
|
||||
),
|
||||
(
|
||||
ir.GenericFunction(
|
||||
[EncryptedScalar(Integer(8, False))],
|
||||
lambda x: x,
|
||||
EncryptedScalar(Integer(8, False)),
|
||||
op_kind="TLU",
|
||||
op_kwargs={"tuple": (1, 2, 3)},
|
||||
),
|
||||
ir.GenericFunction(
|
||||
[EncryptedScalar(Integer(8, False))],
|
||||
lambda x: x,
|
||||
EncryptedScalar(Integer(8, False)),
|
||||
op_kind="TLU",
|
||||
),
|
||||
False,
|
||||
),
|
||||
(
|
||||
ir.Dot(
|
||||
[
|
||||
EncryptedTensor(Integer(32, True), shape=(4,)),
|
||||
ClearTensor(Integer(32, True), shape=(4,)),
|
||||
],
|
||||
Integer(32, True),
|
||||
delegate_evaluation_function=numpy.dot,
|
||||
),
|
||||
ir.Dot(
|
||||
[
|
||||
EncryptedTensor(Integer(32, True), shape=(4,)),
|
||||
ClearTensor(Integer(32, True), shape=(4,)),
|
||||
],
|
||||
Integer(32, True),
|
||||
delegate_evaluation_function=numpy.dot,
|
||||
),
|
||||
True,
|
||||
),
|
||||
(
|
||||
ir.Dot(
|
||||
[
|
||||
EncryptedTensor(Integer(32, True), shape=(4,)),
|
||||
ClearTensor(Integer(32, True), shape=(4,)),
|
||||
],
|
||||
Integer(32, True),
|
||||
delegate_evaluation_function=numpy.dot,
|
||||
),
|
||||
ir.Dot(
|
||||
[
|
||||
EncryptedTensor(Integer(32, True), shape=(4,)),
|
||||
ClearTensor(Integer(32, True), shape=(4,)),
|
||||
],
|
||||
Integer(32, True),
|
||||
),
|
||||
False,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_is_equivalent_to(
|
||||
node1: ir.IntermediateNode,
|
||||
node2: ir.IntermediateNode,
|
||||
expected_result: bool,
|
||||
test_helpers,
|
||||
):
|
||||
"""Test is_equivalent_to methods on IntermediateNodes"""
|
||||
assert (
|
||||
test_helpers.nodes_are_equivalent(node1, node2)
|
||||
== test_helpers.nodes_are_equivalent(node2, node1)
|
||||
== expected_result
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"list_to_fill,expected_list",
|
||||
[
|
||||
pytest.param([None, 1, 2, 3, None, None], [1, 1, 2, 3, 3, 3]),
|
||||
pytest.param([None], None, marks=pytest.mark.xfail(strict=True)),
|
||||
pytest.param([None, None, None, None, 7, None, None, None], [7, 7, 7, 7, 7, 7, 7, 7]),
|
||||
pytest.param([None, None, 3, None, None, None, 2, None], [3, 3, 3, 3, 3, 2, 2, 2]),
|
||||
],
|
||||
)
|
||||
def test_flood_replace_none_values(list_to_fill: list, expected_list: list):
|
||||
"""Unit test for flood_replace_none_values"""
|
||||
|
||||
# avoid modifying the test input
|
||||
list_to_fill_copy = deepcopy(list_to_fill)
|
||||
ir.flood_replace_none_values(list_to_fill_copy)
|
||||
|
||||
assert all(value is not None for value in list_to_fill_copy)
|
||||
assert list_to_fill_copy == expected_list
|
||||
@@ -1,74 +0,0 @@
|
||||
"""Test file for common helpers"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
|
||||
from concrete.common import check_op_graph_is_integer_program, is_a_power_of_2
|
||||
from concrete.common.data_types.floats import Float64
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.values import EncryptedScalar
|
||||
from concrete.numpy.tracing import trace_numpy_function
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"x,result",
|
||||
[
|
||||
(0, False),
|
||||
(1, True),
|
||||
(2, True),
|
||||
(3, False),
|
||||
(4, True),
|
||||
(10, False),
|
||||
(16, True),
|
||||
],
|
||||
)
|
||||
def test_is_a_power_of_2(x, result):
|
||||
"""Test function for test_is_a_power_of_2"""
|
||||
|
||||
assert is_a_power_of_2(x) == result
|
||||
|
||||
|
||||
def test_check_op_graph_is_integer_program():
|
||||
"""Test function for check_op_graph_is_integer_program"""
|
||||
|
||||
def function(x, y):
|
||||
return x + y - y * y + x * y
|
||||
|
||||
op_graph = trace_numpy_function(
|
||||
function, {"x": EncryptedScalar(Integer(64, True)), "y": EncryptedScalar(Integer(64, True))}
|
||||
)
|
||||
|
||||
# Test without and with output list
|
||||
offending_nodes = []
|
||||
assert check_op_graph_is_integer_program(op_graph)
|
||||
assert check_op_graph_is_integer_program(op_graph, offending_nodes)
|
||||
assert len(offending_nodes) == 0
|
||||
|
||||
op_graph_copy = deepcopy(op_graph)
|
||||
op_graph_copy.output_nodes[0].outputs[0].dtype = Float64
|
||||
|
||||
offending_nodes = []
|
||||
assert not check_op_graph_is_integer_program(op_graph_copy)
|
||||
assert not check_op_graph_is_integer_program(op_graph_copy, offending_nodes)
|
||||
assert len(offending_nodes) == 1
|
||||
assert offending_nodes == [op_graph_copy.output_nodes[0]]
|
||||
|
||||
op_graph_copy = deepcopy(op_graph)
|
||||
op_graph_copy.input_nodes[0].inputs[0].dtype = Float64
|
||||
|
||||
offending_nodes = []
|
||||
assert not check_op_graph_is_integer_program(op_graph_copy)
|
||||
assert not check_op_graph_is_integer_program(op_graph_copy, offending_nodes)
|
||||
assert len(offending_nodes) == 1
|
||||
assert offending_nodes == [op_graph_copy.input_nodes[0]]
|
||||
|
||||
op_graph_copy = deepcopy(op_graph)
|
||||
op_graph_copy.input_nodes[0].inputs[0].dtype = Float64
|
||||
op_graph_copy.input_nodes[1].inputs[0].dtype = Float64
|
||||
|
||||
offending_nodes = []
|
||||
assert not check_op_graph_is_integer_program(op_graph_copy)
|
||||
assert not check_op_graph_is_integer_program(op_graph_copy, offending_nodes)
|
||||
assert len(offending_nodes) == 2
|
||||
assert set(offending_nodes) == set([op_graph_copy.input_nodes[0], op_graph_copy.input_nodes[1]])
|
||||
@@ -1,54 +0,0 @@
|
||||
"""Test module for Circuit class"""
|
||||
|
||||
import filecmp
|
||||
|
||||
import concrete.numpy as hnp
|
||||
from concrete.common.debugging import draw_graph, format_operation_graph
|
||||
|
||||
|
||||
def test_circuit_str(default_compilation_configuration):
|
||||
"""Test function for `__str__` method of `Circuit`"""
|
||||
|
||||
def f(x):
|
||||
return x + 42
|
||||
|
||||
x = hnp.EncryptedScalar(hnp.UnsignedInteger(3))
|
||||
|
||||
inputset = range(2 ** 3)
|
||||
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration)
|
||||
|
||||
assert str(circuit) == format_operation_graph(circuit.op_graph)
|
||||
|
||||
|
||||
def test_circuit_draw(default_compilation_configuration):
|
||||
"""Test function for `draw` method of `Circuit`"""
|
||||
|
||||
def f(x):
|
||||
return x + 42
|
||||
|
||||
x = hnp.EncryptedScalar(hnp.UnsignedInteger(3))
|
||||
|
||||
inputset = range(2 ** 3)
|
||||
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration)
|
||||
|
||||
assert filecmp.cmp(circuit.draw(), draw_graph(circuit.op_graph))
|
||||
assert filecmp.cmp(circuit.draw(vertical=False), draw_graph(circuit.op_graph, vertical=False))
|
||||
|
||||
|
||||
def test_circuit_run(default_compilation_configuration):
|
||||
"""Test equivalence of encrypt/run/decrypt and encrypt_run_decrypt"""
|
||||
|
||||
def f(x):
|
||||
return x + 42
|
||||
|
||||
x = hnp.EncryptedScalar(hnp.UnsignedInteger(3))
|
||||
|
||||
inputset = range(2 ** 3)
|
||||
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration)
|
||||
|
||||
circuit.keygen()
|
||||
for x in inputset:
|
||||
enc_x = circuit.encrypt(x)
|
||||
enc_res = circuit.run(enc_x)
|
||||
res = circuit.decrypt(enc_res)
|
||||
assert circuit.encrypt_run_decrypt(x) == res
|
||||
@@ -1,44 +0,0 @@
|
||||
"""Test file for common tracing helpers"""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
|
||||
from concrete.common.tracing.tracing_helpers import prepare_function_parameters
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,function_parameters,ref_dict",
|
||||
[
|
||||
pytest.param(lambda x: None, {}, {}, id="Missing x", marks=pytest.mark.xfail(strict=True)),
|
||||
pytest.param(lambda x: None, {"x": None}, {"x": None}, id="Only x"),
|
||||
pytest.param(
|
||||
lambda x: None, {"x": None, "y": None}, {"x": None}, id="Additional y filtered"
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_prepare_function_parameters(
|
||||
function, function_parameters: Dict[str, Any], ref_dict: Dict[str, Any]
|
||||
):
|
||||
"""Test prepare_function_parameters"""
|
||||
prepared_dict = prepare_function_parameters(function, function_parameters)
|
||||
|
||||
assert prepared_dict == ref_dict
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,function_parameters,expected_ordered_keys",
|
||||
[
|
||||
(lambda x: None, {"x": None}, ["x"]),
|
||||
(lambda x, y: None, {"x": None, "y": None}, ["x", "y"]),
|
||||
(lambda x, y: None, {"y": None, "x": None}, ["x", "y"]),
|
||||
(lambda z, x, y: None, {"y": None, "z": None, "x": None}, ["z", "x", "y"]),
|
||||
],
|
||||
)
|
||||
def test_prepare_function_parameters_order(
|
||||
function, function_parameters: Dict[str, Any], expected_ordered_keys: List[str]
|
||||
):
|
||||
"""Test prepare_function_parameters output order"""
|
||||
prepared_dict = prepare_function_parameters(function, function_parameters)
|
||||
|
||||
assert list(prepared_dict.keys()) == expected_ordered_keys
|
||||
@@ -1,473 +0,0 @@
|
||||
"""PyTest configuration file"""
|
||||
import json
|
||||
import operator
|
||||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Iterable, Optional, Type
|
||||
|
||||
import networkx as nx
|
||||
import networkx.algorithms.isomorphism as iso
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.compilation import CompilationConfiguration
|
||||
from concrete.common.fhe_circuit import FHECircuit
|
||||
from concrete.common.mlir.utils import (
|
||||
ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB,
|
||||
get_op_graph_max_bit_width_and_nodes_over_bit_width_limit,
|
||||
)
|
||||
from concrete.common.representation.intermediate import (
|
||||
ALL_IR_NODES,
|
||||
Add,
|
||||
Constant,
|
||||
Conv2D,
|
||||
Dot,
|
||||
GenericFunction,
|
||||
IndexConstant,
|
||||
Input,
|
||||
IntermediateNode,
|
||||
MatMul,
|
||||
Mul,
|
||||
Sub,
|
||||
)
|
||||
from concrete.numpy import compile as compile_
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
"""Options for pytest"""
|
||||
|
||||
parser.addoption(
|
||||
"--global-coverage-infos-json",
|
||||
action="store",
|
||||
default=None,
|
||||
type=str,
|
||||
help="To dump pytest-cov term report to a text file.",
|
||||
)
|
||||
|
||||
parser.addoption(
|
||||
"--keyring-dir",
|
||||
action="store",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Specify the dir to use to store key cache",
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_KEYRING_PATH = Path.home().resolve() / ".cache/concrete-numpy_pytest"
|
||||
|
||||
|
||||
def get_keyring_dir_from_session_or_default(
|
||||
session: Optional[pytest.Session] = None,
|
||||
) -> Optional[Path]:
|
||||
"""Get keyring dir from test session."""
|
||||
if session is None:
|
||||
return DEFAULT_KEYRING_PATH
|
||||
|
||||
keyring_dir = session.config.getoption("--keyring-dir", default=None)
|
||||
if keyring_dir is not None:
|
||||
if keyring_dir.lower() == "disable":
|
||||
return None
|
||||
keyring_dir = Path(keyring_dir).expanduser().resolve()
|
||||
else:
|
||||
keyring_dir = DEFAULT_KEYRING_PATH
|
||||
return keyring_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_keyring_path():
|
||||
"""Fixture to get test keyring dir."""
|
||||
return DEFAULT_KEYRING_PATH
|
||||
|
||||
|
||||
# This is only for doctests where we currently cannot make use of fixtures
|
||||
original_compilation_config_init = CompilationConfiguration.__init__
|
||||
|
||||
|
||||
def monkeypatched_compilation_configuration_init_for_codeblocks(
|
||||
self: CompilationConfiguration, *args, **kwargs
|
||||
):
|
||||
"""Monkeypatched compilation configuration init for codeblocks tests."""
|
||||
original_compilation_config_init(self, *args, **kwargs)
|
||||
self.dump_artifacts_on_unexpected_failures = False
|
||||
self.enable_unsafe_features = True # This is for our tests only, never use that in prod
|
||||
self.treat_warnings_as_errors = True
|
||||
self.use_insecure_key_cache = True # This is for our tests only, never use that in prod
|
||||
|
||||
|
||||
def pytest_sessionstart(session: pytest.Session):
|
||||
"""Handle keyring for session and codeblocks CompilationConfiguration if needed."""
|
||||
if session.config.getoption("--codeblocks", default=False):
|
||||
# setattr to avoid mypy complaining
|
||||
# Disable the flake8 bug bear warning for the mypy fix
|
||||
setattr( # noqa: B010
|
||||
CompilationConfiguration,
|
||||
"__init__",
|
||||
monkeypatched_compilation_configuration_init_for_codeblocks,
|
||||
)
|
||||
|
||||
keyring_dir = get_keyring_dir_from_session_or_default(session)
|
||||
if keyring_dir is None:
|
||||
return
|
||||
keyring_dir.mkdir(parents=True, exist_ok=True)
|
||||
keyring_dir_as_str = str(keyring_dir)
|
||||
print(f"Using {keyring_dir_as_str} as key cache dir")
|
||||
compile_._COMPILE_FHE_INSECURE_KEY_CACHE_DIR = ( # pylint: disable=protected-access
|
||||
keyring_dir_as_str
|
||||
)
|
||||
|
||||
|
||||
def pytest_sessionfinish(session: pytest.Session, exitstatus): # pylint: disable=unused-argument
|
||||
"""Pytest callback when testing ends."""
|
||||
# Hacked together from the source code, they don't have an option to export to file and it's too
|
||||
# much work to get a PR in for such a little thing
|
||||
# https://github.com/pytest-dev/pytest-cov/blob/
|
||||
# ec344d8adf2d78238d8f07cb20ed2463d7536970/src/pytest_cov/plugin.py#L329
|
||||
if session.config.pluginmanager.hasplugin("_cov"):
|
||||
global_coverage_file = session.config.getoption(
|
||||
"--global-coverage-infos-json", default=None
|
||||
)
|
||||
if global_coverage_file is not None:
|
||||
cov_plugin = session.config.pluginmanager.getplugin("_cov")
|
||||
coverage_txt = cov_plugin.cov_report.getvalue()
|
||||
coverage_status = 0
|
||||
if (
|
||||
cov_plugin.options.cov_fail_under is not None
|
||||
and cov_plugin.options.cov_fail_under > 0
|
||||
):
|
||||
failed = cov_plugin.cov_total < cov_plugin.options.cov_fail_under
|
||||
# If failed is False coverage_status is 0, if True it's 1
|
||||
coverage_status = int(failed)
|
||||
global_coverage_file_path = Path(global_coverage_file).resolve()
|
||||
with open(global_coverage_file_path, "w", encoding="utf-8") as f:
|
||||
json.dump({"exit_code": coverage_status, "content": coverage_txt}, f)
|
||||
|
||||
keyring_dir = get_keyring_dir_from_session_or_default(session)
|
||||
if keyring_dir is not None:
|
||||
# Remove incomplete keys
|
||||
for incomplete_keys in keyring_dir.glob("**/*incomplete*"):
|
||||
shutil.rmtree(incomplete_keys, ignore_errors=True)
|
||||
|
||||
|
||||
def _is_equivalent_to_binary_commutative(lhs: IntermediateNode, rhs: object) -> bool:
|
||||
"""is_equivalent_to for a binary and commutative operation."""
|
||||
return (
|
||||
isinstance(rhs, lhs.__class__)
|
||||
and (lhs.inputs in (rhs.inputs, rhs.inputs[::-1]))
|
||||
and lhs.outputs == rhs.outputs
|
||||
)
|
||||
|
||||
|
||||
def _is_equivalent_to_binary_non_commutative(lhs: IntermediateNode, rhs: object) -> bool:
|
||||
"""is_equivalent_to for a binary and non-commutative operation."""
|
||||
return (
|
||||
isinstance(rhs, lhs.__class__) and lhs.inputs == rhs.inputs and lhs.outputs == rhs.outputs
|
||||
)
|
||||
|
||||
|
||||
def is_equivalent_add(lhs: Add, rhs: object) -> bool:
|
||||
"""Helper function to check if an Add node is equivalent to an other object."""
|
||||
return _is_equivalent_to_binary_commutative(lhs, rhs)
|
||||
|
||||
|
||||
# From https://stackoverflow.com/a/28635464
|
||||
_code_and_constants_attr_getter = operator.attrgetter("co_code", "co_consts")
|
||||
|
||||
|
||||
def _code_and_constants(object_):
|
||||
"""Helper function to get python code and constants"""
|
||||
return _code_and_constants_attr_getter(object_.__code__)
|
||||
|
||||
|
||||
def python_functions_are_equal_or_equivalent(lhs: object, rhs: object) -> bool:
|
||||
"""Helper function to check if two functions are equal or their code are equivalent.
|
||||
|
||||
This is not perfect, but will be good enough for tests.
|
||||
"""
|
||||
|
||||
if lhs == rhs:
|
||||
return True
|
||||
|
||||
try:
|
||||
lhs_code_and_constants = _code_and_constants(lhs)
|
||||
rhs_code_and_constants = _code_and_constants(rhs)
|
||||
return lhs_code_and_constants == rhs_code_and_constants
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
|
||||
def is_equivalent_arbitrary_function(lhs: GenericFunction, rhs: object) -> bool:
|
||||
"""Helper function to check if an GenericFunction node is equivalent to an other object."""
|
||||
return (
|
||||
isinstance(rhs, GenericFunction)
|
||||
and python_functions_are_equal_or_equivalent(lhs.arbitrary_func, rhs.arbitrary_func)
|
||||
and lhs.op_kind == rhs.op_kind
|
||||
and lhs.op_args == rhs.op_args
|
||||
and lhs.op_kwargs == rhs.op_kwargs
|
||||
and lhs.op_attributes == rhs.op_attributes
|
||||
and lhs.op_name == rhs.op_name
|
||||
and is_equivalent_intermediate_node(lhs, rhs)
|
||||
)
|
||||
|
||||
|
||||
def is_equivalent_constant(lhs: Constant, rhs: object) -> bool:
|
||||
"""Helper function to check if a Constant node is equivalent to an other object."""
|
||||
return (
|
||||
isinstance(rhs, Constant)
|
||||
and lhs.constant_data == rhs.constant_data
|
||||
and is_equivalent_intermediate_node(lhs, rhs)
|
||||
)
|
||||
|
||||
|
||||
def is_equivalent_dot(lhs: Dot, rhs: object) -> bool:
|
||||
"""Helper function to check if a Dot node is equivalent to an other object."""
|
||||
return (
|
||||
isinstance(rhs, Dot)
|
||||
and lhs.evaluation_function == rhs.evaluation_function
|
||||
and is_equivalent_intermediate_node(lhs, rhs)
|
||||
)
|
||||
|
||||
|
||||
def is_equivalent_input(lhs: Input, rhs: object) -> bool:
|
||||
"""Helper function to check if an Input node is equivalent to an other object."""
|
||||
return (
|
||||
isinstance(rhs, Input)
|
||||
and lhs.input_name == rhs.input_name
|
||||
and lhs.program_input_idx == rhs.program_input_idx
|
||||
and is_equivalent_intermediate_node(lhs, rhs)
|
||||
)
|
||||
|
||||
|
||||
def is_equivalent_index_constant(lhs: IndexConstant, rhs: object) -> bool:
|
||||
"""Helper function to check if an IndexConstant node is equivalent to an other object."""
|
||||
return (
|
||||
isinstance(rhs, IndexConstant)
|
||||
and lhs.index == rhs.index
|
||||
and is_equivalent_intermediate_node(lhs, rhs)
|
||||
)
|
||||
|
||||
|
||||
def is_equivalent_mul(lhs: Mul, rhs: object) -> bool:
|
||||
"""Helper function to check if a Mul node is equivalent to an other object."""
|
||||
return _is_equivalent_to_binary_commutative(lhs, rhs)
|
||||
|
||||
|
||||
def is_equivalent_sub(lhs: Sub, rhs: object) -> bool:
|
||||
"""Helper function to check if a Sub node is equivalent to an other object."""
|
||||
return _is_equivalent_to_binary_non_commutative(lhs, rhs)
|
||||
|
||||
|
||||
def is_equivalent_matmul(lhs: MatMul, rhs: object) -> bool:
|
||||
"""Helper function to check if a MatMul node is equivalent to an other object."""
|
||||
return isinstance(rhs, MatMul) and is_equivalent_intermediate_node(lhs, rhs)
|
||||
|
||||
|
||||
def is_equivalent_conv2d(lhs: Conv2D, rhs: object) -> bool:
|
||||
"""Helper function to check if a Conv2D node is equivalent to an other object."""
|
||||
return isinstance(rhs, Conv2D) and is_equivalent_intermediate_node(lhs, rhs)
|
||||
|
||||
|
||||
def is_equivalent_intermediate_node(lhs: IntermediateNode, rhs: object) -> bool:
|
||||
"""Helper function to check if an IntermediateNode node is equivalent to an other object."""
|
||||
return (
|
||||
isinstance(rhs, IntermediateNode)
|
||||
and lhs.inputs == rhs.inputs
|
||||
and lhs.outputs == rhs.outputs
|
||||
)
|
||||
|
||||
|
||||
EQUIVALENT_TEST_FUNC: Dict[Type, Callable[..., bool]] = {
|
||||
Add: is_equivalent_add,
|
||||
GenericFunction: is_equivalent_arbitrary_function,
|
||||
Constant: is_equivalent_constant,
|
||||
Conv2D: is_equivalent_conv2d,
|
||||
Dot: is_equivalent_dot,
|
||||
IndexConstant: is_equivalent_index_constant,
|
||||
Input: is_equivalent_input,
|
||||
Mul: is_equivalent_mul,
|
||||
Sub: is_equivalent_sub,
|
||||
MatMul: is_equivalent_matmul,
|
||||
}
|
||||
|
||||
_missing_nodes_in_mapping = ALL_IR_NODES - EQUIVALENT_TEST_FUNC.keys()
|
||||
assert len(_missing_nodes_in_mapping) == 0, (
|
||||
f"Missing IR node in EQUIVALENT_TEST_FUNC : "
|
||||
f"{', '.join(sorted(str(node_type) for node_type in _missing_nodes_in_mapping))}"
|
||||
)
|
||||
|
||||
del _missing_nodes_in_mapping
|
||||
|
||||
|
||||
class TestHelpers:
|
||||
"""Class allowing to pass helper functions to tests"""
|
||||
|
||||
@staticmethod
|
||||
def nodes_are_equivalent(lhs, rhs) -> bool:
|
||||
"""Helper function for tests to check if two nodes are equivalent."""
|
||||
equivalent_func = EQUIVALENT_TEST_FUNC.get(type(lhs), None)
|
||||
if equivalent_func is not None:
|
||||
return equivalent_func(lhs, rhs)
|
||||
|
||||
# This is a default for the test_conftest.py that should remain separate from the package
|
||||
# nodes is_equivalent_* functions
|
||||
return lhs.is_equivalent_to(rhs)
|
||||
|
||||
@staticmethod
|
||||
def digraphs_are_equivalent(reference: nx.MultiDiGraph, to_compare: nx.MultiDiGraph):
|
||||
"""Check that two digraphs are equivalent without modifications"""
|
||||
# edge_match is a copy of node_match
|
||||
edge_matcher = iso.categorical_multiedge_match(["input_idx", "output_idx"], [None, None])
|
||||
node_matcher = iso.generic_node_match(
|
||||
"_test_content", None, TestHelpers.nodes_are_equivalent
|
||||
)
|
||||
|
||||
# Set the _test_content for each node in the graphs
|
||||
for node in reference.nodes():
|
||||
reference.add_node(node, _test_content=node)
|
||||
|
||||
for node in to_compare.nodes():
|
||||
to_compare.add_node(node, _test_content=node)
|
||||
|
||||
graphs_are_isomorphic = nx.is_isomorphic(
|
||||
reference,
|
||||
to_compare,
|
||||
node_match=node_matcher,
|
||||
edge_match=edge_matcher,
|
||||
)
|
||||
|
||||
return graphs_are_isomorphic
|
||||
|
||||
@staticmethod
|
||||
def python_functions_are_equal_or_equivalent(lhs, rhs):
|
||||
"""Helper function to check if two functions are equal or their code are equivalent.
|
||||
|
||||
This is not perfect, but will be good enough for tests.
|
||||
"""
|
||||
return python_functions_are_equal_or_equivalent(lhs, rhs)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_helpers():
|
||||
"""Fixture to return the static helper class"""
|
||||
return TestHelpers
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_compilation_configuration():
|
||||
"""Return the default test compilation configuration"""
|
||||
return CompilationConfiguration(
|
||||
dump_artifacts_on_unexpected_failures=False,
|
||||
enable_unsafe_features=True, # This is for our tests only, never use that in prod
|
||||
treat_warnings_as_errors=True,
|
||||
use_insecure_key_cache=True, # This is for our tests only, never use that in prod
|
||||
)
|
||||
|
||||
|
||||
REMOVE_COLOR_CODES_RE = re.compile(r"\x1b[^m]*m")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def remove_color_codes():
|
||||
"""Return the re object to remove color codes"""
|
||||
return lambda x: REMOVE_COLOR_CODES_RE.sub("", x)
|
||||
|
||||
|
||||
def check_is_good_execution_impl(
|
||||
fhe_circuit: FHECircuit,
|
||||
function: Callable,
|
||||
args: Iterable[Any],
|
||||
preprocess_input_func: Callable[[Any], Any] = lambda x: x,
|
||||
postprocess_output_func: Callable[[Any], Any] = lambda x: x,
|
||||
check_function: Callable[[Any, Any], bool] = numpy.equal,
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""Run several times the check compiler_engine.run(*args) == function(*args). If always wrong,
|
||||
return an error. One can set the expected probability of success of one execution and the
|
||||
number of tests, to finetune the probability of bad luck, ie that we run several times the
|
||||
check and always have a wrong result."""
|
||||
max_bit_width, _ = get_op_graph_max_bit_width_and_nodes_over_bit_width_limit(
|
||||
fhe_circuit.op_graph
|
||||
)
|
||||
|
||||
# Allow tests to pass if cells of the output result are good at least once over the nb_tries
|
||||
# Enabled only when we have a circuit that's using the maximum possible bit width
|
||||
# >= if there are 8 bits signed integers
|
||||
allow_relaxed_tests_passing = max_bit_width >= ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB
|
||||
|
||||
# FIXME: https://github.com/zama-ai/concrete-numpy-internal/issues/1255
|
||||
# Increased with compiler accuracy which dropped, make sure to remove once accuracy improves
|
||||
nb_tries = 10
|
||||
|
||||
# Prepare the bool array to record if cells were properly computed
|
||||
preprocessed_args = tuple(preprocess_input_func(val) for val in args)
|
||||
cells_were_properly_computed = numpy.zeros_like(function(*preprocessed_args), dtype=bool)
|
||||
|
||||
for i in range(1, nb_tries + 1):
|
||||
preprocessed_args = tuple(preprocess_input_func(val) for val in args)
|
||||
last_engine_result = postprocess_output_func(
|
||||
fhe_circuit.encrypt_run_decrypt(*preprocessed_args)
|
||||
)
|
||||
last_function_result = postprocess_output_func(function(*preprocessed_args))
|
||||
|
||||
ok_execution = check_function(last_engine_result, last_function_result)
|
||||
if isinstance(ok_execution, numpy.ndarray):
|
||||
# Record the cells that were well computed
|
||||
cells_were_properly_computed = numpy.logical_or(
|
||||
cells_were_properly_computed, ok_execution
|
||||
)
|
||||
|
||||
# Get a boolean for the execution
|
||||
ok_execution = ok_execution.all()
|
||||
|
||||
if ok_execution:
|
||||
# Good computation after i tries
|
||||
if verbose:
|
||||
print(f"Good computation after {i} tries")
|
||||
return
|
||||
# FIXME: https://github.com/zama-ai/concrete-numpy-internal/issues/1264
|
||||
# Remove the relaxed tests once accuracy is good again for 7 bits
|
||||
if allow_relaxed_tests_passing and cells_were_properly_computed.all():
|
||||
print(
|
||||
"Computation was never good for all output cells at the same time, "
|
||||
f"however each was evaluated properly at least once, stopped after {i} tries"
|
||||
)
|
||||
return
|
||||
|
||||
raise AssertionError(
|
||||
f"bad computation after {nb_tries} tries.\nLast engine result:\n{last_engine_result}\n"
|
||||
f"Last function result:\n{last_function_result}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def check_is_good_execution():
|
||||
"""Fixture to seed torch"""
|
||||
|
||||
return check_is_good_execution_impl
|
||||
|
||||
|
||||
def check_array_equality_impl(actual: Any, expected: Any, verbose: bool = True):
|
||||
"""Assert that `actual` is equal to `expected`."""
|
||||
|
||||
assert numpy.array_equal(actual, expected), (
|
||||
""
|
||||
if not verbose
|
||||
else f"""
|
||||
|
||||
Expected Output
|
||||
===============
|
||||
{expected}
|
||||
|
||||
Actual Output
|
||||
=============
|
||||
{actual}
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def check_array_equality():
|
||||
"""Fixture to check array equality"""
|
||||
|
||||
return check_array_equality_impl
|
||||
@@ -1,57 +0,0 @@
|
||||
"""Test file for conftest helper functions"""
|
||||
import networkx as nx
|
||||
|
||||
|
||||
def test_digraphs_are_equivalent(test_helpers):
|
||||
"""Function to test digraphs_are_equivalent helper function"""
|
||||
|
||||
class TestNode:
|
||||
"""Dummy test node"""
|
||||
|
||||
computation: str
|
||||
|
||||
def __init__(self, computation: str) -> None:
|
||||
self.computation = computation
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.computation.__hash__()
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, self.__class__) and self.computation == other.computation
|
||||
|
||||
is_equivalent_to = __eq__
|
||||
|
||||
g_1 = nx.MultiDiGraph()
|
||||
g_2 = nx.MultiDiGraph()
|
||||
|
||||
t_0 = TestNode("Add")
|
||||
t_1 = TestNode("Mul")
|
||||
t_2 = TestNode("TLU")
|
||||
|
||||
g_1.add_edge(t_0, t_2, input_idx=0, output_idx=0)
|
||||
g_1.add_edge(t_1, t_2, input_idx=1, output_idx=0)
|
||||
|
||||
t0p = TestNode("Add")
|
||||
t1p = TestNode("Mul")
|
||||
t2p = TestNode("TLU")
|
||||
|
||||
g_2.add_edge(t1p, t2p, input_idx=1, output_idx=0)
|
||||
g_2.add_edge(t0p, t2p, input_idx=0, output_idx=0)
|
||||
|
||||
bad_g2 = nx.MultiDiGraph()
|
||||
|
||||
bad_t0 = TestNode("Not Add")
|
||||
|
||||
bad_g2.add_edge(bad_t0, t_2, input_idx=0, output_idx=0)
|
||||
bad_g2.add_edge(t_1, t_2, input_idx=1, output_idx=0)
|
||||
|
||||
bad_g3 = nx.MultiDiGraph()
|
||||
|
||||
bad_g3.add_edge(t_0, t_2, input_idx=1, output_idx=0)
|
||||
bad_g3.add_edge(t_1, t_2, input_idx=0, output_idx=0)
|
||||
|
||||
assert test_helpers.digraphs_are_equivalent(g_1, g_2), "Graphs should be equivalent"
|
||||
assert not test_helpers.digraphs_are_equivalent(g_1, bad_g2), "Graphs should not be equivalent"
|
||||
assert not test_helpers.digraphs_are_equivalent(g_2, bad_g2), "Graphs should not be equivalent"
|
||||
assert not test_helpers.digraphs_are_equivalent(g_1, bad_g3), "Graphs should not be equivalent"
|
||||
assert not test_helpers.digraphs_are_equivalent(g_2, bad_g3), "Graphs should not be equivalent"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,727 +0,0 @@
|
||||
"""Test module for constant indexing."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types import UnsignedInteger
|
||||
from concrete.common.values import EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import (
|
||||
compile_numpy_function,
|
||||
compile_numpy_function_into_op_graph_and_measure_bounds,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_value,function_with_indexing,output_value",
|
||||
[
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-3],
|
||||
EncryptedScalar(UnsignedInteger(1)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-2],
|
||||
EncryptedScalar(UnsignedInteger(1)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-1],
|
||||
EncryptedScalar(UnsignedInteger(1)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[0],
|
||||
EncryptedScalar(UnsignedInteger(1)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[1],
|
||||
EncryptedScalar(UnsignedInteger(1)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[2],
|
||||
EncryptedScalar(UnsignedInteger(1)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[:],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-3:],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-2:],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-1:],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[0:],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[1:],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[2:],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[:-2],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[:1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[:2],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[:3],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-3:-2],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-3:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-3:1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-3:2],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-3:3],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-2:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-2:2],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-2:3],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-1:3],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[0:-2],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[0:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[0:1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[0:2],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[0:3],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[1:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[1:2],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[1:3],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[2:3],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[::-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-3::-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-2::-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-1::-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[0::-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[1::-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[2::-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[:-3:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[:-2:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[:0:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[:1:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[2:0:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[2:1:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-1:1:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[-1:0:-1],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
lambda x: x[:, :, :],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
lambda x: x[0, :, :],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(4, 5)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
lambda x: x[:, 0, :],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 5)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
lambda x: x[:, :, 0],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
lambda x: x[0, 0, :],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(5,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
lambda x: x[0, :, 0],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(4,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
lambda x: x[:, 0, 0],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
lambda x: x[0:, 1:, 2:],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 3, 3)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
lambda x: x[2:, 1:, 0:],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(1, 3, 5)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
lambda x: x[0],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(4, 5)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
lambda x: x[0, 0],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(5,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3, 4, 5)),
|
||||
lambda x: x[0, 0, 0],
|
||||
EncryptedScalar(UnsignedInteger(1)),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_constant_indexing(
|
||||
default_compilation_configuration,
|
||||
input_value,
|
||||
function_with_indexing,
|
||||
output_value,
|
||||
):
|
||||
"""Test compile_numpy_function_into_op_graph with constant indexing"""
|
||||
|
||||
inputset = [
|
||||
np.random.randint(
|
||||
input_value.dtype.min_value(),
|
||||
input_value.dtype.max_value() + 1,
|
||||
size=input_value.shape,
|
||||
)
|
||||
for _ in range(10)
|
||||
]
|
||||
|
||||
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function_with_indexing,
|
||||
{"x": input_value},
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
assert len(op_graph.output_nodes) == 1
|
||||
output_node = op_graph.output_nodes[0]
|
||||
|
||||
assert len(output_node.outputs) == 1
|
||||
assert output_value == output_node.outputs[0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_value,function_with_indexing,expected_error_type,expected_error_message",
|
||||
[
|
||||
pytest.param(
|
||||
EncryptedScalar(UnsignedInteger(1)),
|
||||
lambda x: x[0],
|
||||
TypeError,
|
||||
"Only tensors can be indexed but you tried to index EncryptedScalar<uint1>",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[0.5],
|
||||
TypeError,
|
||||
"Only integers and integer slices can be used for indexing "
|
||||
"but you tried to use 0.5 for indexing",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[1:5:0.5], # type: ignore
|
||||
TypeError,
|
||||
"Only integers and integer slices can be used for indexing "
|
||||
"but you tried to use 1:5:0.5 for indexing",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[0, 1],
|
||||
ValueError,
|
||||
"Tensor of shape (3,) cannot be indexed with [0, 1] "
|
||||
"as the index has more elements than the number of dimensions of the tensor",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[5],
|
||||
ValueError,
|
||||
"Tensor of shape (3,) cannot be indexed with [5] "
|
||||
"because index is out of range for dimension 0",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[5:],
|
||||
ValueError,
|
||||
"Tensor of shape (3,) cannot be indexed with [5:] "
|
||||
"because start index is out of range for dimension 0",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[:10],
|
||||
ValueError,
|
||||
"Tensor of shape (3,) cannot be indexed with [:10] "
|
||||
"because stop index is out of range for dimension 0",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[2:0],
|
||||
ValueError,
|
||||
"Tensor of shape (3,) cannot be indexed with [2:0] "
|
||||
"because start index is not less than stop index for dimension 0",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[5::-1],
|
||||
ValueError,
|
||||
"Tensor of shape (3,) cannot be indexed with [5::-1] "
|
||||
"because start index is out of range for dimension 0",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[:10:-1],
|
||||
ValueError,
|
||||
"Tensor of shape (3,) cannot be indexed with [:10:-1] "
|
||||
"because stop index is out of range for dimension 0",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[0:2:-1],
|
||||
ValueError,
|
||||
"Tensor of shape (3,) cannot be indexed with [0:2:-1] "
|
||||
"because step is negative and stop index is not less than start index for dimension 0",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[::0],
|
||||
ValueError,
|
||||
"Tensor of shape (3,) cannot be indexed with [::0] "
|
||||
"because step is zero for dimension 0",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_constant_indexing(
|
||||
default_compilation_configuration,
|
||||
input_value,
|
||||
function_with_indexing,
|
||||
expected_error_type,
|
||||
expected_error_message,
|
||||
):
|
||||
"""Test compile_numpy_function_into_op_graph with invalid constant indexing"""
|
||||
|
||||
with pytest.raises(expected_error_type):
|
||||
try:
|
||||
inputset = [
|
||||
(
|
||||
np.random.randint(
|
||||
input_value.dtype.min_value(),
|
||||
input_value.dtype.max_value() + 1,
|
||||
size=input_value.shape,
|
||||
),
|
||||
)
|
||||
for _ in range(10)
|
||||
]
|
||||
compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function_with_indexing,
|
||||
{"x": input_value},
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
except Exception as error:
|
||||
assert str(error) == expected_error_message
|
||||
raise
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_value,function_with_indexing,output_value",
|
||||
[
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[np.uint32(0)],
|
||||
EncryptedScalar(UnsignedInteger(1)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[slice(np.uint32(2), np.int32(0), np.int8(-1))],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[np.array(0)],
|
||||
EncryptedScalar(UnsignedInteger(1)),
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[slice(np.array(2), np.array(0), np.array(-1))],
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(2,)),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_constant_indexing_with_numpy_integers(
|
||||
default_compilation_configuration,
|
||||
input_value,
|
||||
function_with_indexing,
|
||||
output_value,
|
||||
):
|
||||
"""Test compile_numpy_function_into_op_graph with constant indexing with numpy integers"""
|
||||
|
||||
inputset = [
|
||||
np.random.randint(
|
||||
input_value.dtype.min_value(),
|
||||
input_value.dtype.max_value() + 1,
|
||||
size=input_value.shape,
|
||||
)
|
||||
for _ in range(10)
|
||||
]
|
||||
|
||||
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function_with_indexing,
|
||||
{"x": input_value},
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
assert len(op_graph.output_nodes) == 1
|
||||
output_node = op_graph.output_nodes[0]
|
||||
|
||||
assert len(output_node.outputs) == 1
|
||||
assert output_value == output_node.outputs[0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_value,function_with_indexing,expected_error_type,expected_error_message",
|
||||
[
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[np.float32(1.5)],
|
||||
TypeError,
|
||||
"Only integers and integer slices can be used for indexing "
|
||||
"but you tried to use 1.5 for indexing",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[np.array(1.5)],
|
||||
TypeError,
|
||||
"Only integers and integer slices can be used for indexing "
|
||||
"but you tried to use 1.5 for indexing",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
lambda x: x[np.array([1, 2])],
|
||||
TypeError,
|
||||
"Only integers and integer slices can be used for indexing "
|
||||
"but you tried to use [1 2] for indexing",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_constant_indexing_with_numpy_values(
|
||||
default_compilation_configuration,
|
||||
input_value,
|
||||
function_with_indexing,
|
||||
expected_error_type,
|
||||
expected_error_message,
|
||||
):
|
||||
"""Test compile_numpy_function_into_op_graph with invalid constant indexing with numpy values"""
|
||||
|
||||
with pytest.raises(expected_error_type):
|
||||
try:
|
||||
inputset = [
|
||||
(
|
||||
np.random.randint(
|
||||
input_value.dtype.min_value(),
|
||||
input_value.dtype.max_value() + 1,
|
||||
size=input_value.shape,
|
||||
),
|
||||
)
|
||||
for _ in range(10)
|
||||
]
|
||||
compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function_with_indexing,
|
||||
{"x": input_value},
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
except Exception as error:
|
||||
assert str(error) == expected_error_message
|
||||
raise
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,inputset,test_input,expected_output",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: x[0],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3,)),
|
||||
},
|
||||
[np.random.randint(0, 2 ** 3, size=(3,)) for _ in range(10)],
|
||||
([4, 2, 6],),
|
||||
4,
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x[-1],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3,)),
|
||||
},
|
||||
[np.random.randint(0, 2 ** 3, size=(3,)) for _ in range(10)],
|
||||
([4, 2, 6],),
|
||||
6,
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x[:3],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(4,)),
|
||||
},
|
||||
[np.random.randint(0, 2 ** 3, size=(4,)) for _ in range(10)],
|
||||
([4, 2, 6, 1],),
|
||||
[4, 2, 6],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x[2:],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(4,)),
|
||||
},
|
||||
[np.random.randint(0, 2 ** 3, size=(4,)) for _ in range(10)],
|
||||
([4, 2, 6, 1],),
|
||||
[6, 1],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x[1:3],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(4,)),
|
||||
},
|
||||
[np.random.randint(0, 2 ** 3, size=(4,)) for _ in range(10)],
|
||||
([4, 2, 6, 1],),
|
||||
[2, 6],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x[::2],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(4,)),
|
||||
},
|
||||
[np.random.randint(0, 2 ** 3, size=(4,)) for _ in range(10)],
|
||||
([4, 2, 6, 1],),
|
||||
[4, 6],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x[::-1],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(4,)),
|
||||
},
|
||||
[np.random.randint(0, 2 ** 3, size=(4,)) for _ in range(10)],
|
||||
([4, 2, 6, 1],),
|
||||
[1, 6, 2, 4],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x[1, 0],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(6), shape=(3, 2)),
|
||||
},
|
||||
[np.random.randint(0, 2 ** 6, size=(3, 2)) for _ in range(10)],
|
||||
([[11, 12], [21, 22], [31, 32]],),
|
||||
21,
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x[:, :],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(6), shape=(3, 2)),
|
||||
},
|
||||
[np.random.randint(0, 2 ** 6, size=(3, 2)) for _ in range(10)],
|
||||
([[11, 12], [21, 22], [31, 32]],),
|
||||
[[11, 12], [21, 22], [31, 32]],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x[0, :],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(6), shape=(3, 2)),
|
||||
},
|
||||
[np.random.randint(0, 2 ** 6, size=(3, 2)) for _ in range(10)],
|
||||
([[11, 12], [21, 22], [31, 32]],),
|
||||
[11, 12],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x[:, 0],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(6), shape=(3, 2)),
|
||||
},
|
||||
[np.random.randint(0, 2 ** 6, size=(3, 2)) for _ in range(10)],
|
||||
([[11, 12], [21, 22], [31, 32]],),
|
||||
[11, 21, 31],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_constant_indexing_run_correctness(
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
test_input,
|
||||
expected_output,
|
||||
default_compilation_configuration,
|
||||
check_array_equality,
|
||||
):
|
||||
"""Test correctness of results when running a compiled function with tensor operators"""
|
||||
circuit = compile_numpy_function(
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
numpy_test_input = tuple(
|
||||
item if isinstance(item, int) else np.array(item, dtype=np.uint8) for item in test_input
|
||||
)
|
||||
|
||||
output = circuit.encrypt_run_decrypt(*numpy_test_input)
|
||||
expected = np.array(expected_output, dtype=np.uint8)
|
||||
|
||||
check_array_equality(output, expected)
|
||||
@@ -1,44 +0,0 @@
|
||||
"""Test module for convolution compilation and execution."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import concrete.numpy as hnp
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.values.tensors import EncryptedTensor
|
||||
from concrete.numpy.compile import compile_numpy_function
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_shape, weight_shape",
|
||||
[
|
||||
pytest.param((1, 1, 4, 4), (1, 1, 2, 2)),
|
||||
pytest.param((4, 3, 4, 4), (2, 3, 2, 2)),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("strides", [(2, 2)])
|
||||
@pytest.mark.parametrize("dilations", [(1, 1)])
|
||||
@pytest.mark.parametrize("has_bias", [True, False])
|
||||
def test_compile_and_run(
|
||||
input_shape, weight_shape, strides, dilations, has_bias, default_compilation_configuration
|
||||
):
|
||||
"""Test function to make sure compilation and execution of conv2d works properly"""
|
||||
if has_bias:
|
||||
bias = np.random.randint(0, 4, size=(weight_shape[0],))
|
||||
else:
|
||||
bias = None
|
||||
weight = np.random.randint(0, 4, size=weight_shape)
|
||||
|
||||
def conv(x):
|
||||
return hnp.conv2d(x, weight, bias, strides=strides, dilations=dilations)
|
||||
|
||||
compiler_engine = compile_numpy_function(
|
||||
conv,
|
||||
{"x": EncryptedTensor(Integer(64, False), input_shape)},
|
||||
[np.random.randint(0, 4, size=input_shape) for i in range(20)],
|
||||
default_compilation_configuration,
|
||||
)
|
||||
x = np.random.randint(0, 4, size=input_shape, dtype=np.uint8)
|
||||
expected = conv(x)
|
||||
result = compiler_engine.encrypt_run_decrypt(x)
|
||||
assert (expected == result).all()
|
||||
@@ -1,265 +0,0 @@
|
||||
"""Test module for memory operations."""
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types import UnsignedInteger
|
||||
from concrete.common.values import EncryptedTensor
|
||||
from concrete.numpy import compile_numpy_function
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,inputset,test_input,expected_output",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: x.flatten(),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(3, 2)) for _ in range(10)],
|
||||
[[0, 1], [1, 2], [2, 3]],
|
||||
[0, 1, 1, 2, 2, 3],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.flatten(),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4, 5, 6)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4, 5, 6)) for _ in range(10)],
|
||||
(numpy.arange(720) % 10).reshape((2, 3, 4, 5, 6)),
|
||||
(numpy.arange(720) % 10),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((1, 3)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(3,)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(3,)) for _ in range(10)],
|
||||
[5, 9, 1],
|
||||
[[5, 9, 1]],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((3, 1)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(3,)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(3,)) for _ in range(10)],
|
||||
[5, 9, 1],
|
||||
[[5], [9], [1]],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((3, 2)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(3, 2)) for _ in range(10)],
|
||||
[[0, 1], [1, 2], [2, 3]],
|
||||
[[0, 1], [1, 2], [2, 3]],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((3, 2)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3)) for _ in range(10)],
|
||||
[[0, 1, 1], [2, 2, 3]],
|
||||
[[0, 1], [1, 2], [2, 3]],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape(-1),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(3, 2)) for _ in range(10)],
|
||||
[[0, 1], [1, 2], [2, 3]],
|
||||
[0, 1, 1, 2, 2, 3],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((2, 2, 3)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(4, 3)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(4, 3)) for _ in range(10)],
|
||||
(numpy.arange(12) % 10).reshape((4, 3)),
|
||||
(numpy.arange(12) % 10).reshape((2, 2, 3)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((4, 3)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 2, 3)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 2, 3)) for _ in range(10)],
|
||||
(numpy.arange(12) % 10).reshape((2, 2, 3)),
|
||||
(numpy.arange(12) % 10).reshape((4, 3)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((3, 2, 2)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 4)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(3, 4)) for _ in range(10)],
|
||||
(numpy.arange(12) % 10).reshape((3, 4)),
|
||||
(numpy.arange(12) % 10).reshape((3, 2, 2)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((3, 4)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2, 2)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(3, 2, 2)) for _ in range(10)],
|
||||
(numpy.arange(12) % 10).reshape((3, 2, 2)),
|
||||
(numpy.arange(12) % 10).reshape((3, 4)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((5, 3, 2)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(6, 5)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(6, 5)) for _ in range(10)],
|
||||
(numpy.arange(30) % 10).reshape((6, 5)),
|
||||
(numpy.arange(30) % 10).reshape((5, 3, 2)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((5, 6)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 5)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 5)) for _ in range(10)],
|
||||
(numpy.arange(30) % 10).reshape((2, 3, 5)),
|
||||
(numpy.arange(30) % 10).reshape((5, 6)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((6, 4, 30)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4, 5, 6)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4, 5, 6)) for _ in range(10)],
|
||||
(numpy.arange(720) % 10).reshape((2, 3, 4, 5, 6)),
|
||||
(numpy.arange(720) % 10).reshape((6, 4, 30)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((2, 60, 6)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4, 5, 6)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4, 5, 6)) for _ in range(10)],
|
||||
(numpy.arange(720) % 10).reshape((2, 3, 4, 5, 6)),
|
||||
(numpy.arange(720) % 10).reshape((2, 60, 6)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((6, 6, -1)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 2, 3, 4)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 2, 3, 4)) for _ in range(10)],
|
||||
(numpy.arange(144) % 10).reshape((2, 3, 2, 3, 4)),
|
||||
(numpy.arange(144) % 10).reshape((6, 6, -1)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((6, -1, 12)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 2, 3, 4)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 2, 3, 4)) for _ in range(10)],
|
||||
(numpy.arange(144) % 10).reshape((2, 3, 2, 3, 4)),
|
||||
(numpy.arange(144) % 10).reshape((6, -1, 12)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((-1, 18, 4)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 2, 3, 4)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 2, 3, 4)) for _ in range(10)],
|
||||
(numpy.arange(144) % 10).reshape((2, 3, 2, 3, 4)),
|
||||
(numpy.arange(144) % 10).reshape((-1, 18, 4)),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_memory_operation_run_correctness(
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
test_input,
|
||||
expected_output,
|
||||
default_compilation_configuration,
|
||||
check_array_equality,
|
||||
):
|
||||
"""
|
||||
Test correctness of results when running a compiled function with memory operators.
|
||||
|
||||
e.g.,
|
||||
- flatten
|
||||
- reshape
|
||||
"""
|
||||
circuit = compile_numpy_function(
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
actual = circuit.encrypt_run_decrypt(numpy.array(test_input, dtype=numpy.uint8))
|
||||
expected = numpy.array(expected_output, dtype=numpy.uint8)
|
||||
|
||||
check_array_equality(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,inputset,error,match",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: x.reshape((-1, -1, 2)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4)) for _ in range(10)],
|
||||
ValueError,
|
||||
"shapes are not compatible (old shape (2, 3, 4), new shape (-1, -1, 2))",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((3, -1, 3)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4)) for _ in range(10)],
|
||||
ValueError,
|
||||
"shapes are not compatible (old shape (2, 3, 4), new shape (3, -1, 3))",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_memory_operation_failed_compilation(
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
error,
|
||||
match,
|
||||
default_compilation_configuration,
|
||||
):
|
||||
"""
|
||||
Test compilation failures of compiled function with memory operations.
|
||||
|
||||
e.g.,
|
||||
- reshape
|
||||
"""
|
||||
|
||||
with pytest.raises(error) as excinfo:
|
||||
compile_numpy_function(
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
assert (
|
||||
str(excinfo.value) == match
|
||||
), f"""
|
||||
|
||||
Actual Output
|
||||
=============
|
||||
{excinfo.value}
|
||||
|
||||
Expected Output
|
||||
===============
|
||||
{match}
|
||||
|
||||
"""
|
||||
@@ -1,280 +0,0 @@
|
||||
"""Test file for user-friendly numpy compilation functions"""
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.debugging import format_operation_graph
|
||||
from concrete.numpy.np_fhe_compiler import NPFHECompiler
|
||||
|
||||
|
||||
def complicated_topology(x, y):
|
||||
"""Mix x in an intricated way."""
|
||||
intermediate = x + y
|
||||
x_p_1 = intermediate + 1
|
||||
x_p_2 = intermediate + 2
|
||||
x_p_3 = x_p_1 + x_p_2
|
||||
return (
|
||||
x_p_3.astype(numpy.int32),
|
||||
x_p_2.astype(numpy.int32),
|
||||
(x_p_2 + 3).astype(numpy.int32),
|
||||
x_p_3.astype(numpy.int32) + 67,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("input_shape", [(), (3, 1, 2)])
|
||||
def test_np_fhe_compiler_op_graph(
|
||||
input_shape, default_compilation_configuration, check_array_equality
|
||||
):
|
||||
"""Test NPFHECompiler in two subtests."""
|
||||
subtest_np_fhe_compiler_1_input_op_graph(
|
||||
input_shape, default_compilation_configuration, check_array_equality
|
||||
)
|
||||
subtest_np_fhe_compiler_2_inputs_op_graph(
|
||||
input_shape, default_compilation_configuration, check_array_equality
|
||||
)
|
||||
|
||||
|
||||
def subtest_np_fhe_compiler_1_input_op_graph(
|
||||
input_shape, default_compilation_configuration, check_array_equality
|
||||
):
|
||||
"""test for NPFHECompiler on one input function"""
|
||||
|
||||
def function_to_compile(x):
|
||||
return complicated_topology(x, 0)
|
||||
|
||||
compiler = NPFHECompiler(
|
||||
function_to_compile,
|
||||
{"x": "encrypted"},
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
# For coverage when the OPGraph is not yet traced
|
||||
compiler._patch_op_graph_input_to_accept_any_integer_input() # pylint: disable=protected-access
|
||||
|
||||
assert compiler.compilation_configuration == default_compilation_configuration
|
||||
assert compiler.compilation_configuration is not default_compilation_configuration
|
||||
|
||||
for i in numpy.arange(5):
|
||||
i = numpy.ones(input_shape, dtype=numpy.int64) * i
|
||||
check_array_equality(compiler(i), function_to_compile(i))
|
||||
|
||||
# For coverage, check that we flush the inputset when we query the OPGraph
|
||||
current_op_graph = compiler.op_graph
|
||||
assert current_op_graph is not compiler.op_graph
|
||||
assert len(compiler._current_inputset) == 0 # pylint: disable=protected-access
|
||||
# For coverage, cover case where the current inputset is empty
|
||||
compiler._eval_on_current_inputset() # pylint: disable=protected-access
|
||||
|
||||
# Continue a bit more
|
||||
for i in numpy.arange(5, 10):
|
||||
i = numpy.ones(input_shape, dtype=numpy.int64) * i
|
||||
check_array_equality(compiler(i), function_to_compile(i))
|
||||
|
||||
if input_shape == ():
|
||||
assert (
|
||||
(got := format_operation_graph(compiler.op_graph))
|
||||
== """ %0 = 67 # ClearScalar<uint7>
|
||||
%1 = 2 # ClearScalar<uint2>
|
||||
%2 = 3 # ClearScalar<uint2>
|
||||
%3 = 1 # ClearScalar<uint1>
|
||||
%4 = x # EncryptedScalar<uint4>
|
||||
%5 = 0 # ClearScalar<uint1>
|
||||
%6 = add(%4, %5) # EncryptedScalar<uint4>
|
||||
%7 = add(%6, %1) # EncryptedScalar<uint4>
|
||||
%8 = add(%6, %3) # EncryptedScalar<uint4>
|
||||
%9 = astype(%7, dtype=int32) # EncryptedScalar<uint4>
|
||||
%10 = add(%7, %2) # EncryptedScalar<uint4>
|
||||
%11 = add(%8, %7) # EncryptedScalar<uint5>
|
||||
%12 = astype(%10, dtype=int32) # EncryptedScalar<uint4>
|
||||
%13 = astype(%11, dtype=int32) # EncryptedScalar<uint5>
|
||||
%14 = astype(%11, dtype=int32) # EncryptedScalar<uint5>
|
||||
%15 = add(%14, %0) # EncryptedScalar<uint7>
|
||||
(%13, %9, %12, %15)"""
|
||||
), got
|
||||
else:
|
||||
assert (
|
||||
(got := format_operation_graph(compiler.op_graph))
|
||||
== """ %0 = 67 # ClearScalar<uint7>
|
||||
%1 = 2 # ClearScalar<uint2>
|
||||
%2 = 3 # ClearScalar<uint2>
|
||||
%3 = 1 # ClearScalar<uint1>
|
||||
%4 = x # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
||||
%5 = 0 # ClearScalar<uint1>
|
||||
%6 = add(%4, %5) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
||||
%7 = add(%6, %1) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
||||
%8 = add(%6, %3) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
||||
%9 = astype(%7, dtype=int32) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
||||
%10 = add(%7, %2) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
||||
%11 = add(%8, %7) # EncryptedTensor<uint5, shape=(3, 1, 2)>
|
||||
%12 = astype(%10, dtype=int32) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
||||
%13 = astype(%11, dtype=int32) # EncryptedTensor<uint5, shape=(3, 1, 2)>
|
||||
%14 = astype(%11, dtype=int32) # EncryptedTensor<uint5, shape=(3, 1, 2)>
|
||||
%15 = add(%14, %0) # EncryptedTensor<uint7, shape=(3, 1, 2)>
|
||||
(%13, %9, %12, %15)"""
|
||||
), got
|
||||
|
||||
|
||||
def subtest_np_fhe_compiler_2_inputs_op_graph(
|
||||
input_shape, default_compilation_configuration, check_array_equality
|
||||
):
|
||||
"""test for NPFHECompiler on two inputs function"""
|
||||
|
||||
compiler = NPFHECompiler(
|
||||
complicated_topology,
|
||||
{"x": "encrypted", "y": "clear"},
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
# For coverage when the OPGraph is not yet traced
|
||||
compiler._patch_op_graph_input_to_accept_any_integer_input() # pylint: disable=protected-access
|
||||
|
||||
assert compiler.compilation_configuration == default_compilation_configuration
|
||||
assert compiler.compilation_configuration is not default_compilation_configuration
|
||||
|
||||
for i, j in zip(numpy.arange(5), numpy.arange(5, 10)):
|
||||
i = numpy.ones(input_shape, dtype=numpy.int64) * i
|
||||
j = numpy.ones(input_shape, dtype=numpy.int64) * j
|
||||
check_array_equality(compiler(i, j), complicated_topology(i, j))
|
||||
|
||||
# For coverage, check that we flush the inputset when we query the OPGraph
|
||||
current_op_graph = compiler.op_graph
|
||||
assert current_op_graph is not compiler.op_graph
|
||||
assert len(compiler._current_inputset) == 0 # pylint: disable=protected-access
|
||||
# For coverage, cover case where the current inputset is empty
|
||||
compiler._eval_on_current_inputset() # pylint: disable=protected-access
|
||||
|
||||
# Continue a bit more
|
||||
for i, j in zip(numpy.arange(5, 10), numpy.arange(5)):
|
||||
i = numpy.ones(input_shape, dtype=numpy.int64) * i
|
||||
j = numpy.ones(input_shape, dtype=numpy.int64) * j
|
||||
check_array_equality(compiler(i, j), complicated_topology(i, j))
|
||||
|
||||
if input_shape == ():
|
||||
assert (
|
||||
(got := format_operation_graph(compiler.op_graph))
|
||||
== """ %0 = 67 # ClearScalar<uint7>
|
||||
%1 = 2 # ClearScalar<uint2>
|
||||
%2 = 3 # ClearScalar<uint2>
|
||||
%3 = 1 # ClearScalar<uint1>
|
||||
%4 = x # EncryptedScalar<uint4>
|
||||
%5 = y # ClearScalar<uint4>
|
||||
%6 = add(%4, %5) # EncryptedScalar<uint4>
|
||||
%7 = add(%6, %1) # EncryptedScalar<uint4>
|
||||
%8 = add(%6, %3) # EncryptedScalar<uint4>
|
||||
%9 = astype(%7, dtype=int32) # EncryptedScalar<uint4>
|
||||
%10 = add(%7, %2) # EncryptedScalar<uint5>
|
||||
%11 = add(%8, %7) # EncryptedScalar<uint5>
|
||||
%12 = astype(%10, dtype=int32) # EncryptedScalar<uint5>
|
||||
%13 = astype(%11, dtype=int32) # EncryptedScalar<uint5>
|
||||
%14 = astype(%11, dtype=int32) # EncryptedScalar<uint5>
|
||||
%15 = add(%14, %0) # EncryptedScalar<uint7>
|
||||
(%13, %9, %12, %15)"""
|
||||
), got
|
||||
else:
|
||||
assert (
|
||||
(got := format_operation_graph(compiler.op_graph))
|
||||
== """ %0 = 67 # ClearScalar<uint7>
|
||||
%1 = 2 # ClearScalar<uint2>
|
||||
%2 = 3 # ClearScalar<uint2>
|
||||
%3 = 1 # ClearScalar<uint1>
|
||||
%4 = x # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
||||
%5 = y # ClearTensor<uint4, shape=(3, 1, 2)>
|
||||
%6 = add(%4, %5) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
||||
%7 = add(%6, %1) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
||||
%8 = add(%6, %3) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
||||
%9 = astype(%7, dtype=int32) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
||||
%10 = add(%7, %2) # EncryptedTensor<uint5, shape=(3, 1, 2)>
|
||||
%11 = add(%8, %7) # EncryptedTensor<uint5, shape=(3, 1, 2)>
|
||||
%12 = astype(%10, dtype=int32) # EncryptedTensor<uint5, shape=(3, 1, 2)>
|
||||
%13 = astype(%11, dtype=int32) # EncryptedTensor<uint5, shape=(3, 1, 2)>
|
||||
%14 = astype(%11, dtype=int32) # EncryptedTensor<uint5, shape=(3, 1, 2)>
|
||||
%15 = add(%14, %0) # EncryptedTensor<uint7, shape=(3, 1, 2)>
|
||||
(%13, %9, %12, %15)"""
|
||||
), got
|
||||
|
||||
|
||||
def remaining_inputset_size(inputset_len):
|
||||
"""Small function to generate test cases below for remaining inputset length."""
|
||||
return inputset_len % NPFHECompiler.INPUTSET_SIZE_BEFORE_AUTO_BOUND_UPDATE
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputset_len, expected_remaining_inputset_len",
|
||||
[
|
||||
(42, remaining_inputset_size(42)),
|
||||
(128, remaining_inputset_size(128)),
|
||||
(234, remaining_inputset_size(234)),
|
||||
],
|
||||
)
|
||||
def test_np_fhe_compiler_auto_flush(
|
||||
inputset_len,
|
||||
expected_remaining_inputset_len,
|
||||
default_compilation_configuration,
|
||||
check_array_equality,
|
||||
):
|
||||
"""Test the auto flush of NPFHECompiler once the inputset is 128 elements."""
|
||||
|
||||
def function_to_compile(x):
|
||||
return x // 2
|
||||
|
||||
compiler = NPFHECompiler(
|
||||
function_to_compile,
|
||||
{"x": "encrypted"},
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
for i in numpy.arange(inputset_len):
|
||||
check_array_equality(compiler(i), function_to_compile(i))
|
||||
|
||||
# Check the inputset was properly flushed
|
||||
assert (
|
||||
len(compiler._current_inputset) # pylint: disable=protected-access
|
||||
== expected_remaining_inputset_len
|
||||
)
|
||||
|
||||
|
||||
def test_np_fhe_compiler_full_compilation(default_compilation_configuration, check_array_equality):
|
||||
"""Test the case where we generate an FHE circuit."""
|
||||
|
||||
def function_to_compile(x):
|
||||
return x + 42
|
||||
|
||||
compiler = NPFHECompiler(
|
||||
function_to_compile,
|
||||
{"x": "encrypted"},
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
# For coverage
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
compiler.get_compiled_fhe_circuit()
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"Requested FHECircuit but no OPGraph was compiled. "
|
||||
"Did you forget to evaluate NPFHECompiler over an inputset?"
|
||||
)
|
||||
|
||||
for i in numpy.arange(64):
|
||||
check_array_equality(compiler(i), function_to_compile(i))
|
||||
|
||||
fhe_circuit = compiler.get_compiled_fhe_circuit()
|
||||
|
||||
for i in range(64):
|
||||
assert fhe_circuit.encrypt_run_decrypt(i) == function_to_compile(i)
|
||||
|
||||
|
||||
def test_np_fhe_compiler_compile_on_inputset(default_compilation_configuration):
|
||||
"""Test the case where we generate an FHE circuit with a single call."""
|
||||
|
||||
def function_to_compile(x):
|
||||
return x + 42
|
||||
|
||||
compiler = NPFHECompiler(
|
||||
function_to_compile,
|
||||
{"x": "encrypted"},
|
||||
default_compilation_configuration,
|
||||
)
|
||||
circuit = compiler.compile_on_inputset(numpy.arange(64))
|
||||
|
||||
for i in range(64):
|
||||
assert circuit.encrypt_run_decrypt(i) == function_to_compile(i)
|
||||
@@ -1,111 +0,0 @@
|
||||
"""Test file for numpy dtype helpers"""
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types.floats import Float
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.numpy.np_dtypes_helpers import (
|
||||
convert_base_data_type_to_numpy_dtype,
|
||||
convert_numpy_dtype_to_base_data_type,
|
||||
get_base_value_for_numpy_or_python_constant_data,
|
||||
get_constructor_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"numpy_dtype,expected_common_type",
|
||||
[
|
||||
pytest.param(numpy.int8, Integer(8, is_signed=True)),
|
||||
pytest.param("int8", Integer(8, is_signed=True)),
|
||||
pytest.param(numpy.int16, Integer(16, is_signed=True)),
|
||||
pytest.param("int16", Integer(16, is_signed=True)),
|
||||
pytest.param(numpy.int32, Integer(32, is_signed=True)),
|
||||
pytest.param("int32", Integer(32, is_signed=True)),
|
||||
pytest.param(numpy.int64, Integer(64, is_signed=True)),
|
||||
pytest.param("int64", Integer(64, is_signed=True)),
|
||||
pytest.param(numpy.uint8, Integer(8, is_signed=False)),
|
||||
pytest.param("uint8", Integer(8, is_signed=False)),
|
||||
pytest.param(numpy.uint16, Integer(16, is_signed=False)),
|
||||
pytest.param("uint16", Integer(16, is_signed=False)),
|
||||
pytest.param(numpy.uint32, Integer(32, is_signed=False)),
|
||||
pytest.param("uint32", Integer(32, is_signed=False)),
|
||||
pytest.param(numpy.uint64, Integer(64, is_signed=False)),
|
||||
pytest.param("uint64", Integer(64, is_signed=False)),
|
||||
pytest.param(numpy.float32, Float(32)),
|
||||
pytest.param("float32", Float(32)),
|
||||
pytest.param(numpy.float64, Float(64)),
|
||||
pytest.param("float64", Float(64)),
|
||||
pytest.param("complex64", None, marks=pytest.mark.xfail(strict=True, raises=ValueError)),
|
||||
],
|
||||
)
|
||||
def test_convert_numpy_dtype_to_base_data_type(numpy_dtype, expected_common_type):
|
||||
"""Test function for convert_numpy_dtype_to_base_data_type"""
|
||||
assert convert_numpy_dtype_to_base_data_type(numpy_dtype) == expected_common_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_dtype,expected_numpy_dtype",
|
||||
[
|
||||
pytest.param(Integer(7, is_signed=False), numpy.uint32),
|
||||
pytest.param(Integer(7, is_signed=True), numpy.int32),
|
||||
pytest.param(Integer(32, is_signed=True), numpy.int32),
|
||||
pytest.param(Integer(64, is_signed=True), numpy.int64),
|
||||
pytest.param(Integer(32, is_signed=False), numpy.uint32),
|
||||
pytest.param(Integer(64, is_signed=False), numpy.uint64),
|
||||
pytest.param(Float(32), numpy.float32),
|
||||
pytest.param(Float(64), numpy.float64),
|
||||
pytest.param(
|
||||
Integer(128, is_signed=True),
|
||||
None,
|
||||
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_convert_common_dtype_to_numpy_dtype(common_dtype, expected_numpy_dtype):
|
||||
"""Test function for convert_common_dtype_to_numpy_dtype"""
|
||||
assert expected_numpy_dtype == convert_base_data_type_to_numpy_dtype(common_dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"constant_data,expected_constructor",
|
||||
[
|
||||
(10, int),
|
||||
(42.0, float),
|
||||
(numpy.int32(10), numpy.int32),
|
||||
],
|
||||
)
|
||||
def test_get_constructor_for_numpy_or_python_constant_data(constant_data, expected_constructor):
|
||||
"""Test function for get_constructor_for_numpy_or_python_constant_data"""
|
||||
|
||||
assert expected_constructor == get_constructor_for_numpy_or_python_constant_data(constant_data)
|
||||
|
||||
|
||||
def test_get_constructor_for_numpy_arrays(test_helpers):
|
||||
"""Test function for get_constructor_for_numpy_or_python_constant_data for numpy arrays."""
|
||||
|
||||
arrays = [
|
||||
numpy.array([[0, 1], [3, 4]], dtype=numpy.uint64),
|
||||
numpy.array([[0, 1], [3, 4]], dtype=numpy.float64),
|
||||
]
|
||||
|
||||
def get_expected_constructor(array: numpy.ndarray):
|
||||
return lambda x: numpy.full(array.shape, x, dtype=array.dtype)
|
||||
|
||||
expected_constructors = [get_expected_constructor(array) for array in arrays]
|
||||
|
||||
for array, expected_constructor in zip(arrays, expected_constructors):
|
||||
assert test_helpers.python_functions_are_equal_or_equivalent(
|
||||
expected_constructor, get_constructor_for_numpy_or_python_constant_data(array)
|
||||
)
|
||||
|
||||
|
||||
def test_get_base_value_for_numpy_or_python_constant_data_with_list():
|
||||
"""Test function for get_base_value_for_numpy_or_python_constant_data called with list"""
|
||||
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match="Unsupported constant data of type list "
|
||||
"\\(if you meant to use a list as an array, please use numpy\\.array instead\\)",
|
||||
):
|
||||
get_base_value_for_numpy_or_python_constant_data([1, 2, 3])
|
||||
@@ -1,96 +0,0 @@
|
||||
"""Test file for numpy inputset helpers"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from concrete.common.compilation import CompilationConfiguration
|
||||
from concrete.common.data_types import Float, UnsignedInteger
|
||||
from concrete.common.data_types.base import BaseDataType
|
||||
from concrete.common.values import BaseValue, EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy.np_inputset_helpers import _generate_random_inputset
|
||||
|
||||
|
||||
def test_generate_random_inputset():
|
||||
"""Test function for generate_random_inputset"""
|
||||
|
||||
inputset = _generate_random_inputset(
|
||||
{
|
||||
"x1": EncryptedScalar(UnsignedInteger(4)),
|
||||
"x2": EncryptedTensor(UnsignedInteger(4), shape=(2, 3)),
|
||||
"x3": EncryptedScalar(Float(64)),
|
||||
"x4": EncryptedTensor(Float(64), shape=(3, 2)),
|
||||
},
|
||||
CompilationConfiguration(random_inputset_samples=15),
|
||||
)
|
||||
|
||||
assert isinstance(inputset, list)
|
||||
assert len(inputset) == 15
|
||||
|
||||
for sample in inputset:
|
||||
assert isinstance(sample, tuple)
|
||||
assert len(sample) == 4
|
||||
|
||||
assert isinstance(sample[0], int)
|
||||
assert 0 <= sample[0] < 2 ** 4
|
||||
|
||||
assert isinstance(sample[1], np.ndarray)
|
||||
assert sample[1].dtype == np.uint64
|
||||
assert sample[1].shape == (2, 3)
|
||||
assert (sample[1] >= 0).all()
|
||||
assert (sample[1] < 2 ** 4).all()
|
||||
|
||||
assert isinstance(sample[2], float)
|
||||
assert 0 <= sample[2] < 1
|
||||
|
||||
assert isinstance(sample[3], np.ndarray)
|
||||
assert sample[3].dtype == np.float64
|
||||
assert sample[3].shape == (3, 2)
|
||||
assert (sample[3] >= 0).all()
|
||||
assert (sample[3] < 1).all()
|
||||
|
||||
|
||||
def test_fail_generate_random_inputset():
|
||||
"""Test function for failed generate_random_inputset"""
|
||||
|
||||
class MockDtype(BaseDataType):
|
||||
"""Unsupported dtype to check error messages"""
|
||||
|
||||
def __eq__(self, o: object) -> bool:
|
||||
return False
|
||||
|
||||
def __str__(self):
|
||||
return "MockDtype"
|
||||
|
||||
class MockValue(BaseValue):
|
||||
"""Unsupported value to check error messages"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(MockDtype(), is_encrypted=True)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return False
|
||||
|
||||
def __str__(self):
|
||||
return "MockValue"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
try:
|
||||
_generate_random_inputset(
|
||||
{"x": MockValue()},
|
||||
CompilationConfiguration(random_inputset_samples=15),
|
||||
)
|
||||
except Exception as error:
|
||||
expected = "Random inputset cannot be generated for MockValue parameters"
|
||||
assert str(error) == expected
|
||||
raise
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
try:
|
||||
_generate_random_inputset(
|
||||
{"x": EncryptedScalar(MockDtype())},
|
||||
CompilationConfiguration(random_inputset_samples=15),
|
||||
)
|
||||
except Exception as error:
|
||||
expected = "Random inputset cannot be generated for parameters of type MockDtype"
|
||||
assert str(error) == expected
|
||||
raise
|
||||
@@ -1,110 +0,0 @@
|
||||
"""Test file for numpy mlir converter"""
|
||||
|
||||
import math
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
import concrete.numpy as hnp
|
||||
from concrete.common.representation.intermediate import GenericFunction
|
||||
from concrete.numpy.np_mlir_converter import generate_deduplicated_tables
|
||||
|
||||
|
||||
def multi_tlu_func(x, cst):
|
||||
"""Multi TLU function"""
|
||||
y = x + cst
|
||||
return y.astype(numpy.int32)
|
||||
|
||||
|
||||
RESNET_BIGGEST_SHAPE = (64, 112, 112)
|
||||
RESNET_BIGGEST_SIZE = math.prod(RESNET_BIGGEST_SHAPE)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,expected_number_of_tables",
|
||||
[
|
||||
(
|
||||
lambda x: multi_tlu_func(x, numpy.zeros(RESNET_BIGGEST_SHAPE, dtype=numpy.float64)),
|
||||
1,
|
||||
),
|
||||
(
|
||||
lambda x: multi_tlu_func(
|
||||
x,
|
||||
numpy.arange(RESNET_BIGGEST_SIZE, dtype=numpy.float64).reshape(
|
||||
RESNET_BIGGEST_SHAPE
|
||||
),
|
||||
),
|
||||
RESNET_BIGGEST_SIZE,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_generate_deduplicated_tables(
|
||||
function, expected_number_of_tables, default_compilation_configuration
|
||||
):
|
||||
"""Test function for generate_deduplicated_tables"""
|
||||
op_graph = hnp.compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function,
|
||||
{"x": hnp.EncryptedTensor(hnp.Integer(7, False), RESNET_BIGGEST_SHAPE)},
|
||||
(i * numpy.ones(RESNET_BIGGEST_SHAPE, dtype=numpy.int32) for i in range(128)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
univariate_function_nodes = [
|
||||
node for node in op_graph.graph.nodes() if isinstance(node, GenericFunction)
|
||||
]
|
||||
|
||||
assert len(univariate_function_nodes) == 1
|
||||
|
||||
tlu_node = univariate_function_nodes[0]
|
||||
|
||||
deduplication_result = generate_deduplicated_tables(
|
||||
tlu_node, op_graph.get_ordered_preds(tlu_node)
|
||||
)
|
||||
|
||||
assert len(deduplication_result) == expected_number_of_tables
|
||||
|
||||
|
||||
def test_deduplicated_tables_correctness(default_compilation_configuration):
|
||||
"""Check the deduplicated tables are the expected ones"""
|
||||
|
||||
tensor_shape = (2, 2)
|
||||
|
||||
op_graph = hnp.compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
lambda x: multi_tlu_func(x, numpy.arange(4, dtype=numpy.float64).reshape(tensor_shape)),
|
||||
{"x": hnp.EncryptedTensor(hnp.Integer(2, False), tensor_shape)},
|
||||
(i * numpy.ones(tensor_shape, dtype=numpy.int32) for i in range(4)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
univariate_function_nodes = [
|
||||
node for node in op_graph.graph.nodes() if isinstance(node, GenericFunction)
|
||||
]
|
||||
|
||||
assert len(univariate_function_nodes) == 1
|
||||
|
||||
tlu_node = univariate_function_nodes[0]
|
||||
|
||||
deduplication_result = generate_deduplicated_tables(
|
||||
tlu_node, op_graph.get_ordered_preds(tlu_node)
|
||||
)
|
||||
|
||||
expected_result = tuple(
|
||||
(
|
||||
numpy.arange(i, 4 + i, dtype=numpy.int32),
|
||||
[
|
||||
numpy.unravel_index(i, tensor_shape),
|
||||
],
|
||||
)
|
||||
for i in range(4)
|
||||
)
|
||||
|
||||
assert len(deduplication_result) == len(expected_result)
|
||||
for computed_array, computed_idx in deduplication_result:
|
||||
for expected_array, expected_idx in expected_result:
|
||||
if numpy.array_equal(computed_array, expected_array) and computed_idx == expected_idx:
|
||||
break
|
||||
else:
|
||||
raise AssertionError(
|
||||
f"Could not find {(computed_array, computed_idx)} "
|
||||
f"in expected_result: {expected_result}"
|
||||
)
|
||||
@@ -1,991 +0,0 @@
|
||||
"""Test file for numpy tracing"""
|
||||
|
||||
import inspect
|
||||
|
||||
import networkx as nx
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types.dtypes_helpers import broadcast_shapes
|
||||
from concrete.common.data_types.floats import Float
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.debugging import format_operation_graph
|
||||
from concrete.common.representation import intermediate as ir
|
||||
from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import tracing
|
||||
|
||||
OPERATIONS_TO_TEST = [ir.Add, ir.Sub, ir.Mul]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"operation",
|
||||
OPERATIONS_TO_TEST,
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"x",
|
||||
[
|
||||
pytest.param(EncryptedScalar(Integer(64, is_signed=False)), id="x: Encrypted uint"),
|
||||
pytest.param(
|
||||
EncryptedScalar(Integer(64, is_signed=True)),
|
||||
id="x: Encrypted int",
|
||||
),
|
||||
pytest.param(
|
||||
ClearScalar(Integer(64, is_signed=False)),
|
||||
id="x: Clear uint",
|
||||
),
|
||||
pytest.param(
|
||||
ClearScalar(Integer(64, is_signed=True)),
|
||||
id="x: Clear int",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"y",
|
||||
[
|
||||
pytest.param(EncryptedScalar(Integer(64, is_signed=False)), id="y: Encrypted uint"),
|
||||
pytest.param(
|
||||
EncryptedScalar(Integer(64, is_signed=True)),
|
||||
id="y: Encrypted int",
|
||||
),
|
||||
pytest.param(
|
||||
ClearScalar(Integer(64, is_signed=False)),
|
||||
id="y: Clear uint",
|
||||
),
|
||||
pytest.param(
|
||||
ClearScalar(Integer(64, is_signed=True)),
|
||||
id="y: Clear int",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_numpy_tracing_binary_op(operation, x, y, test_helpers):
|
||||
"Test numpy tracing a binary operation (in the supported ops)"
|
||||
|
||||
# Remark that the functions here have a common structure (which is
|
||||
# 2x op y), such that creating further the ref_graph is easy, by
|
||||
# hand
|
||||
def simple_add_function(x, y):
|
||||
z = x + x
|
||||
return z + y
|
||||
|
||||
def simple_sub_function(x, y):
|
||||
z = x + x
|
||||
return z - y
|
||||
|
||||
def simple_mul_function(x, y):
|
||||
z = x + x
|
||||
return z * y
|
||||
|
||||
assert operation in OPERATIONS_TO_TEST, f"unknown operation {operation}"
|
||||
if operation == ir.Add:
|
||||
function_to_compile = simple_add_function
|
||||
elif operation == ir.Sub:
|
||||
function_to_compile = simple_sub_function
|
||||
elif operation == ir.Mul:
|
||||
function_to_compile = simple_mul_function
|
||||
|
||||
op_graph = tracing.trace_numpy_function(function_to_compile, {"x": x, "y": y})
|
||||
|
||||
ref_graph = nx.MultiDiGraph()
|
||||
|
||||
input_x = ir.Input(x, input_name="x", program_input_idx=0)
|
||||
input_y = ir.Input(y, input_name="y", program_input_idx=1)
|
||||
|
||||
add_node_z = ir.Add(
|
||||
(
|
||||
input_x.outputs[0],
|
||||
input_x.outputs[0],
|
||||
)
|
||||
)
|
||||
|
||||
returned_final_node = operation(
|
||||
(
|
||||
add_node_z.outputs[0],
|
||||
input_y.outputs[0],
|
||||
)
|
||||
)
|
||||
|
||||
ref_graph.add_node(input_x)
|
||||
ref_graph.add_node(input_y)
|
||||
ref_graph.add_node(add_node_z)
|
||||
ref_graph.add_node(returned_final_node)
|
||||
|
||||
ref_graph.add_edge(input_x, add_node_z, input_idx=0, output_idx=0)
|
||||
ref_graph.add_edge(input_x, add_node_z, input_idx=1, output_idx=0)
|
||||
|
||||
ref_graph.add_edge(add_node_z, returned_final_node, input_idx=0, output_idx=0)
|
||||
ref_graph.add_edge(input_y, returned_final_node, input_idx=1, output_idx=0)
|
||||
|
||||
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)
|
||||
|
||||
|
||||
def test_numpy_tracing_tensors():
|
||||
"Test numpy tracing tensors"
|
||||
|
||||
def all_operations(x):
|
||||
intermediate = x + numpy.array([[1, 2], [3, 4]])
|
||||
intermediate = numpy.array([[5, 6], [7, 8]]) + intermediate
|
||||
|
||||
intermediate = numpy.array([[100, 200], [300, 400]]) - intermediate
|
||||
intermediate = intermediate - numpy.array([[10, 20], [30, 40]])
|
||||
|
||||
intermediate = intermediate * numpy.array([[1, 2], [2, 1]])
|
||||
intermediate = numpy.array([[2, 1], [1, 2]]) * intermediate
|
||||
|
||||
return intermediate
|
||||
|
||||
op_graph = tracing.trace_numpy_function(
|
||||
all_operations, {"x": EncryptedTensor(Integer(32, True), shape=(2, 2))}
|
||||
)
|
||||
|
||||
expected = """ %0 = [[2 1] [1 2]] # ClearTensor<uint2, shape=(2, 2)>
|
||||
%1 = [[1 2] [2 1]] # ClearTensor<uint2, shape=(2, 2)>
|
||||
%2 = [[10 20] [30 40]] # ClearTensor<uint6, shape=(2, 2)>
|
||||
%3 = [[100 200] [300 400]] # ClearTensor<uint9, shape=(2, 2)>
|
||||
%4 = [[5 6] [7 8]] # ClearTensor<uint4, shape=(2, 2)>
|
||||
%5 = x # EncryptedTensor<int32, shape=(2, 2)>
|
||||
%6 = [[1 2] [3 4]] # ClearTensor<uint3, shape=(2, 2)>
|
||||
%7 = add(%5, %6) # EncryptedTensor<int32, shape=(2, 2)>
|
||||
%8 = add(%4, %7) # EncryptedTensor<int32, shape=(2, 2)>
|
||||
%9 = sub(%3, %8) # EncryptedTensor<int32, shape=(2, 2)>
|
||||
%10 = sub(%9, %2) # EncryptedTensor<int32, shape=(2, 2)>
|
||||
%11 = mul(%10, %1) # EncryptedTensor<int32, shape=(2, 2)>
|
||||
%12 = mul(%0, %11) # EncryptedTensor<int32, shape=(2, 2)>
|
||||
return %12""" # noqa: E501
|
||||
|
||||
assert format_operation_graph(op_graph) == expected, format_operation_graph(op_graph)
|
||||
|
||||
|
||||
def test_numpy_explicit_tracing_tensors():
|
||||
"Test numpy tracing tensors using explicit operations"
|
||||
|
||||
def all_explicit_operations(x):
|
||||
intermediate = numpy.add(x, numpy.array([[1, 2], [3, 4]]))
|
||||
intermediate = numpy.add(numpy.array([[5, 6], [7, 8]]), intermediate)
|
||||
|
||||
intermediate = numpy.subtract(numpy.array([[100, 200], [300, 400]]), intermediate)
|
||||
intermediate = numpy.subtract(intermediate, numpy.array([[10, 20], [30, 40]]))
|
||||
|
||||
intermediate = numpy.multiply(intermediate, numpy.array([[1, 2], [2, 1]]))
|
||||
intermediate = numpy.multiply(numpy.array([[2, 1], [1, 2]]), intermediate)
|
||||
|
||||
return intermediate
|
||||
|
||||
op_graph = tracing.trace_numpy_function(
|
||||
all_explicit_operations, {"x": EncryptedTensor(Integer(32, True), shape=(2, 2))}
|
||||
)
|
||||
|
||||
expected = """ %0 = [[2 1] [1 2]] # ClearTensor<uint2, shape=(2, 2)>
|
||||
%1 = [[1 2] [2 1]] # ClearTensor<uint2, shape=(2, 2)>
|
||||
%2 = [[10 20] [30 40]] # ClearTensor<uint6, shape=(2, 2)>
|
||||
%3 = [[100 200] [300 400]] # ClearTensor<uint9, shape=(2, 2)>
|
||||
%4 = [[5 6] [7 8]] # ClearTensor<uint4, shape=(2, 2)>
|
||||
%5 = x # EncryptedTensor<int32, shape=(2, 2)>
|
||||
%6 = [[1 2] [3 4]] # ClearTensor<uint3, shape=(2, 2)>
|
||||
%7 = add(%5, %6) # EncryptedTensor<int32, shape=(2, 2)>
|
||||
%8 = add(%4, %7) # EncryptedTensor<int32, shape=(2, 2)>
|
||||
%9 = sub(%3, %8) # EncryptedTensor<int32, shape=(2, 2)>
|
||||
%10 = sub(%9, %2) # EncryptedTensor<int32, shape=(2, 2)>
|
||||
%11 = mul(%10, %1) # EncryptedTensor<int32, shape=(2, 2)>
|
||||
%12 = mul(%0, %11) # EncryptedTensor<int32, shape=(2, 2)>
|
||||
return %12""" # noqa: E501
|
||||
|
||||
assert format_operation_graph(op_graph) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"x_shape,y_shape",
|
||||
[
|
||||
pytest.param((), ()),
|
||||
pytest.param((3,), ()),
|
||||
pytest.param((3,), (1,)),
|
||||
pytest.param((3,), (2,), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
|
||||
pytest.param((3,), (3,)),
|
||||
pytest.param((2, 3), ()),
|
||||
pytest.param((2, 3), (1,)),
|
||||
pytest.param((2, 3), (2,), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
|
||||
pytest.param((2, 3), (3,)),
|
||||
pytest.param((2, 3), (1, 1)),
|
||||
pytest.param((2, 3), (2, 1)),
|
||||
pytest.param((2, 3), (3, 1), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
|
||||
pytest.param((2, 3), (1, 2), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
|
||||
pytest.param((2, 3), (2, 2), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
|
||||
pytest.param((2, 3), (3, 2), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
|
||||
pytest.param((2, 3), (1, 3)),
|
||||
pytest.param((2, 3), (2, 3)),
|
||||
pytest.param((2, 3), (3, 3), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
|
||||
pytest.param((2, 1, 3), (1, 1, 1)),
|
||||
pytest.param((2, 1, 3), (1, 4, 1)),
|
||||
pytest.param((2, 1, 3), (2, 4, 3)),
|
||||
],
|
||||
)
|
||||
def test_numpy_tracing_broadcasted_tensors(x_shape, y_shape):
|
||||
"""Test numpy tracing broadcasted tensors"""
|
||||
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
op_graph = tracing.trace_numpy_function(
|
||||
f,
|
||||
{
|
||||
"x": EncryptedTensor(Integer(3, True), shape=x_shape),
|
||||
"y": EncryptedTensor(Integer(3, True), shape=y_shape),
|
||||
},
|
||||
)
|
||||
|
||||
assert op_graph.input_nodes[0].outputs[0].shape == x_shape
|
||||
assert op_graph.input_nodes[1].outputs[0].shape == y_shape
|
||||
assert op_graph.output_nodes[0].outputs[0].shape == broadcast_shapes(x_shape, y_shape)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,op_graph_expected_output_type,input_and_expected_output_tuples",
|
||||
[
|
||||
(
|
||||
lambda x: x.astype(numpy.int32),
|
||||
Integer(32, is_signed=True),
|
||||
[
|
||||
(14, numpy.int32(14)),
|
||||
(1.5, numpy.int32(1)),
|
||||
(2.0, numpy.int32(2)),
|
||||
(-1.5, numpy.int32(-1)),
|
||||
(2 ** 31 - 1, numpy.int32(2 ** 31 - 1)),
|
||||
(-(2 ** 31), numpy.int32(-(2 ** 31))),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x.astype(numpy.uint32),
|
||||
Integer(32, is_signed=False),
|
||||
[
|
||||
(14, numpy.uint32(14)),
|
||||
(1.5, numpy.uint32(1)),
|
||||
(2.0, numpy.uint32(2)),
|
||||
(2 ** 32 - 1, numpy.uint32(2 ** 32 - 1)),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x.astype(numpy.int64),
|
||||
Integer(64, is_signed=True),
|
||||
[
|
||||
(14, numpy.int64(14)),
|
||||
(1.5, numpy.int64(1)),
|
||||
(2.0, numpy.int64(2)),
|
||||
(-1.5, numpy.int64(-1)),
|
||||
(2 ** 63 - 1, numpy.int64(2 ** 63 - 1)),
|
||||
(-(2 ** 63), numpy.int64(-(2 ** 63))),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x.astype(numpy.uint64),
|
||||
Integer(64, is_signed=False),
|
||||
[
|
||||
(14, numpy.uint64(14)),
|
||||
(1.5, numpy.uint64(1)),
|
||||
(2.0, numpy.uint64(2)),
|
||||
(2 ** 64 - 1, numpy.uint64(2 ** 64 - 1)),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x.astype(numpy.float64),
|
||||
Float(64),
|
||||
[
|
||||
(14, numpy.float64(14.0)),
|
||||
(1.5, numpy.float64(1.5)),
|
||||
(2.0, numpy.float64(2.0)),
|
||||
(-1.5, numpy.float64(-1.5)),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x.astype(numpy.float32),
|
||||
Float(32),
|
||||
[
|
||||
(14, numpy.float32(14.0)),
|
||||
(1.5, numpy.float32(1.5)),
|
||||
(2.0, numpy.float32(2.0)),
|
||||
(-1.5, numpy.float32(-1.5)),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_tracing_astype(
|
||||
function_to_trace, op_graph_expected_output_type, input_and_expected_output_tuples
|
||||
):
|
||||
"""Test function for NPTracer.astype"""
|
||||
for input_, expected_output in input_and_expected_output_tuples:
|
||||
input_value = (
|
||||
EncryptedScalar(Integer(64, is_signed=True))
|
||||
if isinstance(input_, int)
|
||||
else EncryptedScalar(Float(64))
|
||||
)
|
||||
|
||||
op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value})
|
||||
output_node = op_graph.output_nodes[0]
|
||||
assert op_graph_expected_output_type == output_node.outputs[0].dtype
|
||||
|
||||
node_results = op_graph.evaluate({0: numpy.array(input_)})
|
||||
evaluated_output = node_results[output_node]
|
||||
assert evaluated_output.dtype == expected_output.dtype
|
||||
assert expected_output == evaluated_output
|
||||
|
||||
|
||||
def test_tracing_astype_single_element_array_corner_case(check_array_equality):
|
||||
"""Test corner case where an array could be transformed to its scalar element"""
|
||||
a = numpy.array([1], dtype=numpy.float64)
|
||||
|
||||
op_graph = tracing.trace_numpy_function(
|
||||
lambda x: x.astype(numpy.int32), {"x": EncryptedTensor(Float(64), (1,))}
|
||||
)
|
||||
|
||||
eval_result = op_graph(a)
|
||||
check_array_equality(eval_result, numpy.array([1], dtype=numpy.int32))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,inputs,expected_output_node,expected_output_value",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: numpy.dot(x, y),
|
||||
{
|
||||
"x": EncryptedTensor(Integer(7, is_signed=False), shape=(10,)),
|
||||
"y": EncryptedTensor(Integer(7, is_signed=False), shape=(10,)),
|
||||
},
|
||||
ir.Dot,
|
||||
EncryptedScalar(Integer(32, False)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: numpy.dot(x, y),
|
||||
{
|
||||
"x": EncryptedTensor(Float(64), shape=(10,)),
|
||||
"y": EncryptedTensor(Float(64), shape=(10,)),
|
||||
},
|
||||
ir.Dot,
|
||||
EncryptedScalar(Float(64)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: numpy.dot(x, y),
|
||||
{
|
||||
"x": ClearTensor(Integer(64, is_signed=True), shape=(6,)),
|
||||
"y": ClearTensor(Integer(64, is_signed=True), shape=(6,)),
|
||||
},
|
||||
ir.Dot,
|
||||
ClearScalar(Integer(64, is_signed=True)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.dot(x, numpy.array([1, 2, 3, 4, 5], dtype=numpy.int64)),
|
||||
{
|
||||
"x": EncryptedTensor(Integer(64, is_signed=True), shape=(5,)),
|
||||
},
|
||||
ir.Dot,
|
||||
EncryptedScalar(Integer(64, True)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.dot(numpy.array([1, 2, 3, 4, 5], dtype=numpy.int64)),
|
||||
{
|
||||
"x": EncryptedTensor(Integer(64, is_signed=True), shape=(5,)),
|
||||
},
|
||||
ir.Dot,
|
||||
EncryptedScalar(Integer(64, True)),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_trace_numpy_dot(function_to_trace, inputs, expected_output_node, expected_output_value):
|
||||
"""Function to test dot tracing"""
|
||||
|
||||
op_graph = tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
|
||||
assert len(op_graph.output_nodes) == 1
|
||||
assert isinstance(op_graph.output_nodes[0], expected_output_node)
|
||||
assert len(op_graph.output_nodes[0].outputs) == 1
|
||||
assert op_graph.output_nodes[0].outputs[0] == expected_output_value
|
||||
|
||||
|
||||
@pytest.mark.parametrize("np_function", tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC)
|
||||
def test_nptracer_get_tracing_func_for_np_functions(np_function):
|
||||
"""Test NPTracer get_tracing_func_for_np_function"""
|
||||
|
||||
expected_tracing_func = tracing.NPTracer.UFUNC_ROUTING[np_function]
|
||||
|
||||
assert tracing.NPTracer.get_tracing_func_for_np_function(np_function) == expected_tracing_func
|
||||
|
||||
|
||||
def test_nptracer_get_tracing_func_for_np_functions_not_implemented():
|
||||
"""Check NPTracer in case of not-implemented function"""
|
||||
with pytest.raises(NotImplementedError) as excinfo:
|
||||
tracing.NPTracer.get_tracing_func_for_np_function(numpy.conjugate)
|
||||
|
||||
assert "NPTracer does not yet manage the following func: conjugate" in str(excinfo.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"operation,exception_type,match",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: x + "fail",
|
||||
TypeError,
|
||||
"unsupported operand type(s) for +: 'NPTracer' and 'str'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: "fail" + x,
|
||||
TypeError,
|
||||
'can only concatenate str (not "NPTracer") to str',
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x - "fail",
|
||||
TypeError,
|
||||
"unsupported operand type(s) for -: 'NPTracer' and 'str'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: "fail" - x,
|
||||
TypeError,
|
||||
"unsupported operand type(s) for -: 'str' and 'NPTracer'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x * "fail",
|
||||
TypeError,
|
||||
"can't multiply sequence by non-int of type 'NPTracer'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: "fail" * x,
|
||||
TypeError,
|
||||
"can't multiply sequence by non-int of type 'NPTracer'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x / "fail",
|
||||
TypeError,
|
||||
"unsupported operand type(s) for /: 'NPTracer' and 'str'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: "fail" / x,
|
||||
TypeError,
|
||||
"unsupported operand type(s) for /: 'str' and 'NPTracer'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x // "fail",
|
||||
TypeError,
|
||||
"unsupported operand type(s) for //: 'NPTracer' and 'str'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: "fail" // x,
|
||||
TypeError,
|
||||
"unsupported operand type(s) for //: 'str' and 'NPTracer'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x / y, NotImplementedError, "Can't manage binary operator truediv"
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x // y, NotImplementedError, "Can't manage binary operator floordiv"
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_nptracer_unsupported_operands(operation, exception_type, match):
|
||||
"""Test cases where NPTracer cannot be used with other operands."""
|
||||
tracers = [
|
||||
tracing.NPTracer([], ir.Input(ClearScalar(Integer(32, True)), param_name, idx), 0)
|
||||
for idx, param_name in enumerate(inspect.signature(operation).parameters.keys())
|
||||
]
|
||||
|
||||
with pytest.raises(exception_type) as exc_info:
|
||||
_ = operation(*tracers)
|
||||
|
||||
assert match in str(exc_info)
|
||||
|
||||
|
||||
def subtest_tracing_calls(
|
||||
function_to_trace,
|
||||
input_value_input_and_expected_output_tuples,
|
||||
check_array_equality,
|
||||
):
|
||||
"""Test memory function managed by GenericFunction node of the form numpy.something"""
|
||||
for input_value, input_, expected_output in input_value_input_and_expected_output_tuples:
|
||||
|
||||
op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value})
|
||||
output_node = op_graph.output_nodes[0]
|
||||
|
||||
node_results = op_graph.evaluate({0: input_})
|
||||
evaluated_output = node_results[output_node]
|
||||
assert isinstance(evaluated_output, type(expected_output)), type(evaluated_output)
|
||||
check_array_equality(evaluated_output, expected_output)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,input_value_input_and_expected_output_tuples",
|
||||
[
|
||||
(
|
||||
lambda x: numpy.transpose(x),
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
numpy.arange(4).reshape(2, 2),
|
||||
numpy.array([[0, 2], [1, 3]]),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
numpy.arange(4, 8).reshape(2, 2),
|
||||
numpy.array([[4, 6], [5, 7]]),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(6, is_signed=False), shape=()),
|
||||
numpy.int64(42),
|
||||
numpy.int64(42),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: numpy.transpose(x) + 42,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(42, 57).reshape(3, 5).transpose(),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(6, is_signed=False), shape=()),
|
||||
numpy.int64(42),
|
||||
numpy.int64(84),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: numpy.ravel(x),
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
numpy.arange(4),
|
||||
numpy.array([0, 1, 2, 3]),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
numpy.arange(4).reshape(2, 2),
|
||||
numpy.array([0, 1, 2, 3]),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(6, is_signed=False), shape=()),
|
||||
numpy.int64(42),
|
||||
numpy.array([42], dtype=numpy.int64),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: numpy.reshape(x, (5, 3)) + 42,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(42, 57).reshape(5, 3),
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_tracing_numpy_calls(
|
||||
function_to_trace,
|
||||
input_value_input_and_expected_output_tuples,
|
||||
check_array_equality,
|
||||
):
|
||||
"""Test memory function managed by GenericFunction node of the form numpy.something"""
|
||||
subtest_tracing_calls(
|
||||
function_to_trace, input_value_input_and_expected_output_tuples, check_array_equality
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,input_value_input_and_expected_output_tuples",
|
||||
[
|
||||
(
|
||||
lambda x: x.transpose() + 42,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(42, 57).reshape(3, 5).transpose(),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(6, is_signed=False), shape=()),
|
||||
numpy.int64(42),
|
||||
numpy.int64(84),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x.ravel(),
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
numpy.arange(4),
|
||||
numpy.array([0, 1, 2, 3]),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
numpy.arange(4).reshape(2, 2),
|
||||
numpy.array([0, 1, 2, 3]),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(6, is_signed=False), shape=()),
|
||||
numpy.int64(42),
|
||||
numpy.array([42], dtype=numpy.int64),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x.reshape((5, 3)) + 42,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(42, 57).reshape(5, 3),
|
||||
),
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((5, 3)),
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(6, is_signed=False), shape=()),
|
||||
numpy.int64(42),
|
||||
None,
|
||||
)
|
||||
],
|
||||
marks=pytest.mark.xfail(strict=True, raises=ValueError),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.flatten(),
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(15),
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: abs(x),
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: +x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: -x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
(numpy.arange(15).reshape(3, 5)) * (-1),
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: ~x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(15).reshape(3, 5).__invert__(),
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x << 3,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(15).reshape(3, 5) * 8,
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x >> 1,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(15).reshape(3, 5) // 2,
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: 2 << x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5) % 8,
|
||||
2 << (numpy.arange(15).reshape(3, 5) % 8),
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: 256 >> x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5) % 8,
|
||||
256 >> (numpy.arange(15).reshape(3, 5) % 8),
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x > 4,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(15).reshape(3, 5) > 4,
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x < 5,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(15).reshape(3, 5) < 5,
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x <= 7,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(15).reshape(3, 5) <= 7,
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x >= 9,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(15).reshape(3, 5) >= 9,
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x == 11,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(15).reshape(3, 5) == 11,
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x != 12,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(15).reshape(3, 5) != 12,
|
||||
)
|
||||
],
|
||||
),
|
||||
# Remove misplaced-comparison-constant because precisely, we want to be sure it works fine
|
||||
# pylint: disable=misplaced-comparison-constant
|
||||
pytest.param(
|
||||
lambda x: 4 > x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
4 > numpy.arange(15).reshape(3, 5),
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: 5 < x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
5 < numpy.arange(15).reshape(3, 5),
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: 7 <= x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
7 <= numpy.arange(15).reshape(3, 5),
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: 9 >= x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
9 >= numpy.arange(15).reshape(3, 5),
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: 11 == x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
11 == numpy.arange(15).reshape(3, 5),
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: 12 != x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
12 != numpy.arange(15).reshape(3, 5),
|
||||
)
|
||||
],
|
||||
),
|
||||
# pylint: enable=misplaced-comparison-constant
|
||||
(
|
||||
lambda x: x & 11,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.array([i & 11 for i in range(15)]).reshape(3, 5),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: 13 & x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.array([i & 13 for i in range(15)]).reshape(3, 5),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x | 6,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.array([i | 6 for i in range(15)]).reshape(3, 5),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: 30 | x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.array([i | 30 for i in range(15)]).reshape(3, 5),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x ^ 91,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.array([i ^ 91 for i in range(15)]).reshape(3, 5),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: 115 ^ x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.array([i ^ 115 for i in range(15)]).reshape(3, 5),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x % 11,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.array([i % 11 for i in range(15)]).reshape(3, 5),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: 150 % (x + 1),
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.array([150 % (i + 1) for i in range(15)]).reshape(3, 5),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x ** 2,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.array([i ** 2 for i in range(15)]).reshape(3, 5),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: 2 ** x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5) % 7,
|
||||
numpy.array([2 ** (i % 7) for i in range(15)]).reshape(3, 5),
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_tracing_ndarray_calls(
|
||||
function_to_trace,
|
||||
input_value_input_and_expected_output_tuples,
|
||||
check_array_equality,
|
||||
):
|
||||
"""Test memory function managed by GenericFunction node of the form ndarray.something"""
|
||||
subtest_tracing_calls(
|
||||
function_to_trace, input_value_input_and_expected_output_tuples, check_array_equality
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lambda_f,params",
|
||||
[
|
||||
(
|
||||
lambda x: numpy.reshape(x, (5, 3)),
|
||||
{
|
||||
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(7, 5)),
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_errors_with_generic_function(lambda_f, params):
|
||||
"Test some errors with generic function"
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
assert "shapes are not compatible (old shape (7, 5), new shape (5, 3))" in str(excinfo.value)
|
||||
@@ -1,309 +0,0 @@
|
||||
"""Test file for numpy tracing"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types.floats import Float
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.representation import intermediate as ir
|
||||
from concrete.common.values import EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import tracing
|
||||
|
||||
OPERATIONS_TO_TEST = [ir.Add, ir.Sub, ir.Mul]
|
||||
|
||||
# Functions from tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC, whose output
|
||||
# is a float64, whatever the input type
|
||||
LIST_OF_UFUNC_WHOSE_OUTPUT_IS_FLOAT64 = set(
|
||||
[
|
||||
numpy.arccos,
|
||||
numpy.arccosh,
|
||||
numpy.arcsin,
|
||||
numpy.arcsinh,
|
||||
numpy.arctan,
|
||||
numpy.arctanh,
|
||||
numpy.cbrt,
|
||||
numpy.ceil,
|
||||
numpy.cos,
|
||||
numpy.cosh,
|
||||
numpy.deg2rad,
|
||||
numpy.degrees,
|
||||
numpy.exp,
|
||||
numpy.exp2,
|
||||
numpy.expm1,
|
||||
numpy.fabs,
|
||||
numpy.floor,
|
||||
numpy.log,
|
||||
numpy.log10,
|
||||
numpy.log1p,
|
||||
numpy.log2,
|
||||
numpy.rad2deg,
|
||||
numpy.radians,
|
||||
numpy.rint,
|
||||
numpy.sin,
|
||||
numpy.sinh,
|
||||
numpy.spacing,
|
||||
numpy.sqrt,
|
||||
numpy.tan,
|
||||
numpy.tanh,
|
||||
numpy.trunc,
|
||||
]
|
||||
)
|
||||
|
||||
# Functions from tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC, whose output
|
||||
# is a boolean, whatever the input type
|
||||
LIST_OF_UFUNC_WHOSE_OUTPUT_IS_BOOL = set(
|
||||
[
|
||||
numpy.isfinite,
|
||||
numpy.isinf,
|
||||
numpy.isnan,
|
||||
numpy.signbit,
|
||||
numpy.logical_not,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs,expected_output_node",
|
||||
[
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Integer(7, is_signed=False))},
|
||||
ir.GenericFunction,
|
||||
),
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Integer(32, is_signed=True))},
|
||||
ir.GenericFunction,
|
||||
),
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Integer(64, is_signed=True))},
|
||||
ir.GenericFunction,
|
||||
),
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Integer(128, is_signed=True))},
|
||||
ir.GenericFunction,
|
||||
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
|
||||
),
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Float(64))},
|
||||
ir.GenericFunction,
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace_def",
|
||||
[f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 1],
|
||||
)
|
||||
def test_trace_numpy_supported_unary_ufuncs(inputs, expected_output_node, function_to_trace_def):
|
||||
"""Function to trace supported numpy ufuncs"""
|
||||
|
||||
# We really need a lambda (because numpy functions are not playing
|
||||
# nice with inspect.signature), but pylint and flake8 are not happy
|
||||
# with it
|
||||
# pylint: disable=cell-var-from-loop
|
||||
function_to_trace = lambda x: function_to_trace_def(x) # noqa: E731
|
||||
# pylint: enable=cell-var-from-loop
|
||||
|
||||
op_graph = tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
|
||||
assert len(op_graph.output_nodes) == 1
|
||||
assert isinstance(op_graph.output_nodes[0], expected_output_node)
|
||||
assert len(op_graph.output_nodes[0].outputs) == 1
|
||||
|
||||
if function_to_trace_def in LIST_OF_UFUNC_WHOSE_OUTPUT_IS_FLOAT64:
|
||||
assert op_graph.output_nodes[0].outputs[0] == EncryptedScalar(Float(64))
|
||||
elif function_to_trace_def in LIST_OF_UFUNC_WHOSE_OUTPUT_IS_BOOL:
|
||||
|
||||
# Boolean function
|
||||
assert op_graph.output_nodes[0].outputs[0] == EncryptedScalar(Integer(8, is_signed=False))
|
||||
else:
|
||||
|
||||
# Function keeping more or less input type
|
||||
input_node_type = inputs["x"]
|
||||
|
||||
expected_output_node_type = deepcopy(input_node_type)
|
||||
|
||||
expected_output_node_type.dtype.bit_width = max(
|
||||
expected_output_node_type.dtype.bit_width, 32
|
||||
)
|
||||
|
||||
assert op_graph.output_nodes[0].outputs[0] == expected_output_node_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize("np_function", tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC)
|
||||
def test_nptracer_get_tracing_func_for_np_functions(np_function):
|
||||
"""Test NPTracer get_tracing_func_for_np_function"""
|
||||
|
||||
expected_tracing_func = tracing.NPTracer.UFUNC_ROUTING[np_function]
|
||||
|
||||
assert tracing.NPTracer.get_tracing_func_for_np_function(np_function) == expected_tracing_func
|
||||
|
||||
|
||||
def subtest_tracing_calls(
|
||||
function_to_trace,
|
||||
input_value_input_and_expected_output_tuples,
|
||||
check_array_equality,
|
||||
):
|
||||
"""Test memory function managed by GenericFunction node of the form numpy.something"""
|
||||
for input_value, input_, expected_output in input_value_input_and_expected_output_tuples:
|
||||
|
||||
op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value})
|
||||
output_node = op_graph.output_nodes[0]
|
||||
|
||||
node_results = op_graph.evaluate({0: input_})
|
||||
evaluated_output = node_results[output_node]
|
||||
assert isinstance(evaluated_output, type(expected_output)), type(evaluated_output)
|
||||
check_array_equality(evaluated_output, expected_output)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,input_value_input_and_expected_output_tuples",
|
||||
[
|
||||
(
|
||||
lambda x: numpy.transpose(x),
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
numpy.arange(4).reshape(2, 2),
|
||||
numpy.array([[0, 2], [1, 3]]),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
numpy.arange(4, 8).reshape(2, 2),
|
||||
numpy.array([[4, 6], [5, 7]]),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(6, is_signed=False), shape=()),
|
||||
numpy.int64(42),
|
||||
numpy.int64(42),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: numpy.transpose(x) + 42,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(42, 57).reshape(3, 5).transpose(),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(6, is_signed=False), shape=()),
|
||||
numpy.int64(42),
|
||||
numpy.int64(84),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: numpy.ravel(x),
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
numpy.arange(4),
|
||||
numpy.array([0, 1, 2, 3]),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
numpy.arange(4).reshape(2, 2),
|
||||
numpy.array([0, 1, 2, 3]),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(6, is_signed=False), shape=()),
|
||||
numpy.int64(42),
|
||||
numpy.array([42], dtype=numpy.int64),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: numpy.reshape(x, (5, 3)) + 42,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(42, 57).reshape(5, 3),
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_tracing_numpy_calls(
|
||||
function_to_trace,
|
||||
input_value_input_and_expected_output_tuples,
|
||||
check_array_equality,
|
||||
):
|
||||
"""Test memory function managed by GenericFunction node of the form numpy.something"""
|
||||
subtest_tracing_calls(
|
||||
function_to_trace, input_value_input_and_expected_output_tuples, check_array_equality
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,input_value_input_and_expected_output_tuples",
|
||||
[
|
||||
(
|
||||
lambda x: x.transpose() + 42,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(42, 57).reshape(3, 5).transpose(),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(6, is_signed=False), shape=()),
|
||||
numpy.int64(42),
|
||||
numpy.int64(84),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x.ravel(),
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
numpy.arange(4),
|
||||
numpy.array([0, 1, 2, 3]),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
numpy.arange(4).reshape(2, 2),
|
||||
numpy.array([0, 1, 2, 3]),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(6, is_signed=False), shape=()),
|
||||
numpy.int64(42),
|
||||
numpy.array([42], dtype=numpy.int64),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x.reshape((5, 3)) + 42,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(42, 57).reshape(5, 3),
|
||||
),
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((5, 3)),
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(6, is_signed=False), shape=()),
|
||||
numpy.int64(42),
|
||||
None,
|
||||
)
|
||||
],
|
||||
marks=pytest.mark.xfail(strict=True, raises=ValueError),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_tracing_ndarray_calls(
|
||||
function_to_trace,
|
||||
input_value_input_and_expected_output_tuples,
|
||||
check_array_equality,
|
||||
):
|
||||
"""Test memory function managed by GenericFunction node of the form ndarray.something"""
|
||||
subtest_tracing_calls(
|
||||
function_to_trace, input_value_input_and_expected_output_tuples, check_array_equality
|
||||
)
|
||||
@@ -1,184 +0,0 @@
|
||||
"""Test file for numpy tracing"""
|
||||
|
||||
import inspect
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.representation import intermediate as ir
|
||||
from concrete.common.values import ClearScalar, EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import tracing
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs",
|
||||
[
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Integer(32, is_signed=True))},
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace",
|
||||
# We really need a lambda (because numpy functions are not playing
|
||||
# nice with inspect.signature), but pylint is not happy
|
||||
# with it
|
||||
[lambda x: numpy.invert(x), lambda x: numpy.bitwise_not(x)],
|
||||
)
|
||||
def test_trace_numpy_fails_for_invert(inputs, function_to_trace):
|
||||
"""Check we catch calls to numpy.invert and tell user to change their code"""
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
|
||||
assert (
|
||||
"NPTracer does not manage the following func: invert. Please replace by calls to "
|
||||
"bitwise_xor with appropriate mask" in str(excinfo.value)
|
||||
)
|
||||
|
||||
|
||||
def test_trace_numpy_ufuncs_not_supported():
|
||||
"""Testing a failure case of trace_numpy_function"""
|
||||
inputs = {"x": EncryptedScalar(Integer(128, is_signed=True))}
|
||||
|
||||
# We really need a lambda (because numpy functions are not playing
|
||||
# nice with inspect.signature), but pylint and flake8 are not happy
|
||||
# with it
|
||||
function_to_trace = lambda x: numpy.add.reduce(x) # noqa: E731
|
||||
|
||||
with pytest.raises(NotImplementedError) as excinfo:
|
||||
tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
|
||||
assert "Only __call__ method is supported currently" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_trace_numpy_ufuncs_no_kwargs_no_extra_args():
|
||||
"""Test a case where kwargs are not allowed and too many inputs are passed"""
|
||||
inputs = {
|
||||
"x": EncryptedScalar(Integer(32, is_signed=True)),
|
||||
"y": EncryptedScalar(Integer(32, is_signed=True)),
|
||||
"z": EncryptedScalar(Integer(32, is_signed=True)),
|
||||
}
|
||||
|
||||
# We really need a lambda (because numpy functions are not playing
|
||||
# nice with inspect.signature), but pylint and flake8 are not happy
|
||||
# with it
|
||||
function_to_trace = lambda x, y, z: numpy.add(x, y, z) # noqa: E731
|
||||
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
|
||||
# numpy only passes ufunc.nin tracers so the extra arguments are passed as kwargs
|
||||
assert "**kwargs are currently not supported for numpy ufuncs, ufunc: add" in str(excinfo.value)
|
||||
|
||||
# We really need a lambda (because numpy functions are not playing
|
||||
# nice with inspect.signature), but pylint and flake8 are not happy
|
||||
# with it
|
||||
function_to_trace = lambda x, y, z: numpy.add(x, y, out=z) # noqa: E731
|
||||
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
|
||||
assert "**kwargs are currently not supported for numpy ufuncs, ufunc: add" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_nptracer_get_tracing_func_for_np_functions_not_implemented():
|
||||
"""Check NPTracer in case of not-implemented function"""
|
||||
with pytest.raises(NotImplementedError) as excinfo:
|
||||
tracing.NPTracer.get_tracing_func_for_np_function(numpy.conjugate)
|
||||
|
||||
assert "NPTracer does not yet manage the following func: conjugate" in str(excinfo.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"operation,exception_type,match",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: x + "fail",
|
||||
TypeError,
|
||||
"unsupported operand type(s) for +: 'NPTracer' and 'str'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: "fail" + x,
|
||||
TypeError,
|
||||
'can only concatenate str (not "NPTracer") to str',
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x - "fail",
|
||||
TypeError,
|
||||
"unsupported operand type(s) for -: 'NPTracer' and 'str'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: "fail" - x,
|
||||
TypeError,
|
||||
"unsupported operand type(s) for -: 'str' and 'NPTracer'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x * "fail",
|
||||
TypeError,
|
||||
"can't multiply sequence by non-int of type 'NPTracer'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: "fail" * x,
|
||||
TypeError,
|
||||
"can't multiply sequence by non-int of type 'NPTracer'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x / "fail",
|
||||
TypeError,
|
||||
"unsupported operand type(s) for /: 'NPTracer' and 'str'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: "fail" / x,
|
||||
TypeError,
|
||||
"unsupported operand type(s) for /: 'str' and 'NPTracer'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x // "fail",
|
||||
TypeError,
|
||||
"unsupported operand type(s) for //: 'NPTracer' and 'str'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: "fail" // x,
|
||||
TypeError,
|
||||
"unsupported operand type(s) for //: 'str' and 'NPTracer'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x / y, NotImplementedError, "Can't manage binary operator truediv"
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x // y, NotImplementedError, "Can't manage binary operator floordiv"
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_nptracer_unsupported_operands(operation, exception_type, match):
|
||||
"""Test cases where NPTracer cannot be used with other operands."""
|
||||
tracers = [
|
||||
tracing.NPTracer([], ir.Input(ClearScalar(Integer(32, True)), param_name, idx), 0)
|
||||
for idx, param_name in enumerate(inspect.signature(operation).parameters.keys())
|
||||
]
|
||||
|
||||
with pytest.raises(exception_type) as exc_info:
|
||||
_ = operation(*tracers)
|
||||
|
||||
assert match in str(exc_info)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lambda_f,params",
|
||||
[
|
||||
(
|
||||
lambda x: numpy.reshape(x, (5, 3)),
|
||||
{
|
||||
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(7, 5)),
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_errors_with_generic_function(lambda_f, params):
|
||||
"Test some errors with generic function"
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
assert "shapes are not compatible (old shape (7, 5), new shape (5, 3))" in str(excinfo.value)
|
||||
Reference in New Issue
Block a user