mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
test: testing eval_op_graph_bounds_on_dataset with multiple outputs
refs #74
This commit is contained in:
committed by
Benoit Chevallier
parent
73f21c79a6
commit
6a80b065fc
@@ -1,5 +1,7 @@
|
||||
"""Test file for bounds evaluation with a dataset"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from hdk.common.bounds_measurement.dataset_eval import eval_op_graph_bounds_on_dataset
|
||||
@@ -168,6 +170,39 @@ def test_eval_op_graph_bounds_on_dataset(
|
||||
):
|
||||
"""Test function for eval_op_graph_bounds_on_dataset"""
|
||||
|
||||
test_eval_op_graph_bounds_on_dataset_multiple_output(
|
||||
function,
|
||||
input_ranges,
|
||||
(expected_output_bounds,),
|
||||
(expected_output_data_type,),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,expected_output_bounds,expected_output_data_type",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: (x + 1, y + 10),
|
||||
((-1, 1), (3, 4)),
|
||||
((0, 2), (13, 14)),
|
||||
(Integer(2, is_signed=False), Integer(4, is_signed=False)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x + y + 1, x * y + 42),
|
||||
((-1, 1), (3, 4)),
|
||||
((3, 6), (38, 46)),
|
||||
(Integer(3, is_signed=False), Integer(6, is_signed=False)),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_eval_op_graph_bounds_on_dataset_multiple_output(
|
||||
function,
|
||||
input_ranges,
|
||||
expected_output_bounds,
|
||||
expected_output_data_type: Tuple[Integer],
|
||||
):
|
||||
"""Test function for eval_op_graph_bounds_on_dataset"""
|
||||
|
||||
op_graph = trace_numpy_function(
|
||||
function, {"x": EncryptedValue(Integer(64, True)), "y": EncryptedValue(Integer(64, True))}
|
||||
)
|
||||
@@ -181,12 +216,14 @@ def test_eval_op_graph_bounds_on_dataset(
|
||||
op_graph, data_gen(*tuple(map(lambda x: range(x[0], x[1] + 1), input_ranges)))
|
||||
)
|
||||
|
||||
output_node = op_graph.output_nodes[0]
|
||||
output_node_bounds = node_bounds[output_node]
|
||||
for i, output_node in op_graph.output_nodes.items():
|
||||
output_node_bounds = node_bounds[output_node]
|
||||
|
||||
assert (output_node_bounds["min"], output_node_bounds["max"]) == expected_output_bounds
|
||||
assert (output_node_bounds["min"], output_node_bounds["max"]) == expected_output_bounds[i]
|
||||
|
||||
assert EncryptedValue(Integer(64, True)) == output_node.outputs[0]
|
||||
|
||||
assert EncryptedValue(Integer(64, True)) == output_node.outputs[0]
|
||||
op_graph.update_values_with_bounds(node_bounds)
|
||||
|
||||
assert expected_output_data_type == output_node.outputs[0].data_type
|
||||
for i, output_node in op_graph.output_nodes.items():
|
||||
assert expected_output_data_type[i] == output_node.outputs[0].data_type
|
||||
|
||||
Reference in New Issue
Block a user