mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
dev(floats): update function to mix types to support floats
This commit is contained in:
@@ -9,6 +9,7 @@ from hdk.common.data_types.dtypes_helpers import (
|
||||
value_is_encrypted_integer,
|
||||
value_is_encrypted_unsigned_integer,
|
||||
)
|
||||
from hdk.common.data_types.floats import Float
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.data_types.values import BaseValue, ClearValue, EncryptedValue
|
||||
|
||||
@@ -73,6 +74,16 @@ class UnsupportedDataType(BaseDataType):
|
||||
pytest.param(Integer(6, False), Integer(6, True), Integer(7, True), id="uint6, int6, int7"),
|
||||
pytest.param(Integer(6, True), Integer(5, False), Integer(6, True), id="int6, uint5, int6"),
|
||||
pytest.param(Integer(5, False), Integer(6, True), Integer(6, True), id="uint5, int6, int6"),
|
||||
pytest.param(Integer(32, True), Float(32), Float(32), id="int32, float32, float32"),
|
||||
pytest.param(Integer(64, True), Float(32), Float(32), id="int64, float32, float32"),
|
||||
pytest.param(Integer(64, True), Float(64), Float(64), id="int64, float64, float64"),
|
||||
pytest.param(Integer(32, True), Float(64), Float(64), id="int32, float64, float64"),
|
||||
pytest.param(Float(64), Integer(32, True), Float(64), id="float64, int32, float64"),
|
||||
pytest.param(Float(64), Integer(7, False), Float(64), id="float64, uint7, float64"),
|
||||
pytest.param(Float(32), Float(32), Float(32), id="float32, float32, float32"),
|
||||
pytest.param(Float(32), Float(64), Float(64), id="float32, float64, float64"),
|
||||
pytest.param(Float(64), Float(32), Float(64), id="float64, float32, float64"),
|
||||
pytest.param(Float(64), Float(64), Float(64), id="float64, float64, float64"),
|
||||
pytest.param(
|
||||
UnsupportedDataType(),
|
||||
UnsupportedDataType(),
|
||||
@@ -94,6 +105,13 @@ class UnsupportedDataType(BaseDataType):
|
||||
id="unsupported, int6, xfail",
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
pytest.param(
|
||||
UnsupportedDataType(),
|
||||
Float(32),
|
||||
None,
|
||||
id="unsupported, float32, xfail",
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mix_data_types(
|
||||
@@ -132,6 +150,18 @@ def test_mix_data_types(
|
||||
ClearValue(Integer(7, False)),
|
||||
id="cuint7, cuint7, cuint7",
|
||||
),
|
||||
pytest.param(
|
||||
ClearValue(Float(32)),
|
||||
ClearValue(Float(32)),
|
||||
ClearValue(Float(32)),
|
||||
id="cfloat32, cfloat32, cfloat32",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedValue(Float(32)),
|
||||
ClearValue(Float(32)),
|
||||
EncryptedValue(Float(32)),
|
||||
id="efloat32, cfloat32, efloat32",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mix_values(value1: BaseValue, value2: BaseValue, expected_mixed_value: BaseValue):
|
||||
|
||||
Reference in New Issue
Block a user