From 1b489b281cfe14a4d7fe495bbc8db45a4fffd04e Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Thu, 26 Aug 2021 11:17:56 +0200 Subject: [PATCH] refactor: move is_equivalent to conftest.py for tests --- hdk/common/representation/intermediate.py | 71 ---------- .../representation/test_intermediate.py | 7 +- tests/conftest.py | 125 +++++++++++++++++- 3 files changed, 130 insertions(+), 73 deletions(-) diff --git a/hdk/common/representation/intermediate.py b/hdk/common/representation/intermediate.py index d0cfef740..bbe6530c5 100644 --- a/hdk/common/representation/intermediate.py +++ b/hdk/common/representation/intermediate.py @@ -49,41 +49,6 @@ class IntermediateNode(ABC): self.outputs = [mix_values_func(self.inputs[0], self.inputs[1])] - def _is_equivalent_to_binary_commutative(self, other: object) -> bool: - """is_equivalent_to for a binary and commutative operation.""" - return ( - isinstance(other, self.__class__) - and (self.inputs == other.inputs or self.inputs == other.inputs[::-1]) - and self.outputs == other.outputs - ) - - def _is_equivalent_to_binary_non_commutative(self, other: object) -> bool: - """is_equivalent_to for a binary and non-commutative operation.""" - return ( - isinstance(other, self.__class__) - and self.inputs == other.inputs - and self.outputs == other.outputs - ) - - @abstractmethod - def is_equivalent_to(self, other: object) -> bool: - """Alternative to __eq__ to check equivalence between IntermediateNodes. - - Overriding __eq__ has unwanted side effects, this provides the same facility without - disrupting expected behavior too much - - Args: - other (object): Other object to check against - - Returns: - bool: True if the other object is equivalent - """ - return ( - isinstance(other, IntermediateNode) - and self.inputs == other.inputs - and self.outputs == other.outputs - ) - @abstractmethod def evaluate(self, inputs: Dict[int, Any]) -> Any: """Function to simulate what the represented computation would output for the given inputs. @@ -129,7 +94,6 @@ class Add(IntermediateNode): _n_in: int = 2 __init__ = IntermediateNode._init_binary - is_equivalent_to = IntermediateNode._is_equivalent_to_binary_commutative def evaluate(self, inputs: Dict[int, Any]) -> Any: return inputs[0] + inputs[1] @@ -144,7 +108,6 @@ class Sub(IntermediateNode): _n_in: int = 2 __init__ = IntermediateNode._init_binary - is_equivalent_to = IntermediateNode._is_equivalent_to_binary_non_commutative def evaluate(self, inputs: Dict[int, Any]) -> Any: return inputs[0] - inputs[1] @@ -159,7 +122,6 @@ class Mul(IntermediateNode): _n_in: int = 2 __init__ = IntermediateNode._init_binary - is_equivalent_to = IntermediateNode._is_equivalent_to_binary_commutative def evaluate(self, inputs: Dict[int, Any]) -> Any: return inputs[0] * inputs[1] @@ -190,14 +152,6 @@ class Input(IntermediateNode): def evaluate(self, inputs: Dict[int, Any]) -> Any: return inputs[0] - def is_equivalent_to(self, other: object) -> bool: - return ( - isinstance(other, Input) - and self.input_name == other.input_name - and self.program_input_idx == other.program_input_idx - and super().is_equivalent_to(other) - ) - def label(self) -> str: return self.input_name @@ -225,13 +179,6 @@ class Constant(IntermediateNode): def evaluate(self, inputs: Dict[int, Any]) -> Any: return self.constant_data - def is_equivalent_to(self, other: object) -> bool: - return ( - isinstance(other, Constant) - and self.constant_data == other.constant_data - and super().is_equivalent_to(other) - ) - @property def constant_data(self) -> Any: """Returns the constant_data stored in the Constant node. @@ -278,17 +225,6 @@ class ArbitraryFunction(IntermediateNode): assert self.arbitrary_func is not None return self.arbitrary_func(inputs[0], *self.op_args, **self.op_kwargs) - def is_equivalent_to(self, other: object) -> bool: - # FIXME: comparing self.arbitrary_func to other.arbitrary_func will not work - # Only evaluating over the same set of inputs and comparing will help - return ( - isinstance(other, ArbitraryFunction) - and self.op_args == other.op_args - and self.op_kwargs == other.op_kwargs - and self.op_name == other.op_name - and super().is_equivalent_to(other) - ) - def label(self) -> str: return self.op_name @@ -344,12 +280,5 @@ class Dot(IntermediateNode): assert self.evaluation_function is not None return self.evaluation_function(inputs[0], inputs[1]) - def is_equivalent_to(self, other: object) -> bool: - return ( - isinstance(other, self.__class__) - and self.evaluation_function == other.evaluation_function - and super().is_equivalent_to(other) - ) - def label(self) -> str: return "dot" diff --git a/tests/common/representation/test_intermediate.py b/tests/common/representation/test_intermediate.py index 9283fabd0..4b08d655c 100644 --- a/tests/common/representation/test_intermediate.py +++ b/tests/common/representation/test_intermediate.py @@ -276,6 +276,11 @@ def test_is_equivalent_to( node1: ir.IntermediateNode, node2: ir.IntermediateNode, expected_result: bool, + test_helpers, ): """Test is_equivalent_to methods on IntermediateNodes""" - assert node1.is_equivalent_to(node2) == node2.is_equivalent_to(node1) == expected_result + assert ( + test_helpers.nodes_are_equivalent(node1, node2) + == test_helpers.nodes_are_equivalent(node2, node1) + == expected_result + ) diff --git a/tests/conftest.py b/tests/conftest.py index f222bd3fa..9454fb98f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,19 +1,142 @@ """PyTest configuration file""" +from typing import Callable, Dict, Type + import networkx as nx import networkx.algorithms.isomorphism as iso import pytest +from hdk.common.representation.intermediate import ( + ALL_IR_NODES, + Add, + ArbitraryFunction, + Constant, + Dot, + Input, + IntermediateNode, + Mul, + Sub, +) + + +def _is_equivalent_to_binary_commutative(lhs: IntermediateNode, rhs: object) -> bool: + """is_equivalent_to for a binary and commutative operation.""" + return ( + isinstance(rhs, lhs.__class__) + and (lhs.inputs == rhs.inputs or lhs.inputs == rhs.inputs[::-1]) + and lhs.outputs == rhs.outputs + ) + + +def _is_equivalent_to_binary_non_commutative(lhs: IntermediateNode, rhs: object) -> bool: + """is_equivalent_to for a binary and non-commutative operation.""" + return ( + isinstance(rhs, lhs.__class__) and lhs.inputs == rhs.inputs and lhs.outputs == rhs.outputs + ) + + +def is_equivalent_add(lhs: Add, rhs: object) -> bool: + """Helper function to check if an Add node is equivalent to an other object.""" + return _is_equivalent_to_binary_commutative(lhs, rhs) + + +def is_equivalent_arbitrary_function(lhs: ArbitraryFunction, rhs: object) -> bool: + """Helper function to check if an ArbitraryFunction node is equivalent to an other object.""" + return ( + isinstance(rhs, ArbitraryFunction) + and lhs.op_args == rhs.op_args + and lhs.op_kwargs == rhs.op_kwargs + and lhs.op_name == rhs.op_name + and is_equivalent_intermediate_node(lhs, rhs) + ) + + +def is_equivalent_constant(lhs: Constant, rhs: object) -> bool: + """Helper function to check if a Constant node is equivalent to an other object.""" + return ( + isinstance(rhs, Constant) + and lhs.constant_data == rhs.constant_data + and is_equivalent_intermediate_node(lhs, rhs) + ) + + +def is_equivalent_dot(lhs: Dot, rhs: object) -> bool: + """Helper function to check if a Dot node is equivalent to an other object.""" + return ( + isinstance(rhs, Dot) + and lhs.evaluation_function == rhs.evaluation_function + and is_equivalent_intermediate_node(lhs, rhs) + ) + + +def is_equivalent_input(lhs: Input, rhs: object) -> bool: + """Helper function to check if an Input node is equivalent to an other object.""" + return ( + isinstance(rhs, Input) + and lhs.input_name == rhs.input_name + and lhs.program_input_idx == rhs.program_input_idx + and is_equivalent_intermediate_node(lhs, rhs) + ) + + +def is_equivalent_mul(lhs: Mul, rhs: object) -> bool: + """Helper function to check if a Mul node is equivalent to an other object.""" + return _is_equivalent_to_binary_commutative(lhs, rhs) + + +def is_equivalent_sub(lhs: Sub, rhs: object) -> bool: + """Helper function to check if a Sub node is equivalent to an other object.""" + return _is_equivalent_to_binary_non_commutative(lhs, rhs) + + +def is_equivalent_intermediate_node(lhs: IntermediateNode, rhs: object) -> bool: + """Helper function to check if an IntermediateNode node is equivalent to an other object.""" + return ( + isinstance(rhs, IntermediateNode) + and lhs.inputs == rhs.inputs + and lhs.outputs == rhs.outputs + ) + + +EQUIVALENT_TEST_FUNC: Dict[Type, Callable[..., bool]] = { + Add: is_equivalent_add, + ArbitraryFunction: is_equivalent_arbitrary_function, + Constant: is_equivalent_constant, + Dot: is_equivalent_dot, + Input: is_equivalent_input, + Mul: is_equivalent_mul, + Sub: is_equivalent_sub, +} + +_missing_nodes_in_mapping = ALL_IR_NODES - EQUIVALENT_TEST_FUNC.keys() +assert len(_missing_nodes_in_mapping) == 0, ( + f"Missing IR node in EQUIVALENT_TEST_FUNC : " + f"{', '.join(sorted(str(node_type) for node_type in _missing_nodes_in_mapping))}" +) + +del _missing_nodes_in_mapping + class TestHelpers: """Class allowing to pass helper functions to tests""" + @staticmethod + def nodes_are_equivalent(lhs, rhs) -> bool: + """Helper function for tests to check if two nodes are equivalent.""" + equivalent_func = EQUIVALENT_TEST_FUNC.get(type(lhs), None) + if equivalent_func is not None: + return equivalent_func(lhs, rhs) + + # This is a default for the test_conftest.py that should remain separate from the package + # nodes is_equivalent_* functions + return lhs.is_equivalent_to(rhs) + @staticmethod def digraphs_are_equivalent(reference: nx.MultiDiGraph, to_compare: nx.MultiDiGraph): """Check that two digraphs are equivalent without modifications""" # edge_match is a copy of node_match edge_matcher = iso.categorical_multiedge_match("input_idx", None) node_matcher = iso.generic_node_match( - "_test_content", None, lambda lhs, rhs: lhs.is_equivalent_to(rhs) + "_test_content", None, TestHelpers.nodes_are_equivalent ) # Set the _test_content for each node in the graphs