mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: add support for numpy.transpose compilation
This commit is contained in:
@@ -126,6 +126,8 @@ class IntermediateNodeConverter:
|
||||
result = self.convert_sum()
|
||||
elif self.node.op_name == "concat":
|
||||
result = self.convert_concat()
|
||||
elif self.node.op_name == "transpose":
|
||||
result = self.convert_transpose()
|
||||
else:
|
||||
result = self.convert_generic_function(additional_conversion_info)
|
||||
|
||||
@@ -855,3 +857,17 @@ class IntermediateNodeConverter:
|
||||
ArrayAttr.get([IntegerAttr.get(IntegerType.get_signless(64), axis) for axis in axes]),
|
||||
BoolAttr.get(keep_dims),
|
||||
).result
|
||||
|
||||
def convert_transpose(self) -> OpResult:
|
||||
"""Convert a Transpose node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
assert_true(len(self.node.inputs) == 1)
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
|
||||
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
|
||||
preds = self.preds
|
||||
return fhelinalg.TransposeOp(resulting_type, *preds).result
|
||||
|
||||
@@ -89,7 +89,7 @@ def check_node_compatibility_with_mlir(
|
||||
== 1
|
||||
)
|
||||
else:
|
||||
if node.op_name not in ["flatten", "reshape", "sum", "concat"]:
|
||||
if node.op_name not in ["flatten", "reshape", "sum", "concat", "transpose"]:
|
||||
return f"{node.op_name} is not supported for the time being"
|
||||
|
||||
elif isinstance(node, intermediate.Dot): # constraints for dot product
|
||||
|
||||
@@ -2079,24 +2079,6 @@ return %9
|
||||
""".strip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.transpose(x),
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},
|
||||
[numpy.random.randint(0, 2 ** 3, size=(3, 2)) for i in range(10)],
|
||||
RuntimeError,
|
||||
(
|
||||
"""
|
||||
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<uint3, shape=(3, 2)>
|
||||
%1 = transpose(%0) # EncryptedTensor<uint3, shape=(2, 3)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ transpose is not supported for the time being
|
||||
return %1
|
||||
|
||||
""".strip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.ravel(x),
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},
|
||||
|
||||
Reference in New Issue
Block a user