refactor: do not require an iterator for inputset just iterable

This commit is contained in:
Arthur Meyre
2021-09-15 16:56:18 +02:00
parent 5871d4e187
commit 2296bcd457
2 changed files with 21 additions and 19 deletions

View File

@@ -1,6 +1,6 @@
"""Code to evaluate the IR graph on inputsets."""
from typing import Any, Callable, Dict, Iterator, Tuple
from typing import Any, Callable, Dict, Iterable, Tuple
from ..debugging import custom_assert
from ..operator_graph import OPGraph
@@ -9,7 +9,7 @@ from ..representation.intermediate import IntermediateNode
def eval_op_graph_bounds_on_inputset(
op_graph: OPGraph,
inputset: Iterator[Tuple[Any, ...]],
inputset: Iterable[Tuple[Any, ...]],
min_func: Callable[[Any, Any], Any] = min,
max_func: Callable[[Any, Any], Any] = max,
) -> Dict[IntermediateNode, Dict[str, Any]]:
@@ -20,8 +20,8 @@ def eval_op_graph_bounds_on_inputset(
Args:
op_graph (OPGraph): The graph for which we want to determine the bounds
inputset (Iterator[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
needs to be an iterator on tuples which are of the same length than the number of
inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
needs to be an iterable on tuples which are of the same length than the number of
parameters in the function, and in the same order than these same parameters
min_func (Callable[[Any, Any], Any], optional): custom function to compute a scalar minimum
between two values that can be encountered during evaluation (for e.g. numpy or torch
@@ -48,7 +48,9 @@ def eval_op_graph_bounds_on_inputset(
# TODO: do we want to check coherence between the input data type and the corresponding Input ir
# node expected data type ? Not considering bit_width as they may not make sense at this stage
first_input_data = dict(enumerate(next(inputset)))
inputset_iterator = iter(inputset)
first_input_data = dict(enumerate(next(inputset_iterator)))
check_inputset_input_len_is_valid(first_input_data.values())
first_output = op_graph.evaluate(first_input_data)
@@ -59,7 +61,7 @@ def eval_op_graph_bounds_on_inputset(
for node, value in first_output.items()
}
for input_data in inputset:
for input_data in inputset_iterator:
current_input_data = dict(enumerate(input_data))
check_inputset_input_len_is_valid(current_input_data.values())
current_output = op_graph.evaluate(current_input_data)

View File

@@ -1,7 +1,7 @@
"""numpy compilation function."""
import traceback
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
import numpy
from zamalang import CompilerEngine
@@ -55,7 +55,7 @@ def numpy_min_func(lhs: Any, rhs: Any) -> Any:
def _compile_numpy_function_into_op_graph_internal(
function_to_compile: Callable,
function_parameters: Dict[str, BaseValue],
inputset: Iterator[Tuple[Any, ...]],
inputset: Iterable[Tuple[Any, ...]],
compilation_configuration: CompilationConfiguration,
compilation_artifacts: CompilationArtifacts,
) -> OPGraph:
@@ -65,8 +65,8 @@ def _compile_numpy_function_into_op_graph_internal(
function_to_compile (Callable): The function to compile
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
inputset (Iterator[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
needs to be an iterator on tuples which are of the same length than the number of
inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
needs to be an iterable on tuples which are of the same length than the number of
parameters in the function, and in the same order than these same parameters
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
during compilation
@@ -143,7 +143,7 @@ def _compile_numpy_function_into_op_graph_internal(
def compile_numpy_function_into_op_graph(
function_to_compile: Callable,
function_parameters: Dict[str, BaseValue],
inputset: Iterator[Tuple[Any, ...]],
inputset: Iterable[Tuple[Any, ...]],
compilation_configuration: Optional[CompilationConfiguration] = None,
compilation_artifacts: Optional[CompilationArtifacts] = None,
) -> OPGraph:
@@ -153,8 +153,8 @@ def compile_numpy_function_into_op_graph(
function_to_compile (Callable): The function to compile
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
inputset (Iterator[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
needs to be an iterator on tuples which are of the same length than the number of
inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
needs to be an iterable on tuples which are of the same length than the number of
parameters in the function, and in the same order than these same parameters
compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use
during compilation
@@ -205,7 +205,7 @@ def compile_numpy_function_into_op_graph(
def _compile_numpy_function_internal(
function_to_compile: Callable,
function_parameters: Dict[str, BaseValue],
inputset: Iterator[Tuple[Any, ...]],
inputset: Iterable[Tuple[Any, ...]],
compilation_configuration: CompilationConfiguration,
compilation_artifacts: CompilationArtifacts,
show_mlir: bool,
@@ -216,8 +216,8 @@ def _compile_numpy_function_internal(
function_to_compile (Callable): The function you want to compile
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
inputset (Iterator[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
needs to be an iterator on tuples which are of the same length than the number of
inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
needs to be an iterable on tuples which are of the same length than the number of
parameters in the function, and in the same order than these same parameters
compilation_configuration (CompilationConfiguration): Configuration object to use
during compilation
@@ -260,7 +260,7 @@ def _compile_numpy_function_internal(
def compile_numpy_function(
function_to_compile: Callable,
function_parameters: Dict[str, BaseValue],
inputset: Iterator[Tuple[Any, ...]],
inputset: Iterable[Tuple[Any, ...]],
compilation_configuration: Optional[CompilationConfiguration] = None,
compilation_artifacts: Optional[CompilationArtifacts] = None,
show_mlir: bool = False,
@@ -271,8 +271,8 @@ def compile_numpy_function(
function_to_compile (Callable): The function to compile
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
inputset (Iterator[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
needs to be an iterator on tuples which are of the same length than the number of
inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
needs to be an iterable on tuples which are of the same length than the number of
parameters in the function, and in the same order than these same parameters
compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use
during compilation