feat(bounds): add a way to evaluate an operator graph on a dataset

This commit is contained in:
Arthur Meyre
2021-07-29 15:49:04 +02:00
parent 9b52ea94fb
commit e55284b3ea
3 changed files with 137 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
"""Bounds measurement module"""
from . import dataset_eval

View File

@@ -0,0 +1,34 @@
"""Code to evaluate the IR graph on datasets"""
from typing import Iterator
from ..operator_graph import OPGraph
def eval_op_graph_bounds_on_dataset(op_graph: OPGraph, data_generator: Iterator):
"""Evaluate the bounds for all output values of the operators in the graph op_graph over data
coming from the data_generator
Args:
op_graph (OPGraph): The graph for which we want to determine the bounds
data_generator (Iterator): The dataset over which op_graph is evaluated
Returns:
Dict: dict containing the bounds for each node from op_graph, stored with the node as key
and a dict with keys "min" and "max" as value
"""
first_input_data = dict(enumerate(next(data_generator)))
first_output = op_graph.evaluate(first_input_data)
node_bounds = {
node: {"min": first_output[node], "max": first_output[node]}
for node in op_graph.graph.nodes()
}
for input_data in data_generator:
current_output = op_graph.evaluate(dict(enumerate(input_data)))
for node, value in current_output.items():
node_bounds[node]["min"] = min(node_bounds[node]["min"], value)
node_bounds[node]["max"] = max(node_bounds[node]["max"], value)
return node_bounds

View File

@@ -0,0 +1,101 @@
"""Test file for bounds evaluation with a dataset"""
import pytest
from hdk.common.bounds_measurement.dataset_eval import eval_op_graph_bounds_on_dataset
from hdk.common.data_types.integers import Integer
from hdk.common.data_types.values import EncryptedValue
from hdk.hnumpy.tracing import trace_numpy_function
@pytest.mark.parametrize(
"function,input_ranges,expected_output_bounds",
[
pytest.param(
lambda x, y: x + y,
((-10, 10), (-10, 10)),
(-20, 20),
id="x + y, (-10, 10), (-10, 10), (-20, 20)",
),
pytest.param(
lambda x, y: x + y,
((-10, 2), (-4, 5)),
(-14, 7),
id="x + y, (-10, 2), (-4, 5), (-14, 9)",
),
pytest.param(
lambda x, y: x - y,
((-10, 10), (-10, 10)),
(-20, 20),
id="x - y, (-10, 10), (-10, 10), (-20, 20)",
),
pytest.param(
lambda x, y: x - y,
((-10, 2), (-4, 5)),
(-15, 6),
id="x - y, (-10, 2), (-4, 5), (-15, 6)",
),
pytest.param(
lambda x, y: x * y,
((-10, 10), (-10, 10)),
(-100, 100),
id="x * y, (-10, 10), (-10, 10), (-100, 100)",
),
pytest.param(
lambda x, y: x * y,
((-10, 2), (-4, 5)),
(-50, 40),
id="x * y, (-10, 2), (-4, 5), (-50, 40)",
),
pytest.param(
lambda x, y: x + x + y,
((-10, 10), (-10, 10)),
(-30, 30),
id="x + x + y, (-10, 10), (-10, 10), (-30, 30)",
),
pytest.param(
lambda x, y: x - x + y,
((-10, 10), (-10, 10)),
(-10, 10),
id="x - x + y, (-10, 10), (-10, 10), (-10, 10)",
),
pytest.param(
lambda x, y: x - x + y,
((-10, 2), (-4, 5)),
(-4, 5),
id="x - x + y, (-10, 2), (-4, 5), (-4, 5)",
),
pytest.param(
lambda x, y: x * y - x,
((-10, 10), (-10, 10)),
(-110, 110),
id="x * y - x, (-10, 10), (-10, 10), (-110, 110)",
),
pytest.param(
lambda x, y: x * y - x,
((-10, 2), (-4, 5)),
(-40, 50),
id="x * y - x, (-10, 2), (-4, 5), (-40, 50),",
),
],
)
def test_eval_op_graph_bounds_on_dataset(function, input_ranges, expected_output_bounds):
"""Test function for eval_op_graph_bounds_on_dataset"""
op_graph = trace_numpy_function(
function, {"x": EncryptedValue(Integer(64, True)), "y": EncryptedValue(Integer(64, True))}
)
def data_gen(range_x, range_y):
for x_gen in range_x:
for y_gen in range_y:
yield (x_gen, y_gen)
node_bounds = eval_op_graph_bounds_on_dataset(
op_graph, data_gen(*tuple(map(lambda x: range(x[0], x[1] + 1), input_ranges)))
)
output_node = op_graph.output_nodes[0]
output_node_bounds = node_bounds[output_node]
assert (output_node_bounds["min"], output_node_bounds["max"]) == expected_output_bounds