mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: emit loguru warning with reason for subgraph not fusing
- catches cases with more than one variable input - catches cases where the shapes are not the same in intermediate nodes refs #645
This commit is contained in:
@@ -31,7 +31,6 @@ class OPGraph:
|
||||
input_nodes: Dict[int, Input],
|
||||
output_nodes: Dict[int, IntermediateNode],
|
||||
) -> None:
|
||||
assert_true(len(input_nodes) > 0, "Got a graph without input nodes which is not supported")
|
||||
assert_true(
|
||||
all(isinstance(node, Input) for node in input_nodes.values()),
|
||||
"Got input nodes that were not Input, which is not supported",
|
||||
@@ -47,6 +46,7 @@ class OPGraph:
|
||||
self.prune_nodes()
|
||||
|
||||
def __call__(self, *args) -> Union[Any, Tuple[Any, ...]]:
|
||||
assert_true(len(self.input_nodes) > 0, "Cannot evaluate a graph with no input nodes")
|
||||
inputs = dict(enumerate(args))
|
||||
|
||||
assert_true(
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
"""File holding topological optimization/simplification code."""
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Optional, Set, Tuple, cast
|
||||
from typing import DefaultDict, Dict, List, Optional, Set, Tuple, cast
|
||||
|
||||
import networkx as nx
|
||||
from loguru import logger
|
||||
|
||||
from ..compilation.artifacts import CompilationArtifacts
|
||||
from ..data_types.floats import Float
|
||||
from ..data_types.integers import Integer
|
||||
from ..debugging import get_printable_graph
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import Constant, Input, IntermediateNode, UnivariateFunction
|
||||
@@ -112,11 +115,30 @@ def convert_float_subgraph_to_fused_node(
|
||||
output must be plugged as the input to the subgraph.
|
||||
"""
|
||||
|
||||
subgraph_can_be_fused = subgraph_has_unique_variable_input(
|
||||
float_subgraph_start_nodes
|
||||
) and subgraph_values_allow_fusing(float_subgraph_start_nodes, subgraph_all_nodes)
|
||||
node_with_issues_for_fusing: DefaultDict[IntermediateNode, List[str]] = defaultdict(list)
|
||||
|
||||
subgraph_can_be_fused = subgraph_has_unique_variable_input(
|
||||
float_subgraph_start_nodes, terminal_node, node_with_issues_for_fusing
|
||||
)
|
||||
|
||||
if subgraph_can_be_fused:
|
||||
# subgraph_values_allow_fusing can be called iff the subgraph has a unique variable input
|
||||
subgraph_can_be_fused = subgraph_values_allow_fusing(
|
||||
float_subgraph_start_nodes, subgraph_all_nodes, node_with_issues_for_fusing
|
||||
)
|
||||
|
||||
# This test is separate from the previous one to only handle printing issues once
|
||||
if not subgraph_can_be_fused:
|
||||
float_subgraph = nx.MultiDiGraph(op_graph.graph.subgraph(subgraph_all_nodes))
|
||||
float_subgraph_as_op_graph = OPGraph.from_graph(float_subgraph, [], [terminal_node])
|
||||
|
||||
printable_graph = get_printable_graph(
|
||||
float_subgraph_as_op_graph,
|
||||
show_data_types=True,
|
||||
highlighted_nodes=node_with_issues_for_fusing,
|
||||
)
|
||||
message = f"The following subgraph is not fusable:\n{printable_graph}"
|
||||
logger.warning(message)
|
||||
return None
|
||||
|
||||
# Only one variable input node, find which node feeds its input
|
||||
@@ -258,7 +280,8 @@ def find_float_subgraph_with_unique_terminal_node(
|
||||
def subgraph_values_allow_fusing(
|
||||
float_subgraph_start_nodes: Set[IntermediateNode],
|
||||
subgraph_all_nodes: Set[IntermediateNode],
|
||||
):
|
||||
node_with_issues_for_fusing: DefaultDict[IntermediateNode, List[str]],
|
||||
) -> bool:
|
||||
"""Check if a subgraph's values are compatible with fusing.
|
||||
|
||||
A fused subgraph for example only works on an input tensor if the resulting UnivariateFunction
|
||||
@@ -267,6 +290,8 @@ def subgraph_values_allow_fusing(
|
||||
Args:
|
||||
float_subgraph_start_nodes (Set[IntermediateNode]): The nodes starting the float subgraph.
|
||||
subgraph_all_nodes (Set[IntermediateNode]): All the nodes in the float subgraph.
|
||||
node_with_issues_for_fusing (DefaultDict[IntermediateNode, List[str]]): Dictionary to fill
|
||||
with potential nodes issues preventing fusing.
|
||||
|
||||
Returns:
|
||||
bool: True if all inputs and outputs of the nodes in the subgraph are compatible with fusing
|
||||
@@ -286,10 +311,10 @@ def subgraph_values_allow_fusing(
|
||||
# Some UnivariateFunction nodes have baked constants that need to be taken into account for the
|
||||
# max size computation
|
||||
baked_constants_ir_nodes = [
|
||||
baked_constant_base_value
|
||||
baked_constant_ir_node
|
||||
for node in subgraph_all_nodes
|
||||
if isinstance(node, UnivariateFunction)
|
||||
if (baked_constant_base_value := node.op_attributes.get("baked_constant_ir_node", None))
|
||||
if (baked_constant_ir_node := node.op_attributes.get("baked_constant_ir_node", None))
|
||||
is not None
|
||||
]
|
||||
|
||||
@@ -332,26 +357,72 @@ def subgraph_values_allow_fusing(
|
||||
|
||||
non_constant_nodes = (node for node in subgraph_all_nodes if not isinstance(node, Constant))
|
||||
|
||||
return all(
|
||||
all(
|
||||
isinstance(output, TensorValue) and output.shape == variable_input_node_output_shape
|
||||
nodes_with_different_output_shapes = {
|
||||
node: [
|
||||
(output_idx, output.shape)
|
||||
for output_idx, output in enumerate(node.outputs)
|
||||
if isinstance(output, TensorValue) and output.shape != variable_input_node
|
||||
]
|
||||
for node in non_constant_nodes
|
||||
if any(
|
||||
isinstance(output, TensorValue) and output.shape != variable_input_node_output_shape
|
||||
for output in node.outputs
|
||||
)
|
||||
for node in non_constant_nodes
|
||||
)
|
||||
}
|
||||
|
||||
for node, node_shape_infos in nodes_with_different_output_shapes.items():
|
||||
shape_issue_details = "; ".join(
|
||||
f"#{output_idx}, {output_shape}" for output_idx, output_shape in node_shape_infos
|
||||
)
|
||||
node_with_issues_for_fusing[node].append(
|
||||
f"output shapes: {shape_issue_details} are not the same as the subgraph's input: "
|
||||
f"{variable_input_node_output_shape}"
|
||||
)
|
||||
|
||||
all_nodes_have_same_shape_as_input = len(nodes_with_different_output_shapes) == 0
|
||||
|
||||
if not all_nodes_have_same_shape_as_input:
|
||||
node_with_issues_for_fusing[variable_input_node].append(
|
||||
f"input node with shape {variable_input_node_output_shape}"
|
||||
)
|
||||
|
||||
# All non constant node outputs currently need to have the same shape
|
||||
return all_nodes_have_same_shape_as_input
|
||||
|
||||
|
||||
def subgraph_has_unique_variable_input(
|
||||
float_subgraph_start_nodes: Set[IntermediateNode],
|
||||
terminal_node: IntermediateNode,
|
||||
node_with_issues_for_fusing: DefaultDict[IntermediateNode, List[str]],
|
||||
) -> bool:
|
||||
"""Check that only one of the nodes starting the subgraph is variable.
|
||||
|
||||
Args:
|
||||
float_subgraph_start_nodes (Set[IntermediateNode]): The nodes starting the subgraph.
|
||||
terminal_node (IntermediateNode): The node ending the float subgraph.
|
||||
node_with_issues_for_fusing (DefaultDict[IntermediateNode, List[str]]): Dictionary to fill
|
||||
with potential nodes issues preventing fusing.
|
||||
|
||||
Returns:
|
||||
bool: True if only one of the nodes is not an Constant
|
||||
"""
|
||||
# Only one input to the subgraph where computations are done in floats is variable, this
|
||||
|
||||
variable_inputs_list = [
|
||||
node for node in float_subgraph_start_nodes if not isinstance(node, Constant)
|
||||
]
|
||||
variable_inputs_num = len(variable_inputs_list)
|
||||
|
||||
# Only one input to the subgraph where computations are done in floats can be variable, this
|
||||
# is the only case we can manage with UnivariateFunction fusing
|
||||
return sum(not isinstance(node, Constant) for node in float_subgraph_start_nodes) == 1
|
||||
has_unique_variable_input = variable_inputs_num == 1
|
||||
|
||||
if not has_unique_variable_input:
|
||||
for node in variable_inputs_list:
|
||||
node_with_issues_for_fusing[node].append(
|
||||
f"one of {variable_inputs_num} variable inputs (can only have 1 for fusing)"
|
||||
)
|
||||
node_with_issues_for_fusing[terminal_node].append(
|
||||
f"cannot fuse here as the subgraph has {variable_inputs_num} variable inputs"
|
||||
)
|
||||
|
||||
return has_unique_variable_input
|
||||
|
||||
@@ -27,6 +27,11 @@ def no_fuse_unhandled(x, y):
|
||||
return intermediate.astype(numpy.int32)
|
||||
|
||||
|
||||
def no_fuse_dot(x):
|
||||
"""No fuse dot"""
|
||||
return numpy.dot(x, numpy.full((10,), 1.33, dtype=numpy.float64)).astype(numpy.int32)
|
||||
|
||||
|
||||
def simple_fuse_not_output(x):
|
||||
"""Simple fuse not output"""
|
||||
intermediate = x.astype(numpy.float64)
|
||||
@@ -107,34 +112,96 @@ def mix_x_and_y_into_integer_and_call_f(function, x, y):
|
||||
)
|
||||
|
||||
|
||||
def get_func_params_scalar_int32(func):
|
||||
"""Returns a dict with parameters as scalar int32"""
|
||||
|
||||
return {
|
||||
param_name: EncryptedScalar(Integer(32, True))
|
||||
for param_name in signature(func).parameters.keys()
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,fused",
|
||||
"function_to_trace,fused,params,warning_message",
|
||||
[
|
||||
pytest.param(no_fuse, False, id="no_fuse"),
|
||||
pytest.param(no_fuse_unhandled, False, id="no_fuse_unhandled"),
|
||||
pytest.param(simple_fuse_not_output, True, id="no_fuse"),
|
||||
pytest.param(simple_fuse_output, True, id="no_fuse"),
|
||||
pytest.param(no_fuse, False, get_func_params_scalar_int32(no_fuse), "", id="no_fuse"),
|
||||
pytest.param(
|
||||
no_fuse_unhandled,
|
||||
False,
|
||||
get_func_params_scalar_int32(no_fuse_unhandled),
|
||||
"""The following subgraph is not fusable:
|
||||
%0 = x # EncryptedScalar<Integer<signed, 32 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of 2 variable inputs (can only have 1 for fusing)
|
||||
%1 = Constant(0.7) # ClearScalar<Float<64 bits>>
|
||||
%2 = y # EncryptedScalar<Integer<signed, 32 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of 2 variable inputs (can only have 1 for fusing)
|
||||
%3 = Constant(1.3) # ClearScalar<Float<64 bits>>
|
||||
%4 = Add(%0, %1) # EncryptedScalar<Float<64 bits>>
|
||||
%5 = Add(%2, %3) # EncryptedScalar<Float<64 bits>>
|
||||
%6 = Add(%4, %5) # EncryptedScalar<Float<64 bits>>
|
||||
%7 = astype(int32)(%6) # EncryptedScalar<Integer<signed, 32 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ cannot fuse here as the subgraph has 2 variable inputs
|
||||
return(%7)""", # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_unhandled",
|
||||
),
|
||||
pytest.param(
|
||||
no_fuse_dot,
|
||||
False,
|
||||
{"x": EncryptedTensor(Integer(32, True), (10,))},
|
||||
"""The following subgraph is not fusable:
|
||||
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10,)
|
||||
%1 = Constant([1.33 1.33 ... 1.33 1.33]) # ClearTensor<Float<64 bits>, shape=(10,)>
|
||||
%2 = Dot(%0, %1) # EncryptedScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ output shapes: #0, () are not the same as the subgraph's input: (10,)
|
||||
%3 = astype(int32)(%2) # EncryptedScalar<Integer<signed, 32 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ output shapes: #0, () are not the same as the subgraph's input: (10,)
|
||||
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_dot",
|
||||
),
|
||||
pytest.param(
|
||||
simple_fuse_not_output,
|
||||
True,
|
||||
get_func_params_scalar_int32(simple_fuse_not_output),
|
||||
None,
|
||||
id="simple_fuse_not_output",
|
||||
),
|
||||
pytest.param(
|
||||
simple_fuse_output,
|
||||
True,
|
||||
get_func_params_scalar_int32(simple_fuse_output),
|
||||
None,
|
||||
id="simple_fuse_output",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: mix_x_and_y_intricately_and_call_f(numpy.rint, x, y),
|
||||
True,
|
||||
get_func_params_scalar_int32(lambda x, y: None),
|
||||
None,
|
||||
id="mix_x_and_y_intricately_and_call_f_with_rint",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: mix_x_and_y_and_call_f(numpy.rint, x, y),
|
||||
True,
|
||||
get_func_params_scalar_int32(lambda x, y: None),
|
||||
None,
|
||||
id="mix_x_and_y_and_call_f_with_rint",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("input_", [0, 2, 42, 44])
|
||||
def test_fuse_float_operations(function_to_trace, fused, input_):
|
||||
def test_fuse_float_operations(
|
||||
function_to_trace,
|
||||
fused,
|
||||
params,
|
||||
warning_message,
|
||||
capfd,
|
||||
remove_color_codes,
|
||||
):
|
||||
"""Test function for fuse_float_operations"""
|
||||
|
||||
params_names = signature(function_to_trace).parameters.keys()
|
||||
|
||||
op_graph = trace_numpy_function(
|
||||
function_to_trace,
|
||||
{param_name: EncryptedScalar(Integer(32, True)) for param_name in params_names},
|
||||
params,
|
||||
)
|
||||
orig_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(op_graph)
|
||||
@@ -144,12 +211,19 @@ def test_fuse_float_operations(function_to_trace, fused, input_):
|
||||
assert fused_num_nodes < orig_num_nodes
|
||||
else:
|
||||
assert fused_num_nodes == orig_num_nodes
|
||||
captured = capfd.readouterr()
|
||||
assert warning_message in remove_color_codes(captured.err)
|
||||
|
||||
input_ = numpy.int32(input_)
|
||||
for input_ in [0, 2, 42, 44]:
|
||||
inputs = ()
|
||||
for param_input_value in params.values():
|
||||
if param_input_value.is_scalar:
|
||||
input_ = numpy.int32(input_)
|
||||
else:
|
||||
input_ = numpy.full(param_input_value.shape, input_, dtype=numpy.int32)
|
||||
inputs += (input_,)
|
||||
|
||||
num_params = len(params_names)
|
||||
inputs = (input_,) * num_params
|
||||
assert function_to_trace(*inputs) == op_graph(*inputs)
|
||||
assert numpy.array_equal(function_to_trace(*inputs), op_graph(*inputs))
|
||||
|
||||
|
||||
def subtest_tensor_no_fuse(fun, tensor_shape):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""PyTest configuration file"""
|
||||
import json
|
||||
import operator
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Type
|
||||
|
||||
@@ -264,3 +265,12 @@ def default_compilation_configuration():
|
||||
dump_artifacts_on_unexpected_failures=False,
|
||||
treat_warnings_as_errors=True,
|
||||
)
|
||||
|
||||
|
||||
REMOVE_COLOR_CODES_RE = re.compile(r"\x1b[^m]*m")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def remove_color_codes():
|
||||
"""Return the re object to remove color codes"""
|
||||
return lambda x: REMOVE_COLOR_CODES_RE.sub("", x)
|
||||
|
||||
Reference in New Issue
Block a user