mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(float-fusing): fuse float parts of an OPGraph during compilation
- this allows to be compatible with the current compiler and squash float domains into a single int to int ArbitraryFunction
This commit is contained in:
1
hdk/common/optimization/__init__.py
Normal file
1
hdk/common/optimization/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Module holding various optimization/simplification code."""
|
||||
240
hdk/common/optimization/topological.py
Normal file
240
hdk/common/optimization/topological.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""File holding topological optimization/simplification code."""
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from ..data_types.floats import Float
|
||||
from ..data_types.integers import Integer
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation import intermediate as ir
|
||||
|
||||
|
||||
def fuse_float_operations(op_graph: OPGraph):
|
||||
"""Finds and fuses float domains into single Integer to Integer ArbitraryFunction.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): The OPGraph to simplify
|
||||
"""
|
||||
|
||||
nx_graph = op_graph.graph
|
||||
processed_terminal_nodes: Set[ir.IntermediateNode] = set()
|
||||
while True:
|
||||
float_subgraph_search_result = find_float_subgraph_with_unique_terminal_node(
|
||||
nx_graph, processed_terminal_nodes
|
||||
)
|
||||
if float_subgraph_search_result is None:
|
||||
break
|
||||
|
||||
float_subgraph_start_nodes, terminal_node, subgraph_all_nodes = float_subgraph_search_result
|
||||
processed_terminal_nodes.add(terminal_node)
|
||||
|
||||
subgraph_conversion_result = convert_float_subgraph_to_fused_node(
|
||||
op_graph,
|
||||
float_subgraph_start_nodes,
|
||||
terminal_node,
|
||||
subgraph_all_nodes,
|
||||
)
|
||||
|
||||
# Not a subgraph we can handle, continue
|
||||
if subgraph_conversion_result is None:
|
||||
continue
|
||||
|
||||
fused_node, node_before_subgraph = subgraph_conversion_result
|
||||
|
||||
nx_graph.add_node(fused_node, content=fused_node)
|
||||
|
||||
if terminal_node in op_graph.output_nodes.values():
|
||||
# Output value replace it
|
||||
# As the graph changes recreate the output_node_to_idx dict
|
||||
output_node_to_idx: Dict[ir.IntermediateNode, List[int]] = {
|
||||
out_node: [] for out_node in op_graph.output_nodes.values()
|
||||
}
|
||||
for output_idx, output_node in op_graph.output_nodes.items():
|
||||
output_node_to_idx[output_node].append(output_idx)
|
||||
|
||||
for output_idx in output_node_to_idx.get(terminal_node, []):
|
||||
op_graph.output_nodes[output_idx] = fused_node
|
||||
|
||||
# Disconnect after terminal node and connect fused node instead
|
||||
terminal_node_succ = list(nx_graph.successors(terminal_node))
|
||||
for succ in terminal_node_succ:
|
||||
succ_edge_data = deepcopy(nx_graph.get_edge_data(terminal_node, succ))
|
||||
for edge_key, edge_data in succ_edge_data.items():
|
||||
nx_graph.remove_edge(terminal_node, succ, key=edge_key)
|
||||
nx_graph.add_edge(fused_node, succ, key=edge_key, **edge_data)
|
||||
|
||||
# Connect the node feeding the subgraph contained in fused_node
|
||||
nx_graph.add_edge(node_before_subgraph, fused_node, input_idx=0)
|
||||
|
||||
op_graph.prune_nodes()
|
||||
|
||||
|
||||
def convert_float_subgraph_to_fused_node(
|
||||
op_graph: OPGraph,
|
||||
float_subgraph_start_nodes: Set[ir.IntermediateNode],
|
||||
terminal_node: ir.IntermediateNode,
|
||||
subgraph_all_nodes: Set[ir.IntermediateNode],
|
||||
) -> Optional[Tuple[ir.ArbitraryFunction, ir.IntermediateNode]]:
|
||||
"""Converts a float subgraph to an equivalent fused ArbitraryFunction node.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): The OPGraph the float subgraph is part of.
|
||||
float_subgraph_start_nodes (Set[ir.IntermediateNode]): The nodes starting the float subgraph
|
||||
in `op_graph`.
|
||||
terminal_node (ir.IntermediateNode): The node ending the float subgraph.
|
||||
subgraph_all_nodes (Set[ir.IntermediateNode]): All the nodes in the float subgraph.
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[ir.ArbitraryFunction, ir.IntermediateNode]]: None if the float subgraph
|
||||
cannot be fused, otherwise returns a tuple containing the fused node and the node whose
|
||||
output must be plugged as the input to the subgraph.
|
||||
"""
|
||||
|
||||
if not subgraph_has_unique_variable_input(float_subgraph_start_nodes):
|
||||
return None
|
||||
|
||||
# Only one variable input node, find which node feeds its input
|
||||
non_constant_input_nodes = [
|
||||
node for node in float_subgraph_start_nodes if not isinstance(node, ir.ConstantInput)
|
||||
]
|
||||
assert len(non_constant_input_nodes) == 1
|
||||
|
||||
current_subgraph_variable_input = non_constant_input_nodes[0]
|
||||
new_input_value = deepcopy(current_subgraph_variable_input.outputs[0])
|
||||
|
||||
nx_graph = op_graph.graph
|
||||
|
||||
nodes_after_input_set = subgraph_all_nodes.intersection(
|
||||
nx_graph.succ[current_subgraph_variable_input]
|
||||
)
|
||||
|
||||
float_subgraph = nx.MultiDiGraph(nx_graph.subgraph(subgraph_all_nodes))
|
||||
|
||||
new_subgraph_variable_input = ir.Input(new_input_value, "float_subgraph_input", 0)
|
||||
float_subgraph.add_node(new_subgraph_variable_input)
|
||||
|
||||
for node_after_input in nodes_after_input_set:
|
||||
# Connect the new input to our subgraph
|
||||
edge_data_input_to_subgraph = deepcopy(
|
||||
float_subgraph.get_edge_data(
|
||||
current_subgraph_variable_input,
|
||||
node_after_input,
|
||||
)
|
||||
)
|
||||
for edge_key, edge_data in edge_data_input_to_subgraph.items():
|
||||
float_subgraph.remove_edge(
|
||||
current_subgraph_variable_input, node_after_input, key=edge_key
|
||||
)
|
||||
float_subgraph.add_edge(
|
||||
new_subgraph_variable_input,
|
||||
node_after_input,
|
||||
key=edge_key,
|
||||
**edge_data,
|
||||
)
|
||||
|
||||
float_op_subgraph = OPGraph.from_graph(
|
||||
float_subgraph,
|
||||
[new_subgraph_variable_input],
|
||||
[terminal_node],
|
||||
)
|
||||
|
||||
# Create fused_node
|
||||
fused_node = ir.ArbitraryFunction(
|
||||
deepcopy(new_subgraph_variable_input.inputs[0]),
|
||||
lambda x, float_op_subgraph, terminal_node: float_op_subgraph.evaluate({0: x})[
|
||||
terminal_node
|
||||
],
|
||||
deepcopy(terminal_node.outputs[0].data_type),
|
||||
op_kwargs={
|
||||
"float_op_subgraph": float_op_subgraph,
|
||||
"terminal_node": terminal_node,
|
||||
},
|
||||
op_name="Subgraph",
|
||||
)
|
||||
|
||||
return (
|
||||
fused_node,
|
||||
current_subgraph_variable_input,
|
||||
)
|
||||
|
||||
|
||||
def find_float_subgraph_with_unique_terminal_node(
|
||||
nx_graph: nx.MultiDiGraph,
|
||||
processed_terminal_nodes: Set[ir.IntermediateNode],
|
||||
) -> Optional[Tuple[Set[ir.IntermediateNode], ir.IntermediateNode, Set[ir.IntermediateNode]]]:
|
||||
"""Find a subgraph of the graph with float computations.
|
||||
|
||||
The subgraph has a single terminal node with a single Integer output and has a single variable
|
||||
predecessor node with a single Integer output.
|
||||
|
||||
Args:
|
||||
nx_graph (nx.MultiDiGraph): The networkx graph to search in.
|
||||
processed_terminal_nodes (Set[ir.IntermediateNode]): The set of terminal nodes for which
|
||||
subgraphs have already been searched, those will be skipped.
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[Set[ir.IntermediateNode], ir.IntermediateNode, Set[ir.IntermediateNode]]]:
|
||||
None if there are no float subgraphs to process in `nx_graph`. Otherwise returns a tuple
|
||||
containing the set of nodes beginning a float subgraph, the terminal node of the
|
||||
subgraph and the set of all the nodes in the subgraph.
|
||||
"""
|
||||
|
||||
def is_float_to_single_int_node(node: ir.IntermediateNode) -> bool:
|
||||
return (
|
||||
any(isinstance(input_.data_type, Float) for input_ in node.inputs)
|
||||
and len(node.outputs) == 1
|
||||
and isinstance(node.outputs[0].data_type, Integer)
|
||||
)
|
||||
|
||||
def single_int_output_node(node: ir.IntermediateNode) -> bool:
|
||||
return len(node.outputs) == 1 and isinstance(node.outputs[0].data_type, Integer)
|
||||
|
||||
float_subgraphs_terminal_nodes = (
|
||||
node
|
||||
for node in nx_graph.nodes()
|
||||
if is_float_to_single_int_node(node) and node not in processed_terminal_nodes
|
||||
)
|
||||
|
||||
terminal_node: ir.IntermediateNode
|
||||
|
||||
try:
|
||||
terminal_node = next(float_subgraphs_terminal_nodes)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
# Use dict as ordered set
|
||||
current_nodes = {terminal_node: None}
|
||||
float_subgraph_start_nodes: Set[ir.IntermediateNode] = set()
|
||||
subgraph_all_nodes: Set[ir.IntermediateNode] = set()
|
||||
while current_nodes:
|
||||
next_nodes: Dict[ir.IntermediateNode, None] = dict()
|
||||
for node in current_nodes:
|
||||
subgraph_all_nodes.add(node)
|
||||
predecessors = nx_graph.pred[node]
|
||||
for pred in predecessors:
|
||||
if single_int_output_node(pred):
|
||||
# Limit of subgraph, record that and record the node as we won't visit it
|
||||
float_subgraph_start_nodes.add(pred)
|
||||
subgraph_all_nodes.add(pred)
|
||||
else:
|
||||
next_nodes.update({pred: None})
|
||||
current_nodes = next_nodes
|
||||
|
||||
return float_subgraph_start_nodes, terminal_node, subgraph_all_nodes
|
||||
|
||||
|
||||
def subgraph_has_unique_variable_input(
|
||||
float_subgraph_start_nodes: Set[ir.IntermediateNode],
|
||||
) -> bool:
|
||||
"""Check that only one of the nodes starting the subgraph is variable.
|
||||
|
||||
Args:
|
||||
float_subgraph_start_nodes (Set[ir.IntermediateNode]): The nodes starting the subgraph.
|
||||
|
||||
Returns:
|
||||
bool: True if only one of the nodes is not an ir.ConstantInput
|
||||
"""
|
||||
# Only one input to the subgraph where computations are done in floats is variable, this
|
||||
# is the only case we can manage with ArbitraryFunction fusing
|
||||
return sum(not isinstance(node, ir.ConstantInput) for node in float_subgraph_start_nodes) == 1
|
||||
@@ -1,8 +1,9 @@
|
||||
"""hnumpy compilation function."""
|
||||
|
||||
from typing import Any, Callable, Dict, Iterator, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
from ..common.bounds_measurement.dataset_eval import eval_op_graph_bounds_on_dataset
|
||||
from ..common.common_helpers import check_op_graph_is_integer_program
|
||||
from ..common.compilation import CompilationArtifacts
|
||||
from ..common.data_types import BaseValue
|
||||
from ..common.mlir.utils import (
|
||||
@@ -10,6 +11,8 @@ from ..common.mlir.utils import (
|
||||
update_bit_width_for_mlir,
|
||||
)
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.optimization.topological import fuse_float_operations
|
||||
from ..common.representation import intermediate as ir
|
||||
from ..hnumpy.tracing import trace_numpy_function
|
||||
|
||||
|
||||
@@ -38,6 +41,20 @@ def compile_numpy_function(
|
||||
# Trace
|
||||
op_graph = trace_numpy_function(function_to_trace, function_parameters)
|
||||
|
||||
# Fuse float operations to have int to int ArbitraryFunction
|
||||
if not check_op_graph_is_integer_program(op_graph):
|
||||
fuse_float_operations(op_graph)
|
||||
|
||||
# TODO: To be removed once we support more than integers
|
||||
offending_non_integer_nodes: List[ir.IntermediateNode] = []
|
||||
op_grap_is_int_prog = check_op_graph_is_integer_program(op_graph, offending_non_integer_nodes)
|
||||
if not op_grap_is_int_prog:
|
||||
raise ValueError(
|
||||
f"{function_to_trace.__name__} cannot be compiled as it has nodes with either float "
|
||||
f"inputs or outputs.\nOffending nodes : "
|
||||
f"{', '.join(str(node) for node in offending_non_integer_nodes)}"
|
||||
)
|
||||
|
||||
# Find bounds with the dataset
|
||||
node_bounds = eval_op_graph_bounds_on_dataset(op_graph, dataset)
|
||||
|
||||
|
||||
107
tests/common/optimization/test_float_fusing.py
Normal file
107
tests/common/optimization/test_float_fusing.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""Test file for float subgraph fusing"""
|
||||
|
||||
from inspect import signature
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.data_types.values import EncryptedValue
|
||||
from hdk.common.optimization.topological import fuse_float_operations
|
||||
from hdk.hnumpy.tracing import trace_numpy_function
|
||||
|
||||
|
||||
def no_fuse(x):
|
||||
"""No fuse"""
|
||||
return x + 2
|
||||
|
||||
|
||||
def no_fuse_unhandled(x, y):
|
||||
"""No fuse unhandled"""
|
||||
x_1 = x + 0.7
|
||||
y_1 = y + 1.3
|
||||
intermediate = x_1 + y_1
|
||||
return intermediate.astype(numpy.int32)
|
||||
|
||||
|
||||
def simple_fuse_not_output(x):
|
||||
"""Simple fuse not output"""
|
||||
intermediate = x.astype(numpy.float64)
|
||||
intermediate = intermediate.astype(numpy.int32)
|
||||
return intermediate + 2
|
||||
|
||||
|
||||
def simple_fuse_output(x):
|
||||
"""Simple fuse output"""
|
||||
return x.astype(numpy.float64).astype(numpy.int32)
|
||||
|
||||
|
||||
def complex_fuse_indirect_input(x, y):
|
||||
"""Complex fuse"""
|
||||
intermediate = x + y
|
||||
intermediate = intermediate + 2
|
||||
intermediate = intermediate.astype(numpy.float32)
|
||||
intermediate = intermediate.astype(numpy.int32)
|
||||
x_p_1 = intermediate + 1.5
|
||||
x_p_2 = intermediate + 2.7
|
||||
x_p_3 = numpy.rint(x_p_1 + x_p_2)
|
||||
return (
|
||||
x_p_3.astype(numpy.int32),
|
||||
x_p_2.astype(numpy.int32),
|
||||
(x_p_2 + 3).astype(numpy.int32),
|
||||
x_p_3.astype(numpy.int32) + 67,
|
||||
y,
|
||||
(y + 4.7).astype(numpy.int32) + 3,
|
||||
)
|
||||
|
||||
|
||||
def complex_fuse_direct_input(x, y):
|
||||
"""Complex fuse"""
|
||||
x_p_1 = x + 1.5
|
||||
x_p_2 = x + 2.7
|
||||
x_p_3 = numpy.rint(x_p_1 + x_p_2)
|
||||
return (
|
||||
x_p_3.astype(numpy.int32),
|
||||
x_p_2.astype(numpy.int32),
|
||||
(x_p_2 + 3).astype(numpy.int32),
|
||||
x_p_3.astype(numpy.int32) + 67,
|
||||
y,
|
||||
(y + 4.7).astype(numpy.int32) + 3,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,fused",
|
||||
[
|
||||
pytest.param(no_fuse, False, id="no_fuse"),
|
||||
pytest.param(no_fuse_unhandled, False, id="no_fuse_unhandled"),
|
||||
pytest.param(simple_fuse_not_output, True, id="no_fuse"),
|
||||
pytest.param(simple_fuse_output, True, id="no_fuse"),
|
||||
pytest.param(complex_fuse_indirect_input, True, id="complex_fuse_indirect_input"),
|
||||
pytest.param(complex_fuse_direct_input, True, id="complex_fuse_direct_input"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("input_", [0, 2, 42, 44])
|
||||
def test_fuse_float_operations(function_to_trace, fused, input_):
|
||||
"""Test function for fuse_float_operations"""
|
||||
|
||||
params_names = signature(function_to_trace).parameters.keys()
|
||||
|
||||
op_graph = trace_numpy_function(
|
||||
function_to_trace,
|
||||
{param_name: EncryptedValue(Integer(32, True)) for param_name in params_names},
|
||||
)
|
||||
orig_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(op_graph)
|
||||
fused_num_nodes = len(op_graph.graph)
|
||||
|
||||
if fused:
|
||||
assert fused_num_nodes < orig_num_nodes
|
||||
else:
|
||||
assert fused_num_nodes == orig_num_nodes
|
||||
|
||||
input_ = numpy.int32(input_)
|
||||
|
||||
num_params = len(params_names)
|
||||
inputs = (input_,) * num_params
|
||||
assert function_to_trace(*inputs) == op_graph(*inputs)
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Test file for hnumpy compilation functions"""
|
||||
import itertools
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from hdk.common.data_types.integers import Integer
|
||||
@@ -10,6 +11,14 @@ from hdk.common.extensions.table import LookupTable
|
||||
from hdk.hnumpy.compile import compile_numpy_function
|
||||
|
||||
|
||||
def no_fuse_unhandled(x, y):
|
||||
"""No fuse unhandled"""
|
||||
x_intermediate = x + 2.8
|
||||
y_intermediate = y + 9.3
|
||||
intermediate = x_intermediate + y_intermediate
|
||||
return intermediate.astype(numpy.int32)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,list_of_arg_names",
|
||||
[
|
||||
@@ -21,6 +30,12 @@ from hdk.hnumpy.compile import compile_numpy_function
|
||||
((4, 8), (3, 4), (0, 4)),
|
||||
["x", "y", "z"],
|
||||
),
|
||||
pytest.param(
|
||||
no_fuse_unhandled,
|
||||
((-2, 2), (-2, 2)),
|
||||
["x", "y"],
|
||||
marks=pytest.mark.xfail(raises=ValueError),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_names):
|
||||
|
||||
Reference in New Issue
Block a user