mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: allow tracing __eq__ while keeping hashing when not tracing
closes #936
This commit is contained in:
@@ -23,6 +23,10 @@ from ..values import BaseValue, TensorValue
|
||||
class BaseTracer(ABC):
|
||||
"""Base class for implementing tracers."""
|
||||
|
||||
# this variable changes the behavior of __eq__ so that it can be traced but still allows to hash
|
||||
# BaseTracers when not tracing.
|
||||
_is_tracing: bool = False
|
||||
|
||||
inputs: List["BaseTracer"]
|
||||
traced_computation: IntermediateNode
|
||||
output_idx: int
|
||||
@@ -63,6 +67,15 @@ class BaseTracer(ABC):
|
||||
BaseTracer: The BaseTracer for that constant.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def set_is_tracing(cls, is_tracing: bool) -> None:
|
||||
"""Set whether we are in a tracing context to change __eq__ behavior.
|
||||
|
||||
Args:
|
||||
is_tracing (bool): boolean to use to set whether we are tracing
|
||||
"""
|
||||
cls._is_tracing = is_tracing
|
||||
|
||||
@classmethod
|
||||
def _get_mix_values_func(cls):
|
||||
return cls._mix_values_func
|
||||
@@ -193,6 +206,9 @@ class BaseTracer(ABC):
|
||||
|
||||
return result_tracer
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return id(self)
|
||||
|
||||
def __add__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
if not self._supports_other_operand(other):
|
||||
return NotImplemented
|
||||
@@ -268,6 +284,18 @@ class BaseTracer(ABC):
|
||||
self, other, lambda x, y: x <= y, "le"
|
||||
)
|
||||
|
||||
def __eq__(self, other: Union["BaseTracer", Any]):
|
||||
# x == cst
|
||||
# Return the tracer if we are tracing, else return the result of the default __eq__ function
|
||||
# allows to have hash capabilities outside of tracing
|
||||
return (
|
||||
self._helper_for_binary_functions_with_one_cst_input(
|
||||
self, other, lambda x, y: x == y, "eq"
|
||||
)
|
||||
if self._is_tracing
|
||||
else self is other
|
||||
)
|
||||
|
||||
def __ne__(self, other: Union["BaseTracer", Any]):
|
||||
# x != cst
|
||||
return self._helper_for_binary_functions_with_one_cst_input(
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Helper functions for tracing."""
|
||||
import collections
|
||||
from contextlib import contextmanager
|
||||
from inspect import signature
|
||||
from typing import Callable, Dict, Iterable, OrderedDict, Set, Type
|
||||
from typing import Callable, Dict, Iterable, List, OrderedDict, Set, Type
|
||||
|
||||
import networkx as nx
|
||||
from networkx.algorithms.dag import is_directed_acyclic_graph
|
||||
@@ -141,3 +142,21 @@ def create_graph_from_output_tracers(
|
||||
assert_true(len(unique_edges) == number_of_edges)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tracing_context(tracer_classes: List[Type[BaseTracer]]):
|
||||
"""Set tracer classes in tracing mode.
|
||||
|
||||
Args:
|
||||
tracer_classes (List[Type[BaseTracer]]): The list of tracers for which we should enable
|
||||
tracing.
|
||||
"""
|
||||
|
||||
try:
|
||||
for tracer_class in tracer_classes:
|
||||
tracer_class.set_is_tracing(True)
|
||||
yield
|
||||
finally:
|
||||
for tracer_class in tracer_classes:
|
||||
tracer_class.set_is_tracing(False)
|
||||
|
||||
@@ -11,6 +11,7 @@ from ..common.debugging.custom_assert import assert_true
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.representation.intermediate import Constant, Dot, GenericFunction, MatMul
|
||||
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
|
||||
from ..common.tracing.tracing_helpers import tracing_context
|
||||
from ..common.values import BaseValue, TensorValue
|
||||
from .np_dtypes_helpers import (
|
||||
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES,
|
||||
@@ -697,7 +698,9 @@ def trace_numpy_function(
|
||||
|
||||
# We could easily create a graph of NPTracer, but we may end up with dead nodes starting from
|
||||
# the inputs that's why we create the graph starting from the outputs
|
||||
output_tracers = function_to_trace(**input_tracers)
|
||||
with tracing_context([NPTracer]):
|
||||
output_tracers = function_to_trace(**input_tracers)
|
||||
|
||||
if isinstance(output_tracers, NPTracer):
|
||||
output_tracers = (output_tracers,)
|
||||
|
||||
|
||||
@@ -776,17 +776,16 @@ def test_tracing_numpy_calls(
|
||||
)
|
||||
],
|
||||
),
|
||||
# FIXME: coming soon, #936
|
||||
# pytest.param(
|
||||
# lambda x: x == 11,
|
||||
# [
|
||||
# (
|
||||
# EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
# numpy.arange(15).reshape(3, 5),
|
||||
# numpy.arange(15).reshape(3, 5) == 11,
|
||||
# )
|
||||
# ],
|
||||
# ),
|
||||
pytest.param(
|
||||
lambda x: x == 11,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(15).reshape(3, 5) == 11,
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x != 12,
|
||||
[
|
||||
@@ -839,17 +838,16 @@ def test_tracing_numpy_calls(
|
||||
)
|
||||
],
|
||||
),
|
||||
# FIXME: coming soon, #936
|
||||
# pytest.param(
|
||||
# lambda x: 11 == x,
|
||||
# [
|
||||
# (
|
||||
# EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
# numpy.arange(15).reshape(3, 5),
|
||||
# 11 == numpy.arange(15).reshape(3, 5),
|
||||
# )
|
||||
# ],
|
||||
# ),
|
||||
pytest.param(
|
||||
lambda x: 11 == x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
11 == numpy.arange(15).reshape(3, 5),
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: 12 != x,
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user