diff --git a/hdk/common/representation/intermediate.py b/hdk/common/representation/intermediate.py index d5660a205..4b07e042d 100644 --- a/hdk/common/representation/intermediate.py +++ b/hdk/common/representation/intermediate.py @@ -2,14 +2,15 @@ from abc import ABC, abstractmethod from copy import deepcopy -from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple from ..data_types import BaseValue +from ..data_types.base import BaseDataType from ..data_types.dtypes_helpers import mix_values_determine_holding_dtype from ..data_types.floats import Float from ..data_types.integers import Integer, get_bits_to_represent_int from ..data_types.scalars import Scalars -from ..data_types.values import ClearValue +from ..data_types.values import ClearValue, EncryptedValue class IntermediateNode(ABC): @@ -17,8 +18,8 @@ class IntermediateNode(ABC): inputs: List[BaseValue] outputs: List[BaseValue] - op_args: Optional[Tuple[Any, ...]] - op_kwargs: Optional[Dict[str, Any]] + op_args: Tuple[Any, ...] + op_kwargs: Dict[str, Any] def __init__( self, @@ -28,8 +29,8 @@ class IntermediateNode(ABC): ) -> None: self.inputs = list(inputs) assert all(isinstance(x, BaseValue) for x in self.inputs) - self.op_args = op_args - self.op_kwargs = op_kwargs + self.op_args = deepcopy(op_args) if op_args is not None else () + self.op_kwargs = deepcopy(op_kwargs) if op_kwargs is not None else {} def _init_binary( self, @@ -167,3 +168,33 @@ class ConstantInput(IntermediateNode): def evaluate(self, inputs: Mapping[int, Any]) -> Any: return self.constant_data + + +class ArbitraryFunction(IntermediateNode): + """Node representing a univariate arbitrary function, e.g. sin(x)""" + + # The arbitrary_func is not optional but mypy has a long standing bug and is not able to + # understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623 + arbitrary_func: Optional[Callable] + + # pylint: disable=too-many-arguments + def __init__( + self, + input_base_value: BaseValue, + arbitrary_func: Callable, + output_dtype: BaseDataType, + op_args: Optional[Tuple[Any, ...]] = None, + op_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__([input_base_value], op_args=op_args, op_kwargs=op_kwargs) + assert len(self.inputs) == 1 + self.arbitrary_func = arbitrary_func + # TLU/PBS has an encrypted output + self.outputs = [EncryptedValue(output_dtype)] + + # pylint: enable=too-many-arguments + + def evaluate(self, inputs: Mapping[int, Any]) -> Any: + # This is the continuation of the mypy bug workaround + assert self.arbitrary_func is not None + return self.arbitrary_func(inputs[0], *self.op_args, **self.op_kwargs) diff --git a/tests/common/representation/test_intermediate.py b/tests/common/representation/test_intermediate.py index 2c4a665ec..530742be2 100644 --- a/tests/common/representation/test_intermediate.py +++ b/tests/common/representation/test_intermediate.py @@ -31,6 +31,47 @@ from hdk.common.representation import intermediate as ir pytest.param(ir.Input(ClearValue(Integer(32, True)), "in", 0), [42], 42, id="Input"), pytest.param(ir.ConstantInput(42), None, 42, id="ConstantInput"), pytest.param(ir.ConstantInput(-42), None, -42, id="ConstantInput"), + pytest.param( + ir.ArbitraryFunction( + EncryptedValue(Integer(7, False)), lambda x: x + 3, Integer(7, False) + ), + [10], + 13, + id="ArbitraryFunction, x + 3", + ), + pytest.param( + ir.ArbitraryFunction( + EncryptedValue(Integer(7, False)), + lambda x, y: x + y, + Integer(7, False), + op_kwargs={"y": 3}, + ), + [10], + 13, + id="ArbitraryFunction, (x, y) -> x + y, where y is constant == 3", + ), + pytest.param( + ir.ArbitraryFunction( + EncryptedValue(Integer(7, False)), + lambda x, y: y[x], + Integer(7, False), + op_kwargs={"y": (1, 2, 3, 4)}, + ), + [2], + 3, + id="ArbitraryFunction, (x, y) -> y[x], where y is constant == (1, 2, 3, 4)", + ), + pytest.param( + ir.ArbitraryFunction( + EncryptedValue(Integer(7, False)), + lambda x, y: y[3], + Integer(7, False), + op_kwargs={"y": (1, 2, 3, 4)}, + ), + [2], + 4, + id="ArbitraryFunction, x, y -> y[3], where y is constant == (1, 2, 3, 4)", + ), ], ) def test_evaluate(