refactor: move is_equivalent to conftest.py for tests

This commit is contained in:
Arthur Meyre
2021-08-26 11:17:56 +02:00
parent 31259e556c
commit 1b489b281c
3 changed files with 130 additions and 73 deletions

View File

@@ -49,41 +49,6 @@ class IntermediateNode(ABC):
self.outputs = [mix_values_func(self.inputs[0], self.inputs[1])]
def _is_equivalent_to_binary_commutative(self, other: object) -> bool:
"""is_equivalent_to for a binary and commutative operation."""
return (
isinstance(other, self.__class__)
and (self.inputs == other.inputs or self.inputs == other.inputs[::-1])
and self.outputs == other.outputs
)
def _is_equivalent_to_binary_non_commutative(self, other: object) -> bool:
"""is_equivalent_to for a binary and non-commutative operation."""
return (
isinstance(other, self.__class__)
and self.inputs == other.inputs
and self.outputs == other.outputs
)
@abstractmethod
def is_equivalent_to(self, other: object) -> bool:
"""Alternative to __eq__ to check equivalence between IntermediateNodes.
Overriding __eq__ has unwanted side effects, this provides the same facility without
disrupting expected behavior too much
Args:
other (object): Other object to check against
Returns:
bool: True if the other object is equivalent
"""
return (
isinstance(other, IntermediateNode)
and self.inputs == other.inputs
and self.outputs == other.outputs
)
@abstractmethod
def evaluate(self, inputs: Dict[int, Any]) -> Any:
"""Function to simulate what the represented computation would output for the given inputs.
@@ -129,7 +94,6 @@ class Add(IntermediateNode):
_n_in: int = 2
__init__ = IntermediateNode._init_binary
is_equivalent_to = IntermediateNode._is_equivalent_to_binary_commutative
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0] + inputs[1]
@@ -144,7 +108,6 @@ class Sub(IntermediateNode):
_n_in: int = 2
__init__ = IntermediateNode._init_binary
is_equivalent_to = IntermediateNode._is_equivalent_to_binary_non_commutative
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0] - inputs[1]
@@ -159,7 +122,6 @@ class Mul(IntermediateNode):
_n_in: int = 2
__init__ = IntermediateNode._init_binary
is_equivalent_to = IntermediateNode._is_equivalent_to_binary_commutative
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0] * inputs[1]
@@ -190,14 +152,6 @@ class Input(IntermediateNode):
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0]
def is_equivalent_to(self, other: object) -> bool:
return (
isinstance(other, Input)
and self.input_name == other.input_name
and self.program_input_idx == other.program_input_idx
and super().is_equivalent_to(other)
)
def label(self) -> str:
return self.input_name
@@ -225,13 +179,6 @@ class Constant(IntermediateNode):
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return self.constant_data
def is_equivalent_to(self, other: object) -> bool:
return (
isinstance(other, Constant)
and self.constant_data == other.constant_data
and super().is_equivalent_to(other)
)
@property
def constant_data(self) -> Any:
"""Returns the constant_data stored in the Constant node.
@@ -278,17 +225,6 @@ class ArbitraryFunction(IntermediateNode):
assert self.arbitrary_func is not None
return self.arbitrary_func(inputs[0], *self.op_args, **self.op_kwargs)
def is_equivalent_to(self, other: object) -> bool:
# FIXME: comparing self.arbitrary_func to other.arbitrary_func will not work
# Only evaluating over the same set of inputs and comparing will help
return (
isinstance(other, ArbitraryFunction)
and self.op_args == other.op_args
and self.op_kwargs == other.op_kwargs
and self.op_name == other.op_name
and super().is_equivalent_to(other)
)
def label(self) -> str:
return self.op_name
@@ -344,12 +280,5 @@ class Dot(IntermediateNode):
assert self.evaluation_function is not None
return self.evaluation_function(inputs[0], inputs[1])
def is_equivalent_to(self, other: object) -> bool:
return (
isinstance(other, self.__class__)
and self.evaluation_function == other.evaluation_function
and super().is_equivalent_to(other)
)
def label(self) -> str:
return "dot"

View File

@@ -276,6 +276,11 @@ def test_is_equivalent_to(
node1: ir.IntermediateNode,
node2: ir.IntermediateNode,
expected_result: bool,
test_helpers,
):
"""Test is_equivalent_to methods on IntermediateNodes"""
assert node1.is_equivalent_to(node2) == node2.is_equivalent_to(node1) == expected_result
assert (
test_helpers.nodes_are_equivalent(node1, node2)
== test_helpers.nodes_are_equivalent(node2, node1)
== expected_result
)

View File

