mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: accept non tuple inputs ininputset for 1-parameter functions
closes #952
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""Code to evaluate the IR graph on inputsets."""
|
||||
|
||||
import sys
|
||||
from typing import Any, Callable, Dict, Iterable, Tuple
|
||||
from typing import Any, Callable, Dict, Iterable, Tuple, Union
|
||||
|
||||
from ..compilation import CompilationConfiguration
|
||||
from ..data_types.dtypes_helpers import (
|
||||
@@ -103,7 +103,7 @@ def _print_input_coherency_warnings(
|
||||
|
||||
def eval_op_graph_bounds_on_inputset(
|
||||
op_graph: OPGraph,
|
||||
inputset: Iterable[Tuple[Any, ...]],
|
||||
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]],
|
||||
compilation_configuration: CompilationConfiguration,
|
||||
min_func: Callable[[Any, Any], Any] = min,
|
||||
max_func: Callable[[Any, Any], Any] = max,
|
||||
@@ -118,8 +118,9 @@ def eval_op_graph_bounds_on_inputset(
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): The graph for which we want to determine the bounds
|
||||
inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
|
||||
needs to be an iterable on tuples which are of the same length than the number of
|
||||
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]]]): The inputset over which op_graph
|
||||
is evaluated. It needs to be an iterable on tuples (can be single values in case the
|
||||
function has only one argument) which are of the same length than the number of
|
||||
parameters in the function, and in the same order than these same parameters
|
||||
compilation_configuration (CompilationConfiguration): Configuration object to use
|
||||
during determining input checking strategy
|
||||
@@ -138,15 +139,25 @@ def eval_op_graph_bounds_on_inputset(
|
||||
as key and a dict with keys "min", "max" and "sample" as value.
|
||||
"""
|
||||
|
||||
num_input_nodes = len(op_graph.input_nodes)
|
||||
|
||||
def check_inputset_input_len_is_valid(data_to_check):
|
||||
assert_true(
|
||||
len(data_to_check) == len(op_graph.input_nodes),
|
||||
(
|
||||
f"Got input data from inputset of len: {len(data_to_check)}, "
|
||||
f"function being evaluated has {len(op_graph.input_nodes)} inputs, please make "
|
||||
f"sure your data generator returns valid tuples of input values"
|
||||
),
|
||||
)
|
||||
# Only check if there are more than one input node, otherwise accept the value as the sole
|
||||
# argument passed to the OPGraph for evaluation
|
||||
if num_input_nodes > 1:
|
||||
assert_true(
|
||||
len(data_to_check) == num_input_nodes,
|
||||
(
|
||||
f"Got input data from inputset of len: {len(data_to_check)}, "
|
||||
f"function being evaluated has {num_input_nodes} inputs, please make "
|
||||
f"sure your data generator returns valid tuples of input values"
|
||||
),
|
||||
)
|
||||
|
||||
def generate_input_values_dict(input_data) -> Dict[int, Any]:
|
||||
if num_input_nodes > 1:
|
||||
return dict(enumerate(input_data))
|
||||
return dict(enumerate(input_data)) if isinstance(input_data, tuple) else {0: input_data}
|
||||
|
||||
# TODO: do we want to check coherence between the input data type and the corresponding Input ir
|
||||
# node expected data type ? Not considering bit_width as they may not make sense at this stage
|
||||
@@ -161,7 +172,7 @@ def eval_op_graph_bounds_on_inputset(
|
||||
inputset_iterator = iter(inputset)
|
||||
inputset_size = 0
|
||||
|
||||
current_input_data = dict(enumerate(next(inputset_iterator)))
|
||||
current_input_data = generate_input_values_dict(next(inputset_iterator))
|
||||
inputset_size += 1
|
||||
|
||||
check_inputset_input_len_is_valid(current_input_data.values())
|
||||
@@ -189,7 +200,7 @@ def eval_op_graph_bounds_on_inputset(
|
||||
|
||||
for input_data in inputset_iterator:
|
||||
inputset_size += 1
|
||||
current_input_data = dict(enumerate(input_data))
|
||||
current_input_data = generate_input_values_dict(input_data)
|
||||
|
||||
check_inputset_input_len_is_valid(current_input_data.values())
|
||||
if compilation_configuration.check_every_input_in_inputset:
|
||||
|
||||
@@ -473,3 +473,34 @@ def test_eval_op_graph_bounds_on_non_conformant_inputset_treating_warnings_as_er
|
||||
max_func=numpy_max_func,
|
||||
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
|
||||
def test_inpuset_eval_1_input(default_compilation_configuration):
|
||||
"""Test case for a function with a single parameter and passing the inputset without tuples."""
|
||||
|
||||
def f(x):
|
||||
return x + 42
|
||||
|
||||
x = EncryptedScalar(UnsignedInteger(4))
|
||||
|
||||
inputset = range(10)
|
||||
|
||||
op_graph = trace_numpy_function(f, {"x": x})
|
||||
|
||||
eval_op_graph_bounds_on_inputset(
|
||||
op_graph,
|
||||
inputset,
|
||||
compilation_configuration=default_compilation_configuration,
|
||||
min_func=numpy_min_func,
|
||||
max_func=numpy_max_func,
|
||||
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
input_node = op_graph.input_nodes[0]
|
||||
|
||||
assert input_node.inputs[0] == input_node.outputs[0]
|
||||
assert input_node.inputs[0] == EncryptedScalar(UnsignedInteger(4))
|
||||
|
||||
output_node = op_graph.output_nodes[0]
|
||||
|
||||
assert output_node.outputs[0] == EncryptedScalar(UnsignedInteger(6))
|
||||
|
||||
Reference in New Issue
Block a user