mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
dev(dtypes): add functions to mix values and find dtype to hold two inputs
- allows to have a generic way of propagating types instead of deepcopying input values and types in IR nodes - supports only Integers for now - opset checks will not be performed in those functions to keep knowledge required on the opset to the MLIR conversion step or extra check steps - use the mix_values_determine_holding_dtype in intermediate nodes where appropriate
This commit is contained in:
@@ -2,31 +2,33 @@
|
||||
|
||||
from typing import cast
|
||||
|
||||
from . import integers, values
|
||||
from .base import BaseDataType
|
||||
from .integers import Integer
|
||||
from .values import BaseValue, ClearValue, EncryptedValue
|
||||
|
||||
INTEGER_TYPES = set([integers.Integer])
|
||||
INTEGER_TYPES = set([Integer])
|
||||
|
||||
|
||||
def value_is_encrypted_integer(value_to_check: values.BaseValue) -> bool:
|
||||
def value_is_encrypted_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Helper function to check that a value is an encrypted_integer
|
||||
|
||||
Args:
|
||||
value_to_check (values.BaseValue): The value to check
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is an encrypted value of type Integer
|
||||
"""
|
||||
return (
|
||||
isinstance(value_to_check, values.EncryptedValue)
|
||||
isinstance(value_to_check, EncryptedValue)
|
||||
and type(value_to_check.data_type) in INTEGER_TYPES
|
||||
)
|
||||
|
||||
|
||||
def value_is_encrypted_unsigned_integer(value_to_check: values.BaseValue) -> bool:
|
||||
def value_is_encrypted_unsigned_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Helper function to check that a value is an encrypted_integer
|
||||
|
||||
Args:
|
||||
value_to_check (values.BaseValue): The value to check
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is an encrypted value of type Integer
|
||||
@@ -34,5 +36,81 @@ def value_is_encrypted_unsigned_integer(value_to_check: values.BaseValue) -> boo
|
||||
|
||||
return (
|
||||
value_is_encrypted_integer(value_to_check)
|
||||
and not cast(integers.Integer, value_to_check.data_type).is_signed
|
||||
and not cast(Integer, value_to_check.data_type).is_signed
|
||||
)
|
||||
|
||||
|
||||
def find_type_to_hold_both_lossy(
|
||||
dtype1: BaseDataType,
|
||||
dtype2: BaseDataType,
|
||||
) -> BaseDataType:
|
||||
"""Determine the type that can represent both dtype1 and dtype2 separately, this is lossy with
|
||||
floating point types
|
||||
|
||||
Args:
|
||||
dtype1 (BaseDataType): first dtype to hold
|
||||
dtype2 (BaseDataType): second dtype to hold
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Raised if one of the two input dtypes is not an Integer as they are the
|
||||
only type supported for now
|
||||
|
||||
Returns:
|
||||
BaseDataType: The dtype able to hold (potentially lossy) dtype1 and dtype2
|
||||
"""
|
||||
if isinstance(dtype1, Integer) and isinstance(dtype2, Integer):
|
||||
d1_signed = dtype1.is_signed
|
||||
d2_signed = dtype2.is_signed
|
||||
max_bits = max(dtype1.bit_width, dtype2.bit_width)
|
||||
|
||||
holding_integer: BaseDataType
|
||||
|
||||
if d1_signed and d2_signed:
|
||||
holding_integer = Integer(max_bits, is_signed=True)
|
||||
elif not d1_signed and not d2_signed:
|
||||
holding_integer = Integer(max_bits, is_signed=False)
|
||||
elif d1_signed and not d2_signed:
|
||||
# 2 is unsigned, if it has the bigger bit_width, we need a signed integer that can hold
|
||||
# it, so add 1 bit of sign to its bit_width
|
||||
if dtype2.bit_width >= dtype1.bit_width:
|
||||
new_bit_width = dtype2.bit_width + 1
|
||||
holding_integer = Integer(new_bit_width, is_signed=True)
|
||||
else:
|
||||
holding_integer = Integer(dtype1.bit_width, is_signed=True)
|
||||
elif not d1_signed and d2_signed:
|
||||
# Same as above, with 1 and 2 switched around
|
||||
if dtype1.bit_width >= dtype2.bit_width:
|
||||
new_bit_width = dtype1.bit_width + 1
|
||||
holding_integer = Integer(new_bit_width, is_signed=True)
|
||||
else:
|
||||
holding_integer = Integer(dtype2.bit_width, is_signed=True)
|
||||
|
||||
return holding_integer
|
||||
|
||||
raise NotImplementedError("For now only Integers are supported by find_type_to_hold_both_lossy")
|
||||
|
||||
|
||||
def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> BaseValue:
|
||||
"""Returns a Value that would result from computation on both value1 and value2 while
|
||||
determining the data type able to hold both value1 and value2 data type (this can be lossy
|
||||
with floats)
|
||||
|
||||
Args:
|
||||
value1 (BaseValue): first value to mix
|
||||
value2 (BaseValue): second value to mix
|
||||
|
||||
Returns:
|
||||
BaseValue: The resulting mixed value with data type able to hold both value1 and value2
|
||||
dtypes
|
||||
"""
|
||||
|
||||
holding_type = find_type_to_hold_both_lossy(value1.data_type, value2.data_type)
|
||||
|
||||
mixed_value: BaseValue
|
||||
|
||||
if isinstance(value1, EncryptedValue) or isinstance(value2, EncryptedValue):
|
||||
mixed_value = EncryptedValue(holding_type)
|
||||
else:
|
||||
mixed_value = ClearValue(holding_type)
|
||||
|
||||
return mixed_value
|
||||
|
||||
@@ -5,6 +5,7 @@ from copy import deepcopy
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from ..data_types import BaseValue
|
||||
from ..data_types.dtypes_helpers import mix_values_determine_holding_dtype
|
||||
|
||||
|
||||
class IntermediateNode(ABC):
|
||||
@@ -39,9 +40,7 @@ class IntermediateNode(ABC):
|
||||
|
||||
assert len(self.inputs) == 2
|
||||
|
||||
# For now copy the first input type for the output type
|
||||
# We don't perform checks or enforce consistency here for now, so this is OK
|
||||
self.outputs = [deepcopy(self.inputs[0])]
|
||||
self.outputs = [mix_values_determine_holding_dtype(self.inputs[0], self.inputs[1])]
|
||||
|
||||
def _is_equivalent_to_binary_commutative(self, other: object) -> bool:
|
||||
return (
|
||||
|
||||
@@ -2,7 +2,10 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from hdk.common.data_types.base import BaseDataType
|
||||
from hdk.common.data_types.dtypes_helpers import (
|
||||
find_type_to_hold_both_lossy,
|
||||
mix_values_determine_holding_dtype,
|
||||
value_is_encrypted_integer,
|
||||
value_is_encrypted_unsigned_integer,
|
||||
)
|
||||
@@ -53,3 +56,85 @@ def test_value_is_encrypted_integer(value: BaseValue, expected_result: bool):
|
||||
def test_value_is_encrypted_unsigned_integer(value: BaseValue, expected_result: bool):
|
||||
"""Test value_is_encrypted_unsigned_integer helper"""
|
||||
assert value_is_encrypted_unsigned_integer(value) == expected_result
|
||||
|
||||
|
||||
class UnsupportedDataType(BaseDataType):
|
||||
"""Test helper class to represent an UnsupportedDataType"""
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dtype1,dtype2,expected_mixed_dtype",
|
||||
[
|
||||
pytest.param(Integer(6, True), Integer(6, True), Integer(6, True), id="int6, int6, int6"),
|
||||
pytest.param(
|
||||
Integer(6, False), Integer(6, False), Integer(6, False), id="uint6, uint6, uint6"
|
||||
),
|
||||
pytest.param(Integer(6, True), Integer(6, False), Integer(7, True), id="int6, uint6, int7"),
|
||||
pytest.param(Integer(6, False), Integer(6, True), Integer(7, True), id="uint6, int6, int7"),
|
||||
pytest.param(Integer(6, True), Integer(5, False), Integer(6, True), id="int6, uint5, int6"),
|
||||
pytest.param(Integer(5, False), Integer(6, True), Integer(6, True), id="uint5, int6, int6"),
|
||||
pytest.param(
|
||||
UnsupportedDataType(),
|
||||
UnsupportedDataType(),
|
||||
None,
|
||||
id="unsupported, unsupported, xfail",
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
pytest.param(
|
||||
Integer(6, True),
|
||||
UnsupportedDataType(),
|
||||
None,
|
||||
id="int6, unsupported, xfail",
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
pytest.param(
|
||||
UnsupportedDataType(),
|
||||
Integer(6, True),
|
||||
None,
|
||||
id="unsupported, int6, xfail",
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mix_data_types(
|
||||
dtype1: BaseDataType,
|
||||
dtype2: BaseDataType,
|
||||
expected_mixed_dtype: BaseDataType,
|
||||
):
|
||||
"""Test find_type_to_hold_both_lossy helper"""
|
||||
assert expected_mixed_dtype == find_type_to_hold_both_lossy(dtype1, dtype2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value1,value2,expected_mixed_value",
|
||||
[
|
||||
pytest.param(
|
||||
EncryptedValue(Integer(7, False)),
|
||||
EncryptedValue(Integer(7, False)),
|
||||
EncryptedValue(Integer(7, False)),
|
||||
id="euint7, euint7, euint7",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedValue(Integer(7, False)),
|
||||
ClearValue(Integer(7, False)),
|
||||
EncryptedValue(Integer(7, False)),
|
||||
id="euint7, cuint7, euint7",
|
||||
),
|
||||
pytest.param(
|
||||
ClearValue(Integer(7, False)),
|
||||
EncryptedValue(Integer(7, False)),
|
||||
EncryptedValue(Integer(7, False)),
|
||||
id="cuint7, euint7, euint7",
|
||||
),
|
||||
pytest.param(
|
||||
ClearValue(Integer(7, False)),
|
||||
ClearValue(Integer(7, False)),
|
||||
ClearValue(Integer(7, False)),
|
||||
id="cuint7, cuint7, cuint7",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mix_values(value1: BaseValue, value2: BaseValue, expected_mixed_value: BaseValue):
|
||||
"""Test mix_values helper"""
|
||||
|
||||
assert expected_mixed_value == mix_values_determine_holding_dtype(value1, value2)
|
||||
|
||||
Reference in New Issue
Block a user