mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor: change IR and BaseTracer to support any input mixing function
- prepares support for tensor types - introduce _n_in and n_in() and requires_mix_values_func() for IntermediateNode to know if they require a function to mix input values and determine output value - update BaseTracer to pass the class _mix_values_func to IntermediateNodes that need it - update NPTracer to stay consistent with current behavior regarding values mixing
This commit is contained in:
@@ -11,16 +11,20 @@ from ..data_types.dtypes_helpers import (
|
||||
mix_scalar_values_determine_holding_dtype,
|
||||
)
|
||||
|
||||
IR_MIX_VALUES_FUNC_ARG_NAME = "mix_values_func"
|
||||
|
||||
|
||||
class IntermediateNode(ABC):
|
||||
"""Abstract Base Class to derive from to represent source program operations."""
|
||||
|
||||
inputs: List[BaseValue]
|
||||
outputs: List[BaseValue]
|
||||
_n_in: int # _n_in indicates how many inputs are required to evaluate the IntermediateNode
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inputs: Iterable[BaseValue],
|
||||
**_kwargs, # This is to be able to feed arbitrary arguments to IntermediateNodes
|
||||
) -> None:
|
||||
self.inputs = list(inputs)
|
||||
assert all(isinstance(x, BaseValue) for x in self.inputs)
|
||||
@@ -28,13 +32,15 @@ class IntermediateNode(ABC):
|
||||
def _init_binary(
|
||||
self,
|
||||
inputs: Iterable[BaseValue],
|
||||
mix_values_func: Callable[..., BaseValue] = mix_scalar_values_determine_holding_dtype,
|
||||
**_kwargs, # Required to conform to __init__ typing
|
||||
) -> None:
|
||||
"""__init__ for a binary operation, ie two inputs."""
|
||||
IntermediateNode.__init__(self, inputs)
|
||||
|
||||
assert len(self.inputs) == 2
|
||||
|
||||
self.outputs = [mix_scalar_values_determine_holding_dtype(self.inputs[0], self.inputs[1])]
|
||||
self.outputs = [mix_values_func(self.inputs[0], self.inputs[1])]
|
||||
|
||||
def _is_equivalent_to_binary_commutative(self, other: object) -> bool:
|
||||
"""is_equivalent_to for a binary and commutative operation."""
|
||||
@@ -82,10 +88,30 @@ class IntermediateNode(ABC):
|
||||
Any: the result of the computation
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def n_in(cls) -> int:
|
||||
"""Returns how many inputs the node has.
|
||||
|
||||
Returns:
|
||||
int: The number of inputs of the node.
|
||||
"""
|
||||
return cls._n_in
|
||||
|
||||
@classmethod
|
||||
def requires_mix_values_func(cls) -> bool:
|
||||
"""Function to determine whether the Class requires a mix_values_func to be built.
|
||||
|
||||
Returns:
|
||||
bool: True if __init__ expects a mix_values_func argument.
|
||||
"""
|
||||
return cls.n_in() > 1
|
||||
|
||||
|
||||
class Add(IntermediateNode):
|
||||
"""Addition between two values."""
|
||||
|
||||
_n_in: int = 2
|
||||
|
||||
__init__ = IntermediateNode._init_binary
|
||||
is_equivalent_to = IntermediateNode._is_equivalent_to_binary_commutative
|
||||
|
||||
@@ -96,6 +122,8 @@ class Add(IntermediateNode):
|
||||
class Sub(IntermediateNode):
|
||||
"""Subtraction between two values."""
|
||||
|
||||
_n_in: int = 2
|
||||
|
||||
__init__ = IntermediateNode._init_binary
|
||||
is_equivalent_to = IntermediateNode._is_equivalent_to_binary_non_commutative
|
||||
|
||||
@@ -106,6 +134,8 @@ class Sub(IntermediateNode):
|
||||
class Mul(IntermediateNode):
|
||||
"""Multiplication between two values."""
|
||||
|
||||
_n_in: int = 2
|
||||
|
||||
__init__ = IntermediateNode._init_binary
|
||||
is_equivalent_to = IntermediateNode._is_equivalent_to_binary_commutative
|
||||
|
||||
@@ -118,6 +148,7 @@ class Input(IntermediateNode):
|
||||
|
||||
input_name: str
|
||||
program_input_idx: int
|
||||
_n_in: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -147,6 +178,7 @@ class ConstantInput(IntermediateNode):
|
||||
"""Node representing a constant of the program."""
|
||||
|
||||
_constant_data: Any
|
||||
_n_in: int = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -191,6 +223,7 @@ class ArbitraryFunction(IntermediateNode):
|
||||
op_args: Tuple[Any, ...]
|
||||
op_kwargs: Dict[str, Any]
|
||||
op_name: str
|
||||
_n_in: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""This file holds the code that can be shared between tracers."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Iterable, List, Tuple, Type, Union
|
||||
from typing import Any, Callable, Iterable, List, Tuple, Type, Union
|
||||
|
||||
from ..data_types import BaseValue
|
||||
from ..representation import intermediate as ir
|
||||
from ..representation.intermediate import IR_MIX_VALUES_FUNC_ARG_NAME
|
||||
|
||||
|
||||
class BaseTracer(ABC):
|
||||
@@ -13,6 +14,7 @@ class BaseTracer(ABC):
|
||||
inputs: List["BaseTracer"]
|
||||
traced_computation: ir.IntermediateNode
|
||||
output: BaseValue
|
||||
_mix_values_func: Callable[..., BaseValue]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -47,6 +49,10 @@ class BaseTracer(ABC):
|
||||
BaseTracer: The BaseTracer for that constant.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _get_mix_values_func(cls):
|
||||
return cls._mix_values_func
|
||||
|
||||
def instantiate_output_tracers(
|
||||
self,
|
||||
inputs: Iterable[Union["BaseTracer", Any]],
|
||||
@@ -71,8 +77,15 @@ class BaseTracer(ABC):
|
||||
|
||||
sanitized_inputs = [sanitize(inp) for inp in inputs]
|
||||
|
||||
additional_parameters = (
|
||||
{IR_MIX_VALUES_FUNC_ARG_NAME: self._get_mix_values_func()}
|
||||
if computation_to_trace.requires_mix_values_func()
|
||||
else {}
|
||||
)
|
||||
|
||||
traced_computation = computation_to_trace(
|
||||
(x.output for x in sanitized_inputs),
|
||||
**additional_parameters,
|
||||
)
|
||||
|
||||
output_tracers = tuple(
|
||||
|
||||
@@ -7,6 +7,7 @@ import numpy
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
from ..common.data_types import BaseValue
|
||||
from ..common.data_types.dtypes_helpers import mix_scalar_values_determine_holding_dtype
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.representation.intermediate import ArbitraryFunction, ConstantInput
|
||||
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
|
||||
@@ -30,6 +31,8 @@ NPConstantInput = partial(
|
||||
class NPTracer(BaseTracer):
|
||||
"""Tracer class for numpy operations."""
|
||||
|
||||
_mix_values_func: Callable[..., BaseValue] = mix_scalar_values_determine_holding_dtype
|
||||
|
||||
def __array_ufunc__(self, ufunc, method, *input_tracers, **kwargs):
|
||||
"""Catch calls to numpy ufunc and routes them to tracing functions if supported.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user