mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
dev(floats): update function to mix types to support floats
This commit is contained in:
@@ -1,12 +1,16 @@
|
||||
"""File to hold helper functions for data types related stuff"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import cast
|
||||
|
||||
from .base import BaseDataType
|
||||
from .floats import Float
|
||||
from .integers import Integer
|
||||
from .values import BaseValue, ClearValue, EncryptedValue
|
||||
|
||||
INTEGER_TYPES = (Integer,)
|
||||
FLOAT_TYPES = (Float,)
|
||||
SUPPORTED_TYPES = INTEGER_TYPES + FLOAT_TYPES
|
||||
|
||||
|
||||
def value_is_encrypted_integer(value_to_check: BaseValue) -> bool:
|
||||
@@ -57,36 +61,44 @@ def find_type_to_hold_both_lossy(
|
||||
Returns:
|
||||
BaseDataType: The dtype able to hold (potentially lossy) dtype1 and dtype2
|
||||
"""
|
||||
assert isinstance(dtype1, SUPPORTED_TYPES), f"Unsupported dtype1: {type(dtype1)}"
|
||||
assert isinstance(dtype2, SUPPORTED_TYPES), f"Unsupported dtype2: {type(dtype2)}"
|
||||
|
||||
type_to_return: BaseDataType
|
||||
|
||||
if isinstance(dtype1, Integer) and isinstance(dtype2, Integer):
|
||||
d1_signed = dtype1.is_signed
|
||||
d2_signed = dtype2.is_signed
|
||||
max_bits = max(dtype1.bit_width, dtype2.bit_width)
|
||||
|
||||
holding_integer: BaseDataType
|
||||
|
||||
if d1_signed and d2_signed:
|
||||
holding_integer = Integer(max_bits, is_signed=True)
|
||||
type_to_return = Integer(max_bits, is_signed=True)
|
||||
elif not d1_signed and not d2_signed:
|
||||
holding_integer = Integer(max_bits, is_signed=False)
|
||||
type_to_return = Integer(max_bits, is_signed=False)
|
||||
elif d1_signed and not d2_signed:
|
||||
# 2 is unsigned, if it has the bigger bit_width, we need a signed integer that can hold
|
||||
# it, so add 1 bit of sign to its bit_width
|
||||
if dtype2.bit_width >= dtype1.bit_width:
|
||||
new_bit_width = dtype2.bit_width + 1
|
||||
holding_integer = Integer(new_bit_width, is_signed=True)
|
||||
type_to_return = Integer(new_bit_width, is_signed=True)
|
||||
else:
|
||||
holding_integer = Integer(dtype1.bit_width, is_signed=True)
|
||||
type_to_return = Integer(dtype1.bit_width, is_signed=True)
|
||||
elif not d1_signed and d2_signed:
|
||||
# Same as above, with 1 and 2 switched around
|
||||
if dtype1.bit_width >= dtype2.bit_width:
|
||||
new_bit_width = dtype1.bit_width + 1
|
||||
holding_integer = Integer(new_bit_width, is_signed=True)
|
||||
type_to_return = Integer(new_bit_width, is_signed=True)
|
||||
else:
|
||||
holding_integer = Integer(dtype2.bit_width, is_signed=True)
|
||||
type_to_return = Integer(dtype2.bit_width, is_signed=True)
|
||||
elif isinstance(dtype1, Float) and isinstance(dtype2, Float):
|
||||
max_bits = max(dtype1.bit_width, dtype2.bit_width)
|
||||
type_to_return = Float(max_bits)
|
||||
elif isinstance(dtype1, Float):
|
||||
type_to_return = deepcopy(dtype1)
|
||||
elif isinstance(dtype2, Float):
|
||||
type_to_return = deepcopy(dtype2)
|
||||
|
||||
return holding_integer
|
||||
|
||||
raise NotImplementedError("For now only Integers are supported by find_type_to_hold_both_lossy")
|
||||
return type_to_return
|
||||
|
||||
|
||||
def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> BaseValue:
|
||||
|
||||
@@ -9,6 +9,7 @@ from hdk.common.data_types.dtypes_helpers import (
|
||||
value_is_encrypted_integer,
|
||||
value_is_encrypted_unsigned_integer,
|
||||
)
|
||||
from hdk.common.data_types.floats import Float
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.data_types.values import BaseValue, ClearValue, EncryptedValue
|
||||
|
||||
@@ -73,6 +74,16 @@ class UnsupportedDataType(BaseDataType):
|
||||
pytest.param(Integer(6, False), Integer(6, True), Integer(7, True), id="uint6, int6, int7"),
|
||||
pytest.param(Integer(6, True), Integer(5, False), Integer(6, True), id="int6, uint5, int6"),
|
||||
pytest.param(Integer(5, False), Integer(6, True), Integer(6, True), id="uint5, int6, int6"),
|
||||
pytest.param(Integer(32, True), Float(32), Float(32), id="int32, float32, float32"),
|
||||
pytest.param(Integer(64, True), Float(32), Float(32), id="int64, float32, float32"),
|
||||
pytest.param(Integer(64, True), Float(64), Float(64), id="int64, float64, float64"),
|
||||
pytest.param(Integer(32, True), Float(64), Float(64), id="int32, float64, float64"),
|
||||
pytest.param(Float(64), Integer(32, True), Float(64), id="float64, int32, float64"),
|
||||
pytest.param(Float(64), Integer(7, False), Float(64), id="float64, uint7, float64"),
|
||||
pytest.param(Float(32), Float(32), Float(32), id="float32, float32, float32"),
|
||||
pytest.param(Float(32), Float(64), Float(64), id="float32, float64, float64"),
|
||||
pytest.param(Float(64), Float(32), Float(64), id="float64, float32, float64"),
|
||||
pytest.param(Float(64), Float(64), Float(64), id="float64, float64, float64"),
|
||||
pytest.param(
|
||||
UnsupportedDataType(),
|
||||
UnsupportedDataType(),
|
||||
@@ -94,6 +105,13 @@ class UnsupportedDataType(BaseDataType):
|
||||
id="unsupported, int6, xfail",
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
pytest.param(
|
||||
UnsupportedDataType(),
|
||||
Float(32),
|
||||
None,
|
||||
id="unsupported, float32, xfail",
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mix_data_types(
|
||||
@@ -132,6 +150,18 @@ def test_mix_data_types(
|
||||
ClearValue(Integer(7, False)),
|
||||
id="cuint7, cuint7, cuint7",
|
||||
),
|
||||
pytest.param(
|
||||
ClearValue(Float(32)),
|
||||
ClearValue(Float(32)),
|
||||
ClearValue(Float(32)),
|
||||
id="cfloat32, cfloat32, cfloat32",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedValue(Float(32)),
|
||||
ClearValue(Float(32)),
|
||||
EncryptedValue(Float(32)),
|
||||
id="efloat32, cfloat32, efloat32",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mix_values(value1: BaseValue, value2: BaseValue, expected_mixed_value: BaseValue):
|
||||
|
||||
Reference in New Issue
Block a user