mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-10 04:35:03 -05:00
feat: clear error message when a constant shape is preventing float fusing
- happens when a constant is bigger than the variable input of the subgraph, would require special handling, hard to write and requires a concatenation/stacking operator see #587 closes #645
This commit is contained in:
@@ -27,6 +27,12 @@ def no_fuse_unhandled(x, y):
|
||||
return intermediate.astype(numpy.int32)
|
||||
|
||||
|
||||
def no_fuse_big_constant_3_10_10(x):
|
||||
"""Pass an array x with size < 100 to trigger a no fuse condition."""
|
||||
x = x.astype(numpy.float64)
|
||||
return (x + numpy.ones((3, 10, 10))).astype(numpy.int32)
|
||||
|
||||
|
||||
def no_fuse_dot(x):
|
||||
"""No fuse dot"""
|
||||
return numpy.dot(x, numpy.full((10,), 1.33, dtype=numpy.float64)).astype(numpy.int32)
|
||||
@@ -220,6 +226,21 @@ return(%3)""", # noqa: E501 # pylint: disable=line-too-long
|
||||
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_explicitely_reshape",
|
||||
),
|
||||
pytest.param(
|
||||
no_fuse_big_constant_3_10_10,
|
||||
False,
|
||||
{"x": EncryptedTensor(Integer(32, True), (10, 10))},
|
||||
"""The following subgraph is not fusable:
|
||||
%0 = Constant([[[1. 1. 1 ... . 1. 1.]]]) # ClearTensor<Float<64 bits>, shape=(3, 10, 10)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this constant node has a bigger shape (3, 10, 10) than the subgraph's input: (10, 10)
|
||||
%1 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10, 10)
|
||||
%2 = astype(float64)(%1) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
|
||||
%3 = Add(%2, %0) # EncryptedTensor<Float<64 bits>, shape=(3, 10, 10)>
|
||||
%4 = astype(int32)(%3) # EncryptedTensor<Integer<signed, 32 bits>, shape=(3, 10, 10)>
|
||||
return(%4)""", # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_big_constant_3_10_10",
|
||||
),
|
||||
pytest.param(
|
||||
simple_fuse_not_output,
|
||||
True,
|
||||
|
||||
Reference in New Issue
Block a user