feat: support axes kwarg for np.transpose

This commit is contained in:
Umut
2022-10-18 10:19:46 +02:00
parent 821c61e1d1
commit a2624086a2
4 changed files with 45 additions and 10 deletions

View File

@@ -1061,9 +1061,17 @@ class NodeConverter:
"""
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
preds = self.preds
pred = self.preds[0]
return fhelinalg.TransposeOp(resulting_type, *preds).result
axes = self.node.properties["kwargs"].get("axes", [])
return fhelinalg.TransposeOp(
resulting_type,
pred,
axes=ArrayAttr.get(
[IntegerAttr.get(IntegerType.get_signless(64), axis) for axis in axes]
),
).result
def _convert_zeros(self) -> OpResult:
"""

View File

@@ -313,6 +313,9 @@ class Tracer:
"axis",
"keepdims",
},
np.transpose: {
"axes",
},
np.zeros_like: {
"dtype",
},

View File

@@ -29,6 +29,36 @@ import concrete.numpy as cnp
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: x.transpose((1, 0, 2)),
{
"x": {"shape": (2, 3, 4), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: x.transpose((1, 2, 0)),
{
"x": {"shape": (2, 3, 4), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: x.transpose((0, 2, 1)),
{
"x": {"shape": (2, 3, 4), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: x.transpose((2, 0, 1)),
{
"x": {"shape": (2, 3, 4), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: np.transpose(x, (3, 0, 2, 1)),
{
"x": {"shape": (2, 3, 4, 5), "range": [0, 10], "status": "encrypted"},
},
),
],
)
def test_transpose(function, parameters, helpers):

View File

@@ -27,16 +27,10 @@ from concrete.numpy.values import EncryptedTensor
"Function 'np.sum' is not supported with kwarg 'initial'",
),
pytest.param(
lambda x: np.transpose(x, (1, 0, 2)),
lambda x: np.absolute(x, where=False),
{"x": EncryptedTensor(UnsignedInteger(7), shape=(1, 2, 3))},
RuntimeError,
"Function 'np.transpose' is not supported with kwarg 'axes'",
),
pytest.param(
lambda x: x.transpose((1, 0, 2)),
{"x": EncryptedTensor(UnsignedInteger(7), shape=(1, 2, 3))},
RuntimeError,
"Function 'np.transpose' is not supported with kwarg 'axes'",
"Function 'np.absolute' is not supported with kwarg 'where'",
),
pytest.param(
lambda x: np.multiply.outer(x, [1, 2, 3]),