doc: explain what are datasets in eval_op_graph_bounds_on_dataset

This commit is contained in:
Benoit Chevallier-Mames
2021-08-06 11:23:01 +02:00
committed by Benoit Chevallier
parent 36e30e81b4
commit 055298daf8

View File

@@ -1,23 +1,25 @@
"""Code to evaluate the IR graph on datasets"""
from typing import Iterator
from typing import Any, Iterator, Tuple
from ..operator_graph import OPGraph
def eval_op_graph_bounds_on_dataset(op_graph: OPGraph, data_generator: Iterator):
def eval_op_graph_bounds_on_dataset(op_graph: OPGraph, dataset: Iterator[Tuple[Any, ...]]):
"""Evaluate the bounds for all output values of the operators in the graph op_graph over data
coming from the data_generator
coming from the dataset
Args:
op_graph (OPGraph): The graph for which we want to determine the bounds
data_generator (Iterator): The dataset over which op_graph is evaluated
dataset (Iterator[Tuple[Any, ...]]): The dataset over which op_graph is evaluated. It
needs to be an iterator on tuples which are of the same length than the number of
parameters in the function, and in the same order than these same parameters
Returns:
Dict: dict containing the bounds for each node from op_graph, stored with the node as key
and a dict with keys "min" and "max" as value
"""
first_input_data = dict(enumerate(next(data_generator)))
first_input_data = dict(enumerate(next(dataset)))
# Check the dataset is well-formed
assert len(first_input_data) == len(op_graph.input_nodes), (
@@ -33,7 +35,7 @@ def eval_op_graph_bounds_on_dataset(op_graph: OPGraph, data_generator: Iterator)
for node in op_graph.graph.nodes()
}
for input_data in data_generator:
for input_data in dataset:
next_input_data = dict(enumerate(input_data))