mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: clear error message when a constant shape is preventing float fusing
- happens when a constant is bigger than the variable input of the subgraph, would require special handling, hard to write and requires a concatenation/stacking operator see #587 closes #645
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
"""File holding topological optimization/simplification code."""
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from typing import DefaultDict, Dict, List, Optional, Set, Tuple, cast
|
||||
@@ -349,18 +348,25 @@ def subgraph_nodes_and_values_allow_fusing(
|
||||
variable_input_node_output.size,
|
||||
variable_input_node_output.shape,
|
||||
)
|
||||
max_inputs_size = max(
|
||||
itertools.chain(
|
||||
(variable_input_node_output_size,),
|
||||
(
|
||||
cast(TensorValue, constant_input_node.outputs[0]).size
|
||||
for constant_input_node in subgraph_all_nodes
|
||||
if isinstance(constant_input_node, Constant)
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if variable_input_node_output_size < max_inputs_size:
|
||||
constant_nodes_with_bigger_size_than_variable_input = [
|
||||
constant_input_node
|
||||
for constant_input_node in subgraph_all_nodes
|
||||
if isinstance(constant_input_node, Constant)
|
||||
and cast(TensorValue, constant_input_node.outputs[0]).size > variable_input_node_output_size
|
||||
]
|
||||
|
||||
for bigger_constant_node in constant_nodes_with_bigger_size_than_variable_input:
|
||||
bigger_constant_node_shape = cast(TensorValue, bigger_constant_node.outputs[0]).shape
|
||||
node_with_issues_for_fusing[bigger_constant_node].append(
|
||||
f"this constant node has a bigger shape {bigger_constant_node_shape} "
|
||||
f"than the subgraph's input: {variable_input_node_output_shape}"
|
||||
)
|
||||
|
||||
if len(constant_nodes_with_bigger_size_than_variable_input) > 0:
|
||||
node_with_issues_for_fusing[variable_input_node].append(
|
||||
f"input node with shape {variable_input_node_output_shape}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Now that we know the variable input node has the biggest size we can check shapes are
|
||||
|
||||
@@ -27,6 +27,12 @@ def no_fuse_unhandled(x, y):
|
||||
return intermediate.astype(numpy.int32)
|
||||
|
||||
|
||||
def no_fuse_big_constant_3_10_10(x):
|
||||
"""Pass an array x with size < 100 to trigger a no fuse condition."""
|
||||
x = x.astype(numpy.float64)
|
||||
return (x + numpy.ones((3, 10, 10))).astype(numpy.int32)
|
||||
|
||||
|
||||
def no_fuse_dot(x):
|
||||
"""No fuse dot"""
|
||||
return numpy.dot(x, numpy.full((10,), 1.33, dtype=numpy.float64)).astype(numpy.int32)
|
||||
@@ -220,6 +226,21 @@ return(%3)""", # noqa: E501 # pylint: disable=line-too-long
|
||||
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_explicitely_reshape",
|
||||
),
|
||||
pytest.param(
|
||||
no_fuse_big_constant_3_10_10,
|
||||
False,
|
||||
{"x": EncryptedTensor(Integer(32, True), (10, 10))},
|
||||
"""The following subgraph is not fusable:
|
||||
%0 = Constant([[[1. 1. 1 ... . 1. 1.]]]) # ClearTensor<Float<64 bits>, shape=(3, 10, 10)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this constant node has a bigger shape (3, 10, 10) than the subgraph's input: (10, 10)
|
||||
%1 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10, 10)
|
||||
%2 = astype(float64)(%1) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
|
||||
%3 = Add(%2, %0) # EncryptedTensor<Float<64 bits>, shape=(3, 10, 10)>
|
||||
%4 = astype(int32)(%3) # EncryptedTensor<Integer<signed, 32 bits>, shape=(3, 10, 10)>
|
||||
return(%4)""", # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_big_constant_3_10_10",
|
||||
),
|
||||
pytest.param(
|
||||
simple_fuse_not_output,
|
||||
True,
|
||||
|
||||
Reference in New Issue
Block a user