refacto: change the way mixing values is handled

- use an intermediate function that checks which mix function to use
- update function used in hnumpy tracing
This commit is contained in:
Arthur Meyre
2021-08-23 14:41:31 +02:00
parent cc1221eac5
commit fc3ae6461c
3 changed files with 161 additions and 13 deletions

View File

@@ -4,7 +4,15 @@ from copy import deepcopy
from functools import partial
from typing import Callable, Union, cast
from ..values import BaseValue, ClearValue, EncryptedValue, ScalarValue
from ..values import (
BaseValue,
ClearTensor,
ClearValue,
EncryptedTensor,
EncryptedValue,
ScalarValue,
TensorValue,
)
from .base import BaseDataType
from .floats import Float
from .integers import Integer, get_bits_to_represent_value_as_integer
@@ -134,19 +142,22 @@ def find_type_to_hold_both_lossy(
return type_to_return
def mix_scalar_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> ScalarValue:
"""Return mixed value with data type able to hold both value1 and value2 dtypes.
def mix_scalar_values_determine_holding_dtype(
value1: ScalarValue,
value2: ScalarValue,
) -> ScalarValue:
"""Return mixed ScalarValue with data type able to hold both value1 and value2 dtypes.
Returns a ScalarValue that would result from computation on both value1 and value2 while
determining the data type able to hold both value1 and value2 data type (this can be lossy
with floats).
Args:
value1 (BaseValue): first ScalarValue to mix.
value2 (BaseValue): second ScalarValue to mix.
value1 (ScalarValue): first ScalarValue to mix.
value2 (ScalarValue): second ScalarValue to mix.
Returns:
ScalarValue: The resulting mixed BaseValue with data type able to hold both value1 and
ScalarValue: The resulting mixed ScalarValue with data type able to hold both value1 and
value2 dtypes.
"""
@@ -164,6 +175,77 @@ def mix_scalar_values_determine_holding_dtype(value1: BaseValue, value2: BaseVal
return mixed_value
def mix_tensor_values_determine_holding_dtype(
value1: TensorValue,
value2: TensorValue,
) -> TensorValue:
"""Return mixed TensorValue with data type able to hold both value1 and value2 dtypes.
Returns a TensorValue that would result from computation on both value1 and value2 while
determining the data type able to hold both value1 and value2 data type (this can be lossy
with floats).
Args:
value1 (TensorValue): first TensorValue to mix.
value2 (TensorValue): second TensorValue to mix.
Returns:
TensorValue: The resulting mixed TensorValue with data type able to hold both value1 and
value2 dtypes.
"""
assert isinstance(value1, TensorValue), f"Unsupported value1: {value1}, expected TensorValue"
assert isinstance(value2, TensorValue), f"Unsupported value2: {value2}, expected TensorValue"
assert value1.shape == value2.shape, (
f"Tensors have different shapes which is not supported.\n"
f"value1: {value1.shape}, value2: {value2.shape}"
)
holding_type = find_type_to_hold_both_lossy(value1.data_type, value2.data_type)
shape = value1.shape
if value1.is_encrypted or value2.is_encrypted:
mixed_value = EncryptedTensor(data_type=holding_type, shape=shape)
else:
mixed_value = ClearTensor(data_type=holding_type, shape=shape)
return mixed_value
def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> BaseValue:
"""Return mixed BaseValue with data type able to hold both value1 and value2 dtypes.
Returns a BaseValue that would result from computation on both value1 and value2 while
determining the data type able to hold both value1 and value2 data type (this can be lossy
with floats). Supports only mixing instances from the same class.
Args:
value1 (BaseValue): first BaseValue to mix.
value2 (BaseValue): second BaseValue to mix.
Raises:
ValueError: raised if the BaseValue is not one of (ScalarValue, TensorValue)
Returns:
BaseValue: The resulting mixed BaseValue with data type able to hold both value1 and value2
dtypes.
"""
assert (
value1.__class__ == value2.__class__
), f"Cannot mix values of different types: value 1:{type(value1)}, value2: {type(value2)}"
if isinstance(value1, ScalarValue) and isinstance(value2, ScalarValue):
return mix_scalar_values_determine_holding_dtype(value1, value2)
if isinstance(value1, TensorValue) and isinstance(value2, TensorValue):
return mix_tensor_values_determine_holding_dtype(value1, value2)
raise ValueError(
f"{mix_values_determine_holding_dtype.__name__} does not support value {type(value1)}"
)
def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]) -> BaseDataType:
"""Helper function to determine the BaseDataType to hold the input constant data.

View File

@@ -6,7 +6,7 @@ from typing import Any, Callable, Dict
import numpy
from numpy.typing import DTypeLike
from ..common.data_types.dtypes_helpers import mix_scalar_values_determine_holding_dtype
from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype
from ..common.operator_graph import OPGraph
from ..common.representation.intermediate import ArbitraryFunction, Constant
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
@@ -31,7 +31,7 @@ NPConstant = partial(
class NPTracer(BaseTracer):
"""Tracer class for numpy operations."""
_mix_values_func: Callable[..., BaseValue] = mix_scalar_values_determine_holding_dtype
_mix_values_func: Callable[..., BaseValue] = mix_values_determine_holding_dtype
def __array_ufunc__(self, ufunc, method, *input_tracers, **kwargs):
"""Catch calls to numpy ufunc and routes them to tracing functions if supported.

View File

@@ -5,13 +5,19 @@ import pytest
from hdk.common.data_types.base import BaseDataType
from hdk.common.data_types.dtypes_helpers import (
find_type_to_hold_both_lossy,
mix_scalar_values_determine_holding_dtype,
mix_values_determine_holding_dtype,
value_is_encrypted_integer,
value_is_encrypted_unsigned_integer,
)
from hdk.common.data_types.floats import Float
from hdk.common.data_types.integers import Integer
from hdk.common.values import BaseValue, ClearValue, EncryptedValue
from hdk.common.values import (
BaseValue,
ClearTensor,
ClearValue,
EncryptedTensor,
EncryptedValue,
)
@pytest.mark.parametrize(
@@ -167,7 +173,67 @@ def test_mix_data_types(
),
],
)
def test_mix_values(value1: BaseValue, value2: BaseValue, expected_mixed_value: BaseValue):
"""Test mix_values helper"""
def test_mix_scalar_values(value1, value2, expected_mixed_value):
"""Test mix_values_determine_holding_dtype helper with scalars"""
assert expected_mixed_value == mix_scalar_values_determine_holding_dtype(value1, value2)
assert expected_mixed_value == mix_values_determine_holding_dtype(value1, value2)
@pytest.mark.parametrize(
"value1,value2,expected_mixed_value",
[
pytest.param(
EncryptedTensor(Integer(7, False), (1, 2, 3)),
EncryptedTensor(Integer(7, False), (1, 2, 3)),
EncryptedTensor(Integer(7, False), (1, 2, 3)),
),
pytest.param(
ClearTensor(Integer(7, False), (1, 2, 3)),
EncryptedTensor(Integer(7, False), (1, 2, 3)),
EncryptedTensor(Integer(7, False), (1, 2, 3)),
),
pytest.param(
ClearTensor(Integer(7, False), (1, 2, 3)),
ClearTensor(Integer(7, False), (1, 2, 3)),
ClearTensor(Integer(7, False), (1, 2, 3)),
),
pytest.param(
ClearTensor(Integer(7, False), (1, 2, 3)),
ClearTensor(Integer(7, False), (1, 2, 3)),
ClearTensor(Integer(7, False), (1, 2, 3)),
),
pytest.param(
ClearTensor(Integer(7, False), (1, 2, 3)),
EncryptedValue(Integer(7, False)),
None,
marks=pytest.mark.xfail(raises=AssertionError),
),
pytest.param(
ClearTensor(Integer(7, False), (1, 2, 3)),
ClearTensor(Integer(7, False), (3, 2, 1)),
None,
marks=pytest.mark.xfail(raises=AssertionError),
),
],
)
def test_mix_tensor_values(value1, value2, expected_mixed_value):
"""Test mix_values_determine_holding_dtype helper with tensors"""
assert expected_mixed_value == mix_values_determine_holding_dtype(value1, value2)
class DummyValue(BaseValue):
"""DummyValue"""
def __eq__(self, other: object) -> bool:
return BaseValue.__eq__(self, other)
def test_fail_mix_values_determine_holding_dtype():
"""Test function for failure case of mix_values_determine_holding_dtype"""
with pytest.raises(ValueError, match=r".* does not support value .*"):
mix_values_determine_holding_dtype(
DummyValue(Integer(32, True), True),
DummyValue(Integer(32, True), True),
)