diff --git a/concrete/common/mlir/node_converter.py b/concrete/common/mlir/node_converter.py index e6f5462e0..52e280254 100644 --- a/concrete/common/mlir/node_converter.py +++ b/concrete/common/mlir/node_converter.py @@ -126,6 +126,8 @@ class IntermediateNodeConverter: result = self.convert_sum() elif self.node.op_name == "concat": result = self.convert_concat() + elif self.node.op_name == "transpose": + result = self.convert_transpose() else: result = self.convert_generic_function(additional_conversion_info) @@ -855,3 +857,17 @@ class IntermediateNodeConverter: ArrayAttr.get([IntegerAttr.get(IntegerType.get_signless(64), axis) for axis in axes]), BoolAttr.get(keep_dims), ).result + + def convert_transpose(self) -> OpResult: + """Convert a Transpose node to its corresponding MLIR representation. + + Returns: + str: textual MLIR representation corresponding to self.node + """ + + assert_true(len(self.node.inputs) == 1) + assert_true(len(self.node.outputs) == 1) + + resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0]) + preds = self.preds + return fhelinalg.TransposeOp(resulting_type, *preds).result diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py index 98731edbd..e8412de8a 100644 --- a/concrete/common/mlir/utils.py +++ b/concrete/common/mlir/utils.py @@ -89,7 +89,7 @@ def check_node_compatibility_with_mlir( == 1 ) else: - if node.op_name not in ["flatten", "reshape", "sum", "concat"]: + if node.op_name not in ["flatten", "reshape", "sum", "concat", "transpose"]: return f"{node.op_name} is not supported for the time being" elif isinstance(node, intermediate.Dot): # constraints for dot product diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 0bc93878d..941766c05 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -2079,24 +2079,6 @@ return %9 """.strip() # noqa: E501 ), ), - pytest.param( - lambda x: numpy.transpose(x), - {"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))}, - [numpy.random.randint(0, 2 ** 3, size=(3, 2)) for i in range(10)], - RuntimeError, - ( - """ - -function you are trying to compile isn't supported for MLIR lowering - -%0 = x # EncryptedTensor -%1 = transpose(%0) # EncryptedTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ transpose is not supported for the time being -return %1 - - """.strip() # noqa: E501 - ), - ), pytest.param( lambda x: numpy.ravel(x), {"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},