diff --git a/hdk/common/data_types/integers.py b/hdk/common/data_types/integers.py index dab19380a..7e5b0d79b 100644 --- a/hdk/common/data_types/integers.py +++ b/hdk/common/data_types/integers.py @@ -1,7 +1,7 @@ """This file holds the definitions for integer types.""" import math -from typing import Iterable +from typing import Any, Iterable from . import base @@ -84,36 +84,35 @@ def create_unsigned_integer(bit_width: int) -> Integer: UnsignedInteger = create_unsigned_integer -def make_integer_to_hold_ints(values: Iterable[int], force_signed: bool) -> Integer: +def make_integer_to_hold(values: Iterable[Any], force_signed: bool) -> Integer: """Returns an Integer able to hold all values, it is possible to force the Integer to be signed. Args: - values (Iterable[int]): The values to hold + values (Iterable[Any]): The values to hold force_signed (bool): Set to True to force the result to be a signed Integer Returns: Integer: The Integer able to hold values """ - assert all(isinstance(x, int) for x in values) min_value = min(values) max_value = max(values) make_signed_integer = force_signed or min_value < 0 num_bits = max( - get_bits_to_represent_int(min_value, make_signed_integer), - get_bits_to_represent_int(max_value, make_signed_integer), + get_bits_to_represent_value_as_integer(min_value, make_signed_integer), + get_bits_to_represent_value_as_integer(max_value, make_signed_integer), ) return Integer(num_bits, is_signed=make_signed_integer) -def get_bits_to_represent_int(value: int, force_signed: bool) -> int: - """Returns how many bits are required to represent a single int. +def get_bits_to_represent_value_as_integer(value: Any, force_signed: bool) -> int: + """Returns how many bits are required to represent a numerical Value. Args: - value (int): The int for which we want to know how many bits are required - force_signed (bool): Set to True to force the result to be a signed Integer + value (Any): The value for which we want to know how many bits are required. + force_signed (bool): Set to True to force the result to be a signed integer. Returns: int: required amount of bits diff --git a/hdk/common/extensions/table.py b/hdk/common/extensions/table.py index 74845799d..5326a8f6c 100644 --- a/hdk/common/extensions/table.py +++ b/hdk/common/extensions/table.py @@ -5,7 +5,7 @@ from typing import Iterable, Tuple, Union from ..common_helpers import is_a_power_of_2 from ..data_types.base import BaseDataType -from ..data_types.integers import make_integer_to_hold_ints +from ..data_types.integers import make_integer_to_hold from ..representation import intermediate as ir from ..tracing.base_tracer import BaseTracer @@ -28,7 +28,7 @@ class LookupTable: ) self.table = table - self.output_dtype = make_integer_to_hold_ints(table, force_signed=False) + self.output_dtype = make_integer_to_hold(table, force_signed=False) def __getitem__(self, key: Union[int, BaseTracer]): # if a tracer is used for indexing, diff --git a/hdk/common/operator_graph.py b/hdk/common/operator_graph.py index eecc4ac85..e0015f5e8 100644 --- a/hdk/common/operator_graph.py +++ b/hdk/common/operator_graph.py @@ -6,7 +6,7 @@ from typing import Any, Dict, Iterable, List, Set, Tuple, Union import networkx as nx from .data_types.floats import Float -from .data_types.integers import make_integer_to_hold_ints +from .data_types.integers import make_integer_to_hold from .representation import intermediate as ir from .tracing import BaseTracer from .tracing.tracing_helpers import create_graph_from_output_tracers @@ -152,7 +152,7 @@ class OPGraph: if not isinstance(node, ir.Input): for output_value in node.outputs: if isinstance(min_bound, int) and isinstance(max_bound, int): - output_value.data_type = make_integer_to_hold_ints( + output_value.data_type = make_integer_to_hold( (min_bound, max_bound), force_signed=False ) else: @@ -163,7 +163,7 @@ class OPGraph: f"Inputs to a graph should be integers, got bounds that were not float, \n" f"min: {min_bound} ({type(min_bound)}), max: {max_bound} ({type(max_bound)})" ) - node.inputs[0].data_type = make_integer_to_hold_ints( + node.inputs[0].data_type = make_integer_to_hold( (min_bound, max_bound), force_signed=False ) node.outputs[0] = deepcopy(node.inputs[0]) diff --git a/hdk/common/representation/intermediate.py b/hdk/common/representation/intermediate.py index 13903685a..fecc54d58 100644 --- a/hdk/common/representation/intermediate.py +++ b/hdk/common/representation/intermediate.py @@ -8,7 +8,7 @@ from ..data_types import BaseValue from ..data_types.base import BaseDataType from ..data_types.dtypes_helpers import mix_scalar_values_determine_holding_dtype from ..data_types.floats import Float -from ..data_types.integers import Integer, get_bits_to_represent_int +from ..data_types.integers import Integer, get_bits_to_represent_value_as_integer from ..data_types.scalars import Scalars from ..data_types.values import ClearValue, EncryptedValue @@ -162,7 +162,12 @@ class ConstantInput(IntermediateNode): if isinstance(constant_data, int): is_signed = constant_data < 0 self.outputs = [ - ClearValue(Integer(get_bits_to_represent_int(constant_data, is_signed), is_signed)) + ClearValue( + Integer( + get_bits_to_represent_value_as_integer(constant_data, is_signed), + is_signed, + ) + ) ] elif isinstance(constant_data, float): self.outputs = [ClearValue(Float(64))] diff --git a/tests/common/data_types/test_integers.py b/tests/common/data_types/test_integers.py index c4ed8f1fc..7d1f70a2b 100644 --- a/tests/common/data_types/test_integers.py +++ b/tests/common/data_types/test_integers.py @@ -8,7 +8,7 @@ from hdk.common.data_types.integers import ( Integer, SignedInteger, UnsignedInteger, - make_integer_to_hold_ints, + make_integer_to_hold, ) @@ -109,4 +109,4 @@ def test_integers_repr(integer: Integer, expected_repr_str: str): ) def test_make_integer_to_hold(values, force_signed, expected_result): """Test make_integer_to_hold""" - assert expected_result == make_integer_to_hold_ints(values, force_signed) + assert expected_result == make_integer_to_hold(values, force_signed)