mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refacto: change the way mixing values is handled
- use an intermediate function that checks which mix function to use - update function used in hnumpy tracing
This commit is contained in:
@@ -4,7 +4,15 @@ from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Callable, Union, cast
|
||||
|
||||
from ..values import BaseValue, ClearValue, EncryptedValue, ScalarValue
|
||||
from ..values import (
|
||||
BaseValue,
|
||||
ClearTensor,
|
||||
ClearValue,
|
||||
EncryptedTensor,
|
||||
EncryptedValue,
|
||||
ScalarValue,
|
||||
TensorValue,
|
||||
)
|
||||
from .base import BaseDataType
|
||||
from .floats import Float
|
||||
from .integers import Integer, get_bits_to_represent_value_as_integer
|
||||
@@ -134,19 +142,22 @@ def find_type_to_hold_both_lossy(
|
||||
return type_to_return
|
||||
|
||||
|
||||
def mix_scalar_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> ScalarValue:
|
||||
"""Return mixed value with data type able to hold both value1 and value2 dtypes.
|
||||
def mix_scalar_values_determine_holding_dtype(
|
||||
value1: ScalarValue,
|
||||
value2: ScalarValue,
|
||||
) -> ScalarValue:
|
||||
"""Return mixed ScalarValue with data type able to hold both value1 and value2 dtypes.
|
||||
|
||||
Returns a ScalarValue that would result from computation on both value1 and value2 while
|
||||
determining the data type able to hold both value1 and value2 data type (this can be lossy
|
||||
with floats).
|
||||
|
||||
Args:
|
||||
value1 (BaseValue): first ScalarValue to mix.
|
||||
value2 (BaseValue): second ScalarValue to mix.
|
||||
value1 (ScalarValue): first ScalarValue to mix.
|
||||
value2 (ScalarValue): second ScalarValue to mix.
|
||||
|
||||
Returns:
|
||||
ScalarValue: The resulting mixed BaseValue with data type able to hold both value1 and
|
||||
ScalarValue: The resulting mixed ScalarValue with data type able to hold both value1 and
|
||||
value2 dtypes.
|
||||
"""
|
||||
|
||||
@@ -164,6 +175,77 @@ def mix_scalar_values_determine_holding_dtype(value1: BaseValue, value2: BaseVal
|
||||
return mixed_value
|
||||
|
||||
|
||||
def mix_tensor_values_determine_holding_dtype(
|
||||
value1: TensorValue,
|
||||
value2: TensorValue,
|
||||
) -> TensorValue:
|
||||
"""Return mixed TensorValue with data type able to hold both value1 and value2 dtypes.
|
||||
|
||||
Returns a TensorValue that would result from computation on both value1 and value2 while
|
||||
determining the data type able to hold both value1 and value2 data type (this can be lossy
|
||||
with floats).
|
||||
|
||||
Args:
|
||||
value1 (TensorValue): first TensorValue to mix.
|
||||
value2 (TensorValue): second TensorValue to mix.
|
||||
|
||||
Returns:
|
||||
TensorValue: The resulting mixed TensorValue with data type able to hold both value1 and
|
||||
value2 dtypes.
|
||||
"""
|
||||
|
||||
assert isinstance(value1, TensorValue), f"Unsupported value1: {value1}, expected TensorValue"
|
||||
assert isinstance(value2, TensorValue), f"Unsupported value2: {value2}, expected TensorValue"
|
||||
|
||||
assert value1.shape == value2.shape, (
|
||||
f"Tensors have different shapes which is not supported.\n"
|
||||
f"value1: {value1.shape}, value2: {value2.shape}"
|
||||
)
|
||||
|
||||
holding_type = find_type_to_hold_both_lossy(value1.data_type, value2.data_type)
|
||||
shape = value1.shape
|
||||
|
||||
if value1.is_encrypted or value2.is_encrypted:
|
||||
mixed_value = EncryptedTensor(data_type=holding_type, shape=shape)
|
||||
else:
|
||||
mixed_value = ClearTensor(data_type=holding_type, shape=shape)
|
||||
|
||||
return mixed_value
|
||||
|
||||
|
||||
def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> BaseValue:
|
||||
"""Return mixed BaseValue with data type able to hold both value1 and value2 dtypes.
|
||||
|
||||
Returns a BaseValue that would result from computation on both value1 and value2 while
|
||||
determining the data type able to hold both value1 and value2 data type (this can be lossy
|
||||
with floats). Supports only mixing instances from the same class.
|
||||
|
||||
Args:
|
||||
value1 (BaseValue): first BaseValue to mix.
|
||||
value2 (BaseValue): second BaseValue to mix.
|
||||
|
||||
Raises:
|
||||
ValueError: raised if the BaseValue is not one of (ScalarValue, TensorValue)
|
||||
|
||||
Returns:
|
||||
BaseValue: The resulting mixed BaseValue with data type able to hold both value1 and value2
|
||||
dtypes.
|
||||
"""
|
||||
|
||||
assert (
|
||||
value1.__class__ == value2.__class__
|
||||
), f"Cannot mix values of different types: value 1:{type(value1)}, value2: {type(value2)}"
|
||||
|
||||
if isinstance(value1, ScalarValue) and isinstance(value2, ScalarValue):
|
||||
return mix_scalar_values_determine_holding_dtype(value1, value2)
|
||||
if isinstance(value1, TensorValue) and isinstance(value2, TensorValue):
|
||||
return mix_tensor_values_determine_holding_dtype(value1, value2)
|
||||
|
||||
raise ValueError(
|
||||
f"{mix_values_determine_holding_dtype.__name__} does not support value {type(value1)}"
|
||||
)
|
||||
|
||||
|
||||
def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]) -> BaseDataType:
|
||||
"""Helper function to determine the BaseDataType to hold the input constant data.
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any, Callable, Dict
|
||||
import numpy
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
from ..common.data_types.dtypes_helpers import mix_scalar_values_determine_holding_dtype
|
||||
from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.representation.intermediate import ArbitraryFunction, Constant
|
||||
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
|
||||
@@ -31,7 +31,7 @@ NPConstant = partial(
|
||||
class NPTracer(BaseTracer):
|
||||
"""Tracer class for numpy operations."""
|
||||
|
||||
_mix_values_func: Callable[..., BaseValue] = mix_scalar_values_determine_holding_dtype
|
||||
_mix_values_func: Callable[..., BaseValue] = mix_values_determine_holding_dtype
|
||||
|
||||
def __array_ufunc__(self, ufunc, method, *input_tracers, **kwargs):
|
||||
"""Catch calls to numpy ufunc and routes them to tracing functions if supported.
|
||||
|
||||
@@ -5,13 +5,19 @@ import pytest
|
||||
from hdk.common.data_types.base import BaseDataType
|
||||
from hdk.common.data_types.dtypes_helpers import (
|
||||
find_type_to_hold_both_lossy,
|
||||
mix_scalar_values_determine_holding_dtype,
|
||||
mix_values_determine_holding_dtype,
|
||||
value_is_encrypted_integer,
|
||||
value_is_encrypted_unsigned_integer,
|
||||
)
|
||||
from hdk.common.data_types.floats import Float
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.values import BaseValue, ClearValue, EncryptedValue
|
||||
from hdk.common.values import (
|
||||
BaseValue,
|
||||
ClearTensor,
|
||||
ClearValue,
|
||||
EncryptedTensor,
|
||||
EncryptedValue,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -167,7 +173,67 @@ def test_mix_data_types(
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mix_values(value1: BaseValue, value2: BaseValue, expected_mixed_value: BaseValue):
|
||||
"""Test mix_values helper"""
|
||||
def test_mix_scalar_values(value1, value2, expected_mixed_value):
|
||||
"""Test mix_values_determine_holding_dtype helper with scalars"""
|
||||
|
||||
assert expected_mixed_value == mix_scalar_values_determine_holding_dtype(value1, value2)
|
||||
assert expected_mixed_value == mix_values_determine_holding_dtype(value1, value2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value1,value2,expected_mixed_value",
|
||||
[
|
||||
pytest.param(
|
||||
EncryptedTensor(Integer(7, False), (1, 2, 3)),
|
||||
EncryptedTensor(Integer(7, False), (1, 2, 3)),
|
||||
EncryptedTensor(Integer(7, False), (1, 2, 3)),
|
||||
),
|
||||
pytest.param(
|
||||
ClearTensor(Integer(7, False), (1, 2, 3)),
|
||||
EncryptedTensor(Integer(7, False), (1, 2, 3)),
|
||||
EncryptedTensor(Integer(7, False), (1, 2, 3)),
|
||||
),
|
||||
pytest.param(
|
||||
ClearTensor(Integer(7, False), (1, 2, 3)),
|
||||
ClearTensor(Integer(7, False), (1, 2, 3)),
|
||||
ClearTensor(Integer(7, False), (1, 2, 3)),
|
||||
),
|
||||
pytest.param(
|
||||
ClearTensor(Integer(7, False), (1, 2, 3)),
|
||||
ClearTensor(Integer(7, False), (1, 2, 3)),
|
||||
ClearTensor(Integer(7, False), (1, 2, 3)),
|
||||
),
|
||||
pytest.param(
|
||||
ClearTensor(Integer(7, False), (1, 2, 3)),
|
||||
EncryptedValue(Integer(7, False)),
|
||||
None,
|
||||
marks=pytest.mark.xfail(raises=AssertionError),
|
||||
),
|
||||
pytest.param(
|
||||
ClearTensor(Integer(7, False), (1, 2, 3)),
|
||||
ClearTensor(Integer(7, False), (3, 2, 1)),
|
||||
None,
|
||||
marks=pytest.mark.xfail(raises=AssertionError),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mix_tensor_values(value1, value2, expected_mixed_value):
|
||||
"""Test mix_values_determine_holding_dtype helper with tensors"""
|
||||
|
||||
assert expected_mixed_value == mix_values_determine_holding_dtype(value1, value2)
|
||||
|
||||
|
||||
class DummyValue(BaseValue):
|
||||
"""DummyValue"""
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return BaseValue.__eq__(self, other)
|
||||
|
||||
|
||||
def test_fail_mix_values_determine_holding_dtype():
|
||||
"""Test function for failure case of mix_values_determine_holding_dtype"""
|
||||
|
||||
with pytest.raises(ValueError, match=r".* does not support value .*"):
|
||||
mix_values_determine_holding_dtype(
|
||||
DummyValue(Integer(32, True), True),
|
||||
DummyValue(Integer(32, True), True),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user