refactor: rename ArbitraryFunction to UnivariateFunction

- the naming has always been confusing and recent changes to the code make
this rename necessary for things to be clearer
This commit is contained in:
Arthur Meyre
2021-10-11 09:53:03 +02:00
parent 44016cc80c
commit 00916bcfdb
15 changed files with 72 additions and 69 deletions

View File

@@ -14,12 +14,12 @@ from ..operator_graph import OPGraph
from ..representation.intermediate import (
ALL_IR_NODES,
Add,
ArbitraryFunction,
Constant,
Dot,
Input,
Mul,
Sub,
UnivariateFunction,
)
IR_NODE_COLOR_MAPPING = {
@@ -28,9 +28,9 @@ IR_NODE_COLOR_MAPPING = {
Add: "red",
Sub: "yellow",
Mul: "green",
ArbitraryFunction: "orange",
UnivariateFunction: "orange",
Dot: "purple",
"ArbitraryFunction": "orange",
"UnivariateFunction": "orange",
"TLU": "grey",
"output": "magenta",
}
@@ -71,7 +71,7 @@ def draw_graph(
value_to_return = IR_NODE_COLOR_MAPPING[type(node)]
if node in output_nodes:
value_to_return = IR_NODE_COLOR_MAPPING["output"]
elif isinstance(node, ArbitraryFunction):
elif isinstance(node, UnivariateFunction):
value_to_return = IR_NODE_COLOR_MAPPING.get(node.op_name, value_to_return)
return value_to_return

View File

@@ -6,7 +6,7 @@ import networkx as nx
from ..debugging.custom_assert import custom_assert
from ..operator_graph import OPGraph
from ..representation.intermediate import ArbitraryFunction, Constant, Input
from ..representation.intermediate import Constant, Input, UnivariateFunction
def output_data_type_to_string(node):
@@ -61,7 +61,7 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
base_name = node.__class__.__name__
if isinstance(node, ArbitraryFunction):
if isinstance(node, UnivariateFunction):
base_name = node.op_name
what_to_print = base_name + "("

View File

@@ -6,7 +6,7 @@ from typing import Iterable, Tuple, Union
from ..common_helpers import is_a_power_of_2
from ..data_types.base import BaseDataType
from ..data_types.integers import make_integer_to_hold
from ..representation.intermediate import ArbitraryFunction
from ..representation.intermediate import UnivariateFunction
from ..tracing.base_tracer import BaseTracer
@@ -32,10 +32,10 @@ class LookupTable:
def __getitem__(self, key: Union[int, BaseTracer]):
# if a tracer is used for indexing,
# we need to create an `ArbitraryFunction` node
# we need to create an `UnivariateFunction` node
# because the result will be determined during the runtime
if isinstance(key, BaseTracer):
traced_computation = ArbitraryFunction(
traced_computation = UnivariateFunction(
input_base_value=key.output,
arbitrary_func=LookupTable._checked_indexing,
output_dtype=self.output_dtype,

View File

@@ -22,7 +22,7 @@ from ..data_types.dtypes_helpers import (
)
from ..data_types.integers import Integer
from ..debugging.custom_assert import custom_assert
from ..representation.intermediate import Add, ArbitraryFunction, Constant, Dot, Mul, Sub
from ..representation.intermediate import Add, Constant, Dot, Mul, Sub, UnivariateFunction
from ..values import TensorValue
@@ -165,7 +165,7 @@ def constant(node, _, __, ctx):
def apply_lut(node, preds, ir_to_mlir_node, ctx):
"""Convert an arbitrary function intermediate node."""
"""Convert a UnivariateFunction intermediate node."""
custom_assert(len(node.inputs) == 1, "LUT should have a single input")
custom_assert(len(node.outputs) == 1, "LUT should have a single output")
if not value_is_encrypted_scalar_unsigned_integer(node.inputs[0]):
@@ -224,7 +224,7 @@ V0_OPSET_CONVERSION_FUNCTIONS = {
Sub: sub,
Mul: mul,
Constant: constant,
ArbitraryFunction: apply_lut,
UnivariateFunction: apply_lut,
Dot: dot,
}

View File

@@ -11,7 +11,7 @@ from ..data_types.dtypes_helpers import (
)
from ..debugging.custom_assert import assert_true
from ..operator_graph import OPGraph
from ..representation.intermediate import ArbitraryFunction
from ..representation.intermediate import UnivariateFunction
# TODO: should come from compiler, through an API, #402
ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB = 7
@@ -81,7 +81,7 @@ def update_bit_width_for_mlir(op_graph: OPGraph):
# TODO: remove this workaround, which was for #279, once the compiler can handle
# smaller tables, #412
has_a_table = any(isinstance(node, ArbitraryFunction) for node in op_graph.graph.nodes)
has_a_table = any(isinstance(node, UnivariateFunction) for node in op_graph.graph.nodes)
if has_a_table:
max_bit_width = ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB
@@ -104,7 +104,7 @@ def extend_direct_lookup_tables(op_graph: OPGraph):
op_graph: graph to update lookup tables for
"""
for node in op_graph.graph.nodes:
if isinstance(node, ArbitraryFunction) and node.op_name == "TLU":
if isinstance(node, UnivariateFunction) and node.op_name == "TLU":
table = node.op_kwargs["table"]
bit_width = cast(Integer, node.inputs[0].dtype).bit_width
expected_length = 2 ** bit_width

View File

@@ -10,7 +10,7 @@ from ..data_types.floats import Float
from ..data_types.integers import Integer
from ..debugging.custom_assert import assert_true, custom_assert
from ..operator_graph import OPGraph
from ..representation.intermediate import ArbitraryFunction, Constant, Input, IntermediateNode
from ..representation.intermediate import Constant, Input, IntermediateNode, UnivariateFunction
from ..values import TensorValue
@@ -18,7 +18,7 @@ def fuse_float_operations(
op_graph: OPGraph,
compilation_artifacts: Optional[CompilationArtifacts] = None,
):
"""Find and fuse float domains into single Integer to Integer ArbitraryFunction.
"""Find and fuse float domains into single Integer to Integer UnivariateFunction.
Args:
op_graph (OPGraph): The OPGraph to simplify
@@ -92,8 +92,8 @@ def convert_float_subgraph_to_fused_node(
float_subgraph_start_nodes: Set[IntermediateNode],
terminal_node: IntermediateNode,
subgraph_all_nodes: Set[IntermediateNode],
) -> Optional[Tuple[ArbitraryFunction, IntermediateNode]]:
"""Convert a float subgraph to an equivalent fused ArbitraryFunction node.
) -> Optional[Tuple[UnivariateFunction, IntermediateNode]]:
"""Convert a float subgraph to an equivalent fused UnivariateFunction node.
Args:
op_graph (OPGraph): The OPGraph the float subgraph is part of.
@@ -103,7 +103,7 @@ def convert_float_subgraph_to_fused_node(
subgraph_all_nodes (Set[IntermediateNode]): All the nodes in the float subgraph.
Returns:
Optional[Tuple[ArbitraryFunction, IntermediateNode]]: None if the float subgraph
Optional[Tuple[UnivariateFunction, IntermediateNode]]: None if the float subgraph
cannot be fused, otherwise returns a tuple containing the fused node and the node whose
output must be plugged as the input to the subgraph.
"""
@@ -161,7 +161,7 @@ def convert_float_subgraph_to_fused_node(
)
# Create fused_node
fused_node = ArbitraryFunction(
fused_node = UnivariateFunction(
deepcopy(new_subgraph_variable_input.inputs[0]),
lambda x, float_op_subgraph, terminal_node: float_op_subgraph.evaluate({0: x})[
terminal_node
@@ -251,7 +251,7 @@ def subgraph_values_allow_fusing(
):
"""Check if a subgraph's values are compatible with fusing.
A fused subgraph for example only works on an input tensor if the resulting ArbitraryFunction
A fused subgraph for example only works on an input tensor if the resulting UnivariateFunction
can be applied per cell, hence shuffling or tensor shape changes make fusing impossible.
Args:
@@ -273,12 +273,12 @@ def subgraph_values_allow_fusing(
f"only works for subgraphs with 1 variable input node, got {num_variable_input_nodes}",
)
# Some ArbitraryFunction nodes have baked constants that need to be taken into account for the
# Some UnivariateFunction nodes have baked constants that need to be taken into account for the
# max size computation
baked_constants_ir_nodes = [
baked_constant_base_value
for node in subgraph_all_nodes
if isinstance(node, ArbitraryFunction)
if isinstance(node, UnivariateFunction)
if (baked_constant_base_value := node.op_attributes.get("baked_constant_ir_node", None))
is not None
]
@@ -297,7 +297,7 @@ def subgraph_values_allow_fusing(
# A cheap check is that the variable input node must have the biggest size, i.e. have the most
# elements, meaning all constants will broadcast to its shape. This is because the
# ArbitraryFunction input and output must have the same shape so that it can be applied to each
# UnivariateFunction input and output must have the same shape so that it can be applied to each
# of the input tensor cells.
# There *may* be a way to manage the other case by simulating the broadcast of the smaller input
# array and then concatenating/stacking the results. This is not currently doable as we don't
@@ -343,5 +343,5 @@ def subgraph_has_unique_variable_input(
bool: True if only one of the nodes is not an Constant
"""
# Only one input to the subgraph where computations are done in floats is variable, this
# is the only case we can manage with ArbitraryFunction fusing
# is the only case we can manage with UnivariateFunction fusing
return sum(not isinstance(node, Constant) for node in float_subgraph_start_nodes) == 1

View File

@@ -196,11 +196,14 @@ class Constant(IntermediateNode):
return str(self.constant_data)
class ArbitraryFunction(IntermediateNode):
"""Node representing a univariate arbitrary function, e.g. sin(x)."""
class UnivariateFunction(IntermediateNode):
"""Node representing an univariate arbitrary function, e.g. sin(x)."""
# The arbitrary_func is not optional but mypy has a long standing bug and is not able to
# understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623
# arbitrary_func can take more than one argument but during evaluation the input variable will
# be the first argument passed to it. You can add other constant arguments needed for the proper
# execution of the function through op_args and op_kwargs.
arbitrary_func: Optional[Callable]
op_name: str
op_args: Tuple[Any, ...]
@@ -240,9 +243,9 @@ class ArbitraryFunction(IntermediateNode):
return self.op_name
def get_table(self) -> List[Any]:
"""Get the table for the current input value of this ArbitraryFunction.
"""Get the table for the current input value of this UnivariateFunction.
This function only works if the ArbitraryFunction input value is an unsigned Integer.
This function only works if the UnivariateFunction input value is an unsigned Integer.
Returns:
List[Any]: The table.
@@ -290,7 +293,7 @@ class Dot(IntermediateNode):
"""Return the node representing a dot product."""
_n_in: int = 2
# Optional, same issue as in ArbitraryFunction for mypy
# Optional, same issue as in UnivariateFunction for mypy
evaluation_function: Optional[Callable[[Any, Any], Any]]
# Allows to use specialized implementations from e.g. numpy

View File

@@ -96,7 +96,7 @@ def _compile_numpy_function_into_op_graph_internal(
# Apply topological optimizations if they are enabled
if compilation_configuration.enable_topological_optimizations:
# Fuse float operations to have int to int ArbitraryFunction
# Fuse float operations to have int to int UnivariateFunction
if not check_op_graph_is_integer_program(op_graph):
fuse_float_operations(op_graph, compilation_artifacts)

View File

@@ -9,7 +9,7 @@ from numpy.typing import DTypeLike
from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype
from ..common.debugging.custom_assert import assert_true, custom_assert
from ..common.operator_graph import OPGraph
from ..common.representation.intermediate import ArbitraryFunction, Constant, Dot
from ..common.representation.intermediate import Constant, Dot, UnivariateFunction
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
from ..common.values import BaseValue
from .np_dtypes_helpers import (
@@ -87,7 +87,7 @@ class NPTracer(BaseTracer):
normalized_numpy_dtype = numpy.dtype(numpy_dtype)
output_dtype = convert_numpy_dtype_to_base_data_type(numpy_dtype)
traced_computation = ArbitraryFunction(
traced_computation = UnivariateFunction(
input_base_value=self.output,
arbitrary_func=normalized_numpy_dtype.type,
output_dtype=output_dtype,
@@ -154,7 +154,7 @@ class NPTracer(BaseTracer):
common_output_dtypes = cls._manage_dtypes(unary_operator, *input_tracers)
custom_assert(len(common_output_dtypes) == 1)
traced_computation = ArbitraryFunction(
traced_computation = UnivariateFunction(
input_base_value=input_tracers[0].output,
arbitrary_func=unary_operator,
output_dtype=common_output_dtypes[0],
@@ -218,7 +218,7 @@ class NPTracer(BaseTracer):
"in_which_input_is_constant": in_which_input_is_constant,
}
traced_computation = ArbitraryFunction(
traced_computation = UnivariateFunction(
input_base_value=input_tracers[in_which_input_is_variable].output,
arbitrary_func=arbitrary_func,
output_dtype=common_output_dtypes[0],

View File

@@ -31,7 +31,7 @@ The float subgraph that was detected:
![](../../_static/float_fusing_example/subgraph.png)
The simplified graph of operations with the float subgraph condensed in an `ArbitraryFunction` node:
The simplified graph of operations with the float subgraph condensed in an `UnivariateFunction` node:
![](../../_static/float_fusing_example/after.png)
@@ -39,7 +39,7 @@ The simplified graph of operations with the float subgraph condensed in an `Arbi
The first step consists in detecting where we go from floating point computation back to integers. This allows to identify the potential terminal node of the float subgraph we are going to fuse.
From the terminal node, we go back up through the nodes until we find nodes that go from integers to floats. If we can guarantee the identified float subgraph has a single variable integer input then we can replace it by an equivalent ArbitraryFunction node.
From the terminal node, we go back up through the nodes until we find nodes that go from integers to floats. If we can guarantee the identified float subgraph has a single variable integer input then we can replace it by an equivalent UnivariateFunction node.
An example of a non fusable computation with that technique is:

View File

@@ -83,7 +83,7 @@ Traceback (most recent call last):
File "/src/concrete/numpy/compile.py", line 103, in _compile_numpy_function_into_op_graph_internal
raise ValueError(
ValueError: <lambda> cannot be compiled as it has nodes with either float inputs or outputs.
Offending nodes : <concrete.common.representation.intermediate.ArbitraryFunction object at 0x7f6689fd37f0>
Offending nodes : <concrete.common.representation.intermediate.UnivariateFunction object at 0x7f6689fd37f0>
```
## Manual export

View File

@@ -55,8 +55,8 @@ def test_lookup_table_encrypted_lookup(test_helpers):
ref_graph.add_node(input_x)
# pylint: disable=protected-access
# Need access to _checked_indexing to have is_equivalent_to work for ir.ArbitraryFunction
output_arbitrary_function = ir.ArbitraryFunction(
# Need access to _checked_indexing to have is_equivalent_to work for ir.UnivariateFunction
output_arbitrary_function = ir.UnivariateFunction(
input_base_value=x,
arbitrary_func=LookupTable._checked_indexing,
output_dtype=table.output_dtype,
@@ -68,7 +68,7 @@ def test_lookup_table_encrypted_lookup(test_helpers):
ref_graph.add_edge(input_x, output_arbitrary_function, input_idx=0)
# TODO: discuss if this check is enough as == is not overloaded properly for ArbitraryFunction
# TODO: discuss if this check is enough as == is not overloaded properly for UnivariateFunction
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)
@@ -95,8 +95,8 @@ def test_lookup_table_encrypted_and_plain_lookup(test_helpers):
ref_graph.add_node(input_x)
# pylint: disable=protected-access
# Need access to _checked_indexing to have is_equivalent_to work for ir.ArbitraryFunction
intermediate_arbitrary_function = ir.ArbitraryFunction(
# Need access to _checked_indexing to have is_equivalent_to work for ir.UnivariateFunction
intermediate_arbitrary_function = ir.UnivariateFunction(
input_base_value=x,
arbitrary_func=LookupTable._checked_indexing,
output_dtype=table.output_dtype,
@@ -117,5 +117,5 @@ def test_lookup_table_encrypted_and_plain_lookup(test_helpers):
ref_graph.add_edge(intermediate_arbitrary_function, output_add, input_idx=0)
ref_graph.add_edge(constant_3, output_add, input_idx=1)
# TODO: discuss if this check is enough as == is not overloaded properly for ArbitraryFunction
# TODO: discuss if this check is enough as == is not overloaded properly for UnivariateFunction
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)

View File

@@ -34,15 +34,15 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
pytest.param(ir.Constant(42), None, 42, id="Constant"),
pytest.param(ir.Constant(-42), None, -42, id="Constant"),
pytest.param(
ir.ArbitraryFunction(
ir.UnivariateFunction(
EncryptedScalar(Integer(7, False)), lambda x: x + 3, Integer(7, False)
),
[10],
13,
id="ArbitraryFunction, x + 3",
id="UnivariateFunction, x + 3",
),
pytest.param(
ir.ArbitraryFunction(
ir.UnivariateFunction(
EncryptedScalar(Integer(7, False)),
lambda x, y: x + y,
Integer(7, False),
@@ -50,10 +50,10 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
),
[10],
13,
id="ArbitraryFunction, (x, y) -> x + y, where y is constant == 3",
id="UnivariateFunction, (x, y) -> x + y, where y is constant == 3",
),
pytest.param(
ir.ArbitraryFunction(
ir.UnivariateFunction(
EncryptedScalar(Integer(7, False)),
lambda x, y: y[x],
Integer(7, False),
@@ -61,10 +61,10 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
),
[2],
3,
id="ArbitraryFunction, (x, y) -> y[x], where y is constant == (1, 2, 3, 4)",
id="UnivariateFunction, (x, y) -> y[x], where y is constant == (1, 2, 3, 4)",
),
pytest.param(
ir.ArbitraryFunction(
ir.UnivariateFunction(
EncryptedScalar(Integer(7, False)),
lambda x, y: y[3],
Integer(7, False),
@@ -72,7 +72,7 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
),
[2],
4,
id="ArbitraryFunction, x, y -> y[3], where y is constant == (1, 2, 3, 4)",
id="UnivariateFunction, x, y -> y[3], where y is constant == (1, 2, 3, 4)",
),
pytest.param(
ir.Dot(
@@ -209,34 +209,34 @@ def test_evaluate(
False,
),
(
ir.ArbitraryFunction(
ir.UnivariateFunction(
EncryptedScalar(Integer(8, False)), lambda x: x, Integer(8, False)
),
ir.ArbitraryFunction(
ir.UnivariateFunction(
EncryptedScalar(Integer(8, False)), lambda x: x, Integer(8, False)
),
True,
),
(
ir.ArbitraryFunction(
ir.UnivariateFunction(
EncryptedScalar(Integer(8, False)),
lambda x: x,
Integer(8, False),
op_args=(1, 2, 3),
),
ir.ArbitraryFunction(
ir.UnivariateFunction(
EncryptedScalar(Integer(8, False)), lambda x: x, Integer(8, False)
),
False,
),
(
ir.ArbitraryFunction(
ir.UnivariateFunction(
EncryptedScalar(Integer(8, False)),
lambda x: x,
Integer(8, False),
op_kwargs={"tuple": (1, 2, 3)},
),
ir.ArbitraryFunction(
ir.UnivariateFunction(
EncryptedScalar(Integer(8, False)), lambda x: x, Integer(8, False)
),
False,

View File

@@ -9,13 +9,13 @@ import pytest
from concrete.common.representation.intermediate import (
ALL_IR_NODES,
Add,
ArbitraryFunction,
Constant,
Dot,
Input,
IntermediateNode,
Mul,
Sub,
UnivariateFunction,
)
@@ -66,10 +66,10 @@ def python_functions_are_equal_or_equivalent(lhs: object, rhs: object) -> bool:
return False
def is_equivalent_arbitrary_function(lhs: ArbitraryFunction, rhs: object) -> bool:
"""Helper function to check if an ArbitraryFunction node is equivalent to an other object."""
def is_equivalent_arbitrary_function(lhs: UnivariateFunction, rhs: object) -> bool:
"""Helper function to check if an UnivariateFunction node is equivalent to an other object."""
return (
isinstance(rhs, ArbitraryFunction)
isinstance(rhs, UnivariateFunction)
and python_functions_are_equal_or_equivalent(lhs.arbitrary_func, rhs.arbitrary_func)
and lhs.op_args == rhs.op_args
and lhs.op_kwargs == rhs.op_kwargs
@@ -127,7 +127,7 @@ def is_equivalent_intermediate_node(lhs: IntermediateNode, rhs: object) -> bool:
EQUIVALENT_TEST_FUNC: Dict[Type, Callable[..., bool]] = {
Add: is_equivalent_add,
ArbitraryFunction: is_equivalent_arbitrary_function,
UnivariateFunction: is_equivalent_arbitrary_function,
Constant: is_equivalent_constant,
Dot: is_equivalent_dot,
Input: is_equivalent_input,

View File

@@ -386,24 +386,24 @@ def test_tracing_astype(
[
pytest.param(
{"x": EncryptedScalar(Integer(7, is_signed=False))},
ir.ArbitraryFunction,
ir.UnivariateFunction,
),
pytest.param(
{"x": EncryptedScalar(Integer(32, is_signed=True))},
ir.ArbitraryFunction,
ir.UnivariateFunction,
),
pytest.param(
{"x": EncryptedScalar(Integer(64, is_signed=True))},
ir.ArbitraryFunction,
ir.UnivariateFunction,
),
pytest.param(
{"x": EncryptedScalar(Integer(128, is_signed=True))},
ir.ArbitraryFunction,
ir.UnivariateFunction,
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
),
pytest.param(
{"x": EncryptedScalar(Float(64))},
ir.ArbitraryFunction,
ir.UnivariateFunction,
),
],
)