fix: check datasets

closes #91
This commit is contained in:
Benoit Chevallier-Mames
2021-08-06 11:06:18 +02:00
committed by Benoit Chevallier
parent ee832079ba
commit 36e30e81b4

View File

@@ -18,6 +18,14 @@ def eval_op_graph_bounds_on_dataset(op_graph: OPGraph, data_generator: Iterator)
and a dict with keys "min" and "max" as value
"""
first_input_data = dict(enumerate(next(data_generator)))
# Check the dataset is well-formed
assert len(first_input_data) == len(op_graph.input_nodes), (
f"Got input data from dataset of len: {len(first_input_data)}, function being evaluated has"
f" only {len(op_graph.input_nodes)} inputs, please make sure your data generator returns"
f" valid tuples of input values"
)
first_output = op_graph.evaluate(first_input_data)
node_bounds = {
@@ -26,7 +34,19 @@ def eval_op_graph_bounds_on_dataset(op_graph: OPGraph, data_generator: Iterator)
}
for input_data in data_generator:
current_output = op_graph.evaluate(dict(enumerate(input_data)))
next_input_data = dict(enumerate(input_data))
# Check the dataset is well-formed
assert len(next_input_data) == len(op_graph.input_nodes), (
f"Got input data from dataset of len: {len(next_input_data)},"
f" function being evaluated has"
f" only {len(op_graph.input_nodes)} inputs, please make sure"
f" your data generator returns"
f" valid tuples of input values"
)
current_output = op_graph.evaluate(next_input_data)
for node, value in current_output.items():
node_bounds[node]["min"] = min(node_bounds[node]["min"], value)
node_bounds[node]["max"] = max(node_bounds[node]["max"], value)