From 1a54bc1f22663537339abf90fe9f6b0000f76e79 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Mon, 19 Jul 2021 19:09:15 +0200 Subject: [PATCH] dev(data-types): create Value classes which represent values in a program - value classes have a data_type member to know what they hold - add __repr__ to a few classes to ease readability for debug/print - add helper functions to perform value checks that will be used for tracing to ease readability - add unit tests to get 100% coverage --- hdk/__init__.py | 2 +- hdk/common/data_types/__init__.py | 3 +- hdk/common/data_types/helpers.py | 38 ++++++++++++++++ hdk/common/data_types/integers.py | 4 ++ hdk/common/data_types/values.py | 25 +++++++++++ tests/common/data_types/test_helpers.py | 55 ++++++++++++++++++++++++ tests/common/data_types/test_integers.py | 30 +++++++++++++ tests/common/data_types/test_values.py | 26 +++++++++++ 8 files changed, 181 insertions(+), 2 deletions(-) create mode 100644 hdk/common/data_types/helpers.py create mode 100644 hdk/common/data_types/values.py create mode 100644 tests/common/data_types/test_helpers.py create mode 100644 tests/common/data_types/test_values.py diff --git a/hdk/__init__.py b/hdk/__init__.py index 1df85bfe7..44b5d6169 100644 --- a/hdk/__init__.py +++ b/hdk/__init__.py @@ -1,2 +1,2 @@ """HDK's top import""" -from . import utils +from . import common, utils diff --git a/hdk/common/data_types/__init__.py b/hdk/common/data_types/__init__.py index 256544a3c..1703a0aaf 100644 --- a/hdk/common/data_types/__init__.py +++ b/hdk/common/data_types/__init__.py @@ -1,2 +1,3 @@ """HDK's module for data types code and data structures""" -from . import integers +from . import helpers, integers, values +from .values import BaseValue diff --git a/hdk/common/data_types/helpers.py b/hdk/common/data_types/helpers.py new file mode 100644 index 000000000..892c4be64 --- /dev/null +++ b/hdk/common/data_types/helpers.py @@ -0,0 +1,38 @@ +"""File to hold helper functions for data types related stuff""" + +from typing import cast + +from . import integers, values + +INTEGER_TYPES = set([integers.Integer]) + + +def value_is_encrypted_integer(value_to_check: values.BaseValue) -> bool: + """Helper function to check that a value is an encrypted_integer + + Args: + value_to_check (values.BaseValue): The value to check + + Returns: + bool: True if the passed value_to_check is an encrypted value of type Integer + """ + return ( + isinstance(value_to_check, values.EncryptedValue) + and type(value_to_check.data_type) in INTEGER_TYPES + ) + + +def value_is_encrypted_unsigned_integer(value_to_check: values.BaseValue) -> bool: + """Helper function to check that a value is an encrypted_integer + + Args: + value_to_check (values.BaseValue): The value to check + + Returns: + bool: True if the passed value_to_check is an encrypted value of type Integer + """ + + return ( + value_is_encrypted_integer(value_to_check) + and not cast(integers.Integer, value_to_check.data_type).is_signed + ) diff --git a/hdk/common/data_types/integers.py b/hdk/common/data_types/integers.py index 765b84e8d..d8b431adc 100644 --- a/hdk/common/data_types/integers.py +++ b/hdk/common/data_types/integers.py @@ -13,6 +13,10 @@ class Integer(base.BaseDataType): self.bit_width = bit_width self.is_signed = is_signed + def __repr__(self) -> str: + signed_str = "signed" if self.is_signed else "unsigned" + return f"{self.__class__.__name__}<{signed_str}, {self.bit_width} bits>" + def min_value(self) -> int: """Minimum value representable by the Integer""" if self.is_signed: diff --git a/hdk/common/data_types/values.py b/hdk/common/data_types/values.py new file mode 100644 index 000000000..b00ddf5f5 --- /dev/null +++ b/hdk/common/data_types/values.py @@ -0,0 +1,25 @@ +"""File holding classes representing values used by an FHE program""" + +from abc import ABC + +from . import base + + +class BaseValue(ABC): + """Abstract base class to represent any kind of value in a program""" + + data_type: base.BaseDataType + + def __init__(self, data_type: base.BaseDataType) -> None: + self.data_type = data_type + + def __repr__(self) -> str: + return f"{self.__class__.__name__}<{self.data_type!r}>" + + +class ClearValue(BaseValue): + """Class representing a clear/plaintext value (constant or not)""" + + +class EncryptedValue(BaseValue): + """Class representing an encrypted value (constant or not)""" diff --git a/tests/common/data_types/test_helpers.py b/tests/common/data_types/test_helpers.py new file mode 100644 index 000000000..cb76467eb --- /dev/null +++ b/tests/common/data_types/test_helpers.py @@ -0,0 +1,55 @@ +"""Test file for HDK's common/data_types/helpers.py""" + +import pytest + +from hdk.common.data_types.helpers import ( + value_is_encrypted_integer, + value_is_encrypted_unsigned_integer, +) +from hdk.common.data_types.integers import Integer +from hdk.common.data_types.values import ClearValue, EncryptedValue + + +@pytest.mark.parametrize( + "value,expected_result", + [ + pytest.param( + ClearValue(Integer(8, is_signed=False)), + False, + id="ClearValue 8 bits unsigned Integer", + ), + pytest.param( + EncryptedValue(Integer(8, is_signed=True)), + True, + id="EncryptedValue 8 bits signed Integer", + ), + ], +) +def test_value_is_encrypted_integer(value: Integer, expected_result: bool): + """Test value_is_encrypted_integer helper""" + assert value_is_encrypted_integer(value) == expected_result + + +@pytest.mark.parametrize( + "value,expected_result", + [ + pytest.param( + ClearValue(Integer(8, is_signed=False)), + False, + id="ClearValue 8 bits unsigned Integer", + ), + pytest.param( + EncryptedValue(Integer(8, is_signed=True)), + False, + id="EncryptedValue 8 bits signed Integer", + ), + pytest.param( + EncryptedValue(Integer(8, is_signed=False)), + True, + id="EncryptedValue 8 bits unsigned Integer", + ), + ], +) +def test_value_is_encrypted_unsigned_integer(value: Integer, expected_result: bool): + """Test value_is_encrypted_unsigned_integer helper""" + assert value_is_encrypted_unsigned_integer(value) == expected_result diff --git a/tests/common/data_types/test_integers.py b/tests/common/data_types/test_integers.py index 8e1d12efd..5780cf995 100644 --- a/tests/common/data_types/test_integers.py +++ b/tests/common/data_types/test_integers.py @@ -38,3 +38,33 @@ def test_basic_integers(integer: Integer, expected_min: int, expected_max: int): assert integer.can_represent_value(random.randint(expected_min, expected_max)) assert not integer.can_represent_value(expected_min - 1) assert not integer.can_represent_value(expected_max + 1) + + +@pytest.mark.parametrize( + "integer,expected_repr_str", + [ + pytest.param( + Integer(8, is_signed=False), + "Integer", + id="8 bits unsigned Integer", + ), + pytest.param( + Integer(8, is_signed=True), + "Integer", + id="8 bits signed Integer", + ), + pytest.param( + Integer(32, is_signed=False), + "Integer", + id="32 bits unsigned Integer", + ), + pytest.param( + Integer(32, is_signed=True), + "Integer", + id="32 bits signed Integer", + ), + ], +) +def test_integers_repr(integer: Integer, expected_repr_str: str): + """Test integer repr""" + assert integer.__repr__() == expected_repr_str diff --git a/tests/common/data_types/test_values.py b/tests/common/data_types/test_values.py new file mode 100644 index 000000000..de9803d62 --- /dev/null +++ b/tests/common/data_types/test_values.py @@ -0,0 +1,26 @@ +"""Test file for HDK's common/data_types/values.py""" + +import pytest + +from hdk.common.data_types.integers import Integer +from hdk.common.data_types.values import BaseValue, ClearValue, EncryptedValue + + +@pytest.mark.parametrize( + "value,expected_repr_str", + [ + pytest.param( + ClearValue(Integer(8, is_signed=False)), + "ClearValue>", + id="ClearValue 8 bits unsigned Integer", + ), + pytest.param( + EncryptedValue(Integer(8, is_signed=True)), + "EncryptedValue>", + id="EncryptedValue 8 bits signed Integer", + ), + ], +) +def test_values_repr(value: BaseValue, expected_repr_str: str): + """Test value repr""" + assert value.__repr__() == expected_repr_str