mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: implement values module
This commit is contained in:
7
concrete/numpy/values/__init__.py
Normal file
7
concrete/numpy/values/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Declaration of `concrete.numpy.values` namespace.
|
||||
"""
|
||||
|
||||
from .scalar import ClearScalar, EncryptedScalar
|
||||
from .tensor import ClearTensor, EncryptedTensor
|
||||
from .value import Value
|
||||
44
concrete/numpy/values/scalar.py
Normal file
44
concrete/numpy/values/scalar.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
Declaration of `ClearScalar` and `EncryptedScalar` wrappers.
|
||||
"""
|
||||
|
||||
from ..dtypes import BaseDataType
|
||||
from .value import Value
|
||||
|
||||
|
||||
def clear_scalar_builder(dtype: BaseDataType) -> Value:
|
||||
"""
|
||||
Build a clear scalar value.
|
||||
|
||||
Args:
|
||||
dtype (BaseDataType):
|
||||
dtype of the value
|
||||
|
||||
Returns:
|
||||
Value:
|
||||
clear scalar value with given dtype
|
||||
"""
|
||||
|
||||
return Value(dtype=dtype, shape=(), is_encrypted=False)
|
||||
|
||||
|
||||
ClearScalar = clear_scalar_builder
|
||||
|
||||
|
||||
def encrypted_scalar_builder(dtype: BaseDataType) -> Value:
|
||||
"""
|
||||
Build an encrypted scalar value.
|
||||
|
||||
Args:
|
||||
dtype (BaseDataType):
|
||||
dtype of the value
|
||||
|
||||
Returns:
|
||||
Value:
|
||||
encrypted scalar value with given dtype
|
||||
"""
|
||||
|
||||
return Value(dtype=dtype, shape=(), is_encrypted=True)
|
||||
|
||||
|
||||
EncryptedScalar = encrypted_scalar_builder
|
||||
52
concrete/numpy/values/tensor.py
Normal file
52
concrete/numpy/values/tensor.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
Declaration of `ClearTensor` and `EncryptedTensor` wrappers.
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from ..dtypes import BaseDataType
|
||||
from .value import Value
|
||||
|
||||
|
||||
def clear_tensor_builder(dtype: BaseDataType, shape: Tuple[int, ...]) -> Value:
|
||||
"""
|
||||
Build a clear tensor value.
|
||||
|
||||
Args:
|
||||
dtype (BaseDataType):
|
||||
dtype of the value
|
||||
|
||||
shape (Tuple[int, ...]):
|
||||
shape of the value
|
||||
|
||||
Returns:
|
||||
Value:
|
||||
clear tensor value with given dtype and shape
|
||||
"""
|
||||
|
||||
return Value(dtype=dtype, shape=shape, is_encrypted=False)
|
||||
|
||||
|
||||
ClearTensor = clear_tensor_builder
|
||||
|
||||
|
||||
def encrypted_tensor_builder(dtype: BaseDataType, shape: Tuple[int, ...]) -> Value:
|
||||
"""
|
||||
Build an encrypted tensor value.
|
||||
|
||||
Args:
|
||||
dtype (BaseDataType):
|
||||
dtype of the value
|
||||
|
||||
shape (Tuple[int, ...]):
|
||||
shape of the value
|
||||
|
||||
Returns:
|
||||
Value:
|
||||
encrypted tensor value with given dtype and shape
|
||||
"""
|
||||
|
||||
return Value(dtype=dtype, shape=shape, is_encrypted=True)
|
||||
|
||||
|
||||
EncryptedTensor = encrypted_tensor_builder
|
||||
162
concrete/numpy/values/value.py
Normal file
162
concrete/numpy/values/value.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
Declaration of `Value` class.
|
||||
"""
|
||||
|
||||
from math import prod
|
||||
from typing import Any, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..dtypes import BaseDataType, Float, Integer, UnsignedInteger
|
||||
|
||||
|
||||
class Value:
|
||||
"""
|
||||
Value class, to combine data type, shape, and encryption status into a single object.
|
||||
"""
|
||||
|
||||
dtype: BaseDataType
|
||||
shape: Tuple[int, ...]
|
||||
is_encrypted: bool
|
||||
|
||||
@staticmethod
|
||||
def of(value: Any, is_encrypted: bool = False) -> "Value": # pylint: disable=invalid-name
|
||||
"""
|
||||
Get the `Value` that can represent `value`.
|
||||
|
||||
Args:
|
||||
value (Any):
|
||||
value that needs to be represented
|
||||
|
||||
is_encrypted (bool, default = False):
|
||||
whether the resulting `Value` is encrypted or not
|
||||
|
||||
Returns:
|
||||
Value:
|
||||
`Value` that can represent `value`
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
if `value` cannot be represented by `Value`
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-branches,too-many-return-statements
|
||||
|
||||
if isinstance(value, (bool, np.bool_)):
|
||||
return Value(dtype=UnsignedInteger(1), shape=(), is_encrypted=is_encrypted)
|
||||
|
||||
if isinstance(value, (int, np.integer)):
|
||||
return Value(
|
||||
dtype=Integer.that_can_represent(value),
|
||||
shape=(),
|
||||
is_encrypted=is_encrypted,
|
||||
)
|
||||
|
||||
if isinstance(value, (float, np.float64)):
|
||||
return Value(dtype=Float(64), shape=(), is_encrypted=is_encrypted)
|
||||
|
||||
if isinstance(value, np.float32):
|
||||
return Value(dtype=Float(32), shape=(), is_encrypted=is_encrypted)
|
||||
|
||||
if isinstance(value, np.float16):
|
||||
return Value(dtype=Float(16), shape=(), is_encrypted=is_encrypted)
|
||||
|
||||
if isinstance(value, list):
|
||||
try:
|
||||
value = np.array(value)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# here we try our best to convert the list to np.ndarray
|
||||
# if it fails we raise the exception at the end of the function
|
||||
pass
|
||||
|
||||
if isinstance(value, np.ndarray):
|
||||
|
||||
if np.issubdtype(value.dtype, np.bool_):
|
||||
return Value(dtype=UnsignedInteger(1), shape=value.shape, is_encrypted=is_encrypted)
|
||||
|
||||
if np.issubdtype(value.dtype, np.integer):
|
||||
return Value(
|
||||
dtype=Integer.that_can_represent(value),
|
||||
shape=value.shape,
|
||||
is_encrypted=is_encrypted,
|
||||
)
|
||||
|
||||
if np.issubdtype(value.dtype, np.float64):
|
||||
return Value(dtype=Float(64), shape=value.shape, is_encrypted=is_encrypted)
|
||||
|
||||
if np.issubdtype(value.dtype, np.float32):
|
||||
return Value(dtype=Float(32), shape=value.shape, is_encrypted=is_encrypted)
|
||||
|
||||
if np.issubdtype(value.dtype, np.float16):
|
||||
return Value(dtype=Float(16), shape=value.shape, is_encrypted=is_encrypted)
|
||||
|
||||
raise ValueError(f"Value cannot represent {repr(value)}")
|
||||
|
||||
# pylint: enable=too-many-branches,too-many-return-statements
|
||||
|
||||
def __init__(self, dtype: BaseDataType, shape: Tuple[int, ...], is_encrypted: bool):
|
||||
self.dtype = dtype
|
||||
self.shape = shape
|
||||
self.is_encrypted = is_encrypted
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return (
|
||||
isinstance(other, Value)
|
||||
and self.dtype == other.dtype
|
||||
and self.shape == other.shape
|
||||
and self.is_encrypted == other.is_encrypted
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
encrypted_or_clear_str = "Encrypted" if self.is_encrypted else "Clear"
|
||||
scalar_or_tensor_str = "Scalar" if self.is_scalar else "Tensor"
|
||||
shape_str = f", shape={self.shape}" if not self.is_scalar else ""
|
||||
return f"{encrypted_or_clear_str}{scalar_or_tensor_str}<{str(self.dtype)}{shape_str}>"
|
||||
|
||||
@property
|
||||
def is_clear(self) -> bool:
|
||||
"""
|
||||
Get whether the value is clear or not.
|
||||
|
||||
Returns:
|
||||
bool:
|
||||
True if value is not encrypted, False otherwise
|
||||
"""
|
||||
|
||||
return not self.is_encrypted
|
||||
|
||||
@property
|
||||
def is_scalar(self) -> bool:
|
||||
"""
|
||||
Get whether the value is scalar or not.
|
||||
|
||||
Returns:
|
||||
bool:
|
||||
True if shape of the value is (), False otherwise
|
||||
"""
|
||||
|
||||
return self.shape == ()
|
||||
|
||||
@property
|
||||
def ndim(self) -> int:
|
||||
"""
|
||||
Get number of dimensions of the value.
|
||||
|
||||
Returns:
|
||||
int:
|
||||
number of dimensions of the value
|
||||
"""
|
||||
|
||||
return len(self.shape)
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""
|
||||
Get number of elements in the value.
|
||||
|
||||
Returns:
|
||||
int:
|
||||
number of elements in the value
|
||||
"""
|
||||
|
||||
return int(prod(self.shape))
|
||||
Reference in New Issue
Block a user