test: testing eval_op_graph_bounds_on_dataset with multiple outputs

refs #74
This commit is contained in:
Benoit Chevallier-Mames
2021-08-04 15:49:36 +02:00
committed by Benoit Chevallier
parent 73f21c79a6
commit 6a80b065fc

View File

@@ -1,5 +1,7 @@
"""Test file for bounds evaluation with a dataset"""
from typing import Tuple
import pytest
from hdk.common.bounds_measurement.dataset_eval import eval_op_graph_bounds_on_dataset
@@ -168,6 +170,39 @@ def test_eval_op_graph_bounds_on_dataset(
):
"""Test function for eval_op_graph_bounds_on_dataset"""
test_eval_op_graph_bounds_on_dataset_multiple_output(
function,
input_ranges,
(expected_output_bounds,),
(expected_output_data_type,),
)
@pytest.mark.parametrize(
"function,input_ranges,expected_output_bounds,expected_output_data_type",
[
pytest.param(
lambda x, y: (x + 1, y + 10),
((-1, 1), (3, 4)),
((0, 2), (13, 14)),
(Integer(2, is_signed=False), Integer(4, is_signed=False)),
),
pytest.param(
lambda x, y: (x + y + 1, x * y + 42),
((-1, 1), (3, 4)),
((3, 6), (38, 46)),
(Integer(3, is_signed=False), Integer(6, is_signed=False)),
),
],
)
def test_eval_op_graph_bounds_on_dataset_multiple_output(
function,
input_ranges,
expected_output_bounds,
expected_output_data_type: Tuple[Integer],
):
"""Test function for eval_op_graph_bounds_on_dataset"""
op_graph = trace_numpy_function(
function, {"x": EncryptedValue(Integer(64, True)), "y": EncryptedValue(Integer(64, True))}
)
@@ -181,12 +216,14 @@ def test_eval_op_graph_bounds_on_dataset(
op_graph, data_gen(*tuple(map(lambda x: range(x[0], x[1] + 1), input_ranges)))
)
output_node = op_graph.output_nodes[0]
output_node_bounds = node_bounds[output_node]
for i, output_node in op_graph.output_nodes.items():
output_node_bounds = node_bounds[output_node]
assert (output_node_bounds["min"], output_node_bounds["max"]) == expected_output_bounds
assert (output_node_bounds["min"], output_node_bounds["max"]) == expected_output_bounds[i]
assert EncryptedValue(Integer(64, True)) == output_node.outputs[0]
assert EncryptedValue(Integer(64, True)) == output_node.outputs[0]
op_graph.update_values_with_bounds(node_bounds)
assert expected_output_data_type == output_node.outputs[0].data_type
for i, output_node in op_graph.output_nodes.items():
assert expected_output_data_type[i] == output_node.outputs[0].data_type