From 36e30e81b43fb43c8bb7be10dae2b9ec9ea0d525 Mon Sep 17 00:00:00 2001 From: Benoit Chevallier-Mames Date: Fri, 6 Aug 2021 11:06:18 +0200 Subject: [PATCH] fix: check datasets closes #91 --- hdk/common/bounds_measurement/dataset_eval.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/hdk/common/bounds_measurement/dataset_eval.py b/hdk/common/bounds_measurement/dataset_eval.py index d30dda021..46e588821 100644 --- a/hdk/common/bounds_measurement/dataset_eval.py +++ b/hdk/common/bounds_measurement/dataset_eval.py @@ -18,6 +18,14 @@ def eval_op_graph_bounds_on_dataset(op_graph: OPGraph, data_generator: Iterator) and a dict with keys "min" and "max" as value """ first_input_data = dict(enumerate(next(data_generator))) + + # Check the dataset is well-formed + assert len(first_input_data) == len(op_graph.input_nodes), ( + f"Got input data from dataset of len: {len(first_input_data)}, function being evaluated has" + f" only {len(op_graph.input_nodes)} inputs, please make sure your data generator returns" + f" valid tuples of input values" + ) + first_output = op_graph.evaluate(first_input_data) node_bounds = { @@ -26,7 +34,19 @@ def eval_op_graph_bounds_on_dataset(op_graph: OPGraph, data_generator: Iterator) } for input_data in data_generator: - current_output = op_graph.evaluate(dict(enumerate(input_data))) + + next_input_data = dict(enumerate(input_data)) + + # Check the dataset is well-formed + assert len(next_input_data) == len(op_graph.input_nodes), ( + f"Got input data from dataset of len: {len(next_input_data)}," + f" function being evaluated has" + f" only {len(op_graph.input_nodes)} inputs, please make sure" + f" your data generator returns" + f" valid tuples of input values" + ) + + current_output = op_graph.evaluate(next_input_data) for node, value in current_output.items(): node_bounds[node]["min"] = min(node_bounds[node]["min"], value) node_bounds[node]["max"] = max(node_bounds[node]["max"], value)