From 371cdd5e6698b94318d7c32756dfcee2699de413 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Mon, 9 Aug 2021 10:09:05 +0200 Subject: [PATCH] fix(ir): make is_equivalent_to abstract - this allows to make sure each node has a proper implementation - move op_args and op_kwargs in ArbitraryFunction only - update BaseTracer accordingly --- hdk/common/representation/intermediate.py | 48 ++++--- hdk/common/tracing/base_tracer.py | 9 +- .../representation/test_intermediate.py | 119 ++++++++++++++++++ 3 files changed, 153 insertions(+), 23 deletions(-) diff --git a/hdk/common/representation/intermediate.py b/hdk/common/representation/intermediate.py index 11fdf1a3d..eef62fa8b 100644 --- a/hdk/common/representation/intermediate.py +++ b/hdk/common/representation/intermediate.py @@ -18,30 +18,20 @@ class IntermediateNode(ABC): inputs: List[BaseValue] outputs: List[BaseValue] - op_args: Tuple[Any, ...] - op_kwargs: Dict[str, Any] def __init__( self, inputs: Iterable[BaseValue], - op_args: Optional[Tuple[Any, ...]] = None, - op_kwargs: Optional[Dict[str, Any]] = None, ) -> None: self.inputs = list(inputs) assert all(isinstance(x, BaseValue) for x in self.inputs) - self.op_args = deepcopy(op_args) if op_args is not None else () - self.op_kwargs = deepcopy(op_kwargs) if op_kwargs is not None else {} def _init_binary( self, inputs: Iterable[BaseValue], - op_args: Optional[Tuple[Any, ...]] = None, - op_kwargs: Optional[Dict[str, Any]] = None, ) -> None: - assert op_args is None, f"Expected op_args to be None, got {op_args}" - assert op_kwargs is None, f"Expected op_kwargs to be None, got {op_kwargs}" - IntermediateNode.__init__(self, inputs, op_args=op_args, op_kwargs=op_kwargs) + IntermediateNode.__init__(self, inputs) assert len(self.inputs) == 2 @@ -61,6 +51,7 @@ class IntermediateNode(ABC): and self.outputs == other.outputs ) + @abstractmethod def is_equivalent_to(self, other: object) -> bool: """Overriding __eq__ has unwanted side effects, this provides the same facility without disrupting expected behavior too much @@ -72,11 +63,9 @@ class IntermediateNode(ABC): bool: True if the other object is equivalent """ return ( - isinstance(other, self.__class__) + isinstance(other, IntermediateNode) and self.inputs == other.inputs and self.outputs == other.outputs - and self.op_args == other.op_args - and self.op_kwargs == other.op_kwargs ) @abstractmethod @@ -142,6 +131,14 @@ class Input(IntermediateNode): def evaluate(self, inputs: Mapping[int, Any]) -> Any: return inputs[0] + def is_equivalent_to(self, other: object) -> bool: + return ( + isinstance(other, Input) + and self.input_name == other.input_name + and self.program_input_idx == other.program_input_idx + and super().is_equivalent_to(other) + ) + class ConstantInput(IntermediateNode): """Node representing a constant of the program""" @@ -169,6 +166,13 @@ class ConstantInput(IntermediateNode): def evaluate(self, inputs: Mapping[int, Any]) -> Any: return self.constant_data + def is_equivalent_to(self, other: object) -> bool: + return ( + isinstance(other, ConstantInput) + and self.constant_data == other.constant_data + and super().is_equivalent_to(other) + ) + class ArbitraryFunction(IntermediateNode): """Node representing a univariate arbitrary function, e.g. sin(x)""" @@ -176,6 +180,8 @@ class ArbitraryFunction(IntermediateNode): # The arbitrary_func is not optional but mypy has a long standing bug and is not able to # understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623 arbitrary_func: Optional[Callable] + op_args: Tuple[Any, ...] + op_kwargs: Dict[str, Any] def __init__( self, @@ -185,9 +191,11 @@ class ArbitraryFunction(IntermediateNode): op_args: Optional[Tuple[Any, ...]] = None, op_kwargs: Optional[Dict[str, Any]] = None, ) -> None: - super().__init__([input_base_value], op_args=op_args, op_kwargs=op_kwargs) + super().__init__([input_base_value]) assert len(self.inputs) == 1 self.arbitrary_func = arbitrary_func + self.op_args = deepcopy(op_args) if op_args is not None else () + self.op_kwargs = deepcopy(op_kwargs) if op_kwargs is not None else {} # TLU/PBS has an encrypted output self.outputs = [EncryptedValue(output_dtype)] @@ -195,3 +203,13 @@ class ArbitraryFunction(IntermediateNode): # This is the continuation of the mypy bug workaround assert self.arbitrary_func is not None return self.arbitrary_func(inputs[0], *self.op_args, **self.op_kwargs) + + def is_equivalent_to(self, other: object) -> bool: + # FIXME: comparing self.arbitrary_func to other.arbitrary_func will not work + # Only evaluating over the same set of inputs and comparing will help + return ( + isinstance(other, ArbitraryFunction) + and self.op_args == other.op_args + and self.op_kwargs == other.op_kwargs + and super().is_equivalent_to(other) + ) diff --git a/hdk/common/tracing/base_tracer.py b/hdk/common/tracing/base_tracer.py index 188f30e16..5d20cfb81 100644 --- a/hdk/common/tracing/base_tracer.py +++ b/hdk/common/tracing/base_tracer.py @@ -1,7 +1,7 @@ """This file holds the code that can be shared between tracers""" from abc import ABC -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import List, Tuple, Type, Union from ..data_types import BaseValue from ..data_types.scalars import Scalars @@ -29,8 +29,6 @@ class BaseTracer(ABC): self, inputs: List[Union["BaseTracer", Scalars]], computation_to_trace: Type[ir.IntermediateNode], - op_args: Optional[Tuple[Any, ...]] = None, - op_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple["BaseTracer", ...]: """Helper functions to instantiate all output BaseTracer for a given computation @@ -38,9 +36,6 @@ class BaseTracer(ABC): inputs (List[BaseTracer]): Previous BaseTracer used as inputs for a new node computation_to_trace (Type[ir.IntermediateNode]): The IntermediateNode class to instantiate for the computation being traced - op_args: *args coming from the call being traced - op_kwargs: **kwargs coming from the call being traced - Returns: Tuple[BaseTracer, ...]: A tuple containing an BaseTracer per output function @@ -56,8 +51,6 @@ class BaseTracer(ABC): traced_computation = computation_to_trace( (x.output for x in sanitized_inputs), - op_args=op_args, - op_kwargs=op_kwargs, ) output_tracers = tuple( diff --git a/tests/common/representation/test_intermediate.py b/tests/common/representation/test_intermediate.py index 530742be2..86c94db65 100644 --- a/tests/common/representation/test_intermediate.py +++ b/tests/common/representation/test_intermediate.py @@ -81,3 +81,122 @@ def test_evaluate( ): """Test evaluate methods on IntermediateNodes""" assert node.evaluate(input_data) == expected_result + + +@pytest.mark.parametrize( + "node1,node2,expected_result", + [ + ( + ir.Add([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]), + ir.Add([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]), + True, + ), + ( + ir.Add([EncryptedValue(Integer(16, False)), EncryptedValue(Integer(32, False))]), + ir.Add([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(16, False))]), + True, + ), + ( + ir.Add([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]), + ir.Sub([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]), + False, + ), + ( + ir.Sub([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]), + ir.Sub([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]), + True, + ), + ( + ir.Sub([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(16, False))]), + ir.Sub([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(16, False))]), + True, + ), + ( + ir.Sub([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(16, False))]), + ir.Sub([EncryptedValue(Integer(16, False)), EncryptedValue(Integer(32, False))]), + False, + ), + ( + ir.Mul([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]), + ir.Mul([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]), + True, + ), + ( + ir.Mul([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]), + ir.Sub([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]), + False, + ), + ( + ir.Input(EncryptedValue(Integer(32, False)), "x", 0), + ir.Sub([EncryptedValue(Integer(32, False)), EncryptedValue(Integer(32, False))]), + False, + ), + ( + ir.Input(EncryptedValue(Integer(32, False)), "x", 0), + ir.Input(EncryptedValue(Integer(32, False)), "x", 0), + True, + ), + ( + ir.Input(EncryptedValue(Integer(32, False)), "x", 0), + ir.Input(EncryptedValue(Integer(32, False)), "y", 0), + False, + ), + ( + ir.Input(EncryptedValue(Integer(32, False)), "x", 0), + ir.Input(EncryptedValue(Integer(32, False)), "x", 1), + False, + ), + ( + ir.Input(EncryptedValue(Integer(32, False)), "x", 0), + ir.Input(EncryptedValue(Integer(8, False)), "x", 0), + False, + ), + ( + ir.ConstantInput(10), + ir.ConstantInput(10), + True, + ), + ( + ir.ConstantInput(10), + ir.Input(EncryptedValue(Integer(8, False)), "x", 0), + False, + ), + ( + ir.ConstantInput(10), + ir.ConstantInput(10.0), + False, + ), + ( + ir.ArbitraryFunction(EncryptedValue(Integer(8, False)), lambda x: x, Integer(8, False)), + ir.ArbitraryFunction(EncryptedValue(Integer(8, False)), lambda x: x, Integer(8, False)), + True, + ), + ( + ir.ArbitraryFunction( + EncryptedValue(Integer(8, False)), + lambda x: x, + Integer(8, False), + op_args=(1, 2, 3), + ), + ir.ArbitraryFunction(EncryptedValue(Integer(8, False)), lambda x: x, Integer(8, False)), + False, + ), + ( + ir.ArbitraryFunction( + EncryptedValue(Integer(8, False)), + lambda x: x, + Integer(8, False), + op_kwargs={"tuple": (1, 2, 3)}, + ), + ir.ArbitraryFunction(EncryptedValue(Integer(8, False)), lambda x: x, Integer(8, False)), + False, + ), + ], +) +def test_is_equivalent_to( + node1: ir.IntermediateNode, + node2: ir.IntermediateNode, + expected_result: bool, +): + """Test is_equivalent_to methods on IntermediateNodes""" + assert node1.is_equivalent_to(node2) == node2.is_equivalent_to(node1) == expected_result