refactor: refactor make_integer_to_hold_ints

- rename make_integer_to_hold_ints to make_integer_to_hold
- accept any values as input as we don't know which type this function will
be called with
- rename get_bits_to_represent_int to
get_bits_to_represent_value_as_integer
This commit is contained in:
Arthur Meyre
2021-08-19 15:06:09 +02:00
parent 0ff3ae4795
commit 60daf31981
5 changed files with 23 additions and 19 deletions

View File

@@ -1,7 +1,7 @@
"""This file holds the definitions for integer types."""
import math
from typing import Iterable
from typing import Any, Iterable
from . import base
@@ -84,36 +84,35 @@ def create_unsigned_integer(bit_width: int) -> Integer:
UnsignedInteger = create_unsigned_integer
def make_integer_to_hold_ints(values: Iterable[int], force_signed: bool) -> Integer:
def make_integer_to_hold(values: Iterable[Any], force_signed: bool) -> Integer:
"""Returns an Integer able to hold all values, it is possible to force the Integer to be signed.
Args:
values (Iterable[int]): The values to hold
values (Iterable[Any]): The values to hold
force_signed (bool): Set to True to force the result to be a signed Integer
Returns:
Integer: The Integer able to hold values
"""
assert all(isinstance(x, int) for x in values)
min_value = min(values)
max_value = max(values)
make_signed_integer = force_signed or min_value < 0
num_bits = max(
get_bits_to_represent_int(min_value, make_signed_integer),
get_bits_to_represent_int(max_value, make_signed_integer),
get_bits_to_represent_value_as_integer(min_value, make_signed_integer),
get_bits_to_represent_value_as_integer(max_value, make_signed_integer),
)
return Integer(num_bits, is_signed=make_signed_integer)
def get_bits_to_represent_int(value: int, force_signed: bool) -> int:
"""Returns how many bits are required to represent a single int.
def get_bits_to_represent_value_as_integer(value: Any, force_signed: bool) -> int:
"""Returns how many bits are required to represent a numerical Value.
Args:
value (int): The int for which we want to know how many bits are required
force_signed (bool): Set to True to force the result to be a signed Integer
value (Any): The value for which we want to know how many bits are required.
force_signed (bool): Set to True to force the result to be a signed integer.
Returns:
int: required amount of bits

View File

@@ -5,7 +5,7 @@ from typing import Iterable, Tuple, Union
from ..common_helpers import is_a_power_of_2
from ..data_types.base import BaseDataType
from ..data_types.integers import make_integer_to_hold_ints
from ..data_types.integers import make_integer_to_hold
from ..representation import intermediate as ir
from ..tracing.base_tracer import BaseTracer
@@ -28,7 +28,7 @@ class LookupTable:
)
self.table = table
self.output_dtype = make_integer_to_hold_ints(table, force_signed=False)
self.output_dtype = make_integer_to_hold(table, force_signed=False)
def __getitem__(self, key: Union[int, BaseTracer]):
# if a tracer is used for indexing,

View File

@@ -6,7 +6,7 @@ from typing import Any, Dict, Iterable, List, Set, Tuple, Union
import networkx as nx
from .data_types.floats import Float
from .data_types.integers import make_integer_to_hold_ints
from .data_types.integers import make_integer_to_hold
from .representation import intermediate as ir
from .tracing import BaseTracer
from .tracing.tracing_helpers import create_graph_from_output_tracers
@@ -152,7 +152,7 @@ class OPGraph:
if not isinstance(node, ir.Input):
for output_value in node.outputs:
if isinstance(min_bound, int) and isinstance(max_bound, int):
output_value.data_type = make_integer_to_hold_ints(
output_value.data_type = make_integer_to_hold(
(min_bound, max_bound), force_signed=False
)
else:
@@ -163,7 +163,7 @@ class OPGraph:
f"Inputs to a graph should be integers, got bounds that were not float, \n"
f"min: {min_bound} ({type(min_bound)}), max: {max_bound} ({type(max_bound)})"
)
node.inputs[0].data_type = make_integer_to_hold_ints(
node.inputs[0].data_type = make_integer_to_hold(
(min_bound, max_bound), force_signed=False
)
node.outputs[0] = deepcopy(node.inputs[0])

View File

@@ -8,7 +8,7 @@ from ..data_types import BaseValue
from ..data_types.base import BaseDataType
from ..data_types.dtypes_helpers import mix_scalar_values_determine_holding_dtype
from ..data_types.floats import Float
from ..data_types.integers import Integer, get_bits_to_represent_int
from ..data_types.integers import Integer, get_bits_to_represent_value_as_integer
from ..data_types.scalars import Scalars
from ..data_types.values import ClearValue, EncryptedValue
@@ -162,7 +162,12 @@ class ConstantInput(IntermediateNode):
if isinstance(constant_data, int):
is_signed = constant_data < 0
self.outputs = [
ClearValue(Integer(get_bits_to_represent_int(constant_data, is_signed), is_signed))
ClearValue(
Integer(
get_bits_to_represent_value_as_integer(constant_data, is_signed),
is_signed,
)
)
]
elif isinstance(constant_data, float):
self.outputs = [ClearValue(Float(64))]

View File

@@ -8,7 +8,7 @@ from hdk.common.data_types.integers import (
Integer,
SignedInteger,
UnsignedInteger,
make_integer_to_hold_ints,
make_integer_to_hold,
)
@@ -109,4 +109,4 @@ def test_integers_repr(integer: Integer, expected_repr_str: str):
)
def test_make_integer_to_hold(values, force_signed, expected_result):
"""Test make_integer_to_hold"""
assert expected_result == make_integer_to_hold_ints(values, force_signed)
assert expected_result == make_integer_to_hold(values, force_signed)