feat: implement checking coherence between inputset and parameters

This commit is contained in:
Umut
2021-09-22 12:10:11 +03:00
committed by Arthur Meyre
parent eaf8cfb933
commit 0061e01d62
10 changed files with 360 additions and 28 deletions

View File

@@ -1,3 +1,3 @@
"""Module for shared data structures and code."""
from . import compilation, data_types, debugging, representation
from . import compilation, data_types, debugging, representation, values
from .common_helpers import check_op_graph_is_integer_program, is_a_power_of_2

View File

@@ -1,17 +1,105 @@
"""Code to evaluate the IR graph on inputsets."""
import sys
from typing import Any, Callable, Dict, Iterable, Tuple
from ..compilation import CompilationConfiguration
from ..data_types.dtypes_helpers import (
get_base_value_for_python_constant_data,
is_data_type_compatible_with,
)
from ..debugging import custom_assert
from ..operator_graph import OPGraph
from ..representation.intermediate import IntermediateNode
def _check_input_coherency(
input_to_check: Dict[str, Any],
parameters: Dict[str, Any],
get_base_value_for_constant_data_func: Callable[[Any], Any],
):
"""Check whether `input_to_check` is coherent with `parameters`.
This function works by iterating over each constant of the input,
determining base value of the constant using `get_base_value_for_constant_data_func` and
checking if the base value of the contant is compatible with the base value of the parameter.
Args:
input_to_check (Dict[str, Any]): input to check coherency of
parameters (Dict[str, Any]): parameters and their expected base values
get_base_value_for_constant_data_func (Callable[[Any], Any]):
function to get the base value of python objects.
Returns:
List[str]: List of warnings about the coherency
"""
warnings = []
for parameter_name, value in input_to_check.items():
parameter_base_value = parameters[parameter_name]
base_value_class = get_base_value_for_constant_data_func(value)
base_value = base_value_class(is_encrypted=parameter_base_value.is_encrypted)
if base_value.shape != parameter_base_value.shape or not is_data_type_compatible_with(
base_value.data_type, parameter_base_value.data_type
):
warnings.append(
f"expected {str(parameter_base_value)} "
f"for parameter `{parameter_name}` "
f"but got {str(base_value)} "
f"which is not compatible"
)
return warnings
def _print_input_coherency_warnings(
current_input_index: int,
current_input_data: Dict[int, Any],
parameters: Dict[str, Any],
parameter_index_to_parameter_name: Dict[int, str],
get_base_value_for_constant_data_func: Callable[[Any], Any],
):
"""Print coherency warning for `input_to_check` against `parameters`.
Args:
current_input_index (int): index of the current input on the inputset
current_input_data (Dict[int, Any]): input to print coherency warnings of
parameters (Dict[str, Any]): parameters and their expected base values
parameter_index_to_parameter_name (Dict[int, str]):
dict to get parameter names from parameter indices
get_base_value_for_constant_data_func (Callable[[Any], Any]):
function to get the base value of python objects.
Returns:
None
"""
current_input_named_data = {
parameter_index_to_parameter_name[index]: data for index, data in current_input_data.items()
}
problems = _check_input_coherency(
current_input_named_data,
parameters,
get_base_value_for_constant_data_func,
)
for problem in problems:
sys.stderr.write(
f"Warning: Input #{current_input_index} (0-indexed) "
f"is not coherent with the hinted parameters ({problem})\n",
)
def eval_op_graph_bounds_on_inputset(
op_graph: OPGraph,
inputset: Iterable[Tuple[Any, ...]],
compilation_configuration: CompilationConfiguration,
min_func: Callable[[Any, Any], Any] = min,
max_func: Callable[[Any, Any], Any] = max,
get_base_value_for_constant_data_func: Callable[
[Any], Any
] = get_base_value_for_python_constant_data,
) -> Tuple[int, Dict[IntermediateNode, Dict[str, Any]]]:
"""Evaluate the bounds with a inputset.
@@ -23,12 +111,16 @@ def eval_op_graph_bounds_on_inputset(
inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
needs to be an iterable on tuples which are of the same length than the number of
parameters in the function, and in the same order than these same parameters
compilation_configuration (CompilationConfiguration): Configuration object to use
during determining input checking strategy
min_func (Callable[[Any, Any], Any], optional): custom function to compute a scalar minimum
between two values that can be encountered during evaluation (for e.g. numpy or torch
tensors). Defaults to min.
max_func (Callable[[Any, Any], Any], optional): custom function to compute a scalar maximum
between two values that can be encountered during evaluation (for e.g. numpy or torch
tensors). Defaults to max.
get_base_value_for_constant_data_func (Callable[[Any], Any], optional): custom function
to compute the base value of a python object.
Returns:
Tuple[int, Dict[IntermediateNode, Dict[str, Any]]]: number of inputs in the inputset and
@@ -49,14 +141,29 @@ def eval_op_graph_bounds_on_inputset(
# TODO: do we want to check coherence between the input data type and the corresponding Input ir
# node expected data type ? Not considering bit_width as they may not make sense at this stage
inputset_size = 0
inputset_iterator = iter(inputset)
parameter_index_to_parameter_name = {
index: input_node.input_name for index, input_node in op_graph.input_nodes.items()
}
parameters = {
input_node.input_name: input_node.inputs[0] for input_node in op_graph.input_nodes.values()
}
first_input_data = dict(enumerate(next(inputset_iterator)))
inputset_iterator = iter(inputset)
inputset_size = 0
current_input_data = dict(enumerate(next(inputset_iterator)))
inputset_size += 1
check_inputset_input_len_is_valid(first_input_data.values())
first_output = op_graph.evaluate(first_input_data)
check_inputset_input_len_is_valid(current_input_data.values())
_print_input_coherency_warnings(
inputset_size - 1,
current_input_data,
parameters,
parameter_index_to_parameter_name,
get_base_value_for_constant_data_func,
)
first_output = op_graph.evaluate(current_input_data)
# We evaluate the min and max func to be able to resolve the tensors min and max rather than
# having the tensor itself as the stored min and max values.
@@ -68,7 +175,17 @@ def eval_op_graph_bounds_on_inputset(
for input_data in inputset_iterator:
inputset_size += 1
current_input_data = dict(enumerate(input_data))
check_inputset_input_len_is_valid(current_input_data.values())
if compilation_configuration.check_every_input_in_inputset:
_print_input_coherency_warnings(
inputset_size - 1,
current_input_data,
parameters,
parameter_index_to_parameter_name,
get_base_value_for_constant_data_func,
)
current_output = op_graph.evaluate(current_input_data)
for node, value in current_output.items():
node_bounds[node]["min"] = min_func(node_bounds[node]["min"], value)

View File

@@ -6,11 +6,14 @@ class CompilationConfiguration:
dump_artifacts_on_unexpected_failures: bool
enable_topological_optimizations: bool
check_every_input_in_inputset: bool
def __init__(
self,
dump_artifacts_on_unexpected_failures: bool = True,
enable_topological_optimizations: bool = True,
check_every_input_in_inputset: bool = False,
):
self.dump_artifacts_on_unexpected_failures = dump_artifacts_on_unexpected_failures
self.enable_topological_optimizations = enable_topological_optimizations
self.check_every_input_in_inputset = check_every_input_in_inputset

View File

@@ -2,7 +2,7 @@
from copy import deepcopy
from functools import partial
from typing import Callable, Union, cast
from typing import Callable, List, Union, cast
from ..debugging.custom_assert import custom_assert
from ..values import (
@@ -306,7 +306,9 @@ def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) ->
)
def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]) -> BaseDataType:
def get_base_data_type_for_python_constant_data(
constant_data: Union[int, float, List[int], List[float]]
) -> BaseDataType:
"""Determine the BaseDataType to hold the input constant data.
Args:
@@ -318,10 +320,17 @@ def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]
"""
constant_data_type: BaseDataType
custom_assert(
isinstance(constant_data, (int, float)),
isinstance(constant_data, (int, float, list)),
f"Unsupported constant data of type {type(constant_data)}",
)
if isinstance(constant_data, int):
if isinstance(constant_data, list):
custom_assert(len(constant_data) > 0, "Data type of empty list cannot be detected")
constant_data_type = get_base_data_type_for_python_constant_data(constant_data[0])
for value in constant_data:
other_data_type = get_base_data_type_for_python_constant_data(value)
constant_data_type = find_type_to_hold_both_lossy(constant_data_type, other_data_type)
elif isinstance(constant_data, int):
is_signed = constant_data < 0
constant_data_type = Integer(
get_bits_to_represent_value_as_integer(constant_data, is_signed), is_signed
@@ -332,22 +341,29 @@ def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]
def get_base_value_for_python_constant_data(
constant_data: Union[int, float]
) -> Callable[..., ScalarValue]:
"""Wrap the BaseDataType to hold the input constant data in a ScalarValue partial.
constant_data: Union[int, float, List[int], List[float]]
) -> Callable[..., BaseValue]:
"""Wrap the BaseDataType to hold the input constant data in BaseValue partial.
The returned object can then be instantiated as an Encrypted or Clear version of the ScalarValue
by calling it with the proper arguments forwarded to the ScalarValue `__init__` function
The returned object can then be instantiated as an Encrypted or Clear version
by calling it with the proper arguments forwarded to the BaseValue `__init__` function
Args:
constant_data (Union[int, float]): The constant data for which to determine the
corresponding ScalarValue and BaseDataType.
constant_data (Union[int, float, List[int], List[float]]): The constant data
for which to determine the corresponding Value.
Returns:
Callable[..., ScalarValue]: A partial object that will return the proper ScalarValue when
called with `encrypted` as keyword argument (forwarded to the ScalarValue `__init__`
Callable[..., BaseValue]: A partial object that will return the proper BaseValue when
called with `is_encrypted` as keyword argument (forwarded to the BaseValue `__init__`
method).
"""
if isinstance(constant_data, list):
assert len(constant_data) > 0
constant_shape = (len(constant_data),)
constant_data_type = get_base_data_type_for_python_constant_data(constant_data)
return partial(TensorValue, data_type=constant_data_type, shape=constant_shape)
constant_data_type = get_base_data_type_for_python_constant_data(constant_data)
return partial(ScalarValue, data_type=constant_data_type)
@@ -359,3 +375,28 @@ def get_type_constructor_for_python_constant_data(constant_data: Union[int, floa
constant_data (Any): The data for which we want to determine the type constructor.
"""
return type(constant_data)
def is_data_type_compatible_with(
dtype: BaseDataType,
other: BaseDataType,
) -> bool:
"""Determine whether dtype is compatible with other.
`dtype` being compatible with `other` means `other` can hold every value of `dtype`
(e.g., uint2 is compatible with uint4 and int4)
(e.g., int2 is compatible with int4 but not with uint4)
Note that this function is not symetric.
(e.g., uint2 is compatible with uint4, but uint4 is not compatible with uint2)
Args:
dtype (BaseDataType): dtype to check compatiblity
other (BaseDataType): dtype to check compatiblity against
Returns:
bool: Whether the dtype is compatible with other or not
"""
combination = find_type_to_hold_both_lossy(dtype, other)
return other == combination

View File

@@ -1,5 +1,7 @@
"""Module that defines the scalar values in a program."""
from typing import Tuple
from ..data_types.base import BaseDataType
from .base import BaseValue
@@ -14,6 +16,15 @@ class ScalarValue(BaseValue):
encrypted_str = "Encrypted" if self._is_encrypted else "Clear"
return f"{encrypted_str}Scalar<{self.data_type!r}>"
@property
def shape(self) -> Tuple[int, ...]:
"""Return the ScalarValue shape property.
Returns:
Tuple[int, ...]: The ScalarValue shape which is `()`.
"""
return ()
def make_clear_scalar(data_type: BaseDataType) -> ScalarValue:
"""Create a clear ScalarValue.

View File

@@ -24,6 +24,7 @@ from ..common.values import BaseValue
from ..numpy.tracing import trace_numpy_function
from .np_dtypes_helpers import (
get_base_data_type_for_numpy_or_python_constant_data,
get_base_value_for_numpy_or_python_constant_data,
get_type_constructor_for_numpy_or_python_constant_data,
)
@@ -112,8 +113,10 @@ def _compile_numpy_function_into_op_graph_internal(
inputset_size, node_bounds = eval_op_graph_bounds_on_inputset(
op_graph,
inputset,
compilation_configuration=compilation_configuration,
min_func=numpy_min_func,
max_func=numpy_max_func,
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
)
# Check inputset size
@@ -134,7 +137,7 @@ def _compile_numpy_function_into_op_graph_internal(
minimum_required_inputset_size = min(inputset_size_upper_limit, 10)
if inputset_size < minimum_required_inputset_size:
sys.stderr.write(
f"Provided inputset contains too few inputs "
f"Warning: Provided inputset contains too few inputs "
f"(it should have had at least {minimum_required_inputset_size} "
f"but it only had {inputset_size})\n"
)

View File

@@ -10,6 +10,7 @@ from numpy.typing import DTypeLike
from ..common.data_types.base import BaseDataType
from ..common.data_types.dtypes_helpers import (
BASE_DATA_TYPES,
find_type_to_hold_both_lossy,
get_base_data_type_for_python_constant_data,
get_base_value_for_python_constant_data,
get_type_constructor_for_python_constant_data,
@@ -116,12 +117,26 @@ def get_base_data_type_for_numpy_or_python_constant_data(constant_data: Any) ->
"""
base_dtype: BaseDataType
custom_assert(
isinstance(constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)),
isinstance(
constant_data, (int, float, list, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
),
f"Unsupported constant data of type {type(constant_data)}",
)
if isinstance(constant_data, (numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)):
native_type = (
float
if constant_data.dtype == numpy.float32 or constant_data.dtype == numpy.float64
else int
)
min_value = native_type(constant_data.min())
max_value = native_type(constant_data.max())
min_value_dtype = get_base_data_type_for_python_constant_data(min_value)
max_value_dtype = get_base_data_type_for_python_constant_data(max_value)
# numpy
base_dtype = convert_numpy_dtype_to_base_data_type(constant_data.dtype)
base_dtype = find_type_to_hold_both_lossy(min_value_dtype, max_value_dtype)
else:
# python
base_dtype = get_base_data_type_for_python_constant_data(constant_data)
@@ -148,7 +163,10 @@ def get_base_value_for_numpy_or_python_constant_data(
"""
constant_data_value: Callable[..., BaseValue]
custom_assert(
isinstance(constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)),
isinstance(
constant_data,
(int, float, list, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES),
),
f"Unsupported constant data of type {type(constant_data)}",
)

View File

@@ -2,12 +2,16 @@
from typing import Tuple
import numpy as np
import pytest
from concrete.common.bounds_measurement.inputset_eval import eval_op_graph_bounds_on_inputset
from concrete.common.compilation import CompilationConfiguration
from concrete.common.data_types.floats import Float
from concrete.common.data_types.integers import Integer
from concrete.common.values import EncryptedScalar
from concrete.common.data_types.integers import Integer, UnsignedInteger
from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor
from concrete.numpy.compile import numpy_max_func, numpy_min_func
from concrete.numpy.np_dtypes_helpers import get_base_value_for_numpy_or_python_constant_data
from concrete.numpy.tracing import trace_numpy_function
@@ -280,7 +284,9 @@ def test_eval_op_graph_bounds_on_inputset_multiple_output(
yield (x_gen, y_gen)
_, node_bounds = eval_op_graph_bounds_on_inputset(
op_graph, data_gen(*tuple(range(x[0], x[1] + 1) for x in input_ranges))
op_graph,
data_gen(*tuple(range(x[0], x[1] + 1) for x in input_ranges)),
CompilationConfiguration(),
)
for i, output_node in op_graph.output_nodes.items():
@@ -291,3 +297,137 @@ def test_eval_op_graph_bounds_on_inputset_multiple_output(
for i, output_node in op_graph.output_nodes.items():
assert expected_output_data_type[i] == output_node.outputs[0].data_type
def test_eval_op_graph_bounds_on_non_conformant_inputset_default(capsys):
"""Test function for eval_op_graph_bounds_on_inputset with non conformant inputset"""
def f(x, y):
return np.dot(x, y)
x = EncryptedTensor(UnsignedInteger(2), (3,))
y = ClearTensor(UnsignedInteger(2), (3,))
inputset = [
([2, 1, 3, 1], [1, 2, 1, 1]),
([3, 3, 3], [3, 3, 5]),
]
op_graph = trace_numpy_function(f, {"x": x, "y": y})
configuration = CompilationConfiguration()
eval_op_graph_bounds_on_inputset(op_graph, inputset, compilation_configuration=configuration)
captured = capsys.readouterr()
assert (
captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
"(expected EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `x` "
"but got EncryptedTensor<Integer<unsigned, 2 bits>, shape=(4,)> which is not compatible)\n"
"Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
"(expected ClearTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `y` "
"but got ClearTensor<Integer<unsigned, 2 bits>, shape=(4,)> which is not compatible)\n"
)
def test_eval_op_graph_bounds_on_non_conformant_inputset_check_all(capsys):
"""Test function for eval_op_graph_bounds_on_inputset with non conformant inputset, check all"""
def f(x, y):
return np.dot(x, y)
x = EncryptedTensor(UnsignedInteger(2), (3,))
y = ClearTensor(UnsignedInteger(2), (3,))
inputset = [
([2, 1, 3, 1], [1, 2, 1, 1]),
([3, 3, 3], [3, 3, 5]),
]
op_graph = trace_numpy_function(f, {"x": x, "y": y})
configuration = CompilationConfiguration(check_every_input_in_inputset=True)
eval_op_graph_bounds_on_inputset(op_graph, inputset, compilation_configuration=configuration)
captured = capsys.readouterr()
assert (
captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
"(expected EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `x` "
"but got EncryptedTensor<Integer<unsigned, 2 bits>, shape=(4,)> which is not compatible)\n"
"Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
"(expected ClearTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `y` "
"but got ClearTensor<Integer<unsigned, 2 bits>, shape=(4,)> which is not compatible)\n"
"Warning: Input #1 (0-indexed) is not coherent with the hinted parameters "
"(expected ClearTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `y` "
"but got ClearTensor<Integer<unsigned, 3 bits>, shape=(3,)> which is not compatible)\n"
)
def test_eval_op_graph_bounds_on_conformant_numpy_inputset_check_all(capsys):
"""Test function for eval_op_graph_bounds_on_inputset
with conformant inputset of numpy arrays, check all"""
def f(x, y):
return np.dot(x, y)
x = EncryptedTensor(UnsignedInteger(2), (3,))
y = ClearTensor(UnsignedInteger(2), (3,))
inputset = [
(np.array([2, 1, 3]), np.array([1, 2, 1])),
(np.array([3, 3, 3]), np.array([3, 3, 1])),
]
op_graph = trace_numpy_function(f, {"x": x, "y": y})
configuration = CompilationConfiguration(check_every_input_in_inputset=True)
eval_op_graph_bounds_on_inputset(
op_graph,
inputset,
compilation_configuration=configuration,
min_func=numpy_min_func,
max_func=numpy_max_func,
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
)
captured = capsys.readouterr()
assert captured.err == ""
def test_eval_op_graph_bounds_on_non_conformant_numpy_inputset_check_all(capsys):
"""Test function for eval_op_graph_bounds_on_inputset with non conformant inputset, check all"""
def f(x, y):
return np.dot(x, y)
x = EncryptedTensor(UnsignedInteger(2), (3,))
y = ClearTensor(UnsignedInteger(2), (3,))
inputset = [
(np.array([2, 1, 3, 1]), np.array([1, 2, 1, 1])),
(np.array([3, 3, 3]), np.array([3, 3, 5])),
]
op_graph = trace_numpy_function(f, {"x": x, "y": y})
configuration = CompilationConfiguration(check_every_input_in_inputset=True)
eval_op_graph_bounds_on_inputset(
op_graph,
inputset,
compilation_configuration=configuration,
min_func=numpy_min_func,
max_func=numpy_max_func,
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
)
captured = capsys.readouterr()
assert (
captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
"(expected EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `x` "
"but got EncryptedTensor<Integer<unsigned, 2 bits>, shape=(4,)> which is not compatible)\n"
"Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
"(expected ClearTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `y` "
"but got ClearTensor<Integer<unsigned, 2 bits>, shape=(4,)> which is not compatible)\n"
"Warning: Input #1 (0-indexed) is not coherent with the hinted parameters "
"(expected ClearTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `y` "
"but got ClearTensor<Integer<unsigned, 3 bits>, shape=(3,)> which is not compatible)\n"
)

View File

@@ -1,5 +1,4 @@
"""Test file for data types helpers"""
import pytest
from concrete.common.data_types.base import BaseDataType

View File

@@ -303,7 +303,7 @@ def test_compile_function_with_dot(function, params, shape, ref_graph_str):
iter_i = itertools.product(range(0, max_for_ij + 1), repeat=repeat)
iter_j = itertools.product(range(0, max_for_ij + 1), repeat=repeat)
for prod_i, prod_j in itertools.product(iter_i, iter_j):
yield (prod_i, prod_j)
yield (list(prod_i), list(prod_j))
max_for_ij = 3
assert len(shape) == 1