mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
tests: add function for ArbitraryFunction arbitrary_func equivalence
- this is not perfect but pretty close to the best we can do
This commit is contained in:
@@ -54,13 +54,16 @@ def test_lookup_table_encrypted_lookup(test_helpers):
|
||||
input_x = ir.Input(input_value=x, input_name="x", program_input_idx=0)
|
||||
ref_graph.add_node(input_x)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
# Need access to _checked_indexing to have is_equivalent_to work for ir.ArbitraryFunction
|
||||
output_arbitrary_function = ir.ArbitraryFunction(
|
||||
input_base_value=x,
|
||||
arbitrary_func=lambda x, table: table[x],
|
||||
arbitrary_func=LookupTable._checked_indexing,
|
||||
output_dtype=table.output_dtype,
|
||||
op_kwargs={"table": deepcopy(table.table)},
|
||||
op_name="TLU",
|
||||
)
|
||||
# pylint: enable=protected-access
|
||||
ref_graph.add_node(output_arbitrary_function)
|
||||
|
||||
ref_graph.add_edge(input_x, output_arbitrary_function, input_idx=0)
|
||||
@@ -91,13 +94,16 @@ def test_lookup_table_encrypted_and_plain_lookup(test_helpers):
|
||||
input_x = ir.Input(input_value=x, input_name="x", program_input_idx=0)
|
||||
ref_graph.add_node(input_x)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
# Need access to _checked_indexing to have is_equivalent_to work for ir.ArbitraryFunction
|
||||
intermediate_arbitrary_function = ir.ArbitraryFunction(
|
||||
input_base_value=x,
|
||||
arbitrary_func=lambda x, table: table[x],
|
||||
arbitrary_func=LookupTable._checked_indexing,
|
||||
output_dtype=table.output_dtype,
|
||||
op_kwargs={"table": deepcopy(table.table)},
|
||||
op_name="TLU",
|
||||
)
|
||||
# pylint: enable=protected-access
|
||||
ref_graph.add_node(intermediate_arbitrary_function)
|
||||
|
||||
constant_3 = ir.Constant(3)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""PyTest configuration file"""
|
||||
import operator
|
||||
from typing import Callable, Dict, Type
|
||||
|
||||
import networkx as nx
|
||||
@@ -39,10 +40,37 @@ def is_equivalent_add(lhs: Add, rhs: object) -> bool:
|
||||
return _is_equivalent_to_binary_commutative(lhs, rhs)
|
||||
|
||||
|
||||
# From https://stackoverflow.com/a/28635464
|
||||
_code_and_constants_attr_getter = operator.attrgetter("co_code", "co_consts")
|
||||
|
||||
|
||||
def _code_and_constants(object_):
|
||||
"""Helper function to get python code and constants"""
|
||||
return _code_and_constants_attr_getter(object_.__code__)
|
||||
|
||||
|
||||
def python_functions_are_equal_or_equivalent(lhs: object, rhs: object) -> bool:
|
||||
"""Helper function to check if two functions are equal or their code are equivalent.
|
||||
|
||||
This is not perfect, but will be good enough for tests.
|
||||
"""
|
||||
|
||||
if lhs == rhs:
|
||||
return True
|
||||
|
||||
try:
|
||||
lhs_code_and_constants = _code_and_constants(lhs)
|
||||
rhs_code_and_constants = _code_and_constants(rhs)
|
||||
return lhs_code_and_constants == rhs_code_and_constants
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
|
||||
def is_equivalent_arbitrary_function(lhs: ArbitraryFunction, rhs: object) -> bool:
|
||||
"""Helper function to check if an ArbitraryFunction node is equivalent to an other object."""
|
||||
return (
|
||||
isinstance(rhs, ArbitraryFunction)
|
||||
and python_functions_are_equal_or_equivalent(lhs.arbitrary_func, rhs.arbitrary_func)
|
||||
and lhs.op_args == rhs.op_args
|
||||
and lhs.op_kwargs == rhs.op_kwargs
|
||||
and lhs.op_name == rhs.op_name
|
||||
|
||||
Reference in New Issue
Block a user