mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: add support for np.squeeze
This commit is contained in:
@@ -151,6 +151,9 @@ class GraphConverter:
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only encrypted reshape is supported"
|
||||
|
||||
elif name == "squeeze":
|
||||
assert_that(len(inputs) == 1)
|
||||
|
||||
elif name == "subtract":
|
||||
assert_that(len(inputs) == 2)
|
||||
|
||||
|
||||
@@ -170,6 +170,7 @@ class NodeConverter:
|
||||
"negative": self._convert_neg,
|
||||
"ones": self._convert_ones,
|
||||
"reshape": self._convert_reshape,
|
||||
"squeeze": self._convert_squeeze,
|
||||
"subtract": self._convert_sub,
|
||||
"sum": self._convert_sum,
|
||||
"transpose": self._convert_transpose,
|
||||
@@ -884,6 +885,27 @@ class NodeConverter:
|
||||
),
|
||||
).result
|
||||
|
||||
def _convert_squeeze(self) -> OpResult:
|
||||
"""
|
||||
Convert "squeeze" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
# because of the tracing logic, we have the correct output shape
|
||||
|
||||
# if the output shape is (), it means (1, 1, ..., 1, 1) is squeezed
|
||||
# and the result is a scalar, so we need to do indexing, not reshape
|
||||
if self.node.output.shape == ():
|
||||
assert all(size == 1 for size in self.node.inputs[0].shape)
|
||||
self.node.properties["kwargs"]["index"] = (0,) * self.node.inputs[0].ndim
|
||||
return self._convert_static_indexing()
|
||||
|
||||
# otherwise, a simple reshape would work as we already have the correct shape
|
||||
return self._convert_reshape()
|
||||
|
||||
def _convert_sub(self) -> OpResult:
|
||||
"""
|
||||
Convert "subtract" node to its corresponding MLIR representation.
|
||||
|
||||
@@ -381,6 +381,7 @@ class Node:
|
||||
"negative",
|
||||
"ones",
|
||||
"reshape",
|
||||
"squeeze",
|
||||
"subtract",
|
||||
"sum",
|
||||
"transpose",
|
||||
|
||||
@@ -287,6 +287,7 @@ class Tracer:
|
||||
np.spacing,
|
||||
np.sqrt,
|
||||
np.square,
|
||||
np.squeeze,
|
||||
np.subtract,
|
||||
np.sum,
|
||||
np.tan,
|
||||
@@ -320,6 +321,9 @@ class Tracer:
|
||||
np.round_: {
|
||||
"decimals",
|
||||
},
|
||||
np.squeeze: {
|
||||
"axis",
|
||||
},
|
||||
np.sum: {
|
||||
"axis",
|
||||
"keepdims",
|
||||
|
||||
@@ -642,6 +642,41 @@ def deterministic_unary_function(x):
|
||||
},
|
||||
id="x ** 3",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.squeeze(x),
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [-10, 10], "shape": (1, 2, 1, 3, 1, 4)},
|
||||
},
|
||||
id="np.squeeze(x)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.squeeze(x, axis=2),
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [-10, 10], "shape": (1, 2, 1, 3, 1, 4)},
|
||||
},
|
||||
id="np.squeeze(x, axis=2)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.squeeze(x, axis=(0, 4)),
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [-10, 10], "shape": (1, 2, 1, 3, 1, 4)},
|
||||
},
|
||||
id="np.squeeze(x, axis=(0, 4))",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.squeeze(x),
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [-10, 10], "shape": (1, 1, 1)},
|
||||
},
|
||||
id="np.squeeze(x) where x.shape == (1, 1, 1)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.squeeze(x, axis=1),
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [-10, 10], "shape": (1, 1, 1)},
|
||||
},
|
||||
id="np.squeeze(x, axis=1) where x.shape == (1, 1, 1)",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_others(function, parameters, helpers):
|
||||
|
||||
Reference in New Issue
Block a user