From b716ddcf7613ee0034aedf2aad20f0b632ab8d0d Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Thu, 4 Nov 2021 15:15:49 +0100 Subject: [PATCH] feat: clear error message when a constant shape is preventing float fusing - happens when a constant is bigger than the variable input of the subgraph, would require special handling, hard to write and requires a concatenation/stacking operator see #587 closes #645 --- concrete/common/optimization/topological.py | 30 +++++++++++-------- .../common/optimization/test_float_fusing.py | 21 +++++++++++++ 2 files changed, 39 insertions(+), 12 deletions(-) diff --git a/concrete/common/optimization/topological.py b/concrete/common/optimization/topological.py index 26c7a86c3..391d19390 100644 --- a/concrete/common/optimization/topological.py +++ b/concrete/common/optimization/topological.py @@ -1,5 +1,4 @@ """File holding topological optimization/simplification code.""" -import itertools from collections import defaultdict from copy import deepcopy from typing import DefaultDict, Dict, List, Optional, Set, Tuple, cast @@ -349,18 +348,25 @@ def subgraph_nodes_and_values_allow_fusing( variable_input_node_output.size, variable_input_node_output.shape, ) - max_inputs_size = max( - itertools.chain( - (variable_input_node_output_size,), - ( - cast(TensorValue, constant_input_node.outputs[0]).size - for constant_input_node in subgraph_all_nodes - if isinstance(constant_input_node, Constant) - ), - ) - ) - if variable_input_node_output_size < max_inputs_size: + constant_nodes_with_bigger_size_than_variable_input = [ + constant_input_node + for constant_input_node in subgraph_all_nodes + if isinstance(constant_input_node, Constant) + and cast(TensorValue, constant_input_node.outputs[0]).size > variable_input_node_output_size + ] + + for bigger_constant_node in constant_nodes_with_bigger_size_than_variable_input: + bigger_constant_node_shape = cast(TensorValue, bigger_constant_node.outputs[0]).shape + node_with_issues_for_fusing[bigger_constant_node].append( + f"this constant node has a bigger shape {bigger_constant_node_shape} " + f"than the subgraph's input: {variable_input_node_output_shape}" + ) + + if len(constant_nodes_with_bigger_size_than_variable_input) > 0: + node_with_issues_for_fusing[variable_input_node].append( + f"input node with shape {variable_input_node_output_shape}" + ) return False # Now that we know the variable input node has the biggest size we can check shapes are diff --git a/tests/common/optimization/test_float_fusing.py b/tests/common/optimization/test_float_fusing.py index 42ecd5757..6e05474f5 100644 --- a/tests/common/optimization/test_float_fusing.py +++ b/tests/common/optimization/test_float_fusing.py @@ -27,6 +27,12 @@ def no_fuse_unhandled(x, y): return intermediate.astype(numpy.int32) +def no_fuse_big_constant_3_10_10(x): + """Pass an array x with size < 100 to trigger a no fuse condition.""" + x = x.astype(numpy.float64) + return (x + numpy.ones((3, 10, 10))).astype(numpy.int32) + + def no_fuse_dot(x): """No fuse dot""" return numpy.dot(x, numpy.full((10,), 1.33, dtype=numpy.float64)).astype(numpy.int32) @@ -220,6 +226,21 @@ return(%3)""", # noqa: E501 # pylint: disable=line-too-long return(%3)""", # noqa: E501 # pylint: disable=line-too-long id="no_fuse_explicitely_reshape", ), + pytest.param( + no_fuse_big_constant_3_10_10, + False, + {"x": EncryptedTensor(Integer(32, True), (10, 10))}, + """The following subgraph is not fusable: +%0 = Constant([[[1. 1. 1 ... . 1. 1.]]]) # ClearTensor, shape=(3, 10, 10)> +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this constant node has a bigger shape (3, 10, 10) than the subgraph's input: (10, 10) +%1 = x # EncryptedTensor, shape=(10, 10)> +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10, 10) +%2 = astype(float64)(%1) # EncryptedTensor, shape=(10, 10)> +%3 = Add(%2, %0) # EncryptedTensor, shape=(3, 10, 10)> +%4 = astype(int32)(%3) # EncryptedTensor, shape=(3, 10, 10)> +return(%4)""", # noqa: E501 # pylint: disable=line-too-long + id="no_fuse_big_constant_3_10_10", + ), pytest.param( simple_fuse_not_output, True,