From 2296bcd45736d959d4884dc7a29a4983f5717c08 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Wed, 15 Sep 2021 16:56:18 +0200 Subject: [PATCH] refactor: do not require an iterator for inputset just iterable --- .../bounds_measurement/inputset_eval.py | 14 +++++----- concrete/numpy/compile.py | 26 +++++++++---------- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/concrete/common/bounds_measurement/inputset_eval.py b/concrete/common/bounds_measurement/inputset_eval.py index 2103760af..8f83e02fc 100644 --- a/concrete/common/bounds_measurement/inputset_eval.py +++ b/concrete/common/bounds_measurement/inputset_eval.py @@ -1,6 +1,6 @@ """Code to evaluate the IR graph on inputsets.""" -from typing import Any, Callable, Dict, Iterator, Tuple +from typing import Any, Callable, Dict, Iterable, Tuple from ..debugging import custom_assert from ..operator_graph import OPGraph @@ -9,7 +9,7 @@ from ..representation.intermediate import IntermediateNode def eval_op_graph_bounds_on_inputset( op_graph: OPGraph, - inputset: Iterator[Tuple[Any, ...]], + inputset: Iterable[Tuple[Any, ...]], min_func: Callable[[Any, Any], Any] = min, max_func: Callable[[Any, Any], Any] = max, ) -> Dict[IntermediateNode, Dict[str, Any]]: @@ -20,8 +20,8 @@ def eval_op_graph_bounds_on_inputset( Args: op_graph (OPGraph): The graph for which we want to determine the bounds - inputset (Iterator[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It - needs to be an iterator on tuples which are of the same length than the number of + inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It + needs to be an iterable on tuples which are of the same length than the number of parameters in the function, and in the same order than these same parameters min_func (Callable[[Any, Any], Any], optional): custom function to compute a scalar minimum between two values that can be encountered during evaluation (for e.g. numpy or torch @@ -48,7 +48,9 @@ def eval_op_graph_bounds_on_inputset( # TODO: do we want to check coherence between the input data type and the corresponding Input ir # node expected data type ? Not considering bit_width as they may not make sense at this stage - first_input_data = dict(enumerate(next(inputset))) + inputset_iterator = iter(inputset) + + first_input_data = dict(enumerate(next(inputset_iterator))) check_inputset_input_len_is_valid(first_input_data.values()) first_output = op_graph.evaluate(first_input_data) @@ -59,7 +61,7 @@ def eval_op_graph_bounds_on_inputset( for node, value in first_output.items() } - for input_data in inputset: + for input_data in inputset_iterator: current_input_data = dict(enumerate(input_data)) check_inputset_input_len_is_valid(current_input_data.values()) current_output = op_graph.evaluate(current_input_data) diff --git a/concrete/numpy/compile.py b/concrete/numpy/compile.py index ea3f964a9..03a894464 100644 --- a/concrete/numpy/compile.py +++ b/concrete/numpy/compile.py @@ -1,7 +1,7 @@ """numpy compilation function.""" import traceback -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple import numpy from zamalang import CompilerEngine @@ -55,7 +55,7 @@ def numpy_min_func(lhs: Any, rhs: Any) -> Any: def _compile_numpy_function_into_op_graph_internal( function_to_compile: Callable, function_parameters: Dict[str, BaseValue], - inputset: Iterator[Tuple[Any, ...]], + inputset: Iterable[Tuple[Any, ...]], compilation_configuration: CompilationConfiguration, compilation_artifacts: CompilationArtifacts, ) -> OPGraph: @@ -65,8 +65,8 @@ def _compile_numpy_function_into_op_graph_internal( function_to_compile (Callable): The function to compile function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the function is e.g. an EncryptedScalar holding a 7bits unsigned Integer - inputset (Iterator[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It - needs to be an iterator on tuples which are of the same length than the number of + inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It + needs to be an iterable on tuples which are of the same length than the number of parameters in the function, and in the same order than these same parameters compilation_artifacts (CompilationArtifacts): Artifacts object to fill during compilation @@ -143,7 +143,7 @@ def _compile_numpy_function_into_op_graph_internal( def compile_numpy_function_into_op_graph( function_to_compile: Callable, function_parameters: Dict[str, BaseValue], - inputset: Iterator[Tuple[Any, ...]], + inputset: Iterable[Tuple[Any, ...]], compilation_configuration: Optional[CompilationConfiguration] = None, compilation_artifacts: Optional[CompilationArtifacts] = None, ) -> OPGraph: @@ -153,8 +153,8 @@ def compile_numpy_function_into_op_graph( function_to_compile (Callable): The function to compile function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the function is e.g. an EncryptedScalar holding a 7bits unsigned Integer - inputset (Iterator[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It - needs to be an iterator on tuples which are of the same length than the number of + inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It + needs to be an iterable on tuples which are of the same length than the number of parameters in the function, and in the same order than these same parameters compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use during compilation @@ -205,7 +205,7 @@ def compile_numpy_function_into_op_graph( def _compile_numpy_function_internal( function_to_compile: Callable, function_parameters: Dict[str, BaseValue], - inputset: Iterator[Tuple[Any, ...]], + inputset: Iterable[Tuple[Any, ...]], compilation_configuration: CompilationConfiguration, compilation_artifacts: CompilationArtifacts, show_mlir: bool, @@ -216,8 +216,8 @@ def _compile_numpy_function_internal( function_to_compile (Callable): The function you want to compile function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the function is e.g. an EncryptedScalar holding a 7bits unsigned Integer - inputset (Iterator[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It - needs to be an iterator on tuples which are of the same length than the number of + inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It + needs to be an iterable on tuples which are of the same length than the number of parameters in the function, and in the same order than these same parameters compilation_configuration (CompilationConfiguration): Configuration object to use during compilation @@ -260,7 +260,7 @@ def _compile_numpy_function_internal( def compile_numpy_function( function_to_compile: Callable, function_parameters: Dict[str, BaseValue], - inputset: Iterator[Tuple[Any, ...]], + inputset: Iterable[Tuple[Any, ...]], compilation_configuration: Optional[CompilationConfiguration] = None, compilation_artifacts: Optional[CompilationArtifacts] = None, show_mlir: bool = False, @@ -271,8 +271,8 @@ def compile_numpy_function( function_to_compile (Callable): The function to compile function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the function is e.g. an EncryptedScalar holding a 7bits unsigned Integer - inputset (Iterator[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It - needs to be an iterator on tuples which are of the same length than the number of + inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It + needs to be an iterable on tuples which are of the same length than the number of parameters in the function, and in the same order than these same parameters compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use during compilation