mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor: refactor BaseValue and mix_values_determine_holding_dtype
- add _is_encrypted to BaseValue - remove EncryptedValue and ClearValue classes - add a ScalarValue class - add two helpers EncryptedValue and ClearValue which create a ScalarValue either encrypted or not when passed a data_type - rename to mix_scalar_values_determine_holding_dtype - change typing
This commit is contained in:
@@ -6,7 +6,7 @@ from typing import cast
|
||||
from .base import BaseDataType
|
||||
from .floats import Float
|
||||
from .integers import Integer
|
||||
from .values import BaseValue, ClearValue, EncryptedValue
|
||||
from .values import BaseValue, ClearValue, EncryptedValue, ScalarValue
|
||||
|
||||
INTEGER_TYPES = (Integer,)
|
||||
FLOAT_TYPES = (Float,)
|
||||
@@ -22,8 +22,10 @@ def value_is_encrypted_integer(value_to_check: BaseValue) -> bool:
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is an encrypted value of type Integer
|
||||
"""
|
||||
return isinstance(value_to_check, EncryptedValue) and isinstance(
|
||||
value_to_check.data_type, INTEGER_TYPES
|
||||
return (
|
||||
isinstance(value_to_check, BaseValue)
|
||||
and value_to_check.is_encrypted
|
||||
and isinstance(value_to_check.data_type, INTEGER_TYPES)
|
||||
)
|
||||
|
||||
|
||||
@@ -51,8 +53,10 @@ def value_is_clear_integer(value_to_check: BaseValue) -> bool:
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is a clear value of type Integer
|
||||
"""
|
||||
return isinstance(value_to_check, ClearValue) and isinstance(
|
||||
value_to_check.data_type, INTEGER_TYPES
|
||||
return (
|
||||
isinstance(value_to_check, BaseValue)
|
||||
and value_to_check.is_clear
|
||||
and isinstance(value_to_check.data_type, INTEGER_TYPES)
|
||||
)
|
||||
|
||||
|
||||
@@ -65,7 +69,9 @@ def value_is_integer(value_to_check: BaseValue) -> bool:
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is a value of type Integer
|
||||
"""
|
||||
return isinstance(value_to_check.data_type, INTEGER_TYPES)
|
||||
return isinstance(value_to_check, BaseValue) and isinstance(
|
||||
value_to_check.data_type, INTEGER_TYPES
|
||||
)
|
||||
|
||||
|
||||
def find_type_to_hold_both_lossy(
|
||||
@@ -127,26 +133,29 @@ def find_type_to_hold_both_lossy(
|
||||
return type_to_return
|
||||
|
||||
|
||||
def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> BaseValue:
|
||||
def mix_scalar_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> ScalarValue:
|
||||
"""Return mixed value with data type able to hold both value1 and value2 dtypes.
|
||||
|
||||
Returns a Value that would result from computation on both value1 and value2 while
|
||||
Returns a ScalarValue that would result from computation on both value1 and value2 while
|
||||
determining the data type able to hold both value1 and value2 data type (this can be lossy
|
||||
with floats)
|
||||
with floats).
|
||||
|
||||
Args:
|
||||
value1 (BaseValue): first value to mix
|
||||
value2 (BaseValue): second value to mix
|
||||
value1 (BaseValue): first ScalarValue to mix.
|
||||
value2 (BaseValue): second ScalarValue to mix.
|
||||
|
||||
Returns:
|
||||
BaseValue: The resulting mixed value with data type able to hold both value1 and value2
|
||||
dtypes
|
||||
ScalarValue: The resulting mixed BaseValue with data type able to hold both value1 and
|
||||
value2 dtypes.
|
||||
"""
|
||||
|
||||
assert isinstance(value1, ScalarValue), f"Unsupported value1: {value1}, expected ScalarValue"
|
||||
assert isinstance(value2, ScalarValue), f"Unsupported value2: {value2}, expected ScalarValue"
|
||||
|
||||
holding_type = find_type_to_hold_both_lossy(value1.data_type, value2.data_type)
|
||||
mixed_value: ScalarValue
|
||||
|
||||
mixed_value: BaseValue
|
||||
|
||||
if isinstance(value1, EncryptedValue) or isinstance(value2, EncryptedValue):
|
||||
if value1.is_encrypted or value2.is_encrypted:
|
||||
mixed_value = EncryptedValue(holding_type)
|
||||
else:
|
||||
mixed_value = ClearValue(holding_type)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""File holding classes representing values used by an FHE program."""
|
||||
|
||||
from abc import ABC
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
|
||||
from . import base
|
||||
|
||||
@@ -9,20 +10,46 @@ class BaseValue(ABC):
|
||||
"""Abstract base class to represent any kind of value in a program."""
|
||||
|
||||
data_type: base.BaseDataType
|
||||
_is_encrypted: bool
|
||||
|
||||
def __init__(self, data_type: base.BaseDataType) -> None:
|
||||
def __init__(self, data_type: base.BaseDataType, is_encrypted: bool) -> None:
|
||||
self.data_type = data_type
|
||||
self._is_encrypted = is_encrypted
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover
|
||||
return f"{self.__class__.__name__}<{self.data_type!r}>"
|
||||
encrypted_str = "Encrypted" if self._is_encrypted else "Clear"
|
||||
return f"{encrypted_str}{self.__class__.__name__}<{self.data_type!r}>"
|
||||
|
||||
@abstractmethod
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, self.__class__) and self.data_type == other.data_type
|
||||
|
||||
@property
|
||||
def is_encrypted(self) -> bool:
|
||||
"""Whether Value is encrypted or not.
|
||||
|
||||
class ClearValue(BaseValue):
|
||||
"""Class representing a clear/plaintext value (constant or not)."""
|
||||
Returns:
|
||||
bool: True if encrypted False otherwise
|
||||
"""
|
||||
return self._is_encrypted
|
||||
|
||||
@property
|
||||
def is_clear(self) -> bool:
|
||||
"""Whether Value is clear or not.
|
||||
|
||||
Returns:
|
||||
bool: True if clear False otherwise
|
||||
"""
|
||||
return not self._is_encrypted
|
||||
|
||||
|
||||
class EncryptedValue(BaseValue):
|
||||
"""Class representing an encrypted value (constant or not)."""
|
||||
class ScalarValue(BaseValue):
|
||||
"""Class representing a scalar value."""
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return BaseValue.__eq__(self, other)
|
||||
|
||||
|
||||
ClearValue = partial(ScalarValue, is_encrypted=False)
|
||||
|
||||
EncryptedValue = partial(ScalarValue, is_encrypted=True)
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
|
||||
|
||||
from ..data_types import BaseValue
|
||||
from ..data_types.base import BaseDataType
|
||||
from ..data_types.dtypes_helpers import mix_values_determine_holding_dtype
|
||||
from ..data_types.dtypes_helpers import mix_scalar_values_determine_holding_dtype
|
||||
from ..data_types.floats import Float
|
||||
from ..data_types.integers import Integer, get_bits_to_represent_int
|
||||
from ..data_types.scalars import Scalars
|
||||
@@ -35,7 +35,7 @@ class IntermediateNode(ABC):
|
||||
|
||||
assert len(self.inputs) == 2
|
||||
|
||||
self.outputs = [mix_values_determine_holding_dtype(self.inputs[0], self.inputs[1])]
|
||||
self.outputs = [mix_scalar_values_determine_holding_dtype(self.inputs[0], self.inputs[1])]
|
||||
|
||||
def _is_equivalent_to_binary_commutative(self, other: object) -> bool:
|
||||
"""is_equivalent_to for a binary and commutative operation."""
|
||||
|
||||
@@ -5,7 +5,7 @@ import pytest
|
||||
from hdk.common.data_types.base import BaseDataType
|
||||
from hdk.common.data_types.dtypes_helpers import (
|
||||
find_type_to_hold_both_lossy,
|
||||
mix_values_determine_holding_dtype,
|
||||
mix_scalar_values_determine_holding_dtype,
|
||||
value_is_encrypted_integer,
|
||||
value_is_encrypted_unsigned_integer,
|
||||
)
|
||||
@@ -167,4 +167,4 @@ def test_mix_data_types(
|
||||
def test_mix_values(value1: BaseValue, value2: BaseValue, expected_mixed_value: BaseValue):
|
||||
"""Test mix_values helper"""
|
||||
|
||||
assert expected_mixed_value == mix_values_determine_holding_dtype(value1, value2)
|
||||
assert expected_mixed_value == mix_scalar_values_determine_holding_dtype(value1, value2)
|
||||
|
||||
Reference in New Issue
Block a user