mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor: refactor make_integer_to_hold_ints
- rename make_integer_to_hold_ints to make_integer_to_hold - accept any values as input as we don't know which type this function will be called with - rename get_bits_to_represent_int to get_bits_to_represent_value_as_integer
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""This file holds the definitions for integer types."""
|
||||
|
||||
import math
|
||||
from typing import Iterable
|
||||
from typing import Any, Iterable
|
||||
|
||||
from . import base
|
||||
|
||||
@@ -84,36 +84,35 @@ def create_unsigned_integer(bit_width: int) -> Integer:
|
||||
UnsignedInteger = create_unsigned_integer
|
||||
|
||||
|
||||
def make_integer_to_hold_ints(values: Iterable[int], force_signed: bool) -> Integer:
|
||||
def make_integer_to_hold(values: Iterable[Any], force_signed: bool) -> Integer:
|
||||
"""Returns an Integer able to hold all values, it is possible to force the Integer to be signed.
|
||||
|
||||
Args:
|
||||
values (Iterable[int]): The values to hold
|
||||
values (Iterable[Any]): The values to hold
|
||||
force_signed (bool): Set to True to force the result to be a signed Integer
|
||||
|
||||
Returns:
|
||||
Integer: The Integer able to hold values
|
||||
"""
|
||||
assert all(isinstance(x, int) for x in values)
|
||||
min_value = min(values)
|
||||
max_value = max(values)
|
||||
|
||||
make_signed_integer = force_signed or min_value < 0
|
||||
|
||||
num_bits = max(
|
||||
get_bits_to_represent_int(min_value, make_signed_integer),
|
||||
get_bits_to_represent_int(max_value, make_signed_integer),
|
||||
get_bits_to_represent_value_as_integer(min_value, make_signed_integer),
|
||||
get_bits_to_represent_value_as_integer(max_value, make_signed_integer),
|
||||
)
|
||||
|
||||
return Integer(num_bits, is_signed=make_signed_integer)
|
||||
|
||||
|
||||
def get_bits_to_represent_int(value: int, force_signed: bool) -> int:
|
||||
"""Returns how many bits are required to represent a single int.
|
||||
def get_bits_to_represent_value_as_integer(value: Any, force_signed: bool) -> int:
|
||||
"""Returns how many bits are required to represent a numerical Value.
|
||||
|
||||
Args:
|
||||
value (int): The int for which we want to know how many bits are required
|
||||
force_signed (bool): Set to True to force the result to be a signed Integer
|
||||
value (Any): The value for which we want to know how many bits are required.
|
||||
force_signed (bool): Set to True to force the result to be a signed integer.
|
||||
|
||||
Returns:
|
||||
int: required amount of bits
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Iterable, Tuple, Union
|
||||
|
||||
from ..common_helpers import is_a_power_of_2
|
||||
from ..data_types.base import BaseDataType
|
||||
from ..data_types.integers import make_integer_to_hold_ints
|
||||
from ..data_types.integers import make_integer_to_hold
|
||||
from ..representation import intermediate as ir
|
||||
from ..tracing.base_tracer import BaseTracer
|
||||
|
||||
@@ -28,7 +28,7 @@ class LookupTable:
|
||||
)
|
||||
|
||||
self.table = table
|
||||
self.output_dtype = make_integer_to_hold_ints(table, force_signed=False)
|
||||
self.output_dtype = make_integer_to_hold(table, force_signed=False)
|
||||
|
||||
def __getitem__(self, key: Union[int, BaseTracer]):
|
||||
# if a tracer is used for indexing,
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any, Dict, Iterable, List, Set, Tuple, Union
|
||||
import networkx as nx
|
||||
|
||||
from .data_types.floats import Float
|
||||
from .data_types.integers import make_integer_to_hold_ints
|
||||
from .data_types.integers import make_integer_to_hold
|
||||
from .representation import intermediate as ir
|
||||
from .tracing import BaseTracer
|
||||
from .tracing.tracing_helpers import create_graph_from_output_tracers
|
||||
@@ -152,7 +152,7 @@ class OPGraph:
|
||||
if not isinstance(node, ir.Input):
|
||||
for output_value in node.outputs:
|
||||
if isinstance(min_bound, int) and isinstance(max_bound, int):
|
||||
output_value.data_type = make_integer_to_hold_ints(
|
||||
output_value.data_type = make_integer_to_hold(
|
||||
(min_bound, max_bound), force_signed=False
|
||||
)
|
||||
else:
|
||||
@@ -163,7 +163,7 @@ class OPGraph:
|
||||
f"Inputs to a graph should be integers, got bounds that were not float, \n"
|
||||
f"min: {min_bound} ({type(min_bound)}), max: {max_bound} ({type(max_bound)})"
|
||||
)
|
||||
node.inputs[0].data_type = make_integer_to_hold_ints(
|
||||
node.inputs[0].data_type = make_integer_to_hold(
|
||||
(min_bound, max_bound), force_signed=False
|
||||
)
|
||||
node.outputs[0] = deepcopy(node.inputs[0])
|
||||
|
||||
@@ -8,7 +8,7 @@ from ..data_types import BaseValue
|
||||
from ..data_types.base import BaseDataType
|
||||
from ..data_types.dtypes_helpers import mix_scalar_values_determine_holding_dtype
|
||||
from ..data_types.floats import Float
|
||||
from ..data_types.integers import Integer, get_bits_to_represent_int
|
||||
from ..data_types.integers import Integer, get_bits_to_represent_value_as_integer
|
||||
from ..data_types.scalars import Scalars
|
||||
from ..data_types.values import ClearValue, EncryptedValue
|
||||
|
||||
@@ -162,7 +162,12 @@ class ConstantInput(IntermediateNode):
|
||||
if isinstance(constant_data, int):
|
||||
is_signed = constant_data < 0
|
||||
self.outputs = [
|
||||
ClearValue(Integer(get_bits_to_represent_int(constant_data, is_signed), is_signed))
|
||||
ClearValue(
|
||||
Integer(
|
||||
get_bits_to_represent_value_as_integer(constant_data, is_signed),
|
||||
is_signed,
|
||||
)
|
||||
)
|
||||
]
|
||||
elif isinstance(constant_data, float):
|
||||
self.outputs = [ClearValue(Float(64))]
|
||||
|
||||
@@ -8,7 +8,7 @@ from hdk.common.data_types.integers import (
|
||||
Integer,
|
||||
SignedInteger,
|
||||
UnsignedInteger,
|
||||
make_integer_to_hold_ints,
|
||||
make_integer_to_hold,
|
||||
)
|
||||
|
||||
|
||||
@@ -109,4 +109,4 @@ def test_integers_repr(integer: Integer, expected_repr_str: str):
|
||||
)
|
||||
def test_make_integer_to_hold(values, force_signed, expected_result):
|
||||
"""Test make_integer_to_hold"""
|
||||
assert expected_result == make_integer_to_hold_ints(values, force_signed)
|
||||
assert expected_result == make_integer_to_hold(values, force_signed)
|
||||
|
||||
Reference in New Issue
Block a user