mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
dev(data-types): create Value classes which represent values in a program
- value classes have a data_type member to know what they hold - add __repr__ to a few classes to ease readability for debug/print - add helper functions to perform value checks that will be used for tracing to ease readability - add unit tests to get 100% coverage
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
"""HDK's top import"""
|
||||
from . import utils
|
||||
from . import common, utils
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
"""HDK's module for data types code and data structures"""
|
||||
from . import integers
|
||||
from . import helpers, integers, values
|
||||
from .values import BaseValue
|
||||
|
||||
38
hdk/common/data_types/helpers.py
Normal file
38
hdk/common/data_types/helpers.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""File to hold helper functions for data types related stuff"""
|
||||
|
||||
from typing import cast
|
||||
|
||||
from . import integers, values
|
||||
|
||||
INTEGER_TYPES = set([integers.Integer])
|
||||
|
||||
|
||||
def value_is_encrypted_integer(value_to_check: values.BaseValue) -> bool:
|
||||
"""Helper function to check that a value is an encrypted_integer
|
||||
|
||||
Args:
|
||||
value_to_check (values.BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is an encrypted value of type Integer
|
||||
"""
|
||||
return (
|
||||
isinstance(value_to_check, values.EncryptedValue)
|
||||
and type(value_to_check.data_type) in INTEGER_TYPES
|
||||
)
|
||||
|
||||
|
||||
def value_is_encrypted_unsigned_integer(value_to_check: values.BaseValue) -> bool:
|
||||
"""Helper function to check that a value is an encrypted_integer
|
||||
|
||||
Args:
|
||||
value_to_check (values.BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is an encrypted value of type Integer
|
||||
"""
|
||||
|
||||
return (
|
||||
value_is_encrypted_integer(value_to_check)
|
||||
and not cast(integers.Integer, value_to_check.data_type).is_signed
|
||||
)
|
||||
@@ -13,6 +13,10 @@ class Integer(base.BaseDataType):
|
||||
self.bit_width = bit_width
|
||||
self.is_signed = is_signed
|
||||
|
||||
def __repr__(self) -> str:
|
||||
signed_str = "signed" if self.is_signed else "unsigned"
|
||||
return f"{self.__class__.__name__}<{signed_str}, {self.bit_width} bits>"
|
||||
|
||||
def min_value(self) -> int:
|
||||
"""Minimum value representable by the Integer"""
|
||||
if self.is_signed:
|
||||
|
||||
25
hdk/common/data_types/values.py
Normal file
25
hdk/common/data_types/values.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""File holding classes representing values used by an FHE program"""
|
||||
|
||||
from abc import ABC
|
||||
|
||||
from . import base
|
||||
|
||||
|
||||
class BaseValue(ABC):
|
||||
"""Abstract base class to represent any kind of value in a program"""
|
||||
|
||||
data_type: base.BaseDataType
|
||||
|
||||
def __init__(self, data_type: base.BaseDataType) -> None:
|
||||
self.data_type = data_type
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}<{self.data_type!r}>"
|
||||
|
||||
|
||||
class ClearValue(BaseValue):
|
||||
"""Class representing a clear/plaintext value (constant or not)"""
|
||||
|
||||
|
||||
class EncryptedValue(BaseValue):
|
||||
"""Class representing an encrypted value (constant or not)"""
|
||||
55
tests/common/data_types/test_helpers.py
Normal file
55
tests/common/data_types/test_helpers.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Test file for HDK's common/data_types/helpers.py"""
|
||||
|
||||
import pytest
|
||||
|
||||
from hdk.common.data_types.helpers import (
|
||||
value_is_encrypted_integer,
|
||||
value_is_encrypted_unsigned_integer,
|
||||
)
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.data_types.values import ClearValue, EncryptedValue
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value,expected_result",
|
||||
[
|
||||
pytest.param(
|
||||
ClearValue(Integer(8, is_signed=False)),
|
||||
False,
|
||||
id="ClearValue 8 bits unsigned Integer",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedValue(Integer(8, is_signed=True)),
|
||||
True,
|
||||
id="EncryptedValue 8 bits signed Integer",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_value_is_encrypted_integer(value: Integer, expected_result: bool):
|
||||
"""Test value_is_encrypted_integer helper"""
|
||||
assert value_is_encrypted_integer(value) == expected_result
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value,expected_result",
|
||||
[
|
||||
pytest.param(
|
||||
ClearValue(Integer(8, is_signed=False)),
|
||||
False,
|
||||
id="ClearValue 8 bits unsigned Integer",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedValue(Integer(8, is_signed=True)),
|
||||
False,
|
||||
id="EncryptedValue 8 bits signed Integer",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedValue(Integer(8, is_signed=False)),
|
||||
True,
|
||||
id="EncryptedValue 8 bits unsigned Integer",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_value_is_encrypted_unsigned_integer(value: Integer, expected_result: bool):
|
||||
"""Test value_is_encrypted_unsigned_integer helper"""
|
||||
assert value_is_encrypted_unsigned_integer(value) == expected_result
|
||||
@@ -38,3 +38,33 @@ def test_basic_integers(integer: Integer, expected_min: int, expected_max: int):
|
||||
assert integer.can_represent_value(random.randint(expected_min, expected_max))
|
||||
assert not integer.can_represent_value(expected_min - 1)
|
||||
assert not integer.can_represent_value(expected_max + 1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"integer,expected_repr_str",
|
||||
[
|
||||
pytest.param(
|
||||
Integer(8, is_signed=False),
|
||||
"Integer<unsigned, 8 bits>",
|
||||
id="8 bits unsigned Integer",
|
||||
),
|
||||
pytest.param(
|
||||
Integer(8, is_signed=True),
|
||||
"Integer<signed, 8 bits>",
|
||||
id="8 bits signed Integer",
|
||||
),
|
||||
pytest.param(
|
||||
Integer(32, is_signed=False),
|
||||
"Integer<unsigned, 32 bits>",
|
||||
id="32 bits unsigned Integer",
|
||||
),
|
||||
pytest.param(
|
||||
Integer(32, is_signed=True),
|
||||
"Integer<signed, 32 bits>",
|
||||
id="32 bits signed Integer",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_integers_repr(integer: Integer, expected_repr_str: str):
|
||||
"""Test integer repr"""
|
||||
assert integer.__repr__() == expected_repr_str
|
||||
|
||||
26
tests/common/data_types/test_values.py
Normal file
26
tests/common/data_types/test_values.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Test file for HDK's common/data_types/values.py"""
|
||||
|
||||
import pytest
|
||||
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.data_types.values import BaseValue, ClearValue, EncryptedValue
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value,expected_repr_str",
|
||||
[
|
||||
pytest.param(
|
||||
ClearValue(Integer(8, is_signed=False)),
|
||||
"ClearValue<Integer<unsigned, 8 bits>>",
|
||||
id="ClearValue 8 bits unsigned Integer",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedValue(Integer(8, is_signed=True)),
|
||||
"EncryptedValue<Integer<signed, 8 bits>>",
|
||||
id="EncryptedValue 8 bits signed Integer",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_values_repr(value: BaseValue, expected_repr_str: str):
|
||||
"""Test value repr"""
|
||||
assert value.__repr__() == expected_repr_str
|
||||
Reference in New Issue
Block a user