mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat: support axes kwarg for np.transpose
This commit is contained in:
@@ -1061,9 +1061,17 @@ class NodeConverter:
|
||||
"""
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
preds = self.preds
|
||||
pred = self.preds[0]
|
||||
|
||||
return fhelinalg.TransposeOp(resulting_type, *preds).result
|
||||
axes = self.node.properties["kwargs"].get("axes", [])
|
||||
|
||||
return fhelinalg.TransposeOp(
|
||||
resulting_type,
|
||||
pred,
|
||||
axes=ArrayAttr.get(
|
||||
[IntegerAttr.get(IntegerType.get_signless(64), axis) for axis in axes]
|
||||
),
|
||||
).result
|
||||
|
||||
def _convert_zeros(self) -> OpResult:
|
||||
"""
|
||||
|
||||
@@ -313,6 +313,9 @@ class Tracer:
|
||||
"axis",
|
||||
"keepdims",
|
||||
},
|
||||
np.transpose: {
|
||||
"axes",
|
||||
},
|
||||
np.zeros_like: {
|
||||
"dtype",
|
||||
},
|
||||
|
||||
@@ -29,6 +29,36 @@ import concrete.numpy as cnp
|
||||
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
|
||||
},
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.transpose((1, 0, 2)),
|
||||
{
|
||||
"x": {"shape": (2, 3, 4), "range": [0, 10], "status": "encrypted"},
|
||||
},
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.transpose((1, 2, 0)),
|
||||
{
|
||||
"x": {"shape": (2, 3, 4), "range": [0, 10], "status": "encrypted"},
|
||||
},
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.transpose((0, 2, 1)),
|
||||
{
|
||||
"x": {"shape": (2, 3, 4), "range": [0, 10], "status": "encrypted"},
|
||||
},
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.transpose((2, 0, 1)),
|
||||
{
|
||||
"x": {"shape": (2, 3, 4), "range": [0, 10], "status": "encrypted"},
|
||||
},
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.transpose(x, (3, 0, 2, 1)),
|
||||
{
|
||||
"x": {"shape": (2, 3, 4, 5), "range": [0, 10], "status": "encrypted"},
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_transpose(function, parameters, helpers):
|
||||
|
||||
@@ -27,16 +27,10 @@ from concrete.numpy.values import EncryptedTensor
|
||||
"Function 'np.sum' is not supported with kwarg 'initial'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.transpose(x, (1, 0, 2)),
|
||||
lambda x: np.absolute(x, where=False),
|
||||
{"x": EncryptedTensor(UnsignedInteger(7), shape=(1, 2, 3))},
|
||||
RuntimeError,
|
||||
"Function 'np.transpose' is not supported with kwarg 'axes'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.transpose((1, 0, 2)),
|
||||
{"x": EncryptedTensor(UnsignedInteger(7), shape=(1, 2, 3))},
|
||||
RuntimeError,
|
||||
"Function 'np.transpose' is not supported with kwarg 'axes'",
|
||||
"Function 'np.absolute' is not supported with kwarg 'where'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.multiply.outer(x, [1, 2, 3]),
|
||||
|
||||
Reference in New Issue
Block a user