From ce358ca838af47ae5af0f315cb26bcc47183dcff Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Fri, 16 Jul 2021 17:21:12 +0200 Subject: [PATCH] tests: add test helper to compare digraphs - add test to check that the helper is working --- tests/conftest.py | 29 ++++++++++++++ tests/helpers/test_conftest.py | 71 ++++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+) create mode 100644 tests/conftest.py create mode 100644 tests/helpers/test_conftest.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..b68f743fe --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,29 @@ +"""PyTest configuration file""" +import networkx as nx +import networkx.algorithms.isomorphism as iso +import pytest + + +class TestHelpers: + """Class allowing to pass helper functions to tests""" + + @staticmethod + def digraphs_are_equivalent(reference: nx.DiGraph, to_compare: nx.DiGraph): + """Check that two digraphs are equivalent without modifications""" + # edge_match is a copy of node_match + edge_matcher = iso.categorical_node_match("input_idx", None) + node_matcher = iso.categorical_node_match("content", None) + graphs_are_isomorphic = nx.is_isomorphic( + reference, + to_compare, + node_match=node_matcher, + edge_match=edge_matcher, + ) + + return graphs_are_isomorphic + + +@pytest.fixture +def test_helpers(): + """Fixture to return the static helper class""" + return TestHelpers diff --git a/tests/helpers/test_conftest.py b/tests/helpers/test_conftest.py new file mode 100644 index 000000000..d1c5b4cc4 --- /dev/null +++ b/tests/helpers/test_conftest.py @@ -0,0 +1,71 @@ +"""Test file for conftest helper functions""" +import networkx as nx + + +def test_digraphs_are_equivalent(test_helpers): + """Function to test digraphs_are_equivalent helper function""" + + class TestNode: + """Dummy test node""" + + computation: str + + def __init__(self, computation: str) -> None: + self.computation = computation + + def __hash__(self) -> int: + return self.computation.__hash__() + + def __eq__(self, other) -> bool: + return self.computation == other.computation + + g_1 = nx.DiGraph() + g_2 = nx.DiGraph() + + t_0 = TestNode("Add") + t_1 = TestNode("Mul") + t_2 = TestNode("TLU") + + g_1.add_edge(t_0, t_2, input_idx=0) + g_1.add_edge(t_1, t_2, input_idx=1) + + # This updates the nodes attributes in the graph + for node in g_1: + g_1.add_node(node, content=node) + + t0p = TestNode("Add") + t1p = TestNode("Mul") + t2p = TestNode("TLU") + + g_2.add_edge(t1p, t2p, input_idx=1) + g_2.add_edge(t0p, t2p, input_idx=0) + + # This updates the nodes attributes in the graph + for node in g_2: + g_2.add_node(node, content=node) + + bad_g2 = nx.DiGraph() + + bad_t0 = TestNode("Not Add") + + bad_g2.add_edge(bad_t0, t_2, input_idx=0) + bad_g2.add_edge(t_1, t_2, input_idx=1) + + # This updates the nodes attributes in the graph + for node in bad_g2: + bad_g2.add_node(node, content=node) + + bad_g3 = nx.DiGraph() + + bad_g3.add_edge(t_0, t_2, input_idx=1) + bad_g3.add_edge(t_1, t_2, input_idx=0) + + # This updates the nodes attributes in the graph + for node in bad_g3: + bad_g3.add_node(node, content=node) + + assert test_helpers.digraphs_are_equivalent(g_1, g_2), "Graphs should be equivalent" + assert not test_helpers.digraphs_are_equivalent(g_1, bad_g2), "Graphs should not be equivalent" + assert not test_helpers.digraphs_are_equivalent(g_2, bad_g2), "Graphs should not be equivalent" + assert not test_helpers.digraphs_are_equivalent(g_1, bad_g3), "Graphs should not be equivalent" + assert not test_helpers.digraphs_are_equivalent(g_2, bad_g3), "Graphs should not be equivalent"