mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
refactor: rename ArbitraryFunction to UnivariateFunction
- the naming has always been confusing and recent changes to the code make this rename necessary for things to be clearer
This commit is contained in:
@@ -14,12 +14,12 @@ from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import (
|
||||
ALL_IR_NODES,
|
||||
Add,
|
||||
ArbitraryFunction,
|
||||
Constant,
|
||||
Dot,
|
||||
Input,
|
||||
Mul,
|
||||
Sub,
|
||||
UnivariateFunction,
|
||||
)
|
||||
|
||||
IR_NODE_COLOR_MAPPING = {
|
||||
@@ -28,9 +28,9 @@ IR_NODE_COLOR_MAPPING = {
|
||||
Add: "red",
|
||||
Sub: "yellow",
|
||||
Mul: "green",
|
||||
ArbitraryFunction: "orange",
|
||||
UnivariateFunction: "orange",
|
||||
Dot: "purple",
|
||||
"ArbitraryFunction": "orange",
|
||||
"UnivariateFunction": "orange",
|
||||
"TLU": "grey",
|
||||
"output": "magenta",
|
||||
}
|
||||
@@ -71,7 +71,7 @@ def draw_graph(
|
||||
value_to_return = IR_NODE_COLOR_MAPPING[type(node)]
|
||||
if node in output_nodes:
|
||||
value_to_return = IR_NODE_COLOR_MAPPING["output"]
|
||||
elif isinstance(node, ArbitraryFunction):
|
||||
elif isinstance(node, UnivariateFunction):
|
||||
value_to_return = IR_NODE_COLOR_MAPPING.get(node.op_name, value_to_return)
|
||||
return value_to_return
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import networkx as nx
|
||||
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import ArbitraryFunction, Constant, Input
|
||||
from ..representation.intermediate import Constant, Input, UnivariateFunction
|
||||
|
||||
|
||||
def output_data_type_to_string(node):
|
||||
@@ -61,7 +61,7 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
|
||||
|
||||
base_name = node.__class__.__name__
|
||||
|
||||
if isinstance(node, ArbitraryFunction):
|
||||
if isinstance(node, UnivariateFunction):
|
||||
base_name = node.op_name
|
||||
|
||||
what_to_print = base_name + "("
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Iterable, Tuple, Union
|
||||
from ..common_helpers import is_a_power_of_2
|
||||
from ..data_types.base import BaseDataType
|
||||
from ..data_types.integers import make_integer_to_hold
|
||||
from ..representation.intermediate import ArbitraryFunction
|
||||
from ..representation.intermediate import UnivariateFunction
|
||||
from ..tracing.base_tracer import BaseTracer
|
||||
|
||||
|
||||
@@ -32,10 +32,10 @@ class LookupTable:
|
||||
|
||||
def __getitem__(self, key: Union[int, BaseTracer]):
|
||||
# if a tracer is used for indexing,
|
||||
# we need to create an `ArbitraryFunction` node
|
||||
# we need to create an `UnivariateFunction` node
|
||||
# because the result will be determined during the runtime
|
||||
if isinstance(key, BaseTracer):
|
||||
traced_computation = ArbitraryFunction(
|
||||
traced_computation = UnivariateFunction(
|
||||
input_base_value=key.output,
|
||||
arbitrary_func=LookupTable._checked_indexing,
|
||||
output_dtype=self.output_dtype,
|
||||
|
||||
@@ -22,7 +22,7 @@ from ..data_types.dtypes_helpers import (
|
||||
)
|
||||
from ..data_types.integers import Integer
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from ..representation.intermediate import Add, ArbitraryFunction, Constant, Dot, Mul, Sub
|
||||
from ..representation.intermediate import Add, Constant, Dot, Mul, Sub, UnivariateFunction
|
||||
from ..values import TensorValue
|
||||
|
||||
|
||||
@@ -165,7 +165,7 @@ def constant(node, _, __, ctx):
|
||||
|
||||
|
||||
def apply_lut(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Convert an arbitrary function intermediate node."""
|
||||
"""Convert a UnivariateFunction intermediate node."""
|
||||
custom_assert(len(node.inputs) == 1, "LUT should have a single input")
|
||||
custom_assert(len(node.outputs) == 1, "LUT should have a single output")
|
||||
if not value_is_encrypted_scalar_unsigned_integer(node.inputs[0]):
|
||||
@@ -224,7 +224,7 @@ V0_OPSET_CONVERSION_FUNCTIONS = {
|
||||
Sub: sub,
|
||||
Mul: mul,
|
||||
Constant: constant,
|
||||
ArbitraryFunction: apply_lut,
|
||||
UnivariateFunction: apply_lut,
|
||||
Dot: dot,
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from ..data_types.dtypes_helpers import (
|
||||
)
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import ArbitraryFunction
|
||||
from ..representation.intermediate import UnivariateFunction
|
||||
|
||||
# TODO: should come from compiler, through an API, #402
|
||||
ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB = 7
|
||||
@@ -81,7 +81,7 @@ def update_bit_width_for_mlir(op_graph: OPGraph):
|
||||
|
||||
# TODO: remove this workaround, which was for #279, once the compiler can handle
|
||||
# smaller tables, #412
|
||||
has_a_table = any(isinstance(node, ArbitraryFunction) for node in op_graph.graph.nodes)
|
||||
has_a_table = any(isinstance(node, UnivariateFunction) for node in op_graph.graph.nodes)
|
||||
|
||||
if has_a_table:
|
||||
max_bit_width = ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB
|
||||
@@ -104,7 +104,7 @@ def extend_direct_lookup_tables(op_graph: OPGraph):
|
||||
op_graph: graph to update lookup tables for
|
||||
"""
|
||||
for node in op_graph.graph.nodes:
|
||||
if isinstance(node, ArbitraryFunction) and node.op_name == "TLU":
|
||||
if isinstance(node, UnivariateFunction) and node.op_name == "TLU":
|
||||
table = node.op_kwargs["table"]
|
||||
bit_width = cast(Integer, node.inputs[0].dtype).bit_width
|
||||
expected_length = 2 ** bit_width
|
||||
|
||||
@@ -10,7 +10,7 @@ from ..data_types.floats import Float
|
||||
from ..data_types.integers import Integer
|
||||
from ..debugging.custom_assert import assert_true, custom_assert
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import ArbitraryFunction, Constant, Input, IntermediateNode
|
||||
from ..representation.intermediate import Constant, Input, IntermediateNode, UnivariateFunction
|
||||
from ..values import TensorValue
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ def fuse_float_operations(
|
||||
op_graph: OPGraph,
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
):
|
||||
"""Find and fuse float domains into single Integer to Integer ArbitraryFunction.
|
||||
"""Find and fuse float domains into single Integer to Integer UnivariateFunction.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): The OPGraph to simplify
|
||||
@@ -92,8 +92,8 @@ def convert_float_subgraph_to_fused_node(
|
||||
float_subgraph_start_nodes: Set[IntermediateNode],
|
||||
terminal_node: IntermediateNode,
|
||||
subgraph_all_nodes: Set[IntermediateNode],
|
||||
) -> Optional[Tuple[ArbitraryFunction, IntermediateNode]]:
|
||||
"""Convert a float subgraph to an equivalent fused ArbitraryFunction node.
|
||||
) -> Optional[Tuple[UnivariateFunction, IntermediateNode]]:
|
||||
"""Convert a float subgraph to an equivalent fused UnivariateFunction node.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): The OPGraph the float subgraph is part of.
|
||||
@@ -103,7 +103,7 @@ def convert_float_subgraph_to_fused_node(
|
||||
subgraph_all_nodes (Set[IntermediateNode]): All the nodes in the float subgraph.
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[ArbitraryFunction, IntermediateNode]]: None if the float subgraph
|
||||
Optional[Tuple[UnivariateFunction, IntermediateNode]]: None if the float subgraph
|
||||
cannot be fused, otherwise returns a tuple containing the fused node and the node whose
|
||||
output must be plugged as the input to the subgraph.
|
||||
"""
|
||||
@@ -161,7 +161,7 @@ def convert_float_subgraph_to_fused_node(
|
||||
)
|
||||
|
||||
# Create fused_node
|
||||
fused_node = ArbitraryFunction(
|
||||
fused_node = UnivariateFunction(
|
||||
deepcopy(new_subgraph_variable_input.inputs[0]),
|
||||
lambda x, float_op_subgraph, terminal_node: float_op_subgraph.evaluate({0: x})[
|
||||
terminal_node
|
||||
@@ -251,7 +251,7 @@ def subgraph_values_allow_fusing(
|
||||
):
|
||||
"""Check if a subgraph's values are compatible with fusing.
|
||||
|
||||
A fused subgraph for example only works on an input tensor if the resulting ArbitraryFunction
|
||||
A fused subgraph for example only works on an input tensor if the resulting UnivariateFunction
|
||||
can be applied per cell, hence shuffling or tensor shape changes make fusing impossible.
|
||||
|
||||
Args:
|
||||
@@ -273,12 +273,12 @@ def subgraph_values_allow_fusing(
|
||||
f"only works for subgraphs with 1 variable input node, got {num_variable_input_nodes}",
|
||||
)
|
||||
|
||||
# Some ArbitraryFunction nodes have baked constants that need to be taken into account for the
|
||||
# Some UnivariateFunction nodes have baked constants that need to be taken into account for the
|
||||
# max size computation
|
||||
baked_constants_ir_nodes = [
|
||||
baked_constant_base_value
|
||||
for node in subgraph_all_nodes
|
||||
if isinstance(node, ArbitraryFunction)
|
||||
if isinstance(node, UnivariateFunction)
|
||||
if (baked_constant_base_value := node.op_attributes.get("baked_constant_ir_node", None))
|
||||
is not None
|
||||
]
|
||||
@@ -297,7 +297,7 @@ def subgraph_values_allow_fusing(
|
||||
|
||||
# A cheap check is that the variable input node must have the biggest size, i.e. have the most
|
||||
# elements, meaning all constants will broadcast to its shape. This is because the
|
||||
# ArbitraryFunction input and output must have the same shape so that it can be applied to each
|
||||
# UnivariateFunction input and output must have the same shape so that it can be applied to each
|
||||
# of the input tensor cells.
|
||||
# There *may* be a way to manage the other case by simulating the broadcast of the smaller input
|
||||
# array and then concatenating/stacking the results. This is not currently doable as we don't
|
||||
@@ -343,5 +343,5 @@ def subgraph_has_unique_variable_input(
|
||||
bool: True if only one of the nodes is not an Constant
|
||||
"""
|
||||
# Only one input to the subgraph where computations are done in floats is variable, this
|
||||
# is the only case we can manage with ArbitraryFunction fusing
|
||||
# is the only case we can manage with UnivariateFunction fusing
|
||||
return sum(not isinstance(node, Constant) for node in float_subgraph_start_nodes) == 1
|
||||
|
||||
@@ -196,11 +196,14 @@ class Constant(IntermediateNode):
|
||||
return str(self.constant_data)
|
||||
|
||||
|
||||
class ArbitraryFunction(IntermediateNode):
|
||||
"""Node representing a univariate arbitrary function, e.g. sin(x)."""
|
||||
class UnivariateFunction(IntermediateNode):
|
||||
"""Node representing an univariate arbitrary function, e.g. sin(x)."""
|
||||
|
||||
# The arbitrary_func is not optional but mypy has a long standing bug and is not able to
|
||||
# understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623
|
||||
# arbitrary_func can take more than one argument but during evaluation the input variable will
|
||||
# be the first argument passed to it. You can add other constant arguments needed for the proper
|
||||
# execution of the function through op_args and op_kwargs.
|
||||
arbitrary_func: Optional[Callable]
|
||||
op_name: str
|
||||
op_args: Tuple[Any, ...]
|
||||
@@ -240,9 +243,9 @@ class ArbitraryFunction(IntermediateNode):
|
||||
return self.op_name
|
||||
|
||||
def get_table(self) -> List[Any]:
|
||||
"""Get the table for the current input value of this ArbitraryFunction.
|
||||
"""Get the table for the current input value of this UnivariateFunction.
|
||||
|
||||
This function only works if the ArbitraryFunction input value is an unsigned Integer.
|
||||
This function only works if the UnivariateFunction input value is an unsigned Integer.
|
||||
|
||||
Returns:
|
||||
List[Any]: The table.
|
||||
@@ -290,7 +293,7 @@ class Dot(IntermediateNode):
|
||||
"""Return the node representing a dot product."""
|
||||
|
||||
_n_in: int = 2
|
||||
# Optional, same issue as in ArbitraryFunction for mypy
|
||||
# Optional, same issue as in UnivariateFunction for mypy
|
||||
evaluation_function: Optional[Callable[[Any, Any], Any]]
|
||||
# Allows to use specialized implementations from e.g. numpy
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ def _compile_numpy_function_into_op_graph_internal(
|
||||
|
||||
# Apply topological optimizations if they are enabled
|
||||
if compilation_configuration.enable_topological_optimizations:
|
||||
# Fuse float operations to have int to int ArbitraryFunction
|
||||
# Fuse float operations to have int to int UnivariateFunction
|
||||
if not check_op_graph_is_integer_program(op_graph):
|
||||
fuse_float_operations(op_graph, compilation_artifacts)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from numpy.typing import DTypeLike
|
||||
from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype
|
||||
from ..common.debugging.custom_assert import assert_true, custom_assert
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.representation.intermediate import ArbitraryFunction, Constant, Dot
|
||||
from ..common.representation.intermediate import Constant, Dot, UnivariateFunction
|
||||
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
|
||||
from ..common.values import BaseValue
|
||||
from .np_dtypes_helpers import (
|
||||
@@ -87,7 +87,7 @@ class NPTracer(BaseTracer):
|
||||
|
||||
normalized_numpy_dtype = numpy.dtype(numpy_dtype)
|
||||
output_dtype = convert_numpy_dtype_to_base_data_type(numpy_dtype)
|
||||
traced_computation = ArbitraryFunction(
|
||||
traced_computation = UnivariateFunction(
|
||||
input_base_value=self.output,
|
||||
arbitrary_func=normalized_numpy_dtype.type,
|
||||
output_dtype=output_dtype,
|
||||
@@ -154,7 +154,7 @@ class NPTracer(BaseTracer):
|
||||
common_output_dtypes = cls._manage_dtypes(unary_operator, *input_tracers)
|
||||
custom_assert(len(common_output_dtypes) == 1)
|
||||
|
||||
traced_computation = ArbitraryFunction(
|
||||
traced_computation = UnivariateFunction(
|
||||
input_base_value=input_tracers[0].output,
|
||||
arbitrary_func=unary_operator,
|
||||
output_dtype=common_output_dtypes[0],
|
||||
@@ -218,7 +218,7 @@ class NPTracer(BaseTracer):
|
||||
"in_which_input_is_constant": in_which_input_is_constant,
|
||||
}
|
||||
|
||||
traced_computation = ArbitraryFunction(
|
||||
traced_computation = UnivariateFunction(
|
||||
input_base_value=input_tracers[in_which_input_is_variable].output,
|
||||
arbitrary_func=arbitrary_func,
|
||||
output_dtype=common_output_dtypes[0],
|
||||
|
||||
@@ -31,7 +31,7 @@ The float subgraph that was detected:
|
||||
|
||||

|
||||
|
||||
The simplified graph of operations with the float subgraph condensed in an `ArbitraryFunction` node:
|
||||
The simplified graph of operations with the float subgraph condensed in an `UnivariateFunction` node:
|
||||
|
||||

|
||||
|
||||
@@ -39,7 +39,7 @@ The simplified graph of operations with the float subgraph condensed in an `Arbi
|
||||
|
||||
The first step consists in detecting where we go from floating point computation back to integers. This allows to identify the potential terminal node of the float subgraph we are going to fuse.
|
||||
|
||||
From the terminal node, we go back up through the nodes until we find nodes that go from integers to floats. If we can guarantee the identified float subgraph has a single variable integer input then we can replace it by an equivalent ArbitraryFunction node.
|
||||
From the terminal node, we go back up through the nodes until we find nodes that go from integers to floats. If we can guarantee the identified float subgraph has a single variable integer input then we can replace it by an equivalent UnivariateFunction node.
|
||||
|
||||
An example of a non fusable computation with that technique is:
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ Traceback (most recent call last):
|
||||
File "/src/concrete/numpy/compile.py", line 103, in _compile_numpy_function_into_op_graph_internal
|
||||
raise ValueError(
|
||||
ValueError: <lambda> cannot be compiled as it has nodes with either float inputs or outputs.
|
||||
Offending nodes : <concrete.common.representation.intermediate.ArbitraryFunction object at 0x7f6689fd37f0>
|
||||
Offending nodes : <concrete.common.representation.intermediate.UnivariateFunction object at 0x7f6689fd37f0>
|
||||
```
|
||||
|
||||
## Manual export
|
||||
|
||||
@@ -55,8 +55,8 @@ def test_lookup_table_encrypted_lookup(test_helpers):
|
||||
ref_graph.add_node(input_x)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
# Need access to _checked_indexing to have is_equivalent_to work for ir.ArbitraryFunction
|
||||
output_arbitrary_function = ir.ArbitraryFunction(
|
||||
# Need access to _checked_indexing to have is_equivalent_to work for ir.UnivariateFunction
|
||||
output_arbitrary_function = ir.UnivariateFunction(
|
||||
input_base_value=x,
|
||||
arbitrary_func=LookupTable._checked_indexing,
|
||||
output_dtype=table.output_dtype,
|
||||
@@ -68,7 +68,7 @@ def test_lookup_table_encrypted_lookup(test_helpers):
|
||||
|
||||
ref_graph.add_edge(input_x, output_arbitrary_function, input_idx=0)
|
||||
|
||||
# TODO: discuss if this check is enough as == is not overloaded properly for ArbitraryFunction
|
||||
# TODO: discuss if this check is enough as == is not overloaded properly for UnivariateFunction
|
||||
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)
|
||||
|
||||
|
||||
@@ -95,8 +95,8 @@ def test_lookup_table_encrypted_and_plain_lookup(test_helpers):
|
||||
ref_graph.add_node(input_x)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
# Need access to _checked_indexing to have is_equivalent_to work for ir.ArbitraryFunction
|
||||
intermediate_arbitrary_function = ir.ArbitraryFunction(
|
||||
# Need access to _checked_indexing to have is_equivalent_to work for ir.UnivariateFunction
|
||||
intermediate_arbitrary_function = ir.UnivariateFunction(
|
||||
input_base_value=x,
|
||||
arbitrary_func=LookupTable._checked_indexing,
|
||||
output_dtype=table.output_dtype,
|
||||
@@ -117,5 +117,5 @@ def test_lookup_table_encrypted_and_plain_lookup(test_helpers):
|
||||
ref_graph.add_edge(intermediate_arbitrary_function, output_add, input_idx=0)
|
||||
ref_graph.add_edge(constant_3, output_add, input_idx=1)
|
||||
|
||||
# TODO: discuss if this check is enough as == is not overloaded properly for ArbitraryFunction
|
||||
# TODO: discuss if this check is enough as == is not overloaded properly for UnivariateFunction
|
||||
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)
|
||||
|
||||
@@ -34,15 +34,15 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
|
||||
pytest.param(ir.Constant(42), None, 42, id="Constant"),
|
||||
pytest.param(ir.Constant(-42), None, -42, id="Constant"),
|
||||
pytest.param(
|
||||
ir.ArbitraryFunction(
|
||||
ir.UnivariateFunction(
|
||||
EncryptedScalar(Integer(7, False)), lambda x: x + 3, Integer(7, False)
|
||||
),
|
||||
[10],
|
||||
13,
|
||||
id="ArbitraryFunction, x + 3",
|
||||
id="UnivariateFunction, x + 3",
|
||||
),
|
||||
pytest.param(
|
||||
ir.ArbitraryFunction(
|
||||
ir.UnivariateFunction(
|
||||
EncryptedScalar(Integer(7, False)),
|
||||
lambda x, y: x + y,
|
||||
Integer(7, False),
|
||||
@@ -50,10 +50,10 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
|
||||
),
|
||||
[10],
|
||||
13,
|
||||
id="ArbitraryFunction, (x, y) -> x + y, where y is constant == 3",
|
||||
id="UnivariateFunction, (x, y) -> x + y, where y is constant == 3",
|
||||
),
|
||||
pytest.param(
|
||||
ir.ArbitraryFunction(
|
||||
ir.UnivariateFunction(
|
||||
EncryptedScalar(Integer(7, False)),
|
||||
lambda x, y: y[x],
|
||||
Integer(7, False),
|
||||
@@ -61,10 +61,10 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
|
||||
),
|
||||
[2],
|
||||
3,
|
||||
id="ArbitraryFunction, (x, y) -> y[x], where y is constant == (1, 2, 3, 4)",
|
||||
id="UnivariateFunction, (x, y) -> y[x], where y is constant == (1, 2, 3, 4)",
|
||||
),
|
||||
pytest.param(
|
||||
ir.ArbitraryFunction(
|
||||
ir.UnivariateFunction(
|
||||
EncryptedScalar(Integer(7, False)),
|
||||
lambda x, y: y[3],
|
||||
Integer(7, False),
|
||||
@@ -72,7 +72,7 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
|
||||
),
|
||||
[2],
|
||||
4,
|
||||
id="ArbitraryFunction, x, y -> y[3], where y is constant == (1, 2, 3, 4)",
|
||||
id="UnivariateFunction, x, y -> y[3], where y is constant == (1, 2, 3, 4)",
|
||||
),
|
||||
pytest.param(
|
||||
ir.Dot(
|
||||
@@ -209,34 +209,34 @@ def test_evaluate(
|
||||
False,
|
||||
),
|
||||
(
|
||||
ir.ArbitraryFunction(
|
||||
ir.UnivariateFunction(
|
||||
EncryptedScalar(Integer(8, False)), lambda x: x, Integer(8, False)
|
||||
),
|
||||
ir.ArbitraryFunction(
|
||||
ir.UnivariateFunction(
|
||||
EncryptedScalar(Integer(8, False)), lambda x: x, Integer(8, False)
|
||||
),
|
||||
True,
|
||||
),
|
||||
(
|
||||
ir.ArbitraryFunction(
|
||||
ir.UnivariateFunction(
|
||||
EncryptedScalar(Integer(8, False)),
|
||||
lambda x: x,
|
||||
Integer(8, False),
|
||||
op_args=(1, 2, 3),
|
||||
),
|
||||
ir.ArbitraryFunction(
|
||||
ir.UnivariateFunction(
|
||||
EncryptedScalar(Integer(8, False)), lambda x: x, Integer(8, False)
|
||||
),
|
||||
False,
|
||||
),
|
||||
(
|
||||
ir.ArbitraryFunction(
|
||||
ir.UnivariateFunction(
|
||||
EncryptedScalar(Integer(8, False)),
|
||||
lambda x: x,
|
||||
Integer(8, False),
|
||||
op_kwargs={"tuple": (1, 2, 3)},
|
||||
),
|
||||
ir.ArbitraryFunction(
|
||||
ir.UnivariateFunction(
|
||||
EncryptedScalar(Integer(8, False)), lambda x: x, Integer(8, False)
|
||||
),
|
||||
False,
|
||||
|
||||
@@ -9,13 +9,13 @@ import pytest
|
||||
from concrete.common.representation.intermediate import (
|
||||
ALL_IR_NODES,
|
||||
Add,
|
||||
ArbitraryFunction,
|
||||
Constant,
|
||||
Dot,
|
||||
Input,
|
||||
IntermediateNode,
|
||||
Mul,
|
||||
Sub,
|
||||
UnivariateFunction,
|
||||
)
|
||||
|
||||
|
||||
@@ -66,10 +66,10 @@ def python_functions_are_equal_or_equivalent(lhs: object, rhs: object) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def is_equivalent_arbitrary_function(lhs: ArbitraryFunction, rhs: object) -> bool:
|
||||
"""Helper function to check if an ArbitraryFunction node is equivalent to an other object."""
|
||||
def is_equivalent_arbitrary_function(lhs: UnivariateFunction, rhs: object) -> bool:
|
||||
"""Helper function to check if an UnivariateFunction node is equivalent to an other object."""
|
||||
return (
|
||||
isinstance(rhs, ArbitraryFunction)
|
||||
isinstance(rhs, UnivariateFunction)
|
||||
and python_functions_are_equal_or_equivalent(lhs.arbitrary_func, rhs.arbitrary_func)
|
||||
and lhs.op_args == rhs.op_args
|
||||
and lhs.op_kwargs == rhs.op_kwargs
|
||||
@@ -127,7 +127,7 @@ def is_equivalent_intermediate_node(lhs: IntermediateNode, rhs: object) -> bool:
|
||||
|
||||
EQUIVALENT_TEST_FUNC: Dict[Type, Callable[..., bool]] = {
|
||||
Add: is_equivalent_add,
|
||||
ArbitraryFunction: is_equivalent_arbitrary_function,
|
||||
UnivariateFunction: is_equivalent_arbitrary_function,
|
||||
Constant: is_equivalent_constant,
|
||||
Dot: is_equivalent_dot,
|
||||
Input: is_equivalent_input,
|
||||
|
||||
@@ -386,24 +386,24 @@ def test_tracing_astype(
|
||||
[
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Integer(7, is_signed=False))},
|
||||
ir.ArbitraryFunction,
|
||||
ir.UnivariateFunction,
|
||||
),
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Integer(32, is_signed=True))},
|
||||
ir.ArbitraryFunction,
|
||||
ir.UnivariateFunction,
|
||||
),
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Integer(64, is_signed=True))},
|
||||
ir.ArbitraryFunction,
|
||||
ir.UnivariateFunction,
|
||||
),
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Integer(128, is_signed=True))},
|
||||
ir.ArbitraryFunction,
|
||||
ir.UnivariateFunction,
|
||||
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
|
||||
),
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Float(64))},
|
||||
ir.ArbitraryFunction,
|
||||
ir.UnivariateFunction,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user