refactor: separate some compilation steps to prepare torch-like API work

refs #233
This commit is contained in:
Arthur Meyre
2021-11-18 11:18:17 +01:00
parent e4063d79da
commit 8a27525a64
8 changed files with 108 additions and 38 deletions

View File

@@ -6,5 +6,5 @@ from ..common.debugging import draw_graph, format_operation_graph
from ..common.extensions.multi_table import MultiLookupTable
from ..common.extensions.table import LookupTable
from ..common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue
from .compile import compile_numpy_function, compile_numpy_function_into_op_graph
from .compile import compile_numpy_function, compile_numpy_function_into_op_graph_and_measure_bounds
from .tracing import trace_numpy_function

View File

@@ -61,28 +61,26 @@ def numpy_min_func(lhs: Any, rhs: Any) -> Any:
def _compile_numpy_function_into_op_graph_internal(
function_to_compile: Callable,
function_parameters: Dict[str, BaseValue],
inputset: Iterable[Tuple[Any, ...]],
compilation_configuration: CompilationConfiguration,
compilation_artifacts: CompilationArtifacts,
) -> OPGraph:
"""Compile a function into an OPGraph.
"""Compile a function into an OPGraph without evaluating the intermediate nodes bounds.
Args:
function_to_compile (Callable): The function to compile
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
needs to be an iterable on tuples which are of the same length than the number of
parameters in the function, and in the same order than these same parameters
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
during compilation
compilation_configuration (CompilationConfiguration): Configuration object to use
during compilation
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
during compilation
Returns:
OPGraph: compiled function into a graph
OPGraph: compiled function into a graph, node values are not representative of the values
that can be observed during execution.
Use _compile_numpy_function_into_op_graph_and_measure_bounds_internal if you need bounds
estimation.
"""
# Check function parameters
wrong_inputs = {
inp: function_parameters[inp]
@@ -119,6 +117,34 @@ def _compile_numpy_function_into_op_graph_internal(
if not check_op_graph_is_integer_program(op_graph):
fuse_float_operations(op_graph, compilation_artifacts)
return op_graph
def _measure_op_graph_bounds_and_update_internal(
op_graph: OPGraph,
function_parameters: Dict[str, BaseValue],
inputset: Iterable[Tuple[Any, ...]],
compilation_configuration: CompilationConfiguration,
compilation_artifacts: CompilationArtifacts,
) -> None:
"""Measure the intermediate values and update the OPGraph accordingly for the given inputset.
Args:
op_graph (OPGraph): the OPGraph for which to measure bounds and update node values.
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
needs to be an iterable on tuples which are of the same length than the number of
parameters in the function, and in the same order than these same parameters
compilation_configuration (CompilationConfiguration): Configuration object to use
during compilation
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
during compilation
Raises:
ValueError: Raises an error if the inputset is too small and the compilation configuration
treats warnings as error.
"""
# Find bounds with the inputset
inputset_size, node_bounds_and_samples = eval_op_graph_bounds_on_inputset(
op_graph,
@@ -167,13 +193,54 @@ def _compile_numpy_function_into_op_graph_internal(
get_constructor_for_numpy_or_python_constant_data,
)
# Add the initial graph as an artifact
def _compile_numpy_function_into_op_graph_and_measure_bounds_internal(
function_to_compile: Callable,
function_parameters: Dict[str, BaseValue],
inputset: Iterable[Tuple[Any, ...]],
compilation_configuration: CompilationConfiguration,
compilation_artifacts: CompilationArtifacts,
) -> OPGraph:
"""Compile a function into an OPGraph and evaluate the intermediate nodes bounds.
Args:
function_to_compile (Callable): The function to compile
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
needs to be an iterable on tuples which are of the same length than the number of
parameters in the function, and in the same order than these same parameters
compilation_configuration (CompilationConfiguration): Configuration object to use
during compilation
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
during compilation
Returns:
OPGraph: compiled function into a graph with estimated bounds in node values.
"""
op_graph = _compile_numpy_function_into_op_graph_internal(
function_to_compile,
function_parameters,
compilation_configuration,
compilation_artifacts,
)
_measure_op_graph_bounds_and_update_internal(
op_graph,
function_parameters,
inputset,
compilation_configuration,
compilation_artifacts,
)
# Add the final graph as an artifact
compilation_artifacts.add_operation_graph("final", op_graph)
return op_graph
def compile_numpy_function_into_op_graph(
def compile_numpy_function_into_op_graph_and_measure_bounds(
function_to_compile: Callable,
function_parameters: Dict[str, BaseValue],
inputset: Union[Iterable[Tuple[Any, ...]], str],
@@ -220,7 +287,7 @@ def compile_numpy_function_into_op_graph(
try:
# Use context manager to restore numpy error handling
with numpy.errstate(**numpy.geterr()):
return _compile_numpy_function_into_op_graph_internal(
return _compile_numpy_function_into_op_graph_and_measure_bounds_internal(
function_to_compile,
function_parameters,
inputset,
@@ -300,7 +367,7 @@ def _compile_numpy_function_internal(
"""
# Compile into an OPGraph
op_graph = _compile_numpy_function_into_op_graph_internal(
op_graph = _compile_numpy_function_into_op_graph_and_measure_bounds_internal(
function_to_compile,
function_parameters,
inputset,

View File

@@ -8,7 +8,7 @@ import pytest
from concrete.common.compilation import CompilationConfiguration
from concrete.common.data_types.integers import Integer
from concrete.common.values import EncryptedScalar
from concrete.numpy.compile import compile_numpy_function_into_op_graph
from concrete.numpy.compile import compile_numpy_function_into_op_graph_and_measure_bounds
def no_fuse(x):
@@ -43,7 +43,7 @@ def test_enable_topological_optimizations(
):
"""Test function for enable_topological_optimizations flag of compilation configuration"""
op_graph = compile_numpy_function_into_op_graph(
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
function_to_trace,
{
param: EncryptedScalar(Integer(32, is_signed=False))
@@ -52,7 +52,7 @@ def test_enable_topological_optimizations(
[(numpy.array(i),) for i in range(10)],
default_compilation_configuration,
)
op_graph_not_optimized = compile_numpy_function_into_op_graph(
op_graph_not_optimized = compile_numpy_function_into_op_graph_and_measure_bounds(
function_to_trace,
{
param: EncryptedScalar(Integer(32, is_signed=False))

View File

@@ -6,7 +6,7 @@ from pathlib import Path
from concrete.common.data_types.integers import Integer
from concrete.common.debugging import draw_graph
from concrete.common.values import EncryptedScalar
from concrete.numpy.compile import compile_numpy_function_into_op_graph
from concrete.numpy.compile import compile_numpy_function_into_op_graph_and_measure_bounds
def test_draw_graph_with_saving(default_compilation_configuration):
@@ -15,7 +15,7 @@ def test_draw_graph_with_saving(default_compilation_configuration):
def function(x):
return x + 42
op_graph = compile_numpy_function_into_op_graph(
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
function,
{"x": EncryptedScalar(Integer(7, True))},
[(i,) for i in range(-5, 5)],

View File

@@ -3,7 +3,7 @@
from concrete.common.data_types.integers import Integer
from concrete.common.debugging import format_operation_graph
from concrete.common.values import EncryptedScalar
from concrete.numpy.compile import compile_numpy_function_into_op_graph
from concrete.numpy.compile import compile_numpy_function_into_op_graph_and_measure_bounds
def test_format_operation_graph_with_multiple_edges(default_compilation_configuration):
@@ -12,7 +12,7 @@ def test_format_operation_graph_with_multiple_edges(default_compilation_configur
def function(x):
return x + x
op_graph = compile_numpy_function_into_op_graph(
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
function,
{"x": EncryptedScalar(Integer(4, True))},
[(i,) for i in range(0, 10)],
@@ -38,7 +38,7 @@ def test_format_operation_graph_with_offending_nodes(default_compilation_configu
def function(x):
return x + 42
op_graph = compile_numpy_function_into_op_graph(
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
function,
{"x": EncryptedScalar(Integer(7, True))},
[(i,) for i in range(-5, 5)],

View File

@@ -14,7 +14,10 @@ from concrete.common.extensions.multi_table import MultiLookupTable
from concrete.common.extensions.table import LookupTable
from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor
from concrete.numpy import tracing
from concrete.numpy.compile import compile_numpy_function, compile_numpy_function_into_op_graph
from concrete.numpy.compile import (
compile_numpy_function,
compile_numpy_function_into_op_graph_and_measure_bounds,
)
# pylint: disable=too-many-lines
@@ -624,7 +627,7 @@ def test_compile_function_multiple_outputs(
arg_name: EncryptedScalar(Integer(64, True)) for arg_name in list_of_arg_names
}
op_graph = compile_numpy_function_into_op_graph(
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
function,
function_parameters,
data_gen_local(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
@@ -1328,7 +1331,7 @@ def test_compile_function_with_direct_tlu(default_compilation_configuration):
def function(x):
return x + table[x]
op_graph = compile_numpy_function_into_op_graph(
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
function,
{"x": EncryptedScalar(Integer(2, is_signed=False))},
[(0,), (1,), (2,), (3,)],
@@ -1348,7 +1351,7 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
return table[x]
with pytest.raises(ValueError):
compile_numpy_function_into_op_graph(
compile_numpy_function_into_op_graph_and_measure_bounds(
function,
{"x": EncryptedScalar(Integer(3, is_signed=False))},
[(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,)],
@@ -1570,7 +1573,7 @@ return %7
def test_small_inputset_no_fail():
"""Test function compile_numpy_function_into_op_graph with an unacceptably small inputset"""
compile_numpy_function_into_op_graph(
compile_numpy_function_into_op_graph_and_measure_bounds(
lambda x: x + 42,
{"x": EncryptedScalar(Integer(5, is_signed=False))},
[(0,), (3,)],
@@ -1581,7 +1584,7 @@ def test_small_inputset_no_fail():
def test_small_inputset_treat_warnings_as_errors():
"""Test function compile_numpy_function_into_op_graph with an unacceptably small inputset"""
with pytest.raises(ValueError, match=".* inputset contains too few inputs .*"):
compile_numpy_function_into_op_graph(
compile_numpy_function_into_op_graph_and_measure_bounds(
lambda x: x + 42,
{"x": EncryptedScalar(Integer(5, is_signed=False))},
[(0,), (3,)],
@@ -1632,7 +1635,7 @@ def test_compile_function_with_dot(
assert len(shape) == 1
repeat = shape[0]
op_graph = compile_numpy_function_into_op_graph(
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
function,
params,
data_gen_local(max_for_ij, repeat),
@@ -1729,7 +1732,7 @@ def test_compile_with_random_inputset(default_compilation_configuration):
configuration_to_use = deepcopy(default_compilation_configuration)
configuration_to_use.enable_unsafe_features = True
compile_numpy_function_into_op_graph(
compile_numpy_function_into_op_graph_and_measure_bounds(
lambda x: x + 1,
{"x": EncryptedScalar(UnsignedInteger(6))},
inputset="random",
@@ -1748,7 +1751,7 @@ def test_fail_compile_with_random_inputset(default_compilation_configuration):
with pytest.raises(ValueError):
try:
compile_numpy_function_into_op_graph(
compile_numpy_function_into_op_graph_and_measure_bounds(
lambda x: x + 1,
{"x": EncryptedScalar(UnsignedInteger(3))},
inputset="unsupported",

View File

@@ -5,7 +5,7 @@ import pytest
from concrete.common.data_types import UnsignedInteger
from concrete.common.values import EncryptedScalar, EncryptedTensor
from concrete.numpy import compile_numpy_function_into_op_graph
from concrete.numpy import compile_numpy_function_into_op_graph_and_measure_bounds
@pytest.mark.parametrize(
@@ -347,7 +347,7 @@ def test_constant_indexing(
for _ in range(10)
]
op_graph = compile_numpy_function_into_op_graph(
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
function_with_indexing,
{"x": input_value},
inputset,
@@ -470,7 +470,7 @@ def test_invalid_constant_indexing(
)
for _ in range(10)
]
compile_numpy_function_into_op_graph(
compile_numpy_function_into_op_graph_and_measure_bounds(
function_with_indexing,
{"x": input_value},
inputset,
@@ -525,7 +525,7 @@ def test_constant_indexing_with_numpy_integers(
for _ in range(10)
]
op_graph = compile_numpy_function_into_op_graph(
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
function_with_indexing,
{"x": input_value},
inputset,
@@ -586,7 +586,7 @@ def test_invalid_constant_indexing_with_numpy_values(
)
for _ in range(10)
]
compile_numpy_function_into_op_graph(
compile_numpy_function_into_op_graph_and_measure_bounds(
function_with_indexing,
{"x": input_value},
inputset,

View File

@@ -42,7 +42,7 @@ def test_generate_deduplicated_tables(
function, expected_number_of_tables, default_compilation_configuration
):
"""Test function for generate_deduplicated_tables"""
op_graph = hnp.compile_numpy_function_into_op_graph(
op_graph = hnp.compile_numpy_function_into_op_graph_and_measure_bounds(
function,
{"x": hnp.EncryptedTensor(hnp.Integer(7, False), RESNET_BIGGEST_SHAPE)},
((i * numpy.ones(RESNET_BIGGEST_SHAPE, dtype=numpy.int32),) for i in range(128)),
@@ -69,7 +69,7 @@ def test_deduplicated_tables_correctness(default_compilation_configuration):
tensor_shape = (2, 2)
op_graph = hnp.compile_numpy_function_into_op_graph(
op_graph = hnp.compile_numpy_function_into_op_graph_and_measure_bounds(
lambda x: multi_tlu_func(x, numpy.arange(4, dtype=numpy.float64).reshape(tensor_shape)),
{"x": hnp.EncryptedTensor(hnp.Integer(2, False), tensor_shape)},
((i * numpy.ones(tensor_shape, dtype=numpy.int32),) for i in range(4)),