dev(floats): update function to mix types to support floats

This commit is contained in:
Arthur Meyre
2021-08-05 15:32:57 +02:00
parent 0c275c5f43
commit c9c7acf616
2 changed files with 53 additions and 11 deletions

View File

@@ -1,12 +1,16 @@
"""File to hold helper functions for data types related stuff"""
from copy import deepcopy
from typing import cast
from .base import BaseDataType
from .floats import Float
from .integers import Integer
from .values import BaseValue, ClearValue, EncryptedValue
INTEGER_TYPES = (Integer,)
FLOAT_TYPES = (Float,)
SUPPORTED_TYPES = INTEGER_TYPES + FLOAT_TYPES
def value_is_encrypted_integer(value_to_check: BaseValue) -> bool:
@@ -57,36 +61,44 @@ def find_type_to_hold_both_lossy(
Returns:
BaseDataType: The dtype able to hold (potentially lossy) dtype1 and dtype2
"""
assert isinstance(dtype1, SUPPORTED_TYPES), f"Unsupported dtype1: {type(dtype1)}"
assert isinstance(dtype2, SUPPORTED_TYPES), f"Unsupported dtype2: {type(dtype2)}"
type_to_return: BaseDataType
if isinstance(dtype1, Integer) and isinstance(dtype2, Integer):
d1_signed = dtype1.is_signed
d2_signed = dtype2.is_signed
max_bits = max(dtype1.bit_width, dtype2.bit_width)
holding_integer: BaseDataType
if d1_signed and d2_signed:
holding_integer = Integer(max_bits, is_signed=True)
type_to_return = Integer(max_bits, is_signed=True)
elif not d1_signed and not d2_signed:
holding_integer = Integer(max_bits, is_signed=False)
type_to_return = Integer(max_bits, is_signed=False)
elif d1_signed and not d2_signed:
# 2 is unsigned, if it has the bigger bit_width, we need a signed integer that can hold
# it, so add 1 bit of sign to its bit_width
if dtype2.bit_width >= dtype1.bit_width:
new_bit_width = dtype2.bit_width + 1
holding_integer = Integer(new_bit_width, is_signed=True)
type_to_return = Integer(new_bit_width, is_signed=True)
else:
holding_integer = Integer(dtype1.bit_width, is_signed=True)
type_to_return = Integer(dtype1.bit_width, is_signed=True)
elif not d1_signed and d2_signed:
# Same as above, with 1 and 2 switched around
if dtype1.bit_width >= dtype2.bit_width:
new_bit_width = dtype1.bit_width + 1
holding_integer = Integer(new_bit_width, is_signed=True)
type_to_return = Integer(new_bit_width, is_signed=True)
else:
holding_integer = Integer(dtype2.bit_width, is_signed=True)
type_to_return = Integer(dtype2.bit_width, is_signed=True)
elif isinstance(dtype1, Float) and isinstance(dtype2, Float):
max_bits = max(dtype1.bit_width, dtype2.bit_width)
type_to_return = Float(max_bits)
elif isinstance(dtype1, Float):
type_to_return = deepcopy(dtype1)
elif isinstance(dtype2, Float):
type_to_return = deepcopy(dtype2)
return holding_integer
raise NotImplementedError("For now only Integers are supported by find_type_to_hold_both_lossy")
return type_to_return
def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> BaseValue:

View File

@@ -9,6 +9,7 @@ from hdk.common.data_types.dtypes_helpers import (
value_is_encrypted_integer,
value_is_encrypted_unsigned_integer,
)
from hdk.common.data_types.floats import Float
from hdk.common.data_types.integers import Integer
from hdk.common.data_types.values import BaseValue, ClearValue, EncryptedValue
@@ -73,6 +74,16 @@ class UnsupportedDataType(BaseDataType):
pytest.param(Integer(6, False), Integer(6, True), Integer(7, True), id="uint6, int6, int7"),
pytest.param(Integer(6, True), Integer(5, False), Integer(6, True), id="int6, uint5, int6"),
pytest.param(Integer(5, False), Integer(6, True), Integer(6, True), id="uint5, int6, int6"),
pytest.param(Integer(32, True), Float(32), Float(32), id="int32, float32, float32"),
pytest.param(Integer(64, True), Float(32), Float(32), id="int64, float32, float32"),
pytest.param(Integer(64, True), Float(64), Float(64), id="int64, float64, float64"),
pytest.param(Integer(32, True), Float(64), Float(64), id="int32, float64, float64"),
pytest.param(Float(64), Integer(32, True), Float(64), id="float64, int32, float64"),
pytest.param(Float(64), Integer(7, False), Float(64), id="float64, uint7, float64"),
pytest.param(Float(32), Float(32), Float(32), id="float32, float32, float32"),
pytest.param(Float(32), Float(64), Float(64), id="float32, float64, float64"),
pytest.param(Float(64), Float(32), Float(64), id="float64, float32, float64"),
pytest.param(Float(64), Float(64), Float(64), id="float64, float64, float64"),
pytest.param(
UnsupportedDataType(),
UnsupportedDataType(),
@@ -94,6 +105,13 @@ class UnsupportedDataType(BaseDataType):
id="unsupported, int6, xfail",
marks=pytest.mark.xfail(strict=True),
),
pytest.param(
UnsupportedDataType(),
Float(32),
None,
id="unsupported, float32, xfail",
marks=pytest.mark.xfail(strict=True),
),
],
)
def test_mix_data_types(
@@ -132,6 +150,18 @@ def test_mix_data_types(
ClearValue(Integer(7, False)),
id="cuint7, cuint7, cuint7",
),
pytest.param(
ClearValue(Float(32)),
ClearValue(Float(32)),
ClearValue(Float(32)),
id="cfloat32, cfloat32, cfloat32",
),
pytest.param(
EncryptedValue(Float(32)),
ClearValue(Float(32)),
EncryptedValue(Float(32)),
id="efloat32, cfloat32, efloat32",
),
],
)
def test_mix_values(value1: BaseValue, value2: BaseValue, expected_mixed_value: BaseValue):