From 8a27525a64a3f6de66e4dece8972b981aa6a69e9 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Thu, 18 Nov 2021 11:18:17 +0100 Subject: [PATCH] refactor: separate some compilation steps to prepare torch-like API work refs #233 --- concrete/numpy/__init__.py | 2 +- concrete/numpy/compile.py | 93 ++++++++++++++++--- .../common/compilation/test_configuration.py | 6 +- tests/common/debugging/test_drawing.py | 4 +- tests/common/debugging/test_formatting.py | 6 +- tests/numpy/test_compile.py | 21 +++-- tests/numpy/test_compile_constant_indexing.py | 10 +- tests/numpy/test_np_mlir_converter.py | 4 +- 8 files changed, 108 insertions(+), 38 deletions(-) diff --git a/concrete/numpy/__init__.py b/concrete/numpy/__init__.py index 19dc1cb7b..d7c693afa 100644 --- a/concrete/numpy/__init__.py +++ b/concrete/numpy/__init__.py @@ -6,5 +6,5 @@ from ..common.debugging import draw_graph, format_operation_graph from ..common.extensions.multi_table import MultiLookupTable from ..common.extensions.table import LookupTable from ..common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue -from .compile import compile_numpy_function, compile_numpy_function_into_op_graph +from .compile import compile_numpy_function, compile_numpy_function_into_op_graph_and_measure_bounds from .tracing import trace_numpy_function diff --git a/concrete/numpy/compile.py b/concrete/numpy/compile.py index d08531cdc..eaeea692a 100644 --- a/concrete/numpy/compile.py +++ b/concrete/numpy/compile.py @@ -61,28 +61,26 @@ def numpy_min_func(lhs: Any, rhs: Any) -> Any: def _compile_numpy_function_into_op_graph_internal( function_to_compile: Callable, function_parameters: Dict[str, BaseValue], - inputset: Iterable[Tuple[Any, ...]], compilation_configuration: CompilationConfiguration, compilation_artifacts: CompilationArtifacts, ) -> OPGraph: - """Compile a function into an OPGraph. + """Compile a function into an OPGraph without evaluating the intermediate nodes bounds. Args: function_to_compile (Callable): The function to compile function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the function is e.g. an EncryptedScalar holding a 7bits unsigned Integer - inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It - needs to be an iterable on tuples which are of the same length than the number of - parameters in the function, and in the same order than these same parameters - compilation_artifacts (CompilationArtifacts): Artifacts object to fill - during compilation compilation_configuration (CompilationConfiguration): Configuration object to use during compilation + compilation_artifacts (CompilationArtifacts): Artifacts object to fill + during compilation Returns: - OPGraph: compiled function into a graph + OPGraph: compiled function into a graph, node values are not representative of the values + that can be observed during execution. + Use _compile_numpy_function_into_op_graph_and_measure_bounds_internal if you need bounds + estimation. """ - # Check function parameters wrong_inputs = { inp: function_parameters[inp] @@ -119,6 +117,34 @@ def _compile_numpy_function_into_op_graph_internal( if not check_op_graph_is_integer_program(op_graph): fuse_float_operations(op_graph, compilation_artifacts) + return op_graph + + +def _measure_op_graph_bounds_and_update_internal( + op_graph: OPGraph, + function_parameters: Dict[str, BaseValue], + inputset: Iterable[Tuple[Any, ...]], + compilation_configuration: CompilationConfiguration, + compilation_artifacts: CompilationArtifacts, +) -> None: + """Measure the intermediate values and update the OPGraph accordingly for the given inputset. + + Args: + op_graph (OPGraph): the OPGraph for which to measure bounds and update node values. + function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the + function is e.g. an EncryptedScalar holding a 7bits unsigned Integer + inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It + needs to be an iterable on tuples which are of the same length than the number of + parameters in the function, and in the same order than these same parameters + compilation_configuration (CompilationConfiguration): Configuration object to use + during compilation + compilation_artifacts (CompilationArtifacts): Artifacts object to fill + during compilation + + Raises: + ValueError: Raises an error if the inputset is too small and the compilation configuration + treats warnings as error. + """ # Find bounds with the inputset inputset_size, node_bounds_and_samples = eval_op_graph_bounds_on_inputset( op_graph, @@ -167,13 +193,54 @@ def _compile_numpy_function_into_op_graph_internal( get_constructor_for_numpy_or_python_constant_data, ) - # Add the initial graph as an artifact + +def _compile_numpy_function_into_op_graph_and_measure_bounds_internal( + function_to_compile: Callable, + function_parameters: Dict[str, BaseValue], + inputset: Iterable[Tuple[Any, ...]], + compilation_configuration: CompilationConfiguration, + compilation_artifacts: CompilationArtifacts, +) -> OPGraph: + """Compile a function into an OPGraph and evaluate the intermediate nodes bounds. + + Args: + function_to_compile (Callable): The function to compile + function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the + function is e.g. an EncryptedScalar holding a 7bits unsigned Integer + inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It + needs to be an iterable on tuples which are of the same length than the number of + parameters in the function, and in the same order than these same parameters + compilation_configuration (CompilationConfiguration): Configuration object to use + during compilation + compilation_artifacts (CompilationArtifacts): Artifacts object to fill + during compilation + + Returns: + OPGraph: compiled function into a graph with estimated bounds in node values. + """ + + op_graph = _compile_numpy_function_into_op_graph_internal( + function_to_compile, + function_parameters, + compilation_configuration, + compilation_artifacts, + ) + + _measure_op_graph_bounds_and_update_internal( + op_graph, + function_parameters, + inputset, + compilation_configuration, + compilation_artifacts, + ) + + # Add the final graph as an artifact compilation_artifacts.add_operation_graph("final", op_graph) return op_graph -def compile_numpy_function_into_op_graph( +def compile_numpy_function_into_op_graph_and_measure_bounds( function_to_compile: Callable, function_parameters: Dict[str, BaseValue], inputset: Union[Iterable[Tuple[Any, ...]], str], @@ -220,7 +287,7 @@ def compile_numpy_function_into_op_graph( try: # Use context manager to restore numpy error handling with numpy.errstate(**numpy.geterr()): - return _compile_numpy_function_into_op_graph_internal( + return _compile_numpy_function_into_op_graph_and_measure_bounds_internal( function_to_compile, function_parameters, inputset, @@ -300,7 +367,7 @@ def _compile_numpy_function_internal( """ # Compile into an OPGraph - op_graph = _compile_numpy_function_into_op_graph_internal( + op_graph = _compile_numpy_function_into_op_graph_and_measure_bounds_internal( function_to_compile, function_parameters, inputset, diff --git a/tests/common/compilation/test_configuration.py b/tests/common/compilation/test_configuration.py index 9e8c669f9..c05ff053b 100644 --- a/tests/common/compilation/test_configuration.py +++ b/tests/common/compilation/test_configuration.py @@ -8,7 +8,7 @@ import pytest from concrete.common.compilation import CompilationConfiguration from concrete.common.data_types.integers import Integer from concrete.common.values import EncryptedScalar -from concrete.numpy.compile import compile_numpy_function_into_op_graph +from concrete.numpy.compile import compile_numpy_function_into_op_graph_and_measure_bounds def no_fuse(x): @@ -43,7 +43,7 @@ def test_enable_topological_optimizations( ): """Test function for enable_topological_optimizations flag of compilation configuration""" - op_graph = compile_numpy_function_into_op_graph( + op_graph = compile_numpy_function_into_op_graph_and_measure_bounds( function_to_trace, { param: EncryptedScalar(Integer(32, is_signed=False)) @@ -52,7 +52,7 @@ def test_enable_topological_optimizations( [(numpy.array(i),) for i in range(10)], default_compilation_configuration, ) - op_graph_not_optimized = compile_numpy_function_into_op_graph( + op_graph_not_optimized = compile_numpy_function_into_op_graph_and_measure_bounds( function_to_trace, { param: EncryptedScalar(Integer(32, is_signed=False)) diff --git a/tests/common/debugging/test_drawing.py b/tests/common/debugging/test_drawing.py index eb8d3d16f..bbb1d950e 100644 --- a/tests/common/debugging/test_drawing.py +++ b/tests/common/debugging/test_drawing.py @@ -6,7 +6,7 @@ from pathlib import Path from concrete.common.data_types.integers import Integer from concrete.common.debugging import draw_graph from concrete.common.values import EncryptedScalar -from concrete.numpy.compile import compile_numpy_function_into_op_graph +from concrete.numpy.compile import compile_numpy_function_into_op_graph_and_measure_bounds def test_draw_graph_with_saving(default_compilation_configuration): @@ -15,7 +15,7 @@ def test_draw_graph_with_saving(default_compilation_configuration): def function(x): return x + 42 - op_graph = compile_numpy_function_into_op_graph( + op_graph = compile_numpy_function_into_op_graph_and_measure_bounds( function, {"x": EncryptedScalar(Integer(7, True))}, [(i,) for i in range(-5, 5)], diff --git a/tests/common/debugging/test_formatting.py b/tests/common/debugging/test_formatting.py index 9ef7bf105..4e3a27b73 100644 --- a/tests/common/debugging/test_formatting.py +++ b/tests/common/debugging/test_formatting.py @@ -3,7 +3,7 @@ from concrete.common.data_types.integers import Integer from concrete.common.debugging import format_operation_graph from concrete.common.values import EncryptedScalar -from concrete.numpy.compile import compile_numpy_function_into_op_graph +from concrete.numpy.compile import compile_numpy_function_into_op_graph_and_measure_bounds def test_format_operation_graph_with_multiple_edges(default_compilation_configuration): @@ -12,7 +12,7 @@ def test_format_operation_graph_with_multiple_edges(default_compilation_configur def function(x): return x + x - op_graph = compile_numpy_function_into_op_graph( + op_graph = compile_numpy_function_into_op_graph_and_measure_bounds( function, {"x": EncryptedScalar(Integer(4, True))}, [(i,) for i in range(0, 10)], @@ -38,7 +38,7 @@ def test_format_operation_graph_with_offending_nodes(default_compilation_configu def function(x): return x + 42 - op_graph = compile_numpy_function_into_op_graph( + op_graph = compile_numpy_function_into_op_graph_and_measure_bounds( function, {"x": EncryptedScalar(Integer(7, True))}, [(i,) for i in range(-5, 5)], diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index c64c32613..951b9bd65 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -14,7 +14,10 @@ from concrete.common.extensions.multi_table import MultiLookupTable from concrete.common.extensions.table import LookupTable from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor from concrete.numpy import tracing -from concrete.numpy.compile import compile_numpy_function, compile_numpy_function_into_op_graph +from concrete.numpy.compile import ( + compile_numpy_function, + compile_numpy_function_into_op_graph_and_measure_bounds, +) # pylint: disable=too-many-lines @@ -624,7 +627,7 @@ def test_compile_function_multiple_outputs( arg_name: EncryptedScalar(Integer(64, True)) for arg_name in list_of_arg_names } - op_graph = compile_numpy_function_into_op_graph( + op_graph = compile_numpy_function_into_op_graph_and_measure_bounds( function, function_parameters, data_gen_local(tuple(range(x[0], x[1] + 1) for x in input_ranges)), @@ -1328,7 +1331,7 @@ def test_compile_function_with_direct_tlu(default_compilation_configuration): def function(x): return x + table[x] - op_graph = compile_numpy_function_into_op_graph( + op_graph = compile_numpy_function_into_op_graph_and_measure_bounds( function, {"x": EncryptedScalar(Integer(2, is_signed=False))}, [(0,), (1,), (2,), (3,)], @@ -1348,7 +1351,7 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura return table[x] with pytest.raises(ValueError): - compile_numpy_function_into_op_graph( + compile_numpy_function_into_op_graph_and_measure_bounds( function, {"x": EncryptedScalar(Integer(3, is_signed=False))}, [(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,)], @@ -1570,7 +1573,7 @@ return %7 def test_small_inputset_no_fail(): """Test function compile_numpy_function_into_op_graph with an unacceptably small inputset""" - compile_numpy_function_into_op_graph( + compile_numpy_function_into_op_graph_and_measure_bounds( lambda x: x + 42, {"x": EncryptedScalar(Integer(5, is_signed=False))}, [(0,), (3,)], @@ -1581,7 +1584,7 @@ def test_small_inputset_no_fail(): def test_small_inputset_treat_warnings_as_errors(): """Test function compile_numpy_function_into_op_graph with an unacceptably small inputset""" with pytest.raises(ValueError, match=".* inputset contains too few inputs .*"): - compile_numpy_function_into_op_graph( + compile_numpy_function_into_op_graph_and_measure_bounds( lambda x: x + 42, {"x": EncryptedScalar(Integer(5, is_signed=False))}, [(0,), (3,)], @@ -1632,7 +1635,7 @@ def test_compile_function_with_dot( assert len(shape) == 1 repeat = shape[0] - op_graph = compile_numpy_function_into_op_graph( + op_graph = compile_numpy_function_into_op_graph_and_measure_bounds( function, params, data_gen_local(max_for_ij, repeat), @@ -1729,7 +1732,7 @@ def test_compile_with_random_inputset(default_compilation_configuration): configuration_to_use = deepcopy(default_compilation_configuration) configuration_to_use.enable_unsafe_features = True - compile_numpy_function_into_op_graph( + compile_numpy_function_into_op_graph_and_measure_bounds( lambda x: x + 1, {"x": EncryptedScalar(UnsignedInteger(6))}, inputset="random", @@ -1748,7 +1751,7 @@ def test_fail_compile_with_random_inputset(default_compilation_configuration): with pytest.raises(ValueError): try: - compile_numpy_function_into_op_graph( + compile_numpy_function_into_op_graph_and_measure_bounds( lambda x: x + 1, {"x": EncryptedScalar(UnsignedInteger(3))}, inputset="unsupported", diff --git a/tests/numpy/test_compile_constant_indexing.py b/tests/numpy/test_compile_constant_indexing.py index 197a0b1f0..acf8b3788 100644 --- a/tests/numpy/test_compile_constant_indexing.py +++ b/tests/numpy/test_compile_constant_indexing.py @@ -5,7 +5,7 @@ import pytest from concrete.common.data_types import UnsignedInteger from concrete.common.values import EncryptedScalar, EncryptedTensor -from concrete.numpy import compile_numpy_function_into_op_graph +from concrete.numpy import compile_numpy_function_into_op_graph_and_measure_bounds @pytest.mark.parametrize( @@ -347,7 +347,7 @@ def test_constant_indexing( for _ in range(10) ] - op_graph = compile_numpy_function_into_op_graph( + op_graph = compile_numpy_function_into_op_graph_and_measure_bounds( function_with_indexing, {"x": input_value}, inputset, @@ -470,7 +470,7 @@ def test_invalid_constant_indexing( ) for _ in range(10) ] - compile_numpy_function_into_op_graph( + compile_numpy_function_into_op_graph_and_measure_bounds( function_with_indexing, {"x": input_value}, inputset, @@ -525,7 +525,7 @@ def test_constant_indexing_with_numpy_integers( for _ in range(10) ] - op_graph = compile_numpy_function_into_op_graph( + op_graph = compile_numpy_function_into_op_graph_and_measure_bounds( function_with_indexing, {"x": input_value}, inputset, @@ -586,7 +586,7 @@ def test_invalid_constant_indexing_with_numpy_values( ) for _ in range(10) ] - compile_numpy_function_into_op_graph( + compile_numpy_function_into_op_graph_and_measure_bounds( function_with_indexing, {"x": input_value}, inputset, diff --git a/tests/numpy/test_np_mlir_converter.py b/tests/numpy/test_np_mlir_converter.py index c986def80..7502fa0cb 100644 --- a/tests/numpy/test_np_mlir_converter.py +++ b/tests/numpy/test_np_mlir_converter.py @@ -42,7 +42,7 @@ def test_generate_deduplicated_tables( function, expected_number_of_tables, default_compilation_configuration ): """Test function for generate_deduplicated_tables""" - op_graph = hnp.compile_numpy_function_into_op_graph( + op_graph = hnp.compile_numpy_function_into_op_graph_and_measure_bounds( function, {"x": hnp.EncryptedTensor(hnp.Integer(7, False), RESNET_BIGGEST_SHAPE)}, ((i * numpy.ones(RESNET_BIGGEST_SHAPE, dtype=numpy.int32),) for i in range(128)), @@ -69,7 +69,7 @@ def test_deduplicated_tables_correctness(default_compilation_configuration): tensor_shape = (2, 2) - op_graph = hnp.compile_numpy_function_into_op_graph( + op_graph = hnp.compile_numpy_function_into_op_graph_and_measure_bounds( lambda x: multi_tlu_func(x, numpy.arange(4, dtype=numpy.float64).reshape(tensor_shape)), {"x": hnp.EncryptedTensor(hnp.Integer(2, False), tensor_shape)}, ((i * numpy.ones(tensor_shape, dtype=numpy.int32),) for i in range(4)),