fix(ir): make is_equivalent_to abstract

- this allows to make sure each node has a proper implementation
- move op_args and op_kwargs in ArbitraryFunction only
- update BaseTracer accordingly
This commit is contained in:
Arthur Meyre
2021-08-09 10:09:05 +02:00
parent d09f1b90a6
commit 371cdd5e66
3 changed files with 153 additions and 23 deletions

View File

@@ -18,30 +18,20 @@ class IntermediateNode(ABC):
inputs: List[BaseValue]
outputs: List[BaseValue]
op_args: Tuple[Any, ...]
op_kwargs: Dict[str, Any]
def __init__(
self,
inputs: Iterable[BaseValue],
op_args: Optional[Tuple[Any, ...]] = None,
op_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
self.inputs = list(inputs)
assert all(isinstance(x, BaseValue) for x in self.inputs)
self.op_args = deepcopy(op_args) if op_args is not None else ()
self.op_kwargs = deepcopy(op_kwargs) if op_kwargs is not None else {}
def _init_binary(
self,
inputs: Iterable[BaseValue],
op_args: Optional[Tuple[Any, ...]] = None,
op_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
assert op_args is None, f"Expected op_args to be None, got {op_args}"
assert op_kwargs is None, f"Expected op_kwargs to be None, got {op_kwargs}"
IntermediateNode.__init__(self, inputs, op_args=op_args, op_kwargs=op_kwargs)
IntermediateNode.__init__(self, inputs)
assert len(self.inputs) == 2
@@ -61,6 +51,7 @@ class IntermediateNode(ABC):
and self.outputs == other.outputs
)
@abstractmethod
def is_equivalent_to(self, other: object) -> bool:
"""Overriding __eq__ has unwanted side effects, this provides the same facility without
disrupting expected behavior too much
@@ -72,11 +63,9 @@ class IntermediateNode(ABC):
bool: True if the other object is equivalent
"""
return (
isinstance(other, self.__class__)
isinstance(other, IntermediateNode)
and self.inputs == other.inputs
and self.outputs == other.outputs
and self.op_args == other.op_args
and self.op_kwargs == other.op_kwargs
)
@abstractmethod
@@ -142,6 +131,14 @@ class Input(IntermediateNode):
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
return inputs[0]
def is_equivalent_to(self, other: object) -> bool:
return (
isinstance(other, Input)
and self.input_name == other.input_name
and self.program_input_idx == other.program_input_idx
and super().is_equivalent_to(other)
)
class ConstantInput(IntermediateNode):
"""Node representing a constant of the program"""
@@ -169,6 +166,13 @@ class ConstantInput(IntermediateNode):
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
return self.constant_data
def is_equivalent_to(self, other: object) -> bool:
return (
isinstance(other, ConstantInput)
and self.constant_data == other.constant_data
and super().is_equivalent_to(other)
)
class ArbitraryFunction(IntermediateNode):
"""Node representing a univariate arbitrary function, e.g. sin(x)"""
@@ -176,6 +180,8 @@ class ArbitraryFunction(IntermediateNode):
# The arbitrary_func is not optional but mypy has a long standing bug and is not able to
# understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623
arbitrary_func: Optional[Callable]
op_args: Tuple[Any, ...]
op_kwargs: Dict[str, Any]
def __init__(
self,
@@ -185,9 +191,11 @@ class ArbitraryFunction(IntermediateNode):
op_args: Optional[Tuple[Any, ...]] = None,
op_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__([input_base_value], op_args=op_args, op_kwargs=op_kwargs)
super().__init__([input_base_value])
assert len(self.inputs) == 1
self.arbitrary_func = arbitrary_func
self.op_args = deepcopy(op_args) if op_args is not None else ()
self.op_kwargs = deepcopy(op_kwargs) if op_kwargs is not None else {}
# TLU/PBS has an encrypted output
self.outputs = [EncryptedValue(output_dtype)]
@@ -195,3 +203,13 @@ class ArbitraryFunction(IntermediateNode):
# This is the continuation of the mypy bug workaround
assert self.arbitrary_func is not None
return self.arbitrary_func(inputs[0], *self.op_args, **self.op_kwargs)
def is_equivalent_to(self, other: object) -> bool:
# FIXME: comparing self.arbitrary_func to other.arbitrary_func will not work
# Only evaluating over the same set of inputs and comparing will help
return (
isinstance(other, ArbitraryFunction)
and self.op_args == other.op_args
and self.op_kwargs == other.op_kwargs
and super().is_equivalent_to(other)
)

View File

@@ -1,7 +1,7 @@
"""This file holds the code that can be shared between tracers"""
from abc import ABC
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import List, Tuple, Type, Union
from ..data_types import BaseValue
from ..data_types.scalars import Scalars
@@ -29,8 +29,6 @@ class BaseTracer(ABC):
self,
inputs: List[Union["BaseTracer", Scalars]],
computation_to_trace: Type[ir.IntermediateNode],
op_args: Optional[Tuple[Any, ...]] = None,
op_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple["BaseTracer", ...]:
"""Helper functions to instantiate all output BaseTracer for a given computation
@@ -38,9 +36,6 @@ class BaseTracer(ABC):
inputs (List[BaseTracer]): Previous BaseTracer used as inputs for a new node
computation_to_trace (Type[ir.IntermediateNode]): The IntermediateNode class
to instantiate for the computation being traced
op_args: *args coming from the call being traced
op_kwargs: **kwargs coming from the call being traced
Returns:
Tuple[BaseTracer, ...]: A tuple containing an BaseTracer per output function
@@ -56,8 +51,6 @@ class BaseTracer(ABC):
traced_computation = computation_to_trace(
(x.output for x in sanitized_inputs),
op_args=op_args,
op_kwargs=op_kwargs,
)
output_tracers = tuple(

View File

@@ -81,3 +81,122 @@ def test_evaluate(
):
"""Test evaluate methods on IntermediateNodes"""
assert node.evaluate(input_data) == expected_result
@pytest.mark.parametrize(
"node1,node2,expected_result",
[
(
ir.Add([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]),
ir.Add([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]),
True,
),
(
ir.Add([EncryptedValue(Integer(16, False)), EncryptedValue(Integer(32, False))]),
ir.Add([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(16, False))]),
True,
),
(
ir.Add([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]),
ir.Sub([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]),
False,
),
(
ir.Sub([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]),
ir.Sub([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]),
True,
),
(
ir.Sub([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(16, False))]),
ir.Sub([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(16, False))]),
True,
),
(
ir.Sub([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(16, False))]),
ir.Sub([EncryptedValue(Integer(16, False)), EncryptedValue(Integer(32, False))]),
False,
),
(
ir.Mul([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]),
ir.Mul([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]),
True,
),
(
ir.Mul([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]),
ir.Sub([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]),
False,
),
(
ir.Input(EncryptedValue(Integer(32, False)), "x", 0),
ir.Sub([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]),
False,
),
(
ir.Input(EncryptedValue(Integer(32, False)), "x", 0),
ir.Input(EncryptedValue(Integer(32, False)), "x", 0),
True,
),
(
ir.Input(EncryptedValue(Integer(32, False)), "x", 0),
ir.Input(EncryptedValue(Integer(32, False)), "y", 0),
False,
),
(
ir.Input(EncryptedValue(Integer(32, False)), "x", 0),
ir.Input(EncryptedValue(Integer(32, False)), "x", 1),
False,
),
(
ir.Input(EncryptedValue(Integer(32, False)), "x", 0),
ir.Input(EncryptedValue(Integer(8, False)), "x", 0),
False,
),
(
ir.ConstantInput(10),
ir.ConstantInput(10),
True,
),
(
ir.ConstantInput(10),
ir.Input(EncryptedValue(Integer(8, False)), "x", 0),
False,
),
(
ir.ConstantInput(10),
ir.ConstantInput(10.0),
False,
),
(
ir.ArbitraryFunction(EncryptedValue(Integer(8, False)), lambda x: x, Integer(8, False)),
ir.ArbitraryFunction(EncryptedValue(Integer(8, False)), lambda x: x, Integer(8, False)),
True,
),
(
ir.ArbitraryFunction(
EncryptedValue(Integer(8, False)),
lambda x: x,
Integer(8, False),
op_args=(1, 2, 3),
),
ir.ArbitraryFunction(EncryptedValue(Integer(8, False)), lambda x: x, Integer(8, False)),
False,
),
(
ir.ArbitraryFunction(
EncryptedValue(Integer(8, False)),
lambda x: x,
Integer(8, False),
op_kwargs={"tuple": (1, 2, 3)},
),
ir.ArbitraryFunction(EncryptedValue(Integer(8, False)), lambda x: x, Integer(8, False)),
False,
),
],
)
def test_is_equivalent_to(
node1: ir.IntermediateNode,
node2: ir.IntermediateNode,
expected_result: bool,
):
"""Test is_equivalent_to methods on IntermediateNodes"""
assert node1.is_equivalent_to(node2) == node2.is_equivalent_to(node1) == expected_result