mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: create torch-like APIs part 1
- work on generating OPGraph with a torch-like API refs #233
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
"""Code to evaluate the IR graph on inputsets."""
|
||||
|
||||
import sys
|
||||
from typing import Any, Callable, Dict, Iterable, Tuple, Union
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
|
||||
|
||||
from ..compilation import CompilationConfiguration
|
||||
from ..data_types.dtypes_helpers import (
|
||||
@@ -110,6 +111,7 @@ def eval_op_graph_bounds_on_inputset(
|
||||
get_base_value_for_constant_data_func: Callable[
|
||||
[Any], Any
|
||||
] = get_base_value_for_python_constant_data,
|
||||
prev_node_bounds_and_samples: Optional[Dict[IntermediateNode, Dict[str, Any]]] = None,
|
||||
) -> Tuple[int, Dict[IntermediateNode, Dict[str, Any]]]:
|
||||
"""Evaluate the bounds with a inputset.
|
||||
|
||||
@@ -132,6 +134,8 @@ def eval_op_graph_bounds_on_inputset(
|
||||
tensors). Defaults to max.
|
||||
get_base_value_for_constant_data_func (Callable[[Any], Any], optional): custom function
|
||||
to compute the base value of a python object.
|
||||
prev_node_bounds_and_samples (Optional[Dict[IntermediateNode, Dict[str, Any]]], optional):
|
||||
Bounds and samples from a previous run. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tuple[int, Dict[IntermediateNode, Dict[str, Any]]]: number of inputs in the inputset and
|
||||
@@ -187,13 +191,38 @@ def eval_op_graph_bounds_on_inputset(
|
||||
|
||||
first_output = op_graph.evaluate(current_input_data)
|
||||
|
||||
prev_node_bounds_and_samples = (
|
||||
{} if prev_node_bounds_and_samples is None else prev_node_bounds_and_samples
|
||||
)
|
||||
|
||||
def get_previous_value_for_key_or_default_for_dict(
|
||||
dict_: Dict[IntermediateNode, Dict[str, Any]],
|
||||
node: IntermediateNode,
|
||||
key: str,
|
||||
default: Any,
|
||||
) -> Any:
|
||||
return_value = default
|
||||
|
||||
previous_value_dict = dict_.get(node, None)
|
||||
|
||||
if previous_value_dict is not None:
|
||||
return_value = previous_value_dict.get(key, default)
|
||||
|
||||
return return_value
|
||||
|
||||
get_previous_value_for_key_or_default = partial(
|
||||
get_previous_value_for_key_or_default_for_dict, prev_node_bounds_and_samples
|
||||
)
|
||||
|
||||
# We evaluate the min and max func to be able to resolve the tensors min and max rather than
|
||||
# having the tensor itself as the stored min and max values.
|
||||
# As we don't know the integrity of prev_node_bounds_and_samples we make sure we can
|
||||
# populate the new node_bounds_and_samples
|
||||
node_bounds_and_samples = {
|
||||
node: {
|
||||
"min": min_func(value, value),
|
||||
"max": max_func(value, value),
|
||||
"sample": value,
|
||||
"min": min_func(value, get_previous_value_for_key_or_default(node, "min", value)),
|
||||
"max": max_func(value, get_previous_value_for_key_or_default(node, "max", value)),
|
||||
"sample": get_previous_value_for_key_or_default(node, "sample", value),
|
||||
}
|
||||
for node, value in first_output.items()
|
||||
}
|
||||
|
||||
@@ -26,3 +26,6 @@ class CompilationConfiguration:
|
||||
self.treat_warnings_as_errors = treat_warnings_as_errors
|
||||
self.enable_unsafe_features = enable_unsafe_features
|
||||
self.random_inputset_samples = random_inputset_samples
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return isinstance(other, CompilationConfiguration) and self.__dict__ == other.__dict__
|
||||
|
||||
@@ -15,4 +15,5 @@ from ..common.extensions.multi_table import MultiLookupTable
|
||||
from ..common.extensions.table import LookupTable
|
||||
from ..common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue
|
||||
from .compile import compile_numpy_function, compile_numpy_function_into_op_graph_and_measure_bounds
|
||||
from .np_fhe_compiler import NPFHECompiler
|
||||
from .tracing import trace_numpy_function
|
||||
|
||||
@@ -22,7 +22,7 @@ from ..common.mlir.utils import (
|
||||
)
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.optimization.topological import fuse_float_operations
|
||||
from ..common.representation.intermediate import Add, Constant, GenericFunction
|
||||
from ..common.representation.intermediate import Add, Constant, GenericFunction, IntermediateNode
|
||||
from ..common.values import BaseValue, ClearScalar
|
||||
from ..numpy.tracing import trace_numpy_function
|
||||
from .np_dtypes_helpers import (
|
||||
@@ -60,6 +60,102 @@ def numpy_min_func(lhs: Any, rhs: Any) -> Any:
|
||||
return numpy.minimum(lhs, rhs).min()
|
||||
|
||||
|
||||
def sanitize_compilation_configuration_and_artifacts(
|
||||
compilation_configuration: Optional[CompilationConfiguration] = None,
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
) -> Tuple[CompilationConfiguration, CompilationArtifacts]:
|
||||
"""Return the proper compilation configuration and artifacts.
|
||||
|
||||
Default values are returned if None is passed for each argument.
|
||||
|
||||
Args:
|
||||
compilation_configuration (Optional[CompilationConfiguration], optional): the compilation
|
||||
configuration to sanitize. Defaults to None.
|
||||
compilation_artifacts (Optional[CompilationArtifacts], optional): the compilation artifacts
|
||||
to sanitize. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tuple[CompilationConfiguration, CompilationArtifacts]: the tuple of sanitized configuration
|
||||
and artifacts.
|
||||
"""
|
||||
# Create default configuration if custom configuration is not specified
|
||||
compilation_configuration = (
|
||||
CompilationConfiguration()
|
||||
if compilation_configuration is None
|
||||
else compilation_configuration
|
||||
)
|
||||
|
||||
# Create temporary artifacts if custom artifacts is not specified (in case of exceptions)
|
||||
if compilation_artifacts is None:
|
||||
compilation_artifacts = CompilationArtifacts()
|
||||
|
||||
return compilation_configuration, compilation_artifacts
|
||||
|
||||
|
||||
def get_inputset_to_use(
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
inputset: Union[Iterable[Tuple[Any, ...]], str],
|
||||
compilation_configuration: CompilationConfiguration,
|
||||
) -> Iterable[Tuple[Any, ...]]:
|
||||
"""Get the proper inputset to use for compilation.
|
||||
|
||||
Args:
|
||||
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
|
||||
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
|
||||
inputset (Union[Iterable[Tuple[Any, ...]], str]): The inputset over which op_graph is
|
||||
evaluated. It needs to be an iterable on tuples which are of the same length than the
|
||||
number of parameters in the function, and in the same order than these same parameters
|
||||
compilation_configuration (CompilationConfiguration): Configuration object to use during
|
||||
compilation
|
||||
|
||||
Returns:
|
||||
Iterable[Tuple[Any, ...]]: the inputset to use.
|
||||
"""
|
||||
# Generate random inputset if it is requested and available
|
||||
if isinstance(inputset, str):
|
||||
_check_special_inputset_availability(inputset, compilation_configuration)
|
||||
return _generate_random_inputset(function_parameters, compilation_configuration)
|
||||
return inputset
|
||||
|
||||
|
||||
def run_compilation_function_with_error_management(
|
||||
compilation_function: Callable,
|
||||
compilation_configuration: CompilationConfiguration,
|
||||
compilation_artifacts: CompilationArtifacts,
|
||||
) -> Any:
|
||||
"""Call compilation_function() and manage exceptions that may occur.
|
||||
|
||||
Args:
|
||||
compilation_function (Callable): the compilation function to call.
|
||||
compilation_configuration (CompilationConfiguration): the current compilation configuration.
|
||||
compilation_artifacts (CompilationArtifacts): the current compilation artifacts.
|
||||
|
||||
Returns:
|
||||
Any: returns the result of the call to compilation_function
|
||||
"""
|
||||
|
||||
# Try to compile the function and save partial artifacts on failure
|
||||
try:
|
||||
# Use context manager to restore numpy error handling
|
||||
with numpy.errstate(**numpy.geterr()):
|
||||
return compilation_function()
|
||||
except Exception: # pragma: no cover
|
||||
# This branch is reserved for unexpected issues and hence it shouldn't be tested.
|
||||
# If it could be tested, we would have fixed the underlying issue.
|
||||
|
||||
# We need to export all the information we have about the compilation
|
||||
# If the user wants them to be exported
|
||||
|
||||
if compilation_configuration.dump_artifacts_on_unexpected_failures:
|
||||
compilation_artifacts.export()
|
||||
|
||||
traceback_path = compilation_artifacts.output_directory.joinpath("traceback.txt")
|
||||
with open(traceback_path, "w", encoding="utf-8") as f:
|
||||
f.write(traceback.format_exc())
|
||||
|
||||
raise
|
||||
|
||||
|
||||
def _compile_numpy_function_into_op_graph_internal(
|
||||
function_to_compile: Callable,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
@@ -122,13 +218,60 @@ def _compile_numpy_function_into_op_graph_internal(
|
||||
return op_graph
|
||||
|
||||
|
||||
def compile_numpy_function_into_op_graph(
|
||||
function_to_compile: Callable,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
compilation_configuration: Optional[CompilationConfiguration] = None,
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
) -> OPGraph:
|
||||
"""Compile a function into an OPGraph.
|
||||
|
||||
Args:
|
||||
function_to_compile (Callable): The function to compile
|
||||
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
|
||||
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
|
||||
compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use
|
||||
during compilation
|
||||
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
|
||||
during compilation
|
||||
|
||||
Returns:
|
||||
OPGraph: compiled function into a graph
|
||||
"""
|
||||
|
||||
(
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
) = sanitize_compilation_configuration_and_artifacts(
|
||||
compilation_configuration, compilation_artifacts
|
||||
)
|
||||
|
||||
def compilation_function():
|
||||
return _compile_numpy_function_into_op_graph_internal(
|
||||
function_to_compile,
|
||||
function_parameters,
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
)
|
||||
|
||||
result = run_compilation_function_with_error_management(
|
||||
compilation_function, compilation_configuration, compilation_artifacts
|
||||
)
|
||||
|
||||
# for mypy
|
||||
assert isinstance(result, OPGraph)
|
||||
return result
|
||||
|
||||
|
||||
def _measure_op_graph_bounds_and_update_internal(
|
||||
op_graph: OPGraph,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
inputset: Iterable[Tuple[Any, ...]],
|
||||
compilation_configuration: CompilationConfiguration,
|
||||
compilation_artifacts: CompilationArtifacts,
|
||||
) -> None:
|
||||
prev_node_bounds_and_samples: Optional[Dict[IntermediateNode, Dict[str, Any]]] = None,
|
||||
warn_on_inputset_length: bool = True,
|
||||
) -> Dict[IntermediateNode, Dict[str, Any]]:
|
||||
"""Measure the intermediate values and update the OPGraph accordingly for the given inputset.
|
||||
|
||||
Args:
|
||||
@@ -142,11 +285,21 @@ def _measure_op_graph_bounds_and_update_internal(
|
||||
during compilation
|
||||
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
|
||||
during compilation
|
||||
prev_node_bounds_and_samples (Optional[Dict[IntermediateNode, Dict[str, Any]]], optional):
|
||||
Bounds and samples from a previous run. Defaults to None.
|
||||
warn_on_inputset_length (bool, optional): Set to True to get a warning if inputset is not
|
||||
long enough. Defaults to True.
|
||||
|
||||
Raises:
|
||||
ValueError: Raises an error if the inputset is too small and the compilation configuration
|
||||
treats warnings as error.
|
||||
|
||||
Returns:
|
||||
Dict[IntermediateNode, Dict[str, Any]]: a dict containing the bounds for each node from
|
||||
op_graph, stored with the node as key and a dict with keys "min", "max" and "sample" as
|
||||
value.
|
||||
"""
|
||||
|
||||
# Find bounds with the inputset
|
||||
inputset_size, node_bounds_and_samples = eval_op_graph_bounds_on_inputset(
|
||||
op_graph,
|
||||
@@ -155,35 +308,37 @@ def _measure_op_graph_bounds_and_update_internal(
|
||||
min_func=numpy_min_func,
|
||||
max_func=numpy_max_func,
|
||||
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
|
||||
prev_node_bounds_and_samples=prev_node_bounds_and_samples,
|
||||
)
|
||||
|
||||
# Check inputset size
|
||||
inputset_size_upper_limit = 1
|
||||
if warn_on_inputset_length:
|
||||
# Check inputset size
|
||||
inputset_size_upper_limit = 1
|
||||
|
||||
# this loop will determine the number of possible inputs of the function
|
||||
# if a function have a single 3-bit input, for example, `inputset_size_upper_limit` will be 8
|
||||
for parameter_value in function_parameters.values():
|
||||
if isinstance(parameter_value.dtype, Integer):
|
||||
# multiple parameter bit-widths are multiplied as they can be combined into an input
|
||||
inputset_size_upper_limit *= 2 ** parameter_value.dtype.bit_width
|
||||
# this loop will determine the number of possible inputs of the function
|
||||
# if a function have a single 3-bit input, for example, inputset_size_upper_limit will be 8
|
||||
for parameter_value in function_parameters.values():
|
||||
if isinstance(parameter_value.dtype, Integer):
|
||||
# multiple parameter bit-widths are multiplied as they can be combined into an input
|
||||
inputset_size_upper_limit *= 2 ** parameter_value.dtype.bit_width
|
||||
|
||||
# if the upper limit of the inputset size goes above 10,
|
||||
# break the loop as we will require at least 10 inputs in this case
|
||||
if inputset_size_upper_limit > 10:
|
||||
break
|
||||
# if the upper limit of the inputset size goes above 10,
|
||||
# break the loop as we will require at least 10 inputs in this case
|
||||
if inputset_size_upper_limit > 10:
|
||||
break
|
||||
|
||||
minimum_required_inputset_size = min(inputset_size_upper_limit, 10)
|
||||
if inputset_size < minimum_required_inputset_size:
|
||||
message = (
|
||||
f"Provided inputset contains too few inputs "
|
||||
f"(it should have had at least {minimum_required_inputset_size} "
|
||||
f"but it only had {inputset_size})\n"
|
||||
)
|
||||
minimum_required_inputset_size = min(inputset_size_upper_limit, 10)
|
||||
if inputset_size < minimum_required_inputset_size:
|
||||
message = (
|
||||
f"Provided inputset contains too few inputs "
|
||||
f"(it should have had at least {minimum_required_inputset_size} "
|
||||
f"but it only had {inputset_size})\n"
|
||||
)
|
||||
|
||||
if compilation_configuration.treat_warnings_as_errors:
|
||||
raise ValueError(message)
|
||||
if compilation_configuration.treat_warnings_as_errors:
|
||||
raise ValueError(message)
|
||||
|
||||
sys.stderr.write(f"Warning: {message}")
|
||||
sys.stderr.write(f"Warning: {message}")
|
||||
|
||||
# Add the bounds as an artifact
|
||||
compilation_artifacts.add_final_operation_graph_bounds(node_bounds_and_samples)
|
||||
@@ -195,6 +350,74 @@ def _measure_op_graph_bounds_and_update_internal(
|
||||
get_constructor_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
return node_bounds_and_samples
|
||||
|
||||
|
||||
def measure_op_graph_bounds_and_update(
|
||||
op_graph: OPGraph,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
inputset: Union[Iterable[Tuple[Any, ...]], str],
|
||||
compilation_configuration: Optional[CompilationConfiguration] = None,
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
prev_node_bounds_and_samples: Optional[Dict[IntermediateNode, Dict[str, Any]]] = None,
|
||||
warn_on_inputset_length: bool = True,
|
||||
) -> Dict[IntermediateNode, Dict[str, Any]]:
|
||||
"""Measure the intermediate values and update the OPGraph accordingly for the given inputset.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): the OPGraph for which to measure bounds and update node values.
|
||||
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
|
||||
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
|
||||
inputset (Union[Iterable[Tuple[Any, ...]], str]): The inputset over which op_graph is
|
||||
evaluated. It needs to be an iterable on tuples which are of the same length than the
|
||||
number of parameters in the function, and in the same order than these same parameters
|
||||
compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use
|
||||
during compilation
|
||||
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
|
||||
during compilation
|
||||
prev_node_bounds_and_samples (Optional[Dict[IntermediateNode, Dict[str, Any]]], optional):
|
||||
Bounds and samples from a previous run. Defaults to None.
|
||||
warn_on_inputset_length (bool, optional): Set to True to get a warning if inputset is not
|
||||
long enough. Defaults to True.
|
||||
|
||||
Raises:
|
||||
ValueError: Raises an error if the inputset is too small and the compilation configuration
|
||||
treats warnings as error.
|
||||
|
||||
Returns:
|
||||
Dict[IntermediateNode, Dict[str, Any]]: a dict containing the bounds for each node from
|
||||
op_graph, stored with the node as key and a dict with keys "min", "max" and "sample" as
|
||||
value.
|
||||
"""
|
||||
|
||||
(
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
) = sanitize_compilation_configuration_and_artifacts(
|
||||
compilation_configuration, compilation_artifacts
|
||||
)
|
||||
|
||||
inputset = get_inputset_to_use(function_parameters, inputset, compilation_configuration)
|
||||
|
||||
def compilation_function():
|
||||
return _measure_op_graph_bounds_and_update_internal(
|
||||
op_graph,
|
||||
function_parameters,
|
||||
inputset,
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
prev_node_bounds_and_samples,
|
||||
warn_on_inputset_length,
|
||||
)
|
||||
|
||||
result = run_compilation_function_with_error_management(
|
||||
compilation_function, compilation_configuration, compilation_artifacts
|
||||
)
|
||||
|
||||
# for mypy
|
||||
assert isinstance(result, dict)
|
||||
return result
|
||||
|
||||
|
||||
def _compile_numpy_function_into_op_graph_and_measure_bounds_internal(
|
||||
function_to_compile: Callable,
|
||||
@@ -269,48 +492,31 @@ def compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
OPGraph: compiled function into a graph
|
||||
"""
|
||||
|
||||
# Create default configuration if custom configuration is not specified
|
||||
compilation_configuration = (
|
||||
CompilationConfiguration()
|
||||
if compilation_configuration is None
|
||||
else compilation_configuration
|
||||
(
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
) = sanitize_compilation_configuration_and_artifacts(
|
||||
compilation_configuration, compilation_artifacts
|
||||
)
|
||||
|
||||
# Create temporary artifacts if custom artifacts is not specified (in case of exceptions)
|
||||
if compilation_artifacts is None:
|
||||
compilation_artifacts = CompilationArtifacts()
|
||||
inputset = get_inputset_to_use(function_parameters, inputset, compilation_configuration)
|
||||
|
||||
# Generate random inputset if it is requested and available
|
||||
if isinstance(inputset, str):
|
||||
_check_special_inputset_availability(inputset, compilation_configuration)
|
||||
inputset = _generate_random_inputset(function_parameters, compilation_configuration)
|
||||
def compilation_function():
|
||||
return _compile_numpy_function_into_op_graph_and_measure_bounds_internal(
|
||||
function_to_compile,
|
||||
function_parameters,
|
||||
inputset,
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
)
|
||||
|
||||
# Try to compile the function and save partial artifacts on failure
|
||||
try:
|
||||
# Use context manager to restore numpy error handling
|
||||
with numpy.errstate(**numpy.geterr()):
|
||||
return _compile_numpy_function_into_op_graph_and_measure_bounds_internal(
|
||||
function_to_compile,
|
||||
function_parameters,
|
||||
inputset,
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
)
|
||||
except Exception: # pragma: no cover
|
||||
# This branch is reserved for unexpected issues and hence it shouldn't be tested.
|
||||
# If it could be tested, we would have fixed the underlying issue.
|
||||
result = run_compilation_function_with_error_management(
|
||||
compilation_function, compilation_configuration, compilation_artifacts
|
||||
)
|
||||
|
||||
# We need to export all the information we have about the compilation
|
||||
# If the user wants them to be exported
|
||||
|
||||
if compilation_configuration.dump_artifacts_on_unexpected_failures:
|
||||
compilation_artifacts.export()
|
||||
|
||||
traceback_path = compilation_artifacts.output_directory.joinpath("traceback.txt")
|
||||
with open(traceback_path, "w", encoding="utf-8") as f:
|
||||
f.write(traceback.format_exc())
|
||||
|
||||
raise
|
||||
# for mypy
|
||||
assert isinstance(result, OPGraph)
|
||||
return result
|
||||
|
||||
|
||||
# HACK
|
||||
@@ -492,46 +698,29 @@ def compile_numpy_function(
|
||||
CompilerEngine: engine to run and debug the compiled graph
|
||||
"""
|
||||
|
||||
# Create default configuration if custom configuration is not specified
|
||||
compilation_configuration = (
|
||||
CompilationConfiguration()
|
||||
if compilation_configuration is None
|
||||
else compilation_configuration
|
||||
(
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
) = sanitize_compilation_configuration_and_artifacts(
|
||||
compilation_configuration, compilation_artifacts
|
||||
)
|
||||
|
||||
# Create temporary artifacts if custom artifacts is not specified (in case of exceptions)
|
||||
if compilation_artifacts is None:
|
||||
compilation_artifacts = CompilationArtifacts()
|
||||
inputset = get_inputset_to_use(function_parameters, inputset, compilation_configuration)
|
||||
|
||||
# Generate random inputset if it is requested and available
|
||||
if isinstance(inputset, str):
|
||||
_check_special_inputset_availability(inputset, compilation_configuration)
|
||||
inputset = _generate_random_inputset(function_parameters, compilation_configuration)
|
||||
def compilation_function():
|
||||
return _compile_numpy_function_internal(
|
||||
function_to_compile,
|
||||
function_parameters,
|
||||
inputset,
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
show_mlir,
|
||||
)
|
||||
|
||||
# Try to compile the function and save partial artifacts on failure
|
||||
try:
|
||||
# Use context manager to restore numpy error handling
|
||||
with numpy.errstate(**numpy.geterr()):
|
||||
return _compile_numpy_function_internal(
|
||||
function_to_compile,
|
||||
function_parameters,
|
||||
inputset,
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
show_mlir,
|
||||
)
|
||||
except Exception: # pragma: no cover
|
||||
# This branch is reserved for unexpected issues and hence it shouldn't be tested.
|
||||
# If it could be tested, we would have fixed the underlying issue.
|
||||
result = run_compilation_function_with_error_management(
|
||||
compilation_function, compilation_configuration, compilation_artifacts
|
||||
)
|
||||
|
||||
# We need to export all the information we have about the compilation
|
||||
# If the user wants them to be exported
|
||||
|
||||
if compilation_configuration.dump_artifacts_on_unexpected_failures:
|
||||
compilation_artifacts.export()
|
||||
|
||||
traceback_path = compilation_artifacts.output_directory.joinpath("traceback.txt")
|
||||
with open(traceback_path, "w", encoding="utf-8") as f:
|
||||
f.write(traceback.format_exc())
|
||||
|
||||
raise
|
||||
# for mypy
|
||||
assert isinstance(result, FHECircuit)
|
||||
return result
|
||||
|
||||
199
concrete/numpy/np_fhe_compiler.py
Normal file
199
concrete/numpy/np_fhe_compiler.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""Module to hold a user friendly class to compile programs."""
|
||||
|
||||
from copy import deepcopy
|
||||
from enum import Enum, unique
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
|
||||
from ..common.data_types import Integer
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.representation.intermediate import IntermediateNode
|
||||
from ..common.values import BaseValue
|
||||
from .compile import compile_numpy_function_into_op_graph, measure_op_graph_bounds_and_update
|
||||
from .np_dtypes_helpers import get_base_value_for_numpy_or_python_constant_data
|
||||
|
||||
|
||||
@unique
|
||||
class EncryptedStatus(str, Enum):
|
||||
"""Enum to validate GenericFunction op_kind."""
|
||||
|
||||
CLEAR = "clear"
|
||||
ENCRYPTED = "encrypted"
|
||||
|
||||
|
||||
class NPFHECompiler:
|
||||
"""Class to ease the conversion of a numpy program to its FHE equivalent."""
|
||||
|
||||
INPUTSET_SIZE_BEFORE_AUTO_BOUND_UPDATE: int = 128
|
||||
|
||||
# _function_to_compile is not optional but mypy has a long standing bug and is not able to
|
||||
# understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623
|
||||
_function_to_compile: Optional[Callable]
|
||||
_function_parameters_encrypted_status: Dict[str, bool]
|
||||
_current_inputset: List[Union[Any, Tuple]]
|
||||
_op_graph: Optional[OPGraph]
|
||||
_nodes_and_bounds: Dict[IntermediateNode, Dict[str, Any]]
|
||||
|
||||
_compilation_configuration: CompilationConfiguration
|
||||
|
||||
compilation_artifacts: CompilationArtifacts
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
function_to_compile: Callable,
|
||||
function_parameters_encrypted_status: Dict[str, Union[str, EncryptedStatus]],
|
||||
compilation_configuration: Optional[CompilationConfiguration] = None,
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
) -> None:
|
||||
self._function_to_compile = function_to_compile
|
||||
self._function_parameters_encrypted_status = {
|
||||
param_name: EncryptedStatus(status.lower()) == EncryptedStatus.ENCRYPTED
|
||||
for param_name, status in function_parameters_encrypted_status.items()
|
||||
}
|
||||
|
||||
self._current_inputset = []
|
||||
self._op_graph = None
|
||||
self._nodes_and_bounds = {}
|
||||
|
||||
self._compilation_configuration = (
|
||||
deepcopy(compilation_configuration)
|
||||
if compilation_configuration is not None
|
||||
else CompilationConfiguration()
|
||||
)
|
||||
self.compilation_artifacts = (
|
||||
compilation_artifacts if compilation_artifacts is not None else CompilationArtifacts()
|
||||
)
|
||||
|
||||
@property
|
||||
def function_to_compile(self) -> Callable:
|
||||
"""Get the function to compile.
|
||||
|
||||
Returns:
|
||||
Callable: the function to compile.
|
||||
"""
|
||||
# Continuation of mypy bug
|
||||
assert self._function_to_compile is not None
|
||||
return self._function_to_compile
|
||||
|
||||
@property
|
||||
def op_graph(self) -> Optional[OPGraph]:
|
||||
"""Return a copy of the OPGraph.
|
||||
|
||||
Returns:
|
||||
Optional[OPGraph]: the held OPGraph or None
|
||||
"""
|
||||
# To keep consistency with what the user expects, we make sure to evaluate on the remaining
|
||||
# inputset values if any before giving a copy of the OPGraph we trace
|
||||
self._eval_on_current_inputset()
|
||||
return deepcopy(self._op_graph)
|
||||
|
||||
@property
|
||||
def compilation_configuration(self) -> Optional[CompilationConfiguration]:
|
||||
"""Get a copy of the compilation configuration.
|
||||
|
||||
Returns:
|
||||
Optional[CompilationConfiguration]: copy of the current compilation configuration.
|
||||
"""
|
||||
return deepcopy(self._compilation_configuration)
|
||||
|
||||
def __call__(self, *args: Any) -> Any:
|
||||
"""Evaluate the OPGraph corresponding to the function being compiled and return result.
|
||||
|
||||
Returns:
|
||||
Any: the result of the OPGraph evaluation.
|
||||
"""
|
||||
self._current_inputset.append(deepcopy(args))
|
||||
|
||||
inferred_args = {
|
||||
param_name: get_base_value_for_numpy_or_python_constant_data(val)(
|
||||
is_encrypted=is_encrypted
|
||||
)
|
||||
for (param_name, is_encrypted), val in zip(
|
||||
self._function_parameters_encrypted_status.items(), args
|
||||
)
|
||||
}
|
||||
|
||||
if len(self._current_inputset) >= self.INPUTSET_SIZE_BEFORE_AUTO_BOUND_UPDATE:
|
||||
self._eval_on_current_inputset()
|
||||
|
||||
self._trace_op_graph_if_needed(inferred_args)
|
||||
|
||||
# For mypy
|
||||
assert self._op_graph is not None
|
||||
return self._op_graph(*args)
|
||||
|
||||
def eval_on_inputset(self, inputset: Iterable[Union[Any, Tuple]]) -> None:
|
||||
"""Evaluate the underlying function on an inputset in one go, populates OPGraph and bounds.
|
||||
|
||||
Args:
|
||||
inputset (Iterable[Union[Any, Tuple]]): The inputset on which the function should be
|
||||
evaluated.
|
||||
"""
|
||||
inputset_as_list = list(inputset)
|
||||
if len(inputset_as_list) == 0:
|
||||
return
|
||||
|
||||
inferred_args = {
|
||||
param_name: get_base_value_for_numpy_or_python_constant_data(val)(
|
||||
is_encrypted=is_encrypted
|
||||
)
|
||||
for (param_name, is_encrypted), val in zip(
|
||||
self._function_parameters_encrypted_status.items(), self._current_inputset[0]
|
||||
)
|
||||
}
|
||||
|
||||
self._trace_op_graph_if_needed(inferred_args)
|
||||
|
||||
# For mypy
|
||||
assert self._op_graph is not None
|
||||
|
||||
self._patch_op_graph_input_to_accept_any_integer_input()
|
||||
|
||||
self._nodes_and_bounds = measure_op_graph_bounds_and_update(
|
||||
self._op_graph,
|
||||
inferred_args,
|
||||
inputset_as_list,
|
||||
self._compilation_configuration,
|
||||
self.compilation_artifacts,
|
||||
self._nodes_and_bounds,
|
||||
False,
|
||||
)
|
||||
|
||||
def _eval_on_current_inputset(self) -> None:
|
||||
"""Evaluate OPGraph on _current_inputset."""
|
||||
self.eval_on_inputset(self._current_inputset)
|
||||
self._current_inputset.clear()
|
||||
|
||||
def _needs_tracing(self) -> bool:
|
||||
"""Return whether we need to trace the function and populate the OPGraph."""
|
||||
return self._op_graph is None
|
||||
|
||||
def _trace_op_graph_if_needed(self, inferred_args: Dict[str, BaseValue]) -> None:
|
||||
"""Populate _op_graph with the OPGraph for _function_to_compile."""
|
||||
if not self._needs_tracing():
|
||||
return
|
||||
|
||||
self._op_graph = compile_numpy_function_into_op_graph(
|
||||
self.function_to_compile,
|
||||
inferred_args,
|
||||
self._compilation_configuration,
|
||||
self.compilation_artifacts,
|
||||
)
|
||||
|
||||
def _patch_op_graph_input_to_accept_any_integer_input(self) -> None:
|
||||
"""Patch inputs as we don't know what data we expect."""
|
||||
|
||||
# Can only do that if the OPGraph was created hence the test.
|
||||
if self._needs_tracing():
|
||||
return
|
||||
|
||||
# For mypy
|
||||
assert self._op_graph is not None
|
||||
|
||||
# Cheat on Input nodes to avoid issues during inputset eval as we do not know in advance
|
||||
# what the final bit width for the inputs should be
|
||||
for node in self._op_graph.input_nodes.values():
|
||||
for input_ in node.inputs:
|
||||
if isinstance(dtype := (input_.dtype), Integer):
|
||||
dtype.bit_width = 128
|
||||
dtype.is_signed = True
|
||||
215
tests/numpy/test_compile_user_friendly_api.py
Normal file
215
tests/numpy/test_compile_user_friendly_api.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""Test file for user-friendly numpy compilation functions"""
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.debugging import format_operation_graph
|
||||
from concrete.numpy.np_fhe_compiler import NPFHECompiler
|
||||
|
||||
|
||||
def complicated_topology(x, y):
|
||||
"""Mix x in an intricated way."""
|
||||
intermediate = x + y
|
||||
x_p_1 = intermediate + 1
|
||||
x_p_2 = intermediate + 2
|
||||
x_p_3 = x_p_1 + x_p_2
|
||||
return (
|
||||
x_p_3.astype(numpy.int32),
|
||||
x_p_2.astype(numpy.int32),
|
||||
(x_p_2 + 3).astype(numpy.int32),
|
||||
x_p_3.astype(numpy.int32) + 67,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("input_shape", [(), (3, 1, 2)])
|
||||
def test_np_fhe_compiler(input_shape, default_compilation_configuration):
|
||||
"""Test NPFHECompiler in two subtests."""
|
||||
subtest_np_fhe_compiler_1_input(input_shape, default_compilation_configuration)
|
||||
subtest_np_fhe_compiler_2_inputs(input_shape, default_compilation_configuration)
|
||||
|
||||
|
||||
def subtest_np_fhe_compiler_1_input(input_shape, default_compilation_configuration):
|
||||
"""test for NPFHECompiler on one input function"""
|
||||
|
||||
compiler = NPFHECompiler(
|
||||
lambda x: complicated_topology(x, 0),
|
||||
{"x": "encrypted"},
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
# For coverage when the OPGraph is not yet traced
|
||||
compiler._patch_op_graph_input_to_accept_any_integer_input() # pylint: disable=protected-access
|
||||
|
||||
assert compiler.compilation_configuration == default_compilation_configuration
|
||||
assert compiler.compilation_configuration is not default_compilation_configuration
|
||||
|
||||
for i in numpy.arange(5):
|
||||
i = numpy.ones(input_shape, dtype=numpy.int64) * i
|
||||
assert numpy.array_equal(compiler(i), complicated_topology(i, 0))
|
||||
|
||||
# For coverage, check that we flush the inputset when we query the OPGraph
|
||||
current_op_graph = compiler.op_graph
|
||||
assert current_op_graph is not compiler.op_graph
|
||||
assert len(compiler._current_inputset) == 0 # pylint: disable=protected-access
|
||||
# For coverage, cover case where the current inputset is empty
|
||||
compiler._eval_on_current_inputset() # pylint: disable=protected-access
|
||||
|
||||
# Continue a bit more
|
||||
for i in numpy.arange(5, 10):
|
||||
i = numpy.ones(input_shape, dtype=numpy.int64) * i
|
||||
assert numpy.array_equal(compiler(i), complicated_topology(i, 0))
|
||||
|
||||
if input_shape == ():
|
||||
assert (
|
||||
(got := format_operation_graph(compiler.op_graph))
|
||||
== """ %0 = 67 # ClearScalar<uint7>
|
||||
%1 = 2 # ClearScalar<uint2>
|
||||
%2 = 3 # ClearScalar<uint2>
|
||||
%3 = 1 # ClearScalar<uint1>
|
||||
%4 = x # EncryptedScalar<uint4>
|
||||
%5 = 0 # ClearScalar<uint1>
|
||||
%6 = add(%4, %5) # EncryptedScalar<uint4>
|
||||
%7 = add(%6, %1) # EncryptedScalar<uint4>
|
||||
%8 = add(%6, %3) # EncryptedScalar<uint4>
|
||||
%9 = astype(%7, dtype=int32) # EncryptedScalar<uint4>
|
||||
%10 = add(%7, %2) # EncryptedScalar<uint4>
|
||||
%11 = add(%8, %7) # EncryptedScalar<uint5>
|
||||
%12 = astype(%10, dtype=int32) # EncryptedScalar<uint4>
|
||||
%13 = astype(%11, dtype=int32) # EncryptedScalar<uint5>
|
||||
%14 = astype(%11, dtype=int32) # EncryptedScalar<uint5>
|
||||
%15 = add(%14, %0) # EncryptedScalar<uint7>
|
||||
(%13, %9, %12, %15)"""
|
||||
), got
|
||||
else:
|
||||
assert (
|
||||
(got := format_operation_graph(compiler.op_graph))
|
||||
== """ %0 = 67 # ClearScalar<uint7>
|
||||
%1 = 2 # ClearScalar<uint2>
|
||||
%2 = 3 # ClearScalar<uint2>
|
||||
%3 = 1 # ClearScalar<uint1>
|
||||
%4 = x # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
||||
%5 = 0 # ClearScalar<uint1>
|
||||
%6 = add(%4, %5) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
||||
%7 = add(%6, %1) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
||||
%8 = add(%6, %3) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
||||
%9 = astype(%7, dtype=int32) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
||||
%10 = add(%7, %2) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
||||
%11 = add(%8, %7) # EncryptedTensor<uint5, shape=(3, 1, 2)>
|
||||
%12 = astype(%10, dtype=int32) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
||||
%13 = astype(%11, dtype=int32) # EncryptedTensor<uint5, shape=(3, 1, 2)>
|
||||
%14 = astype(%11, dtype=int32) # EncryptedTensor<uint5, shape=(3, 1, 2)>
|
||||
%15 = add(%14, %0) # EncryptedTensor<uint7, shape=(3, 1, 2)>
|
||||
(%13, %9, %12, %15)"""
|
||||
), got
|
||||
|
||||
|
||||
def subtest_np_fhe_compiler_2_inputs(input_shape, default_compilation_configuration):
|
||||
"""test for NPFHECompiler on two inputs function"""
|
||||
|
||||
compiler = NPFHECompiler(
|
||||
complicated_topology,
|
||||
{"x": "encrypted", "y": "clear"},
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
# For coverage when the OPGraph is not yet traced
|
||||
compiler._patch_op_graph_input_to_accept_any_integer_input() # pylint: disable=protected-access
|
||||
|
||||
assert compiler.compilation_configuration == default_compilation_configuration
|
||||
assert compiler.compilation_configuration is not default_compilation_configuration
|
||||
|
||||
for i, j in zip(numpy.arange(5), numpy.arange(5, 10)):
|
||||
i = numpy.ones(input_shape, dtype=numpy.int64) * i
|
||||
j = numpy.ones(input_shape, dtype=numpy.int64) * j
|
||||
assert numpy.array_equal(compiler(i, j), complicated_topology(i, j))
|
||||
|
||||
# For coverage, check that we flush the inputset when we query the OPGraph
|
||||
current_op_graph = compiler.op_graph
|
||||
assert current_op_graph is not compiler.op_graph
|
||||
assert len(compiler._current_inputset) == 0 # pylint: disable=protected-access
|
||||
# For coverage, cover case where the current inputset is empty
|
||||
compiler._eval_on_current_inputset() # pylint: disable=protected-access
|
||||
|
||||
# Continue a bit more
|
||||
for i, j in zip(numpy.arange(5, 10), numpy.arange(5)):
|
||||
i = numpy.ones(input_shape, dtype=numpy.int64) * i
|
||||
j = numpy.ones(input_shape, dtype=numpy.int64) * j
|
||||
assert numpy.array_equal(compiler(i, j), complicated_topology(i, j))
|
||||
|
||||
if input_shape == ():
|
||||
assert (
|
||||
(got := format_operation_graph(compiler.op_graph))
|
||||
== """ %0 = 67 # ClearScalar<uint7>
|
||||
%1 = 2 # ClearScalar<uint2>
|
||||
%2 = 3 # ClearScalar<uint2>
|
||||
%3 = 1 # ClearScalar<uint1>
|
||||
%4 = x # EncryptedScalar<uint4>
|
||||
%5 = y # ClearScalar<uint4>
|
||||
%6 = add(%4, %5) # EncryptedScalar<uint4>
|
||||
%7 = add(%6, %1) # EncryptedScalar<uint4>
|
||||
%8 = add(%6, %3) # EncryptedScalar<uint4>
|
||||
%9 = astype(%7, dtype=int32) # EncryptedScalar<uint4>
|
||||
%10 = add(%7, %2) # EncryptedScalar<uint5>
|
||||
%11 = add(%8, %7) # EncryptedScalar<uint5>
|
||||
%12 = astype(%10, dtype=int32) # EncryptedScalar<uint5>
|
||||
%13 = astype(%11, dtype=int32) # EncryptedScalar<uint5>
|
||||
%14 = astype(%11, dtype=int32) # EncryptedScalar<uint5>
|
||||
%15 = add(%14, %0) # EncryptedScalar<uint7>
|
||||
(%13, %9, %12, %15)"""
|
||||
), got
|
||||
else:
|
||||
assert (
|
||||
(got := format_operation_graph(compiler.op_graph))
|
||||
== """ %0 = 67 # ClearScalar<uint7>
|
||||
%1 = 2 # ClearScalar<uint2>
|
||||
%2 = 3 # ClearScalar<uint2>
|
||||
%3 = 1 # ClearScalar<uint1>
|
||||
%4 = x # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
||||
%5 = y # ClearTensor<uint4, shape=(3, 1, 2)>
|
||||
%6 = add(%4, %5) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
||||
%7 = add(%6, %1) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
||||
%8 = add(%6, %3) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
||||
%9 = astype(%7, dtype=int32) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
||||
%10 = add(%7, %2) # EncryptedTensor<uint5, shape=(3, 1, 2)>
|
||||
%11 = add(%8, %7) # EncryptedTensor<uint5, shape=(3, 1, 2)>
|
||||
%12 = astype(%10, dtype=int32) # EncryptedTensor<uint5, shape=(3, 1, 2)>
|
||||
%13 = astype(%11, dtype=int32) # EncryptedTensor<uint5, shape=(3, 1, 2)>
|
||||
%14 = astype(%11, dtype=int32) # EncryptedTensor<uint5, shape=(3, 1, 2)>
|
||||
%15 = add(%14, %0) # EncryptedTensor<uint7, shape=(3, 1, 2)>
|
||||
(%13, %9, %12, %15)"""
|
||||
), got
|
||||
|
||||
|
||||
def remaining_inputset_size(inputset_len):
|
||||
"""Small function to generate test cases below for remaining inputset length."""
|
||||
return inputset_len % NPFHECompiler.INPUTSET_SIZE_BEFORE_AUTO_BOUND_UPDATE
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputset_len, expected_remaining_inputset_len",
|
||||
[
|
||||
(42, remaining_inputset_size(42)),
|
||||
(128, remaining_inputset_size(128)),
|
||||
(234, remaining_inputset_size(234)),
|
||||
],
|
||||
)
|
||||
def test_np_fhe_compiler_auto_flush(
|
||||
inputset_len,
|
||||
expected_remaining_inputset_len,
|
||||
default_compilation_configuration,
|
||||
):
|
||||
"""Test the auto flush of NPFHECompiler once the inputset is 128 elements."""
|
||||
compiler = NPFHECompiler(
|
||||
lambda x: x // 2,
|
||||
{"x": "encrypted"},
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
for i in numpy.arange(inputset_len):
|
||||
assert numpy.array_equal(compiler(i), i // 2)
|
||||
|
||||
# Check the inputset was properly flushed
|
||||
assert (
|
||||
len(compiler._current_inputset) # pylint: disable=protected-access
|
||||
== expected_remaining_inputset_len
|
||||
)
|
||||
Reference in New Issue
Block a user