From c9c7acf616684d04aa74a4cd0de1d2f01acf4296 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Thu, 5 Aug 2021 15:32:57 +0200 Subject: [PATCH] dev(floats): update function to mix types to support floats --- hdk/common/data_types/dtypes_helpers.py | 34 +++++++++++++------ .../common/data_types/test_dtypes_helpers.py | 30 ++++++++++++++++ 2 files changed, 53 insertions(+), 11 deletions(-) diff --git a/hdk/common/data_types/dtypes_helpers.py b/hdk/common/data_types/dtypes_helpers.py index a7036b950..1c93737e6 100644 --- a/hdk/common/data_types/dtypes_helpers.py +++ b/hdk/common/data_types/dtypes_helpers.py @@ -1,12 +1,16 @@ """File to hold helper functions for data types related stuff""" +from copy import deepcopy from typing import cast from .base import BaseDataType +from .floats import Float from .integers import Integer from .values import BaseValue, ClearValue, EncryptedValue INTEGER_TYPES = (Integer,) +FLOAT_TYPES = (Float,) +SUPPORTED_TYPES = INTEGER_TYPES + FLOAT_TYPES def value_is_encrypted_integer(value_to_check: BaseValue) -> bool: @@ -57,36 +61,44 @@ def find_type_to_hold_both_lossy( Returns: BaseDataType: The dtype able to hold (potentially lossy) dtype1 and dtype2 """ + assert isinstance(dtype1, SUPPORTED_TYPES), f"Unsupported dtype1: {type(dtype1)}" + assert isinstance(dtype2, SUPPORTED_TYPES), f"Unsupported dtype2: {type(dtype2)}" + + type_to_return: BaseDataType + if isinstance(dtype1, Integer) and isinstance(dtype2, Integer): d1_signed = dtype1.is_signed d2_signed = dtype2.is_signed max_bits = max(dtype1.bit_width, dtype2.bit_width) - holding_integer: BaseDataType - if d1_signed and d2_signed: - holding_integer = Integer(max_bits, is_signed=True) + type_to_return = Integer(max_bits, is_signed=True) elif not d1_signed and not d2_signed: - holding_integer = Integer(max_bits, is_signed=False) + type_to_return = Integer(max_bits, is_signed=False) elif d1_signed and not d2_signed: # 2 is unsigned, if it has the bigger bit_width, we need a signed integer that can hold # it, so add 1 bit of sign to its bit_width if dtype2.bit_width >= dtype1.bit_width: new_bit_width = dtype2.bit_width + 1 - holding_integer = Integer(new_bit_width, is_signed=True) + type_to_return = Integer(new_bit_width, is_signed=True) else: - holding_integer = Integer(dtype1.bit_width, is_signed=True) + type_to_return = Integer(dtype1.bit_width, is_signed=True) elif not d1_signed and d2_signed: # Same as above, with 1 and 2 switched around if dtype1.bit_width >= dtype2.bit_width: new_bit_width = dtype1.bit_width + 1 - holding_integer = Integer(new_bit_width, is_signed=True) + type_to_return = Integer(new_bit_width, is_signed=True) else: - holding_integer = Integer(dtype2.bit_width, is_signed=True) + type_to_return = Integer(dtype2.bit_width, is_signed=True) + elif isinstance(dtype1, Float) and isinstance(dtype2, Float): + max_bits = max(dtype1.bit_width, dtype2.bit_width) + type_to_return = Float(max_bits) + elif isinstance(dtype1, Float): + type_to_return = deepcopy(dtype1) + elif isinstance(dtype2, Float): + type_to_return = deepcopy(dtype2) - return holding_integer - - raise NotImplementedError("For now only Integers are supported by find_type_to_hold_both_lossy") + return type_to_return def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> BaseValue: diff --git a/tests/common/data_types/test_dtypes_helpers.py b/tests/common/data_types/test_dtypes_helpers.py index d9fc88099..c66fbae88 100644 --- a/tests/common/data_types/test_dtypes_helpers.py +++ b/tests/common/data_types/test_dtypes_helpers.py @@ -9,6 +9,7 @@ from hdk.common.data_types.dtypes_helpers import ( value_is_encrypted_integer, value_is_encrypted_unsigned_integer, ) +from hdk.common.data_types.floats import Float from hdk.common.data_types.integers import Integer from hdk.common.data_types.values import BaseValue, ClearValue, EncryptedValue @@ -73,6 +74,16 @@ class UnsupportedDataType(BaseDataType): pytest.param(Integer(6, False), Integer(6, True), Integer(7, True), id="uint6, int6, int7"), pytest.param(Integer(6, True), Integer(5, False), Integer(6, True), id="int6, uint5, int6"), pytest.param(Integer(5, False), Integer(6, True), Integer(6, True), id="uint5, int6, int6"), + pytest.param(Integer(32, True), Float(32), Float(32), id="int32, float32, float32"), + pytest.param(Integer(64, True), Float(32), Float(32), id="int64, float32, float32"), + pytest.param(Integer(64, True), Float(64), Float(64), id="int64, float64, float64"), + pytest.param(Integer(32, True), Float(64), Float(64), id="int32, float64, float64"), + pytest.param(Float(64), Integer(32, True), Float(64), id="float64, int32, float64"), + pytest.param(Float(64), Integer(7, False), Float(64), id="float64, uint7, float64"), + pytest.param(Float(32), Float(32), Float(32), id="float32, float32, float32"), + pytest.param(Float(32), Float(64), Float(64), id="float32, float64, float64"), + pytest.param(Float(64), Float(32), Float(64), id="float64, float32, float64"), + pytest.param(Float(64), Float(64), Float(64), id="float64, float64, float64"), pytest.param( UnsupportedDataType(), UnsupportedDataType(), @@ -94,6 +105,13 @@ class UnsupportedDataType(BaseDataType): id="unsupported, int6, xfail", marks=pytest.mark.xfail(strict=True), ), + pytest.param( + UnsupportedDataType(), + Float(32), + None, + id="unsupported, float32, xfail", + marks=pytest.mark.xfail(strict=True), + ), ], ) def test_mix_data_types( @@ -132,6 +150,18 @@ def test_mix_data_types( ClearValue(Integer(7, False)), id="cuint7, cuint7, cuint7", ), + pytest.param( + ClearValue(Float(32)), + ClearValue(Float(32)), + ClearValue(Float(32)), + id="cfloat32, cfloat32, cfloat32", + ), + pytest.param( + EncryptedValue(Float(32)), + ClearValue(Float(32)), + EncryptedValue(Float(32)), + id="efloat32, cfloat32, efloat32", + ), ], ) def test_mix_values(value1: BaseValue, value2: BaseValue, expected_mixed_value: BaseValue):