mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
285 lines
8.2 KiB
Python
285 lines
8.2 KiB
Python
"""File containing code to represent source programs operations."""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from copy import deepcopy
|
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type
|
|
|
|
from ..data_types.base import BaseDataType
|
|
from ..data_types.dtypes_helpers import (
|
|
get_base_value_for_python_constant_data,
|
|
mix_scalar_values_determine_holding_dtype,
|
|
)
|
|
from ..values import BaseValue, ClearValue, EncryptedValue, TensorValue
|
|
|
|
IR_MIX_VALUES_FUNC_ARG_NAME = "mix_values_func"
|
|
|
|
ALL_IR_NODES: Set[Type] = set()
|
|
|
|
|
|
class IntermediateNode(ABC):
|
|
"""Abstract Base Class to derive from to represent source program operations."""
|
|
|
|
inputs: List[BaseValue]
|
|
outputs: List[BaseValue]
|
|
_n_in: int # _n_in indicates how many inputs are required to evaluate the IntermediateNode
|
|
|
|
def __init__(
|
|
self,
|
|
inputs: Iterable[BaseValue],
|
|
**_kwargs, # This is to be able to feed arbitrary arguments to IntermediateNodes
|
|
) -> None:
|
|
self.inputs = list(inputs)
|
|
assert all(isinstance(x, BaseValue) for x in self.inputs)
|
|
|
|
# Register all IR nodes
|
|
def __init_subclass__(cls, **kwargs):
|
|
super().__init_subclass__(**kwargs)
|
|
ALL_IR_NODES.add(cls)
|
|
|
|
def _init_binary(
|
|
self,
|
|
inputs: Iterable[BaseValue],
|
|
mix_values_func: Callable[..., BaseValue] = mix_scalar_values_determine_holding_dtype,
|
|
**_kwargs, # Required to conform to __init__ typing
|
|
) -> None:
|
|
"""__init__ for a binary operation, ie two inputs."""
|
|
IntermediateNode.__init__(self, inputs)
|
|
|
|
assert len(self.inputs) == 2
|
|
|
|
self.outputs = [mix_values_func(self.inputs[0], self.inputs[1])]
|
|
|
|
@abstractmethod
|
|
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
|
"""Function to simulate what the represented computation would output for the given inputs.
|
|
|
|
Args:
|
|
inputs (Dict[int, Any]): Dict containing the inputs for the evaluation
|
|
|
|
Returns:
|
|
Any: the result of the computation
|
|
"""
|
|
|
|
@classmethod
|
|
def n_in(cls) -> int:
|
|
"""Returns how many inputs the node has.
|
|
|
|
Returns:
|
|
int: The number of inputs of the node.
|
|
"""
|
|
return cls._n_in
|
|
|
|
@classmethod
|
|
def requires_mix_values_func(cls) -> bool:
|
|
"""Function to determine whether the Class requires a mix_values_func to be built.
|
|
|
|
Returns:
|
|
bool: True if __init__ expects a mix_values_func argument.
|
|
"""
|
|
return cls.n_in() > 1
|
|
|
|
@abstractmethod
|
|
def label(self) -> str:
|
|
"""Function to get the label of the node.
|
|
|
|
Returns:
|
|
str: the label of the node
|
|
|
|
"""
|
|
|
|
|
|
class Add(IntermediateNode):
|
|
"""Addition between two values."""
|
|
|
|
_n_in: int = 2
|
|
|
|
__init__ = IntermediateNode._init_binary
|
|
|
|
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
|
return inputs[0] + inputs[1]
|
|
|
|
def label(self) -> str:
|
|
return "+"
|
|
|
|
|
|
class Sub(IntermediateNode):
|
|
"""Subtraction between two values."""
|
|
|
|
_n_in: int = 2
|
|
|
|
__init__ = IntermediateNode._init_binary
|
|
|
|
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
|
return inputs[0] - inputs[1]
|
|
|
|
def label(self) -> str:
|
|
return "-"
|
|
|
|
|
|
class Mul(IntermediateNode):
|
|
"""Multiplication between two values."""
|
|
|
|
_n_in: int = 2
|
|
|
|
__init__ = IntermediateNode._init_binary
|
|
|
|
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
|
return inputs[0] * inputs[1]
|
|
|
|
def label(self) -> str:
|
|
return "*"
|
|
|
|
|
|
class Input(IntermediateNode):
|
|
"""Node representing an input of the program."""
|
|
|
|
input_name: str
|
|
program_input_idx: int
|
|
_n_in: int = 1
|
|
|
|
def __init__(
|
|
self,
|
|
input_value: BaseValue,
|
|
input_name: str,
|
|
program_input_idx: int,
|
|
) -> None:
|
|
super().__init__((input_value,))
|
|
assert len(self.inputs) == 1
|
|
self.input_name = input_name
|
|
self.program_input_idx = program_input_idx
|
|
self.outputs = [deepcopy(self.inputs[0])]
|
|
|
|
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
|
return inputs[0]
|
|
|
|
def label(self) -> str:
|
|
return self.input_name
|
|
|
|
|
|
class Constant(IntermediateNode):
|
|
"""Node representing a constant of the program."""
|
|
|
|
_constant_data: Any
|
|
_n_in: int = 0
|
|
|
|
def __init__(
|
|
self,
|
|
constant_data: Any,
|
|
get_base_value_for_data_func: Callable[
|
|
[Any], Callable[..., BaseValue]
|
|
] = get_base_value_for_python_constant_data,
|
|
) -> None:
|
|
super().__init__([])
|
|
|
|
base_value_class = get_base_value_for_data_func(constant_data)
|
|
|
|
self._constant_data = constant_data
|
|
self.outputs = [base_value_class(is_encrypted=False)]
|
|
|
|
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
|
return self.constant_data
|
|
|
|
@property
|
|
def constant_data(self) -> Any:
|
|
"""Returns the constant_data stored in the Constant node.
|
|
|
|
Returns:
|
|
Any: The constant data that was stored.
|
|
"""
|
|
return self._constant_data
|
|
|
|
def label(self) -> str:
|
|
return str(self.constant_data)
|
|
|
|
|
|
class ArbitraryFunction(IntermediateNode):
|
|
"""Node representing a univariate arbitrary function, e.g. sin(x)."""
|
|
|
|
# The arbitrary_func is not optional but mypy has a long standing bug and is not able to
|
|
# understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623
|
|
arbitrary_func: Optional[Callable]
|
|
op_args: Tuple[Any, ...]
|
|
op_kwargs: Dict[str, Any]
|
|
op_name: str
|
|
_n_in: int = 1
|
|
|
|
def __init__(
|
|
self,
|
|
input_base_value: BaseValue,
|
|
arbitrary_func: Callable,
|
|
output_dtype: BaseDataType,
|
|
op_name: Optional[str] = None,
|
|
op_args: Optional[Tuple[Any, ...]] = None,
|
|
op_kwargs: Optional[Dict[str, Any]] = None,
|
|
) -> None:
|
|
super().__init__([input_base_value])
|
|
assert len(self.inputs) == 1
|
|
self.arbitrary_func = arbitrary_func
|
|
self.op_args = op_args if op_args is not None else ()
|
|
self.op_kwargs = op_kwargs if op_kwargs is not None else {}
|
|
self.outputs = [input_base_value.__class__(output_dtype, input_base_value.is_encrypted)]
|
|
self.op_name = op_name if op_name is not None else self.__class__.__name__
|
|
|
|
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
|
# This is the continuation of the mypy bug workaround
|
|
assert self.arbitrary_func is not None
|
|
return self.arbitrary_func(inputs[0], *self.op_args, **self.op_kwargs)
|
|
|
|
def label(self) -> str:
|
|
return self.op_name
|
|
|
|
|
|
def default_dot_evaluation_function(lhs: Any, rhs: Any) -> Any:
|
|
"""Default python dot implementation for 1D iterable arrays.
|
|
|
|
Args:
|
|
lhs (Any): lhs vector of the dot.
|
|
rhs (Any): rhs vector of the dot.
|
|
|
|
Returns:
|
|
Any: the result of the dot operation.
|
|
"""
|
|
return sum(lhs * rhs for lhs, rhs in zip(lhs, rhs))
|
|
|
|
|
|
class Dot(IntermediateNode):
|
|
"""Node representing a dot product."""
|
|
|
|
_n_in: int = 2
|
|
# Optional, same issue as in ArbitraryFunction for mypy
|
|
evaluation_function: Optional[Callable[[Any, Any], Any]]
|
|
# Allows to use specialized implementations from e.g. numpy
|
|
|
|
def __init__(
|
|
self,
|
|
inputs: Iterable[BaseValue],
|
|
output_dtype: BaseDataType,
|
|
delegate_evaluation_function: Optional[
|
|
Callable[[Any, Any], Any]
|
|
] = default_dot_evaluation_function,
|
|
) -> None:
|
|
super().__init__(inputs)
|
|
assert len(self.inputs) == 2
|
|
|
|
assert all(
|
|
isinstance(input_value, TensorValue) and input_value.ndim == 1
|
|
for input_value in self.inputs
|
|
), f"Dot only supports two vectors ({TensorValue.__name__} with ndim == 1)"
|
|
|
|
output_scalar_value = (
|
|
EncryptedValue
|
|
if (self.inputs[0].is_encrypted or self.inputs[1].is_encrypted)
|
|
else ClearValue
|
|
)
|
|
|
|
self.outputs = [output_scalar_value(output_dtype)]
|
|
self.evaluation_function = delegate_evaluation_function
|
|
|
|
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
|
# This is the continuation of the mypy bug workaround
|
|
assert self.evaluation_function is not None
|
|
return self.evaluation_function(inputs[0], inputs[1])
|
|
|
|
def label(self) -> str:
|
|
return "dot"
|