From 00916bcfdbc91e91af649acf2e4e6f00fa572ad9 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Mon, 11 Oct 2021 09:53:03 +0200 Subject: [PATCH] refactor: rename ArbitraryFunction to UnivariateFunction - the naming has always been confusing and recent changes to the code make this rename necessary for things to be clearer --- concrete/common/debugging/drawing.py | 8 +++--- concrete/common/debugging/printing.py | 4 +-- concrete/common/extensions/table.py | 6 ++-- concrete/common/mlir/converters.py | 6 ++-- concrete/common/mlir/utils.py | 6 ++-- concrete/common/optimization/topological.py | 22 +++++++-------- .../common/representation/intermediate.py | 13 +++++---- concrete/numpy/compile.py | 2 +- concrete/numpy/tracing.py | 8 +++--- docs/dev/explanation/FLOAT-FUSING.md | 4 +-- docs/user/tutorial/COMPILATION_ARTIFACTS.md | 2 +- tests/common/extensions/test_table.py | 12 ++++---- .../representation/test_intermediate.py | 28 +++++++++---------- tests/conftest.py | 10 +++---- tests/numpy/test_tracing.py | 10 +++---- 15 files changed, 72 insertions(+), 69 deletions(-) diff --git a/concrete/common/debugging/drawing.py b/concrete/common/debugging/drawing.py index c06115996..65ce67628 100644 --- a/concrete/common/debugging/drawing.py +++ b/concrete/common/debugging/drawing.py @@ -14,12 +14,12 @@ from ..operator_graph import OPGraph from ..representation.intermediate import ( ALL_IR_NODES, Add, - ArbitraryFunction, Constant, Dot, Input, Mul, Sub, + UnivariateFunction, ) IR_NODE_COLOR_MAPPING = { @@ -28,9 +28,9 @@ IR_NODE_COLOR_MAPPING = { Add: "red", Sub: "yellow", Mul: "green", - ArbitraryFunction: "orange", + UnivariateFunction: "orange", Dot: "purple", - "ArbitraryFunction": "orange", + "UnivariateFunction": "orange", "TLU": "grey", "output": "magenta", } @@ -71,7 +71,7 @@ def draw_graph( value_to_return = IR_NODE_COLOR_MAPPING[type(node)] if node in output_nodes: value_to_return = IR_NODE_COLOR_MAPPING["output"] - elif isinstance(node, ArbitraryFunction): + elif isinstance(node, UnivariateFunction): value_to_return = IR_NODE_COLOR_MAPPING.get(node.op_name, value_to_return) return value_to_return diff --git a/concrete/common/debugging/printing.py b/concrete/common/debugging/printing.py index 0bc093b70..8187139ce 100644 --- a/concrete/common/debugging/printing.py +++ b/concrete/common/debugging/printing.py @@ -6,7 +6,7 @@ import networkx as nx from ..debugging.custom_assert import custom_assert from ..operator_graph import OPGraph -from ..representation.intermediate import ArbitraryFunction, Constant, Input +from ..representation.intermediate import Constant, Input, UnivariateFunction def output_data_type_to_string(node): @@ -61,7 +61,7 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str: base_name = node.__class__.__name__ - if isinstance(node, ArbitraryFunction): + if isinstance(node, UnivariateFunction): base_name = node.op_name what_to_print = base_name + "(" diff --git a/concrete/common/extensions/table.py b/concrete/common/extensions/table.py index 971a4309f..8e882bd52 100644 --- a/concrete/common/extensions/table.py +++ b/concrete/common/extensions/table.py @@ -6,7 +6,7 @@ from typing import Iterable, Tuple, Union from ..common_helpers import is_a_power_of_2 from ..data_types.base import BaseDataType from ..data_types.integers import make_integer_to_hold -from ..representation.intermediate import ArbitraryFunction +from ..representation.intermediate import UnivariateFunction from ..tracing.base_tracer import BaseTracer @@ -32,10 +32,10 @@ class LookupTable: def __getitem__(self, key: Union[int, BaseTracer]): # if a tracer is used for indexing, - # we need to create an `ArbitraryFunction` node + # we need to create an `UnivariateFunction` node # because the result will be determined during the runtime if isinstance(key, BaseTracer): - traced_computation = ArbitraryFunction( + traced_computation = UnivariateFunction( input_base_value=key.output, arbitrary_func=LookupTable._checked_indexing, output_dtype=self.output_dtype, diff --git a/concrete/common/mlir/converters.py b/concrete/common/mlir/converters.py index aae80ad1b..255a6720a 100644 --- a/concrete/common/mlir/converters.py +++ b/concrete/common/mlir/converters.py @@ -22,7 +22,7 @@ from ..data_types.dtypes_helpers import ( ) from ..data_types.integers import Integer from ..debugging.custom_assert import custom_assert -from ..representation.intermediate import Add, ArbitraryFunction, Constant, Dot, Mul, Sub +from ..representation.intermediate import Add, Constant, Dot, Mul, Sub, UnivariateFunction from ..values import TensorValue @@ -165,7 +165,7 @@ def constant(node, _, __, ctx): def apply_lut(node, preds, ir_to_mlir_node, ctx): - """Convert an arbitrary function intermediate node.""" + """Convert a UnivariateFunction intermediate node.""" custom_assert(len(node.inputs) == 1, "LUT should have a single input") custom_assert(len(node.outputs) == 1, "LUT should have a single output") if not value_is_encrypted_scalar_unsigned_integer(node.inputs[0]): @@ -224,7 +224,7 @@ V0_OPSET_CONVERSION_FUNCTIONS = { Sub: sub, Mul: mul, Constant: constant, - ArbitraryFunction: apply_lut, + UnivariateFunction: apply_lut, Dot: dot, } diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py index 28280a4ad..6c37cddf1 100644 --- a/concrete/common/mlir/utils.py +++ b/concrete/common/mlir/utils.py @@ -11,7 +11,7 @@ from ..data_types.dtypes_helpers import ( ) from ..debugging.custom_assert import assert_true from ..operator_graph import OPGraph -from ..representation.intermediate import ArbitraryFunction +from ..representation.intermediate import UnivariateFunction # TODO: should come from compiler, through an API, #402 ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB = 7 @@ -81,7 +81,7 @@ def update_bit_width_for_mlir(op_graph: OPGraph): # TODO: remove this workaround, which was for #279, once the compiler can handle # smaller tables, #412 - has_a_table = any(isinstance(node, ArbitraryFunction) for node in op_graph.graph.nodes) + has_a_table = any(isinstance(node, UnivariateFunction) for node in op_graph.graph.nodes) if has_a_table: max_bit_width = ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB @@ -104,7 +104,7 @@ def extend_direct_lookup_tables(op_graph: OPGraph): op_graph: graph to update lookup tables for """ for node in op_graph.graph.nodes: - if isinstance(node, ArbitraryFunction) and node.op_name == "TLU": + if isinstance(node, UnivariateFunction) and node.op_name == "TLU": table = node.op_kwargs["table"] bit_width = cast(Integer, node.inputs[0].dtype).bit_width expected_length = 2 ** bit_width diff --git a/concrete/common/optimization/topological.py b/concrete/common/optimization/topological.py index e83af0ae1..b9941cc10 100644 --- a/concrete/common/optimization/topological.py +++ b/concrete/common/optimization/topological.py @@ -10,7 +10,7 @@ from ..data_types.floats import Float from ..data_types.integers import Integer from ..debugging.custom_assert import assert_true, custom_assert from ..operator_graph import OPGraph -from ..representation.intermediate import ArbitraryFunction, Constant, Input, IntermediateNode +from ..representation.intermediate import Constant, Input, IntermediateNode, UnivariateFunction from ..values import TensorValue @@ -18,7 +18,7 @@ def fuse_float_operations( op_graph: OPGraph, compilation_artifacts: Optional[CompilationArtifacts] = None, ): - """Find and fuse float domains into single Integer to Integer ArbitraryFunction. + """Find and fuse float domains into single Integer to Integer UnivariateFunction. Args: op_graph (OPGraph): The OPGraph to simplify @@ -92,8 +92,8 @@ def convert_float_subgraph_to_fused_node( float_subgraph_start_nodes: Set[IntermediateNode], terminal_node: IntermediateNode, subgraph_all_nodes: Set[IntermediateNode], -) -> Optional[Tuple[ArbitraryFunction, IntermediateNode]]: - """Convert a float subgraph to an equivalent fused ArbitraryFunction node. +) -> Optional[Tuple[UnivariateFunction, IntermediateNode]]: + """Convert a float subgraph to an equivalent fused UnivariateFunction node. Args: op_graph (OPGraph): The OPGraph the float subgraph is part of. @@ -103,7 +103,7 @@ def convert_float_subgraph_to_fused_node( subgraph_all_nodes (Set[IntermediateNode]): All the nodes in the float subgraph. Returns: - Optional[Tuple[ArbitraryFunction, IntermediateNode]]: None if the float subgraph + Optional[Tuple[UnivariateFunction, IntermediateNode]]: None if the float subgraph cannot be fused, otherwise returns a tuple containing the fused node and the node whose output must be plugged as the input to the subgraph. """ @@ -161,7 +161,7 @@ def convert_float_subgraph_to_fused_node( ) # Create fused_node - fused_node = ArbitraryFunction( + fused_node = UnivariateFunction( deepcopy(new_subgraph_variable_input.inputs[0]), lambda x, float_op_subgraph, terminal_node: float_op_subgraph.evaluate({0: x})[ terminal_node @@ -251,7 +251,7 @@ def subgraph_values_allow_fusing( ): """Check if a subgraph's values are compatible with fusing. - A fused subgraph for example only works on an input tensor if the resulting ArbitraryFunction + A fused subgraph for example only works on an input tensor if the resulting UnivariateFunction can be applied per cell, hence shuffling or tensor shape changes make fusing impossible. Args: @@ -273,12 +273,12 @@ def subgraph_values_allow_fusing( f"only works for subgraphs with 1 variable input node, got {num_variable_input_nodes}", ) - # Some ArbitraryFunction nodes have baked constants that need to be taken into account for the + # Some UnivariateFunction nodes have baked constants that need to be taken into account for the # max size computation baked_constants_ir_nodes = [ baked_constant_base_value for node in subgraph_all_nodes - if isinstance(node, ArbitraryFunction) + if isinstance(node, UnivariateFunction) if (baked_constant_base_value := node.op_attributes.get("baked_constant_ir_node", None)) is not None ] @@ -297,7 +297,7 @@ def subgraph_values_allow_fusing( # A cheap check is that the variable input node must have the biggest size, i.e. have the most # elements, meaning all constants will broadcast to its shape. This is because the - # ArbitraryFunction input and output must have the same shape so that it can be applied to each + # UnivariateFunction input and output must have the same shape so that it can be applied to each # of the input tensor cells. # There *may* be a way to manage the other case by simulating the broadcast of the smaller input # array and then concatenating/stacking the results. This is not currently doable as we don't @@ -343,5 +343,5 @@ def subgraph_has_unique_variable_input( bool: True if only one of the nodes is not an Constant """ # Only one input to the subgraph where computations are done in floats is variable, this - # is the only case we can manage with ArbitraryFunction fusing + # is the only case we can manage with UnivariateFunction fusing return sum(not isinstance(node, Constant) for node in float_subgraph_start_nodes) == 1 diff --git a/concrete/common/representation/intermediate.py b/concrete/common/representation/intermediate.py index e5d5328ab..2183a8597 100644 --- a/concrete/common/representation/intermediate.py +++ b/concrete/common/representation/intermediate.py @@ -196,11 +196,14 @@ class Constant(IntermediateNode): return str(self.constant_data) -class ArbitraryFunction(IntermediateNode): - """Node representing a univariate arbitrary function, e.g. sin(x).""" +class UnivariateFunction(IntermediateNode): + """Node representing an univariate arbitrary function, e.g. sin(x).""" # The arbitrary_func is not optional but mypy has a long standing bug and is not able to # understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623 + # arbitrary_func can take more than one argument but during evaluation the input variable will + # be the first argument passed to it. You can add other constant arguments needed for the proper + # execution of the function through op_args and op_kwargs. arbitrary_func: Optional[Callable] op_name: str op_args: Tuple[Any, ...] @@ -240,9 +243,9 @@ class ArbitraryFunction(IntermediateNode): return self.op_name def get_table(self) -> List[Any]: - """Get the table for the current input value of this ArbitraryFunction. + """Get the table for the current input value of this UnivariateFunction. - This function only works if the ArbitraryFunction input value is an unsigned Integer. + This function only works if the UnivariateFunction input value is an unsigned Integer. Returns: List[Any]: The table. @@ -290,7 +293,7 @@ class Dot(IntermediateNode): """Return the node representing a dot product.""" _n_in: int = 2 - # Optional, same issue as in ArbitraryFunction for mypy + # Optional, same issue as in UnivariateFunction for mypy evaluation_function: Optional[Callable[[Any, Any], Any]] # Allows to use specialized implementations from e.g. numpy diff --git a/concrete/numpy/compile.py b/concrete/numpy/compile.py index bf8419715..689b9c5b6 100644 --- a/concrete/numpy/compile.py +++ b/concrete/numpy/compile.py @@ -96,7 +96,7 @@ def _compile_numpy_function_into_op_graph_internal( # Apply topological optimizations if they are enabled if compilation_configuration.enable_topological_optimizations: - # Fuse float operations to have int to int ArbitraryFunction + # Fuse float operations to have int to int UnivariateFunction if not check_op_graph_is_integer_program(op_graph): fuse_float_operations(op_graph, compilation_artifacts) diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index 1c62ace0a..c9b5c6966 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -9,7 +9,7 @@ from numpy.typing import DTypeLike from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype from ..common.debugging.custom_assert import assert_true, custom_assert from ..common.operator_graph import OPGraph -from ..common.representation.intermediate import ArbitraryFunction, Constant, Dot +from ..common.representation.intermediate import Constant, Dot, UnivariateFunction from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters from ..common.values import BaseValue from .np_dtypes_helpers import ( @@ -87,7 +87,7 @@ class NPTracer(BaseTracer): normalized_numpy_dtype = numpy.dtype(numpy_dtype) output_dtype = convert_numpy_dtype_to_base_data_type(numpy_dtype) - traced_computation = ArbitraryFunction( + traced_computation = UnivariateFunction( input_base_value=self.output, arbitrary_func=normalized_numpy_dtype.type, output_dtype=output_dtype, @@ -154,7 +154,7 @@ class NPTracer(BaseTracer): common_output_dtypes = cls._manage_dtypes(unary_operator, *input_tracers) custom_assert(len(common_output_dtypes) == 1) - traced_computation = ArbitraryFunction( + traced_computation = UnivariateFunction( input_base_value=input_tracers[0].output, arbitrary_func=unary_operator, output_dtype=common_output_dtypes[0], @@ -218,7 +218,7 @@ class NPTracer(BaseTracer): "in_which_input_is_constant": in_which_input_is_constant, } - traced_computation = ArbitraryFunction( + traced_computation = UnivariateFunction( input_base_value=input_tracers[in_which_input_is_variable].output, arbitrary_func=arbitrary_func, output_dtype=common_output_dtypes[0], diff --git a/docs/dev/explanation/FLOAT-FUSING.md b/docs/dev/explanation/FLOAT-FUSING.md index ac10088d8..9a20ace5f 100644 --- a/docs/dev/explanation/FLOAT-FUSING.md +++ b/docs/dev/explanation/FLOAT-FUSING.md @@ -31,7 +31,7 @@ The float subgraph that was detected: ![](../../_static/float_fusing_example/subgraph.png) -The simplified graph of operations with the float subgraph condensed in an `ArbitraryFunction` node: +The simplified graph of operations with the float subgraph condensed in an `UnivariateFunction` node: ![](../../_static/float_fusing_example/after.png) @@ -39,7 +39,7 @@ The simplified graph of operations with the float subgraph condensed in an `Arbi The first step consists in detecting where we go from floating point computation back to integers. This allows to identify the potential terminal node of the float subgraph we are going to fuse. -From the terminal node, we go back up through the nodes until we find nodes that go from integers to floats. If we can guarantee the identified float subgraph has a single variable integer input then we can replace it by an equivalent ArbitraryFunction node. +From the terminal node, we go back up through the nodes until we find nodes that go from integers to floats. If we can guarantee the identified float subgraph has a single variable integer input then we can replace it by an equivalent UnivariateFunction node. An example of a non fusable computation with that technique is: diff --git a/docs/user/tutorial/COMPILATION_ARTIFACTS.md b/docs/user/tutorial/COMPILATION_ARTIFACTS.md index 8000d63d5..6c3a72100 100644 --- a/docs/user/tutorial/COMPILATION_ARTIFACTS.md +++ b/docs/user/tutorial/COMPILATION_ARTIFACTS.md @@ -83,7 +83,7 @@ Traceback (most recent call last): File "/src/concrete/numpy/compile.py", line 103, in _compile_numpy_function_into_op_graph_internal raise ValueError( ValueError: cannot be compiled as it has nodes with either float inputs or outputs. -Offending nodes : +Offending nodes : ``` ## Manual export diff --git a/tests/common/extensions/test_table.py b/tests/common/extensions/test_table.py index cdd0481a2..d6a95999a 100644 --- a/tests/common/extensions/test_table.py +++ b/tests/common/extensions/test_table.py @@ -55,8 +55,8 @@ def test_lookup_table_encrypted_lookup(test_helpers): ref_graph.add_node(input_x) # pylint: disable=protected-access - # Need access to _checked_indexing to have is_equivalent_to work for ir.ArbitraryFunction - output_arbitrary_function = ir.ArbitraryFunction( + # Need access to _checked_indexing to have is_equivalent_to work for ir.UnivariateFunction + output_arbitrary_function = ir.UnivariateFunction( input_base_value=x, arbitrary_func=LookupTable._checked_indexing, output_dtype=table.output_dtype, @@ -68,7 +68,7 @@ def test_lookup_table_encrypted_lookup(test_helpers): ref_graph.add_edge(input_x, output_arbitrary_function, input_idx=0) - # TODO: discuss if this check is enough as == is not overloaded properly for ArbitraryFunction + # TODO: discuss if this check is enough as == is not overloaded properly for UnivariateFunction assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph) @@ -95,8 +95,8 @@ def test_lookup_table_encrypted_and_plain_lookup(test_helpers): ref_graph.add_node(input_x) # pylint: disable=protected-access - # Need access to _checked_indexing to have is_equivalent_to work for ir.ArbitraryFunction - intermediate_arbitrary_function = ir.ArbitraryFunction( + # Need access to _checked_indexing to have is_equivalent_to work for ir.UnivariateFunction + intermediate_arbitrary_function = ir.UnivariateFunction( input_base_value=x, arbitrary_func=LookupTable._checked_indexing, output_dtype=table.output_dtype, @@ -117,5 +117,5 @@ def test_lookup_table_encrypted_and_plain_lookup(test_helpers): ref_graph.add_edge(intermediate_arbitrary_function, output_add, input_idx=0) ref_graph.add_edge(constant_3, output_add, input_idx=1) - # TODO: discuss if this check is enough as == is not overloaded properly for ArbitraryFunction + # TODO: discuss if this check is enough as == is not overloaded properly for UnivariateFunction assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph) diff --git a/tests/common/representation/test_intermediate.py b/tests/common/representation/test_intermediate.py index 34251437e..847a0e562 100644 --- a/tests/common/representation/test_intermediate.py +++ b/tests/common/representation/test_intermediate.py @@ -34,15 +34,15 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En pytest.param(ir.Constant(42), None, 42, id="Constant"), pytest.param(ir.Constant(-42), None, -42, id="Constant"), pytest.param( - ir.ArbitraryFunction( + ir.UnivariateFunction( EncryptedScalar(Integer(7, False)), lambda x: x + 3, Integer(7, False) ), [10], 13, - id="ArbitraryFunction, x + 3", + id="UnivariateFunction, x + 3", ), pytest.param( - ir.ArbitraryFunction( + ir.UnivariateFunction( EncryptedScalar(Integer(7, False)), lambda x, y: x + y, Integer(7, False), @@ -50,10 +50,10 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En ), [10], 13, - id="ArbitraryFunction, (x, y) -> x + y, where y is constant == 3", + id="UnivariateFunction, (x, y) -> x + y, where y is constant == 3", ), pytest.param( - ir.ArbitraryFunction( + ir.UnivariateFunction( EncryptedScalar(Integer(7, False)), lambda x, y: y[x], Integer(7, False), @@ -61,10 +61,10 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En ), [2], 3, - id="ArbitraryFunction, (x, y) -> y[x], where y is constant == (1, 2, 3, 4)", + id="UnivariateFunction, (x, y) -> y[x], where y is constant == (1, 2, 3, 4)", ), pytest.param( - ir.ArbitraryFunction( + ir.UnivariateFunction( EncryptedScalar(Integer(7, False)), lambda x, y: y[3], Integer(7, False), @@ -72,7 +72,7 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En ), [2], 4, - id="ArbitraryFunction, x, y -> y[3], where y is constant == (1, 2, 3, 4)", + id="UnivariateFunction, x, y -> y[3], where y is constant == (1, 2, 3, 4)", ), pytest.param( ir.Dot( @@ -209,34 +209,34 @@ def test_evaluate( False, ), ( - ir.ArbitraryFunction( + ir.UnivariateFunction( EncryptedScalar(Integer(8, False)), lambda x: x, Integer(8, False) ), - ir.ArbitraryFunction( + ir.UnivariateFunction( EncryptedScalar(Integer(8, False)), lambda x: x, Integer(8, False) ), True, ), ( - ir.ArbitraryFunction( + ir.UnivariateFunction( EncryptedScalar(Integer(8, False)), lambda x: x, Integer(8, False), op_args=(1, 2, 3), ), - ir.ArbitraryFunction( + ir.UnivariateFunction( EncryptedScalar(Integer(8, False)), lambda x: x, Integer(8, False) ), False, ), ( - ir.ArbitraryFunction( + ir.UnivariateFunction( EncryptedScalar(Integer(8, False)), lambda x: x, Integer(8, False), op_kwargs={"tuple": (1, 2, 3)}, ), - ir.ArbitraryFunction( + ir.UnivariateFunction( EncryptedScalar(Integer(8, False)), lambda x: x, Integer(8, False) ), False, diff --git a/tests/conftest.py b/tests/conftest.py index 2a0d48873..c5d45b98b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,13 +9,13 @@ import pytest from concrete.common.representation.intermediate import ( ALL_IR_NODES, Add, - ArbitraryFunction, Constant, Dot, Input, IntermediateNode, Mul, Sub, + UnivariateFunction, ) @@ -66,10 +66,10 @@ def python_functions_are_equal_or_equivalent(lhs: object, rhs: object) -> bool: return False -def is_equivalent_arbitrary_function(lhs: ArbitraryFunction, rhs: object) -> bool: - """Helper function to check if an ArbitraryFunction node is equivalent to an other object.""" +def is_equivalent_arbitrary_function(lhs: UnivariateFunction, rhs: object) -> bool: + """Helper function to check if an UnivariateFunction node is equivalent to an other object.""" return ( - isinstance(rhs, ArbitraryFunction) + isinstance(rhs, UnivariateFunction) and python_functions_are_equal_or_equivalent(lhs.arbitrary_func, rhs.arbitrary_func) and lhs.op_args == rhs.op_args and lhs.op_kwargs == rhs.op_kwargs @@ -127,7 +127,7 @@ def is_equivalent_intermediate_node(lhs: IntermediateNode, rhs: object) -> bool: EQUIVALENT_TEST_FUNC: Dict[Type, Callable[..., bool]] = { Add: is_equivalent_add, - ArbitraryFunction: is_equivalent_arbitrary_function, + UnivariateFunction: is_equivalent_arbitrary_function, Constant: is_equivalent_constant, Dot: is_equivalent_dot, Input: is_equivalent_input, diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index 01b188878..95e500c34 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -386,24 +386,24 @@ def test_tracing_astype( [ pytest.param( {"x": EncryptedScalar(Integer(7, is_signed=False))}, - ir.ArbitraryFunction, + ir.UnivariateFunction, ), pytest.param( {"x": EncryptedScalar(Integer(32, is_signed=True))}, - ir.ArbitraryFunction, + ir.UnivariateFunction, ), pytest.param( {"x": EncryptedScalar(Integer(64, is_signed=True))}, - ir.ArbitraryFunction, + ir.UnivariateFunction, ), pytest.param( {"x": EncryptedScalar(Integer(128, is_signed=True))}, - ir.ArbitraryFunction, + ir.UnivariateFunction, marks=pytest.mark.xfail(strict=True, raises=NotImplementedError), ), pytest.param( {"x": EncryptedScalar(Float(64))}, - ir.ArbitraryFunction, + ir.UnivariateFunction, ), ], )