feat: implement dtypes module

This commit is contained in:
Umut
2022-04-04 13:27:09 +02:00
parent 9bf9b3c743
commit 904b283df7
5 changed files with 280 additions and 0 deletions

View File

@@ -0,0 +1,7 @@
"""
Declaration of `concrete.numpy.dtypes` namespace.
"""
from .base import BaseDataType
from .float import Float
from .integer import Integer, SignedInteger, UnsignedInteger

View File

@@ -0,0 +1,17 @@
"""
Declaration of `BaseDataType` abstract class.
"""
from abc import ABC, abstractmethod
class BaseDataType(ABC):
"""BaseDataType abstract class, to form a basis for data types."""
@abstractmethod
def __eq__(self, other: object) -> bool:
pass # pragma: no cover
@abstractmethod
def __str__(self) -> str:
pass # pragma: no cover

View File

@@ -0,0 +1,30 @@
"""
Declaration of `Float` class.
"""
from .base import BaseDataType
class Float(BaseDataType):
"""
Float class, to represent floating point numbers.
"""
bit_width: int
def __init__(self, bit_width: int):
super().__init__()
if bit_width not in [16, 32, 64]:
raise ValueError(
f"Float({repr(bit_width)}) is not supported "
f"(bit width must be one of 16, 32 or 64)"
)
self.bit_width = bit_width
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.bit_width == other.bit_width
def __str__(self) -> str:
return f"float{self.bit_width}"

View File

@@ -0,0 +1,152 @@
"""
Declaration of `Integer` class.
"""
from functools import partial
from typing import Any
import numpy as np
from .base import BaseDataType
class Integer(BaseDataType):
"""
Integer class, to represent integers.
"""
is_signed: bool
bit_width: int
@staticmethod
def that_can_represent(value: Any, force_signed: bool = False) -> "Integer":
"""
Get the minimal `Integer` that can represent `value`.
Args:
value (Any):
value that needs to be represented
force_signed (bool, default = False):
whether to force signed integers or not
Returns:
Integer:
minimal `Integer` that can represent `value`
Raises:
ValueError:
if `value` cannot be represented by `Integer`
"""
lower_bound: int
upper_bound: int
if isinstance(value, list):
try:
value = np.array(value)
except Exception: # pylint: disable=broad-except
# here we try our best to convert the list to np.ndarray
# if it fails we raise the exception at the else branch below
pass
if isinstance(value, (int, np.integer)):
lower_bound = int(value)
upper_bound = int(value)
elif isinstance(value, np.ndarray) and np.issubdtype(value.dtype, np.integer):
lower_bound = int(value.min())
upper_bound = int(value.max())
else:
raise ValueError(f"Integer cannot represent {repr(value)}")
def bits_to_represent_int(value: int, force_signed: bool) -> int:
bits: int
if value == 0:
return 1
if value < 0:
bits = int(np.ceil(np.log2(abs(value)))) + 1
else:
bits = int(np.ceil(np.log2(value + 1)))
if force_signed:
bits += 1
return bits
is_signed = force_signed or lower_bound < 0
bit_width = (
bits_to_represent_int(lower_bound, is_signed)
if lower_bound == upper_bound
else max(
bits_to_represent_int(lower_bound, is_signed),
bits_to_represent_int(upper_bound, is_signed),
)
)
return Integer(is_signed, bit_width)
def __init__(self, is_signed: bool, bit_width: int):
super().__init__()
if not isinstance(bit_width, int) or bit_width <= 0:
integer_str = "SignedInteger" if is_signed else "UnsignedInteger"
raise ValueError(
f"{integer_str}({repr(bit_width)}) is not supported "
f"(bit width must be a positive integer)"
)
self.is_signed = is_signed
self.bit_width = bit_width
def __eq__(self, other: Any) -> bool:
return (
isinstance(other, self.__class__)
and self.is_signed == other.is_signed
and self.bit_width == other.bit_width
)
def __str__(self) -> str:
return f"{('int' if self.is_signed else 'uint')}{self.bit_width}"
def min(self) -> int:
"""
Get the minumum value that can be represented by the `Integer`.
Returns:
int:
minumum value that can be represented by the `Integer`
"""
return 0 if not self.is_signed else -(2 ** (self.bit_width - 1))
def max(self) -> int:
"""
Get the maximum value that can be represented by the `Integer`.
Returns:
int:
maximum value that can be represented by the `Integer`
"""
return (2 ** self.bit_width) - 1 if not self.is_signed else (2 ** (self.bit_width - 1)) - 1
def can_represent(self, value: int) -> bool:
"""
Get whether `value` can be represented by the `Integer` or not.
Args:
value (int):
value to check representability
Returns:
bool:
True if `value` is representable by the `integer`, False otherwise
"""
return self.min() <= value <= self.max()
SignedInteger = partial(Integer, True)
UnsignedInteger = partial(Integer, False)

View File

@@ -0,0 +1,74 @@
"""
Declaration of various functions and constants related to data types.
"""
from typing import List
from ..internal.utils import assert_that
from .base import BaseDataType
from .float import Float
from .integer import Integer, SignedInteger, UnsignedInteger
def combine_dtypes(dtypes: List[BaseDataType]) -> BaseDataType:
"""
Get the 'BaseDataType' that can represent a set of 'BaseDataType's.
Args:
dtypes (List[BaseDataType]):
dtypes to combine
Returns:
BaseDataType:
dtype that can hold all the given dtypes (potentially lossy)
"""
assert_that(len(dtypes) != 0)
assert_that(all(isinstance(dtype, (Integer, Float)) for dtype in dtypes))
def combine_2_dtypes(dtype1: BaseDataType, dtype2: BaseDataType) -> BaseDataType:
result: BaseDataType = dtype1
if isinstance(dtype1, Integer) and isinstance(dtype2, Integer):
max_bits = max(dtype1.bit_width, dtype2.bit_width)
if dtype1.is_signed and dtype2.is_signed:
result = SignedInteger(max_bits)
elif not dtype1.is_signed and not dtype2.is_signed:
result = UnsignedInteger(max_bits)
elif dtype1.is_signed and not dtype2.is_signed:
# if dtype2 has the bigger bit_width,
# we need a signed integer that can hold
# it, so add 1 bit of sign to its bit_width
if dtype2.bit_width >= dtype1.bit_width:
new_bit_width = dtype2.bit_width + 1
result = SignedInteger(new_bit_width)
else:
result = SignedInteger(dtype1.bit_width)
elif not dtype1.is_signed and dtype2.is_signed:
# Same as above, with dtype1 and dtype2 switched around
if dtype1.bit_width >= dtype2.bit_width:
new_bit_width = dtype1.bit_width + 1
result = SignedInteger(new_bit_width)
else:
result = SignedInteger(dtype2.bit_width)
elif isinstance(dtype1, Float) and isinstance(dtype2, Float):
max_bits = max(dtype1.bit_width, dtype2.bit_width)
result = Float(max_bits)
elif isinstance(dtype1, Float):
result = dtype1
elif isinstance(dtype2, Float):
result = dtype2
return result
result = dtypes[0]
for other in dtypes[1:]:
result = combine_2_dtypes(result, other)
return result