From 4e40982f5a74d942e52a85ad79ff4bbdf3647fea Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Mon, 16 Aug 2021 14:04:49 +0200 Subject: [PATCH] feat(float-fusing): fuse float parts of an OPGraph during compilation - this allows to be compatible with the current compiler and squash float domains into a single int to int ArbitraryFunction --- hdk/common/optimization/__init__.py | 1 + hdk/common/optimization/topological.py | 240 ++++++++++++++++++ hdk/hnumpy/compile.py | 19 +- .../common/optimization/test_float_fusing.py | 107 ++++++++ tests/hnumpy/test_compile.py | 15 ++ 5 files changed, 381 insertions(+), 1 deletion(-) create mode 100644 hdk/common/optimization/__init__.py create mode 100644 hdk/common/optimization/topological.py create mode 100644 tests/common/optimization/test_float_fusing.py diff --git a/hdk/common/optimization/__init__.py b/hdk/common/optimization/__init__.py new file mode 100644 index 000000000..f4180d6bb --- /dev/null +++ b/hdk/common/optimization/__init__.py @@ -0,0 +1 @@ +"""Module holding various optimization/simplification code.""" diff --git a/hdk/common/optimization/topological.py b/hdk/common/optimization/topological.py new file mode 100644 index 000000000..3cc249bc9 --- /dev/null +++ b/hdk/common/optimization/topological.py @@ -0,0 +1,240 @@ +"""File holding topological optimization/simplification code.""" +from copy import deepcopy +from typing import Dict, List, Optional, Set, Tuple + +import networkx as nx + +from ..data_types.floats import Float +from ..data_types.integers import Integer +from ..operator_graph import OPGraph +from ..representation import intermediate as ir + + +def fuse_float_operations(op_graph: OPGraph): + """Finds and fuses float domains into single Integer to Integer ArbitraryFunction. + + Args: + op_graph (OPGraph): The OPGraph to simplify + """ + + nx_graph = op_graph.graph + processed_terminal_nodes: Set[ir.IntermediateNode] = set() + while True: + float_subgraph_search_result = find_float_subgraph_with_unique_terminal_node( + nx_graph, processed_terminal_nodes + ) + if float_subgraph_search_result is None: + break + + float_subgraph_start_nodes, terminal_node, subgraph_all_nodes = float_subgraph_search_result + processed_terminal_nodes.add(terminal_node) + + subgraph_conversion_result = convert_float_subgraph_to_fused_node( + op_graph, + float_subgraph_start_nodes, + terminal_node, + subgraph_all_nodes, + ) + + # Not a subgraph we can handle, continue + if subgraph_conversion_result is None: + continue + + fused_node, node_before_subgraph = subgraph_conversion_result + + nx_graph.add_node(fused_node, content=fused_node) + + if terminal_node in op_graph.output_nodes.values(): + # Output value replace it + # As the graph changes recreate the output_node_to_idx dict + output_node_to_idx: Dict[ir.IntermediateNode, List[int]] = { + out_node: [] for out_node in op_graph.output_nodes.values() + } + for output_idx, output_node in op_graph.output_nodes.items(): + output_node_to_idx[output_node].append(output_idx) + + for output_idx in output_node_to_idx.get(terminal_node, []): + op_graph.output_nodes[output_idx] = fused_node + + # Disconnect after terminal node and connect fused node instead + terminal_node_succ = list(nx_graph.successors(terminal_node)) + for succ in terminal_node_succ: + succ_edge_data = deepcopy(nx_graph.get_edge_data(terminal_node, succ)) + for edge_key, edge_data in succ_edge_data.items(): + nx_graph.remove_edge(terminal_node, succ, key=edge_key) + nx_graph.add_edge(fused_node, succ, key=edge_key, **edge_data) + + # Connect the node feeding the subgraph contained in fused_node + nx_graph.add_edge(node_before_subgraph, fused_node, input_idx=0) + + op_graph.prune_nodes() + + +def convert_float_subgraph_to_fused_node( + op_graph: OPGraph, + float_subgraph_start_nodes: Set[ir.IntermediateNode], + terminal_node: ir.IntermediateNode, + subgraph_all_nodes: Set[ir.IntermediateNode], +) -> Optional[Tuple[ir.ArbitraryFunction, ir.IntermediateNode]]: + """Converts a float subgraph to an equivalent fused ArbitraryFunction node. + + Args: + op_graph (OPGraph): The OPGraph the float subgraph is part of. + float_subgraph_start_nodes (Set[ir.IntermediateNode]): The nodes starting the float subgraph + in `op_graph`. + terminal_node (ir.IntermediateNode): The node ending the float subgraph. + subgraph_all_nodes (Set[ir.IntermediateNode]): All the nodes in the float subgraph. + + Returns: + Optional[Tuple[ir.ArbitraryFunction, ir.IntermediateNode]]: None if the float subgraph + cannot be fused, otherwise returns a tuple containing the fused node and the node whose + output must be plugged as the input to the subgraph. + """ + + if not subgraph_has_unique_variable_input(float_subgraph_start_nodes): + return None + + # Only one variable input node, find which node feeds its input + non_constant_input_nodes = [ + node for node in float_subgraph_start_nodes if not isinstance(node, ir.ConstantInput) + ] + assert len(non_constant_input_nodes) == 1 + + current_subgraph_variable_input = non_constant_input_nodes[0] + new_input_value = deepcopy(current_subgraph_variable_input.outputs[0]) + + nx_graph = op_graph.graph + + nodes_after_input_set = subgraph_all_nodes.intersection( + nx_graph.succ[current_subgraph_variable_input] + ) + + float_subgraph = nx.MultiDiGraph(nx_graph.subgraph(subgraph_all_nodes)) + + new_subgraph_variable_input = ir.Input(new_input_value, "float_subgraph_input", 0) + float_subgraph.add_node(new_subgraph_variable_input) + + for node_after_input in nodes_after_input_set: + # Connect the new input to our subgraph + edge_data_input_to_subgraph = deepcopy( + float_subgraph.get_edge_data( + current_subgraph_variable_input, + node_after_input, + ) + ) + for edge_key, edge_data in edge_data_input_to_subgraph.items(): + float_subgraph.remove_edge( + current_subgraph_variable_input, node_after_input, key=edge_key + ) + float_subgraph.add_edge( + new_subgraph_variable_input, + node_after_input, + key=edge_key, + **edge_data, + ) + + float_op_subgraph = OPGraph.from_graph( + float_subgraph, + [new_subgraph_variable_input], + [terminal_node], + ) + + # Create fused_node + fused_node = ir.ArbitraryFunction( + deepcopy(new_subgraph_variable_input.inputs[0]), + lambda x, float_op_subgraph, terminal_node: float_op_subgraph.evaluate({0: x})[ + terminal_node + ], + deepcopy(terminal_node.outputs[0].data_type), + op_kwargs={ + "float_op_subgraph": float_op_subgraph, + "terminal_node": terminal_node, + }, + op_name="Subgraph", + ) + + return ( + fused_node, + current_subgraph_variable_input, + ) + + +def find_float_subgraph_with_unique_terminal_node( + nx_graph: nx.MultiDiGraph, + processed_terminal_nodes: Set[ir.IntermediateNode], +) -> Optional[Tuple[Set[ir.IntermediateNode], ir.IntermediateNode, Set[ir.IntermediateNode]]]: + """Find a subgraph of the graph with float computations. + + The subgraph has a single terminal node with a single Integer output and has a single variable + predecessor node with a single Integer output. + + Args: + nx_graph (nx.MultiDiGraph): The networkx graph to search in. + processed_terminal_nodes (Set[ir.IntermediateNode]): The set of terminal nodes for which + subgraphs have already been searched, those will be skipped. + + Returns: + Optional[Tuple[Set[ir.IntermediateNode], ir.IntermediateNode, Set[ir.IntermediateNode]]]: + None if there are no float subgraphs to process in `nx_graph`. Otherwise returns a tuple + containing the set of nodes beginning a float subgraph, the terminal node of the + subgraph and the set of all the nodes in the subgraph. + """ + + def is_float_to_single_int_node(node: ir.IntermediateNode) -> bool: + return ( + any(isinstance(input_.data_type, Float) for input_ in node.inputs) + and len(node.outputs) == 1 + and isinstance(node.outputs[0].data_type, Integer) + ) + + def single_int_output_node(node: ir.IntermediateNode) -> bool: + return len(node.outputs) == 1 and isinstance(node.outputs[0].data_type, Integer) + + float_subgraphs_terminal_nodes = ( + node + for node in nx_graph.nodes() + if is_float_to_single_int_node(node) and node not in processed_terminal_nodes + ) + + terminal_node: ir.IntermediateNode + + try: + terminal_node = next(float_subgraphs_terminal_nodes) + except StopIteration: + return None + + # Use dict as ordered set + current_nodes = {terminal_node: None} + float_subgraph_start_nodes: Set[ir.IntermediateNode] = set() + subgraph_all_nodes: Set[ir.IntermediateNode] = set() + while current_nodes: + next_nodes: Dict[ir.IntermediateNode, None] = dict() + for node in current_nodes: + subgraph_all_nodes.add(node) + predecessors = nx_graph.pred[node] + for pred in predecessors: + if single_int_output_node(pred): + # Limit of subgraph, record that and record the node as we won't visit it + float_subgraph_start_nodes.add(pred) + subgraph_all_nodes.add(pred) + else: + next_nodes.update({pred: None}) + current_nodes = next_nodes + + return float_subgraph_start_nodes, terminal_node, subgraph_all_nodes + + +def subgraph_has_unique_variable_input( + float_subgraph_start_nodes: Set[ir.IntermediateNode], +) -> bool: + """Check that only one of the nodes starting the subgraph is variable. + + Args: + float_subgraph_start_nodes (Set[ir.IntermediateNode]): The nodes starting the subgraph. + + Returns: + bool: True if only one of the nodes is not an ir.ConstantInput + """ + # Only one input to the subgraph where computations are done in floats is variable, this + # is the only case we can manage with ArbitraryFunction fusing + return sum(not isinstance(node, ir.ConstantInput) for node in float_subgraph_start_nodes) == 1 diff --git a/hdk/hnumpy/compile.py b/hdk/hnumpy/compile.py index 89eb79713..153cc076c 100644 --- a/hdk/hnumpy/compile.py +++ b/hdk/hnumpy/compile.py @@ -1,8 +1,9 @@ """hnumpy compilation function.""" -from typing import Any, Callable, Dict, Iterator, Optional, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from ..common.bounds_measurement.dataset_eval import eval_op_graph_bounds_on_dataset +from ..common.common_helpers import check_op_graph_is_integer_program from ..common.compilation import CompilationArtifacts from ..common.data_types import BaseValue from ..common.mlir.utils import ( @@ -10,6 +11,8 @@ from ..common.mlir.utils import ( update_bit_width_for_mlir, ) from ..common.operator_graph import OPGraph +from ..common.optimization.topological import fuse_float_operations +from ..common.representation import intermediate as ir from ..hnumpy.tracing import trace_numpy_function @@ -38,6 +41,20 @@ def compile_numpy_function( # Trace op_graph = trace_numpy_function(function_to_trace, function_parameters) + # Fuse float operations to have int to int ArbitraryFunction + if not check_op_graph_is_integer_program(op_graph): + fuse_float_operations(op_graph) + + # TODO: To be removed once we support more than integers + offending_non_integer_nodes: List[ir.IntermediateNode] = [] + op_grap_is_int_prog = check_op_graph_is_integer_program(op_graph, offending_non_integer_nodes) + if not op_grap_is_int_prog: + raise ValueError( + f"{function_to_trace.__name__} cannot be compiled as it has nodes with either float " + f"inputs or outputs.\nOffending nodes : " + f"{', '.join(str(node) for node in offending_non_integer_nodes)}" + ) + # Find bounds with the dataset node_bounds = eval_op_graph_bounds_on_dataset(op_graph, dataset) diff --git a/tests/common/optimization/test_float_fusing.py b/tests/common/optimization/test_float_fusing.py new file mode 100644 index 000000000..046cfdbdb --- /dev/null +++ b/tests/common/optimization/test_float_fusing.py @@ -0,0 +1,107 @@ +"""Test file for float subgraph fusing""" + +from inspect import signature + +import numpy +import pytest + +from hdk.common.data_types.integers import Integer +from hdk.common.data_types.values import EncryptedValue +from hdk.common.optimization.topological import fuse_float_operations +from hdk.hnumpy.tracing import trace_numpy_function + + +def no_fuse(x): + """No fuse""" + return x + 2 + + +def no_fuse_unhandled(x, y): + """No fuse unhandled""" + x_1 = x + 0.7 + y_1 = y + 1.3 + intermediate = x_1 + y_1 + return intermediate.astype(numpy.int32) + + +def simple_fuse_not_output(x): + """Simple fuse not output""" + intermediate = x.astype(numpy.float64) + intermediate = intermediate.astype(numpy.int32) + return intermediate + 2 + + +def simple_fuse_output(x): + """Simple fuse output""" + return x.astype(numpy.float64).astype(numpy.int32) + + +def complex_fuse_indirect_input(x, y): + """Complex fuse""" + intermediate = x + y + intermediate = intermediate + 2 + intermediate = intermediate.astype(numpy.float32) + intermediate = intermediate.astype(numpy.int32) + x_p_1 = intermediate + 1.5 + x_p_2 = intermediate + 2.7 + x_p_3 = numpy.rint(x_p_1 + x_p_2) + return ( + x_p_3.astype(numpy.int32), + x_p_2.astype(numpy.int32), + (x_p_2 + 3).astype(numpy.int32), + x_p_3.astype(numpy.int32) + 67, + y, + (y + 4.7).astype(numpy.int32) + 3, + ) + + +def complex_fuse_direct_input(x, y): + """Complex fuse""" + x_p_1 = x + 1.5 + x_p_2 = x + 2.7 + x_p_3 = numpy.rint(x_p_1 + x_p_2) + return ( + x_p_3.astype(numpy.int32), + x_p_2.astype(numpy.int32), + (x_p_2 + 3).astype(numpy.int32), + x_p_3.astype(numpy.int32) + 67, + y, + (y + 4.7).astype(numpy.int32) + 3, + ) + + +@pytest.mark.parametrize( + "function_to_trace,fused", + [ + pytest.param(no_fuse, False, id="no_fuse"), + pytest.param(no_fuse_unhandled, False, id="no_fuse_unhandled"), + pytest.param(simple_fuse_not_output, True, id="no_fuse"), + pytest.param(simple_fuse_output, True, id="no_fuse"), + pytest.param(complex_fuse_indirect_input, True, id="complex_fuse_indirect_input"), + pytest.param(complex_fuse_direct_input, True, id="complex_fuse_direct_input"), + ], +) +@pytest.mark.parametrize("input_", [0, 2, 42, 44]) +def test_fuse_float_operations(function_to_trace, fused, input_): + """Test function for fuse_float_operations""" + + params_names = signature(function_to_trace).parameters.keys() + + op_graph = trace_numpy_function( + function_to_trace, + {param_name: EncryptedValue(Integer(32, True)) for param_name in params_names}, + ) + orig_num_nodes = len(op_graph.graph) + fuse_float_operations(op_graph) + fused_num_nodes = len(op_graph.graph) + + if fused: + assert fused_num_nodes < orig_num_nodes + else: + assert fused_num_nodes == orig_num_nodes + + input_ = numpy.int32(input_) + + num_params = len(params_names) + inputs = (input_,) * num_params + assert function_to_trace(*inputs) == op_graph(*inputs) diff --git a/tests/hnumpy/test_compile.py b/tests/hnumpy/test_compile.py index 098967130..e5a077980 100644 --- a/tests/hnumpy/test_compile.py +++ b/tests/hnumpy/test_compile.py @@ -1,6 +1,7 @@ """Test file for hnumpy compilation functions""" import itertools +import numpy import pytest from hdk.common.data_types.integers import Integer @@ -10,6 +11,14 @@ from hdk.common.extensions.table import LookupTable from hdk.hnumpy.compile import compile_numpy_function +def no_fuse_unhandled(x, y): + """No fuse unhandled""" + x_intermediate = x + 2.8 + y_intermediate = y + 9.3 + intermediate = x_intermediate + y_intermediate + return intermediate.astype(numpy.int32) + + @pytest.mark.parametrize( "function,input_ranges,list_of_arg_names", [ @@ -21,6 +30,12 @@ from hdk.hnumpy.compile import compile_numpy_function ((4, 8), (3, 4), (0, 4)), ["x", "y", "z"], ), + pytest.param( + no_fuse_unhandled, + ((-2, 2), (-2, 2)), + ["x", "y"], + marks=pytest.mark.xfail(raises=ValueError), + ), ], ) def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_names):