feat: management of reshape

refs #615
closes #786
This commit is contained in:
Benoit Chevallier-Mames
2021-10-29 17:53:30 +02:00
committed by Benoit Chevallier
parent 39c16038c7
commit 086dba4194
4 changed files with 148 additions and 9 deletions

View File

@@ -1,7 +1,7 @@
"""numpy tracing utilities."""
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union, cast
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
import numpy
from numpy.typing import DTypeLike
@@ -69,7 +69,15 @@ class NPTracer(BaseTracer):
(len(kwargs) == 0),
f"**kwargs are currently not supported for numpy functions, func: {func}",
)
sanitized_args = [self._sanitize(arg) for arg in args]
# Fixme: Special case to be removed once #772 is done
if func is not numpy.reshape:
sanitized_args = [self._sanitize(arg) for arg in args]
else:
# In numpy.reshape, the second argument is the new shape
sanitized_args = [self._sanitize(args[0]), args[1]]
return tracing_func(self, sanitized_args[0], sanitized_args[1], **kwargs)
return tracing_func(self, *sanitized_args, **kwargs)
def astype(self, numpy_dtype: DTypeLike, *args, **kwargs) -> "NPTracer":
@@ -267,7 +275,7 @@ class NPTracer(BaseTracer):
)
return output_tracer
def transpose(self, *args: "NPTracer", **_kwargs) -> "NPTracer":
def transpose(self, *args: "NPTracer", **kwargs) -> "NPTracer":
"""Trace numpy.transpose.
Returns:
@@ -285,7 +293,7 @@ class NPTracer(BaseTracer):
arbitrary_func=numpy.transpose,
output_dtype=first_arg_output.dtype,
output_shape=first_arg_output.shape[::-1],
op_kwargs=deepcopy(_kwargs),
op_kwargs=deepcopy(kwargs),
op_name="np.transpose",
)
output_tracer = self.__class__(
@@ -295,7 +303,7 @@ class NPTracer(BaseTracer):
)
return output_tracer
def ravel(self, *args: "NPTracer", **_kwargs) -> "NPTracer":
def ravel(self, *args: "NPTracer", **kwargs) -> "NPTracer":
"""Trace numpy.ravel.
Returns:
@@ -313,7 +321,7 @@ class NPTracer(BaseTracer):
arbitrary_func=numpy.ravel,
output_dtype=first_arg_output.dtype,
output_shape=(numpy.product(first_arg_output.shape),),
op_kwargs=deepcopy(_kwargs),
op_kwargs=deepcopy(kwargs),
op_name="np.ravel",
)
output_tracer = self.__class__(
@@ -323,6 +331,54 @@ class NPTracer(BaseTracer):
)
return output_tracer
def reshape(self, arg0: "NPTracer", arg1: Tuple[Any, ...], **kwargs) -> "NPTracer":
"""Trace numpy.reshape.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
# FIXME: #772, restore reshape(self, *args, **kwargs) signature when possible, with mypy
# types
# FIXME: #772, restore
# assert_true((num_args := len(args)) == 2, f"reshape expect 2 input got {num_args}")
# when possible
assert_true((num_kwargs := len(kwargs)) == 0, f"reshape expect 0 kwargs got {num_kwargs}")
first_arg_output = arg0.output
assert_true(isinstance(first_arg_output, TensorValue))
first_arg_output = cast(TensorValue, first_arg_output)
newshape = deepcopy(arg1)
if isinstance(newshape, int):
# Make numpy.reshape(x, (170)) and numpy.reshape(x, 170) work, while classical form is
# numpy.reshape(x, (170,))
newshape = (newshape,)
# Check shape compatibility
assert_true(
numpy.product(newshape) == numpy.product(first_arg_output.shape),
f"shapes are not compatible (old shape {first_arg_output.shape}, new shape {newshape})",
)
traced_computation = GenericFunction(
input_base_value=first_arg_output,
arbitrary_func=numpy.reshape,
output_dtype=first_arg_output.dtype,
output_shape=newshape,
op_kwargs={"newshape": newshape},
op_name="np.reshape",
)
output_tracer = self.__class__(
[arg0],
traced_computation=traced_computation,
output_idx=0,
)
return output_tracer
def __getitem__(self, item):
if isinstance(item, tuple):
item = tuple(process_indexing_element(indexing_element) for indexing_element in item)
@@ -436,6 +492,7 @@ class NPTracer(BaseTracer):
FUNC_ROUTING: Dict[Callable, Callable] = {
numpy.dot: dot,
numpy.transpose: transpose,
numpy.reshape: reshape,
numpy.ravel: ravel,
}

View File

@@ -197,6 +197,17 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
numpy.arange(15),
id="GenericFunction, x ravel",
),
pytest.param(
ir.GenericFunction(
EncryptedTensor(Integer(32, False), shape=(3, 5)),
lambda x: numpy.reshape(x, (5, 3)),
Integer(32, False),
output_shape=(5, 3),
),
[numpy.arange(15).reshape(3, 5)],
numpy.arange(15).reshape(5, 3),
id="GenericFunction, x reshape",
),
],
)
def test_evaluate(

View File

@@ -215,6 +215,7 @@ def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str):
)
# pylint: disable=line-too-long
@pytest.mark.parametrize(
"lambda_f,params,ref_graph_str",
[
@@ -223,14 +224,55 @@ def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str):
{
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(3, 5)),
},
"%0 = x\n%1 = np.transpose(%0)\nreturn(%1)\n",
"""
%0 = x # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3, 5)>
%1 = np.transpose(%0) # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(5, 3)>
return(%1)
""".lstrip(), # noqa: E501
),
(
lambda x: numpy.ravel(x),
{
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(3, 5)),
},
"%0 = x\n%1 = np.ravel(%0)\nreturn(%1)\n",
"""
%0 = x # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3, 5)>
%1 = np.ravel(%0) # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(15,)>
return(%1)
""".lstrip(), # noqa: E501
),
(
lambda x: numpy.reshape(x, (5, 3)),
{
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(3, 5)),
},
"""
%0 = x # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3, 5)>
%1 = np.reshape(%0) # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(5, 3)>
return(%1)
""".lstrip(), # noqa: E501
),
(
lambda x: numpy.reshape(x, (170,)),
{
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(17, 10)),
},
"""
%0 = x # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(17, 10)>
%1 = np.reshape(%0) # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(170,)>
return(%1)
""".lstrip(), # noqa: E501
),
(
lambda x: numpy.reshape(x, (170)),
{
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(17, 10)),
},
"""
%0 = x # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(17, 10)>
%1 = np.reshape(%0) # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(170,)>
return(%1)
""".lstrip(), # noqa: E501
),
],
)
@@ -240,7 +282,7 @@ def test_print_and_draw_graph_with_generic_function(lambda_f, params, ref_graph_
draw_graph(graph, show=False)
str_of_the_graph = get_printable_graph(graph)
str_of_the_graph = get_printable_graph(graph, show_data_types=True)
assert str_of_the_graph == ref_graph_str, (
f"\n==================\nGot \n{str_of_the_graph}"
@@ -249,6 +291,9 @@ def test_print_and_draw_graph_with_generic_function(lambda_f, params, ref_graph_
)
# pylint: enable=line-too-long
# Remark that the bitwidths are not particularly correct (eg, a MUL of a 17b times 23b
# returning 23b), since they are replaced later by the real bitwidths computed on the
# inputset

View File

@@ -636,6 +636,13 @@ def test_nptracer_unsupported_operands(operation, tracer):
(numpy.arange(4).reshape(2, 2), numpy.array([0, 1, 2, 3])),
],
),
(
lambda x: numpy.reshape(x, (5, 3)) + 42,
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
[
(numpy.arange(15).reshape(3, 5), numpy.arange(42, 57).reshape(5, 3)),
],
),
],
)
def test_tracing_generic_function(function_to_trace, input_value, input_and_expected_output_tuples):
@@ -649,3 +656,22 @@ def test_tracing_generic_function(function_to_trace, input_value, input_and_expe
evaluated_output = node_results[output_node]
assert isinstance(evaluated_output, type(expected_output))
assert numpy.array_equal(expected_output, evaluated_output)
@pytest.mark.parametrize(
"lambda_f,params",
[
(
lambda x: numpy.reshape(x, (5, 3)),
{
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(7, 5)),
},
),
],
)
def test_errors_with_generic_function(lambda_f, params):
"Test some errors with generic function"
with pytest.raises(AssertionError) as excinfo:
tracing.trace_numpy_function(lambda_f, params)
assert "shapes are not compatible (old shape (7, 5), new shape (5, 3))" in str(excinfo.value)