diff --git a/concrete/common/bounds_measurement/inputset_eval.py b/concrete/common/bounds_measurement/inputset_eval.py index 10d8d91e6..b2af728e4 100644 --- a/concrete/common/bounds_measurement/inputset_eval.py +++ b/concrete/common/bounds_measurement/inputset_eval.py @@ -1,7 +1,7 @@ """Code to evaluate the IR graph on inputsets.""" import sys -from typing import Any, Callable, Dict, Iterable, Tuple +from typing import Any, Callable, Dict, Iterable, Tuple, Union from ..compilation import CompilationConfiguration from ..data_types.dtypes_helpers import ( @@ -103,7 +103,7 @@ def _print_input_coherency_warnings( def eval_op_graph_bounds_on_inputset( op_graph: OPGraph, - inputset: Iterable[Tuple[Any, ...]], + inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]], compilation_configuration: CompilationConfiguration, min_func: Callable[[Any, Any], Any] = min, max_func: Callable[[Any, Any], Any] = max, @@ -118,8 +118,9 @@ def eval_op_graph_bounds_on_inputset( Args: op_graph (OPGraph): The graph for which we want to determine the bounds - inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It - needs to be an iterable on tuples which are of the same length than the number of + inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]]]): The inputset over which op_graph + is evaluated. It needs to be an iterable on tuples (can be single values in case the + function has only one argument) which are of the same length than the number of parameters in the function, and in the same order than these same parameters compilation_configuration (CompilationConfiguration): Configuration object to use during determining input checking strategy @@ -138,15 +139,25 @@ def eval_op_graph_bounds_on_inputset( as key and a dict with keys "min", "max" and "sample" as value. """ + num_input_nodes = len(op_graph.input_nodes) + def check_inputset_input_len_is_valid(data_to_check): - assert_true( - len(data_to_check) == len(op_graph.input_nodes), - ( - f"Got input data from inputset of len: {len(data_to_check)}, " - f"function being evaluated has {len(op_graph.input_nodes)} inputs, please make " - f"sure your data generator returns valid tuples of input values" - ), - ) + # Only check if there are more than one input node, otherwise accept the value as the sole + # argument passed to the OPGraph for evaluation + if num_input_nodes > 1: + assert_true( + len(data_to_check) == num_input_nodes, + ( + f"Got input data from inputset of len: {len(data_to_check)}, " + f"function being evaluated has {num_input_nodes} inputs, please make " + f"sure your data generator returns valid tuples of input values" + ), + ) + + def generate_input_values_dict(input_data) -> Dict[int, Any]: + if num_input_nodes > 1: + return dict(enumerate(input_data)) + return dict(enumerate(input_data)) if isinstance(input_data, tuple) else {0: input_data} # TODO: do we want to check coherence between the input data type and the corresponding Input ir # node expected data type ? Not considering bit_width as they may not make sense at this stage @@ -161,7 +172,7 @@ def eval_op_graph_bounds_on_inputset( inputset_iterator = iter(inputset) inputset_size = 0 - current_input_data = dict(enumerate(next(inputset_iterator))) + current_input_data = generate_input_values_dict(next(inputset_iterator)) inputset_size += 1 check_inputset_input_len_is_valid(current_input_data.values()) @@ -189,7 +200,7 @@ def eval_op_graph_bounds_on_inputset( for input_data in inputset_iterator: inputset_size += 1 - current_input_data = dict(enumerate(input_data)) + current_input_data = generate_input_values_dict(input_data) check_inputset_input_len_is_valid(current_input_data.values()) if compilation_configuration.check_every_input_in_inputset: diff --git a/tests/common/bounds_measurement/test_inputset_eval.py b/tests/common/bounds_measurement/test_inputset_eval.py index ddf1fa975..ffef45e41 100644 --- a/tests/common/bounds_measurement/test_inputset_eval.py +++ b/tests/common/bounds_measurement/test_inputset_eval.py @@ -473,3 +473,34 @@ def test_eval_op_graph_bounds_on_non_conformant_inputset_treating_warnings_as_er max_func=numpy_max_func, get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data, ) + + +def test_inpuset_eval_1_input(default_compilation_configuration): + """Test case for a function with a single parameter and passing the inputset without tuples.""" + + def f(x): + return x + 42 + + x = EncryptedScalar(UnsignedInteger(4)) + + inputset = range(10) + + op_graph = trace_numpy_function(f, {"x": x}) + + eval_op_graph_bounds_on_inputset( + op_graph, + inputset, + compilation_configuration=default_compilation_configuration, + min_func=numpy_min_func, + max_func=numpy_max_func, + get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data, + ) + + input_node = op_graph.input_nodes[0] + + assert input_node.inputs[0] == input_node.outputs[0] + assert input_node.inputs[0] == EncryptedScalar(UnsignedInteger(4)) + + output_node = op_graph.output_nodes[0] + + assert output_node.outputs[0] == EncryptedScalar(UnsignedInteger(6))