mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: implement transpose
This commit is contained in:
@@ -45,7 +45,7 @@ class GraphConverter:
|
||||
None if node is convertible to MLIR, the reason for inconvertibility otherwise
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-branches,too-many-return-statements
|
||||
# pylint: disable=too-many-branches,too-many-return-statements,too-many-statements
|
||||
|
||||
inputs = node.inputs
|
||||
output = node.output
|
||||
@@ -121,6 +121,11 @@ class GraphConverter:
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only encrypted sum is supported"
|
||||
|
||||
elif name == "transpose":
|
||||
assert_that(len(inputs) == 1)
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only encrypted transpose is supported"
|
||||
|
||||
else:
|
||||
variable_input_indices = [
|
||||
idx
|
||||
@@ -135,7 +140,7 @@ class GraphConverter:
|
||||
|
||||
return None
|
||||
|
||||
# pylint: enable=too-many-branches,too-many-return-statements
|
||||
# pylint: enable=too-many-branches,too-many-return-statements,too-many-statements
|
||||
|
||||
@staticmethod
|
||||
def _check_graph_convertibility(graph: Graph):
|
||||
|
||||
@@ -166,6 +166,9 @@ class NodeConverter:
|
||||
elif name == "sum":
|
||||
result = self.convert_sum()
|
||||
|
||||
elif name == "transpose":
|
||||
result = self.convert_transpose()
|
||||
|
||||
else:
|
||||
result = self.convert_tlu()
|
||||
|
||||
@@ -798,3 +801,17 @@ class NodeConverter:
|
||||
result = fhe.ApplyLookupTableEintOp(resulting_type, pred, lut).result
|
||||
|
||||
return result
|
||||
|
||||
def convert_transpose(self) -> OpResult:
|
||||
"""
|
||||
Convert "transpose" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
preds = self.preds
|
||||
|
||||
return fhelinalg.TransposeOp(resulting_type, *preds).result
|
||||
|
||||
@@ -313,4 +313,5 @@ class Node:
|
||||
"reshape",
|
||||
"subtract",
|
||||
"sum",
|
||||
"transpose",
|
||||
]
|
||||
|
||||
@@ -4,7 +4,7 @@ Declaration of `Tracer` class.
|
||||
|
||||
import inspect
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, List, Set, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
@@ -230,6 +230,7 @@ class Tracer:
|
||||
np.sum,
|
||||
np.tan,
|
||||
np.tanh,
|
||||
np.transpose,
|
||||
np.true_divide,
|
||||
np.trunc,
|
||||
}
|
||||
@@ -345,6 +346,10 @@ class Tracer:
|
||||
sanitized_args = [self._sanitize(args[0])]
|
||||
if len(args) > 1:
|
||||
kwargs["newshape"] = args[1]
|
||||
elif func is np.transpose:
|
||||
sanitized_args = [self._sanitize(args[0])]
|
||||
if len(args) > 1:
|
||||
kwargs["axes"] = args[1]
|
||||
else:
|
||||
sanitized_args = [self._sanitize(arg) for arg in args]
|
||||
|
||||
@@ -516,6 +521,16 @@ class Tracer:
|
||||
|
||||
return Tracer._trace_numpy_operation(np.reshape, self, newshape=newshape)
|
||||
|
||||
def transpose(self, axes: Optional[Tuple[int, ...]] = None) -> "Tracer":
|
||||
"""
|
||||
Trace numpy.ndarray.transpose().
|
||||
"""
|
||||
|
||||
if axes is None:
|
||||
return Tracer._trace_numpy_operation(np.transpose, self)
|
||||
|
||||
return Tracer._trace_numpy_operation(np.transpose, self, axes=axes)
|
||||
|
||||
def __getitem__(
|
||||
self,
|
||||
index: Union[int, np.integer, slice, Tuple[Union[int, np.integer, slice], ...]],
|
||||
@@ -583,3 +598,11 @@ class Tracer:
|
||||
"""
|
||||
|
||||
return self.output.size
|
||||
|
||||
@property
|
||||
def T(self) -> "Tracer": # pylint: disable=invalid-name
|
||||
"""
|
||||
Trace numpy.ndarray.T.
|
||||
"""
|
||||
|
||||
return Tracer._trace_numpy_operation(np.transpose, self)
|
||||
|
||||
@@ -89,6 +89,7 @@ List of supported functions:
|
||||
- sum
|
||||
- tan
|
||||
- tanh
|
||||
- transpose
|
||||
- true_divide
|
||||
- trunc
|
||||
<!--- gen_supported_ufuncs.py: inject supported operations [END] -->
|
||||
|
||||
48
tests/execution/test_transpose.py
Normal file
48
tests/execution/test_transpose.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
Tests of execution of transpose operation.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import concrete.numpy as cnp
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: np.transpose(x),
|
||||
{
|
||||
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
|
||||
},
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.transpose(),
|
||||
{
|
||||
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
|
||||
},
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.T,
|
||||
{
|
||||
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_transpose(function, parameters, helpers):
|
||||
"""
|
||||
Test transpose.
|
||||
"""
|
||||
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
@@ -309,6 +309,22 @@ Function you are trying to compile cannot be converted to MLIR:
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only up to 8-bit integers are supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.transpose(x),
|
||||
{"x": "clear"},
|
||||
[np.random.randint(0, 2, size=(3, 2)) for _ in range(100)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # ClearTensor<uint1, shape=(3, 2)>
|
||||
%1 = transpose(%0) # ClearTensor<uint1, shape=(2, 3)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted transpose is supported
|
||||
return %1
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
],
|
||||
|
||||
@@ -25,6 +25,18 @@ from concrete.numpy.values import EncryptedTensor
|
||||
RuntimeError,
|
||||
"Function 'np.sum' is not supported with kwarg 'initial'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.transpose(x, (1, 0, 2)),
|
||||
{"x": EncryptedTensor(UnsignedInteger(7), shape=(1, 2, 3))},
|
||||
RuntimeError,
|
||||
"Function 'np.transpose' is not supported with kwarg 'axes'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.transpose((1, 0, 2)),
|
||||
{"x": EncryptedTensor(UnsignedInteger(7), shape=(1, 2, 3))},
|
||||
RuntimeError,
|
||||
"Function 'np.transpose' is not supported with kwarg 'axes'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.multiply.outer(x, [1, 2, 3]),
|
||||
{"x": EncryptedTensor(UnsignedInteger(7), shape=(4,))},
|
||||
|
||||
Reference in New Issue
Block a user