mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: implement dtypes module
This commit is contained in:
7
concrete/numpy/dtypes/__init__.py
Normal file
7
concrete/numpy/dtypes/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Declaration of `concrete.numpy.dtypes` namespace.
|
||||
"""
|
||||
|
||||
from .base import BaseDataType
|
||||
from .float import Float
|
||||
from .integer import Integer, SignedInteger, UnsignedInteger
|
||||
17
concrete/numpy/dtypes/base.py
Normal file
17
concrete/numpy/dtypes/base.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Declaration of `BaseDataType` abstract class.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseDataType(ABC):
|
||||
"""BaseDataType abstract class, to form a basis for data types."""
|
||||
|
||||
@abstractmethod
|
||||
def __eq__(self, other: object) -> bool:
|
||||
pass # pragma: no cover
|
||||
|
||||
@abstractmethod
|
||||
def __str__(self) -> str:
|
||||
pass # pragma: no cover
|
||||
30
concrete/numpy/dtypes/float.py
Normal file
30
concrete/numpy/dtypes/float.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""
|
||||
Declaration of `Float` class.
|
||||
"""
|
||||
|
||||
from .base import BaseDataType
|
||||
|
||||
|
||||
class Float(BaseDataType):
|
||||
"""
|
||||
Float class, to represent floating point numbers.
|
||||
"""
|
||||
|
||||
bit_width: int
|
||||
|
||||
def __init__(self, bit_width: int):
|
||||
super().__init__()
|
||||
|
||||
if bit_width not in [16, 32, 64]:
|
||||
raise ValueError(
|
||||
f"Float({repr(bit_width)}) is not supported "
|
||||
f"(bit width must be one of 16, 32 or 64)"
|
||||
)
|
||||
|
||||
self.bit_width = bit_width
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, self.__class__) and self.bit_width == other.bit_width
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"float{self.bit_width}"
|
||||
152
concrete/numpy/dtypes/integer.py
Normal file
152
concrete/numpy/dtypes/integer.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
Declaration of `Integer` class.
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseDataType
|
||||
|
||||
|
||||
class Integer(BaseDataType):
|
||||
"""
|
||||
Integer class, to represent integers.
|
||||
"""
|
||||
|
||||
is_signed: bool
|
||||
bit_width: int
|
||||
|
||||
@staticmethod
|
||||
def that_can_represent(value: Any, force_signed: bool = False) -> "Integer":
|
||||
"""
|
||||
Get the minimal `Integer` that can represent `value`.
|
||||
|
||||
Args:
|
||||
value (Any):
|
||||
value that needs to be represented
|
||||
|
||||
force_signed (bool, default = False):
|
||||
whether to force signed integers or not
|
||||
|
||||
Returns:
|
||||
Integer:
|
||||
minimal `Integer` that can represent `value`
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
if `value` cannot be represented by `Integer`
|
||||
"""
|
||||
|
||||
lower_bound: int
|
||||
upper_bound: int
|
||||
|
||||
if isinstance(value, list):
|
||||
try:
|
||||
value = np.array(value)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# here we try our best to convert the list to np.ndarray
|
||||
# if it fails we raise the exception at the else branch below
|
||||
pass
|
||||
|
||||
if isinstance(value, (int, np.integer)):
|
||||
lower_bound = int(value)
|
||||
upper_bound = int(value)
|
||||
elif isinstance(value, np.ndarray) and np.issubdtype(value.dtype, np.integer):
|
||||
lower_bound = int(value.min())
|
||||
upper_bound = int(value.max())
|
||||
else:
|
||||
raise ValueError(f"Integer cannot represent {repr(value)}")
|
||||
|
||||
def bits_to_represent_int(value: int, force_signed: bool) -> int:
|
||||
bits: int
|
||||
|
||||
if value == 0:
|
||||
return 1
|
||||
|
||||
if value < 0:
|
||||
bits = int(np.ceil(np.log2(abs(value)))) + 1
|
||||
else:
|
||||
bits = int(np.ceil(np.log2(value + 1)))
|
||||
if force_signed:
|
||||
bits += 1
|
||||
|
||||
return bits
|
||||
|
||||
is_signed = force_signed or lower_bound < 0
|
||||
bit_width = (
|
||||
bits_to_represent_int(lower_bound, is_signed)
|
||||
if lower_bound == upper_bound
|
||||
else max(
|
||||
bits_to_represent_int(lower_bound, is_signed),
|
||||
bits_to_represent_int(upper_bound, is_signed),
|
||||
)
|
||||
)
|
||||
|
||||
return Integer(is_signed, bit_width)
|
||||
|
||||
def __init__(self, is_signed: bool, bit_width: int):
|
||||
super().__init__()
|
||||
|
||||
if not isinstance(bit_width, int) or bit_width <= 0:
|
||||
integer_str = "SignedInteger" if is_signed else "UnsignedInteger"
|
||||
raise ValueError(
|
||||
f"{integer_str}({repr(bit_width)}) is not supported "
|
||||
f"(bit width must be a positive integer)"
|
||||
)
|
||||
|
||||
self.is_signed = is_signed
|
||||
self.bit_width = bit_width
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and self.is_signed == other.is_signed
|
||||
and self.bit_width == other.bit_width
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{('int' if self.is_signed else 'uint')}{self.bit_width}"
|
||||
|
||||
def min(self) -> int:
|
||||
"""
|
||||
Get the minumum value that can be represented by the `Integer`.
|
||||
|
||||
Returns:
|
||||
int:
|
||||
minumum value that can be represented by the `Integer`
|
||||
"""
|
||||
|
||||
return 0 if not self.is_signed else -(2 ** (self.bit_width - 1))
|
||||
|
||||
def max(self) -> int:
|
||||
"""
|
||||
Get the maximum value that can be represented by the `Integer`.
|
||||
|
||||
Returns:
|
||||
int:
|
||||
maximum value that can be represented by the `Integer`
|
||||
"""
|
||||
|
||||
return (2 ** self.bit_width) - 1 if not self.is_signed else (2 ** (self.bit_width - 1)) - 1
|
||||
|
||||
def can_represent(self, value: int) -> bool:
|
||||
"""
|
||||
Get whether `value` can be represented by the `Integer` or not.
|
||||
|
||||
Args:
|
||||
value (int):
|
||||
value to check representability
|
||||
|
||||
Returns:
|
||||
bool:
|
||||
True if `value` is representable by the `integer`, False otherwise
|
||||
"""
|
||||
|
||||
return self.min() <= value <= self.max()
|
||||
|
||||
|
||||
SignedInteger = partial(Integer, True)
|
||||
|
||||
UnsignedInteger = partial(Integer, False)
|
||||
74
concrete/numpy/dtypes/utils.py
Normal file
74
concrete/numpy/dtypes/utils.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
Declaration of various functions and constants related to data types.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
from ..internal.utils import assert_that
|
||||
from .base import BaseDataType
|
||||
from .float import Float
|
||||
from .integer import Integer, SignedInteger, UnsignedInteger
|
||||
|
||||
|
||||
def combine_dtypes(dtypes: List[BaseDataType]) -> BaseDataType:
|
||||
"""
|
||||
Get the 'BaseDataType' that can represent a set of 'BaseDataType's.
|
||||
|
||||
Args:
|
||||
dtypes (List[BaseDataType]):
|
||||
dtypes to combine
|
||||
|
||||
Returns:
|
||||
BaseDataType:
|
||||
dtype that can hold all the given dtypes (potentially lossy)
|
||||
"""
|
||||
|
||||
assert_that(len(dtypes) != 0)
|
||||
assert_that(all(isinstance(dtype, (Integer, Float)) for dtype in dtypes))
|
||||
|
||||
def combine_2_dtypes(dtype1: BaseDataType, dtype2: BaseDataType) -> BaseDataType:
|
||||
result: BaseDataType = dtype1
|
||||
|
||||
if isinstance(dtype1, Integer) and isinstance(dtype2, Integer):
|
||||
max_bits = max(dtype1.bit_width, dtype2.bit_width)
|
||||
|
||||
if dtype1.is_signed and dtype2.is_signed:
|
||||
result = SignedInteger(max_bits)
|
||||
|
||||
elif not dtype1.is_signed and not dtype2.is_signed:
|
||||
result = UnsignedInteger(max_bits)
|
||||
|
||||
elif dtype1.is_signed and not dtype2.is_signed:
|
||||
# if dtype2 has the bigger bit_width,
|
||||
# we need a signed integer that can hold
|
||||
# it, so add 1 bit of sign to its bit_width
|
||||
if dtype2.bit_width >= dtype1.bit_width:
|
||||
new_bit_width = dtype2.bit_width + 1
|
||||
result = SignedInteger(new_bit_width)
|
||||
else:
|
||||
result = SignedInteger(dtype1.bit_width)
|
||||
|
||||
elif not dtype1.is_signed and dtype2.is_signed:
|
||||
# Same as above, with dtype1 and dtype2 switched around
|
||||
if dtype1.bit_width >= dtype2.bit_width:
|
||||
new_bit_width = dtype1.bit_width + 1
|
||||
result = SignedInteger(new_bit_width)
|
||||
else:
|
||||
result = SignedInteger(dtype2.bit_width)
|
||||
|
||||
elif isinstance(dtype1, Float) and isinstance(dtype2, Float):
|
||||
max_bits = max(dtype1.bit_width, dtype2.bit_width)
|
||||
result = Float(max_bits)
|
||||
|
||||
elif isinstance(dtype1, Float):
|
||||
result = dtype1
|
||||
|
||||
elif isinstance(dtype2, Float):
|
||||
result = dtype2
|
||||
|
||||
return result
|
||||
|
||||
result = dtypes[0]
|
||||
for other in dtypes[1:]:
|
||||
result = combine_2_dtypes(result, other)
|
||||
return result
|
||||
Reference in New Issue
Block a user