refactor: refactor BaseValue and mix_values_determine_holding_dtype

- add _is_encrypted to BaseValue
- remove EncryptedValue and ClearValue classes
- add a ScalarValue class
- add two helpers EncryptedValue and ClearValue which create a ScalarValue
either encrypted or not when passed a data_type
- rename to mix_scalar_values_determine_holding_dtype
- change typing
This commit is contained in:
Arthur Meyre
2021-08-19 14:40:50 +02:00
parent 371ecae801
commit 0ff3ae4795
4 changed files with 63 additions and 27 deletions

View File

@@ -6,7 +6,7 @@ from typing import cast
from .base import BaseDataType
from .floats import Float
from .integers import Integer
from .values import BaseValue, ClearValue, EncryptedValue
from .values import BaseValue, ClearValue, EncryptedValue, ScalarValue
INTEGER_TYPES = (Integer,)
FLOAT_TYPES = (Float,)
@@ -22,8 +22,10 @@ def value_is_encrypted_integer(value_to_check: BaseValue) -> bool:
Returns:
bool: True if the passed value_to_check is an encrypted value of type Integer
"""
return isinstance(value_to_check, EncryptedValue) and isinstance(
value_to_check.data_type, INTEGER_TYPES
return (
isinstance(value_to_check, BaseValue)
and value_to_check.is_encrypted
and isinstance(value_to_check.data_type, INTEGER_TYPES)
)
@@ -51,8 +53,10 @@ def value_is_clear_integer(value_to_check: BaseValue) -> bool:
Returns:
bool: True if the passed value_to_check is a clear value of type Integer
"""
return isinstance(value_to_check, ClearValue) and isinstance(
value_to_check.data_type, INTEGER_TYPES
return (
isinstance(value_to_check, BaseValue)
and value_to_check.is_clear
and isinstance(value_to_check.data_type, INTEGER_TYPES)
)
@@ -65,7 +69,9 @@ def value_is_integer(value_to_check: BaseValue) -> bool:
Returns:
bool: True if the passed value_to_check is a value of type Integer
"""
return isinstance(value_to_check.data_type, INTEGER_TYPES)
return isinstance(value_to_check, BaseValue) and isinstance(
value_to_check.data_type, INTEGER_TYPES
)
def find_type_to_hold_both_lossy(
@@ -127,26 +133,29 @@ def find_type_to_hold_both_lossy(
return type_to_return
def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> BaseValue:
def mix_scalar_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> ScalarValue:
"""Return mixed value with data type able to hold both value1 and value2 dtypes.
Returns a Value that would result from computation on both value1 and value2 while
Returns a ScalarValue that would result from computation on both value1 and value2 while
determining the data type able to hold both value1 and value2 data type (this can be lossy
with floats)
with floats).
Args:
value1 (BaseValue): first value to mix
value2 (BaseValue): second value to mix
value1 (BaseValue): first ScalarValue to mix.
value2 (BaseValue): second ScalarValue to mix.
Returns:
BaseValue: The resulting mixed value with data type able to hold both value1 and value2
dtypes
ScalarValue: The resulting mixed BaseValue with data type able to hold both value1 and
value2 dtypes.
"""
assert isinstance(value1, ScalarValue), f"Unsupported value1: {value1}, expected ScalarValue"
assert isinstance(value2, ScalarValue), f"Unsupported value2: {value2}, expected ScalarValue"
holding_type = find_type_to_hold_both_lossy(value1.data_type, value2.data_type)
mixed_value: ScalarValue
mixed_value: BaseValue
if isinstance(value1, EncryptedValue) or isinstance(value2, EncryptedValue):
if value1.is_encrypted or value2.is_encrypted:
mixed_value = EncryptedValue(holding_type)
else:
mixed_value = ClearValue(holding_type)

View File

@@ -1,6 +1,7 @@
"""File holding classes representing values used by an FHE program."""
from abc import ABC
from abc import ABC, abstractmethod
from functools import partial
from . import base
@@ -9,20 +10,46 @@ class BaseValue(ABC):
"""Abstract base class to represent any kind of value in a program."""
data_type: base.BaseDataType
_is_encrypted: bool
def __init__(self, data_type: base.BaseDataType) -> None:
def __init__(self, data_type: base.BaseDataType, is_encrypted: bool) -> None:
self.data_type = data_type
self._is_encrypted = is_encrypted
def __repr__(self) -> str: # pragma: no cover
return f"{self.__class__.__name__}<{self.data_type!r}>"
encrypted_str = "Encrypted" if self._is_encrypted else "Clear"
return f"{encrypted_str}{self.__class__.__name__}<{self.data_type!r}>"
@abstractmethod
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.data_type == other.data_type
@property
def is_encrypted(self) -> bool:
"""Whether Value is encrypted or not.
class ClearValue(BaseValue):
"""Class representing a clear/plaintext value (constant or not)."""
Returns:
bool: True if encrypted False otherwise
"""
return self._is_encrypted
@property
def is_clear(self) -> bool:
"""Whether Value is clear or not.
Returns:
bool: True if clear False otherwise
"""
return not self._is_encrypted
class EncryptedValue(BaseValue):
"""Class representing an encrypted value (constant or not)."""
class ScalarValue(BaseValue):
"""Class representing a scalar value."""
def __eq__(self, other: object) -> bool:
return BaseValue.__eq__(self, other)
ClearValue = partial(ScalarValue, is_encrypted=False)
EncryptedValue = partial(ScalarValue, is_encrypted=True)

View File

@@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
from ..data_types import BaseValue
from ..data_types.base import BaseDataType
from ..data_types.dtypes_helpers import mix_values_determine_holding_dtype
from ..data_types.dtypes_helpers import mix_scalar_values_determine_holding_dtype
from ..data_types.floats import Float
from ..data_types.integers import Integer, get_bits_to_represent_int
from ..data_types.scalars import Scalars
@@ -35,7 +35,7 @@ class IntermediateNode(ABC):
assert len(self.inputs) == 2
self.outputs = [mix_values_determine_holding_dtype(self.inputs[0], self.inputs[1])]
self.outputs = [mix_scalar_values_determine_holding_dtype(self.inputs[0], self.inputs[1])]
def _is_equivalent_to_binary_commutative(self, other: object) -> bool:
"""is_equivalent_to for a binary and commutative operation."""

View File

@@ -5,7 +5,7 @@ import pytest
from hdk.common.data_types.base import BaseDataType
from hdk.common.data_types.dtypes_helpers import (
find_type_to_hold_both_lossy,
mix_values_determine_holding_dtype,
mix_scalar_values_determine_holding_dtype,
value_is_encrypted_integer,
value_is_encrypted_unsigned_integer,
)
@@ -167,4 +167,4 @@ def test_mix_data_types(
def test_mix_values(value1: BaseValue, value2: BaseValue, expected_mixed_value: BaseValue):
"""Test mix_values helper"""
assert expected_mixed_value == mix_values_determine_holding_dtype(value1, value2)
assert expected_mixed_value == mix_scalar_values_determine_holding_dtype(value1, value2)