feat(tracing): support x.reshape as well as np.reshape(x, )

closes #701
This commit is contained in:
Benoit Chevallier-Mames
2021-11-02 18:03:41 +01:00
committed by Benoit Chevallier
parent b1df5c0fbe
commit 50a6b06c37
3 changed files with 93 additions and 6 deletions

View File

@@ -251,7 +251,7 @@ class NPTracer(BaseTracer):
)
return output_tracer
def dot(self, *args: "NPTracer", **_kwargs) -> "NPTracer":
def numpy_dot(self, *args: "NPTracer", **_kwargs) -> "NPTracer":
"""Trace numpy.dot.
Returns:
@@ -275,7 +275,27 @@ class NPTracer(BaseTracer):
)
return output_tracer
def dot(self, *args: "NPTracer", **kwargs) -> "NPTracer":
"""Trace x.dot.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
assert len(args) == 1
arg0 = self._sanitize(args[0])
assert_true(isinstance(arg0, NPTracer))
arg0 = cast(NPTracer, arg0)
return self.numpy_dot(self, arg0, **kwargs)
def transpose(self, *args: "NPTracer", **kwargs) -> "NPTracer":
"""Trace x.transpose.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
return self.numpy_transpose(self, *args, **kwargs)
def numpy_transpose(self, *args: "NPTracer", **kwargs) -> "NPTracer":
"""Trace numpy.transpose.
Returns:
@@ -304,6 +324,14 @@ class NPTracer(BaseTracer):
return output_tracer
def ravel(self, *args: "NPTracer", **kwargs) -> "NPTracer":
"""Trace x.ravel.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
return self.numpy_ravel(self, *args, **kwargs)
def numpy_ravel(self, *args: "NPTracer", **kwargs) -> "NPTracer":
"""Trace numpy.ravel.
Returns:
@@ -331,7 +359,15 @@ class NPTracer(BaseTracer):
)
return output_tracer
def reshape(self, arg0: "NPTracer", arg1: Tuple[Any, ...], **kwargs) -> "NPTracer":
def reshape(self, arg1: Tuple[Any, ...], **kwargs) -> "NPTracer":
"""Trace x.reshape.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
return self.numpy_reshape(self, arg1, **kwargs)
def numpy_reshape(self, arg0: "NPTracer", arg1: Tuple[Any, ...], **kwargs) -> "NPTracer":
"""Trace numpy.reshape.
Returns:
@@ -388,6 +424,11 @@ class NPTracer(BaseTracer):
return BaseTracer.__getitem__(self, item)
def __matmul__(self, other):
"""Trace numpy.matmul."""
return self.__array_ufunc__(numpy.matmul, "__call__", self, other)
def matmul(self, other):
"""Trace x.matmul."""
return self.__array_ufunc__(numpy.matmul, "__call__", self, other)
# Supported functions are either univariate or bivariate for which one of the two
@@ -490,10 +531,10 @@ class NPTracer(BaseTracer):
UFUNC_ROUTING: Dict[numpy.ufunc, Callable] = {}
FUNC_ROUTING: Dict[Callable, Callable] = {
numpy.dot: dot,
numpy.transpose: transpose,
numpy.reshape: reshape,
numpy.ravel: ravel,
numpy.dot: numpy_dot,
numpy.transpose: numpy_transpose,
numpy.reshape: numpy_reshape,
numpy.ravel: numpy_ravel,
}

View File

@@ -938,6 +938,20 @@ return(%7)
"return(%2)\n"
),
),
pytest.param(
lambda x: x.matmul(numpy.ones(shape=(2, 3), dtype=numpy.uint32)),
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)],
(
"function you are trying to compile isn't supported for MLIR lowering\n"
"\n"
"%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(3, 2)>\n" # noqa: E501
"%1 = Constant([[1 1 1] [1 1 1]]) # ClearTensor<Integer<unsigned, 1 bits>, shape=(2, 3)>\n" # noqa: E501
"%2 = MatMul(%0, %1) # EncryptedTensor<Integer<unsigned, 4 bits>, shape=(3, 3)>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being\n" # noqa: E501
"return(%2)\n"
),
),
pytest.param(
multi_lut,
{"x": EncryptedTensor(UnsignedInteger(2), shape=(3, 2))},

View File

@@ -557,6 +557,14 @@ def test_trace_numpy_ufuncs_no_kwargs_no_extra_args():
ir.Dot,
EncryptedScalar(Integer(64, True)),
),
pytest.param(
lambda x: x.dot(numpy.array([1, 2, 3, 4, 5], dtype=numpy.int64)),
{
"x": EncryptedTensor(Integer(64, is_signed=True), shape=(5,)),
},
ir.Dot,
EncryptedScalar(Integer(64, True)),
),
],
)
def test_trace_numpy_dot(function_to_trace, inputs, expected_output_node, expected_output_value):
@@ -613,6 +621,7 @@ def test_nptracer_unsupported_operands(operation, tracer):
@pytest.mark.parametrize(
"function_to_trace,input_value,input_and_expected_output_tuples",
[
# Indirect calls, like numpy.function(x, ...)
(
lambda x: numpy.transpose(x),
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
@@ -643,6 +652,29 @@ def test_nptracer_unsupported_operands(operation, tracer):
(numpy.arange(15).reshape(3, 5), numpy.arange(42, 57).reshape(5, 3)),
],
),
# Direct calls, like x.function(...)
(
lambda x: x.transpose() + 42,
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
[
(numpy.arange(15).reshape(3, 5), numpy.arange(42, 57).reshape(3, 5).transpose()),
],
),
(
lambda x: x.ravel(),
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
[
(numpy.arange(4), numpy.array([0, 1, 2, 3])),
(numpy.arange(4).reshape(2, 2), numpy.array([0, 1, 2, 3])),
],
),
(
lambda x: x.reshape((5, 3)) + 42,
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
[
(numpy.arange(15).reshape(3, 5), numpy.arange(42, 57).reshape(5, 3)),
],
),
],
)
def test_tracing_generic_function(function_to_trace, input_value, input_and_expected_output_tuples):