diff --git a/hdk/common/data_types/dtypes_helpers.py b/hdk/common/data_types/dtypes_helpers.py index 731eb54d1..6521ba12e 100644 --- a/hdk/common/data_types/dtypes_helpers.py +++ b/hdk/common/data_types/dtypes_helpers.py @@ -1,16 +1,16 @@ """File to hold helper functions for data types related stuff.""" from copy import deepcopy -from typing import cast +from typing import Union, cast from .base import BaseDataType from .floats import Float -from .integers import Integer +from .integers import Integer, get_bits_to_represent_value_as_integer from .values import BaseValue, ClearValue, EncryptedValue, ScalarValue INTEGER_TYPES = (Integer,) FLOAT_TYPES = (Float,) -SUPPORTED_TYPES = INTEGER_TYPES + FLOAT_TYPES +BASE_DATA_TYPES = INTEGER_TYPES + FLOAT_TYPES def value_is_encrypted_integer(value_to_check: BaseValue) -> bool: @@ -93,8 +93,8 @@ def find_type_to_hold_both_lossy( Returns: BaseDataType: The dtype able to hold (potentially lossy) dtype1 and dtype2 """ - assert isinstance(dtype1, SUPPORTED_TYPES), f"Unsupported dtype1: {type(dtype1)}" - assert isinstance(dtype2, SUPPORTED_TYPES), f"Unsupported dtype2: {type(dtype2)}" + assert isinstance(dtype1, BASE_DATA_TYPES), f"Unsupported dtype1: {type(dtype1)}" + assert isinstance(dtype2, BASE_DATA_TYPES), f"Unsupported dtype2: {type(dtype2)}" type_to_return: BaseDataType @@ -161,3 +161,27 @@ def mix_scalar_values_determine_holding_dtype(value1: BaseValue, value2: BaseVal mixed_value = ClearValue(holding_type) return mixed_value + + +def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]) -> BaseDataType: + """Helper function to determine the BaseDataType to hold the input constant data. + + Args: + constant_data (Union[int, float]): The constant data for which to determine the + corresponding BaseDataType. + + Returns: + BaseDataType: The corresponding BaseDataType + """ + constant_data_type: BaseDataType + assert isinstance( + constant_data, (int, float) + ), f"Unsupported constant data of type {type(constant_data)}" + if isinstance(constant_data, int): + is_signed = constant_data < 0 + constant_data_type = Integer( + get_bits_to_represent_value_as_integer(constant_data, is_signed), is_signed + ) + elif isinstance(constant_data, float): + constant_data_type = Float(64) + return constant_data_type diff --git a/hdk/common/operator_graph.py b/hdk/common/operator_graph.py index e0015f5e8..d42a15419 100644 --- a/hdk/common/operator_graph.py +++ b/hdk/common/operator_graph.py @@ -1,12 +1,14 @@ """Code to wrap and make manipulating networkx graphs easier.""" from copy import deepcopy -from typing import Any, Dict, Iterable, List, Set, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Union import networkx as nx +from .data_types.base import BaseDataType +from .data_types.dtypes_helpers import get_base_data_type_for_python_constant_data from .data_types.floats import Float -from .data_types.integers import make_integer_to_hold +from .data_types.integers import Integer, make_integer_to_hold from .representation import intermediate as ir from .tracing import BaseTracer from .tracing.tracing_helpers import create_graph_from_output_tracers @@ -130,7 +132,13 @@ class OPGraph: return node_results - def update_values_with_bounds(self, node_bounds: dict): + def update_values_with_bounds( + self, + node_bounds: dict, + get_base_data_type_for_constant_data: Callable[ + [Any], BaseDataType + ] = get_base_data_type_for_python_constant_data, + ): """Update values with bounds. Update nodes inputs and outputs values with data types able to hold data ranges measured @@ -139,6 +147,10 @@ class OPGraph: Args: node_bounds (dict): Dictionary with nodes as keys, holding dicts with a 'min' and 'max' keys. Those bounds will be taken as the data range to be represented, per node. + get_base_data_type_for_constant_data (Callable[ [Type], BaseDataType ], optional): This + is a callback function to convert data encountered during value updates to + BaseDataType. This allows to manage data coming from foreign frameworks without + specialising OPGraph. Defaults to get_base_data_type_for_python_constant_data. """ node: ir.IntermediateNode @@ -149,9 +161,12 @@ class OPGraph: current_node_bounds["max"], ) + min_data_type = get_base_data_type_for_constant_data(min_bound) + max_data_type = get_base_data_type_for_constant_data(max_bound) + if not isinstance(node, ir.Input): for output_value in node.outputs: - if isinstance(min_bound, int) and isinstance(max_bound, int): + if isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer): output_value.data_type = make_integer_to_hold( (min_bound, max_bound), force_signed=False ) @@ -159,8 +174,8 @@ class OPGraph: output_value.data_type = Float(64) else: # Currently variable inputs are only allowed to be integers - assert isinstance(min_bound, int) and isinstance(max_bound, int), ( - f"Inputs to a graph should be integers, got bounds that were not float, \n" + assert isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer), ( + f"Inputs to a graph should be integers, got bounds that were float, \n" f"min: {min_bound} ({type(min_bound)}), max: {max_bound} ({type(max_bound)})" ) node.inputs[0].data_type = make_integer_to_hold( diff --git a/hdk/hnumpy/np_dtypes_helpers.py b/hdk/hnumpy/np_dtypes_helpers.py index 0830841d3..1de2d6c4b 100644 --- a/hdk/hnumpy/np_dtypes_helpers.py +++ b/hdk/hnumpy/np_dtypes_helpers.py @@ -7,7 +7,7 @@ import numpy from numpy.typing import DTypeLike from ..common.data_types.base import BaseDataType -from ..common.data_types.dtypes_helpers import SUPPORTED_TYPES +from ..common.data_types.dtypes_helpers import BASE_DATA_TYPES from ..common.data_types.floats import Float from ..common.data_types.integers import Integer @@ -62,7 +62,7 @@ def convert_common_dtype_to_numpy_dtype(common_dtype: BaseDataType) -> numpy.dty numpy.dtype: The resulting numpy.dtype """ assert isinstance( - common_dtype, SUPPORTED_TYPES + common_dtype, BASE_DATA_TYPES ), f"Unsupported common_dtype: {type(common_dtype)}" type_to_return: numpy.dtype