dev(ir): add the ArbitraryFunction ir node

This commit is contained in:
Arthur Meyre
2021-08-06 15:04:34 +02:00
parent 789a976661
commit 5d9259c000
2 changed files with 78 additions and 6 deletions

View File

@@ -2,14 +2,15 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
from ..data_types import BaseValue
from ..data_types.base import BaseDataType
from ..data_types.dtypes_helpers import mix_values_determine_holding_dtype
from ..data_types.floats import Float
from ..data_types.integers import Integer, get_bits_to_represent_int
from ..data_types.scalars import Scalars
from ..data_types.values import ClearValue
from ..data_types.values import ClearValue, EncryptedValue
class IntermediateNode(ABC):
@@ -17,8 +18,8 @@ class IntermediateNode(ABC):
inputs: List[BaseValue]
outputs: List[BaseValue]
op_args: Optional[Tuple[Any, ...]]
op_kwargs: Optional[Dict[str, Any]]
op_args: Tuple[Any, ...]
op_kwargs: Dict[str, Any]
def __init__(
self,
@@ -28,8 +29,8 @@ class IntermediateNode(ABC):
) -> None:
self.inputs = list(inputs)
assert all(isinstance(x, BaseValue) for x in self.inputs)
self.op_args = op_args
self.op_kwargs = op_kwargs
self.op_args = deepcopy(op_args) if op_args is not None else ()
self.op_kwargs = deepcopy(op_kwargs) if op_kwargs is not None else {}
def _init_binary(
self,
@@ -167,3 +168,33 @@ class ConstantInput(IntermediateNode):
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
return self.constant_data
class ArbitraryFunction(IntermediateNode):
"""Node representing a univariate arbitrary function, e.g. sin(x)"""
# The arbitrary_func is not optional but mypy has a long standing bug and is not able to
# understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623
arbitrary_func: Optional[Callable]
# pylint: disable=too-many-arguments
def __init__(
self,
input_base_value: BaseValue,
arbitrary_func: Callable,
output_dtype: BaseDataType,
op_args: Optional[Tuple[Any, ...]] = None,
op_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__([input_base_value], op_args=op_args, op_kwargs=op_kwargs)
assert len(self.inputs) == 1
self.arbitrary_func = arbitrary_func
# TLU/PBS has an encrypted output
self.outputs = [EncryptedValue(output_dtype)]
# pylint: enable=too-many-arguments
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
# This is the continuation of the mypy bug workaround
assert self.arbitrary_func is not None
return self.arbitrary_func(inputs[0], *self.op_args, **self.op_kwargs)

View File

@@ -31,6 +31,47 @@ from hdk.common.representation import intermediate as ir
pytest.param(ir.Input(ClearValue(Integer(32, True)), "in", 0), [42], 42, id="Input"),
pytest.param(ir.ConstantInput(42), None, 42, id="ConstantInput"),
pytest.param(ir.ConstantInput(-42), None, -42, id="ConstantInput"),
pytest.param(
ir.ArbitraryFunction(
EncryptedValue(Integer(7, False)), lambda x: x + 3, Integer(7, False)
),
[10],
13,
id="ArbitraryFunction, x + 3",
),
pytest.param(
ir.ArbitraryFunction(
EncryptedValue(Integer(7, False)),
lambda x, y: x + y,
Integer(7, False),
op_kwargs={"y": 3},
),
[10],
13,
id="ArbitraryFunction, (x, y) -> x + y, where y is constant == 3",
),
pytest.param(
ir.ArbitraryFunction(
EncryptedValue(Integer(7, False)),
lambda x, y: y[x],
Integer(7, False),
op_kwargs={"y": (1, 2, 3, 4)},
),
[2],
3,
id="ArbitraryFunction, (x, y) -> y[x], where y is constant == (1, 2, 3, 4)",
),
pytest.param(
ir.ArbitraryFunction(
EncryptedValue(Integer(7, False)),
lambda x, y: y[3],
Integer(7, False),
op_kwargs={"y": (1, 2, 3, 4)},
),
[2],
4,
id="ArbitraryFunction, x, y -> y[3], where y is constant == (1, 2, 3, 4)",
),
],
)
def test_evaluate(