feat: add support for numpy.transpose compilation

This commit is contained in:
youben11
2022-02-07 10:43:04 +01:00
committed by Umut
parent 90f1a0b470
commit a2955d29ea
3 changed files with 17 additions and 19 deletions

View File

@@ -126,6 +126,8 @@ class IntermediateNodeConverter:
result = self.convert_sum()
elif self.node.op_name == "concat":
result = self.convert_concat()
elif self.node.op_name == "transpose":
result = self.convert_transpose()
else:
result = self.convert_generic_function(additional_conversion_info)
@@ -855,3 +857,17 @@ class IntermediateNodeConverter:
ArrayAttr.get([IntegerAttr.get(IntegerType.get_signless(64), axis) for axis in axes]),
BoolAttr.get(keep_dims),
).result
def convert_transpose(self) -> OpResult:
"""Convert a Transpose node to its corresponding MLIR representation.
Returns:
str: textual MLIR representation corresponding to self.node
"""
assert_true(len(self.node.inputs) == 1)
assert_true(len(self.node.outputs) == 1)
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
preds = self.preds
return fhelinalg.TransposeOp(resulting_type, *preds).result

View File

@@ -89,7 +89,7 @@ def check_node_compatibility_with_mlir(
== 1
)
else:
if node.op_name not in ["flatten", "reshape", "sum", "concat"]:
if node.op_name not in ["flatten", "reshape", "sum", "concat", "transpose"]:
return f"{node.op_name} is not supported for the time being"
elif isinstance(node, intermediate.Dot): # constraints for dot product

View File

@@ -2079,24 +2079,6 @@ return %9
""".strip() # noqa: E501
),
),
pytest.param(
lambda x: numpy.transpose(x),
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},
[numpy.random.randint(0, 2 ** 3, size=(3, 2)) for i in range(10)],
RuntimeError,
(
"""
function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<uint3, shape=(3, 2)>
%1 = transpose(%0) # EncryptedTensor<uint3, shape=(2, 3)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ transpose is not supported for the time being
return %1
""".strip() # noqa: E501
),
),
pytest.param(
lambda x: numpy.ravel(x),
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},