mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
refactor: move is_equivalent to conftest.py for tests
This commit is contained in:
@@ -276,6 +276,11 @@ def test_is_equivalent_to(
|
||||
node1: ir.IntermediateNode,
|
||||
node2: ir.IntermediateNode,
|
||||
expected_result: bool,
|
||||
test_helpers,
|
||||
):
|
||||
"""Test is_equivalent_to methods on IntermediateNodes"""
|
||||
assert node1.is_equivalent_to(node2) == node2.is_equivalent_to(node1) == expected_result
|
||||
assert (
|
||||
test_helpers.nodes_are_equivalent(node1, node2)
|
||||
== test_helpers.nodes_are_equivalent(node2, node1)
|
||||
== expected_result
|
||||
)
|
||||
|
||||
@@ -1,19 +1,142 @@
|
||||
"""PyTest configuration file"""
|
||||
from typing import Callable, Dict, Type
|
||||
|
||||
import networkx as nx
|
||||
import networkx.algorithms.isomorphism as iso
|
||||
import pytest
|
||||
|
||||
from hdk.common.representation.intermediate import (
|
||||
ALL_IR_NODES,
|
||||
Add,
|
||||
ArbitraryFunction,
|
||||
Constant,
|
||||
Dot,
|
||||
Input,
|
||||
IntermediateNode,
|
||||
Mul,
|
||||
Sub,
|
||||
)
|
||||
|
||||
|
||||
def _is_equivalent_to_binary_commutative(lhs: IntermediateNode, rhs: object) -> bool:
|
||||
"""is_equivalent_to for a binary and commutative operation."""
|
||||
return (
|
||||
isinstance(rhs, lhs.__class__)
|
||||
and (lhs.inputs == rhs.inputs or lhs.inputs == rhs.inputs[::-1])
|
||||
and lhs.outputs == rhs.outputs
|
||||
)
|
||||
|
||||
|
||||
def _is_equivalent_to_binary_non_commutative(lhs: IntermediateNode, rhs: object) -> bool:
|
||||
"""is_equivalent_to for a binary and non-commutative operation."""
|
||||
return (
|
||||
isinstance(rhs, lhs.__class__) and lhs.inputs == rhs.inputs and lhs.outputs == rhs.outputs
|
||||
)
|
||||
|
||||
|
||||
def is_equivalent_add(lhs: Add, rhs: object) -> bool:
|
||||
"""Helper function to check if an Add node is equivalent to an other object."""
|
||||
return _is_equivalent_to_binary_commutative(lhs, rhs)
|
||||
|
||||
|
||||
def is_equivalent_arbitrary_function(lhs: ArbitraryFunction, rhs: object) -> bool:
|
||||
"""Helper function to check if an ArbitraryFunction node is equivalent to an other object."""
|
||||
return (
|
||||
isinstance(rhs, ArbitraryFunction)
|
||||
and lhs.op_args == rhs.op_args
|
||||
and lhs.op_kwargs == rhs.op_kwargs
|
||||
and lhs.op_name == rhs.op_name
|
||||
and is_equivalent_intermediate_node(lhs, rhs)
|
||||
)
|
||||
|
||||
|
||||
def is_equivalent_constant(lhs: Constant, rhs: object) -> bool:
|
||||
"""Helper function to check if a Constant node is equivalent to an other object."""
|
||||
return (
|
||||
isinstance(rhs, Constant)
|
||||
and lhs.constant_data == rhs.constant_data
|
||||
and is_equivalent_intermediate_node(lhs, rhs)
|
||||
)
|
||||
|
||||
|
||||
def is_equivalent_dot(lhs: Dot, rhs: object) -> bool:
|
||||
"""Helper function to check if a Dot node is equivalent to an other object."""
|
||||
return (
|
||||
isinstance(rhs, Dot)
|
||||
and lhs.evaluation_function == rhs.evaluation_function
|
||||
and is_equivalent_intermediate_node(lhs, rhs)
|
||||
)
|
||||
|
||||
|
||||
def is_equivalent_input(lhs: Input, rhs: object) -> bool:
|
||||
"""Helper function to check if an Input node is equivalent to an other object."""
|
||||
return (
|
||||
isinstance(rhs, Input)
|
||||
and lhs.input_name == rhs.input_name
|
||||
and lhs.program_input_idx == rhs.program_input_idx
|
||||
and is_equivalent_intermediate_node(lhs, rhs)
|
||||
)
|
||||
|
||||
|
||||
def is_equivalent_mul(lhs: Mul, rhs: object) -> bool:
|
||||
"""Helper function to check if a Mul node is equivalent to an other object."""
|
||||
return _is_equivalent_to_binary_commutative(lhs, rhs)
|
||||
|
||||
|
||||
def is_equivalent_sub(lhs: Sub, rhs: object) -> bool:
|
||||
"""Helper function to check if a Sub node is equivalent to an other object."""
|
||||
return _is_equivalent_to_binary_non_commutative(lhs, rhs)
|
||||
|
||||
|
||||
def is_equivalent_intermediate_node(lhs: IntermediateNode, rhs: object) -> bool:
|
||||
"""Helper function to check if an IntermediateNode node is equivalent to an other object."""
|
||||
return (
|
||||
isinstance(rhs, IntermediateNode)
|
||||
and lhs.inputs == rhs.inputs
|
||||
and lhs.outputs == rhs.outputs
|
||||
)
|
||||
|
||||
|
||||
EQUIVALENT_TEST_FUNC: Dict[Type, Callable[..., bool]] = {
|
||||
Add: is_equivalent_add,
|
||||
ArbitraryFunction: is_equivalent_arbitrary_function,
|
||||
Constant: is_equivalent_constant,
|
||||
Dot: is_equivalent_dot,
|
||||
Input: is_equivalent_input,
|
||||
Mul: is_equivalent_mul,
|
||||
Sub: is_equivalent_sub,
|
||||
}
|
||||
|
||||
_missing_nodes_in_mapping = ALL_IR_NODES - EQUIVALENT_TEST_FUNC.keys()
|
||||
assert len(_missing_nodes_in_mapping) == 0, (
|
||||
f"Missing IR node in EQUIVALENT_TEST_FUNC : "
|
||||
f"{', '.join(sorted(str(node_type) for node_type in _missing_nodes_in_mapping))}"
|
||||
)
|
||||
|
||||
del _missing_nodes_in_mapping
|
||||
|
||||
|
||||
class TestHelpers:
|
||||
"""Class allowing to pass helper functions to tests"""
|
||||
|
||||
@staticmethod
|
||||
def nodes_are_equivalent(lhs, rhs) -> bool:
|
||||
"""Helper function for tests to check if two nodes are equivalent."""
|
||||
equivalent_func = EQUIVALENT_TEST_FUNC.get(type(lhs), None)
|
||||
if equivalent_func is not None:
|
||||
return equivalent_func(lhs, rhs)
|
||||
|
||||
# This is a default for the test_conftest.py that should remain separate from the package
|
||||
# nodes is_equivalent_* functions
|
||||
return lhs.is_equivalent_to(rhs)
|
||||
|
||||
@staticmethod
|
||||
def digraphs_are_equivalent(reference: nx.MultiDiGraph, to_compare: nx.MultiDiGraph):
|
||||
"""Check that two digraphs are equivalent without modifications"""
|
||||
# edge_match is a copy of node_match
|
||||
edge_matcher = iso.categorical_multiedge_match("input_idx", None)
|
||||
node_matcher = iso.generic_node_match(
|
||||
"_test_content", None, lambda lhs, rhs: lhs.is_equivalent_to(rhs)
|
||||
"_test_content", None, TestHelpers.nodes_are_equivalent
|
||||
)
|
||||
|
||||
# Set the _test_content for each node in the graphs
|
||||
|
||||
Reference in New Issue
Block a user