mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
tests: add test helper to compare digraphs
- add test to check that the helper is working
This commit is contained in:
29
tests/conftest.py
Normal file
29
tests/conftest.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""PyTest configuration file"""
|
||||
import networkx as nx
|
||||
import networkx.algorithms.isomorphism as iso
|
||||
import pytest
|
||||
|
||||
|
||||
class TestHelpers:
|
||||
"""Class allowing to pass helper functions to tests"""
|
||||
|
||||
@staticmethod
|
||||
def digraphs_are_equivalent(reference: nx.DiGraph, to_compare: nx.DiGraph):
|
||||
"""Check that two digraphs are equivalent without modifications"""
|
||||
# edge_match is a copy of node_match
|
||||
edge_matcher = iso.categorical_node_match("input_idx", None)
|
||||
node_matcher = iso.categorical_node_match("content", None)
|
||||
graphs_are_isomorphic = nx.is_isomorphic(
|
||||
reference,
|
||||
to_compare,
|
||||
node_match=node_matcher,
|
||||
edge_match=edge_matcher,
|
||||
)
|
||||
|
||||
return graphs_are_isomorphic
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_helpers():
|
||||
"""Fixture to return the static helper class"""
|
||||
return TestHelpers
|
||||
71
tests/helpers/test_conftest.py
Normal file
71
tests/helpers/test_conftest.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Test file for conftest helper functions"""
|
||||
import networkx as nx
|
||||
|
||||
|
||||
def test_digraphs_are_equivalent(test_helpers):
|
||||
"""Function to test digraphs_are_equivalent helper function"""
|
||||
|
||||
class TestNode:
|
||||
"""Dummy test node"""
|
||||
|
||||
computation: str
|
||||
|
||||
def __init__(self, computation: str) -> None:
|
||||
self.computation = computation
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.computation.__hash__()
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return self.computation == other.computation
|
||||
|
||||
g_1 = nx.DiGraph()
|
||||
g_2 = nx.DiGraph()
|
||||
|
||||
t_0 = TestNode("Add")
|
||||
t_1 = TestNode("Mul")
|
||||
t_2 = TestNode("TLU")
|
||||
|
||||
g_1.add_edge(t_0, t_2, input_idx=0)
|
||||
g_1.add_edge(t_1, t_2, input_idx=1)
|
||||
|
||||
# This updates the nodes attributes in the graph
|
||||
for node in g_1:
|
||||
g_1.add_node(node, content=node)
|
||||
|
||||
t0p = TestNode("Add")
|
||||
t1p = TestNode("Mul")
|
||||
t2p = TestNode("TLU")
|
||||
|
||||
g_2.add_edge(t1p, t2p, input_idx=1)
|
||||
g_2.add_edge(t0p, t2p, input_idx=0)
|
||||
|
||||
# This updates the nodes attributes in the graph
|
||||
for node in g_2:
|
||||
g_2.add_node(node, content=node)
|
||||
|
||||
bad_g2 = nx.DiGraph()
|
||||
|
||||
bad_t0 = TestNode("Not Add")
|
||||
|
||||
bad_g2.add_edge(bad_t0, t_2, input_idx=0)
|
||||
bad_g2.add_edge(t_1, t_2, input_idx=1)
|
||||
|
||||
# This updates the nodes attributes in the graph
|
||||
for node in bad_g2:
|
||||
bad_g2.add_node(node, content=node)
|
||||
|
||||
bad_g3 = nx.DiGraph()
|
||||
|
||||
bad_g3.add_edge(t_0, t_2, input_idx=1)
|
||||
bad_g3.add_edge(t_1, t_2, input_idx=0)
|
||||
|
||||
# This updates the nodes attributes in the graph
|
||||
for node in bad_g3:
|
||||
bad_g3.add_node(node, content=node)
|
||||
|
||||
assert test_helpers.digraphs_are_equivalent(g_1, g_2), "Graphs should be equivalent"
|
||||
assert not test_helpers.digraphs_are_equivalent(g_1, bad_g2), "Graphs should not be equivalent"
|
||||
assert not test_helpers.digraphs_are_equivalent(g_2, bad_g2), "Graphs should not be equivalent"
|
||||
assert not test_helpers.digraphs_are_equivalent(g_1, bad_g3), "Graphs should not be equivalent"
|
||||
assert not test_helpers.digraphs_are_equivalent(g_2, bad_g3), "Graphs should not be equivalent"
|
||||
Reference in New Issue
Block a user