diff --git a/hdk/common/data_types/dtypes_helpers.py b/hdk/common/data_types/dtypes_helpers.py index 3d289205e..731eb54d1 100644 --- a/hdk/common/data_types/dtypes_helpers.py +++ b/hdk/common/data_types/dtypes_helpers.py @@ -6,7 +6,7 @@ from typing import cast from .base import BaseDataType from .floats import Float from .integers import Integer -from .values import BaseValue, ClearValue, EncryptedValue +from .values import BaseValue, ClearValue, EncryptedValue, ScalarValue INTEGER_TYPES = (Integer,) FLOAT_TYPES = (Float,) @@ -22,8 +22,10 @@ def value_is_encrypted_integer(value_to_check: BaseValue) -> bool: Returns: bool: True if the passed value_to_check is an encrypted value of type Integer """ - return isinstance(value_to_check, EncryptedValue) and isinstance( - value_to_check.data_type, INTEGER_TYPES + return ( + isinstance(value_to_check, BaseValue) + and value_to_check.is_encrypted + and isinstance(value_to_check.data_type, INTEGER_TYPES) ) @@ -51,8 +53,10 @@ def value_is_clear_integer(value_to_check: BaseValue) -> bool: Returns: bool: True if the passed value_to_check is a clear value of type Integer """ - return isinstance(value_to_check, ClearValue) and isinstance( - value_to_check.data_type, INTEGER_TYPES + return ( + isinstance(value_to_check, BaseValue) + and value_to_check.is_clear + and isinstance(value_to_check.data_type, INTEGER_TYPES) ) @@ -65,7 +69,9 @@ def value_is_integer(value_to_check: BaseValue) -> bool: Returns: bool: True if the passed value_to_check is a value of type Integer """ - return isinstance(value_to_check.data_type, INTEGER_TYPES) + return isinstance(value_to_check, BaseValue) and isinstance( + value_to_check.data_type, INTEGER_TYPES + ) def find_type_to_hold_both_lossy( @@ -127,26 +133,29 @@ def find_type_to_hold_both_lossy( return type_to_return -def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> BaseValue: +def mix_scalar_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> ScalarValue: """Return mixed value with data type able to hold both value1 and value2 dtypes. - Returns a Value that would result from computation on both value1 and value2 while + Returns a ScalarValue that would result from computation on both value1 and value2 while determining the data type able to hold both value1 and value2 data type (this can be lossy - with floats) + with floats). Args: - value1 (BaseValue): first value to mix - value2 (BaseValue): second value to mix + value1 (BaseValue): first ScalarValue to mix. + value2 (BaseValue): second ScalarValue to mix. Returns: - BaseValue: The resulting mixed value with data type able to hold both value1 and value2 - dtypes + ScalarValue: The resulting mixed BaseValue with data type able to hold both value1 and + value2 dtypes. """ + + assert isinstance(value1, ScalarValue), f"Unsupported value1: {value1}, expected ScalarValue" + assert isinstance(value2, ScalarValue), f"Unsupported value2: {value2}, expected ScalarValue" + holding_type = find_type_to_hold_both_lossy(value1.data_type, value2.data_type) + mixed_value: ScalarValue - mixed_value: BaseValue - - if isinstance(value1, EncryptedValue) or isinstance(value2, EncryptedValue): + if value1.is_encrypted or value2.is_encrypted: mixed_value = EncryptedValue(holding_type) else: mixed_value = ClearValue(holding_type) diff --git a/hdk/common/data_types/values.py b/hdk/common/data_types/values.py index eda60eeb5..0001b3a83 100644 --- a/hdk/common/data_types/values.py +++ b/hdk/common/data_types/values.py @@ -1,6 +1,7 @@ """File holding classes representing values used by an FHE program.""" -from abc import ABC +from abc import ABC, abstractmethod +from functools import partial from . import base @@ -9,20 +10,46 @@ class BaseValue(ABC): """Abstract base class to represent any kind of value in a program.""" data_type: base.BaseDataType + _is_encrypted: bool - def __init__(self, data_type: base.BaseDataType) -> None: + def __init__(self, data_type: base.BaseDataType, is_encrypted: bool) -> None: self.data_type = data_type + self._is_encrypted = is_encrypted def __repr__(self) -> str: # pragma: no cover - return f"{self.__class__.__name__}<{self.data_type!r}>" + encrypted_str = "Encrypted" if self._is_encrypted else "Clear" + return f"{encrypted_str}{self.__class__.__name__}<{self.data_type!r}>" + @abstractmethod def __eq__(self, other: object) -> bool: return isinstance(other, self.__class__) and self.data_type == other.data_type + @property + def is_encrypted(self) -> bool: + """Whether Value is encrypted or not. -class ClearValue(BaseValue): - """Class representing a clear/plaintext value (constant or not).""" + Returns: + bool: True if encrypted False otherwise + """ + return self._is_encrypted + + @property + def is_clear(self) -> bool: + """Whether Value is clear or not. + + Returns: + bool: True if clear False otherwise + """ + return not self._is_encrypted -class EncryptedValue(BaseValue): - """Class representing an encrypted value (constant or not).""" +class ScalarValue(BaseValue): + """Class representing a scalar value.""" + + def __eq__(self, other: object) -> bool: + return BaseValue.__eq__(self, other) + + +ClearValue = partial(ScalarValue, is_encrypted=False) + +EncryptedValue = partial(ScalarValue, is_encrypted=True) diff --git a/hdk/common/representation/intermediate.py b/hdk/common/representation/intermediate.py index e257556d5..13903685a 100644 --- a/hdk/common/representation/intermediate.py +++ b/hdk/common/representation/intermediate.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple from ..data_types import BaseValue from ..data_types.base import BaseDataType -from ..data_types.dtypes_helpers import mix_values_determine_holding_dtype +from ..data_types.dtypes_helpers import mix_scalar_values_determine_holding_dtype from ..data_types.floats import Float from ..data_types.integers import Integer, get_bits_to_represent_int from ..data_types.scalars import Scalars @@ -35,7 +35,7 @@ class IntermediateNode(ABC): assert len(self.inputs) == 2 - self.outputs = [mix_values_determine_holding_dtype(self.inputs[0], self.inputs[1])] + self.outputs = [mix_scalar_values_determine_holding_dtype(self.inputs[0], self.inputs[1])] def _is_equivalent_to_binary_commutative(self, other: object) -> bool: """is_equivalent_to for a binary and commutative operation.""" diff --git a/tests/common/data_types/test_dtypes_helpers.py b/tests/common/data_types/test_dtypes_helpers.py index c66fbae88..1a4761330 100644 --- a/tests/common/data_types/test_dtypes_helpers.py +++ b/tests/common/data_types/test_dtypes_helpers.py @@ -5,7 +5,7 @@ import pytest from hdk.common.data_types.base import BaseDataType from hdk.common.data_types.dtypes_helpers import ( find_type_to_hold_both_lossy, - mix_values_determine_holding_dtype, + mix_scalar_values_determine_holding_dtype, value_is_encrypted_integer, value_is_encrypted_unsigned_integer, ) @@ -167,4 +167,4 @@ def test_mix_data_types( def test_mix_values(value1: BaseValue, value2: BaseValue, expected_mixed_value: BaseValue): """Test mix_values helper""" - assert expected_mixed_value == mix_values_determine_holding_dtype(value1, value2) + assert expected_mixed_value == mix_scalar_values_determine_holding_dtype(value1, value2)