mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
dev(ir): add the ArbitraryFunction ir node
This commit is contained in:
@@ -2,14 +2,15 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
|
||||
|
||||
from ..data_types import BaseValue
|
||||
from ..data_types.base import BaseDataType
|
||||
from ..data_types.dtypes_helpers import mix_values_determine_holding_dtype
|
||||
from ..data_types.floats import Float
|
||||
from ..data_types.integers import Integer, get_bits_to_represent_int
|
||||
from ..data_types.scalars import Scalars
|
||||
from ..data_types.values import ClearValue
|
||||
from ..data_types.values import ClearValue, EncryptedValue
|
||||
|
||||
|
||||
class IntermediateNode(ABC):
|
||||
@@ -17,8 +18,8 @@ class IntermediateNode(ABC):
|
||||
|
||||
inputs: List[BaseValue]
|
||||
outputs: List[BaseValue]
|
||||
op_args: Optional[Tuple[Any, ...]]
|
||||
op_kwargs: Optional[Dict[str, Any]]
|
||||
op_args: Tuple[Any, ...]
|
||||
op_kwargs: Dict[str, Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -28,8 +29,8 @@ class IntermediateNode(ABC):
|
||||
) -> None:
|
||||
self.inputs = list(inputs)
|
||||
assert all(isinstance(x, BaseValue) for x in self.inputs)
|
||||
self.op_args = op_args
|
||||
self.op_kwargs = op_kwargs
|
||||
self.op_args = deepcopy(op_args) if op_args is not None else ()
|
||||
self.op_kwargs = deepcopy(op_kwargs) if op_kwargs is not None else {}
|
||||
|
||||
def _init_binary(
|
||||
self,
|
||||
@@ -167,3 +168,33 @@ class ConstantInput(IntermediateNode):
|
||||
|
||||
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
|
||||
return self.constant_data
|
||||
|
||||
|
||||
class ArbitraryFunction(IntermediateNode):
|
||||
"""Node representing a univariate arbitrary function, e.g. sin(x)"""
|
||||
|
||||
# The arbitrary_func is not optional but mypy has a long standing bug and is not able to
|
||||
# understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623
|
||||
arbitrary_func: Optional[Callable]
|
||||
|
||||
# pylint: disable=too-many-arguments
|
||||
def __init__(
|
||||
self,
|
||||
input_base_value: BaseValue,
|
||||
arbitrary_func: Callable,
|
||||
output_dtype: BaseDataType,
|
||||
op_args: Optional[Tuple[Any, ...]] = None,
|
||||
op_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__([input_base_value], op_args=op_args, op_kwargs=op_kwargs)
|
||||
assert len(self.inputs) == 1
|
||||
self.arbitrary_func = arbitrary_func
|
||||
# TLU/PBS has an encrypted output
|
||||
self.outputs = [EncryptedValue(output_dtype)]
|
||||
|
||||
# pylint: enable=too-many-arguments
|
||||
|
||||
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
|
||||
# This is the continuation of the mypy bug workaround
|
||||
assert self.arbitrary_func is not None
|
||||
return self.arbitrary_func(inputs[0], *self.op_args, **self.op_kwargs)
|
||||
|
||||
@@ -31,6 +31,47 @@ from hdk.common.representation import intermediate as ir
|
||||
pytest.param(ir.Input(ClearValue(Integer(32, True)), "in", 0), [42], 42, id="Input"),
|
||||
pytest.param(ir.ConstantInput(42), None, 42, id="ConstantInput"),
|
||||
pytest.param(ir.ConstantInput(-42), None, -42, id="ConstantInput"),
|
||||
pytest.param(
|
||||
ir.ArbitraryFunction(
|
||||
EncryptedValue(Integer(7, False)), lambda x: x + 3, Integer(7, False)
|
||||
),
|
||||
[10],
|
||||
13,
|
||||
id="ArbitraryFunction, x + 3",
|
||||
),
|
||||
pytest.param(
|
||||
ir.ArbitraryFunction(
|
||||
EncryptedValue(Integer(7, False)),
|
||||
lambda x, y: x + y,
|
||||
Integer(7, False),
|
||||
op_kwargs={"y": 3},
|
||||
),
|
||||
[10],
|
||||
13,
|
||||
id="ArbitraryFunction, (x, y) -> x + y, where y is constant == 3",
|
||||
),
|
||||
pytest.param(
|
||||
ir.ArbitraryFunction(
|
||||
EncryptedValue(Integer(7, False)),
|
||||
lambda x, y: y[x],
|
||||
Integer(7, False),
|
||||
op_kwargs={"y": (1, 2, 3, 4)},
|
||||
),
|
||||
[2],
|
||||
3,
|
||||
id="ArbitraryFunction, (x, y) -> y[x], where y is constant == (1, 2, 3, 4)",
|
||||
),
|
||||
pytest.param(
|
||||
ir.ArbitraryFunction(
|
||||
EncryptedValue(Integer(7, False)),
|
||||
lambda x, y: y[3],
|
||||
Integer(7, False),
|
||||
op_kwargs={"y": (1, 2, 3, 4)},
|
||||
),
|
||||
[2],
|
||||
4,
|
||||
id="ArbitraryFunction, x, y -> y[3], where y is constant == (1, 2, 3, 4)",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_evaluate(
|
||||
|
||||
Reference in New Issue
Block a user