feat: add TensorValue

This commit is contained in:
Arthur Meyre
2021-08-20 11:46:12 +02:00
parent 985cf973d8
commit 4e658c15cb
2 changed files with 181 additions and 0 deletions

View File

@@ -2,6 +2,8 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from math import prod
from typing import Optional, Tuple
from .base import BaseDataType
@@ -80,3 +82,95 @@ def make_encrypted_scalar(
ClearValue = make_clear_scalar
EncryptedValue = make_encrypted_scalar
class TensorValue(BaseValue):
"""Class representing a tensor value."""
_shape: Tuple[int, ...]
_ndim: int
_size: int
def __init__(
self,
data_type: BaseDataType,
is_encrypted: bool,
shape: Optional[Tuple[int, ...]] = None,
) -> None:
super().__init__(data_type, is_encrypted)
# Managing tensors as in numpy, no shape or () is treated as a 0-D array of size 1
self._shape = shape if shape is not None else ()
self._ndim = len(self._shape)
self._size = prod(self._shape) if self._shape else 1
def __eq__(self, other: object) -> bool:
return (
isinstance(other, self.__class__)
and self.shape == other.shape
and self.ndim == other.ndim
and self.size == other.size
and super().__eq__(other)
)
@property
def shape(self) -> Tuple[int, ...]:
"""The TensorValue shape property.
Returns:
Tuple[int, ...]: The TensorValue shape.
"""
return self._shape
@property
def ndim(self) -> int:
"""The TensorValue ndim property.
Returns:
int: The TensorValue ndim.
"""
return self._ndim
@property
def size(self) -> int:
"""The TensorValue size property.
Returns:
int: The TensorValue size.
"""
return self._size
def make_clear_tensor(
data_type: BaseDataType,
shape: Optional[Tuple[int, ...]] = None,
) -> TensorValue:
"""Helper to create a clear TensorValue.
Args:
data_type (BaseDataType): The data type for the tensor.
shape (Optional[Tuple[int, ...]], optional): The tensor shape. Defaults to None.
Returns:
TensorValue: The corresponding TensorValue.
"""
return TensorValue(data_type=data_type, is_encrypted=False, shape=shape)
def make_encrypted_tensor(
data_type: BaseDataType,
shape: Optional[Tuple[int, ...]] = None,
) -> TensorValue:
"""Helper to create an encrypted TensorValue.
Args:
data_type (BaseDataType): The data type for the tensor.
shape (Optional[Tuple[int, ...]], optional): The tensor shape. Defaults to None.
Returns:
TensorValue: The corresponding TensorValue.
"""
return TensorValue(data_type=data_type, is_encrypted=True, shape=shape)
ClearTensor = make_clear_tensor
EncryptedTensor = make_encrypted_tensor

View File

@@ -0,0 +1,87 @@
"""Test file for values related code."""
from copy import deepcopy
from functools import partial
from typing import Callable, Optional, Tuple, Union
import pytest
from hdk.common.data_types.base import BaseDataType
from hdk.common.data_types.floats import Float
from hdk.common.data_types.integers import Integer
from hdk.common.data_types.values import ClearTensor, EncryptedTensor, TensorValue
class DummyDtype(BaseDataType):
"""Dummy Helper Dtype"""
def __eq__(self, o: object) -> bool:
return isinstance(o, self.__class__)
@pytest.mark.parametrize(
"tensor_constructor,expected_is_encrypted",
[
(ClearTensor, False),
(partial(TensorValue, is_encrypted=False), False),
(EncryptedTensor, True),
(partial(TensorValue, is_encrypted=True), True),
],
)
@pytest.mark.parametrize(
"shape,expected_shape,expected_ndim,expected_size",
[
(None, (), 0, 1),
((), (), 0, 1),
((3, 256, 256), (3, 256, 256), 3, 196_608),
((1920, 1080, 3), (1920, 1080, 3), 3, 6_220_800),
],
)
@pytest.mark.parametrize(
"data_type",
[
Integer(7, False),
Integer(32, True),
Integer(32, False),
Integer(64, True),
Integer(64, False),
Float(32),
Float(64),
],
)
def test_tensor_value(
tensor_constructor: Callable[..., TensorValue],
expected_is_encrypted: bool,
shape: Optional[Tuple[int, ...]],
expected_shape: Tuple[int, ...],
expected_ndim: int,
expected_size: int,
data_type: Union[Integer, Float],
):
"""Test function for TensorValue"""
tensor_value = tensor_constructor(data_type=data_type, shape=shape)
assert expected_is_encrypted == tensor_value.is_encrypted
assert expected_shape == tensor_value.shape
assert expected_ndim == tensor_value.ndim
assert expected_size == tensor_value.size
assert data_type == tensor_value.data_type
other_tensor = deepcopy(tensor_value)
assert other_tensor == tensor_value
other_tensor_value = deepcopy(other_tensor)
other_tensor_value.data_type = DummyDtype()
assert other_tensor_value != tensor_value
other_shape = tuple(val + 1 for val in shape) if shape is not None else ()
other_shape += (2,)
other_tensor_value = tensor_constructor(data_type=data_type, shape=other_shape)
assert other_tensor_value.shape != tensor_value.shape
assert other_tensor_value.ndim != tensor_value.ndim
assert other_tensor_value.size != tensor_value.size
assert other_tensor_value != tensor_value