@@ -1,19 +1,142 @@
"""PyTest configuration file"""
from typing import Callable, Dict, Type
import networkx as nx
import networkx.algorithms.isomorphism as iso
import pytest
from hdk.common.representation.intermediate import (
ALL_IR_NODES,
Add,
ArbitraryFunction,
Constant,
Dot,
Input,
IntermediateNode,
Mul,
Sub,
)
def _is_equivalent_to_binary_commutative(lhs: IntermediateNode, rhs: object) -> bool:
"""is_equivalent_to for a binary and commutative operation."""
return (
isinstance(rhs, lhs.__class__)
and (lhs.inputs == rhs.inputs or lhs.inputs == rhs.inputs[::-1])
and lhs.outputs == rhs.outputs
)
def _is_equivalent_to_binary_non_commutative(lhs: IntermediateNode, rhs: object) -> bool:
"""is_equivalent_to for a binary and non-commutative operation."""
return (
isinstance(rhs, lhs.__class__) and lhs.inputs == rhs.inputs and lhs.outputs == rhs.outputs
)
def is_equivalent_add(lhs: Add, rhs: object) -> bool:
"""Helper function to check if an Add node is equivalent to an other object."""
return _is_equivalent_to_binary_commutative(lhs, rhs)
def is_equivalent_arbitrary_function(lhs: ArbitraryFunction, rhs: object) -> bool:
"""Helper function to check if an ArbitraryFunction node is equivalent to an other object."""
return (
isinstance(rhs, ArbitraryFunction)
and lhs.op_args == rhs.op_args
and lhs.op_kwargs == rhs.op_kwargs
and lhs.op_name == rhs.op_name
and is_equivalent_intermediate_node(lhs, rhs)
)
def is_equivalent_constant(lhs: Constant, rhs: object) -> bool:
"""Helper function to check if a Constant node is equivalent to an other object."""
return (
isinstance(rhs, Constant)
and lhs.constant_data == rhs.constant_data
and is_equivalent_intermediate_node(lhs, rhs)
)
def is_equivalent_dot(lhs: Dot, rhs: object) -> bool:
"""Helper function to check if a Dot node is equivalent to an other object."""
return (
isinstance(rhs, Dot)
and lhs.evaluation_function == rhs.evaluation_function
and is_equivalent_intermediate_node(lhs, rhs)
)
def is_equivalent_input(lhs: Input, rhs: object) -> bool:
"""Helper function to check if an Input node is equivalent to an other object."""
return (
isinstance(rhs, Input)
and lhs.input_name == rhs.input_name
and lhs.program_input_idx == rhs.program_input_idx
and is_equivalent_intermediate_node(lhs, rhs)
)
def is_equivalent_mul(lhs: Mul, rhs: object) -> bool:
"""Helper function to check if a Mul node is equivalent to an other object."""
return _is_equivalent_to_binary_commutative(lhs, rhs)
def is_equivalent_sub(lhs: Sub, rhs: object) -> bool:
"""Helper function to check if a Sub node is equivalent to an other object."""
return _is_equivalent_to_binary_non_commutative(lhs, rhs)
def is_equivalent_intermediate_node(lhs: IntermediateNode, rhs: object) -> bool:
"""Helper function to check if an IntermediateNode node is equivalent to an other object."""
return (
isinstance(rhs, IntermediateNode)
and lhs.inputs == rhs.inputs
and lhs.outputs == rhs.outputs
)
EQUIVALENT_TEST_FUNC: Dict[Type, Callable[..., bool]] = {
Add: is_equivalent_add,
ArbitraryFunction: is_equivalent_arbitrary_function,
Constant: is_equivalent_constant,
Dot: is_equivalent_dot,
Input: is_equivalent_input,
Mul: is_equivalent_mul,
Sub: is_equivalent_sub,
}
_missing_nodes_in_mapping = ALL_IR_NODES - EQUIVALENT_TEST_FUNC.keys()
assert len(_missing_nodes_in_mapping) == 0, (
f"Missing IR node in EQUIVALENT_TEST_FUNC : "
f"{', '.join(sorted(str(node_type) for node_type in _missing_nodes_in_mapping))}"
)
del _missing_nodes_in_mapping
class TestHelpers:
"""Class allowing to pass helper functions to tests"""
@staticmethod
def nodes_are_equivalent(lhs, rhs) -> bool:
"""Helper function for tests to check if two nodes are equivalent."""
equivalent_func = EQUIVALENT_TEST_FUNC.get(type(lhs), None)
if equivalent_func is not None:
return equivalent_func(lhs, rhs)
# This is a default for the test_conftest.py that should remain separate from the package
# nodes is_equivalent_* functions
return lhs.is_equivalent_to(rhs)
@staticmethod
def digraphs_are_equivalent(reference: nx.MultiDiGraph, to_compare: nx.MultiDiGraph):
"""Check that two digraphs are equivalent without modifications"""
# edge_match is a copy of node_match
edge_matcher = iso.categorical_multiedge_match("input_idx", None)
node_matcher = iso.generic_node_match(
"_test_content", None, lambda lhs, rhs: lhs.is_equivalent_to(rhs)
"_test_content", None, TestHelpers.nodes_are_equivalent
)
# Set the _test_content for each node in the graphs