feat: accept non tuple inputs ininputset for 1-parameter functions

closes #952
This commit is contained in:
Arthur Meyre
2021-11-22 09:36:59 +01:00
parent 8164ce3946
commit a77c369daf
2 changed files with 56 additions and 14 deletions

View File

@@ -1,7 +1,7 @@
"""Code to evaluate the IR graph on inputsets."""
import sys
from typing import Any, Callable, Dict, Iterable, Tuple
from typing import Any, Callable, Dict, Iterable, Tuple, Union
from ..compilation import CompilationConfiguration
from ..data_types.dtypes_helpers import (
@@ -103,7 +103,7 @@ def _print_input_coherency_warnings(
def eval_op_graph_bounds_on_inputset(
op_graph: OPGraph,
inputset: Iterable[Tuple[Any, ...]],
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]],
compilation_configuration: CompilationConfiguration,
min_func: Callable[[Any, Any], Any] = min,
max_func: Callable[[Any, Any], Any] = max,
@@ -118,8 +118,9 @@ def eval_op_graph_bounds_on_inputset(
Args:
op_graph (OPGraph): The graph for which we want to determine the bounds
inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
needs to be an iterable on tuples which are of the same length than the number of
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]]]): The inputset over which op_graph
is evaluated. It needs to be an iterable on tuples (can be single values in case the
function has only one argument) which are of the same length than the number of
parameters in the function, and in the same order than these same parameters
compilation_configuration (CompilationConfiguration): Configuration object to use
during determining input checking strategy
@@ -138,15 +139,25 @@ def eval_op_graph_bounds_on_inputset(
as key and a dict with keys "min", "max" and "sample" as value.
"""
num_input_nodes = len(op_graph.input_nodes)
def check_inputset_input_len_is_valid(data_to_check):
assert_true(
len(data_to_check) == len(op_graph.input_nodes),
(
f"Got input data from inputset of len: {len(data_to_check)}, "
f"function being evaluated has {len(op_graph.input_nodes)} inputs, please make "
f"sure your data generator returns valid tuples of input values"
),
)
# Only check if there are more than one input node, otherwise accept the value as the sole
# argument passed to the OPGraph for evaluation
if num_input_nodes > 1:
assert_true(
len(data_to_check) == num_input_nodes,
(
f"Got input data from inputset of len: {len(data_to_check)}, "
f"function being evaluated has {num_input_nodes} inputs, please make "
f"sure your data generator returns valid tuples of input values"
),
)
def generate_input_values_dict(input_data) -> Dict[int, Any]:
if num_input_nodes > 1:
return dict(enumerate(input_data))
return dict(enumerate(input_data)) if isinstance(input_data, tuple) else {0: input_data}
# TODO: do we want to check coherence between the input data type and the corresponding Input ir
# node expected data type ? Not considering bit_width as they may not make sense at this stage
@@ -161,7 +172,7 @@ def eval_op_graph_bounds_on_inputset(
inputset_iterator = iter(inputset)
inputset_size = 0
current_input_data = dict(enumerate(next(inputset_iterator)))
current_input_data = generate_input_values_dict(next(inputset_iterator))
inputset_size += 1
check_inputset_input_len_is_valid(current_input_data.values())
@@ -189,7 +200,7 @@ def eval_op_graph_bounds_on_inputset(
for input_data in inputset_iterator:
inputset_size += 1
current_input_data = dict(enumerate(input_data))
current_input_data = generate_input_values_dict(input_data)
check_inputset_input_len_is_valid(current_input_data.values())
if compilation_configuration.check_every_input_in_inputset:

View File

@@ -473,3 +473,34 @@ def test_eval_op_graph_bounds_on_non_conformant_inputset_treating_warnings_as_er
max_func=numpy_max_func,
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
)
def test_inpuset_eval_1_input(default_compilation_configuration):
"""Test case for a function with a single parameter and passing the inputset without tuples."""
def f(x):
return x + 42
x = EncryptedScalar(UnsignedInteger(4))
inputset = range(10)
op_graph = trace_numpy_function(f, {"x": x})
eval_op_graph_bounds_on_inputset(
op_graph,
inputset,
compilation_configuration=default_compilation_configuration,
min_func=numpy_min_func,
max_func=numpy_max_func,
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
)
input_node = op_graph.input_nodes[0]
assert input_node.inputs[0] == input_node.outputs[0]
assert input_node.inputs[0] == EncryptedScalar(UnsignedInteger(4))
output_node = op_graph.output_nodes[0]
assert output_node.outputs[0] == EncryptedScalar(UnsignedInteger(6))