From 8123a5ef4521f02fc0ed230a5f3f84b0c910b421 Mon Sep 17 00:00:00 2001 From: Benoit Chevallier-Mames Date: Tue, 26 Oct 2021 12:11:18 +0200 Subject: [PATCH] feat: implement a generic node for functions which change shape and implement np.transpose with it and implement np.ravel with it refs #745 --- concrete/common/debugging/drawing.py | 3 + concrete/common/debugging/printing.py | 3 +- concrete/common/mlir/utils.py | 3 + .../common/representation/intermediate.py | 49 +++++++++++++ concrete/numpy/tracing.py | 72 +++++++++++++++++-- .../representation/test_intermediate.py | 24 ++++++- tests/conftest.py | 2 + tests/numpy/test_debugging.py | 34 +++++++++ tests/numpy/test_tracing.py | 41 +++++++++++ 9 files changed, 225 insertions(+), 6 deletions(-) diff --git a/concrete/common/debugging/drawing.py b/concrete/common/debugging/drawing.py index 664c38348..40e91f4e1 100644 --- a/concrete/common/debugging/drawing.py +++ b/concrete/common/debugging/drawing.py @@ -16,6 +16,7 @@ from ..representation.intermediate import ( Add, Constant, Dot, + GenericFunction, IndexConstant, Input, MatMul, @@ -31,10 +32,12 @@ IR_NODE_COLOR_MAPPING = { Sub: "yellow", Mul: "green", UnivariateFunction: "orange", + GenericFunction: "orange", IndexConstant: "black", Dot: "purple", MatMul: "brown", "UnivariateFunction": "orange", + "GenericFunction": "orange", "TLU": "grey", "output": "magenta", } diff --git a/concrete/common/debugging/printing.py b/concrete/common/debugging/printing.py index 13c334dcf..83947423e 100644 --- a/concrete/common/debugging/printing.py +++ b/concrete/common/debugging/printing.py @@ -8,6 +8,7 @@ from ..debugging.custom_assert import assert_true from ..operator_graph import OPGraph from ..representation.intermediate import ( Constant, + GenericFunction, IndexConstant, Input, IntermediateNode, @@ -91,7 +92,7 @@ def get_printable_graph( base_name = node.__class__.__name__ - if isinstance(node, UnivariateFunction): + if isinstance(node, (UnivariateFunction, GenericFunction)): base_name = node.op_name what_to_print = base_name + "(" diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py index a40de7557..20ea6d236 100644 --- a/concrete/common/mlir/utils.py +++ b/concrete/common/mlir/utils.py @@ -69,6 +69,9 @@ def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool) if not value_is_scalar(inputs[0]) or not value_is_unsigned_integer(inputs[0]): return "only unsigned integer scalar lookup tables are supported" + elif isinstance(node, intermediate.GenericFunction): # constraints for generic functions + return f"{node.op_name} is not supported for the time being" # pragma: no cover + elif isinstance(node, intermediate.Dot): # constraints for dot product assert_true(len(inputs) == 2) if not value_is_unsigned_integer(inputs[0]) or not value_is_unsigned_integer(inputs[1]): diff --git a/concrete/common/representation/intermediate.py b/concrete/common/representation/intermediate.py index e4493ce74..949bef2d1 100644 --- a/concrete/common/representation/intermediate.py +++ b/concrete/common/representation/intermediate.py @@ -470,3 +470,52 @@ class MatMul(IntermediateNode): def label(self) -> str: return "@" + + +class GenericFunction(IntermediateNode): + """Return the node representing a generic function.""" + + # The arbitrary_func is not optional but mypy has a long standing bug and is not able to + # understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623 + # arbitrary_func can take more than one argument but during evaluation the input variable will + # be the first argument passed to it. You can add other constant arguments needed for the proper + # execution of the function through op_args and op_kwargs. + arbitrary_func: Optional[Callable] + op_name: str + op_args: Tuple[Any, ...] + op_kwargs: Dict[str, Any] + op_attributes: Dict[str, Any] + _n_in: int = 1 + + def __init__( + self, + input_base_value: TensorValue, + arbitrary_func: Callable, + output_dtype: BaseDataType, + output_shape: Tuple, + op_name: Optional[str] = None, + op_args: Optional[Tuple[Any, ...]] = None, + op_kwargs: Optional[Dict[str, Any]] = None, + op_attributes: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__([input_base_value]) + assert_true(len(self.inputs) == 1) + self.arbitrary_func = arbitrary_func + self.op_args = op_args if op_args is not None else () + self.op_kwargs = op_kwargs if op_kwargs is not None else {} + self.op_attributes = op_attributes if op_attributes is not None else {} + + self.outputs = [ + EncryptedTensor(output_dtype, output_shape) + if self.inputs[0].is_encrypted + else ClearTensor(output_dtype, output_shape) + ] + self.op_name = op_name if op_name is not None else self.__class__.__name__ + + def evaluate(self, inputs: Dict[int, Any]) -> Any: + # This is the continuation of the mypy bug workaround + assert self.arbitrary_func is not None + return self.arbitrary_func(inputs[0], *self.op_args, **self.op_kwargs) + + def label(self) -> str: + return self.op_name diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index 8b5488d20..bec4b7f57 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -1,17 +1,23 @@ """numpy tracing utilities.""" from copy import deepcopy from functools import partial -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union, cast import numpy from numpy.typing import DTypeLike from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype -from ..common.debugging.custom_assert import assert_true +from ..common.debugging.custom_assert import assert_false, assert_true from ..common.operator_graph import OPGraph -from ..common.representation.intermediate import Constant, Dot, MatMul, UnivariateFunction +from ..common.representation.intermediate import ( + Constant, + Dot, + GenericFunction, + MatMul, + UnivariateFunction, +) from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters -from ..common.values import BaseValue +from ..common.values import BaseValue, TensorValue from .np_dtypes_helpers import ( SUPPORTED_NUMPY_DTYPES_CLASS_TYPES, convert_numpy_dtype_to_base_data_type, @@ -261,6 +267,62 @@ class NPTracer(BaseTracer): ) return output_tracer + def transpose(self, *args: "NPTracer", **_kwargs) -> "NPTracer": + """Trace numpy.transpose. + + Returns: + NPTracer: The output NPTracer containing the traced function + """ + assert_true((num_args := len(args)) == 1, f"transpose expect 1 input got {num_args}") + + first_arg_output = args[0].output + assert_true(isinstance(first_arg_output, TensorValue)) + first_arg_output = cast(TensorValue, first_arg_output) + assert_false(first_arg_output.is_scalar) + + traced_computation = GenericFunction( + input_base_value=first_arg_output, + arbitrary_func=numpy.transpose, + output_dtype=first_arg_output.dtype, + output_shape=first_arg_output.shape[::-1], + op_kwargs=deepcopy(_kwargs), + op_name="np.transpose", + ) + output_tracer = self.__class__( + args, + traced_computation=traced_computation, + output_idx=0, + ) + return output_tracer + + def ravel(self, *args: "NPTracer", **_kwargs) -> "NPTracer": + """Trace numpy.ravel. + + Returns: + NPTracer: The output NPTracer containing the traced function + """ + assert_true((num_args := len(args)) == 1, f"ravel expect 1 input got {num_args}") + + first_arg_output = args[0].output + assert_true(isinstance(first_arg_output, TensorValue)) + first_arg_output = cast(TensorValue, first_arg_output) + assert_false(first_arg_output.is_scalar) + + traced_computation = GenericFunction( + input_base_value=first_arg_output, + arbitrary_func=numpy.ravel, + output_dtype=first_arg_output.dtype, + output_shape=(numpy.product(first_arg_output.shape),), + op_kwargs=deepcopy(_kwargs), + op_name="np.ravel", + ) + output_tracer = self.__class__( + args, + traced_computation=traced_computation, + output_idx=0, + ) + return output_tracer + def __getitem__(self, item): if isinstance(item, tuple): item = tuple(process_indexing_element(indexing_element) for indexing_element in item) @@ -373,6 +435,8 @@ class NPTracer(BaseTracer): FUNC_ROUTING: Dict[Callable, Callable] = { numpy.dot: dot, + numpy.transpose: transpose, + numpy.ravel: ravel, } diff --git a/tests/common/representation/test_intermediate.py b/tests/common/representation/test_intermediate.py index 00e5905ef..ab53b0a4e 100644 --- a/tests/common/representation/test_intermediate.py +++ b/tests/common/representation/test_intermediate.py @@ -175,6 +175,28 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En numpy.array([[9, 12, 15], [19, 26, 33], [29, 40, 51]]), id="MatMul, numpy.arange(1, 7).reshape(3, 2), numpy.arange(1, 7).reshape(2, 3)", ), + pytest.param( + ir.GenericFunction( + EncryptedTensor(Integer(32, False), shape=(3, 5)), + lambda x: numpy.transpose(x), + Integer(32, False), + output_shape=(5, 3), + ), + [numpy.arange(15).reshape(3, 5)], + numpy.array([[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8, 13], [4, 9, 14]]), + id="GenericFunction, x transpose", + ), + pytest.param( + ir.GenericFunction( + EncryptedTensor(Integer(32, False), shape=(3, 5)), + lambda x: numpy.ravel(x), + Integer(32, False), + output_shape=(5, 3), + ), + [numpy.arange(15).reshape(3, 5)], + numpy.arange(15), + id="GenericFunction, x ravel", + ), ], ) def test_evaluate( @@ -184,7 +206,7 @@ def test_evaluate( ): """Test evaluate methods on IntermediateNodes""" if isinstance(expected_result, numpy.ndarray): - assert (node.evaluate(input_data) == expected_result).all() + assert numpy.array_equal(node.evaluate(input_data), expected_result) else: assert node.evaluate(input_data) == expected_result diff --git a/tests/conftest.py b/tests/conftest.py index fda1dc7bc..ae71b2ea8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,7 @@ from concrete.common.representation.intermediate import ( Add, Constant, Dot, + GenericFunction, IndexConstant, Input, IntermediateNode, @@ -186,6 +187,7 @@ def is_equivalent_intermediate_node(lhs: IntermediateNode, rhs: object) -> bool: EQUIVALENT_TEST_FUNC: Dict[Type, Callable[..., bool]] = { Add: is_equivalent_add, UnivariateFunction: is_equivalent_arbitrary_function, + GenericFunction: is_equivalent_arbitrary_function, Constant: is_equivalent_constant, Dot: is_equivalent_dot, IndexConstant: is_equivalent_index_constant, diff --git a/tests/numpy/test_debugging.py b/tests/numpy/test_debugging.py index 254a7fac2..fa9968680 100644 --- a/tests/numpy/test_debugging.py +++ b/tests/numpy/test_debugging.py @@ -215,6 +215,40 @@ def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str): ) +@pytest.mark.parametrize( + "lambda_f,params,ref_graph_str", + [ + ( + lambda x: numpy.transpose(x), + { + "x": EncryptedTensor(Integer(2, is_signed=False), shape=(3, 5)), + }, + "%0 = x\n%1 = np.transpose(%0)\nreturn(%1)\n", + ), + ( + lambda x: numpy.ravel(x), + { + "x": EncryptedTensor(Integer(2, is_signed=False), shape=(3, 5)), + }, + "%0 = x\n%1 = np.ravel(%0)\nreturn(%1)\n", + ), + ], +) +def test_print_and_draw_graph_with_generic_function(lambda_f, params, ref_graph_str): + "Test get_printable_graph and draw_graph on graphs with generic function" + graph = tracing.trace_numpy_function(lambda_f, params) + + draw_graph(graph, show=False) + + str_of_the_graph = get_printable_graph(graph) + + assert str_of_the_graph == ref_graph_str, ( + f"\n==================\nGot \n{str_of_the_graph}" + f"==================\nExpected \n{ref_graph_str}" + f"==================\n" + ) + + # Remark that the bitwidths are not particularly correct (eg, a MUL of a 17b times 23b # returning 23b), since they are replaced later by the real bitwidths computed on the # inputset diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index 441f1ed22..ffb0a9bbc 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -608,3 +608,44 @@ def test_nptracer_unsupported_operands(operation, tracer): """Test cases where NPTracer cannot be used with other operands.""" with pytest.raises(TypeError): tracer = operation(tracer) + + +@pytest.mark.parametrize( + "function_to_trace,input_value,input_and_expected_output_tuples", + [ + ( + lambda x: numpy.transpose(x), + EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)), + [ + (numpy.arange(4).reshape(2, 2), numpy.array([[0, 2], [1, 3]])), + (numpy.arange(4, 8).reshape(2, 2), numpy.array([[4, 6], [5, 7]])), + ], + ), + ( + lambda x: numpy.transpose(x) + 42, + EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), + [ + (numpy.arange(15).reshape(3, 5), numpy.arange(42, 57).reshape(3, 5).transpose()), + ], + ), + ( + lambda x: numpy.ravel(x), + EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)), + [ + (numpy.arange(4), numpy.array([0, 1, 2, 3])), + (numpy.arange(4).reshape(2, 2), numpy.array([0, 1, 2, 3])), + ], + ), + ], +) +def test_tracing_generic_function(function_to_trace, input_value, input_and_expected_output_tuples): + """Test function for managed by GenericFunction node""" + for input_, expected_output in input_and_expected_output_tuples: + + op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value}) + output_node = op_graph.output_nodes[0] + + node_results = op_graph.evaluate({0: input_}) + evaluated_output = node_results[output_node] + assert isinstance(evaluated_output, type(expected_output)) + assert numpy.array_equal(expected_output, evaluated_output)