mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
committed by
Benoit Chevallier
parent
39c16038c7
commit
086dba4194
@@ -1,7 +1,7 @@
|
||||
"""numpy tracing utilities."""
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Union, cast
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy
|
||||
from numpy.typing import DTypeLike
|
||||
@@ -69,7 +69,15 @@ class NPTracer(BaseTracer):
|
||||
(len(kwargs) == 0),
|
||||
f"**kwargs are currently not supported for numpy functions, func: {func}",
|
||||
)
|
||||
sanitized_args = [self._sanitize(arg) for arg in args]
|
||||
|
||||
# Fixme: Special case to be removed once #772 is done
|
||||
if func is not numpy.reshape:
|
||||
sanitized_args = [self._sanitize(arg) for arg in args]
|
||||
else:
|
||||
# In numpy.reshape, the second argument is the new shape
|
||||
sanitized_args = [self._sanitize(args[0]), args[1]]
|
||||
return tracing_func(self, sanitized_args[0], sanitized_args[1], **kwargs)
|
||||
|
||||
return tracing_func(self, *sanitized_args, **kwargs)
|
||||
|
||||
def astype(self, numpy_dtype: DTypeLike, *args, **kwargs) -> "NPTracer":
|
||||
@@ -267,7 +275,7 @@ class NPTracer(BaseTracer):
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def transpose(self, *args: "NPTracer", **_kwargs) -> "NPTracer":
|
||||
def transpose(self, *args: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Trace numpy.transpose.
|
||||
|
||||
Returns:
|
||||
@@ -285,7 +293,7 @@ class NPTracer(BaseTracer):
|
||||
arbitrary_func=numpy.transpose,
|
||||
output_dtype=first_arg_output.dtype,
|
||||
output_shape=first_arg_output.shape[::-1],
|
||||
op_kwargs=deepcopy(_kwargs),
|
||||
op_kwargs=deepcopy(kwargs),
|
||||
op_name="np.transpose",
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
@@ -295,7 +303,7 @@ class NPTracer(BaseTracer):
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def ravel(self, *args: "NPTracer", **_kwargs) -> "NPTracer":
|
||||
def ravel(self, *args: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Trace numpy.ravel.
|
||||
|
||||
Returns:
|
||||
@@ -313,7 +321,7 @@ class NPTracer(BaseTracer):
|
||||
arbitrary_func=numpy.ravel,
|
||||
output_dtype=first_arg_output.dtype,
|
||||
output_shape=(numpy.product(first_arg_output.shape),),
|
||||
op_kwargs=deepcopy(_kwargs),
|
||||
op_kwargs=deepcopy(kwargs),
|
||||
op_name="np.ravel",
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
@@ -323,6 +331,54 @@ class NPTracer(BaseTracer):
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def reshape(self, arg0: "NPTracer", arg1: Tuple[Any, ...], **kwargs) -> "NPTracer":
|
||||
"""Trace numpy.reshape.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
|
||||
# FIXME: #772, restore reshape(self, *args, **kwargs) signature when possible, with mypy
|
||||
# types
|
||||
|
||||
# FIXME: #772, restore
|
||||
# assert_true((num_args := len(args)) == 2, f"reshape expect 2 input got {num_args}")
|
||||
# when possible
|
||||
|
||||
assert_true((num_kwargs := len(kwargs)) == 0, f"reshape expect 0 kwargs got {num_kwargs}")
|
||||
|
||||
first_arg_output = arg0.output
|
||||
assert_true(isinstance(first_arg_output, TensorValue))
|
||||
first_arg_output = cast(TensorValue, first_arg_output)
|
||||
|
||||
newshape = deepcopy(arg1)
|
||||
|
||||
if isinstance(newshape, int):
|
||||
# Make numpy.reshape(x, (170)) and numpy.reshape(x, 170) work, while classical form is
|
||||
# numpy.reshape(x, (170,))
|
||||
newshape = (newshape,)
|
||||
|
||||
# Check shape compatibility
|
||||
assert_true(
|
||||
numpy.product(newshape) == numpy.product(first_arg_output.shape),
|
||||
f"shapes are not compatible (old shape {first_arg_output.shape}, new shape {newshape})",
|
||||
)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
input_base_value=first_arg_output,
|
||||
arbitrary_func=numpy.reshape,
|
||||
output_dtype=first_arg_output.dtype,
|
||||
output_shape=newshape,
|
||||
op_kwargs={"newshape": newshape},
|
||||
op_name="np.reshape",
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
[arg0],
|
||||
traced_computation=traced_computation,
|
||||
output_idx=0,
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def __getitem__(self, item):
|
||||
if isinstance(item, tuple):
|
||||
item = tuple(process_indexing_element(indexing_element) for indexing_element in item)
|
||||
@@ -436,6 +492,7 @@ class NPTracer(BaseTracer):
|
||||
FUNC_ROUTING: Dict[Callable, Callable] = {
|
||||
numpy.dot: dot,
|
||||
numpy.transpose: transpose,
|
||||
numpy.reshape: reshape,
|
||||
numpy.ravel: ravel,
|
||||
}
|
||||
|
||||
|
||||
@@ -197,6 +197,17 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
|
||||
numpy.arange(15),
|
||||
id="GenericFunction, x ravel",
|
||||
),
|
||||
pytest.param(
|
||||
ir.GenericFunction(
|
||||
EncryptedTensor(Integer(32, False), shape=(3, 5)),
|
||||
lambda x: numpy.reshape(x, (5, 3)),
|
||||
Integer(32, False),
|
||||
output_shape=(5, 3),
|
||||
),
|
||||
[numpy.arange(15).reshape(3, 5)],
|
||||
numpy.arange(15).reshape(5, 3),
|
||||
id="GenericFunction, x reshape",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_evaluate(
|
||||
|
||||
@@ -215,6 +215,7 @@ def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str):
|
||||
)
|
||||
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
@pytest.mark.parametrize(
|
||||
"lambda_f,params,ref_graph_str",
|
||||
[
|
||||
@@ -223,14 +224,55 @@ def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str):
|
||||
{
|
||||
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(3, 5)),
|
||||
},
|
||||
"%0 = x\n%1 = np.transpose(%0)\nreturn(%1)\n",
|
||||
"""
|
||||
%0 = x # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3, 5)>
|
||||
%1 = np.transpose(%0) # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(5, 3)>
|
||||
return(%1)
|
||||
""".lstrip(), # noqa: E501
|
||||
),
|
||||
(
|
||||
lambda x: numpy.ravel(x),
|
||||
{
|
||||
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(3, 5)),
|
||||
},
|
||||
"%0 = x\n%1 = np.ravel(%0)\nreturn(%1)\n",
|
||||
"""
|
||||
%0 = x # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3, 5)>
|
||||
%1 = np.ravel(%0) # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(15,)>
|
||||
return(%1)
|
||||
""".lstrip(), # noqa: E501
|
||||
),
|
||||
(
|
||||
lambda x: numpy.reshape(x, (5, 3)),
|
||||
{
|
||||
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(3, 5)),
|
||||
},
|
||||
"""
|
||||
%0 = x # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3, 5)>
|
||||
%1 = np.reshape(%0) # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(5, 3)>
|
||||
return(%1)
|
||||
""".lstrip(), # noqa: E501
|
||||
),
|
||||
(
|
||||
lambda x: numpy.reshape(x, (170,)),
|
||||
{
|
||||
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(17, 10)),
|
||||
},
|
||||
"""
|
||||
%0 = x # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(17, 10)>
|
||||
%1 = np.reshape(%0) # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(170,)>
|
||||
return(%1)
|
||||
""".lstrip(), # noqa: E501
|
||||
),
|
||||
(
|
||||
lambda x: numpy.reshape(x, (170)),
|
||||
{
|
||||
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(17, 10)),
|
||||
},
|
||||
"""
|
||||
%0 = x # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(17, 10)>
|
||||
%1 = np.reshape(%0) # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(170,)>
|
||||
return(%1)
|
||||
""".lstrip(), # noqa: E501
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -240,7 +282,7 @@ def test_print_and_draw_graph_with_generic_function(lambda_f, params, ref_graph_
|
||||
|
||||
draw_graph(graph, show=False)
|
||||
|
||||
str_of_the_graph = get_printable_graph(graph)
|
||||
str_of_the_graph = get_printable_graph(graph, show_data_types=True)
|
||||
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
@@ -249,6 +291,9 @@ def test_print_and_draw_graph_with_generic_function(lambda_f, params, ref_graph_
|
||||
)
|
||||
|
||||
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
|
||||
# Remark that the bitwidths are not particularly correct (eg, a MUL of a 17b times 23b
|
||||
# returning 23b), since they are replaced later by the real bitwidths computed on the
|
||||
# inputset
|
||||
|
||||
@@ -636,6 +636,13 @@ def test_nptracer_unsupported_operands(operation, tracer):
|
||||
(numpy.arange(4).reshape(2, 2), numpy.array([0, 1, 2, 3])),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: numpy.reshape(x, (5, 3)) + 42,
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
[
|
||||
(numpy.arange(15).reshape(3, 5), numpy.arange(42, 57).reshape(5, 3)),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_tracing_generic_function(function_to_trace, input_value, input_and_expected_output_tuples):
|
||||
@@ -649,3 +656,22 @@ def test_tracing_generic_function(function_to_trace, input_value, input_and_expe
|
||||
evaluated_output = node_results[output_node]
|
||||
assert isinstance(evaluated_output, type(expected_output))
|
||||
assert numpy.array_equal(expected_output, evaluated_output)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lambda_f,params",
|
||||
[
|
||||
(
|
||||
lambda x: numpy.reshape(x, (5, 3)),
|
||||
{
|
||||
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(7, 5)),
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_errors_with_generic_function(lambda_f, params):
|
||||
"Test some errors with generic function"
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
assert "shapes are not compatible (old shape (7, 5), new shape (5, 3))" in str(excinfo.value)
|
||||
|
||||
Reference in New Issue
Block a user