feat: emit loguru warning with reason for subgraph not fusing

- catches cases with more than one variable input
- catches cases where the shapes are not the same in intermediate nodes

refs #645
This commit is contained in:
Arthur Meyre
2021-10-27 10:55:58 +02:00
parent 86b6137fcb
commit 212dc36382
4 changed files with 184 additions and 29 deletions

View File

@@ -31,7 +31,6 @@ class OPGraph:
input_nodes: Dict[int, Input],
output_nodes: Dict[int, IntermediateNode],
) -> None:
assert_true(len(input_nodes) > 0, "Got a graph without input nodes which is not supported")
assert_true(
all(isinstance(node, Input) for node in input_nodes.values()),
"Got input nodes that were not Input, which is not supported",
@@ -47,6 +46,7 @@ class OPGraph:
self.prune_nodes()
def __call__(self, *args) -> Union[Any, Tuple[Any, ...]]:
assert_true(len(self.input_nodes) > 0, "Cannot evaluate a graph with no input nodes")
inputs = dict(enumerate(args))
assert_true(

View File

@@ -1,13 +1,16 @@
"""File holding topological optimization/simplification code."""
import itertools
from collections import defaultdict
from copy import deepcopy
from typing import Dict, List, Optional, Set, Tuple, cast
from typing import DefaultDict, Dict, List, Optional, Set, Tuple, cast
import networkx as nx
from loguru import logger
from ..compilation.artifacts import CompilationArtifacts
from ..data_types.floats import Float
from ..data_types.integers import Integer
from ..debugging import get_printable_graph
from ..debugging.custom_assert import assert_true
from ..operator_graph import OPGraph
from ..representation.intermediate import Constant, Input, IntermediateNode, UnivariateFunction
@@ -112,11 +115,30 @@ def convert_float_subgraph_to_fused_node(
output must be plugged as the input to the subgraph.
"""
subgraph_can_be_fused = subgraph_has_unique_variable_input(
float_subgraph_start_nodes
) and subgraph_values_allow_fusing(float_subgraph_start_nodes, subgraph_all_nodes)
node_with_issues_for_fusing: DefaultDict[IntermediateNode, List[str]] = defaultdict(list)
subgraph_can_be_fused = subgraph_has_unique_variable_input(
float_subgraph_start_nodes, terminal_node, node_with_issues_for_fusing
)
if subgraph_can_be_fused:
# subgraph_values_allow_fusing can be called iff the subgraph has a unique variable input
subgraph_can_be_fused = subgraph_values_allow_fusing(
float_subgraph_start_nodes, subgraph_all_nodes, node_with_issues_for_fusing
)
# This test is separate from the previous one to only handle printing issues once
if not subgraph_can_be_fused:
float_subgraph = nx.MultiDiGraph(op_graph.graph.subgraph(subgraph_all_nodes))
float_subgraph_as_op_graph = OPGraph.from_graph(float_subgraph, [], [terminal_node])
printable_graph = get_printable_graph(
float_subgraph_as_op_graph,
show_data_types=True,
highlighted_nodes=node_with_issues_for_fusing,
)
message = f"The following subgraph is not fusable:\n{printable_graph}"
logger.warning(message)
return None
# Only one variable input node, find which node feeds its input
@@ -258,7 +280,8 @@ def find_float_subgraph_with_unique_terminal_node(
def subgraph_values_allow_fusing(
float_subgraph_start_nodes: Set[IntermediateNode],
subgraph_all_nodes: Set[IntermediateNode],
):
node_with_issues_for_fusing: DefaultDict[IntermediateNode, List[str]],
) -> bool:
"""Check if a subgraph's values are compatible with fusing.
A fused subgraph for example only works on an input tensor if the resulting UnivariateFunction
@@ -267,6 +290,8 @@ def subgraph_values_allow_fusing(
Args:
float_subgraph_start_nodes (Set[IntermediateNode]): The nodes starting the float subgraph.
subgraph_all_nodes (Set[IntermediateNode]): All the nodes in the float subgraph.
node_with_issues_for_fusing (DefaultDict[IntermediateNode, List[str]]): Dictionary to fill
with potential nodes issues preventing fusing.
Returns:
bool: True if all inputs and outputs of the nodes in the subgraph are compatible with fusing
@@ -286,10 +311,10 @@ def subgraph_values_allow_fusing(
# Some UnivariateFunction nodes have baked constants that need to be taken into account for the
# max size computation
baked_constants_ir_nodes = [
baked_constant_base_value
baked_constant_ir_node
for node in subgraph_all_nodes
if isinstance(node, UnivariateFunction)
if (baked_constant_base_value := node.op_attributes.get("baked_constant_ir_node", None))
if (baked_constant_ir_node := node.op_attributes.get("baked_constant_ir_node", None))
is not None
]
@@ -332,26 +357,72 @@ def subgraph_values_allow_fusing(
non_constant_nodes = (node for node in subgraph_all_nodes if not isinstance(node, Constant))
return all(
all(
isinstance(output, TensorValue) and output.shape == variable_input_node_output_shape
nodes_with_different_output_shapes = {
node: [
(output_idx, output.shape)
for output_idx, output in enumerate(node.outputs)
if isinstance(output, TensorValue) and output.shape != variable_input_node
]
for node in non_constant_nodes
if any(
isinstance(output, TensorValue) and output.shape != variable_input_node_output_shape
for output in node.outputs
)
for node in non_constant_nodes
)
}
for node, node_shape_infos in nodes_with_different_output_shapes.items():
shape_issue_details = "; ".join(
f"#{output_idx}, {output_shape}" for output_idx, output_shape in node_shape_infos
)
node_with_issues_for_fusing[node].append(
f"output shapes: {shape_issue_details} are not the same as the subgraph's input: "
f"{variable_input_node_output_shape}"
)
all_nodes_have_same_shape_as_input = len(nodes_with_different_output_shapes) == 0
if not all_nodes_have_same_shape_as_input:
node_with_issues_for_fusing[variable_input_node].append(
f"input node with shape {variable_input_node_output_shape}"
)
# All non constant node outputs currently need to have the same shape
return all_nodes_have_same_shape_as_input
def subgraph_has_unique_variable_input(
float_subgraph_start_nodes: Set[IntermediateNode],
terminal_node: IntermediateNode,
node_with_issues_for_fusing: DefaultDict[IntermediateNode, List[str]],
) -> bool:
"""Check that only one of the nodes starting the subgraph is variable.
Args:
float_subgraph_start_nodes (Set[IntermediateNode]): The nodes starting the subgraph.
terminal_node (IntermediateNode): The node ending the float subgraph.
node_with_issues_for_fusing (DefaultDict[IntermediateNode, List[str]]): Dictionary to fill
with potential nodes issues preventing fusing.
Returns:
bool: True if only one of the nodes is not an Constant
"""
# Only one input to the subgraph where computations are done in floats is variable, this
variable_inputs_list = [
node for node in float_subgraph_start_nodes if not isinstance(node, Constant)
]
variable_inputs_num = len(variable_inputs_list)
# Only one input to the subgraph where computations are done in floats can be variable, this
# is the only case we can manage with UnivariateFunction fusing
return sum(not isinstance(node, Constant) for node in float_subgraph_start_nodes) == 1
has_unique_variable_input = variable_inputs_num == 1
if not has_unique_variable_input:
for node in variable_inputs_list:
node_with_issues_for_fusing[node].append(
f"one of {variable_inputs_num} variable inputs (can only have 1 for fusing)"
)
node_with_issues_for_fusing[terminal_node].append(
f"cannot fuse here as the subgraph has {variable_inputs_num} variable inputs"
)
return has_unique_variable_input

View File

@@ -27,6 +27,11 @@ def no_fuse_unhandled(x, y):
return intermediate.astype(numpy.int32)
def no_fuse_dot(x):
"""No fuse dot"""
return numpy.dot(x, numpy.full((10,), 1.33, dtype=numpy.float64)).astype(numpy.int32)
def simple_fuse_not_output(x):
"""Simple fuse not output"""
intermediate = x.astype(numpy.float64)
@@ -107,34 +112,96 @@ def mix_x_and_y_into_integer_and_call_f(function, x, y):
)
def get_func_params_scalar_int32(func):
"""Returns a dict with parameters as scalar int32"""
return {
param_name: EncryptedScalar(Integer(32, True))
for param_name in signature(func).parameters.keys()
}
@pytest.mark.parametrize(
"function_to_trace,fused",
"function_to_trace,fused,params,warning_message",
[
pytest.param(no_fuse, False, id="no_fuse"),
pytest.param(no_fuse_unhandled, False, id="no_fuse_unhandled"),
pytest.param(simple_fuse_not_output, True, id="no_fuse"),
pytest.param(simple_fuse_output, True, id="no_fuse"),
pytest.param(no_fuse, False, get_func_params_scalar_int32(no_fuse), "", id="no_fuse"),
pytest.param(
no_fuse_unhandled,
False,
get_func_params_scalar_int32(no_fuse_unhandled),
"""The following subgraph is not fusable:
%0 = x # EncryptedScalar<Integer<signed, 32 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of 2 variable inputs (can only have 1 for fusing)
%1 = Constant(0.7) # ClearScalar<Float<64 bits>>
%2 = y # EncryptedScalar<Integer<signed, 32 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of 2 variable inputs (can only have 1 for fusing)
%3 = Constant(1.3) # ClearScalar<Float<64 bits>>
%4 = Add(%0, %1) # EncryptedScalar<Float<64 bits>>
%5 = Add(%2, %3) # EncryptedScalar<Float<64 bits>>
%6 = Add(%4, %5) # EncryptedScalar<Float<64 bits>>
%7 = astype(int32)(%6) # EncryptedScalar<Integer<signed, 32 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ cannot fuse here as the subgraph has 2 variable inputs
return(%7)""", # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_unhandled",
),
pytest.param(
no_fuse_dot,
False,
{"x": EncryptedTensor(Integer(32, True), (10,))},
"""The following subgraph is not fusable:
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10,)
%1 = Constant([1.33 1.33 ... 1.33 1.33]) # ClearTensor<Float<64 bits>, shape=(10,)>
%2 = Dot(%0, %1) # EncryptedScalar<Float<64 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ output shapes: #0, () are not the same as the subgraph's input: (10,)
%3 = astype(int32)(%2) # EncryptedScalar<Integer<signed, 32 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ output shapes: #0, () are not the same as the subgraph's input: (10,)
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_dot",
),
pytest.param(
simple_fuse_not_output,
True,
get_func_params_scalar_int32(simple_fuse_not_output),
None,
id="simple_fuse_not_output",
),
pytest.param(
simple_fuse_output,
True,
get_func_params_scalar_int32(simple_fuse_output),
None,
id="simple_fuse_output",
),
pytest.param(
lambda x, y: mix_x_and_y_intricately_and_call_f(numpy.rint, x, y),
True,
get_func_params_scalar_int32(lambda x, y: None),
None,
id="mix_x_and_y_intricately_and_call_f_with_rint",
),
pytest.param(
lambda x, y: mix_x_and_y_and_call_f(numpy.rint, x, y),
True,
get_func_params_scalar_int32(lambda x, y: None),
None,
id="mix_x_and_y_and_call_f_with_rint",
),
],
)
@pytest.mark.parametrize("input_", [0, 2, 42, 44])
def test_fuse_float_operations(function_to_trace, fused, input_):
def test_fuse_float_operations(
function_to_trace,
fused,
params,
warning_message,
capfd,
remove_color_codes,
):
"""Test function for fuse_float_operations"""
params_names = signature(function_to_trace).parameters.keys()
op_graph = trace_numpy_function(
function_to_trace,
{param_name: EncryptedScalar(Integer(32, True)) for param_name in params_names},
params,
)
orig_num_nodes = len(op_graph.graph)
fuse_float_operations(op_graph)
@@ -144,12 +211,19 @@ def test_fuse_float_operations(function_to_trace, fused, input_):
assert fused_num_nodes < orig_num_nodes
else:
assert fused_num_nodes == orig_num_nodes
captured = capfd.readouterr()
assert warning_message in remove_color_codes(captured.err)
input_ = numpy.int32(input_)
for input_ in [0, 2, 42, 44]:
inputs = ()
for param_input_value in params.values():
if param_input_value.is_scalar:
input_ = numpy.int32(input_)
else:
input_ = numpy.full(param_input_value.shape, input_, dtype=numpy.int32)
inputs += (input_,)
num_params = len(params_names)
inputs = (input_,) * num_params
assert function_to_trace(*inputs) == op_graph(*inputs)
assert numpy.array_equal(function_to_trace(*inputs), op_graph(*inputs))
def subtest_tensor_no_fuse(fun, tensor_shape):

View File

@@ -1,6 +1,7 @@
"""PyTest configuration file"""
import json
import operator
import re
from pathlib import Path
from typing import Callable, Dict, Type
@@ -264,3 +265,12 @@ def default_compilation_configuration():
dump_artifacts_on_unexpected_failures=False,
treat_warnings_as_errors=True,
)
REMOVE_COLOR_CODES_RE = re.compile(r"\x1b[^m]*m")
@pytest.fixture
def remove_color_codes():
"""Return the re object to remove color codes"""
return lambda x: REMOVE_COLOR_CODES_RE.sub("", x)