feat: add support for np.squeeze

This commit is contained in:
Umut
2023-01-23 16:03:10 +01:00
parent eb11f3119d
commit 8564930fc8
5 changed files with 65 additions and 0 deletions

View File

@@ -151,6 +151,9 @@ class GraphConverter:
if not inputs[0].is_encrypted:
return "only encrypted reshape is supported"
elif name == "squeeze":
assert_that(len(inputs) == 1)
elif name == "subtract":
assert_that(len(inputs) == 2)

View File

@@ -170,6 +170,7 @@ class NodeConverter:
"negative": self._convert_neg,
"ones": self._convert_ones,
"reshape": self._convert_reshape,
"squeeze": self._convert_squeeze,
"subtract": self._convert_sub,
"sum": self._convert_sum,
"transpose": self._convert_transpose,
@@ -884,6 +885,27 @@ class NodeConverter:
),
).result
def _convert_squeeze(self) -> OpResult:
"""
Convert "squeeze" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
# because of the tracing logic, we have the correct output shape
# if the output shape is (), it means (1, 1, ..., 1, 1) is squeezed
# and the result is a scalar, so we need to do indexing, not reshape
if self.node.output.shape == ():
assert all(size == 1 for size in self.node.inputs[0].shape)
self.node.properties["kwargs"]["index"] = (0,) * self.node.inputs[0].ndim
return self._convert_static_indexing()
# otherwise, a simple reshape would work as we already have the correct shape
return self._convert_reshape()
def _convert_sub(self) -> OpResult:
"""
Convert "subtract" node to its corresponding MLIR representation.

View File

@@ -381,6 +381,7 @@ class Node:
"negative",
"ones",
"reshape",
"squeeze",
"subtract",
"sum",
"transpose",

View File

@@ -287,6 +287,7 @@ class Tracer:
np.spacing,
np.sqrt,
np.square,
np.squeeze,
np.subtract,
np.sum,
np.tan,
@@ -320,6 +321,9 @@ class Tracer:
np.round_: {
"decimals",
},
np.squeeze: {
"axis",
},
np.sum: {
"axis",
"keepdims",

View File

@@ -642,6 +642,41 @@ def deterministic_unary_function(x):
},
id="x ** 3",
),
pytest.param(
lambda x: np.squeeze(x),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (1, 2, 1, 3, 1, 4)},
},
id="np.squeeze(x)",
),
pytest.param(
lambda x: np.squeeze(x, axis=2),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (1, 2, 1, 3, 1, 4)},
},
id="np.squeeze(x, axis=2)",
),
pytest.param(
lambda x: np.squeeze(x, axis=(0, 4)),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (1, 2, 1, 3, 1, 4)},
},
id="np.squeeze(x, axis=(0, 4))",
),
pytest.param(
lambda x: np.squeeze(x),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (1, 1, 1)},
},
id="np.squeeze(x) where x.shape == (1, 1, 1)",
),
pytest.param(
lambda x: np.squeeze(x, axis=1),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (1, 1, 1)},
},
id="np.squeeze(x, axis=1) where x.shape == (1, 1, 1)",
),
],
)
def test_others(function, parameters, helpers):