diff --git a/tests/common/extensions/test_table.py b/tests/common/extensions/test_table.py index 81e7e90ca..cdd0481a2 100644 --- a/tests/common/extensions/test_table.py +++ b/tests/common/extensions/test_table.py @@ -54,13 +54,16 @@ def test_lookup_table_encrypted_lookup(test_helpers): input_x = ir.Input(input_value=x, input_name="x", program_input_idx=0) ref_graph.add_node(input_x) + # pylint: disable=protected-access + # Need access to _checked_indexing to have is_equivalent_to work for ir.ArbitraryFunction output_arbitrary_function = ir.ArbitraryFunction( input_base_value=x, - arbitrary_func=lambda x, table: table[x], + arbitrary_func=LookupTable._checked_indexing, output_dtype=table.output_dtype, op_kwargs={"table": deepcopy(table.table)}, op_name="TLU", ) + # pylint: enable=protected-access ref_graph.add_node(output_arbitrary_function) ref_graph.add_edge(input_x, output_arbitrary_function, input_idx=0) @@ -91,13 +94,16 @@ def test_lookup_table_encrypted_and_plain_lookup(test_helpers): input_x = ir.Input(input_value=x, input_name="x", program_input_idx=0) ref_graph.add_node(input_x) + # pylint: disable=protected-access + # Need access to _checked_indexing to have is_equivalent_to work for ir.ArbitraryFunction intermediate_arbitrary_function = ir.ArbitraryFunction( input_base_value=x, - arbitrary_func=lambda x, table: table[x], + arbitrary_func=LookupTable._checked_indexing, output_dtype=table.output_dtype, op_kwargs={"table": deepcopy(table.table)}, op_name="TLU", ) + # pylint: enable=protected-access ref_graph.add_node(intermediate_arbitrary_function) constant_3 = ir.Constant(3) diff --git a/tests/conftest.py b/tests/conftest.py index 9df9f064c..09d84a55f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ """PyTest configuration file""" +import operator from typing import Callable, Dict, Type import networkx as nx @@ -39,10 +40,37 @@ def is_equivalent_add(lhs: Add, rhs: object) -> bool: return _is_equivalent_to_binary_commutative(lhs, rhs) +# From https://stackoverflow.com/a/28635464 +_code_and_constants_attr_getter = operator.attrgetter("co_code", "co_consts") + + +def _code_and_constants(object_): + """Helper function to get python code and constants""" + return _code_and_constants_attr_getter(object_.__code__) + + +def python_functions_are_equal_or_equivalent(lhs: object, rhs: object) -> bool: + """Helper function to check if two functions are equal or their code are equivalent. + + This is not perfect, but will be good enough for tests. + """ + + if lhs == rhs: + return True + + try: + lhs_code_and_constants = _code_and_constants(lhs) + rhs_code_and_constants = _code_and_constants(rhs) + return lhs_code_and_constants == rhs_code_and_constants + except AttributeError: + return False + + def is_equivalent_arbitrary_function(lhs: ArbitraryFunction, rhs: object) -> bool: """Helper function to check if an ArbitraryFunction node is equivalent to an other object.""" return ( isinstance(rhs, ArbitraryFunction) + and python_functions_are_equal_or_equivalent(lhs.arbitrary_func, rhs.arbitrary_func) and lhs.op_args == rhs.op_args and lhs.op_kwargs == rhs.op_kwargs and lhs.op_name == rhs.op_name