mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
refactor: separate some compilation steps to prepare torch-like API work
refs #233
This commit is contained in:
@@ -6,5 +6,5 @@ from ..common.debugging import draw_graph, format_operation_graph
|
||||
from ..common.extensions.multi_table import MultiLookupTable
|
||||
from ..common.extensions.table import LookupTable
|
||||
from ..common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue
|
||||
from .compile import compile_numpy_function, compile_numpy_function_into_op_graph
|
||||
from .compile import compile_numpy_function, compile_numpy_function_into_op_graph_and_measure_bounds
|
||||
from .tracing import trace_numpy_function
|
||||
|
||||
@@ -61,28 +61,26 @@ def numpy_min_func(lhs: Any, rhs: Any) -> Any:
|
||||
def _compile_numpy_function_into_op_graph_internal(
|
||||
function_to_compile: Callable,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
inputset: Iterable[Tuple[Any, ...]],
|
||||
compilation_configuration: CompilationConfiguration,
|
||||
compilation_artifacts: CompilationArtifacts,
|
||||
) -> OPGraph:
|
||||
"""Compile a function into an OPGraph.
|
||||
"""Compile a function into an OPGraph without evaluating the intermediate nodes bounds.
|
||||
|
||||
Args:
|
||||
function_to_compile (Callable): The function to compile
|
||||
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
|
||||
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
|
||||
inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
|
||||
needs to be an iterable on tuples which are of the same length than the number of
|
||||
parameters in the function, and in the same order than these same parameters
|
||||
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
|
||||
during compilation
|
||||
compilation_configuration (CompilationConfiguration): Configuration object to use
|
||||
during compilation
|
||||
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
|
||||
during compilation
|
||||
|
||||
Returns:
|
||||
OPGraph: compiled function into a graph
|
||||
OPGraph: compiled function into a graph, node values are not representative of the values
|
||||
that can be observed during execution.
|
||||
Use _compile_numpy_function_into_op_graph_and_measure_bounds_internal if you need bounds
|
||||
estimation.
|
||||
"""
|
||||
|
||||
# Check function parameters
|
||||
wrong_inputs = {
|
||||
inp: function_parameters[inp]
|
||||
@@ -119,6 +117,34 @@ def _compile_numpy_function_into_op_graph_internal(
|
||||
if not check_op_graph_is_integer_program(op_graph):
|
||||
fuse_float_operations(op_graph, compilation_artifacts)
|
||||
|
||||
return op_graph
|
||||
|
||||
|
||||
def _measure_op_graph_bounds_and_update_internal(
|
||||
op_graph: OPGraph,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
inputset: Iterable[Tuple[Any, ...]],
|
||||
compilation_configuration: CompilationConfiguration,
|
||||
compilation_artifacts: CompilationArtifacts,
|
||||
) -> None:
|
||||
"""Measure the intermediate values and update the OPGraph accordingly for the given inputset.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): the OPGraph for which to measure bounds and update node values.
|
||||
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
|
||||
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
|
||||
inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
|
||||
needs to be an iterable on tuples which are of the same length than the number of
|
||||
parameters in the function, and in the same order than these same parameters
|
||||
compilation_configuration (CompilationConfiguration): Configuration object to use
|
||||
during compilation
|
||||
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
|
||||
during compilation
|
||||
|
||||
Raises:
|
||||
ValueError: Raises an error if the inputset is too small and the compilation configuration
|
||||
treats warnings as error.
|
||||
"""
|
||||
# Find bounds with the inputset
|
||||
inputset_size, node_bounds_and_samples = eval_op_graph_bounds_on_inputset(
|
||||
op_graph,
|
||||
@@ -167,13 +193,54 @@ def _compile_numpy_function_into_op_graph_internal(
|
||||
get_constructor_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
# Add the initial graph as an artifact
|
||||
|
||||
def _compile_numpy_function_into_op_graph_and_measure_bounds_internal(
|
||||
function_to_compile: Callable,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
inputset: Iterable[Tuple[Any, ...]],
|
||||
compilation_configuration: CompilationConfiguration,
|
||||
compilation_artifacts: CompilationArtifacts,
|
||||
) -> OPGraph:
|
||||
"""Compile a function into an OPGraph and evaluate the intermediate nodes bounds.
|
||||
|
||||
Args:
|
||||
function_to_compile (Callable): The function to compile
|
||||
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
|
||||
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
|
||||
inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
|
||||
needs to be an iterable on tuples which are of the same length than the number of
|
||||
parameters in the function, and in the same order than these same parameters
|
||||
compilation_configuration (CompilationConfiguration): Configuration object to use
|
||||
during compilation
|
||||
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
|
||||
during compilation
|
||||
|
||||
Returns:
|
||||
OPGraph: compiled function into a graph with estimated bounds in node values.
|
||||
"""
|
||||
|
||||
op_graph = _compile_numpy_function_into_op_graph_internal(
|
||||
function_to_compile,
|
||||
function_parameters,
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
)
|
||||
|
||||
_measure_op_graph_bounds_and_update_internal(
|
||||
op_graph,
|
||||
function_parameters,
|
||||
inputset,
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
)
|
||||
|
||||
# Add the final graph as an artifact
|
||||
compilation_artifacts.add_operation_graph("final", op_graph)
|
||||
|
||||
return op_graph
|
||||
|
||||
|
||||
def compile_numpy_function_into_op_graph(
|
||||
def compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function_to_compile: Callable,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
inputset: Union[Iterable[Tuple[Any, ...]], str],
|
||||
@@ -220,7 +287,7 @@ def compile_numpy_function_into_op_graph(
|
||||
try:
|
||||
# Use context manager to restore numpy error handling
|
||||
with numpy.errstate(**numpy.geterr()):
|
||||
return _compile_numpy_function_into_op_graph_internal(
|
||||
return _compile_numpy_function_into_op_graph_and_measure_bounds_internal(
|
||||
function_to_compile,
|
||||
function_parameters,
|
||||
inputset,
|
||||
@@ -300,7 +367,7 @@ def _compile_numpy_function_internal(
|
||||
"""
|
||||
|
||||
# Compile into an OPGraph
|
||||
op_graph = _compile_numpy_function_into_op_graph_internal(
|
||||
op_graph = _compile_numpy_function_into_op_graph_and_measure_bounds_internal(
|
||||
function_to_compile,
|
||||
function_parameters,
|
||||
inputset,
|
||||
|
||||
@@ -8,7 +8,7 @@ import pytest
|
||||
from concrete.common.compilation import CompilationConfiguration
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.values import EncryptedScalar
|
||||
from concrete.numpy.compile import compile_numpy_function_into_op_graph
|
||||
from concrete.numpy.compile import compile_numpy_function_into_op_graph_and_measure_bounds
|
||||
|
||||
|
||||
def no_fuse(x):
|
||||
@@ -43,7 +43,7 @@ def test_enable_topological_optimizations(
|
||||
):
|
||||
"""Test function for enable_topological_optimizations flag of compilation configuration"""
|
||||
|
||||
op_graph = compile_numpy_function_into_op_graph(
|
||||
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function_to_trace,
|
||||
{
|
||||
param: EncryptedScalar(Integer(32, is_signed=False))
|
||||
@@ -52,7 +52,7 @@ def test_enable_topological_optimizations(
|
||||
[(numpy.array(i),) for i in range(10)],
|
||||
default_compilation_configuration,
|
||||
)
|
||||
op_graph_not_optimized = compile_numpy_function_into_op_graph(
|
||||
op_graph_not_optimized = compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function_to_trace,
|
||||
{
|
||||
param: EncryptedScalar(Integer(32, is_signed=False))
|
||||
|
||||
@@ -6,7 +6,7 @@ from pathlib import Path
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.debugging import draw_graph
|
||||
from concrete.common.values import EncryptedScalar
|
||||
from concrete.numpy.compile import compile_numpy_function_into_op_graph
|
||||
from concrete.numpy.compile import compile_numpy_function_into_op_graph_and_measure_bounds
|
||||
|
||||
|
||||
def test_draw_graph_with_saving(default_compilation_configuration):
|
||||
@@ -15,7 +15,7 @@ def test_draw_graph_with_saving(default_compilation_configuration):
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
op_graph = compile_numpy_function_into_op_graph(
|
||||
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function,
|
||||
{"x": EncryptedScalar(Integer(7, True))},
|
||||
[(i,) for i in range(-5, 5)],
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.debugging import format_operation_graph
|
||||
from concrete.common.values import EncryptedScalar
|
||||
from concrete.numpy.compile import compile_numpy_function_into_op_graph
|
||||
from concrete.numpy.compile import compile_numpy_function_into_op_graph_and_measure_bounds
|
||||
|
||||
|
||||
def test_format_operation_graph_with_multiple_edges(default_compilation_configuration):
|
||||
@@ -12,7 +12,7 @@ def test_format_operation_graph_with_multiple_edges(default_compilation_configur
|
||||
def function(x):
|
||||
return x + x
|
||||
|
||||
op_graph = compile_numpy_function_into_op_graph(
|
||||
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function,
|
||||
{"x": EncryptedScalar(Integer(4, True))},
|
||||
[(i,) for i in range(0, 10)],
|
||||
@@ -38,7 +38,7 @@ def test_format_operation_graph_with_offending_nodes(default_compilation_configu
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
op_graph = compile_numpy_function_into_op_graph(
|
||||
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function,
|
||||
{"x": EncryptedScalar(Integer(7, True))},
|
||||
[(i,) for i in range(-5, 5)],
|
||||
|
||||
@@ -14,7 +14,10 @@ from concrete.common.extensions.multi_table import MultiLookupTable
|
||||
from concrete.common.extensions.table import LookupTable
|
||||
from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import tracing
|
||||
from concrete.numpy.compile import compile_numpy_function, compile_numpy_function_into_op_graph
|
||||
from concrete.numpy.compile import (
|
||||
compile_numpy_function,
|
||||
compile_numpy_function_into_op_graph_and_measure_bounds,
|
||||
)
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
@@ -624,7 +627,7 @@ def test_compile_function_multiple_outputs(
|
||||
arg_name: EncryptedScalar(Integer(64, True)) for arg_name in list_of_arg_names
|
||||
}
|
||||
|
||||
op_graph = compile_numpy_function_into_op_graph(
|
||||
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen_local(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
@@ -1328,7 +1331,7 @@ def test_compile_function_with_direct_tlu(default_compilation_configuration):
|
||||
def function(x):
|
||||
return x + table[x]
|
||||
|
||||
op_graph = compile_numpy_function_into_op_graph(
|
||||
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function,
|
||||
{"x": EncryptedScalar(Integer(2, is_signed=False))},
|
||||
[(0,), (1,), (2,), (3,)],
|
||||
@@ -1348,7 +1351,7 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
|
||||
return table[x]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
compile_numpy_function_into_op_graph(
|
||||
compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function,
|
||||
{"x": EncryptedScalar(Integer(3, is_signed=False))},
|
||||
[(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,)],
|
||||
@@ -1570,7 +1573,7 @@ return %7
|
||||
|
||||
def test_small_inputset_no_fail():
|
||||
"""Test function compile_numpy_function_into_op_graph with an unacceptably small inputset"""
|
||||
compile_numpy_function_into_op_graph(
|
||||
compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
lambda x: x + 42,
|
||||
{"x": EncryptedScalar(Integer(5, is_signed=False))},
|
||||
[(0,), (3,)],
|
||||
@@ -1581,7 +1584,7 @@ def test_small_inputset_no_fail():
|
||||
def test_small_inputset_treat_warnings_as_errors():
|
||||
"""Test function compile_numpy_function_into_op_graph with an unacceptably small inputset"""
|
||||
with pytest.raises(ValueError, match=".* inputset contains too few inputs .*"):
|
||||
compile_numpy_function_into_op_graph(
|
||||
compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
lambda x: x + 42,
|
||||
{"x": EncryptedScalar(Integer(5, is_signed=False))},
|
||||
[(0,), (3,)],
|
||||
@@ -1632,7 +1635,7 @@ def test_compile_function_with_dot(
|
||||
assert len(shape) == 1
|
||||
repeat = shape[0]
|
||||
|
||||
op_graph = compile_numpy_function_into_op_graph(
|
||||
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function,
|
||||
params,
|
||||
data_gen_local(max_for_ij, repeat),
|
||||
@@ -1729,7 +1732,7 @@ def test_compile_with_random_inputset(default_compilation_configuration):
|
||||
configuration_to_use = deepcopy(default_compilation_configuration)
|
||||
configuration_to_use.enable_unsafe_features = True
|
||||
|
||||
compile_numpy_function_into_op_graph(
|
||||
compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
lambda x: x + 1,
|
||||
{"x": EncryptedScalar(UnsignedInteger(6))},
|
||||
inputset="random",
|
||||
@@ -1748,7 +1751,7 @@ def test_fail_compile_with_random_inputset(default_compilation_configuration):
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
try:
|
||||
compile_numpy_function_into_op_graph(
|
||||
compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
lambda x: x + 1,
|
||||
{"x": EncryptedScalar(UnsignedInteger(3))},
|
||||
inputset="unsupported",
|
||||
|
||||
@@ -5,7 +5,7 @@ import pytest
|
||||
|
||||
from concrete.common.data_types import UnsignedInteger
|
||||
from concrete.common.values import EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import compile_numpy_function_into_op_graph
|
||||
from concrete.numpy import compile_numpy_function_into_op_graph_and_measure_bounds
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -347,7 +347,7 @@ def test_constant_indexing(
|
||||
for _ in range(10)
|
||||
]
|
||||
|
||||
op_graph = compile_numpy_function_into_op_graph(
|
||||
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function_with_indexing,
|
||||
{"x": input_value},
|
||||
inputset,
|
||||
@@ -470,7 +470,7 @@ def test_invalid_constant_indexing(
|
||||
)
|
||||
for _ in range(10)
|
||||
]
|
||||
compile_numpy_function_into_op_graph(
|
||||
compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function_with_indexing,
|
||||
{"x": input_value},
|
||||
inputset,
|
||||
@@ -525,7 +525,7 @@ def test_constant_indexing_with_numpy_integers(
|
||||
for _ in range(10)
|
||||
]
|
||||
|
||||
op_graph = compile_numpy_function_into_op_graph(
|
||||
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function_with_indexing,
|
||||
{"x": input_value},
|
||||
inputset,
|
||||
@@ -586,7 +586,7 @@ def test_invalid_constant_indexing_with_numpy_values(
|
||||
)
|
||||
for _ in range(10)
|
||||
]
|
||||
compile_numpy_function_into_op_graph(
|
||||
compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function_with_indexing,
|
||||
{"x": input_value},
|
||||
inputset,
|
||||
|
||||
@@ -42,7 +42,7 @@ def test_generate_deduplicated_tables(
|
||||
function, expected_number_of_tables, default_compilation_configuration
|
||||
):
|
||||
"""Test function for generate_deduplicated_tables"""
|
||||
op_graph = hnp.compile_numpy_function_into_op_graph(
|
||||
op_graph = hnp.compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
function,
|
||||
{"x": hnp.EncryptedTensor(hnp.Integer(7, False), RESNET_BIGGEST_SHAPE)},
|
||||
((i * numpy.ones(RESNET_BIGGEST_SHAPE, dtype=numpy.int32),) for i in range(128)),
|
||||
@@ -69,7 +69,7 @@ def test_deduplicated_tables_correctness(default_compilation_configuration):
|
||||
|
||||
tensor_shape = (2, 2)
|
||||
|
||||
op_graph = hnp.compile_numpy_function_into_op_graph(
|
||||
op_graph = hnp.compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
lambda x: multi_tlu_func(x, numpy.arange(4, dtype=numpy.float64).reshape(tensor_shape)),
|
||||
{"x": hnp.EncryptedTensor(hnp.Integer(2, False), tensor_shape)},
|
||||
((i * numpy.ones(tensor_shape, dtype=numpy.int32),) for i in range(4)),
|
||||
|
||||
Reference in New Issue
Block a user