mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: add TensorValue
This commit is contained in:
@@ -2,6 +2,8 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from math import prod
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from .base import BaseDataType
|
||||
|
||||
@@ -80,3 +82,95 @@ def make_encrypted_scalar(
|
||||
|
||||
ClearValue = make_clear_scalar
|
||||
EncryptedValue = make_encrypted_scalar
|
||||
|
||||
|
||||
class TensorValue(BaseValue):
|
||||
"""Class representing a tensor value."""
|
||||
|
||||
_shape: Tuple[int, ...]
|
||||
_ndim: int
|
||||
_size: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_type: BaseDataType,
|
||||
is_encrypted: bool,
|
||||
shape: Optional[Tuple[int, ...]] = None,
|
||||
) -> None:
|
||||
super().__init__(data_type, is_encrypted)
|
||||
# Managing tensors as in numpy, no shape or () is treated as a 0-D array of size 1
|
||||
self._shape = shape if shape is not None else ()
|
||||
self._ndim = len(self._shape)
|
||||
self._size = prod(self._shape) if self._shape else 1
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and self.shape == other.shape
|
||||
and self.ndim == other.ndim
|
||||
and self.size == other.size
|
||||
and super().__eq__(other)
|
||||
)
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
"""The TensorValue shape property.
|
||||
|
||||
Returns:
|
||||
Tuple[int, ...]: The TensorValue shape.
|
||||
"""
|
||||
return self._shape
|
||||
|
||||
@property
|
||||
def ndim(self) -> int:
|
||||
"""The TensorValue ndim property.
|
||||
|
||||
Returns:
|
||||
int: The TensorValue ndim.
|
||||
"""
|
||||
return self._ndim
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""The TensorValue size property.
|
||||
|
||||
Returns:
|
||||
int: The TensorValue size.
|
||||
"""
|
||||
return self._size
|
||||
|
||||
|
||||
def make_clear_tensor(
|
||||
data_type: BaseDataType,
|
||||
shape: Optional[Tuple[int, ...]] = None,
|
||||
) -> TensorValue:
|
||||
"""Helper to create a clear TensorValue.
|
||||
|
||||
Args:
|
||||
data_type (BaseDataType): The data type for the tensor.
|
||||
shape (Optional[Tuple[int, ...]], optional): The tensor shape. Defaults to None.
|
||||
|
||||
Returns:
|
||||
TensorValue: The corresponding TensorValue.
|
||||
"""
|
||||
return TensorValue(data_type=data_type, is_encrypted=False, shape=shape)
|
||||
|
||||
|
||||
def make_encrypted_tensor(
|
||||
data_type: BaseDataType,
|
||||
shape: Optional[Tuple[int, ...]] = None,
|
||||
) -> TensorValue:
|
||||
"""Helper to create an encrypted TensorValue.
|
||||
|
||||
Args:
|
||||
data_type (BaseDataType): The data type for the tensor.
|
||||
shape (Optional[Tuple[int, ...]], optional): The tensor shape. Defaults to None.
|
||||
|
||||
Returns:
|
||||
TensorValue: The corresponding TensorValue.
|
||||
"""
|
||||
return TensorValue(data_type=data_type, is_encrypted=True, shape=shape)
|
||||
|
||||
|
||||
ClearTensor = make_clear_tensor
|
||||
EncryptedTensor = make_encrypted_tensor
|
||||
|
||||
87
tests/common/data_types/test_values.py
Normal file
87
tests/common/data_types/test_values.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Test file for values related code."""
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from hdk.common.data_types.base import BaseDataType
|
||||
from hdk.common.data_types.floats import Float
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.data_types.values import ClearTensor, EncryptedTensor, TensorValue
|
||||
|
||||
|
||||
class DummyDtype(BaseDataType):
|
||||
"""Dummy Helper Dtype"""
|
||||
|
||||
def __eq__(self, o: object) -> bool:
|
||||
return isinstance(o, self.__class__)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tensor_constructor,expected_is_encrypted",
|
||||
[
|
||||
(ClearTensor, False),
|
||||
(partial(TensorValue, is_encrypted=False), False),
|
||||
(EncryptedTensor, True),
|
||||
(partial(TensorValue, is_encrypted=True), True),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"shape,expected_shape,expected_ndim,expected_size",
|
||||
[
|
||||
(None, (), 0, 1),
|
||||
((), (), 0, 1),
|
||||
((3, 256, 256), (3, 256, 256), 3, 196_608),
|
||||
((1920, 1080, 3), (1920, 1080, 3), 3, 6_220_800),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"data_type",
|
||||
[
|
||||
Integer(7, False),
|
||||
Integer(32, True),
|
||||
Integer(32, False),
|
||||
Integer(64, True),
|
||||
Integer(64, False),
|
||||
Float(32),
|
||||
Float(64),
|
||||
],
|
||||
)
|
||||
def test_tensor_value(
|
||||
tensor_constructor: Callable[..., TensorValue],
|
||||
expected_is_encrypted: bool,
|
||||
shape: Optional[Tuple[int, ...]],
|
||||
expected_shape: Tuple[int, ...],
|
||||
expected_ndim: int,
|
||||
expected_size: int,
|
||||
data_type: Union[Integer, Float],
|
||||
):
|
||||
"""Test function for TensorValue"""
|
||||
|
||||
tensor_value = tensor_constructor(data_type=data_type, shape=shape)
|
||||
|
||||
assert expected_is_encrypted == tensor_value.is_encrypted
|
||||
assert expected_shape == tensor_value.shape
|
||||
assert expected_ndim == tensor_value.ndim
|
||||
assert expected_size == tensor_value.size
|
||||
|
||||
assert data_type == tensor_value.data_type
|
||||
|
||||
other_tensor = deepcopy(tensor_value)
|
||||
|
||||
assert other_tensor == tensor_value
|
||||
|
||||
other_tensor_value = deepcopy(other_tensor)
|
||||
other_tensor_value.data_type = DummyDtype()
|
||||
assert other_tensor_value != tensor_value
|
||||
|
||||
other_shape = tuple(val + 1 for val in shape) if shape is not None else ()
|
||||
other_shape += (2,)
|
||||
other_tensor_value = tensor_constructor(data_type=data_type, shape=other_shape)
|
||||
|
||||
assert other_tensor_value.shape != tensor_value.shape
|
||||
assert other_tensor_value.ndim != tensor_value.ndim
|
||||
assert other_tensor_value.size != tensor_value.size
|
||||
assert other_tensor_value != tensor_value
|
||||
Reference in New Issue
Block a user