feat: clear error message when a constant shape is preventing float fusing

- happens when a constant is bigger than the variable input of the
subgraph, would require special handling, hard to write and requires a
concatenation/stacking operator see #587

closes #645
This commit is contained in:
Arthur Meyre
2021-11-04 15:15:49 +01:00
parent fae89bd452
commit b716ddcf76
2 changed files with 39 additions and 12 deletions

View File

@@ -27,6 +27,12 @@ def no_fuse_unhandled(x, y):
return intermediate.astype(numpy.int32)
def no_fuse_big_constant_3_10_10(x):
"""Pass an array x with size < 100 to trigger a no fuse condition."""
x = x.astype(numpy.float64)
return (x + numpy.ones((3, 10, 10))).astype(numpy.int32)
def no_fuse_dot(x):
"""No fuse dot"""
return numpy.dot(x, numpy.full((10,), 1.33, dtype=numpy.float64)).astype(numpy.int32)
@@ -220,6 +226,21 @@ return(%3)""", # noqa: E501 # pylint: disable=line-too-long
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_explicitely_reshape",
),
pytest.param(
no_fuse_big_constant_3_10_10,
False,
{"x": EncryptedTensor(Integer(32, True), (10, 10))},
"""The following subgraph is not fusable:
%0 = Constant([[[1. 1. 1 ... . 1. 1.]]]) # ClearTensor<Float<64 bits>, shape=(3, 10, 10)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this constant node has a bigger shape (3, 10, 10) than the subgraph's input: (10, 10)
%1 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10, 10)
%2 = astype(float64)(%1) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
%3 = Add(%2, %0) # EncryptedTensor<Float<64 bits>, shape=(3, 10, 10)>
%4 = astype(int32)(%3) # EncryptedTensor<Integer<signed, 32 bits>, shape=(3, 10, 10)>
return(%4)""", # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_big_constant_3_10_10",
),
pytest.param(
simple_fuse_not_output,
True,