feat(bounds): add function to update OPGraph IR nodes in and output values

- this allows to have tighter data types by sticking to the smallest types
able to represent the ranges passed as argument
- update test_dataset_eval to check the output Value's data_type is updated
This commit is contained in:
Arthur Meyre
2021-08-02 10:27:30 +02:00
parent b1a3b28a20
commit a158b09f44
2 changed files with 62 additions and 2 deletions

View File

@@ -1,9 +1,11 @@
"""Code to wrap and make manipulating networkx graphs easier"""
from copy import deepcopy
from typing import Any, Dict, Iterable, Mapping
import networkx as nx
from .data_types.integers import make_integer_to_hold_ints
from .representation import intermediate as ir
from .tracing import BaseTracer
from .tracing.tracing_helpers import create_graph_from_output_tracers
@@ -57,3 +59,40 @@ class OPGraph:
node_results[node] = node.evaluate({0: inputs[node.program_input_idx]})
return node_results
def update_values_with_bounds(self, node_bounds: dict):
"""Update nodes inputs and outputs values with data types able to hold data ranges measured
and passed in nodes_bounds
Args:
node_bounds (dict): Dictionary with nodes as keys, holding dicts with a 'min' and 'max'
keys. Those bounds will be taken as the data range to be represented, per node.
"""
node: ir.IntermediateNode
for node in self.graph.nodes():
current_node_bounds = node_bounds[node]
min_bound, max_bound = current_node_bounds["min"], current_node_bounds["max"]
if not isinstance(node, ir.Input):
for output_value in node.outputs:
output_value.data_type = make_integer_to_hold_ints(
(min_bound, max_bound), force_signed=False
)
else:
node.inputs[0].data_type = make_integer_to_hold_ints(
(min_bound, max_bound), force_signed=False
)
node.outputs[0] = deepcopy(node.inputs[0])
# TODO: #57 manage multiple outputs from a node, probably requires an output_idx when
# adding an edge
assert len(node.outputs) == 1
successors = self.graph.succ[node]
for succ in successors:
edge_data = self.graph.get_edge_data(node, succ)
for edge in edge_data.values():
input_idx = edge["input_idx"]
succ.inputs[input_idx] = deepcopy(node.outputs[0])

View File

@@ -9,77 +9,93 @@ from hdk.hnumpy.tracing import trace_numpy_function
@pytest.mark.parametrize(
"function,input_ranges,expected_output_bounds",
"function,input_ranges,expected_output_bounds,expected_output_data_type",
[
pytest.param(
lambda x, y: x + y,
((-10, 10), (-10, 10)),
(-20, 20),
Integer(6, is_signed=True),
id="x + y, (-10, 10), (-10, 10), (-20, 20)",
),
pytest.param(
lambda x, y: x + y,
((-10, 2), (-4, 5)),
(-14, 7),
Integer(5, is_signed=True),
id="x + y, (-10, 2), (-4, 5), (-14, 9)",
),
pytest.param(
lambda x, y: x - y,
((-10, 10), (-10, 10)),
(-20, 20),
Integer(6, is_signed=True),
id="x - y, (-10, 10), (-10, 10), (-20, 20)",
),
pytest.param(
lambda x, y: x - y,
((-10, 2), (-4, 5)),
(-15, 6),
Integer(5, is_signed=True),
id="x - y, (-10, 2), (-4, 5), (-15, 6)",
),
pytest.param(
lambda x, y: x * y,
((-10, 10), (-10, 10)),
(-100, 100),
Integer(8, is_signed=True),
id="x * y, (-10, 10), (-10, 10), (-100, 100)",
),
pytest.param(
lambda x, y: x * y,
((-10, 2), (-4, 5)),
(-50, 40),
Integer(7, is_signed=True),
id="x * y, (-10, 2), (-4, 5), (-50, 40)",
),
pytest.param(
lambda x, y: x + x + y,
((-10, 10), (-10, 10)),
(-30, 30),
Integer(6, is_signed=True),
id="x + x + y, (-10, 10), (-10, 10), (-30, 30)",
),
pytest.param(
lambda x, y: x - x + y,
((-10, 10), (-10, 10)),
(-10, 10),
Integer(5, is_signed=True),
id="x - x + y, (-10, 10), (-10, 10), (-10, 10)",
),
pytest.param(
lambda x, y: x - x + y,
((-10, 2), (-4, 5)),
(-4, 5),
Integer(4, is_signed=True),
id="x - x + y, (-10, 2), (-4, 5), (-4, 5)",
),
pytest.param(
lambda x, y: x * y - x,
((-10, 10), (-10, 10)),
(-110, 110),
Integer(8, is_signed=True),
id="x * y - x, (-10, 10), (-10, 10), (-110, 110)",
),
pytest.param(
lambda x, y: x * y - x,
((-10, 2), (-4, 5)),
(-40, 50),
Integer(7, is_signed=True),
id="x * y - x, (-10, 2), (-4, 5), (-40, 50),",
),
],
)
def test_eval_op_graph_bounds_on_dataset(function, input_ranges, expected_output_bounds):
def test_eval_op_graph_bounds_on_dataset(
function,
input_ranges,
expected_output_bounds,
expected_output_data_type: Integer,
):
"""Test function for eval_op_graph_bounds_on_dataset"""
op_graph = trace_numpy_function(
@@ -99,3 +115,8 @@ def test_eval_op_graph_bounds_on_dataset(function, input_ranges, expected_output
output_node_bounds = node_bounds[output_node]
assert (output_node_bounds["min"], output_node_bounds["max"]) == expected_output_bounds
assert EncryptedValue(Integer(64, True)) == output_node.outputs[0]
op_graph.update_values_with_bounds(node_bounds)
assert expected_output_data_type == output_node.outputs[0].data_type