diff --git a/concrete/numpy/dtypes/__init__.py b/concrete/numpy/dtypes/__init__.py new file mode 100644 index 000000000..7a019414d --- /dev/null +++ b/concrete/numpy/dtypes/__init__.py @@ -0,0 +1,7 @@ +""" +Declaration of `concrete.numpy.dtypes` namespace. +""" + +from .base import BaseDataType +from .float import Float +from .integer import Integer, SignedInteger, UnsignedInteger diff --git a/concrete/numpy/dtypes/base.py b/concrete/numpy/dtypes/base.py new file mode 100644 index 000000000..50521d140 --- /dev/null +++ b/concrete/numpy/dtypes/base.py @@ -0,0 +1,17 @@ +""" +Declaration of `BaseDataType` abstract class. +""" + +from abc import ABC, abstractmethod + + +class BaseDataType(ABC): + """BaseDataType abstract class, to form a basis for data types.""" + + @abstractmethod + def __eq__(self, other: object) -> bool: + pass # pragma: no cover + + @abstractmethod + def __str__(self) -> str: + pass # pragma: no cover diff --git a/concrete/numpy/dtypes/float.py b/concrete/numpy/dtypes/float.py new file mode 100644 index 000000000..f3e35b15b --- /dev/null +++ b/concrete/numpy/dtypes/float.py @@ -0,0 +1,30 @@ +""" +Declaration of `Float` class. +""" + +from .base import BaseDataType + + +class Float(BaseDataType): + """ + Float class, to represent floating point numbers. + """ + + bit_width: int + + def __init__(self, bit_width: int): + super().__init__() + + if bit_width not in [16, 32, 64]: + raise ValueError( + f"Float({repr(bit_width)}) is not supported " + f"(bit width must be one of 16, 32 or 64)" + ) + + self.bit_width = bit_width + + def __eq__(self, other: object) -> bool: + return isinstance(other, self.__class__) and self.bit_width == other.bit_width + + def __str__(self) -> str: + return f"float{self.bit_width}" diff --git a/concrete/numpy/dtypes/integer.py b/concrete/numpy/dtypes/integer.py new file mode 100644 index 000000000..732d86101 --- /dev/null +++ b/concrete/numpy/dtypes/integer.py @@ -0,0 +1,152 @@ +""" +Declaration of `Integer` class. +""" + +from functools import partial +from typing import Any + +import numpy as np + +from .base import BaseDataType + + +class Integer(BaseDataType): + """ + Integer class, to represent integers. + """ + + is_signed: bool + bit_width: int + + @staticmethod + def that_can_represent(value: Any, force_signed: bool = False) -> "Integer": + """ + Get the minimal `Integer` that can represent `value`. + + Args: + value (Any): + value that needs to be represented + + force_signed (bool, default = False): + whether to force signed integers or not + + Returns: + Integer: + minimal `Integer` that can represent `value` + + Raises: + ValueError: + if `value` cannot be represented by `Integer` + """ + + lower_bound: int + upper_bound: int + + if isinstance(value, list): + try: + value = np.array(value) + except Exception: # pylint: disable=broad-except + # here we try our best to convert the list to np.ndarray + # if it fails we raise the exception at the else branch below + pass + + if isinstance(value, (int, np.integer)): + lower_bound = int(value) + upper_bound = int(value) + elif isinstance(value, np.ndarray) and np.issubdtype(value.dtype, np.integer): + lower_bound = int(value.min()) + upper_bound = int(value.max()) + else: + raise ValueError(f"Integer cannot represent {repr(value)}") + + def bits_to_represent_int(value: int, force_signed: bool) -> int: + bits: int + + if value == 0: + return 1 + + if value < 0: + bits = int(np.ceil(np.log2(abs(value)))) + 1 + else: + bits = int(np.ceil(np.log2(value + 1))) + if force_signed: + bits += 1 + + return bits + + is_signed = force_signed or lower_bound < 0 + bit_width = ( + bits_to_represent_int(lower_bound, is_signed) + if lower_bound == upper_bound + else max( + bits_to_represent_int(lower_bound, is_signed), + bits_to_represent_int(upper_bound, is_signed), + ) + ) + + return Integer(is_signed, bit_width) + + def __init__(self, is_signed: bool, bit_width: int): + super().__init__() + + if not isinstance(bit_width, int) or bit_width <= 0: + integer_str = "SignedInteger" if is_signed else "UnsignedInteger" + raise ValueError( + f"{integer_str}({repr(bit_width)}) is not supported " + f"(bit width must be a positive integer)" + ) + + self.is_signed = is_signed + self.bit_width = bit_width + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, self.__class__) + and self.is_signed == other.is_signed + and self.bit_width == other.bit_width + ) + + def __str__(self) -> str: + return f"{('int' if self.is_signed else 'uint')}{self.bit_width}" + + def min(self) -> int: + """ + Get the minumum value that can be represented by the `Integer`. + + Returns: + int: + minumum value that can be represented by the `Integer` + """ + + return 0 if not self.is_signed else -(2 ** (self.bit_width - 1)) + + def max(self) -> int: + """ + Get the maximum value that can be represented by the `Integer`. + + Returns: + int: + maximum value that can be represented by the `Integer` + """ + + return (2 ** self.bit_width) - 1 if not self.is_signed else (2 ** (self.bit_width - 1)) - 1 + + def can_represent(self, value: int) -> bool: + """ + Get whether `value` can be represented by the `Integer` or not. + + Args: + value (int): + value to check representability + + Returns: + bool: + True if `value` is representable by the `integer`, False otherwise + """ + + return self.min() <= value <= self.max() + + +SignedInteger = partial(Integer, True) + +UnsignedInteger = partial(Integer, False) diff --git a/concrete/numpy/dtypes/utils.py b/concrete/numpy/dtypes/utils.py new file mode 100644 index 000000000..06f584065 --- /dev/null +++ b/concrete/numpy/dtypes/utils.py @@ -0,0 +1,74 @@ +""" +Declaration of various functions and constants related to data types. +""" + +from typing import List + +from ..internal.utils import assert_that +from .base import BaseDataType +from .float import Float +from .integer import Integer, SignedInteger, UnsignedInteger + + +def combine_dtypes(dtypes: List[BaseDataType]) -> BaseDataType: + """ + Get the 'BaseDataType' that can represent a set of 'BaseDataType's. + + Args: + dtypes (List[BaseDataType]): + dtypes to combine + + Returns: + BaseDataType: + dtype that can hold all the given dtypes (potentially lossy) + """ + + assert_that(len(dtypes) != 0) + assert_that(all(isinstance(dtype, (Integer, Float)) for dtype in dtypes)) + + def combine_2_dtypes(dtype1: BaseDataType, dtype2: BaseDataType) -> BaseDataType: + result: BaseDataType = dtype1 + + if isinstance(dtype1, Integer) and isinstance(dtype2, Integer): + max_bits = max(dtype1.bit_width, dtype2.bit_width) + + if dtype1.is_signed and dtype2.is_signed: + result = SignedInteger(max_bits) + + elif not dtype1.is_signed and not dtype2.is_signed: + result = UnsignedInteger(max_bits) + + elif dtype1.is_signed and not dtype2.is_signed: + # if dtype2 has the bigger bit_width, + # we need a signed integer that can hold + # it, so add 1 bit of sign to its bit_width + if dtype2.bit_width >= dtype1.bit_width: + new_bit_width = dtype2.bit_width + 1 + result = SignedInteger(new_bit_width) + else: + result = SignedInteger(dtype1.bit_width) + + elif not dtype1.is_signed and dtype2.is_signed: + # Same as above, with dtype1 and dtype2 switched around + if dtype1.bit_width >= dtype2.bit_width: + new_bit_width = dtype1.bit_width + 1 + result = SignedInteger(new_bit_width) + else: + result = SignedInteger(dtype2.bit_width) + + elif isinstance(dtype1, Float) and isinstance(dtype2, Float): + max_bits = max(dtype1.bit_width, dtype2.bit_width) + result = Float(max_bits) + + elif isinstance(dtype1, Float): + result = dtype1 + + elif isinstance(dtype2, Float): + result = dtype2 + + return result + + result = dtypes[0] + for other in dtypes[1:]: + result = combine_2_dtypes(result, other) + return result