diff --git a/hdk/common/bounds_measurement/dataset_eval.py b/hdk/common/bounds_measurement/dataset_eval.py index 4abd8d719..8fb8d5df9 100644 --- a/hdk/common/bounds_measurement/dataset_eval.py +++ b/hdk/common/bounds_measurement/dataset_eval.py @@ -19,15 +19,21 @@ def eval_op_graph_bounds_on_dataset(op_graph: OPGraph, dataset: Iterator[Tuple[A Dict: dict containing the bounds for each node from op_graph, stored with the node as key and a dict with keys "min" and "max" as value """ + + def check_dataset_input_is_valid(data_to_check): + assert len(data_to_check) == len(op_graph.input_nodes), ( + f"Got input data from dataset of len: {len(data_to_check)}, " + f"function being evaluated has {len(op_graph.input_nodes)} inputs, please make " + f"sure your data generator returns valid tuples of input values" + ) + # TODO: change this to be more generic and check coherence between the input data type and + # the corresponding Input ir node expected data type + assert all( + isinstance(val, int) for val in data_to_check + ), "For now dataset evaluation only support int as inputs, please check your dataset" + first_input_data = dict(enumerate(next(dataset))) - - # Check the dataset is well-formed - assert len(first_input_data) == len(op_graph.input_nodes), ( - f"Got input data from dataset of len: {len(first_input_data)}, function being evaluated has" - f" only {len(op_graph.input_nodes)} inputs, please make sure your data generator returns" - f" valid tuples of input values" - ) - + check_dataset_input_is_valid(first_input_data.values()) first_output = op_graph.evaluate(first_input_data) node_bounds = { @@ -36,19 +42,9 @@ def eval_op_graph_bounds_on_dataset(op_graph: OPGraph, dataset: Iterator[Tuple[A } for input_data in dataset: - - next_input_data = dict(enumerate(input_data)) - - # Check the dataset is well-formed - assert len(next_input_data) == len(op_graph.input_nodes), ( - f"Got input data from dataset of len: {len(next_input_data)}," - f" function being evaluated has" - f" only {len(op_graph.input_nodes)} inputs, please make sure" - f" your data generator returns" - f" valid tuples of input values" - ) - - current_output = op_graph.evaluate(next_input_data) + current_input_data = dict(enumerate(input_data)) + check_dataset_input_is_valid(current_input_data.values()) + current_output = op_graph.evaluate(current_input_data) for node, value in current_output.items(): node_bounds[node]["min"] = min(node_bounds[node]["min"], value) node_bounds[node]["max"] = max(node_bounds[node]["max"], value) diff --git a/hdk/common/operator_graph.py b/hdk/common/operator_graph.py index 4a93182c1..04655eeba 100644 --- a/hdk/common/operator_graph.py +++ b/hdk/common/operator_graph.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Iterable, Mapping import networkx as nx +from .data_types.floats import Float from .data_types.integers import make_integer_to_hold_ints from .representation import intermediate as ir from .tracing import BaseTracer @@ -71,10 +72,18 @@ class OPGraph: if not isinstance(node, ir.Input): for output_value in node.outputs: - output_value.data_type = make_integer_to_hold_ints( - (min_bound, max_bound), force_signed=False - ) + if isinstance(min_bound, int) and isinstance(max_bound, int): + output_value.data_type = make_integer_to_hold_ints( + (min_bound, max_bound), force_signed=False + ) + else: + output_value.data_type = Float(64) else: + # Currently variable inputs are only allowed to be integers + assert isinstance(min_bound, int) and isinstance(max_bound, int), ( + f"Inputs to a graph should be integers, got bounds that were not float, \n" + f"min: {min_bound} ({type(min_bound)}), max: {max_bound} ({type(max_bound)})" + ) node.inputs[0].data_type = make_integer_to_hold_ints( (min_bound, max_bound), force_signed=False ) diff --git a/hdk/common/representation/intermediate.py b/hdk/common/representation/intermediate.py index db78a6e4f..d5660a205 100644 --- a/hdk/common/representation/intermediate.py +++ b/hdk/common/representation/intermediate.py @@ -6,6 +6,7 @@ from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple from ..data_types import BaseValue from ..data_types.dtypes_helpers import mix_values_determine_holding_dtype +from ..data_types.floats import Float from ..data_types.integers import Integer, get_bits_to_represent_int from ..data_types.scalars import Scalars from ..data_types.values import ClearValue @@ -153,12 +154,16 @@ class ConstantInput(IntermediateNode): super().__init__([]) self.constant_data = constant_data - # TODO: manage other cases, we can't call get_bits_to_represent_int - assert isinstance(constant_data, int) - is_signed = constant_data < 0 - self.outputs = [ - ClearValue(Integer(get_bits_to_represent_int(constant_data, is_signed), is_signed)) - ] + assert isinstance( + constant_data, (int, float) + ), "Only int and float are support for constant input" + if isinstance(constant_data, int): + is_signed = constant_data < 0 + self.outputs = [ + ClearValue(Integer(get_bits_to_represent_int(constant_data, is_signed), is_signed)) + ] + elif isinstance(constant_data, float): + self.outputs = [ClearValue(Float(64))] def evaluate(self, inputs: Mapping[int, Any]) -> Any: return self.constant_data diff --git a/tests/common/bounds_measurement/test_dataset_eval.py b/tests/common/bounds_measurement/test_dataset_eval.py index cdf30691a..98f5b2b43 100644 --- a/tests/common/bounds_measurement/test_dataset_eval.py +++ b/tests/common/bounds_measurement/test_dataset_eval.py @@ -5,6 +5,7 @@ from typing import Tuple import pytest from hdk.common.bounds_measurement.dataset_eval import eval_op_graph_bounds_on_dataset +from hdk.common.data_types.floats import Float from hdk.common.data_types.integers import Integer from hdk.common.data_types.values import EncryptedValue from hdk.hnumpy.tracing import trace_numpy_function @@ -27,6 +28,13 @@ from hdk.hnumpy.tracing import trace_numpy_function Integer(5, is_signed=True), id="x + y, (-10, 2), (-4, 5), (-14, 7)", ), + pytest.param( + lambda x, y: x + y + 1.7, + ((-10, 2), (-4, 5)), + (-12.3, 8.7), + Float(64), + id="x + y + 1.7, (-10, 2), (-4, 5), (-12.3, 8.7)", + ), pytest.param( lambda x, y: x + y + 1, ((-10, 2), (-4, 5)), @@ -67,21 +75,42 @@ from hdk.hnumpy.tracing import trace_numpy_function ((-10, 2), (-4, 5)), (-57, -36), Integer(7, is_signed=True), - id="x - y, (-10, 2), (-4, 5), (-57, -36)", + id="x - y - 42, (-10, 2), (-4, 5), (-57, -36)", + ), + pytest.param( + lambda x, y: x - y - 41.5, + ((-10, 2), (-4, 5)), + (-56.5, -35.5), + Float(64), + id="x - y - 41.5, (-10, 2), (-4, 5), (-56.5, -35.5)", ), pytest.param( lambda x, y: 3 - x + y, ((-10, 2), (-4, 5)), (-3, 18), Integer(6, is_signed=True), - id="x - y, (-10, 2), (-4, 5), (-3, 18)", + id="3 - x + y, (-10, 2), (-4, 5), (-3, 18)", + ), + pytest.param( + lambda x, y: 2.8 - x + y, + ((-10, 2), (-4, 5)), + (-3.2, 17.8), + Float(64), + id="2.8 - x + y, (-10, 2), (-4, 5), (-3.2, 17.8)", ), pytest.param( lambda x, y: (-13) - x + y, ((-10, 2), (-4, 5)), (-19, 2), Integer(6, is_signed=True), - id="x - y, (-10, 2), (-4, 5), (-16, 2)", + id="(-13) - x + y, (-10, 2), (-4, 5), (-19, 2)", + ), + pytest.param( + lambda x, y: (-13.5) - x + y, + ((-10, 2), (-4, 5)), + (-19.5, 1.5), + Float(64), + id="(-13.5) - x + y, (-10, 2), (-4, 5), (-19.5, 1.5)", ), pytest.param( lambda x, y: x * y, @@ -102,7 +131,14 @@ from hdk.hnumpy.tracing import trace_numpy_function ((-10, 2), (-4, 5)), (-150, 120), Integer(9, is_signed=True), - id="x * y, (-10, 2), (-4, 5), (-150, 120)", + id="(3 * x) * y, (-10, 2), (-4, 5), (-150, 120)", + ), + pytest.param( + lambda x, y: (3.0 * x) * y, + ((-10, 2), (-4, 5)), + (-150.0, 120.0), + Float(64), + id="(3.0 * x) * y, (-10, 2), (-4, 5), (-150.0, 120.0)", ), pytest.param( lambda x, y: (x * 11) * y, @@ -116,7 +152,14 @@ from hdk.hnumpy.tracing import trace_numpy_function ((-10, 2), (-4, 5)), (-440, 550), Integer(11, is_signed=True), - id="x * y, (-10, 2), (-4, 5), (-440, 550)", + id="(x * (-11)) * y, (-10, 2), (-4, 5), (-440, 550)", + ), + pytest.param( + lambda x, y: (x * (-11.0)) * y, + ((-10, 2), (-4, 5)), + (-440.0, 550.0), + Float(64), + id="(x * (-11.0)) * y, (-10, 2), (-4, 5), (-440.0, 550.0)", ), pytest.param( lambda x, y: x + x + y, @@ -187,12 +230,36 @@ def test_eval_op_graph_bounds_on_dataset( ((0, 2), (13, 14)), (Integer(2, is_signed=False), Integer(4, is_signed=False)), ), + pytest.param( + lambda x, y: (x + 1.5, y + 9.6), + ((-1, 1), (3, 4)), + ((0.5, 2.5), (12.6, 13.6)), + (Float(64), Float(64)), + ), pytest.param( lambda x, y: (x + y + 1, x * y + 42), ((-1, 1), (3, 4)), ((3, 6), (38, 46)), (Integer(3, is_signed=False), Integer(6, is_signed=False)), ), + pytest.param( + lambda x, y: (x + y + 0.4, x * y + 41.7), + ((-1, 1), (3, 4)), + ((2.4, 5.4), (37.7, 45.7)), + (Float(64), Float(64)), + ), + pytest.param( + lambda x, y: (x + y + 1, x * y + 41.7), + ((-1, 1), (3, 4)), + ((3, 6), (37.7, 45.7)), + (Integer(3, is_signed=False), Float(64)), + ), + pytest.param( + lambda x, y: (x + y + 0.4, x * y + 42), + ((-1, 1), (3, 4)), + ((2.4, 5.4), (38, 46)), + (Float(64), Integer(6, is_signed=False)), + ), ], ) def test_eval_op_graph_bounds_on_dataset_multiple_output( @@ -218,11 +285,8 @@ def test_eval_op_graph_bounds_on_dataset_multiple_output( for i, output_node in op_graph.output_nodes.items(): output_node_bounds = node_bounds[output_node] - assert (output_node_bounds["min"], output_node_bounds["max"]) == expected_output_bounds[i] - assert EncryptedValue(Integer(64, True)) == output_node.outputs[0] - op_graph.update_values_with_bounds(node_bounds) for i, output_node in op_graph.output_nodes.items():