mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
- the naming has always been confusing and recent changes to the code make this rename necessary for things to be clearer
191 lines
6.2 KiB
Python
191 lines
6.2 KiB
Python
"""PyTest configuration file"""
|
|
import operator
|
|
from typing import Callable, Dict, Type
|
|
|
|
import networkx as nx
|
|
import networkx.algorithms.isomorphism as iso
|
|
import pytest
|
|
|
|
from concrete.common.representation.intermediate import (
|
|
ALL_IR_NODES,
|
|
Add,
|
|
Constant,
|
|
Dot,
|
|
Input,
|
|
IntermediateNode,
|
|
Mul,
|
|
Sub,
|
|
UnivariateFunction,
|
|
)
|
|
|
|
|
|
def _is_equivalent_to_binary_commutative(lhs: IntermediateNode, rhs: object) -> bool:
|
|
"""is_equivalent_to for a binary and commutative operation."""
|
|
return (
|
|
isinstance(rhs, lhs.__class__)
|
|
and (lhs.inputs in (rhs.inputs, rhs.inputs[::-1]))
|
|
and lhs.outputs == rhs.outputs
|
|
)
|
|
|
|
|
|
def _is_equivalent_to_binary_non_commutative(lhs: IntermediateNode, rhs: object) -> bool:
|
|
"""is_equivalent_to for a binary and non-commutative operation."""
|
|
return (
|
|
isinstance(rhs, lhs.__class__) and lhs.inputs == rhs.inputs and lhs.outputs == rhs.outputs
|
|
)
|
|
|
|
|
|
def is_equivalent_add(lhs: Add, rhs: object) -> bool:
|
|
"""Helper function to check if an Add node is equivalent to an other object."""
|
|
return _is_equivalent_to_binary_commutative(lhs, rhs)
|
|
|
|
|
|
# From https://stackoverflow.com/a/28635464
|
|
_code_and_constants_attr_getter = operator.attrgetter("co_code", "co_consts")
|
|
|
|
|
|
def _code_and_constants(object_):
|
|
"""Helper function to get python code and constants"""
|
|
return _code_and_constants_attr_getter(object_.__code__)
|
|
|
|
|
|
def python_functions_are_equal_or_equivalent(lhs: object, rhs: object) -> bool:
|
|
"""Helper function to check if two functions are equal or their code are equivalent.
|
|
|
|
This is not perfect, but will be good enough for tests.
|
|
"""
|
|
|
|
if lhs == rhs:
|
|
return True
|
|
|
|
try:
|
|
lhs_code_and_constants = _code_and_constants(lhs)
|
|
rhs_code_and_constants = _code_and_constants(rhs)
|
|
return lhs_code_and_constants == rhs_code_and_constants
|
|
except AttributeError:
|
|
return False
|
|
|
|
|
|
def is_equivalent_arbitrary_function(lhs: UnivariateFunction, rhs: object) -> bool:
|
|
"""Helper function to check if an UnivariateFunction node is equivalent to an other object."""
|
|
return (
|
|
isinstance(rhs, UnivariateFunction)
|
|
and python_functions_are_equal_or_equivalent(lhs.arbitrary_func, rhs.arbitrary_func)
|
|
and lhs.op_args == rhs.op_args
|
|
and lhs.op_kwargs == rhs.op_kwargs
|
|
and lhs.op_name == rhs.op_name
|
|
and is_equivalent_intermediate_node(lhs, rhs)
|
|
)
|
|
|
|
|
|
def is_equivalent_constant(lhs: Constant, rhs: object) -> bool:
|
|
"""Helper function to check if a Constant node is equivalent to an other object."""
|
|
return (
|
|
isinstance(rhs, Constant)
|
|
and lhs.constant_data == rhs.constant_data
|
|
and is_equivalent_intermediate_node(lhs, rhs)
|
|
)
|
|
|
|
|
|
def is_equivalent_dot(lhs: Dot, rhs: object) -> bool:
|
|
"""Helper function to check if a Dot node is equivalent to an other object."""
|
|
return (
|
|
isinstance(rhs, Dot)
|
|
and lhs.evaluation_function == rhs.evaluation_function
|
|
and is_equivalent_intermediate_node(lhs, rhs)
|
|
)
|
|
|
|
|
|
def is_equivalent_input(lhs: Input, rhs: object) -> bool:
|
|
"""Helper function to check if an Input node is equivalent to an other object."""
|
|
return (
|
|
isinstance(rhs, Input)
|
|
and lhs.input_name == rhs.input_name
|
|
and lhs.program_input_idx == rhs.program_input_idx
|
|
and is_equivalent_intermediate_node(lhs, rhs)
|
|
)
|
|
|
|
|
|
def is_equivalent_mul(lhs: Mul, rhs: object) -> bool:
|
|
"""Helper function to check if a Mul node is equivalent to an other object."""
|
|
return _is_equivalent_to_binary_commutative(lhs, rhs)
|
|
|
|
|
|
def is_equivalent_sub(lhs: Sub, rhs: object) -> bool:
|
|
"""Helper function to check if a Sub node is equivalent to an other object."""
|
|
return _is_equivalent_to_binary_non_commutative(lhs, rhs)
|
|
|
|
|
|
def is_equivalent_intermediate_node(lhs: IntermediateNode, rhs: object) -> bool:
|
|
"""Helper function to check if an IntermediateNode node is equivalent to an other object."""
|
|
return (
|
|
isinstance(rhs, IntermediateNode)
|
|
and lhs.inputs == rhs.inputs
|
|
and lhs.outputs == rhs.outputs
|
|
)
|
|
|
|
|
|
EQUIVALENT_TEST_FUNC: Dict[Type, Callable[..., bool]] = {
|
|
Add: is_equivalent_add,
|
|
UnivariateFunction: is_equivalent_arbitrary_function,
|
|
Constant: is_equivalent_constant,
|
|
Dot: is_equivalent_dot,
|
|
Input: is_equivalent_input,
|
|
Mul: is_equivalent_mul,
|
|
Sub: is_equivalent_sub,
|
|
}
|
|
|
|
_missing_nodes_in_mapping = ALL_IR_NODES - EQUIVALENT_TEST_FUNC.keys()
|
|
assert len(_missing_nodes_in_mapping) == 0, (
|
|
f"Missing IR node in EQUIVALENT_TEST_FUNC : "
|
|
f"{', '.join(sorted(str(node_type) for node_type in _missing_nodes_in_mapping))}"
|
|
)
|
|
|
|
del _missing_nodes_in_mapping
|
|
|
|
|
|
class TestHelpers:
|
|
"""Class allowing to pass helper functions to tests"""
|
|
|
|
@staticmethod
|
|
def nodes_are_equivalent(lhs, rhs) -> bool:
|
|
"""Helper function for tests to check if two nodes are equivalent."""
|
|
equivalent_func = EQUIVALENT_TEST_FUNC.get(type(lhs), None)
|
|
if equivalent_func is not None:
|
|
return equivalent_func(lhs, rhs)
|
|
|
|
# This is a default for the test_conftest.py that should remain separate from the package
|
|
# nodes is_equivalent_* functions
|
|
return lhs.is_equivalent_to(rhs)
|
|
|
|
@staticmethod
|
|
def digraphs_are_equivalent(reference: nx.MultiDiGraph, to_compare: nx.MultiDiGraph):
|
|
"""Check that two digraphs are equivalent without modifications"""
|
|
# edge_match is a copy of node_match
|
|
edge_matcher = iso.categorical_multiedge_match("input_idx", None)
|
|
node_matcher = iso.generic_node_match(
|
|
"_test_content", None, TestHelpers.nodes_are_equivalent
|
|
)
|
|
|
|
# Set the _test_content for each node in the graphs
|
|
for node in reference.nodes():
|
|
reference.add_node(node, _test_content=node)
|
|
|
|
for node in to_compare.nodes():
|
|
to_compare.add_node(node, _test_content=node)
|
|
|
|
graphs_are_isomorphic = nx.is_isomorphic(
|
|
reference,
|
|
to_compare,
|
|
node_match=node_matcher,
|
|
edge_match=edge_matcher,
|
|
)
|
|
|
|
return graphs_are_isomorphic
|
|
|
|
|
|
@pytest.fixture
|
|
def test_helpers():
|
|
"""Fixture to return the static helper class"""
|
|
return TestHelpers
|