dev: add OPGraphs to compilation artifacts during float fusing

This commit is contained in:
Arthur Meyre
2021-09-01 11:45:51 +02:00
parent 1e8debfb57
commit b582e68cd0
2 changed files with 16 additions and 5 deletions

View File

@@ -4,21 +4,29 @@ from typing import Dict, List, Optional, Set, Tuple
import networkx as nx
from ..compilation.artifacts import CompilationArtifacts
from ..data_types.floats import Float
from ..data_types.integers import Integer
from ..operator_graph import OPGraph
from ..representation import intermediate as ir
def fuse_float_operations(op_graph: OPGraph):
def fuse_float_operations(
op_graph: OPGraph,
compilation_artifacts: Optional[CompilationArtifacts] = None,
):
"""Finds and fuses float domains into single Integer to Integer ArbitraryFunction.
Args:
op_graph (OPGraph): The OPGraph to simplify
compilation_artifacts (Optional[CompilationArtifacts]): The CompilationArtifacts of the
current compilation, this argument is optional as it's not required to execute float
fusing.
"""
nx_graph = op_graph.graph
processed_terminal_nodes: Set[ir.IntermediateNode] = set()
number_of_fuse = 0
while True:
float_subgraph_search_result = find_float_subgraph_with_unique_terminal_node(
nx_graph, processed_terminal_nodes
@@ -68,6 +76,12 @@ def fuse_float_operations(op_graph: OPGraph):
nx_graph.add_edge(node_before_subgraph, fused_node, input_idx=0)
op_graph.prune_nodes()
if compilation_artifacts is not None:
compilation_artifacts.add_operation_graph(
f"after-float-fuse-{number_of_fuse}", op_graph
)
number_of_fuse += 1
def convert_float_subgraph_to_fused_node(

View File

@@ -90,10 +90,7 @@ def _compile_numpy_function_into_op_graph_internal(
if compilation_configuration.enable_topological_optimizations:
# Fuse float operations to have int to int ArbitraryFunction
if not check_op_graph_is_integer_program(op_graph):
fuse_float_operations(op_graph)
# Add the fused floats graph as an artifact
compilation_artifacts.add_operation_graph("fused-float-operations", op_graph)
fuse_float_operations(op_graph, compilation_artifacts)
# TODO: To be removed once we support more than integers
offending_non_integer_nodes: List[ir.IntermediateNode] = []