mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: implement checking coherence between inputset and parameters
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
"""Module for shared data structures and code."""
|
||||
from . import compilation, data_types, debugging, representation
|
||||
from . import compilation, data_types, debugging, representation, values
|
||||
from .common_helpers import check_op_graph_is_integer_program, is_a_power_of_2
|
||||
|
||||
@@ -1,17 +1,105 @@
|
||||
"""Code to evaluate the IR graph on inputsets."""
|
||||
|
||||
import sys
|
||||
from typing import Any, Callable, Dict, Iterable, Tuple
|
||||
|
||||
from ..compilation import CompilationConfiguration
|
||||
from ..data_types.dtypes_helpers import (
|
||||
get_base_value_for_python_constant_data,
|
||||
is_data_type_compatible_with,
|
||||
)
|
||||
from ..debugging import custom_assert
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import IntermediateNode
|
||||
|
||||
|
||||
def _check_input_coherency(
|
||||
input_to_check: Dict[str, Any],
|
||||
parameters: Dict[str, Any],
|
||||
get_base_value_for_constant_data_func: Callable[[Any], Any],
|
||||
):
|
||||
"""Check whether `input_to_check` is coherent with `parameters`.
|
||||
|
||||
This function works by iterating over each constant of the input,
|
||||
determining base value of the constant using `get_base_value_for_constant_data_func` and
|
||||
checking if the base value of the contant is compatible with the base value of the parameter.
|
||||
|
||||
Args:
|
||||
input_to_check (Dict[str, Any]): input to check coherency of
|
||||
parameters (Dict[str, Any]): parameters and their expected base values
|
||||
get_base_value_for_constant_data_func (Callable[[Any], Any]):
|
||||
function to get the base value of python objects.
|
||||
|
||||
Returns:
|
||||
List[str]: List of warnings about the coherency
|
||||
"""
|
||||
|
||||
warnings = []
|
||||
for parameter_name, value in input_to_check.items():
|
||||
parameter_base_value = parameters[parameter_name]
|
||||
|
||||
base_value_class = get_base_value_for_constant_data_func(value)
|
||||
base_value = base_value_class(is_encrypted=parameter_base_value.is_encrypted)
|
||||
|
||||
if base_value.shape != parameter_base_value.shape or not is_data_type_compatible_with(
|
||||
base_value.data_type, parameter_base_value.data_type
|
||||
):
|
||||
warnings.append(
|
||||
f"expected {str(parameter_base_value)} "
|
||||
f"for parameter `{parameter_name}` "
|
||||
f"but got {str(base_value)} "
|
||||
f"which is not compatible"
|
||||
)
|
||||
return warnings
|
||||
|
||||
|
||||
def _print_input_coherency_warnings(
|
||||
current_input_index: int,
|
||||
current_input_data: Dict[int, Any],
|
||||
parameters: Dict[str, Any],
|
||||
parameter_index_to_parameter_name: Dict[int, str],
|
||||
get_base_value_for_constant_data_func: Callable[[Any], Any],
|
||||
):
|
||||
"""Print coherency warning for `input_to_check` against `parameters`.
|
||||
|
||||
Args:
|
||||
current_input_index (int): index of the current input on the inputset
|
||||
current_input_data (Dict[int, Any]): input to print coherency warnings of
|
||||
parameters (Dict[str, Any]): parameters and their expected base values
|
||||
parameter_index_to_parameter_name (Dict[int, str]):
|
||||
dict to get parameter names from parameter indices
|
||||
get_base_value_for_constant_data_func (Callable[[Any], Any]):
|
||||
function to get the base value of python objects.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
current_input_named_data = {
|
||||
parameter_index_to_parameter_name[index]: data for index, data in current_input_data.items()
|
||||
}
|
||||
|
||||
problems = _check_input_coherency(
|
||||
current_input_named_data,
|
||||
parameters,
|
||||
get_base_value_for_constant_data_func,
|
||||
)
|
||||
for problem in problems:
|
||||
sys.stderr.write(
|
||||
f"Warning: Input #{current_input_index} (0-indexed) "
|
||||
f"is not coherent with the hinted parameters ({problem})\n",
|
||||
)
|
||||
|
||||
|
||||
def eval_op_graph_bounds_on_inputset(
|
||||
op_graph: OPGraph,
|
||||
inputset: Iterable[Tuple[Any, ...]],
|
||||
compilation_configuration: CompilationConfiguration,
|
||||
min_func: Callable[[Any, Any], Any] = min,
|
||||
max_func: Callable[[Any, Any], Any] = max,
|
||||
get_base_value_for_constant_data_func: Callable[
|
||||
[Any], Any
|
||||
] = get_base_value_for_python_constant_data,
|
||||
) -> Tuple[int, Dict[IntermediateNode, Dict[str, Any]]]:
|
||||
"""Evaluate the bounds with a inputset.
|
||||
|
||||
@@ -23,12 +111,16 @@ def eval_op_graph_bounds_on_inputset(
|
||||
inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
|
||||
needs to be an iterable on tuples which are of the same length than the number of
|
||||
parameters in the function, and in the same order than these same parameters
|
||||
compilation_configuration (CompilationConfiguration): Configuration object to use
|
||||
during determining input checking strategy
|
||||
min_func (Callable[[Any, Any], Any], optional): custom function to compute a scalar minimum
|
||||
between two values that can be encountered during evaluation (for e.g. numpy or torch
|
||||
tensors). Defaults to min.
|
||||
max_func (Callable[[Any, Any], Any], optional): custom function to compute a scalar maximum
|
||||
between two values that can be encountered during evaluation (for e.g. numpy or torch
|
||||
tensors). Defaults to max.
|
||||
get_base_value_for_constant_data_func (Callable[[Any], Any], optional): custom function
|
||||
to compute the base value of a python object.
|
||||
|
||||
Returns:
|
||||
Tuple[int, Dict[IntermediateNode, Dict[str, Any]]]: number of inputs in the inputset and
|
||||
@@ -49,14 +141,29 @@ def eval_op_graph_bounds_on_inputset(
|
||||
# TODO: do we want to check coherence between the input data type and the corresponding Input ir
|
||||
# node expected data type ? Not considering bit_width as they may not make sense at this stage
|
||||
|
||||
inputset_size = 0
|
||||
inputset_iterator = iter(inputset)
|
||||
parameter_index_to_parameter_name = {
|
||||
index: input_node.input_name for index, input_node in op_graph.input_nodes.items()
|
||||
}
|
||||
parameters = {
|
||||
input_node.input_name: input_node.inputs[0] for input_node in op_graph.input_nodes.values()
|
||||
}
|
||||
|
||||
first_input_data = dict(enumerate(next(inputset_iterator)))
|
||||
inputset_iterator = iter(inputset)
|
||||
inputset_size = 0
|
||||
|
||||
current_input_data = dict(enumerate(next(inputset_iterator)))
|
||||
inputset_size += 1
|
||||
|
||||
check_inputset_input_len_is_valid(first_input_data.values())
|
||||
first_output = op_graph.evaluate(first_input_data)
|
||||
check_inputset_input_len_is_valid(current_input_data.values())
|
||||
_print_input_coherency_warnings(
|
||||
inputset_size - 1,
|
||||
current_input_data,
|
||||
parameters,
|
||||
parameter_index_to_parameter_name,
|
||||
get_base_value_for_constant_data_func,
|
||||
)
|
||||
|
||||
first_output = op_graph.evaluate(current_input_data)
|
||||
|
||||
# We evaluate the min and max func to be able to resolve the tensors min and max rather than
|
||||
# having the tensor itself as the stored min and max values.
|
||||
@@ -68,7 +175,17 @@ def eval_op_graph_bounds_on_inputset(
|
||||
for input_data in inputset_iterator:
|
||||
inputset_size += 1
|
||||
current_input_data = dict(enumerate(input_data))
|
||||
|
||||
check_inputset_input_len_is_valid(current_input_data.values())
|
||||
if compilation_configuration.check_every_input_in_inputset:
|
||||
_print_input_coherency_warnings(
|
||||
inputset_size - 1,
|
||||
current_input_data,
|
||||
parameters,
|
||||
parameter_index_to_parameter_name,
|
||||
get_base_value_for_constant_data_func,
|
||||
)
|
||||
|
||||
current_output = op_graph.evaluate(current_input_data)
|
||||
for node, value in current_output.items():
|
||||
node_bounds[node]["min"] = min_func(node_bounds[node]["min"], value)
|
||||
|
||||
@@ -6,11 +6,14 @@ class CompilationConfiguration:
|
||||
|
||||
dump_artifacts_on_unexpected_failures: bool
|
||||
enable_topological_optimizations: bool
|
||||
check_every_input_in_inputset: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dump_artifacts_on_unexpected_failures: bool = True,
|
||||
enable_topological_optimizations: bool = True,
|
||||
check_every_input_in_inputset: bool = False,
|
||||
):
|
||||
self.dump_artifacts_on_unexpected_failures = dump_artifacts_on_unexpected_failures
|
||||
self.enable_topological_optimizations = enable_topological_optimizations
|
||||
self.check_every_input_in_inputset = check_every_input_in_inputset
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Callable, Union, cast
|
||||
from typing import Callable, List, Union, cast
|
||||
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from ..values import (
|
||||
@@ -306,7 +306,9 @@ def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) ->
|
||||
)
|
||||
|
||||
|
||||
def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]) -> BaseDataType:
|
||||
def get_base_data_type_for_python_constant_data(
|
||||
constant_data: Union[int, float, List[int], List[float]]
|
||||
) -> BaseDataType:
|
||||
"""Determine the BaseDataType to hold the input constant data.
|
||||
|
||||
Args:
|
||||
@@ -318,10 +320,17 @@ def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]
|
||||
"""
|
||||
constant_data_type: BaseDataType
|
||||
custom_assert(
|
||||
isinstance(constant_data, (int, float)),
|
||||
isinstance(constant_data, (int, float, list)),
|
||||
f"Unsupported constant data of type {type(constant_data)}",
|
||||
)
|
||||
if isinstance(constant_data, int):
|
||||
|
||||
if isinstance(constant_data, list):
|
||||
custom_assert(len(constant_data) > 0, "Data type of empty list cannot be detected")
|
||||
constant_data_type = get_base_data_type_for_python_constant_data(constant_data[0])
|
||||
for value in constant_data:
|
||||
other_data_type = get_base_data_type_for_python_constant_data(value)
|
||||
constant_data_type = find_type_to_hold_both_lossy(constant_data_type, other_data_type)
|
||||
elif isinstance(constant_data, int):
|
||||
is_signed = constant_data < 0
|
||||
constant_data_type = Integer(
|
||||
get_bits_to_represent_value_as_integer(constant_data, is_signed), is_signed
|
||||
@@ -332,22 +341,29 @@ def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]
|
||||
|
||||
|
||||
def get_base_value_for_python_constant_data(
|
||||
constant_data: Union[int, float]
|
||||
) -> Callable[..., ScalarValue]:
|
||||
"""Wrap the BaseDataType to hold the input constant data in a ScalarValue partial.
|
||||
constant_data: Union[int, float, List[int], List[float]]
|
||||
) -> Callable[..., BaseValue]:
|
||||
"""Wrap the BaseDataType to hold the input constant data in BaseValue partial.
|
||||
|
||||
The returned object can then be instantiated as an Encrypted or Clear version of the ScalarValue
|
||||
by calling it with the proper arguments forwarded to the ScalarValue `__init__` function
|
||||
The returned object can then be instantiated as an Encrypted or Clear version
|
||||
by calling it with the proper arguments forwarded to the BaseValue `__init__` function
|
||||
|
||||
Args:
|
||||
constant_data (Union[int, float]): The constant data for which to determine the
|
||||
corresponding ScalarValue and BaseDataType.
|
||||
constant_data (Union[int, float, List[int], List[float]]): The constant data
|
||||
for which to determine the corresponding Value.
|
||||
|
||||
Returns:
|
||||
Callable[..., ScalarValue]: A partial object that will return the proper ScalarValue when
|
||||
called with `encrypted` as keyword argument (forwarded to the ScalarValue `__init__`
|
||||
Callable[..., BaseValue]: A partial object that will return the proper BaseValue when
|
||||
called with `is_encrypted` as keyword argument (forwarded to the BaseValue `__init__`
|
||||
method).
|
||||
"""
|
||||
|
||||
if isinstance(constant_data, list):
|
||||
assert len(constant_data) > 0
|
||||
constant_shape = (len(constant_data),)
|
||||
constant_data_type = get_base_data_type_for_python_constant_data(constant_data)
|
||||
return partial(TensorValue, data_type=constant_data_type, shape=constant_shape)
|
||||
|
||||
constant_data_type = get_base_data_type_for_python_constant_data(constant_data)
|
||||
return partial(ScalarValue, data_type=constant_data_type)
|
||||
|
||||
@@ -359,3 +375,28 @@ def get_type_constructor_for_python_constant_data(constant_data: Union[int, floa
|
||||
constant_data (Any): The data for which we want to determine the type constructor.
|
||||
"""
|
||||
return type(constant_data)
|
||||
|
||||
|
||||
def is_data_type_compatible_with(
|
||||
dtype: BaseDataType,
|
||||
other: BaseDataType,
|
||||
) -> bool:
|
||||
"""Determine whether dtype is compatible with other.
|
||||
|
||||
`dtype` being compatible with `other` means `other` can hold every value of `dtype`
|
||||
(e.g., uint2 is compatible with uint4 and int4)
|
||||
(e.g., int2 is compatible with int4 but not with uint4)
|
||||
|
||||
Note that this function is not symetric.
|
||||
(e.g., uint2 is compatible with uint4, but uint4 is not compatible with uint2)
|
||||
|
||||
Args:
|
||||
dtype (BaseDataType): dtype to check compatiblity
|
||||
other (BaseDataType): dtype to check compatiblity against
|
||||
|
||||
Returns:
|
||||
bool: Whether the dtype is compatible with other or not
|
||||
"""
|
||||
|
||||
combination = find_type_to_hold_both_lossy(dtype, other)
|
||||
return other == combination
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Module that defines the scalar values in a program."""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from ..data_types.base import BaseDataType
|
||||
from .base import BaseValue
|
||||
|
||||
@@ -14,6 +16,15 @@ class ScalarValue(BaseValue):
|
||||
encrypted_str = "Encrypted" if self._is_encrypted else "Clear"
|
||||
return f"{encrypted_str}Scalar<{self.data_type!r}>"
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
"""Return the ScalarValue shape property.
|
||||
|
||||
Returns:
|
||||
Tuple[int, ...]: The ScalarValue shape which is `()`.
|
||||
"""
|
||||
return ()
|
||||
|
||||
|
||||
def make_clear_scalar(data_type: BaseDataType) -> ScalarValue:
|
||||
"""Create a clear ScalarValue.
|
||||
|
||||
@@ -24,6 +24,7 @@ from ..common.values import BaseValue
|
||||
from ..numpy.tracing import trace_numpy_function
|
||||
from .np_dtypes_helpers import (
|
||||
get_base_data_type_for_numpy_or_python_constant_data,
|
||||
get_base_value_for_numpy_or_python_constant_data,
|
||||
get_type_constructor_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
@@ -112,8 +113,10 @@ def _compile_numpy_function_into_op_graph_internal(
|
||||
inputset_size, node_bounds = eval_op_graph_bounds_on_inputset(
|
||||
op_graph,
|
||||
inputset,
|
||||
compilation_configuration=compilation_configuration,
|
||||
min_func=numpy_min_func,
|
||||
max_func=numpy_max_func,
|
||||
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
# Check inputset size
|
||||
@@ -134,7 +137,7 @@ def _compile_numpy_function_into_op_graph_internal(
|
||||
minimum_required_inputset_size = min(inputset_size_upper_limit, 10)
|
||||
if inputset_size < minimum_required_inputset_size:
|
||||
sys.stderr.write(
|
||||
f"Provided inputset contains too few inputs "
|
||||
f"Warning: Provided inputset contains too few inputs "
|
||||
f"(it should have had at least {minimum_required_inputset_size} "
|
||||
f"but it only had {inputset_size})\n"
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ from numpy.typing import DTypeLike
|
||||
from ..common.data_types.base import BaseDataType
|
||||
from ..common.data_types.dtypes_helpers import (
|
||||
BASE_DATA_TYPES,
|
||||
find_type_to_hold_both_lossy,
|
||||
get_base_data_type_for_python_constant_data,
|
||||
get_base_value_for_python_constant_data,
|
||||
get_type_constructor_for_python_constant_data,
|
||||
@@ -116,12 +117,26 @@ def get_base_data_type_for_numpy_or_python_constant_data(constant_data: Any) ->
|
||||
"""
|
||||
base_dtype: BaseDataType
|
||||
custom_assert(
|
||||
isinstance(constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)),
|
||||
isinstance(
|
||||
constant_data, (int, float, list, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
|
||||
),
|
||||
f"Unsupported constant data of type {type(constant_data)}",
|
||||
)
|
||||
if isinstance(constant_data, (numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)):
|
||||
native_type = (
|
||||
float
|
||||
if constant_data.dtype == numpy.float32 or constant_data.dtype == numpy.float64
|
||||
else int
|
||||
)
|
||||
|
||||
min_value = native_type(constant_data.min())
|
||||
max_value = native_type(constant_data.max())
|
||||
|
||||
min_value_dtype = get_base_data_type_for_python_constant_data(min_value)
|
||||
max_value_dtype = get_base_data_type_for_python_constant_data(max_value)
|
||||
|
||||
# numpy
|
||||
base_dtype = convert_numpy_dtype_to_base_data_type(constant_data.dtype)
|
||||
base_dtype = find_type_to_hold_both_lossy(min_value_dtype, max_value_dtype)
|
||||
else:
|
||||
# python
|
||||
base_dtype = get_base_data_type_for_python_constant_data(constant_data)
|
||||
@@ -148,7 +163,10 @@ def get_base_value_for_numpy_or_python_constant_data(
|
||||
"""
|
||||
constant_data_value: Callable[..., BaseValue]
|
||||
custom_assert(
|
||||
isinstance(constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)),
|
||||
isinstance(
|
||||
constant_data,
|
||||
(int, float, list, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES),
|
||||
),
|
||||
f"Unsupported constant data of type {type(constant_data)}",
|
||||
)
|
||||
|
||||
|
||||
@@ -2,12 +2,16 @@
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from concrete.common.bounds_measurement.inputset_eval import eval_op_graph_bounds_on_inputset
|
||||
from concrete.common.compilation import CompilationConfiguration
|
||||
from concrete.common.data_types.floats import Float
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.values import EncryptedScalar
|
||||
from concrete.common.data_types.integers import Integer, UnsignedInteger
|
||||
from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy.compile import numpy_max_func, numpy_min_func
|
||||
from concrete.numpy.np_dtypes_helpers import get_base_value_for_numpy_or_python_constant_data
|
||||
from concrete.numpy.tracing import trace_numpy_function
|
||||
|
||||
|
||||
@@ -280,7 +284,9 @@ def test_eval_op_graph_bounds_on_inputset_multiple_output(
|
||||
yield (x_gen, y_gen)
|
||||
|
||||
_, node_bounds = eval_op_graph_bounds_on_inputset(
|
||||
op_graph, data_gen(*tuple(range(x[0], x[1] + 1) for x in input_ranges))
|
||||
op_graph,
|
||||
data_gen(*tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
CompilationConfiguration(),
|
||||
)
|
||||
|
||||
for i, output_node in op_graph.output_nodes.items():
|
||||
@@ -291,3 +297,137 @@ def test_eval_op_graph_bounds_on_inputset_multiple_output(
|
||||
|
||||
for i, output_node in op_graph.output_nodes.items():
|
||||
assert expected_output_data_type[i] == output_node.outputs[0].data_type
|
||||
|
||||
|
||||
def test_eval_op_graph_bounds_on_non_conformant_inputset_default(capsys):
|
||||
"""Test function for eval_op_graph_bounds_on_inputset with non conformant inputset"""
|
||||
|
||||
def f(x, y):
|
||||
return np.dot(x, y)
|
||||
|
||||
x = EncryptedTensor(UnsignedInteger(2), (3,))
|
||||
y = ClearTensor(UnsignedInteger(2), (3,))
|
||||
|
||||
inputset = [
|
||||
([2, 1, 3, 1], [1, 2, 1, 1]),
|
||||
([3, 3, 3], [3, 3, 5]),
|
||||
]
|
||||
|
||||
op_graph = trace_numpy_function(f, {"x": x, "y": y})
|
||||
|
||||
configuration = CompilationConfiguration()
|
||||
eval_op_graph_bounds_on_inputset(op_graph, inputset, compilation_configuration=configuration)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert (
|
||||
captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
|
||||
"(expected EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `x` "
|
||||
"but got EncryptedTensor<Integer<unsigned, 2 bits>, shape=(4,)> which is not compatible)\n"
|
||||
"Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
|
||||
"(expected ClearTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `y` "
|
||||
"but got ClearTensor<Integer<unsigned, 2 bits>, shape=(4,)> which is not compatible)\n"
|
||||
)
|
||||
|
||||
|
||||
def test_eval_op_graph_bounds_on_non_conformant_inputset_check_all(capsys):
|
||||
"""Test function for eval_op_graph_bounds_on_inputset with non conformant inputset, check all"""
|
||||
|
||||
def f(x, y):
|
||||
return np.dot(x, y)
|
||||
|
||||
x = EncryptedTensor(UnsignedInteger(2), (3,))
|
||||
y = ClearTensor(UnsignedInteger(2), (3,))
|
||||
|
||||
inputset = [
|
||||
([2, 1, 3, 1], [1, 2, 1, 1]),
|
||||
([3, 3, 3], [3, 3, 5]),
|
||||
]
|
||||
|
||||
op_graph = trace_numpy_function(f, {"x": x, "y": y})
|
||||
|
||||
configuration = CompilationConfiguration(check_every_input_in_inputset=True)
|
||||
eval_op_graph_bounds_on_inputset(op_graph, inputset, compilation_configuration=configuration)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert (
|
||||
captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
|
||||
"(expected EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `x` "
|
||||
"but got EncryptedTensor<Integer<unsigned, 2 bits>, shape=(4,)> which is not compatible)\n"
|
||||
"Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
|
||||
"(expected ClearTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `y` "
|
||||
"but got ClearTensor<Integer<unsigned, 2 bits>, shape=(4,)> which is not compatible)\n"
|
||||
"Warning: Input #1 (0-indexed) is not coherent with the hinted parameters "
|
||||
"(expected ClearTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `y` "
|
||||
"but got ClearTensor<Integer<unsigned, 3 bits>, shape=(3,)> which is not compatible)\n"
|
||||
)
|
||||
|
||||
|
||||
def test_eval_op_graph_bounds_on_conformant_numpy_inputset_check_all(capsys):
|
||||
"""Test function for eval_op_graph_bounds_on_inputset
|
||||
with conformant inputset of numpy arrays, check all"""
|
||||
|
||||
def f(x, y):
|
||||
return np.dot(x, y)
|
||||
|
||||
x = EncryptedTensor(UnsignedInteger(2), (3,))
|
||||
y = ClearTensor(UnsignedInteger(2), (3,))
|
||||
|
||||
inputset = [
|
||||
(np.array([2, 1, 3]), np.array([1, 2, 1])),
|
||||
(np.array([3, 3, 3]), np.array([3, 3, 1])),
|
||||
]
|
||||
|
||||
op_graph = trace_numpy_function(f, {"x": x, "y": y})
|
||||
|
||||
configuration = CompilationConfiguration(check_every_input_in_inputset=True)
|
||||
eval_op_graph_bounds_on_inputset(
|
||||
op_graph,
|
||||
inputset,
|
||||
compilation_configuration=configuration,
|
||||
min_func=numpy_min_func,
|
||||
max_func=numpy_max_func,
|
||||
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert captured.err == ""
|
||||
|
||||
|
||||
def test_eval_op_graph_bounds_on_non_conformant_numpy_inputset_check_all(capsys):
|
||||
"""Test function for eval_op_graph_bounds_on_inputset with non conformant inputset, check all"""
|
||||
|
||||
def f(x, y):
|
||||
return np.dot(x, y)
|
||||
|
||||
x = EncryptedTensor(UnsignedInteger(2), (3,))
|
||||
y = ClearTensor(UnsignedInteger(2), (3,))
|
||||
|
||||
inputset = [
|
||||
(np.array([2, 1, 3, 1]), np.array([1, 2, 1, 1])),
|
||||
(np.array([3, 3, 3]), np.array([3, 3, 5])),
|
||||
]
|
||||
|
||||
op_graph = trace_numpy_function(f, {"x": x, "y": y})
|
||||
|
||||
configuration = CompilationConfiguration(check_every_input_in_inputset=True)
|
||||
eval_op_graph_bounds_on_inputset(
|
||||
op_graph,
|
||||
inputset,
|
||||
compilation_configuration=configuration,
|
||||
min_func=numpy_min_func,
|
||||
max_func=numpy_max_func,
|
||||
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert (
|
||||
captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
|
||||
"(expected EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `x` "
|
||||
"but got EncryptedTensor<Integer<unsigned, 2 bits>, shape=(4,)> which is not compatible)\n"
|
||||
"Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
|
||||
"(expected ClearTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `y` "
|
||||
"but got ClearTensor<Integer<unsigned, 2 bits>, shape=(4,)> which is not compatible)\n"
|
||||
"Warning: Input #1 (0-indexed) is not coherent with the hinted parameters "
|
||||
"(expected ClearTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `y` "
|
||||
"but got ClearTensor<Integer<unsigned, 3 bits>, shape=(3,)> which is not compatible)\n"
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""Test file for data types helpers"""
|
||||
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types.base import BaseDataType
|
||||
|
||||
@@ -303,7 +303,7 @@ def test_compile_function_with_dot(function, params, shape, ref_graph_str):
|
||||
iter_i = itertools.product(range(0, max_for_ij + 1), repeat=repeat)
|
||||
iter_j = itertools.product(range(0, max_for_ij + 1), repeat=repeat)
|
||||
for prod_i, prod_j in itertools.product(iter_i, iter_j):
|
||||
yield (prod_i, prod_j)
|
||||
yield (list(prod_i), list(prod_j))
|
||||
|
||||
max_for_ij = 3
|
||||
assert len(shape) == 1
|
||||
|
||||
Reference in New Issue
Block a user