diff --git a/hdk/common/data_types/values.py b/hdk/common/data_types/values.py index 0f8e8738d..16c7f409b 100644 --- a/hdk/common/data_types/values.py +++ b/hdk/common/data_types/values.py @@ -2,6 +2,8 @@ from abc import ABC, abstractmethod from copy import deepcopy +from math import prod +from typing import Optional, Tuple from .base import BaseDataType @@ -80,3 +82,95 @@ def make_encrypted_scalar( ClearValue = make_clear_scalar EncryptedValue = make_encrypted_scalar + + +class TensorValue(BaseValue): + """Class representing a tensor value.""" + + _shape: Tuple[int, ...] + _ndim: int + _size: int + + def __init__( + self, + data_type: BaseDataType, + is_encrypted: bool, + shape: Optional[Tuple[int, ...]] = None, + ) -> None: + super().__init__(data_type, is_encrypted) + # Managing tensors as in numpy, no shape or () is treated as a 0-D array of size 1 + self._shape = shape if shape is not None else () + self._ndim = len(self._shape) + self._size = prod(self._shape) if self._shape else 1 + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, self.__class__) + and self.shape == other.shape + and self.ndim == other.ndim + and self.size == other.size + and super().__eq__(other) + ) + + @property + def shape(self) -> Tuple[int, ...]: + """The TensorValue shape property. + + Returns: + Tuple[int, ...]: The TensorValue shape. + """ + return self._shape + + @property + def ndim(self) -> int: + """The TensorValue ndim property. + + Returns: + int: The TensorValue ndim. + """ + return self._ndim + + @property + def size(self) -> int: + """The TensorValue size property. + + Returns: + int: The TensorValue size. + """ + return self._size + + +def make_clear_tensor( + data_type: BaseDataType, + shape: Optional[Tuple[int, ...]] = None, +) -> TensorValue: + """Helper to create a clear TensorValue. + + Args: + data_type (BaseDataType): The data type for the tensor. + shape (Optional[Tuple[int, ...]], optional): The tensor shape. Defaults to None. + + Returns: + TensorValue: The corresponding TensorValue. + """ + return TensorValue(data_type=data_type, is_encrypted=False, shape=shape) + + +def make_encrypted_tensor( + data_type: BaseDataType, + shape: Optional[Tuple[int, ...]] = None, +) -> TensorValue: + """Helper to create an encrypted TensorValue. + + Args: + data_type (BaseDataType): The data type for the tensor. + shape (Optional[Tuple[int, ...]], optional): The tensor shape. Defaults to None. + + Returns: + TensorValue: The corresponding TensorValue. + """ + return TensorValue(data_type=data_type, is_encrypted=True, shape=shape) + + +ClearTensor = make_clear_tensor +EncryptedTensor = make_encrypted_tensor diff --git a/tests/common/data_types/test_values.py b/tests/common/data_types/test_values.py new file mode 100644 index 000000000..3f84a34a6 --- /dev/null +++ b/tests/common/data_types/test_values.py @@ -0,0 +1,87 @@ +"""Test file for values related code.""" + +from copy import deepcopy +from functools import partial +from typing import Callable, Optional, Tuple, Union + +import pytest + +from hdk.common.data_types.base import BaseDataType +from hdk.common.data_types.floats import Float +from hdk.common.data_types.integers import Integer +from hdk.common.data_types.values import ClearTensor, EncryptedTensor, TensorValue + + +class DummyDtype(BaseDataType): + """Dummy Helper Dtype""" + + def __eq__(self, o: object) -> bool: + return isinstance(o, self.__class__) + + +@pytest.mark.parametrize( + "tensor_constructor,expected_is_encrypted", + [ + (ClearTensor, False), + (partial(TensorValue, is_encrypted=False), False), + (EncryptedTensor, True), + (partial(TensorValue, is_encrypted=True), True), + ], +) +@pytest.mark.parametrize( + "shape,expected_shape,expected_ndim,expected_size", + [ + (None, (), 0, 1), + ((), (), 0, 1), + ((3, 256, 256), (3, 256, 256), 3, 196_608), + ((1920, 1080, 3), (1920, 1080, 3), 3, 6_220_800), + ], +) +@pytest.mark.parametrize( + "data_type", + [ + Integer(7, False), + Integer(32, True), + Integer(32, False), + Integer(64, True), + Integer(64, False), + Float(32), + Float(64), + ], +) +def test_tensor_value( + tensor_constructor: Callable[..., TensorValue], + expected_is_encrypted: bool, + shape: Optional[Tuple[int, ...]], + expected_shape: Tuple[int, ...], + expected_ndim: int, + expected_size: int, + data_type: Union[Integer, Float], +): + """Test function for TensorValue""" + + tensor_value = tensor_constructor(data_type=data_type, shape=shape) + + assert expected_is_encrypted == tensor_value.is_encrypted + assert expected_shape == tensor_value.shape + assert expected_ndim == tensor_value.ndim + assert expected_size == tensor_value.size + + assert data_type == tensor_value.data_type + + other_tensor = deepcopy(tensor_value) + + assert other_tensor == tensor_value + + other_tensor_value = deepcopy(other_tensor) + other_tensor_value.data_type = DummyDtype() + assert other_tensor_value != tensor_value + + other_shape = tuple(val + 1 for val in shape) if shape is not None else () + other_shape += (2,) + other_tensor_value = tensor_constructor(data_type=data_type, shape=other_shape) + + assert other_tensor_value.shape != tensor_value.shape + assert other_tensor_value.ndim != tensor_value.ndim + assert other_tensor_value.size != tensor_value.size + assert other_tensor_value != tensor_value