diff --git a/concrete/common/operator_graph.py b/concrete/common/operator_graph.py index d7355476b..92ed97157 100644 --- a/concrete/common/operator_graph.py +++ b/concrete/common/operator_graph.py @@ -31,7 +31,6 @@ class OPGraph: input_nodes: Dict[int, Input], output_nodes: Dict[int, IntermediateNode], ) -> None: - assert_true(len(input_nodes) > 0, "Got a graph without input nodes which is not supported") assert_true( all(isinstance(node, Input) for node in input_nodes.values()), "Got input nodes that were not Input, which is not supported", @@ -47,6 +46,7 @@ class OPGraph: self.prune_nodes() def __call__(self, *args) -> Union[Any, Tuple[Any, ...]]: + assert_true(len(self.input_nodes) > 0, "Cannot evaluate a graph with no input nodes") inputs = dict(enumerate(args)) assert_true( diff --git a/concrete/common/optimization/topological.py b/concrete/common/optimization/topological.py index ca6cdb012..a7e5a0708 100644 --- a/concrete/common/optimization/topological.py +++ b/concrete/common/optimization/topological.py @@ -1,13 +1,16 @@ """File holding topological optimization/simplification code.""" import itertools +from collections import defaultdict from copy import deepcopy -from typing import Dict, List, Optional, Set, Tuple, cast +from typing import DefaultDict, Dict, List, Optional, Set, Tuple, cast import networkx as nx +from loguru import logger from ..compilation.artifacts import CompilationArtifacts from ..data_types.floats import Float from ..data_types.integers import Integer +from ..debugging import get_printable_graph from ..debugging.custom_assert import assert_true from ..operator_graph import OPGraph from ..representation.intermediate import Constant, Input, IntermediateNode, UnivariateFunction @@ -112,11 +115,30 @@ def convert_float_subgraph_to_fused_node( output must be plugged as the input to the subgraph. """ - subgraph_can_be_fused = subgraph_has_unique_variable_input( - float_subgraph_start_nodes - ) and subgraph_values_allow_fusing(float_subgraph_start_nodes, subgraph_all_nodes) + node_with_issues_for_fusing: DefaultDict[IntermediateNode, List[str]] = defaultdict(list) + subgraph_can_be_fused = subgraph_has_unique_variable_input( + float_subgraph_start_nodes, terminal_node, node_with_issues_for_fusing + ) + + if subgraph_can_be_fused: + # subgraph_values_allow_fusing can be called iff the subgraph has a unique variable input + subgraph_can_be_fused = subgraph_values_allow_fusing( + float_subgraph_start_nodes, subgraph_all_nodes, node_with_issues_for_fusing + ) + + # This test is separate from the previous one to only handle printing issues once if not subgraph_can_be_fused: + float_subgraph = nx.MultiDiGraph(op_graph.graph.subgraph(subgraph_all_nodes)) + float_subgraph_as_op_graph = OPGraph.from_graph(float_subgraph, [], [terminal_node]) + + printable_graph = get_printable_graph( + float_subgraph_as_op_graph, + show_data_types=True, + highlighted_nodes=node_with_issues_for_fusing, + ) + message = f"The following subgraph is not fusable:\n{printable_graph}" + logger.warning(message) return None # Only one variable input node, find which node feeds its input @@ -258,7 +280,8 @@ def find_float_subgraph_with_unique_terminal_node( def subgraph_values_allow_fusing( float_subgraph_start_nodes: Set[IntermediateNode], subgraph_all_nodes: Set[IntermediateNode], -): + node_with_issues_for_fusing: DefaultDict[IntermediateNode, List[str]], +) -> bool: """Check if a subgraph's values are compatible with fusing. A fused subgraph for example only works on an input tensor if the resulting UnivariateFunction @@ -267,6 +290,8 @@ def subgraph_values_allow_fusing( Args: float_subgraph_start_nodes (Set[IntermediateNode]): The nodes starting the float subgraph. subgraph_all_nodes (Set[IntermediateNode]): All the nodes in the float subgraph. + node_with_issues_for_fusing (DefaultDict[IntermediateNode, List[str]]): Dictionary to fill + with potential nodes issues preventing fusing. Returns: bool: True if all inputs and outputs of the nodes in the subgraph are compatible with fusing @@ -286,10 +311,10 @@ def subgraph_values_allow_fusing( # Some UnivariateFunction nodes have baked constants that need to be taken into account for the # max size computation baked_constants_ir_nodes = [ - baked_constant_base_value + baked_constant_ir_node for node in subgraph_all_nodes if isinstance(node, UnivariateFunction) - if (baked_constant_base_value := node.op_attributes.get("baked_constant_ir_node", None)) + if (baked_constant_ir_node := node.op_attributes.get("baked_constant_ir_node", None)) is not None ] @@ -332,26 +357,72 @@ def subgraph_values_allow_fusing( non_constant_nodes = (node for node in subgraph_all_nodes if not isinstance(node, Constant)) - return all( - all( - isinstance(output, TensorValue) and output.shape == variable_input_node_output_shape + nodes_with_different_output_shapes = { + node: [ + (output_idx, output.shape) + for output_idx, output in enumerate(node.outputs) + if isinstance(output, TensorValue) and output.shape != variable_input_node + ] + for node in non_constant_nodes + if any( + isinstance(output, TensorValue) and output.shape != variable_input_node_output_shape for output in node.outputs ) - for node in non_constant_nodes - ) + } + + for node, node_shape_infos in nodes_with_different_output_shapes.items(): + shape_issue_details = "; ".join( + f"#{output_idx}, {output_shape}" for output_idx, output_shape in node_shape_infos + ) + node_with_issues_for_fusing[node].append( + f"output shapes: {shape_issue_details} are not the same as the subgraph's input: " + f"{variable_input_node_output_shape}" + ) + + all_nodes_have_same_shape_as_input = len(nodes_with_different_output_shapes) == 0 + + if not all_nodes_have_same_shape_as_input: + node_with_issues_for_fusing[variable_input_node].append( + f"input node with shape {variable_input_node_output_shape}" + ) + + # All non constant node outputs currently need to have the same shape + return all_nodes_have_same_shape_as_input def subgraph_has_unique_variable_input( float_subgraph_start_nodes: Set[IntermediateNode], + terminal_node: IntermediateNode, + node_with_issues_for_fusing: DefaultDict[IntermediateNode, List[str]], ) -> bool: """Check that only one of the nodes starting the subgraph is variable. Args: float_subgraph_start_nodes (Set[IntermediateNode]): The nodes starting the subgraph. + terminal_node (IntermediateNode): The node ending the float subgraph. + node_with_issues_for_fusing (DefaultDict[IntermediateNode, List[str]]): Dictionary to fill + with potential nodes issues preventing fusing. Returns: bool: True if only one of the nodes is not an Constant """ - # Only one input to the subgraph where computations are done in floats is variable, this + + variable_inputs_list = [ + node for node in float_subgraph_start_nodes if not isinstance(node, Constant) + ] + variable_inputs_num = len(variable_inputs_list) + + # Only one input to the subgraph where computations are done in floats can be variable, this # is the only case we can manage with UnivariateFunction fusing - return sum(not isinstance(node, Constant) for node in float_subgraph_start_nodes) == 1 + has_unique_variable_input = variable_inputs_num == 1 + + if not has_unique_variable_input: + for node in variable_inputs_list: + node_with_issues_for_fusing[node].append( + f"one of {variable_inputs_num} variable inputs (can only have 1 for fusing)" + ) + node_with_issues_for_fusing[terminal_node].append( + f"cannot fuse here as the subgraph has {variable_inputs_num} variable inputs" + ) + + return has_unique_variable_input diff --git a/tests/common/optimization/test_float_fusing.py b/tests/common/optimization/test_float_fusing.py index e9f13c3eb..cbd2ee0ab 100644 --- a/tests/common/optimization/test_float_fusing.py +++ b/tests/common/optimization/test_float_fusing.py @@ -27,6 +27,11 @@ def no_fuse_unhandled(x, y): return intermediate.astype(numpy.int32) +def no_fuse_dot(x): + """No fuse dot""" + return numpy.dot(x, numpy.full((10,), 1.33, dtype=numpy.float64)).astype(numpy.int32) + + def simple_fuse_not_output(x): """Simple fuse not output""" intermediate = x.astype(numpy.float64) @@ -107,34 +112,96 @@ def mix_x_and_y_into_integer_and_call_f(function, x, y): ) +def get_func_params_scalar_int32(func): + """Returns a dict with parameters as scalar int32""" + + return { + param_name: EncryptedScalar(Integer(32, True)) + for param_name in signature(func).parameters.keys() + } + + @pytest.mark.parametrize( - "function_to_trace,fused", + "function_to_trace,fused,params,warning_message", [ - pytest.param(no_fuse, False, id="no_fuse"), - pytest.param(no_fuse_unhandled, False, id="no_fuse_unhandled"), - pytest.param(simple_fuse_not_output, True, id="no_fuse"), - pytest.param(simple_fuse_output, True, id="no_fuse"), + pytest.param(no_fuse, False, get_func_params_scalar_int32(no_fuse), "", id="no_fuse"), + pytest.param( + no_fuse_unhandled, + False, + get_func_params_scalar_int32(no_fuse_unhandled), + """The following subgraph is not fusable: +%0 = x # EncryptedScalar> +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of 2 variable inputs (can only have 1 for fusing) +%1 = Constant(0.7) # ClearScalar> +%2 = y # EncryptedScalar> +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of 2 variable inputs (can only have 1 for fusing) +%3 = Constant(1.3) # ClearScalar> +%4 = Add(%0, %1) # EncryptedScalar> +%5 = Add(%2, %3) # EncryptedScalar> +%6 = Add(%4, %5) # EncryptedScalar> +%7 = astype(int32)(%6) # EncryptedScalar> +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ cannot fuse here as the subgraph has 2 variable inputs +return(%7)""", # noqa: E501 # pylint: disable=line-too-long + id="no_fuse_unhandled", + ), + pytest.param( + no_fuse_dot, + False, + {"x": EncryptedTensor(Integer(32, True), (10,))}, + """The following subgraph is not fusable: +%0 = x # EncryptedTensor, shape=(10,)> +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10,) +%1 = Constant([1.33 1.33 ... 1.33 1.33]) # ClearTensor, shape=(10,)> +%2 = Dot(%0, %1) # EncryptedScalar> +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ output shapes: #0, () are not the same as the subgraph's input: (10,) +%3 = astype(int32)(%2) # EncryptedScalar> +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ output shapes: #0, () are not the same as the subgraph's input: (10,) +return(%3)""", # noqa: E501 # pylint: disable=line-too-long + id="no_fuse_dot", + ), + pytest.param( + simple_fuse_not_output, + True, + get_func_params_scalar_int32(simple_fuse_not_output), + None, + id="simple_fuse_not_output", + ), + pytest.param( + simple_fuse_output, + True, + get_func_params_scalar_int32(simple_fuse_output), + None, + id="simple_fuse_output", + ), pytest.param( lambda x, y: mix_x_and_y_intricately_and_call_f(numpy.rint, x, y), True, + get_func_params_scalar_int32(lambda x, y: None), + None, id="mix_x_and_y_intricately_and_call_f_with_rint", ), pytest.param( lambda x, y: mix_x_and_y_and_call_f(numpy.rint, x, y), True, + get_func_params_scalar_int32(lambda x, y: None), + None, id="mix_x_and_y_and_call_f_with_rint", ), ], ) -@pytest.mark.parametrize("input_", [0, 2, 42, 44]) -def test_fuse_float_operations(function_to_trace, fused, input_): +def test_fuse_float_operations( + function_to_trace, + fused, + params, + warning_message, + capfd, + remove_color_codes, +): """Test function for fuse_float_operations""" - params_names = signature(function_to_trace).parameters.keys() - op_graph = trace_numpy_function( function_to_trace, - {param_name: EncryptedScalar(Integer(32, True)) for param_name in params_names}, + params, ) orig_num_nodes = len(op_graph.graph) fuse_float_operations(op_graph) @@ -144,12 +211,19 @@ def test_fuse_float_operations(function_to_trace, fused, input_): assert fused_num_nodes < orig_num_nodes else: assert fused_num_nodes == orig_num_nodes + captured = capfd.readouterr() + assert warning_message in remove_color_codes(captured.err) - input_ = numpy.int32(input_) + for input_ in [0, 2, 42, 44]: + inputs = () + for param_input_value in params.values(): + if param_input_value.is_scalar: + input_ = numpy.int32(input_) + else: + input_ = numpy.full(param_input_value.shape, input_, dtype=numpy.int32) + inputs += (input_,) - num_params = len(params_names) - inputs = (input_,) * num_params - assert function_to_trace(*inputs) == op_graph(*inputs) + assert numpy.array_equal(function_to_trace(*inputs), op_graph(*inputs)) def subtest_tensor_no_fuse(fun, tensor_shape): diff --git a/tests/conftest.py b/tests/conftest.py index e947978f6..fda1dc7bc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ """PyTest configuration file""" import json import operator +import re from pathlib import Path from typing import Callable, Dict, Type @@ -264,3 +265,12 @@ def default_compilation_configuration(): dump_artifacts_on_unexpected_failures=False, treat_warnings_as_errors=True, ) + + +REMOVE_COLOR_CODES_RE = re.compile(r"\x1b[^m]*m") + + +@pytest.fixture +def remove_color_codes(): + """Return the re object to remove color codes""" + return lambda x: REMOVE_COLOR_CODES_RE.sub("", x)