refactor: move is_equivalent to conftest.py for tests

This commit is contained in:
Arthur Meyre
2021-08-26 11:17:56 +02:00
parent 31259e556c
commit 1b489b281c
3 changed files with 130 additions and 73 deletions

View File

@@ -276,6 +276,11 @@ def test_is_equivalent_to(
node1: ir.IntermediateNode,
node2: ir.IntermediateNode,
expected_result: bool,
test_helpers,
):
"""Test is_equivalent_to methods on IntermediateNodes"""
assert node1.is_equivalent_to(node2) == node2.is_equivalent_to(node1) == expected_result
assert (
test_helpers.nodes_are_equivalent(node1, node2)
== test_helpers.nodes_are_equivalent(node2, node1)
== expected_result
)

View File

@@ -1,19 +1,142 @@
"""PyTest configuration file"""
from typing import Callable, Dict, Type
import networkx as nx
import networkx.algorithms.isomorphism as iso
import pytest
from hdk.common.representation.intermediate import (
ALL_IR_NODES,
Add,
ArbitraryFunction,
Constant,
Dot,
Input,
IntermediateNode,
Mul,
Sub,
)
def _is_equivalent_to_binary_commutative(lhs: IntermediateNode, rhs: object) -> bool:
"""is_equivalent_to for a binary and commutative operation."""
return (
isinstance(rhs, lhs.__class__)
and (lhs.inputs == rhs.inputs or lhs.inputs == rhs.inputs[::-1])
and lhs.outputs == rhs.outputs
)
def _is_equivalent_to_binary_non_commutative(lhs: IntermediateNode, rhs: object) -> bool:
"""is_equivalent_to for a binary and non-commutative operation."""
return (
isinstance(rhs, lhs.__class__) and lhs.inputs == rhs.inputs and lhs.outputs == rhs.outputs
)
def is_equivalent_add(lhs: Add, rhs: object) -> bool:
"""Helper function to check if an Add node is equivalent to an other object."""
return _is_equivalent_to_binary_commutative(lhs, rhs)
def is_equivalent_arbitrary_function(lhs: ArbitraryFunction, rhs: object) -> bool:
"""Helper function to check if an ArbitraryFunction node is equivalent to an other object."""
return (
isinstance(rhs, ArbitraryFunction)
and lhs.op_args == rhs.op_args
and lhs.op_kwargs == rhs.op_kwargs
and lhs.op_name == rhs.op_name
and is_equivalent_intermediate_node(lhs, rhs)
)
def is_equivalent_constant(lhs: Constant, rhs: object) -> bool:
"""Helper function to check if a Constant node is equivalent to an other object."""
return (
isinstance(rhs, Constant)
and lhs.constant_data == rhs.constant_data
and is_equivalent_intermediate_node(lhs, rhs)
)
def is_equivalent_dot(lhs: Dot, rhs: object) -> bool:
"""Helper function to check if a Dot node is equivalent to an other object."""
return (
isinstance(rhs, Dot)
and lhs.evaluation_function == rhs.evaluation_function
and is_equivalent_intermediate_node(lhs, rhs)
)
def is_equivalent_input(lhs: Input, rhs: object) -> bool:
"""Helper function to check if an Input node is equivalent to an other object."""
return (
isinstance(rhs, Input)
and lhs.input_name == rhs.input_name
and lhs.program_input_idx == rhs.program_input_idx
and is_equivalent_intermediate_node(lhs, rhs)
)
def is_equivalent_mul(lhs: Mul, rhs: object) -> bool:
"""Helper function to check if a Mul node is equivalent to an other object."""
return _is_equivalent_to_binary_commutative(lhs, rhs)
def is_equivalent_sub(lhs: Sub, rhs: object) -> bool:
"""Helper function to check if a Sub node is equivalent to an other object."""
return _is_equivalent_to_binary_non_commutative(lhs, rhs)
def is_equivalent_intermediate_node(lhs: IntermediateNode, rhs: object) -> bool:
"""Helper function to check if an IntermediateNode node is equivalent to an other object."""
return (
isinstance(rhs, IntermediateNode)
and lhs.inputs == rhs.inputs
and lhs.outputs == rhs.outputs
)
EQUIVALENT_TEST_FUNC: Dict[Type, Callable[..., bool]] = {
Add: is_equivalent_add,
ArbitraryFunction: is_equivalent_arbitrary_function,
Constant: is_equivalent_constant,
Dot: is_equivalent_dot,
Input: is_equivalent_input,
Mul: is_equivalent_mul,
Sub: is_equivalent_sub,
}
_missing_nodes_in_mapping = ALL_IR_NODES - EQUIVALENT_TEST_FUNC.keys()
assert len(_missing_nodes_in_mapping) == 0, (
f"Missing IR node in EQUIVALENT_TEST_FUNC : "
f"{', '.join(sorted(str(node_type) for node_type in _missing_nodes_in_mapping))}"
)
del _missing_nodes_in_mapping
class TestHelpers:
"""Class allowing to pass helper functions to tests"""
@staticmethod
def nodes_are_equivalent(lhs, rhs) -> bool:
"""Helper function for tests to check if two nodes are equivalent."""
equivalent_func = EQUIVALENT_TEST_FUNC.get(type(lhs), None)
if equivalent_func is not None:
return equivalent_func(lhs, rhs)
# This is a default for the test_conftest.py that should remain separate from the package
# nodes is_equivalent_* functions
return lhs.is_equivalent_to(rhs)
@staticmethod
def digraphs_are_equivalent(reference: nx.MultiDiGraph, to_compare: nx.MultiDiGraph):
"""Check that two digraphs are equivalent without modifications"""
# edge_match is a copy of node_match
edge_matcher = iso.categorical_multiedge_match("input_idx", None)
node_matcher = iso.generic_node_match(
"_test_content", None, lambda lhs, rhs: lhs.is_equivalent_to(rhs)
"_test_content", None, TestHelpers.nodes_are_equivalent
)
# Set the _test_content for each node in the graphs