diff --git a/hdk/common/data_types/dtypes_helpers.py b/hdk/common/data_types/dtypes_helpers.py index 06516dfcd..06a23431e 100644 --- a/hdk/common/data_types/dtypes_helpers.py +++ b/hdk/common/data_types/dtypes_helpers.py @@ -4,7 +4,15 @@ from copy import deepcopy from functools import partial from typing import Callable, Union, cast -from ..values import BaseValue, ClearValue, EncryptedValue, ScalarValue +from ..values import ( + BaseValue, + ClearTensor, + ClearValue, + EncryptedTensor, + EncryptedValue, + ScalarValue, + TensorValue, +) from .base import BaseDataType from .floats import Float from .integers import Integer, get_bits_to_represent_value_as_integer @@ -134,19 +142,22 @@ def find_type_to_hold_both_lossy( return type_to_return -def mix_scalar_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> ScalarValue: - """Return mixed value with data type able to hold both value1 and value2 dtypes. +def mix_scalar_values_determine_holding_dtype( + value1: ScalarValue, + value2: ScalarValue, +) -> ScalarValue: + """Return mixed ScalarValue with data type able to hold both value1 and value2 dtypes. Returns a ScalarValue that would result from computation on both value1 and value2 while determining the data type able to hold both value1 and value2 data type (this can be lossy with floats). Args: - value1 (BaseValue): first ScalarValue to mix. - value2 (BaseValue): second ScalarValue to mix. + value1 (ScalarValue): first ScalarValue to mix. + value2 (ScalarValue): second ScalarValue to mix. Returns: - ScalarValue: The resulting mixed BaseValue with data type able to hold both value1 and + ScalarValue: The resulting mixed ScalarValue with data type able to hold both value1 and value2 dtypes. """ @@ -164,6 +175,77 @@ def mix_scalar_values_determine_holding_dtype(value1: BaseValue, value2: BaseVal return mixed_value +def mix_tensor_values_determine_holding_dtype( + value1: TensorValue, + value2: TensorValue, +) -> TensorValue: + """Return mixed TensorValue with data type able to hold both value1 and value2 dtypes. + + Returns a TensorValue that would result from computation on both value1 and value2 while + determining the data type able to hold both value1 and value2 data type (this can be lossy + with floats). + + Args: + value1 (TensorValue): first TensorValue to mix. + value2 (TensorValue): second TensorValue to mix. + + Returns: + TensorValue: The resulting mixed TensorValue with data type able to hold both value1 and + value2 dtypes. + """ + + assert isinstance(value1, TensorValue), f"Unsupported value1: {value1}, expected TensorValue" + assert isinstance(value2, TensorValue), f"Unsupported value2: {value2}, expected TensorValue" + + assert value1.shape == value2.shape, ( + f"Tensors have different shapes which is not supported.\n" + f"value1: {value1.shape}, value2: {value2.shape}" + ) + + holding_type = find_type_to_hold_both_lossy(value1.data_type, value2.data_type) + shape = value1.shape + + if value1.is_encrypted or value2.is_encrypted: + mixed_value = EncryptedTensor(data_type=holding_type, shape=shape) + else: + mixed_value = ClearTensor(data_type=holding_type, shape=shape) + + return mixed_value + + +def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> BaseValue: + """Return mixed BaseValue with data type able to hold both value1 and value2 dtypes. + + Returns a BaseValue that would result from computation on both value1 and value2 while + determining the data type able to hold both value1 and value2 data type (this can be lossy + with floats). Supports only mixing instances from the same class. + + Args: + value1 (BaseValue): first BaseValue to mix. + value2 (BaseValue): second BaseValue to mix. + + Raises: + ValueError: raised if the BaseValue is not one of (ScalarValue, TensorValue) + + Returns: + BaseValue: The resulting mixed BaseValue with data type able to hold both value1 and value2 + dtypes. + """ + + assert ( + value1.__class__ == value2.__class__ + ), f"Cannot mix values of different types: value 1:{type(value1)}, value2: {type(value2)}" + + if isinstance(value1, ScalarValue) and isinstance(value2, ScalarValue): + return mix_scalar_values_determine_holding_dtype(value1, value2) + if isinstance(value1, TensorValue) and isinstance(value2, TensorValue): + return mix_tensor_values_determine_holding_dtype(value1, value2) + + raise ValueError( + f"{mix_values_determine_holding_dtype.__name__} does not support value {type(value1)}" + ) + + def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]) -> BaseDataType: """Helper function to determine the BaseDataType to hold the input constant data. diff --git a/hdk/hnumpy/tracing.py b/hdk/hnumpy/tracing.py index 26eadd548..6b4e40e18 100644 --- a/hdk/hnumpy/tracing.py +++ b/hdk/hnumpy/tracing.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Dict import numpy from numpy.typing import DTypeLike -from ..common.data_types.dtypes_helpers import mix_scalar_values_determine_holding_dtype +from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype from ..common.operator_graph import OPGraph from ..common.representation.intermediate import ArbitraryFunction, Constant from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters @@ -31,7 +31,7 @@ NPConstant = partial( class NPTracer(BaseTracer): """Tracer class for numpy operations.""" - _mix_values_func: Callable[..., BaseValue] = mix_scalar_values_determine_holding_dtype + _mix_values_func: Callable[..., BaseValue] = mix_values_determine_holding_dtype def __array_ufunc__(self, ufunc, method, *input_tracers, **kwargs): """Catch calls to numpy ufunc and routes them to tracing functions if supported. diff --git a/tests/common/data_types/test_dtypes_helpers.py b/tests/common/data_types/test_dtypes_helpers.py index 5df0666f8..05a28bddb 100644 --- a/tests/common/data_types/test_dtypes_helpers.py +++ b/tests/common/data_types/test_dtypes_helpers.py @@ -5,13 +5,19 @@ import pytest from hdk.common.data_types.base import BaseDataType from hdk.common.data_types.dtypes_helpers import ( find_type_to_hold_both_lossy, - mix_scalar_values_determine_holding_dtype, + mix_values_determine_holding_dtype, value_is_encrypted_integer, value_is_encrypted_unsigned_integer, ) from hdk.common.data_types.floats import Float from hdk.common.data_types.integers import Integer -from hdk.common.values import BaseValue, ClearValue, EncryptedValue +from hdk.common.values import ( + BaseValue, + ClearTensor, + ClearValue, + EncryptedTensor, + EncryptedValue, +) @pytest.mark.parametrize( @@ -167,7 +173,67 @@ def test_mix_data_types( ), ], ) -def test_mix_values(value1: BaseValue, value2: BaseValue, expected_mixed_value: BaseValue): - """Test mix_values helper""" +def test_mix_scalar_values(value1, value2, expected_mixed_value): + """Test mix_values_determine_holding_dtype helper with scalars""" - assert expected_mixed_value == mix_scalar_values_determine_holding_dtype(value1, value2) + assert expected_mixed_value == mix_values_determine_holding_dtype(value1, value2) + + +@pytest.mark.parametrize( + "value1,value2,expected_mixed_value", + [ + pytest.param( + EncryptedTensor(Integer(7, False), (1, 2, 3)), + EncryptedTensor(Integer(7, False), (1, 2, 3)), + EncryptedTensor(Integer(7, False), (1, 2, 3)), + ), + pytest.param( + ClearTensor(Integer(7, False), (1, 2, 3)), + EncryptedTensor(Integer(7, False), (1, 2, 3)), + EncryptedTensor(Integer(7, False), (1, 2, 3)), + ), + pytest.param( + ClearTensor(Integer(7, False), (1, 2, 3)), + ClearTensor(Integer(7, False), (1, 2, 3)), + ClearTensor(Integer(7, False), (1, 2, 3)), + ), + pytest.param( + ClearTensor(Integer(7, False), (1, 2, 3)), + ClearTensor(Integer(7, False), (1, 2, 3)), + ClearTensor(Integer(7, False), (1, 2, 3)), + ), + pytest.param( + ClearTensor(Integer(7, False), (1, 2, 3)), + EncryptedValue(Integer(7, False)), + None, + marks=pytest.mark.xfail(raises=AssertionError), + ), + pytest.param( + ClearTensor(Integer(7, False), (1, 2, 3)), + ClearTensor(Integer(7, False), (3, 2, 1)), + None, + marks=pytest.mark.xfail(raises=AssertionError), + ), + ], +) +def test_mix_tensor_values(value1, value2, expected_mixed_value): + """Test mix_values_determine_holding_dtype helper with tensors""" + + assert expected_mixed_value == mix_values_determine_holding_dtype(value1, value2) + + +class DummyValue(BaseValue): + """DummyValue""" + + def __eq__(self, other: object) -> bool: + return BaseValue.__eq__(self, other) + + +def test_fail_mix_values_determine_holding_dtype(): + """Test function for failure case of mix_values_determine_holding_dtype""" + + with pytest.raises(ValueError, match=r".* does not support value .*"): + mix_values_determine_holding_dtype( + DummyValue(Integer(32, True), True), + DummyValue(Integer(32, True), True), + )