mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix(ir): make is_equivalent_to abstract
- this allows to make sure each node has a proper implementation - move op_args and op_kwargs in ArbitraryFunction only - update BaseTracer accordingly
This commit is contained in:
@@ -18,30 +18,20 @@ class IntermediateNode(ABC):
|
||||
|
||||
inputs: List[BaseValue]
|
||||
outputs: List[BaseValue]
|
||||
op_args: Tuple[Any, ...]
|
||||
op_kwargs: Dict[str, Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inputs: Iterable[BaseValue],
|
||||
op_args: Optional[Tuple[Any, ...]] = None,
|
||||
op_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.inputs = list(inputs)
|
||||
assert all(isinstance(x, BaseValue) for x in self.inputs)
|
||||
self.op_args = deepcopy(op_args) if op_args is not None else ()
|
||||
self.op_kwargs = deepcopy(op_kwargs) if op_kwargs is not None else {}
|
||||
|
||||
def _init_binary(
|
||||
self,
|
||||
inputs: Iterable[BaseValue],
|
||||
op_args: Optional[Tuple[Any, ...]] = None,
|
||||
op_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
assert op_args is None, f"Expected op_args to be None, got {op_args}"
|
||||
assert op_kwargs is None, f"Expected op_kwargs to be None, got {op_kwargs}"
|
||||
|
||||
IntermediateNode.__init__(self, inputs, op_args=op_args, op_kwargs=op_kwargs)
|
||||
IntermediateNode.__init__(self, inputs)
|
||||
|
||||
assert len(self.inputs) == 2
|
||||
|
||||
@@ -61,6 +51,7 @@ class IntermediateNode(ABC):
|
||||
and self.outputs == other.outputs
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def is_equivalent_to(self, other: object) -> bool:
|
||||
"""Overriding __eq__ has unwanted side effects, this provides the same facility without
|
||||
disrupting expected behavior too much
|
||||
@@ -72,11 +63,9 @@ class IntermediateNode(ABC):
|
||||
bool: True if the other object is equivalent
|
||||
"""
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
isinstance(other, IntermediateNode)
|
||||
and self.inputs == other.inputs
|
||||
and self.outputs == other.outputs
|
||||
and self.op_args == other.op_args
|
||||
and self.op_kwargs == other.op_kwargs
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@@ -142,6 +131,14 @@ class Input(IntermediateNode):
|
||||
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
|
||||
return inputs[0]
|
||||
|
||||
def is_equivalent_to(self, other: object) -> bool:
|
||||
return (
|
||||
isinstance(other, Input)
|
||||
and self.input_name == other.input_name
|
||||
and self.program_input_idx == other.program_input_idx
|
||||
and super().is_equivalent_to(other)
|
||||
)
|
||||
|
||||
|
||||
class ConstantInput(IntermediateNode):
|
||||
"""Node representing a constant of the program"""
|
||||
@@ -169,6 +166,13 @@ class ConstantInput(IntermediateNode):
|
||||
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
|
||||
return self.constant_data
|
||||
|
||||
def is_equivalent_to(self, other: object) -> bool:
|
||||
return (
|
||||
isinstance(other, ConstantInput)
|
||||
and self.constant_data == other.constant_data
|
||||
and super().is_equivalent_to(other)
|
||||
)
|
||||
|
||||
|
||||
class ArbitraryFunction(IntermediateNode):
|
||||
"""Node representing a univariate arbitrary function, e.g. sin(x)"""
|
||||
@@ -176,6 +180,8 @@ class ArbitraryFunction(IntermediateNode):
|
||||
# The arbitrary_func is not optional but mypy has a long standing bug and is not able to
|
||||
# understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623
|
||||
arbitrary_func: Optional[Callable]
|
||||
op_args: Tuple[Any, ...]
|
||||
op_kwargs: Dict[str, Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -185,9 +191,11 @@ class ArbitraryFunction(IntermediateNode):
|
||||
op_args: Optional[Tuple[Any, ...]] = None,
|
||||
op_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__([input_base_value], op_args=op_args, op_kwargs=op_kwargs)
|
||||
super().__init__([input_base_value])
|
||||
assert len(self.inputs) == 1
|
||||
self.arbitrary_func = arbitrary_func
|
||||
self.op_args = deepcopy(op_args) if op_args is not None else ()
|
||||
self.op_kwargs = deepcopy(op_kwargs) if op_kwargs is not None else {}
|
||||
# TLU/PBS has an encrypted output
|
||||
self.outputs = [EncryptedValue(output_dtype)]
|
||||
|
||||
@@ -195,3 +203,13 @@ class ArbitraryFunction(IntermediateNode):
|
||||
# This is the continuation of the mypy bug workaround
|
||||
assert self.arbitrary_func is not None
|
||||
return self.arbitrary_func(inputs[0], *self.op_args, **self.op_kwargs)
|
||||
|
||||
def is_equivalent_to(self, other: object) -> bool:
|
||||
# FIXME: comparing self.arbitrary_func to other.arbitrary_func will not work
|
||||
# Only evaluating over the same set of inputs and comparing will help
|
||||
return (
|
||||
isinstance(other, ArbitraryFunction)
|
||||
and self.op_args == other.op_args
|
||||
and self.op_kwargs == other.op_kwargs
|
||||
and super().is_equivalent_to(other)
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""This file holds the code that can be shared between tracers"""
|
||||
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import List, Tuple, Type, Union
|
||||
|
||||
from ..data_types import BaseValue
|
||||
from ..data_types.scalars import Scalars
|
||||
@@ -29,8 +29,6 @@ class BaseTracer(ABC):
|
||||
self,
|
||||
inputs: List[Union["BaseTracer", Scalars]],
|
||||
computation_to_trace: Type[ir.IntermediateNode],
|
||||
op_args: Optional[Tuple[Any, ...]] = None,
|
||||
op_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple["BaseTracer", ...]:
|
||||
"""Helper functions to instantiate all output BaseTracer for a given computation
|
||||
|
||||
@@ -38,9 +36,6 @@ class BaseTracer(ABC):
|
||||
inputs (List[BaseTracer]): Previous BaseTracer used as inputs for a new node
|
||||
computation_to_trace (Type[ir.IntermediateNode]): The IntermediateNode class
|
||||
to instantiate for the computation being traced
|
||||
op_args: *args coming from the call being traced
|
||||
op_kwargs: **kwargs coming from the call being traced
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[BaseTracer, ...]: A tuple containing an BaseTracer per output function
|
||||
@@ -56,8 +51,6 @@ class BaseTracer(ABC):
|
||||
|
||||
traced_computation = computation_to_trace(
|
||||
(x.output for x in sanitized_inputs),
|
||||
op_args=op_args,
|
||||
op_kwargs=op_kwargs,
|
||||
)
|
||||
|
||||
output_tracers = tuple(
|
||||
|
||||
@@ -81,3 +81,122 @@ def test_evaluate(
|
||||
):
|
||||
"""Test evaluate methods on IntermediateNodes"""
|
||||
assert node.evaluate(input_data) == expected_result
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"node1,node2,expected_result",
|
||||
[
|
||||
(
|
||||
ir.Add([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]),
|
||||
ir.Add([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]),
|
||||
True,
|
||||
),
|
||||
(
|
||||
ir.Add([EncryptedValue(Integer(16, False)), EncryptedValue(Integer(32, False))]),
|
||||
ir.Add([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(16, False))]),
|
||||
True,
|
||||
),
|
||||
(
|
||||
ir.Add([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]),
|
||||
ir.Sub([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]),
|
||||
False,
|
||||
),
|
||||
(
|
||||
ir.Sub([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]),
|
||||
ir.Sub([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]),
|
||||
True,
|
||||
),
|
||||
(
|
||||
ir.Sub([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(16, False))]),
|
||||
ir.Sub([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(16, False))]),
|
||||
True,
|
||||
),
|
||||
(
|
||||
ir.Sub([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(16, False))]),
|
||||
ir.Sub([EncryptedValue(Integer(16, False)), EncryptedValue(Integer(32, False))]),
|
||||
False,
|
||||
),
|
||||
(
|
||||
ir.Mul([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]),
|
||||
ir.Mul([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]),
|
||||
True,
|
||||
),
|
||||
(
|
||||
ir.Mul([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]),
|
||||
ir.Sub([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]),
|
||||
False,
|
||||
),
|
||||
(
|
||||
ir.Input(EncryptedValue(Integer(32, False)), "x", 0),
|
||||
ir.Sub([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]),
|
||||
False,
|
||||
),
|
||||
(
|
||||
ir.Input(EncryptedValue(Integer(32, False)), "x", 0),
|
||||
ir.Input(EncryptedValue(Integer(32, False)), "x", 0),
|
||||
True,
|
||||
),
|
||||
(
|
||||
ir.Input(EncryptedValue(Integer(32, False)), "x", 0),
|
||||
ir.Input(EncryptedValue(Integer(32, False)), "y", 0),
|
||||
False,
|
||||
),
|
||||
(
|
||||
ir.Input(EncryptedValue(Integer(32, False)), "x", 0),
|
||||
ir.Input(EncryptedValue(Integer(32, False)), "x", 1),
|
||||
False,
|
||||
),
|
||||
(
|
||||
ir.Input(EncryptedValue(Integer(32, False)), "x", 0),
|
||||
ir.Input(EncryptedValue(Integer(8, False)), "x", 0),
|
||||
False,
|
||||
),
|
||||
(
|
||||
ir.ConstantInput(10),
|
||||
ir.ConstantInput(10),
|
||||
True,
|
||||
),
|
||||
(
|
||||
ir.ConstantInput(10),
|
||||
ir.Input(EncryptedValue(Integer(8, False)), "x", 0),
|
||||
False,
|
||||
),
|
||||
(
|
||||
ir.ConstantInput(10),
|
||||
ir.ConstantInput(10.0),
|
||||
False,
|
||||
),
|
||||
(
|
||||
ir.ArbitraryFunction(EncryptedValue(Integer(8, False)), lambda x: x, Integer(8, False)),
|
||||
ir.ArbitraryFunction(EncryptedValue(Integer(8, False)), lambda x: x, Integer(8, False)),
|
||||
True,
|
||||
),
|
||||
(
|
||||
ir.ArbitraryFunction(
|
||||
EncryptedValue(Integer(8, False)),
|
||||
lambda x: x,
|
||||
Integer(8, False),
|
||||
op_args=(1, 2, 3),
|
||||
),
|
||||
ir.ArbitraryFunction(EncryptedValue(Integer(8, False)), lambda x: x, Integer(8, False)),
|
||||
False,
|
||||
),
|
||||
(
|
||||
ir.ArbitraryFunction(
|
||||
EncryptedValue(Integer(8, False)),
|
||||
lambda x: x,
|
||||
Integer(8, False),
|
||||
op_kwargs={"tuple": (1, 2, 3)},
|
||||
),
|
||||
ir.ArbitraryFunction(EncryptedValue(Integer(8, False)), lambda x: x, Integer(8, False)),
|
||||
False,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_is_equivalent_to(
|
||||
node1: ir.IntermediateNode,
|
||||
node2: ir.IntermediateNode,
|
||||
expected_result: bool,
|
||||
):
|
||||
"""Test is_equivalent_to methods on IntermediateNodes"""
|
||||
assert node1.is_equivalent_to(node2) == node2.is_equivalent_to(node1) == expected_result
|
||||
|
||||
Reference in New Issue
Block a user