feat: clear error message when a constant shape is preventing float fusing

- happens when a constant is bigger than the variable input of the
subgraph, would require special handling, hard to write and requires a
concatenation/stacking operator see #587

closes #645
This commit is contained in:
Arthur Meyre
2021-11-04 15:15:49 +01:00
parent fae89bd452
commit b716ddcf76
2 changed files with 39 additions and 12 deletions

View File

@@ -1,5 +1,4 @@
"""File holding topological optimization/simplification code."""
import itertools
from collections import defaultdict
from copy import deepcopy
from typing import DefaultDict, Dict, List, Optional, Set, Tuple, cast
@@ -349,18 +348,25 @@ def subgraph_nodes_and_values_allow_fusing(
variable_input_node_output.size,
variable_input_node_output.shape,
)
max_inputs_size = max(
itertools.chain(
(variable_input_node_output_size,),
(
cast(TensorValue, constant_input_node.outputs[0]).size
for constant_input_node in subgraph_all_nodes
if isinstance(constant_input_node, Constant)
),
)
)
if variable_input_node_output_size < max_inputs_size:
constant_nodes_with_bigger_size_than_variable_input = [
constant_input_node
for constant_input_node in subgraph_all_nodes
if isinstance(constant_input_node, Constant)
and cast(TensorValue, constant_input_node.outputs[0]).size > variable_input_node_output_size
]
for bigger_constant_node in constant_nodes_with_bigger_size_than_variable_input:
bigger_constant_node_shape = cast(TensorValue, bigger_constant_node.outputs[0]).shape
node_with_issues_for_fusing[bigger_constant_node].append(
f"this constant node has a bigger shape {bigger_constant_node_shape} "
f"than the subgraph's input: {variable_input_node_output_shape}"
)
if len(constant_nodes_with_bigger_size_than_variable_input) > 0:
node_with_issues_for_fusing[variable_input_node].append(
f"input node with shape {variable_input_node_output_shape}"
)
return False
# Now that we know the variable input node has the biggest size we can check shapes are

View File

@@ -27,6 +27,12 @@ def no_fuse_unhandled(x, y):
return intermediate.astype(numpy.int32)
def no_fuse_big_constant_3_10_10(x):
"""Pass an array x with size < 100 to trigger a no fuse condition."""
x = x.astype(numpy.float64)
return (x + numpy.ones((3, 10, 10))).astype(numpy.int32)
def no_fuse_dot(x):
"""No fuse dot"""
return numpy.dot(x, numpy.full((10,), 1.33, dtype=numpy.float64)).astype(numpy.int32)
@@ -220,6 +226,21 @@ return(%3)""", # noqa: E501 # pylint: disable=line-too-long
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_explicitely_reshape",
),
pytest.param(
no_fuse_big_constant_3_10_10,
False,
{"x": EncryptedTensor(Integer(32, True), (10, 10))},
"""The following subgraph is not fusable:
%0 = Constant([[[1. 1. 1 ... . 1. 1.]]]) # ClearTensor<Float<64 bits>, shape=(3, 10, 10)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this constant node has a bigger shape (3, 10, 10) than the subgraph's input: (10, 10)
%1 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10, 10)
%2 = astype(float64)(%1) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
%3 = Add(%2, %0) # EncryptedTensor<Float<64 bits>, shape=(3, 10, 10)>
%4 = astype(int32)(%3) # EncryptedTensor<Integer<signed, 32 bits>, shape=(3, 10, 10)>
return(%4)""", # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_big_constant_3_10_10",
),
pytest.param(
simple_fuse_not_output,
True,