feat: allow tracing __eq__ while keeping hashing when not tracing

closes #936
This commit is contained in:
Arthur Meyre
2021-11-17 18:18:14 +01:00
parent c978107124
commit bc145e21e1
4 changed files with 72 additions and 24 deletions

View File

@@ -23,6 +23,10 @@ from ..values import BaseValue, TensorValue
class BaseTracer(ABC):
"""Base class for implementing tracers."""
# this variable changes the behavior of __eq__ so that it can be traced but still allows to hash
# BaseTracers when not tracing.
_is_tracing: bool = False
inputs: List["BaseTracer"]
traced_computation: IntermediateNode
output_idx: int
@@ -63,6 +67,15 @@ class BaseTracer(ABC):
BaseTracer: The BaseTracer for that constant.
"""
@classmethod
def set_is_tracing(cls, is_tracing: bool) -> None:
"""Set whether we are in a tracing context to change __eq__ behavior.
Args:
is_tracing (bool): boolean to use to set whether we are tracing
"""
cls._is_tracing = is_tracing
@classmethod
def _get_mix_values_func(cls):
return cls._mix_values_func
@@ -193,6 +206,9 @@ class BaseTracer(ABC):
return result_tracer
def __hash__(self) -> int:
return id(self)
def __add__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
if not self._supports_other_operand(other):
return NotImplemented
@@ -268,6 +284,18 @@ class BaseTracer(ABC):
self, other, lambda x, y: x <= y, "le"
)
def __eq__(self, other: Union["BaseTracer", Any]):
# x == cst
# Return the tracer if we are tracing, else return the result of the default __eq__ function
# allows to have hash capabilities outside of tracing
return (
self._helper_for_binary_functions_with_one_cst_input(
self, other, lambda x, y: x == y, "eq"
)
if self._is_tracing
else self is other
)
def __ne__(self, other: Union["BaseTracer", Any]):
# x != cst
return self._helper_for_binary_functions_with_one_cst_input(

View File

@@ -1,7 +1,8 @@
"""Helper functions for tracing."""
import collections
from contextlib import contextmanager
from inspect import signature
from typing import Callable, Dict, Iterable, OrderedDict, Set, Type
from typing import Callable, Dict, Iterable, List, OrderedDict, Set, Type
import networkx as nx
from networkx.algorithms.dag import is_directed_acyclic_graph
@@ -141,3 +142,21 @@ def create_graph_from_output_tracers(
assert_true(len(unique_edges) == number_of_edges)
return graph
@contextmanager
def tracing_context(tracer_classes: List[Type[BaseTracer]]):
"""Set tracer classes in tracing mode.
Args:
tracer_classes (List[Type[BaseTracer]]): The list of tracers for which we should enable
tracing.
"""
try:
for tracer_class in tracer_classes:
tracer_class.set_is_tracing(True)
yield
finally:
for tracer_class in tracer_classes:
tracer_class.set_is_tracing(False)

View File

@@ -11,6 +11,7 @@ from ..common.debugging.custom_assert import assert_true
from ..common.operator_graph import OPGraph
from ..common.representation.intermediate import Constant, Dot, GenericFunction, MatMul
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
from ..common.tracing.tracing_helpers import tracing_context
from ..common.values import BaseValue, TensorValue
from .np_dtypes_helpers import (
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES,
@@ -697,7 +698,9 @@ def trace_numpy_function(
# We could easily create a graph of NPTracer, but we may end up with dead nodes starting from
# the inputs that's why we create the graph starting from the outputs
output_tracers = function_to_trace(**input_tracers)
with tracing_context([NPTracer]):
output_tracers = function_to_trace(**input_tracers)
if isinstance(output_tracers, NPTracer):
output_tracers = (output_tracers,)

View File

@@ -776,17 +776,16 @@ def test_tracing_numpy_calls(
)
],
),
# FIXME: coming soon, #936
# pytest.param(
# lambda x: x == 11,
# [
# (
# EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
# numpy.arange(15).reshape(3, 5),
# numpy.arange(15).reshape(3, 5) == 11,
# )
# ],
# ),
pytest.param(
lambda x: x == 11,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
numpy.arange(15).reshape(3, 5) == 11,
)
],
),
pytest.param(
lambda x: x != 12,
[
@@ -839,17 +838,16 @@ def test_tracing_numpy_calls(
)
],
),
# FIXME: coming soon, #936
# pytest.param(
# lambda x: 11 == x,
# [
# (
# EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
# numpy.arange(15).reshape(3, 5),
# 11 == numpy.arange(15).reshape(3, 5),
# )
# ],
# ),
pytest.param(
lambda x: 11 == x,
[
(
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
numpy.arange(15).reshape(3, 5),
11 == numpy.arange(15).reshape(3, 5),
)
],
),
pytest.param(
lambda x: 12 != x,
[