feat: implement transpose

This commit is contained in:
Umut
2022-04-04 14:13:03 +02:00
parent 79685ed7dc
commit ba33d42762
8 changed files with 126 additions and 3 deletions

View File

@@ -45,7 +45,7 @@ class GraphConverter:
None if node is convertible to MLIR, the reason for inconvertibility otherwise
"""
# pylint: disable=too-many-branches,too-many-return-statements
# pylint: disable=too-many-branches,too-many-return-statements,too-many-statements
inputs = node.inputs
output = node.output
@@ -121,6 +121,11 @@ class GraphConverter:
if not inputs[0].is_encrypted:
return "only encrypted sum is supported"
elif name == "transpose":
assert_that(len(inputs) == 1)
if not inputs[0].is_encrypted:
return "only encrypted transpose is supported"
else:
variable_input_indices = [
idx
@@ -135,7 +140,7 @@ class GraphConverter:
return None
# pylint: enable=too-many-branches,too-many-return-statements
# pylint: enable=too-many-branches,too-many-return-statements,too-many-statements
@staticmethod
def _check_graph_convertibility(graph: Graph):

View File

@@ -166,6 +166,9 @@ class NodeConverter:
elif name == "sum":
result = self.convert_sum()
elif name == "transpose":
result = self.convert_transpose()
else:
result = self.convert_tlu()
@@ -798,3 +801,17 @@ class NodeConverter:
result = fhe.ApplyLookupTableEintOp(resulting_type, pred, lut).result
return result
def convert_transpose(self) -> OpResult:
"""
Convert "transpose" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
preds = self.preds
return fhelinalg.TransposeOp(resulting_type, *preds).result

View File

@@ -313,4 +313,5 @@ class Node:
"reshape",
"subtract",
"sum",
"transpose",
]

View File

@@ -4,7 +4,7 @@ Declaration of `Tracer` class.
import inspect
from copy import deepcopy
from typing import Any, Callable, Dict, List, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import networkx as nx
import numpy as np
@@ -230,6 +230,7 @@ class Tracer:
np.sum,
np.tan,
np.tanh,
np.transpose,
np.true_divide,
np.trunc,
}
@@ -345,6 +346,10 @@ class Tracer:
sanitized_args = [self._sanitize(args[0])]
if len(args) > 1:
kwargs["newshape"] = args[1]
elif func is np.transpose:
sanitized_args = [self._sanitize(args[0])]
if len(args) > 1:
kwargs["axes"] = args[1]
else:
sanitized_args = [self._sanitize(arg) for arg in args]
@@ -516,6 +521,16 @@ class Tracer:
return Tracer._trace_numpy_operation(np.reshape, self, newshape=newshape)
def transpose(self, axes: Optional[Tuple[int, ...]] = None) -> "Tracer":
"""
Trace numpy.ndarray.transpose().
"""
if axes is None:
return Tracer._trace_numpy_operation(np.transpose, self)
return Tracer._trace_numpy_operation(np.transpose, self, axes=axes)
def __getitem__(
self,
index: Union[int, np.integer, slice, Tuple[Union[int, np.integer, slice], ...]],
@@ -583,3 +598,11 @@ class Tracer:
"""
return self.output.size
@property
def T(self) -> "Tracer": # pylint: disable=invalid-name
"""
Trace numpy.ndarray.T.
"""
return Tracer._trace_numpy_operation(np.transpose, self)

View File

@@ -89,6 +89,7 @@ List of supported functions:
- sum
- tan
- tanh
- transpose
- true_divide
- trunc
<!--- gen_supported_ufuncs.py: inject supported operations [END] -->

View File

@@ -0,0 +1,48 @@
"""
Tests of execution of transpose operation.
"""
import numpy as np
import pytest
import concrete.numpy as cnp
@pytest.mark.parametrize(
"function,parameters",
[
pytest.param(
lambda x: np.transpose(x),
{
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: x.transpose(),
{
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
},
),
pytest.param(
lambda x: x.T,
{
"x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"},
},
),
],
)
def test_transpose(function, parameters, helpers):
"""
Test transpose.
"""
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset)
sample = helpers.generate_sample(parameters)
helpers.check_execution(circuit, function, sample)

View File

@@ -309,6 +309,22 @@ Function you are trying to compile cannot be converted to MLIR:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only up to 8-bit integers are supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x: np.transpose(x),
{"x": "clear"},
[np.random.randint(0, 2, size=(3, 2)) for _ in range(100)],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # ClearTensor<uint1, shape=(3, 2)>
%1 = transpose(%0) # ClearTensor<uint1, shape=(2, 3)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted transpose is supported
return %1
""", # noqa: E501
),
],

View File

@@ -25,6 +25,18 @@ from concrete.numpy.values import EncryptedTensor
RuntimeError,
"Function 'np.sum' is not supported with kwarg 'initial'",
),
pytest.param(
lambda x: np.transpose(x, (1, 0, 2)),
{"x": EncryptedTensor(UnsignedInteger(7), shape=(1, 2, 3))},
RuntimeError,
"Function 'np.transpose' is not supported with kwarg 'axes'",
),
pytest.param(
lambda x: x.transpose((1, 0, 2)),
{"x": EncryptedTensor(UnsignedInteger(7), shape=(1, 2, 3))},
RuntimeError,
"Function 'np.transpose' is not supported with kwarg 'axes'",
),
pytest.param(
lambda x: np.multiply.outer(x, [1, 2, 3]),
{"x": EncryptedTensor(UnsignedInteger(7), shape=(4,))},