dev(floats): update function to mix types to support floats

This commit is contained in:
Arthur Meyre
2021-08-05 15:32:57 +02:00
parent 0c275c5f43
commit c9c7acf616
2 changed files with 53 additions and 11 deletions

View File

@@ -9,6 +9,7 @@ from hdk.common.data_types.dtypes_helpers import (
value_is_encrypted_integer,
value_is_encrypted_unsigned_integer,
)
from hdk.common.data_types.floats import Float
from hdk.common.data_types.integers import Integer
from hdk.common.data_types.values import BaseValue, ClearValue, EncryptedValue
@@ -73,6 +74,16 @@ class UnsupportedDataType(BaseDataType):
pytest.param(Integer(6, False), Integer(6, True), Integer(7, True), id="uint6, int6, int7"),
pytest.param(Integer(6, True), Integer(5, False), Integer(6, True), id="int6, uint5, int6"),
pytest.param(Integer(5, False), Integer(6, True), Integer(6, True), id="uint5, int6, int6"),
pytest.param(Integer(32, True), Float(32), Float(32), id="int32, float32, float32"),
pytest.param(Integer(64, True), Float(32), Float(32), id="int64, float32, float32"),
pytest.param(Integer(64, True), Float(64), Float(64), id="int64, float64, float64"),
pytest.param(Integer(32, True), Float(64), Float(64), id="int32, float64, float64"),
pytest.param(Float(64), Integer(32, True), Float(64), id="float64, int32, float64"),
pytest.param(Float(64), Integer(7, False), Float(64), id="float64, uint7, float64"),
pytest.param(Float(32), Float(32), Float(32), id="float32, float32, float32"),
pytest.param(Float(32), Float(64), Float(64), id="float32, float64, float64"),
pytest.param(Float(64), Float(32), Float(64), id="float64, float32, float64"),
pytest.param(Float(64), Float(64), Float(64), id="float64, float64, float64"),
pytest.param(
UnsupportedDataType(),
UnsupportedDataType(),
@@ -94,6 +105,13 @@ class UnsupportedDataType(BaseDataType):
id="unsupported, int6, xfail",
marks=pytest.mark.xfail(strict=True),
),
pytest.param(
UnsupportedDataType(),
Float(32),
None,
id="unsupported, float32, xfail",
marks=pytest.mark.xfail(strict=True),
),
],
)
def test_mix_data_types(
@@ -132,6 +150,18 @@ def test_mix_data_types(
ClearValue(Integer(7, False)),
id="cuint7, cuint7, cuint7",
),
pytest.param(
ClearValue(Float(32)),
ClearValue(Float(32)),
ClearValue(Float(32)),
id="cfloat32, cfloat32, cfloat32",
),
pytest.param(
EncryptedValue(Float(32)),
ClearValue(Float(32)),
EncryptedValue(Float(32)),
id="efloat32, cfloat32, efloat32",
),
],
)
def test_mix_values(value1: BaseValue, value2: BaseValue, expected_mixed_value: BaseValue):