refactor(optimization): refactor code to prepare new topologies fusing

refs #499
This commit is contained in:
Arthur Meyre
2021-11-17 11:53:35 +01:00
parent a712b0573c
commit b9c52947f6

View File

@@ -211,6 +211,57 @@ def convert_float_subgraph_to_fused_node(
)
def is_single_int_output_node(node: IntermediateNode) -> bool:
"""Check if a node has a single output and that output is an integer.
Args:
node (IntermediateNode): the node to check.
Returns:
bool: returns True if the node has a single integer output, False otherwise.
"""
return len(node.outputs) == 1 and isinstance(node.outputs[0].dtype, Integer)
def find_closest_single_int_output_nodes(
nx_graph: nx.MultiDiGraph,
start_nodes: List[IntermediateNode],
subgraph_all_nodes: Set[IntermediateNode],
) -> Tuple[Dict[IntermediateNode, None], Set[IntermediateNode]]:
"""Find in nx_graph the closest upstream single integer output nodes to some start nodes.
Args:
nx_graph (nx.MultiDiGraph): the networkx graph to search in.
start_nodes (List[IntermediateNode]): the nodes from which to start the search.
subgraph_all_nodes (Set[IntermediateNode]): a set that will be updated with all the nodes
visited during the search.
Returns:
Tuple[Dict[IntermediateNode, None], Set[IntermediateNode]]: returns the dict used as an
ordered set containing the found single output nodes and the updated set of the visited
nodes during the search.
"""
# Use dict as ordered set
current_nodes = {start_node: None for start_node in start_nodes}
closest_single_int_output_nodes: Dict[IntermediateNode, None] = {}
while current_nodes:
next_nodes: Dict[IntermediateNode, None] = {}
for node in current_nodes:
subgraph_all_nodes.add(node)
predecessors = nx_graph.pred[node]
for pred in predecessors:
if is_single_int_output_node(pred):
# Limit of subgraph, record that and record the node as we won't visit it
closest_single_int_output_nodes[pred] = None
subgraph_all_nodes.add(pred)
else:
next_nodes.update({pred: None})
current_nodes = next_nodes
return closest_single_int_output_nodes, subgraph_all_nodes
def find_float_subgraph_with_unique_terminal_node(
nx_graph: nx.MultiDiGraph,
processed_terminal_nodes: Set[IntermediateNode],
@@ -239,9 +290,6 @@ def find_float_subgraph_with_unique_terminal_node(
and isinstance(node.outputs[0].dtype, Integer)
)
def single_int_output_node(node: IntermediateNode) -> bool:
return len(node.outputs) == 1 and isinstance(node.outputs[0].dtype, Integer)
float_subgraphs_terminal_nodes = (
node
for node in nx_graph.nodes()
@@ -255,25 +303,15 @@ def find_float_subgraph_with_unique_terminal_node(
except StopIteration:
return None
# Use dict as ordered set
current_nodes = {terminal_node: None}
float_subgraph_start_nodes: Set[IntermediateNode] = set()
subgraph_all_nodes: Set[IntermediateNode] = set()
while current_nodes:
next_nodes: Dict[IntermediateNode, None] = {}
for node in current_nodes:
subgraph_all_nodes.add(node)
predecessors = nx_graph.pred[node]
for pred in predecessors:
if single_int_output_node(pred):
# Limit of subgraph, record that and record the node as we won't visit it
float_subgraph_start_nodes.add(pred)
subgraph_all_nodes.add(pred)
else:
next_nodes.update({pred: None})
current_nodes = next_nodes
return float_subgraph_start_nodes, terminal_node, subgraph_all_nodes
float_subgraph_start_nodes, subgraph_all_nodes = find_closest_single_int_output_nodes(
nx_graph,
[terminal_node],
subgraph_all_nodes,
)
return set(float_subgraph_start_nodes.keys()), terminal_node, subgraph_all_nodes
def subgraph_nodes_and_values_allow_fusing(