dev(dtypes): add functions to mix values and find dtype to hold two inputs

- allows to have a generic way of propagating types instead of deepcopying
input values and types in IR nodes
- supports only Integers for now
- opset checks will not be performed in those functions to keep knowledge
required on the opset to the MLIR conversion step or extra check steps
- use the mix_values_determine_holding_dtype in intermediate nodes where
appropriate
This commit is contained in:
Arthur Meyre
2021-07-26 14:31:49 +02:00
parent a56a0dbf0c
commit d7c1f42363
3 changed files with 173 additions and 11 deletions

View File

@@ -2,31 +2,33 @@
from typing import cast
from . import integers, values
from .base import BaseDataType
from .integers import Integer
from .values import BaseValue, ClearValue, EncryptedValue
INTEGER_TYPES = set([integers.Integer])
INTEGER_TYPES = set([Integer])
def value_is_encrypted_integer(value_to_check: values.BaseValue) -> bool:
def value_is_encrypted_integer(value_to_check: BaseValue) -> bool:
"""Helper function to check that a value is an encrypted_integer
Args:
value_to_check (values.BaseValue): The value to check
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is an encrypted value of type Integer
"""
return (
isinstance(value_to_check, values.EncryptedValue)
isinstance(value_to_check, EncryptedValue)
and type(value_to_check.data_type) in INTEGER_TYPES
)
def value_is_encrypted_unsigned_integer(value_to_check: values.BaseValue) -> bool:
def value_is_encrypted_unsigned_integer(value_to_check: BaseValue) -> bool:
"""Helper function to check that a value is an encrypted_integer
Args:
value_to_check (values.BaseValue): The value to check
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is an encrypted value of type Integer
@@ -34,5 +36,81 @@ def value_is_encrypted_unsigned_integer(value_to_check: values.BaseValue) -> boo
return (
value_is_encrypted_integer(value_to_check)
and not cast(integers.Integer, value_to_check.data_type).is_signed
and not cast(Integer, value_to_check.data_type).is_signed
)
def find_type_to_hold_both_lossy(
dtype1: BaseDataType,
dtype2: BaseDataType,
) -> BaseDataType:
"""Determine the type that can represent both dtype1 and dtype2 separately, this is lossy with
floating point types
Args:
dtype1 (BaseDataType): first dtype to hold
dtype2 (BaseDataType): second dtype to hold
Raises:
NotImplementedError: Raised if one of the two input dtypes is not an Integer as they are the
only type supported for now
Returns:
BaseDataType: The dtype able to hold (potentially lossy) dtype1 and dtype2
"""
if isinstance(dtype1, Integer) and isinstance(dtype2, Integer):
d1_signed = dtype1.is_signed
d2_signed = dtype2.is_signed
max_bits = max(dtype1.bit_width, dtype2.bit_width)
holding_integer: BaseDataType
if d1_signed and d2_signed:
holding_integer = Integer(max_bits, is_signed=True)
elif not d1_signed and not d2_signed:
holding_integer = Integer(max_bits, is_signed=False)
elif d1_signed and not d2_signed:
# 2 is unsigned, if it has the bigger bit_width, we need a signed integer that can hold
# it, so add 1 bit of sign to its bit_width
if dtype2.bit_width >= dtype1.bit_width:
new_bit_width = dtype2.bit_width + 1
holding_integer = Integer(new_bit_width, is_signed=True)
else:
holding_integer = Integer(dtype1.bit_width, is_signed=True)
elif not d1_signed and d2_signed:
# Same as above, with 1 and 2 switched around
if dtype1.bit_width >= dtype2.bit_width:
new_bit_width = dtype1.bit_width + 1
holding_integer = Integer(new_bit_width, is_signed=True)
else:
holding_integer = Integer(dtype2.bit_width, is_signed=True)
return holding_integer
raise NotImplementedError("For now only Integers are supported by find_type_to_hold_both_lossy")
def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> BaseValue:
"""Returns a Value that would result from computation on both value1 and value2 while
determining the data type able to hold both value1 and value2 data type (this can be lossy
with floats)
Args:
value1 (BaseValue): first value to mix
value2 (BaseValue): second value to mix
Returns:
BaseValue: The resulting mixed value with data type able to hold both value1 and value2
dtypes
"""
holding_type = find_type_to_hold_both_lossy(value1.data_type, value2.data_type)
mixed_value: BaseValue
if isinstance(value1, EncryptedValue) or isinstance(value2, EncryptedValue):
mixed_value = EncryptedValue(holding_type)
else:
mixed_value = ClearValue(holding_type)
return mixed_value

View File

@@ -5,6 +5,7 @@ from copy import deepcopy
from typing import Any, Dict, Iterable, List, Optional, Tuple
from ..data_types import BaseValue
from ..data_types.dtypes_helpers import mix_values_determine_holding_dtype
class IntermediateNode(ABC):
@@ -39,9 +40,7 @@ class IntermediateNode(ABC):
assert len(self.inputs) == 2
# For now copy the first input type for the output type
# We don't perform checks or enforce consistency here for now, so this is OK
self.outputs = [deepcopy(self.inputs[0])]
self.outputs = [mix_values_determine_holding_dtype(self.inputs[0], self.inputs[1])]
def _is_equivalent_to_binary_commutative(self, other: object) -> bool:
return (

View File

@@ -2,7 +2,10 @@
import pytest
from hdk.common.data_types.base import BaseDataType
from hdk.common.data_types.dtypes_helpers import (
find_type_to_hold_both_lossy,
mix_values_determine_holding_dtype,
value_is_encrypted_integer,
value_is_encrypted_unsigned_integer,
)
@@ -53,3 +56,85 @@ def test_value_is_encrypted_integer(value: BaseValue, expected_result: bool):
def test_value_is_encrypted_unsigned_integer(value: BaseValue, expected_result: bool):
"""Test value_is_encrypted_unsigned_integer helper"""
assert value_is_encrypted_unsigned_integer(value) == expected_result
class UnsupportedDataType(BaseDataType):
"""Test helper class to represent an UnsupportedDataType"""
@pytest.mark.parametrize(
"dtype1,dtype2,expected_mixed_dtype",
[
pytest.param(Integer(6, True), Integer(6, True), Integer(6, True), id="int6, int6, int6"),
pytest.param(
Integer(6, False), Integer(6, False), Integer(6, False), id="uint6, uint6, uint6"
),
pytest.param(Integer(6, True), Integer(6, False), Integer(7, True), id="int6, uint6, int7"),
pytest.param(Integer(6, False), Integer(6, True), Integer(7, True), id="uint6, int6, int7"),
pytest.param(Integer(6, True), Integer(5, False), Integer(6, True), id="int6, uint5, int6"),
pytest.param(Integer(5, False), Integer(6, True), Integer(6, True), id="uint5, int6, int6"),
pytest.param(
UnsupportedDataType(),
UnsupportedDataType(),
None,
id="unsupported, unsupported, xfail",
marks=pytest.mark.xfail(strict=True),
),
pytest.param(
Integer(6, True),
UnsupportedDataType(),
None,
id="int6, unsupported, xfail",
marks=pytest.mark.xfail(strict=True),
),
pytest.param(
UnsupportedDataType(),
Integer(6, True),
None,
id="unsupported, int6, xfail",
marks=pytest.mark.xfail(strict=True),
),
],
)
def test_mix_data_types(
dtype1: BaseDataType,
dtype2: BaseDataType,
expected_mixed_dtype: BaseDataType,
):
"""Test find_type_to_hold_both_lossy helper"""
assert expected_mixed_dtype == find_type_to_hold_both_lossy(dtype1, dtype2)
@pytest.mark.parametrize(
"value1,value2,expected_mixed_value",
[
pytest.param(
EncryptedValue(Integer(7, False)),
EncryptedValue(Integer(7, False)),
EncryptedValue(Integer(7, False)),
id="euint7, euint7, euint7",
),
pytest.param(
EncryptedValue(Integer(7, False)),
ClearValue(Integer(7, False)),
EncryptedValue(Integer(7, False)),
id="euint7, cuint7, euint7",
),
pytest.param(
ClearValue(Integer(7, False)),
EncryptedValue(Integer(7, False)),
EncryptedValue(Integer(7, False)),
id="cuint7, euint7, euint7",
),
pytest.param(
ClearValue(Integer(7, False)),
ClearValue(Integer(7, False)),
ClearValue(Integer(7, False)),
id="cuint7, cuint7, cuint7",
),
],
)
def test_mix_values(value1: BaseValue, value2: BaseValue, expected_mixed_value: BaseValue):
"""Test mix_values helper"""
assert expected_mixed_value == mix_values_determine_holding_dtype(value1, value2)