mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
dev: add OPGraphs to compilation artifacts during float fusing
This commit is contained in:
@@ -4,21 +4,29 @@ from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from ..compilation.artifacts import CompilationArtifacts
|
||||
from ..data_types.floats import Float
|
||||
from ..data_types.integers import Integer
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation import intermediate as ir
|
||||
|
||||
|
||||
def fuse_float_operations(op_graph: OPGraph):
|
||||
def fuse_float_operations(
|
||||
op_graph: OPGraph,
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
):
|
||||
"""Finds and fuses float domains into single Integer to Integer ArbitraryFunction.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): The OPGraph to simplify
|
||||
compilation_artifacts (Optional[CompilationArtifacts]): The CompilationArtifacts of the
|
||||
current compilation, this argument is optional as it's not required to execute float
|
||||
fusing.
|
||||
"""
|
||||
|
||||
nx_graph = op_graph.graph
|
||||
processed_terminal_nodes: Set[ir.IntermediateNode] = set()
|
||||
number_of_fuse = 0
|
||||
while True:
|
||||
float_subgraph_search_result = find_float_subgraph_with_unique_terminal_node(
|
||||
nx_graph, processed_terminal_nodes
|
||||
@@ -68,6 +76,12 @@ def fuse_float_operations(op_graph: OPGraph):
|
||||
nx_graph.add_edge(node_before_subgraph, fused_node, input_idx=0)
|
||||
|
||||
op_graph.prune_nodes()
|
||||
if compilation_artifacts is not None:
|
||||
compilation_artifacts.add_operation_graph(
|
||||
f"after-float-fuse-{number_of_fuse}", op_graph
|
||||
)
|
||||
|
||||
number_of_fuse += 1
|
||||
|
||||
|
||||
def convert_float_subgraph_to_fused_node(
|
||||
|
||||
@@ -90,10 +90,7 @@ def _compile_numpy_function_into_op_graph_internal(
|
||||
if compilation_configuration.enable_topological_optimizations:
|
||||
# Fuse float operations to have int to int ArbitraryFunction
|
||||
if not check_op_graph_is_integer_program(op_graph):
|
||||
fuse_float_operations(op_graph)
|
||||
|
||||
# Add the fused floats graph as an artifact
|
||||
compilation_artifacts.add_operation_graph("fused-float-operations", op_graph)
|
||||
fuse_float_operations(op_graph, compilation_artifacts)
|
||||
|
||||
# TODO: To be removed once we support more than integers
|
||||
offending_non_integer_nodes: List[ir.IntermediateNode] = []
|
||||
|
||||
Reference in New Issue
Block a user