From bc145e21e148710895939f73ac1c2088d4d83f3d Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Wed, 17 Nov 2021 18:18:14 +0100 Subject: [PATCH] feat: allow tracing __eq__ while keeping hashing when not tracing closes #936 --- concrete/common/tracing/base_tracer.py | 28 +++++++++++++++ concrete/common/tracing/tracing_helpers.py | 21 ++++++++++- concrete/numpy/tracing.py | 5 ++- tests/numpy/test_tracing.py | 42 +++++++++++----------- 4 files changed, 72 insertions(+), 24 deletions(-) diff --git a/concrete/common/tracing/base_tracer.py b/concrete/common/tracing/base_tracer.py index 8a37b3530..6bfcd410f 100644 --- a/concrete/common/tracing/base_tracer.py +++ b/concrete/common/tracing/base_tracer.py @@ -23,6 +23,10 @@ from ..values import BaseValue, TensorValue class BaseTracer(ABC): """Base class for implementing tracers.""" + # this variable changes the behavior of __eq__ so that it can be traced but still allows to hash + # BaseTracers when not tracing. + _is_tracing: bool = False + inputs: List["BaseTracer"] traced_computation: IntermediateNode output_idx: int @@ -63,6 +67,15 @@ class BaseTracer(ABC): BaseTracer: The BaseTracer for that constant. """ + @classmethod + def set_is_tracing(cls, is_tracing: bool) -> None: + """Set whether we are in a tracing context to change __eq__ behavior. + + Args: + is_tracing (bool): boolean to use to set whether we are tracing + """ + cls._is_tracing = is_tracing + @classmethod def _get_mix_values_func(cls): return cls._mix_values_func @@ -193,6 +206,9 @@ class BaseTracer(ABC): return result_tracer + def __hash__(self) -> int: + return id(self) + def __add__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": if not self._supports_other_operand(other): return NotImplemented @@ -268,6 +284,18 @@ class BaseTracer(ABC): self, other, lambda x, y: x <= y, "le" ) + def __eq__(self, other: Union["BaseTracer", Any]): + # x == cst + # Return the tracer if we are tracing, else return the result of the default __eq__ function + # allows to have hash capabilities outside of tracing + return ( + self._helper_for_binary_functions_with_one_cst_input( + self, other, lambda x, y: x == y, "eq" + ) + if self._is_tracing + else self is other + ) + def __ne__(self, other: Union["BaseTracer", Any]): # x != cst return self._helper_for_binary_functions_with_one_cst_input( diff --git a/concrete/common/tracing/tracing_helpers.py b/concrete/common/tracing/tracing_helpers.py index 8d114ed35..85b5492a6 100644 --- a/concrete/common/tracing/tracing_helpers.py +++ b/concrete/common/tracing/tracing_helpers.py @@ -1,7 +1,8 @@ """Helper functions for tracing.""" import collections +from contextlib import contextmanager from inspect import signature -from typing import Callable, Dict, Iterable, OrderedDict, Set, Type +from typing import Callable, Dict, Iterable, List, OrderedDict, Set, Type import networkx as nx from networkx.algorithms.dag import is_directed_acyclic_graph @@ -141,3 +142,21 @@ def create_graph_from_output_tracers( assert_true(len(unique_edges) == number_of_edges) return graph + + +@contextmanager +def tracing_context(tracer_classes: List[Type[BaseTracer]]): + """Set tracer classes in tracing mode. + + Args: + tracer_classes (List[Type[BaseTracer]]): The list of tracers for which we should enable + tracing. + """ + + try: + for tracer_class in tracer_classes: + tracer_class.set_is_tracing(True) + yield + finally: + for tracer_class in tracer_classes: + tracer_class.set_is_tracing(False) diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index 697f07815..069e082e4 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -11,6 +11,7 @@ from ..common.debugging.custom_assert import assert_true from ..common.operator_graph import OPGraph from ..common.representation.intermediate import Constant, Dot, GenericFunction, MatMul from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters +from ..common.tracing.tracing_helpers import tracing_context from ..common.values import BaseValue, TensorValue from .np_dtypes_helpers import ( SUPPORTED_NUMPY_DTYPES_CLASS_TYPES, @@ -697,7 +698,9 @@ def trace_numpy_function( # We could easily create a graph of NPTracer, but we may end up with dead nodes starting from # the inputs that's why we create the graph starting from the outputs - output_tracers = function_to_trace(**input_tracers) + with tracing_context([NPTracer]): + output_tracers = function_to_trace(**input_tracers) + if isinstance(output_tracers, NPTracer): output_tracers = (output_tracers,) diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index a578c7906..df2daaa12 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -776,17 +776,16 @@ def test_tracing_numpy_calls( ) ], ), - # FIXME: coming soon, #936 - # pytest.param( - # lambda x: x == 11, - # [ - # ( - # EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), - # numpy.arange(15).reshape(3, 5), - # numpy.arange(15).reshape(3, 5) == 11, - # ) - # ], - # ), + pytest.param( + lambda x: x == 11, + [ + ( + EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + numpy.arange(15).reshape(3, 5) == 11, + ) + ], + ), pytest.param( lambda x: x != 12, [ @@ -839,17 +838,16 @@ def test_tracing_numpy_calls( ) ], ), - # FIXME: coming soon, #936 - # pytest.param( - # lambda x: 11 == x, - # [ - # ( - # EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), - # numpy.arange(15).reshape(3, 5), - # 11 == numpy.arange(15).reshape(3, 5), - # ) - # ], - # ), + pytest.param( + lambda x: 11 == x, + [ + ( + EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + 11 == numpy.arange(15).reshape(3, 5), + ) + ], + ), pytest.param( lambda x: 12 != x, [