diff --git a/hdk/common/optimization/topological.py b/hdk/common/optimization/topological.py index e374ca57e..cc66bb731 100644 --- a/hdk/common/optimization/topological.py +++ b/hdk/common/optimization/topological.py @@ -4,21 +4,29 @@ from typing import Dict, List, Optional, Set, Tuple import networkx as nx +from ..compilation.artifacts import CompilationArtifacts from ..data_types.floats import Float from ..data_types.integers import Integer from ..operator_graph import OPGraph from ..representation import intermediate as ir -def fuse_float_operations(op_graph: OPGraph): +def fuse_float_operations( + op_graph: OPGraph, + compilation_artifacts: Optional[CompilationArtifacts] = None, +): """Finds and fuses float domains into single Integer to Integer ArbitraryFunction. Args: op_graph (OPGraph): The OPGraph to simplify + compilation_artifacts (Optional[CompilationArtifacts]): The CompilationArtifacts of the + current compilation, this argument is optional as it's not required to execute float + fusing. """ nx_graph = op_graph.graph processed_terminal_nodes: Set[ir.IntermediateNode] = set() + number_of_fuse = 0 while True: float_subgraph_search_result = find_float_subgraph_with_unique_terminal_node( nx_graph, processed_terminal_nodes @@ -68,6 +76,12 @@ def fuse_float_operations(op_graph: OPGraph): nx_graph.add_edge(node_before_subgraph, fused_node, input_idx=0) op_graph.prune_nodes() + if compilation_artifacts is not None: + compilation_artifacts.add_operation_graph( + f"after-float-fuse-{number_of_fuse}", op_graph + ) + + number_of_fuse += 1 def convert_float_subgraph_to_fused_node( diff --git a/hdk/hnumpy/compile.py b/hdk/hnumpy/compile.py index 51208b96f..334490cb6 100644 --- a/hdk/hnumpy/compile.py +++ b/hdk/hnumpy/compile.py @@ -90,10 +90,7 @@ def _compile_numpy_function_into_op_graph_internal( if compilation_configuration.enable_topological_optimizations: # Fuse float operations to have int to int ArbitraryFunction if not check_op_graph_is_integer_program(op_graph): - fuse_float_operations(op_graph) - - # Add the fused floats graph as an artifact - compilation_artifacts.add_operation_graph("fused-float-operations", op_graph) + fuse_float_operations(op_graph, compilation_artifacts) # TODO: To be removed once we support more than integers offending_non_integer_nodes: List[ir.IntermediateNode] = []