mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor: move is_equivalent to conftest.py for tests
This commit is contained in:
@@ -49,41 +49,6 @@ class IntermediateNode(ABC):
|
||||
|
||||
self.outputs = [mix_values_func(self.inputs[0], self.inputs[1])]
|
||||
|
||||
def _is_equivalent_to_binary_commutative(self, other: object) -> bool:
|
||||
"""is_equivalent_to for a binary and commutative operation."""
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and (self.inputs == other.inputs or self.inputs == other.inputs[::-1])
|
||||
and self.outputs == other.outputs
|
||||
)
|
||||
|
||||
def _is_equivalent_to_binary_non_commutative(self, other: object) -> bool:
|
||||
"""is_equivalent_to for a binary and non-commutative operation."""
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and self.inputs == other.inputs
|
||||
and self.outputs == other.outputs
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def is_equivalent_to(self, other: object) -> bool:
|
||||
"""Alternative to __eq__ to check equivalence between IntermediateNodes.
|
||||
|
||||
Overriding __eq__ has unwanted side effects, this provides the same facility without
|
||||
disrupting expected behavior too much
|
||||
|
||||
Args:
|
||||
other (object): Other object to check against
|
||||
|
||||
Returns:
|
||||
bool: True if the other object is equivalent
|
||||
"""
|
||||
return (
|
||||
isinstance(other, IntermediateNode)
|
||||
and self.inputs == other.inputs
|
||||
and self.outputs == other.outputs
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
"""Function to simulate what the represented computation would output for the given inputs.
|
||||
@@ -129,7 +94,6 @@ class Add(IntermediateNode):
|
||||
_n_in: int = 2
|
||||
|
||||
__init__ = IntermediateNode._init_binary
|
||||
is_equivalent_to = IntermediateNode._is_equivalent_to_binary_commutative
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0] + inputs[1]
|
||||
@@ -144,7 +108,6 @@ class Sub(IntermediateNode):
|
||||
_n_in: int = 2
|
||||
|
||||
__init__ = IntermediateNode._init_binary
|
||||
is_equivalent_to = IntermediateNode._is_equivalent_to_binary_non_commutative
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0] - inputs[1]
|
||||
@@ -159,7 +122,6 @@ class Mul(IntermediateNode):
|
||||
_n_in: int = 2
|
||||
|
||||
__init__ = IntermediateNode._init_binary
|
||||
is_equivalent_to = IntermediateNode._is_equivalent_to_binary_commutative
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0] * inputs[1]
|
||||
@@ -190,14 +152,6 @@ class Input(IntermediateNode):
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0]
|
||||
|
||||
def is_equivalent_to(self, other: object) -> bool:
|
||||
return (
|
||||
isinstance(other, Input)
|
||||
and self.input_name == other.input_name
|
||||
and self.program_input_idx == other.program_input_idx
|
||||
and super().is_equivalent_to(other)
|
||||
)
|
||||
|
||||
def label(self) -> str:
|
||||
return self.input_name
|
||||
|
||||
@@ -225,13 +179,6 @@ class Constant(IntermediateNode):
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return self.constant_data
|
||||
|
||||
def is_equivalent_to(self, other: object) -> bool:
|
||||
return (
|
||||
isinstance(other, Constant)
|
||||
and self.constant_data == other.constant_data
|
||||
and super().is_equivalent_to(other)
|
||||
)
|
||||
|
||||
@property
|
||||
def constant_data(self) -> Any:
|
||||
"""Returns the constant_data stored in the Constant node.
|
||||
@@ -278,17 +225,6 @@ class ArbitraryFunction(IntermediateNode):
|
||||
assert self.arbitrary_func is not None
|
||||
return self.arbitrary_func(inputs[0], *self.op_args, **self.op_kwargs)
|
||||
|
||||
def is_equivalent_to(self, other: object) -> bool:
|
||||
# FIXME: comparing self.arbitrary_func to other.arbitrary_func will not work
|
||||
# Only evaluating over the same set of inputs and comparing will help
|
||||
return (
|
||||
isinstance(other, ArbitraryFunction)
|
||||
and self.op_args == other.op_args
|
||||
and self.op_kwargs == other.op_kwargs
|
||||
and self.op_name == other.op_name
|
||||
and super().is_equivalent_to(other)
|
||||
)
|
||||
|
||||
def label(self) -> str:
|
||||
return self.op_name
|
||||
|
||||
@@ -344,12 +280,5 @@ class Dot(IntermediateNode):
|
||||
assert self.evaluation_function is not None
|
||||
return self.evaluation_function(inputs[0], inputs[1])
|
||||
|
||||
def is_equivalent_to(self, other: object) -> bool:
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and self.evaluation_function == other.evaluation_function
|
||||
and super().is_equivalent_to(other)
|
||||
)
|
||||
|
||||
def label(self) -> str:
|
||||
return "dot"
|
||||
|
||||
@@ -276,6 +276,11 @@ def test_is_equivalent_to(
|
||||
node1: ir.IntermediateNode,
|
||||
node2: ir.IntermediateNode,
|
||||
expected_result: bool,
|
||||
test_helpers,
|
||||
):
|
||||
"""Test is_equivalent_to methods on IntermediateNodes"""
|
||||
assert node1.is_equivalent_to(node2) == node2.is_equivalent_to(node1) == expected_result
|
||||
assert (
|
||||
test_helpers.nodes_are_equivalent(node1, node2)
|
||||
== test_helpers.nodes_are_equivalent(node2, node1)
|
||||
== expected_result
|
||||
)
|
||||
|
||||
@@ -1,19 +1,142 @@
|
||||
"""PyTest configuration file"""
|
||||
from typing import Callable, Dict, Type
|
||||
|
||||
import networkx as nx
|
||||
import networkx.algorithms.isomorphism as iso
|
||||
import pytest
|
||||
|
||||
from hdk.common.representation.intermediate import (
|
||||
ALL_IR_NODES,
|
||||
Add,
|
||||
ArbitraryFunction,
|
||||
Constant,
|
||||
Dot,
|
||||
Input,
|
||||
IntermediateNode,
|
||||
Mul,
|
||||
Sub,
|
||||
)
|
||||
|
||||
|
||||
def _is_equivalent_to_binary_commutative(lhs: IntermediateNode, rhs: object) -> bool:
|
||||
"""is_equivalent_to for a binary and commutative operation."""
|
||||
return (
|
||||
isinstance(rhs, lhs.__class__)
|
||||
and (lhs.inputs == rhs.inputs or lhs.inputs == rhs.inputs[::-1])
|
||||
and lhs.outputs == rhs.outputs
|
||||
)
|
||||
|
||||
|
||||
def _is_equivalent_to_binary_non_commutative(lhs: IntermediateNode, rhs: object) -> bool:
|
||||
"""is_equivalent_to for a binary and non-commutative operation."""
|
||||
return (
|
||||
isinstance(rhs, lhs.__class__) and lhs.inputs == rhs.inputs and lhs.outputs == rhs.outputs
|
||||
)
|
||||
|
||||
|
||||
def is_equivalent_add(lhs: Add, rhs: object) -> bool:
|
||||
"""Helper function to check if an Add node is equivalent to an other object."""
|
||||
return _is_equivalent_to_binary_commutative(lhs, rhs)
|
||||
|
||||
|
||||
def is_equivalent_arbitrary_function(lhs: ArbitraryFunction, rhs: object) -> bool:
|
||||
"""Helper function to check if an ArbitraryFunction node is equivalent to an other object."""
|
||||
return (
|
||||
isinstance(rhs, ArbitraryFunction)
|
||||
and lhs.op_args == rhs.op_args
|
||||
and lhs.op_kwargs == rhs.op_kwargs
|
||||
and lhs.op_name == rhs.op_name
|
||||
and is_equivalent_intermediate_node(lhs, rhs)
|
||||
)
|
||||
|
||||
|
||||
def is_equivalent_constant(lhs: Constant, rhs: object) -> bool:
|
||||
"""Helper function to check if a Constant node is equivalent to an other object."""
|
||||
return (
|
||||
isinstance(rhs, Constant)
|
||||
and lhs.constant_data == rhs.constant_data
|
||||
and is_equivalent_intermediate_node(lhs, rhs)
|
||||
)
|
||||
|
||||
|
||||
def is_equivalent_dot(lhs: Dot, rhs: object) -> bool:
|
||||
"""Helper function to check if a Dot node is equivalent to an other object."""
|
||||
return (
|
||||
isinstance(rhs, Dot)
|
||||
and lhs.evaluation_function == rhs.evaluation_function
|
||||
and is_equivalent_intermediate_node(lhs, rhs)
|
||||
)
|
||||
|
||||
|
||||
def is_equivalent_input(lhs: Input, rhs: object) -> bool:
|
||||
"""Helper function to check if an Input node is equivalent to an other object."""
|
||||
return (
|
||||
isinstance(rhs, Input)
|
||||
and lhs.input_name == rhs.input_name
|
||||
and lhs.program_input_idx == rhs.program_input_idx
|
||||
and is_equivalent_intermediate_node(lhs, rhs)
|
||||
)
|
||||
|
||||
|
||||
def is_equivalent_mul(lhs: Mul, rhs: object) -> bool:
|
||||
"""Helper function to check if a Mul node is equivalent to an other object."""
|
||||
return _is_equivalent_to_binary_commutative(lhs, rhs)
|
||||
|
||||
|
||||
def is_equivalent_sub(lhs: Sub, rhs: object) -> bool:
|
||||
"""Helper function to check if a Sub node is equivalent to an other object."""
|
||||
return _is_equivalent_to_binary_non_commutative(lhs, rhs)
|
||||
|
||||
|
||||
def is_equivalent_intermediate_node(lhs: IntermediateNode, rhs: object) -> bool:
|
||||
"""Helper function to check if an IntermediateNode node is equivalent to an other object."""
|
||||
return (
|
||||
isinstance(rhs, IntermediateNode)
|
||||
and lhs.inputs == rhs.inputs
|
||||
and lhs.outputs == rhs.outputs
|
||||
)
|
||||
|
||||
|
||||
EQUIVALENT_TEST_FUNC: Dict[Type, Callable[..., bool]] = {
|
||||
Add: is_equivalent_add,
|
||||
ArbitraryFunction: is_equivalent_arbitrary_function,
|
||||
Constant: is_equivalent_constant,
|
||||
Dot: is_equivalent_dot,
|
||||
Input: is_equivalent_input,
|
||||
Mul: is_equivalent_mul,
|
||||
Sub: is_equivalent_sub,
|
||||
}
|
||||
|
||||
_missing_nodes_in_mapping = ALL_IR_NODES - EQUIVALENT_TEST_FUNC.keys()
|
||||
assert len(_missing_nodes_in_mapping) == 0, (
|
||||
f"Missing IR node in EQUIVALENT_TEST_FUNC : "
|
||||
f"{', '.join(sorted(str(node_type) for node_type in _missing_nodes_in_mapping))}"
|
||||
)
|
||||
|
||||
del _missing_nodes_in_mapping
|
||||
|
||||
|
||||
class TestHelpers:
|
||||
"""Class allowing to pass helper functions to tests"""
|
||||
|
||||
@staticmethod
|
||||
def nodes_are_equivalent(lhs, rhs) -> bool:
|
||||
"""Helper function for tests to check if two nodes are equivalent."""
|
||||
equivalent_func = EQUIVALENT_TEST_FUNC.get(type(lhs), None)
|
||||
if equivalent_func is not None:
|
||||
return equivalent_func(lhs, rhs)
|
||||
|
||||
# This is a default for the test_conftest.py that should remain separate from the package
|
||||
# nodes is_equivalent_* functions
|
||||
return lhs.is_equivalent_to(rhs)
|
||||
|
||||
@staticmethod
|
||||
def digraphs_are_equivalent(reference: nx.MultiDiGraph, to_compare: nx.MultiDiGraph):
|
||||
"""Check that two digraphs are equivalent without modifications"""
|
||||
# edge_match is a copy of node_match
|
||||
edge_matcher = iso.categorical_multiedge_match("input_idx", None)
|
||||
node_matcher = iso.generic_node_match(
|
||||
"_test_content", None, lambda lhs, rhs: lhs.is_equivalent_to(rhs)
|
||||
"_test_content", None, TestHelpers.nodes_are_equivalent
|
||||
)
|
||||
|
||||
# Set the _test_content for each node in the graphs
|
||||
|
||||
Reference in New Issue
Block a user