feat: create torch-like APIs part 1

- work on generating OPGraph with a torch-like API

refs #233
This commit is contained in:
Arthur Meyre
2021-11-18 13:35:45 +01:00
parent ac74e94e13
commit 13b9ff96f0
6 changed files with 739 additions and 103 deletions

View File

@@ -1,7 +1,8 @@
"""Code to evaluate the IR graph on inputsets."""
import sys
from typing import Any, Callable, Dict, Iterable, Tuple, Union
from functools import partial
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
from ..compilation import CompilationConfiguration
from ..data_types.dtypes_helpers import (
@@ -110,6 +111,7 @@ def eval_op_graph_bounds_on_inputset(
get_base_value_for_constant_data_func: Callable[
[Any], Any
] = get_base_value_for_python_constant_data,
prev_node_bounds_and_samples: Optional[Dict[IntermediateNode, Dict[str, Any]]] = None,
) -> Tuple[int, Dict[IntermediateNode, Dict[str, Any]]]:
"""Evaluate the bounds with a inputset.
@@ -132,6 +134,8 @@ def eval_op_graph_bounds_on_inputset(
tensors). Defaults to max.
get_base_value_for_constant_data_func (Callable[[Any], Any], optional): custom function
to compute the base value of a python object.
prev_node_bounds_and_samples (Optional[Dict[IntermediateNode, Dict[str, Any]]], optional):
Bounds and samples from a previous run. Defaults to None.
Returns:
Tuple[int, Dict[IntermediateNode, Dict[str, Any]]]: number of inputs in the inputset and
@@ -187,13 +191,38 @@ def eval_op_graph_bounds_on_inputset(
first_output = op_graph.evaluate(current_input_data)
prev_node_bounds_and_samples = (
{} if prev_node_bounds_and_samples is None else prev_node_bounds_and_samples
)
def get_previous_value_for_key_or_default_for_dict(
dict_: Dict[IntermediateNode, Dict[str, Any]],
node: IntermediateNode,
key: str,
default: Any,
) -> Any:
return_value = default
previous_value_dict = dict_.get(node, None)
if previous_value_dict is not None:
return_value = previous_value_dict.get(key, default)
return return_value
get_previous_value_for_key_or_default = partial(
get_previous_value_for_key_or_default_for_dict, prev_node_bounds_and_samples
)
# We evaluate the min and max func to be able to resolve the tensors min and max rather than
# having the tensor itself as the stored min and max values.
# As we don't know the integrity of prev_node_bounds_and_samples we make sure we can
# populate the new node_bounds_and_samples
node_bounds_and_samples = {
node: {
"min": min_func(value, value),
"max": max_func(value, value),
"sample": value,
"min": min_func(value, get_previous_value_for_key_or_default(node, "min", value)),
"max": max_func(value, get_previous_value_for_key_or_default(node, "max", value)),
"sample": get_previous_value_for_key_or_default(node, "sample", value),
}
for node, value in first_output.items()
}

View File

@@ -26,3 +26,6 @@ class CompilationConfiguration:
self.treat_warnings_as_errors = treat_warnings_as_errors
self.enable_unsafe_features = enable_unsafe_features
self.random_inputset_samples = random_inputset_samples
def __eq__(self, other) -> bool:
return isinstance(other, CompilationConfiguration) and self.__dict__ == other.__dict__

View File

@@ -15,4 +15,5 @@ from ..common.extensions.multi_table import MultiLookupTable
from ..common.extensions.table import LookupTable
from ..common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue
from .compile import compile_numpy_function, compile_numpy_function_into_op_graph_and_measure_bounds
from .np_fhe_compiler import NPFHECompiler
from .tracing import trace_numpy_function

View File

@@ -22,7 +22,7 @@ from ..common.mlir.utils import (
)
from ..common.operator_graph import OPGraph
from ..common.optimization.topological import fuse_float_operations
from ..common.representation.intermediate import Add, Constant, GenericFunction
from ..common.representation.intermediate import Add, Constant, GenericFunction, IntermediateNode
from ..common.values import BaseValue, ClearScalar
from ..numpy.tracing import trace_numpy_function
from .np_dtypes_helpers import (
@@ -60,6 +60,102 @@ def numpy_min_func(lhs: Any, rhs: Any) -> Any:
return numpy.minimum(lhs, rhs).min()
def sanitize_compilation_configuration_and_artifacts(
compilation_configuration: Optional[CompilationConfiguration] = None,
compilation_artifacts: Optional[CompilationArtifacts] = None,
) -> Tuple[CompilationConfiguration, CompilationArtifacts]:
"""Return the proper compilation configuration and artifacts.
Default values are returned if None is passed for each argument.
Args:
compilation_configuration (Optional[CompilationConfiguration], optional): the compilation
configuration to sanitize. Defaults to None.
compilation_artifacts (Optional[CompilationArtifacts], optional): the compilation artifacts
to sanitize. Defaults to None.
Returns:
Tuple[CompilationConfiguration, CompilationArtifacts]: the tuple of sanitized configuration
and artifacts.
"""
# Create default configuration if custom configuration is not specified
compilation_configuration = (
CompilationConfiguration()
if compilation_configuration is None
else compilation_configuration
)
# Create temporary artifacts if custom artifacts is not specified (in case of exceptions)
if compilation_artifacts is None:
compilation_artifacts = CompilationArtifacts()
return compilation_configuration, compilation_artifacts
def get_inputset_to_use(
function_parameters: Dict[str, BaseValue],
inputset: Union[Iterable[Tuple[Any, ...]], str],
compilation_configuration: CompilationConfiguration,
) -> Iterable[Tuple[Any, ...]]:
"""Get the proper inputset to use for compilation.
Args:
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
inputset (Union[Iterable[Tuple[Any, ...]], str]): The inputset over which op_graph is
evaluated. It needs to be an iterable on tuples which are of the same length than the
number of parameters in the function, and in the same order than these same parameters
compilation_configuration (CompilationConfiguration): Configuration object to use during
compilation
Returns:
Iterable[Tuple[Any, ...]]: the inputset to use.
"""
# Generate random inputset if it is requested and available
if isinstance(inputset, str):
_check_special_inputset_availability(inputset, compilation_configuration)
return _generate_random_inputset(function_parameters, compilation_configuration)
return inputset
def run_compilation_function_with_error_management(
compilation_function: Callable,
compilation_configuration: CompilationConfiguration,
compilation_artifacts: CompilationArtifacts,
) -> Any:
"""Call compilation_function() and manage exceptions that may occur.
Args:
compilation_function (Callable): the compilation function to call.
compilation_configuration (CompilationConfiguration): the current compilation configuration.
compilation_artifacts (CompilationArtifacts): the current compilation artifacts.
Returns:
Any: returns the result of the call to compilation_function
"""
# Try to compile the function and save partial artifacts on failure
try:
# Use context manager to restore numpy error handling
with numpy.errstate(**numpy.geterr()):
return compilation_function()
except Exception: # pragma: no cover
# This branch is reserved for unexpected issues and hence it shouldn't be tested.
# If it could be tested, we would have fixed the underlying issue.
# We need to export all the information we have about the compilation
# If the user wants them to be exported
if compilation_configuration.dump_artifacts_on_unexpected_failures:
compilation_artifacts.export()
traceback_path = compilation_artifacts.output_directory.joinpath("traceback.txt")
with open(traceback_path, "w", encoding="utf-8") as f:
f.write(traceback.format_exc())
raise
def _compile_numpy_function_into_op_graph_internal(
function_to_compile: Callable,
function_parameters: Dict[str, BaseValue],
@@ -122,13 +218,60 @@ def _compile_numpy_function_into_op_graph_internal(
return op_graph
def compile_numpy_function_into_op_graph(
function_to_compile: Callable,
function_parameters: Dict[str, BaseValue],
compilation_configuration: Optional[CompilationConfiguration] = None,
compilation_artifacts: Optional[CompilationArtifacts] = None,
) -> OPGraph:
"""Compile a function into an OPGraph.
Args:
function_to_compile (Callable): The function to compile
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use
during compilation
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
during compilation
Returns:
OPGraph: compiled function into a graph
"""
(
compilation_configuration,
compilation_artifacts,
) = sanitize_compilation_configuration_and_artifacts(
compilation_configuration, compilation_artifacts
)
def compilation_function():
return _compile_numpy_function_into_op_graph_internal(
function_to_compile,
function_parameters,
compilation_configuration,
compilation_artifacts,
)
result = run_compilation_function_with_error_management(
compilation_function, compilation_configuration, compilation_artifacts
)
# for mypy
assert isinstance(result, OPGraph)
return result
def _measure_op_graph_bounds_and_update_internal(
op_graph: OPGraph,
function_parameters: Dict[str, BaseValue],
inputset: Iterable[Tuple[Any, ...]],
compilation_configuration: CompilationConfiguration,
compilation_artifacts: CompilationArtifacts,
) -> None:
prev_node_bounds_and_samples: Optional[Dict[IntermediateNode, Dict[str, Any]]] = None,
warn_on_inputset_length: bool = True,
) -> Dict[IntermediateNode, Dict[str, Any]]:
"""Measure the intermediate values and update the OPGraph accordingly for the given inputset.
Args:
@@ -142,11 +285,21 @@ def _measure_op_graph_bounds_and_update_internal(
during compilation
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
during compilation
prev_node_bounds_and_samples (Optional[Dict[IntermediateNode, Dict[str, Any]]], optional):
Bounds and samples from a previous run. Defaults to None.
warn_on_inputset_length (bool, optional): Set to True to get a warning if inputset is not
long enough. Defaults to True.
Raises:
ValueError: Raises an error if the inputset is too small and the compilation configuration
treats warnings as error.
Returns:
Dict[IntermediateNode, Dict[str, Any]]: a dict containing the bounds for each node from
op_graph, stored with the node as key and a dict with keys "min", "max" and "sample" as
value.
"""
# Find bounds with the inputset
inputset_size, node_bounds_and_samples = eval_op_graph_bounds_on_inputset(
op_graph,
@@ -155,35 +308,37 @@ def _measure_op_graph_bounds_and_update_internal(
min_func=numpy_min_func,
max_func=numpy_max_func,
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
prev_node_bounds_and_samples=prev_node_bounds_and_samples,
)
# Check inputset size
inputset_size_upper_limit = 1
if warn_on_inputset_length:
# Check inputset size
inputset_size_upper_limit = 1
# this loop will determine the number of possible inputs of the function
# if a function have a single 3-bit input, for example, `inputset_size_upper_limit` will be 8
for parameter_value in function_parameters.values():
if isinstance(parameter_value.dtype, Integer):
# multiple parameter bit-widths are multiplied as they can be combined into an input
inputset_size_upper_limit *= 2 ** parameter_value.dtype.bit_width
# this loop will determine the number of possible inputs of the function
# if a function have a single 3-bit input, for example, inputset_size_upper_limit will be 8
for parameter_value in function_parameters.values():
if isinstance(parameter_value.dtype, Integer):
# multiple parameter bit-widths are multiplied as they can be combined into an input
inputset_size_upper_limit *= 2 ** parameter_value.dtype.bit_width
# if the upper limit of the inputset size goes above 10,
# break the loop as we will require at least 10 inputs in this case
if inputset_size_upper_limit > 10:
break
# if the upper limit of the inputset size goes above 10,
# break the loop as we will require at least 10 inputs in this case
if inputset_size_upper_limit > 10:
break
minimum_required_inputset_size = min(inputset_size_upper_limit, 10)
if inputset_size < minimum_required_inputset_size:
message = (
f"Provided inputset contains too few inputs "
f"(it should have had at least {minimum_required_inputset_size} "
f"but it only had {inputset_size})\n"
)
minimum_required_inputset_size = min(inputset_size_upper_limit, 10)
if inputset_size < minimum_required_inputset_size:
message = (
f"Provided inputset contains too few inputs "
f"(it should have had at least {minimum_required_inputset_size} "
f"but it only had {inputset_size})\n"
)
if compilation_configuration.treat_warnings_as_errors:
raise ValueError(message)
if compilation_configuration.treat_warnings_as_errors:
raise ValueError(message)
sys.stderr.write(f"Warning: {message}")
sys.stderr.write(f"Warning: {message}")
# Add the bounds as an artifact
compilation_artifacts.add_final_operation_graph_bounds(node_bounds_and_samples)
@@ -195,6 +350,74 @@ def _measure_op_graph_bounds_and_update_internal(
get_constructor_for_numpy_or_python_constant_data,
)
return node_bounds_and_samples
def measure_op_graph_bounds_and_update(
op_graph: OPGraph,
function_parameters: Dict[str, BaseValue],
inputset: Union[Iterable[Tuple[Any, ...]], str],
compilation_configuration: Optional[CompilationConfiguration] = None,
compilation_artifacts: Optional[CompilationArtifacts] = None,
prev_node_bounds_and_samples: Optional[Dict[IntermediateNode, Dict[str, Any]]] = None,
warn_on_inputset_length: bool = True,
) -> Dict[IntermediateNode, Dict[str, Any]]:
"""Measure the intermediate values and update the OPGraph accordingly for the given inputset.
Args:
op_graph (OPGraph): the OPGraph for which to measure bounds and update node values.
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
inputset (Union[Iterable[Tuple[Any, ...]], str]): The inputset over which op_graph is
evaluated. It needs to be an iterable on tuples which are of the same length than the
number of parameters in the function, and in the same order than these same parameters
compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use
during compilation
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
during compilation
prev_node_bounds_and_samples (Optional[Dict[IntermediateNode, Dict[str, Any]]], optional):
Bounds and samples from a previous run. Defaults to None.
warn_on_inputset_length (bool, optional): Set to True to get a warning if inputset is not
long enough. Defaults to True.
Raises:
ValueError: Raises an error if the inputset is too small and the compilation configuration
treats warnings as error.
Returns:
Dict[IntermediateNode, Dict[str, Any]]: a dict containing the bounds for each node from
op_graph, stored with the node as key and a dict with keys "min", "max" and "sample" as
value.
"""
(
compilation_configuration,
compilation_artifacts,
) = sanitize_compilation_configuration_and_artifacts(
compilation_configuration, compilation_artifacts
)
inputset = get_inputset_to_use(function_parameters, inputset, compilation_configuration)
def compilation_function():
return _measure_op_graph_bounds_and_update_internal(
op_graph,
function_parameters,
inputset,
compilation_configuration,
compilation_artifacts,
prev_node_bounds_and_samples,
warn_on_inputset_length,
)
result = run_compilation_function_with_error_management(
compilation_function, compilation_configuration, compilation_artifacts
)
# for mypy
assert isinstance(result, dict)
return result
def _compile_numpy_function_into_op_graph_and_measure_bounds_internal(
function_to_compile: Callable,
@@ -269,48 +492,31 @@ def compile_numpy_function_into_op_graph_and_measure_bounds(
OPGraph: compiled function into a graph
"""
# Create default configuration if custom configuration is not specified
compilation_configuration = (
CompilationConfiguration()
if compilation_configuration is None
else compilation_configuration
(
compilation_configuration,
compilation_artifacts,
) = sanitize_compilation_configuration_and_artifacts(
compilation_configuration, compilation_artifacts
)
# Create temporary artifacts if custom artifacts is not specified (in case of exceptions)
if compilation_artifacts is None:
compilation_artifacts = CompilationArtifacts()
inputset = get_inputset_to_use(function_parameters, inputset, compilation_configuration)
# Generate random inputset if it is requested and available
if isinstance(inputset, str):
_check_special_inputset_availability(inputset, compilation_configuration)
inputset = _generate_random_inputset(function_parameters, compilation_configuration)
def compilation_function():
return _compile_numpy_function_into_op_graph_and_measure_bounds_internal(
function_to_compile,
function_parameters,
inputset,
compilation_configuration,
compilation_artifacts,
)
# Try to compile the function and save partial artifacts on failure
try:
# Use context manager to restore numpy error handling
with numpy.errstate(**numpy.geterr()):
return _compile_numpy_function_into_op_graph_and_measure_bounds_internal(
function_to_compile,
function_parameters,
inputset,
compilation_configuration,
compilation_artifacts,
)
except Exception: # pragma: no cover
# This branch is reserved for unexpected issues and hence it shouldn't be tested.
# If it could be tested, we would have fixed the underlying issue.
result = run_compilation_function_with_error_management(
compilation_function, compilation_configuration, compilation_artifacts
)
# We need to export all the information we have about the compilation
# If the user wants them to be exported
if compilation_configuration.dump_artifacts_on_unexpected_failures:
compilation_artifacts.export()
traceback_path = compilation_artifacts.output_directory.joinpath("traceback.txt")
with open(traceback_path, "w", encoding="utf-8") as f:
f.write(traceback.format_exc())
raise
# for mypy
assert isinstance(result, OPGraph)
return result
# HACK
@@ -492,46 +698,29 @@ def compile_numpy_function(
CompilerEngine: engine to run and debug the compiled graph
"""
# Create default configuration if custom configuration is not specified
compilation_configuration = (
CompilationConfiguration()
if compilation_configuration is None
else compilation_configuration
(
compilation_configuration,
compilation_artifacts,
) = sanitize_compilation_configuration_and_artifacts(
compilation_configuration, compilation_artifacts
)
# Create temporary artifacts if custom artifacts is not specified (in case of exceptions)
if compilation_artifacts is None:
compilation_artifacts = CompilationArtifacts()
inputset = get_inputset_to_use(function_parameters, inputset, compilation_configuration)
# Generate random inputset if it is requested and available
if isinstance(inputset, str):
_check_special_inputset_availability(inputset, compilation_configuration)
inputset = _generate_random_inputset(function_parameters, compilation_configuration)
def compilation_function():
return _compile_numpy_function_internal(
function_to_compile,
function_parameters,
inputset,
compilation_configuration,
compilation_artifacts,
show_mlir,
)
# Try to compile the function and save partial artifacts on failure
try:
# Use context manager to restore numpy error handling
with numpy.errstate(**numpy.geterr()):
return _compile_numpy_function_internal(
function_to_compile,
function_parameters,
inputset,
compilation_configuration,
compilation_artifacts,
show_mlir,
)
except Exception: # pragma: no cover
# This branch is reserved for unexpected issues and hence it shouldn't be tested.
# If it could be tested, we would have fixed the underlying issue.
result = run_compilation_function_with_error_management(
compilation_function, compilation_configuration, compilation_artifacts
)
# We need to export all the information we have about the compilation
# If the user wants them to be exported
if compilation_configuration.dump_artifacts_on_unexpected_failures:
compilation_artifacts.export()
traceback_path = compilation_artifacts.output_directory.joinpath("traceback.txt")
with open(traceback_path, "w", encoding="utf-8") as f:
f.write(traceback.format_exc())
raise
# for mypy
assert isinstance(result, FHECircuit)
return result

View File

@@ -0,0 +1,199 @@
"""Module to hold a user friendly class to compile programs."""
from copy import deepcopy
from enum import Enum, unique
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
from ..common.data_types import Integer
from ..common.operator_graph import OPGraph
from ..common.representation.intermediate import IntermediateNode
from ..common.values import BaseValue
from .compile import compile_numpy_function_into_op_graph, measure_op_graph_bounds_and_update
from .np_dtypes_helpers import get_base_value_for_numpy_or_python_constant_data
@unique
class EncryptedStatus(str, Enum):
"""Enum to validate GenericFunction op_kind."""
CLEAR = "clear"
ENCRYPTED = "encrypted"
class NPFHECompiler:
"""Class to ease the conversion of a numpy program to its FHE equivalent."""
INPUTSET_SIZE_BEFORE_AUTO_BOUND_UPDATE: int = 128
# _function_to_compile is not optional but mypy has a long standing bug and is not able to
# understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623
_function_to_compile: Optional[Callable]
_function_parameters_encrypted_status: Dict[str, bool]
_current_inputset: List[Union[Any, Tuple]]
_op_graph: Optional[OPGraph]
_nodes_and_bounds: Dict[IntermediateNode, Dict[str, Any]]
_compilation_configuration: CompilationConfiguration
compilation_artifacts: CompilationArtifacts
def __init__(
self,
function_to_compile: Callable,
function_parameters_encrypted_status: Dict[str, Union[str, EncryptedStatus]],
compilation_configuration: Optional[CompilationConfiguration] = None,
compilation_artifacts: Optional[CompilationArtifacts] = None,
) -> None:
self._function_to_compile = function_to_compile
self._function_parameters_encrypted_status = {
param_name: EncryptedStatus(status.lower()) == EncryptedStatus.ENCRYPTED
for param_name, status in function_parameters_encrypted_status.items()
}
self._current_inputset = []
self._op_graph = None
self._nodes_and_bounds = {}
self._compilation_configuration = (
deepcopy(compilation_configuration)
if compilation_configuration is not None
else CompilationConfiguration()
)
self.compilation_artifacts = (
compilation_artifacts if compilation_artifacts is not None else CompilationArtifacts()
)
@property
def function_to_compile(self) -> Callable:
"""Get the function to compile.
Returns:
Callable: the function to compile.
"""
# Continuation of mypy bug
assert self._function_to_compile is not None
return self._function_to_compile
@property
def op_graph(self) -> Optional[OPGraph]:
"""Return a copy of the OPGraph.
Returns:
Optional[OPGraph]: the held OPGraph or None
"""
# To keep consistency with what the user expects, we make sure to evaluate on the remaining
# inputset values if any before giving a copy of the OPGraph we trace
self._eval_on_current_inputset()
return deepcopy(self._op_graph)
@property
def compilation_configuration(self) -> Optional[CompilationConfiguration]:
"""Get a copy of the compilation configuration.
Returns:
Optional[CompilationConfiguration]: copy of the current compilation configuration.
"""
return deepcopy(self._compilation_configuration)
def __call__(self, *args: Any) -> Any:
"""Evaluate the OPGraph corresponding to the function being compiled and return result.
Returns:
Any: the result of the OPGraph evaluation.
"""
self._current_inputset.append(deepcopy(args))
inferred_args = {
param_name: get_base_value_for_numpy_or_python_constant_data(val)(
is_encrypted=is_encrypted
)
for (param_name, is_encrypted), val in zip(
self._function_parameters_encrypted_status.items(), args
)
}
if len(self._current_inputset) >= self.INPUTSET_SIZE_BEFORE_AUTO_BOUND_UPDATE:
self._eval_on_current_inputset()
self._trace_op_graph_if_needed(inferred_args)
# For mypy
assert self._op_graph is not None
return self._op_graph(*args)
def eval_on_inputset(self, inputset: Iterable[Union[Any, Tuple]]) -> None:
"""Evaluate the underlying function on an inputset in one go, populates OPGraph and bounds.
Args:
inputset (Iterable[Union[Any, Tuple]]): The inputset on which the function should be
evaluated.
"""
inputset_as_list = list(inputset)
if len(inputset_as_list) == 0:
return
inferred_args = {
param_name: get_base_value_for_numpy_or_python_constant_data(val)(
is_encrypted=is_encrypted
)
for (param_name, is_encrypted), val in zip(
self._function_parameters_encrypted_status.items(), self._current_inputset[0]
)
}
self._trace_op_graph_if_needed(inferred_args)
# For mypy
assert self._op_graph is not None
self._patch_op_graph_input_to_accept_any_integer_input()
self._nodes_and_bounds = measure_op_graph_bounds_and_update(
self._op_graph,
inferred_args,
inputset_as_list,
self._compilation_configuration,
self.compilation_artifacts,
self._nodes_and_bounds,
False,
)
def _eval_on_current_inputset(self) -> None:
"""Evaluate OPGraph on _current_inputset."""
self.eval_on_inputset(self._current_inputset)
self._current_inputset.clear()
def _needs_tracing(self) -> bool:
"""Return whether we need to trace the function and populate the OPGraph."""
return self._op_graph is None
def _trace_op_graph_if_needed(self, inferred_args: Dict[str, BaseValue]) -> None:
"""Populate _op_graph with the OPGraph for _function_to_compile."""
if not self._needs_tracing():
return
self._op_graph = compile_numpy_function_into_op_graph(
self.function_to_compile,
inferred_args,
self._compilation_configuration,
self.compilation_artifacts,
)
def _patch_op_graph_input_to_accept_any_integer_input(self) -> None:
"""Patch inputs as we don't know what data we expect."""
# Can only do that if the OPGraph was created hence the test.
if self._needs_tracing():
return
# For mypy
assert self._op_graph is not None
# Cheat on Input nodes to avoid issues during inputset eval as we do not know in advance
# what the final bit width for the inputs should be
for node in self._op_graph.input_nodes.values():
for input_ in node.inputs:
if isinstance(dtype := (input_.dtype), Integer):
dtype.bit_width = 128
dtype.is_signed = True

View File

@@ -0,0 +1,215 @@
"""Test file for user-friendly numpy compilation functions"""
import numpy
import pytest
from concrete.common.debugging import format_operation_graph
from concrete.numpy.np_fhe_compiler import NPFHECompiler
def complicated_topology(x, y):
"""Mix x in an intricated way."""
intermediate = x + y
x_p_1 = intermediate + 1
x_p_2 = intermediate + 2
x_p_3 = x_p_1 + x_p_2
return (
x_p_3.astype(numpy.int32),
x_p_2.astype(numpy.int32),
(x_p_2 + 3).astype(numpy.int32),
x_p_3.astype(numpy.int32) + 67,
)
@pytest.mark.parametrize("input_shape", [(), (3, 1, 2)])
def test_np_fhe_compiler(input_shape, default_compilation_configuration):
"""Test NPFHECompiler in two subtests."""
subtest_np_fhe_compiler_1_input(input_shape, default_compilation_configuration)
subtest_np_fhe_compiler_2_inputs(input_shape, default_compilation_configuration)
def subtest_np_fhe_compiler_1_input(input_shape, default_compilation_configuration):
"""test for NPFHECompiler on one input function"""
compiler = NPFHECompiler(
lambda x: complicated_topology(x, 0),
{"x": "encrypted"},
default_compilation_configuration,
)
# For coverage when the OPGraph is not yet traced
compiler._patch_op_graph_input_to_accept_any_integer_input() # pylint: disable=protected-access
assert compiler.compilation_configuration == default_compilation_configuration
assert compiler.compilation_configuration is not default_compilation_configuration
for i in numpy.arange(5):
i = numpy.ones(input_shape, dtype=numpy.int64) * i
assert numpy.array_equal(compiler(i), complicated_topology(i, 0))
# For coverage, check that we flush the inputset when we query the OPGraph
current_op_graph = compiler.op_graph
assert current_op_graph is not compiler.op_graph
assert len(compiler._current_inputset) == 0 # pylint: disable=protected-access
# For coverage, cover case where the current inputset is empty
compiler._eval_on_current_inputset() # pylint: disable=protected-access
# Continue a bit more
for i in numpy.arange(5, 10):
i = numpy.ones(input_shape, dtype=numpy.int64) * i
assert numpy.array_equal(compiler(i), complicated_topology(i, 0))
if input_shape == ():
assert (
(got := format_operation_graph(compiler.op_graph))
== """ %0 = 67 # ClearScalar<uint7>
%1 = 2 # ClearScalar<uint2>
%2 = 3 # ClearScalar<uint2>
%3 = 1 # ClearScalar<uint1>
%4 = x # EncryptedScalar<uint4>
%5 = 0 # ClearScalar<uint1>
%6 = add(%4, %5) # EncryptedScalar<uint4>
%7 = add(%6, %1) # EncryptedScalar<uint4>
%8 = add(%6, %3) # EncryptedScalar<uint4>
%9 = astype(%7, dtype=int32) # EncryptedScalar<uint4>
%10 = add(%7, %2) # EncryptedScalar<uint4>
%11 = add(%8, %7) # EncryptedScalar<uint5>
%12 = astype(%10, dtype=int32) # EncryptedScalar<uint4>
%13 = astype(%11, dtype=int32) # EncryptedScalar<uint5>
%14 = astype(%11, dtype=int32) # EncryptedScalar<uint5>
%15 = add(%14, %0) # EncryptedScalar<uint7>
(%13, %9, %12, %15)"""
), got
else:
assert (
(got := format_operation_graph(compiler.op_graph))
== """ %0 = 67 # ClearScalar<uint7>
%1 = 2 # ClearScalar<uint2>
%2 = 3 # ClearScalar<uint2>
%3 = 1 # ClearScalar<uint1>
%4 = x # EncryptedTensor<uint4, shape=(3, 1, 2)>
%5 = 0 # ClearScalar<uint1>
%6 = add(%4, %5) # EncryptedTensor<uint4, shape=(3, 1, 2)>
%7 = add(%6, %1) # EncryptedTensor<uint4, shape=(3, 1, 2)>
%8 = add(%6, %3) # EncryptedTensor<uint4, shape=(3, 1, 2)>
%9 = astype(%7, dtype=int32) # EncryptedTensor<uint4, shape=(3, 1, 2)>
%10 = add(%7, %2) # EncryptedTensor<uint4, shape=(3, 1, 2)>
%11 = add(%8, %7) # EncryptedTensor<uint5, shape=(3, 1, 2)>
%12 = astype(%10, dtype=int32) # EncryptedTensor<uint4, shape=(3, 1, 2)>
%13 = astype(%11, dtype=int32) # EncryptedTensor<uint5, shape=(3, 1, 2)>
%14 = astype(%11, dtype=int32) # EncryptedTensor<uint5, shape=(3, 1, 2)>
%15 = add(%14, %0) # EncryptedTensor<uint7, shape=(3, 1, 2)>
(%13, %9, %12, %15)"""
), got
def subtest_np_fhe_compiler_2_inputs(input_shape, default_compilation_configuration):
"""test for NPFHECompiler on two inputs function"""
compiler = NPFHECompiler(
complicated_topology,
{"x": "encrypted", "y": "clear"},
default_compilation_configuration,
)
# For coverage when the OPGraph is not yet traced
compiler._patch_op_graph_input_to_accept_any_integer_input() # pylint: disable=protected-access
assert compiler.compilation_configuration == default_compilation_configuration
assert compiler.compilation_configuration is not default_compilation_configuration
for i, j in zip(numpy.arange(5), numpy.arange(5, 10)):
i = numpy.ones(input_shape, dtype=numpy.int64) * i
j = numpy.ones(input_shape, dtype=numpy.int64) * j
assert numpy.array_equal(compiler(i, j), complicated_topology(i, j))
# For coverage, check that we flush the inputset when we query the OPGraph
current_op_graph = compiler.op_graph
assert current_op_graph is not compiler.op_graph
assert len(compiler._current_inputset) == 0 # pylint: disable=protected-access
# For coverage, cover case where the current inputset is empty
compiler._eval_on_current_inputset() # pylint: disable=protected-access
# Continue a bit more
for i, j in zip(numpy.arange(5, 10), numpy.arange(5)):
i = numpy.ones(input_shape, dtype=numpy.int64) * i
j = numpy.ones(input_shape, dtype=numpy.int64) * j
assert numpy.array_equal(compiler(i, j), complicated_topology(i, j))
if input_shape == ():
assert (
(got := format_operation_graph(compiler.op_graph))
== """ %0 = 67 # ClearScalar<uint7>
%1 = 2 # ClearScalar<uint2>
%2 = 3 # ClearScalar<uint2>
%3 = 1 # ClearScalar<uint1>
%4 = x # EncryptedScalar<uint4>
%5 = y # ClearScalar<uint4>
%6 = add(%4, %5) # EncryptedScalar<uint4>
%7 = add(%6, %1) # EncryptedScalar<uint4>
%8 = add(%6, %3) # EncryptedScalar<uint4>
%9 = astype(%7, dtype=int32) # EncryptedScalar<uint4>
%10 = add(%7, %2) # EncryptedScalar<uint5>
%11 = add(%8, %7) # EncryptedScalar<uint5>
%12 = astype(%10, dtype=int32) # EncryptedScalar<uint5>
%13 = astype(%11, dtype=int32) # EncryptedScalar<uint5>
%14 = astype(%11, dtype=int32) # EncryptedScalar<uint5>
%15 = add(%14, %0) # EncryptedScalar<uint7>
(%13, %9, %12, %15)"""
), got
else:
assert (
(got := format_operation_graph(compiler.op_graph))
== """ %0 = 67 # ClearScalar<uint7>
%1 = 2 # ClearScalar<uint2>
%2 = 3 # ClearScalar<uint2>
%3 = 1 # ClearScalar<uint1>
%4 = x # EncryptedTensor<uint4, shape=(3, 1, 2)>
%5 = y # ClearTensor<uint4, shape=(3, 1, 2)>
%6 = add(%4, %5) # EncryptedTensor<uint4, shape=(3, 1, 2)>
%7 = add(%6, %1) # EncryptedTensor<uint4, shape=(3, 1, 2)>
%8 = add(%6, %3) # EncryptedTensor<uint4, shape=(3, 1, 2)>
%9 = astype(%7, dtype=int32) # EncryptedTensor<uint4, shape=(3, 1, 2)>
%10 = add(%7, %2) # EncryptedTensor<uint5, shape=(3, 1, 2)>
%11 = add(%8, %7) # EncryptedTensor<uint5, shape=(3, 1, 2)>
%12 = astype(%10, dtype=int32) # EncryptedTensor<uint5, shape=(3, 1, 2)>
%13 = astype(%11, dtype=int32) # EncryptedTensor<uint5, shape=(3, 1, 2)>
%14 = astype(%11, dtype=int32) # EncryptedTensor<uint5, shape=(3, 1, 2)>
%15 = add(%14, %0) # EncryptedTensor<uint7, shape=(3, 1, 2)>
(%13, %9, %12, %15)"""
), got
def remaining_inputset_size(inputset_len):
"""Small function to generate test cases below for remaining inputset length."""
return inputset_len % NPFHECompiler.INPUTSET_SIZE_BEFORE_AUTO_BOUND_UPDATE
@pytest.mark.parametrize(
"inputset_len, expected_remaining_inputset_len",
[
(42, remaining_inputset_size(42)),
(128, remaining_inputset_size(128)),
(234, remaining_inputset_size(234)),
],
)
def test_np_fhe_compiler_auto_flush(
inputset_len,
expected_remaining_inputset_len,
default_compilation_configuration,
):
"""Test the auto flush of NPFHECompiler once the inputset is 128 elements."""
compiler = NPFHECompiler(
lambda x: x // 2,
{"x": "encrypted"},
default_compilation_configuration,
)
for i in numpy.arange(inputset_len):
assert numpy.array_equal(compiler(i), i // 2)
# Check the inputset was properly flushed
assert (
len(compiler._current_inputset) # pylint: disable=protected-access
== expected_remaining_inputset_len
)