diff --git a/hdk/common/bounds_measurement/dataset_eval.py b/hdk/common/bounds_measurement/dataset_eval.py index 46e588821..4abd8d719 100644 --- a/hdk/common/bounds_measurement/dataset_eval.py +++ b/hdk/common/bounds_measurement/dataset_eval.py @@ -1,23 +1,25 @@ """Code to evaluate the IR graph on datasets""" -from typing import Iterator +from typing import Any, Iterator, Tuple from ..operator_graph import OPGraph -def eval_op_graph_bounds_on_dataset(op_graph: OPGraph, data_generator: Iterator): +def eval_op_graph_bounds_on_dataset(op_graph: OPGraph, dataset: Iterator[Tuple[Any, ...]]): """Evaluate the bounds for all output values of the operators in the graph op_graph over data - coming from the data_generator + coming from the dataset Args: op_graph (OPGraph): The graph for which we want to determine the bounds - data_generator (Iterator): The dataset over which op_graph is evaluated + dataset (Iterator[Tuple[Any, ...]]): The dataset over which op_graph is evaluated. It + needs to be an iterator on tuples which are of the same length than the number of + parameters in the function, and in the same order than these same parameters Returns: Dict: dict containing the bounds for each node from op_graph, stored with the node as key and a dict with keys "min" and "max" as value """ - first_input_data = dict(enumerate(next(data_generator))) + first_input_data = dict(enumerate(next(dataset))) # Check the dataset is well-formed assert len(first_input_data) == len(op_graph.input_nodes), ( @@ -33,7 +35,7 @@ def eval_op_graph_bounds_on_dataset(op_graph: OPGraph, data_generator: Iterator) for node in op_graph.graph.nodes() } - for input_data in data_generator: + for input_data in dataset: next_input_data = dict(enumerate(input_data))