From a158b09f44d24974cec74d3b3746e982b16b20bc Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Mon, 2 Aug 2021 10:27:30 +0200 Subject: [PATCH] feat(bounds): add function to update OPGraph IR nodes in and output values - this allows to have tighter data types by sticking to the smallest types able to represent the ranges passed as argument - update test_dataset_eval to check the output Value's data_type is updated --- hdk/common/operator_graph.py | 39 +++++++++++++++++++ .../bounds_measurement/test_dataset_eval.py | 25 +++++++++++- 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/hdk/common/operator_graph.py b/hdk/common/operator_graph.py index ccb7500a4..b1ad61619 100644 --- a/hdk/common/operator_graph.py +++ b/hdk/common/operator_graph.py @@ -1,9 +1,11 @@ """Code to wrap and make manipulating networkx graphs easier""" +from copy import deepcopy from typing import Any, Dict, Iterable, Mapping import networkx as nx +from .data_types.integers import make_integer_to_hold_ints from .representation import intermediate as ir from .tracing import BaseTracer from .tracing.tracing_helpers import create_graph_from_output_tracers @@ -57,3 +59,40 @@ class OPGraph: node_results[node] = node.evaluate({0: inputs[node.program_input_idx]}) return node_results + + def update_values_with_bounds(self, node_bounds: dict): + """Update nodes inputs and outputs values with data types able to hold data ranges measured + and passed in nodes_bounds + + Args: + node_bounds (dict): Dictionary with nodes as keys, holding dicts with a 'min' and 'max' + keys. Those bounds will be taken as the data range to be represented, per node. + """ + + node: ir.IntermediateNode + + for node in self.graph.nodes(): + current_node_bounds = node_bounds[node] + min_bound, max_bound = current_node_bounds["min"], current_node_bounds["max"] + + if not isinstance(node, ir.Input): + for output_value in node.outputs: + output_value.data_type = make_integer_to_hold_ints( + (min_bound, max_bound), force_signed=False + ) + else: + node.inputs[0].data_type = make_integer_to_hold_ints( + (min_bound, max_bound), force_signed=False + ) + node.outputs[0] = deepcopy(node.inputs[0]) + + # TODO: #57 manage multiple outputs from a node, probably requires an output_idx when + # adding an edge + assert len(node.outputs) == 1 + + successors = self.graph.succ[node] + for succ in successors: + edge_data = self.graph.get_edge_data(node, succ) + for edge in edge_data.values(): + input_idx = edge["input_idx"] + succ.inputs[input_idx] = deepcopy(node.outputs[0]) diff --git a/tests/common/bounds_measurement/test_dataset_eval.py b/tests/common/bounds_measurement/test_dataset_eval.py index 7c82bf6e7..398f190cf 100644 --- a/tests/common/bounds_measurement/test_dataset_eval.py +++ b/tests/common/bounds_measurement/test_dataset_eval.py @@ -9,77 +9,93 @@ from hdk.hnumpy.tracing import trace_numpy_function @pytest.mark.parametrize( - "function,input_ranges,expected_output_bounds", + "function,input_ranges,expected_output_bounds,expected_output_data_type", [ pytest.param( lambda x, y: x + y, ((-10, 10), (-10, 10)), (-20, 20), + Integer(6, is_signed=True), id="x + y, (-10, 10), (-10, 10), (-20, 20)", ), pytest.param( lambda x, y: x + y, ((-10, 2), (-4, 5)), (-14, 7), + Integer(5, is_signed=True), id="x + y, (-10, 2), (-4, 5), (-14, 9)", ), pytest.param( lambda x, y: x - y, ((-10, 10), (-10, 10)), (-20, 20), + Integer(6, is_signed=True), id="x - y, (-10, 10), (-10, 10), (-20, 20)", ), pytest.param( lambda x, y: x - y, ((-10, 2), (-4, 5)), (-15, 6), + Integer(5, is_signed=True), id="x - y, (-10, 2), (-4, 5), (-15, 6)", ), pytest.param( lambda x, y: x * y, ((-10, 10), (-10, 10)), (-100, 100), + Integer(8, is_signed=True), id="x * y, (-10, 10), (-10, 10), (-100, 100)", ), pytest.param( lambda x, y: x * y, ((-10, 2), (-4, 5)), (-50, 40), + Integer(7, is_signed=True), id="x * y, (-10, 2), (-4, 5), (-50, 40)", ), pytest.param( lambda x, y: x + x + y, ((-10, 10), (-10, 10)), (-30, 30), + Integer(6, is_signed=True), id="x + x + y, (-10, 10), (-10, 10), (-30, 30)", ), pytest.param( lambda x, y: x - x + y, ((-10, 10), (-10, 10)), (-10, 10), + Integer(5, is_signed=True), id="x - x + y, (-10, 10), (-10, 10), (-10, 10)", ), pytest.param( lambda x, y: x - x + y, ((-10, 2), (-4, 5)), (-4, 5), + Integer(4, is_signed=True), id="x - x + y, (-10, 2), (-4, 5), (-4, 5)", ), pytest.param( lambda x, y: x * y - x, ((-10, 10), (-10, 10)), (-110, 110), + Integer(8, is_signed=True), id="x * y - x, (-10, 10), (-10, 10), (-110, 110)", ), pytest.param( lambda x, y: x * y - x, ((-10, 2), (-4, 5)), (-40, 50), + Integer(7, is_signed=True), id="x * y - x, (-10, 2), (-4, 5), (-40, 50),", ), ], ) -def test_eval_op_graph_bounds_on_dataset(function, input_ranges, expected_output_bounds): +def test_eval_op_graph_bounds_on_dataset( + function, + input_ranges, + expected_output_bounds, + expected_output_data_type: Integer, +): """Test function for eval_op_graph_bounds_on_dataset""" op_graph = trace_numpy_function( @@ -99,3 +115,8 @@ def test_eval_op_graph_bounds_on_dataset(function, input_ranges, expected_output output_node_bounds = node_bounds[output_node] assert (output_node_bounds["min"], output_node_bounds["max"]) == expected_output_bounds + + assert EncryptedValue(Integer(64, True)) == output_node.outputs[0] + op_graph.update_values_with_bounds(node_bounds) + + assert expected_output_data_type == output_node.outputs[0].data_type