mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
refactor: do not require an iterator for inputset just iterable
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""Code to evaluate the IR graph on inputsets."""
|
||||
|
||||
from typing import Any, Callable, Dict, Iterator, Tuple
|
||||
from typing import Any, Callable, Dict, Iterable, Tuple
|
||||
|
||||
from ..debugging import custom_assert
|
||||
from ..operator_graph import OPGraph
|
||||
@@ -9,7 +9,7 @@ from ..representation.intermediate import IntermediateNode
|
||||
|
||||
def eval_op_graph_bounds_on_inputset(
|
||||
op_graph: OPGraph,
|
||||
inputset: Iterator[Tuple[Any, ...]],
|
||||
inputset: Iterable[Tuple[Any, ...]],
|
||||
min_func: Callable[[Any, Any], Any] = min,
|
||||
max_func: Callable[[Any, Any], Any] = max,
|
||||
) -> Dict[IntermediateNode, Dict[str, Any]]:
|
||||
@@ -20,8 +20,8 @@ def eval_op_graph_bounds_on_inputset(
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): The graph for which we want to determine the bounds
|
||||
inputset (Iterator[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
|
||||
needs to be an iterator on tuples which are of the same length than the number of
|
||||
inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
|
||||
needs to be an iterable on tuples which are of the same length than the number of
|
||||
parameters in the function, and in the same order than these same parameters
|
||||
min_func (Callable[[Any, Any], Any], optional): custom function to compute a scalar minimum
|
||||
between two values that can be encountered during evaluation (for e.g. numpy or torch
|
||||
@@ -48,7 +48,9 @@ def eval_op_graph_bounds_on_inputset(
|
||||
# TODO: do we want to check coherence between the input data type and the corresponding Input ir
|
||||
# node expected data type ? Not considering bit_width as they may not make sense at this stage
|
||||
|
||||
first_input_data = dict(enumerate(next(inputset)))
|
||||
inputset_iterator = iter(inputset)
|
||||
|
||||
first_input_data = dict(enumerate(next(inputset_iterator)))
|
||||
check_inputset_input_len_is_valid(first_input_data.values())
|
||||
first_output = op_graph.evaluate(first_input_data)
|
||||
|
||||
@@ -59,7 +61,7 @@ def eval_op_graph_bounds_on_inputset(
|
||||
for node, value in first_output.items()
|
||||
}
|
||||
|
||||
for input_data in inputset:
|
||||
for input_data in inputset_iterator:
|
||||
current_input_data = dict(enumerate(input_data))
|
||||
check_inputset_input_len_is_valid(current_input_data.values())
|
||||
current_output = op_graph.evaluate(current_input_data)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""numpy compilation function."""
|
||||
|
||||
import traceback
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import numpy
|
||||
from zamalang import CompilerEngine
|
||||
@@ -55,7 +55,7 @@ def numpy_min_func(lhs: Any, rhs: Any) -> Any:
|
||||
def _compile_numpy_function_into_op_graph_internal(
|
||||
function_to_compile: Callable,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
inputset: Iterator[Tuple[Any, ...]],
|
||||
inputset: Iterable[Tuple[Any, ...]],
|
||||
compilation_configuration: CompilationConfiguration,
|
||||
compilation_artifacts: CompilationArtifacts,
|
||||
) -> OPGraph:
|
||||
@@ -65,8 +65,8 @@ def _compile_numpy_function_into_op_graph_internal(
|
||||
function_to_compile (Callable): The function to compile
|
||||
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
|
||||
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
|
||||
inputset (Iterator[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
|
||||
needs to be an iterator on tuples which are of the same length than the number of
|
||||
inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
|
||||
needs to be an iterable on tuples which are of the same length than the number of
|
||||
parameters in the function, and in the same order than these same parameters
|
||||
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
|
||||
during compilation
|
||||
@@ -143,7 +143,7 @@ def _compile_numpy_function_into_op_graph_internal(
|
||||
def compile_numpy_function_into_op_graph(
|
||||
function_to_compile: Callable,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
inputset: Iterator[Tuple[Any, ...]],
|
||||
inputset: Iterable[Tuple[Any, ...]],
|
||||
compilation_configuration: Optional[CompilationConfiguration] = None,
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
) -> OPGraph:
|
||||
@@ -153,8 +153,8 @@ def compile_numpy_function_into_op_graph(
|
||||
function_to_compile (Callable): The function to compile
|
||||
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
|
||||
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
|
||||
inputset (Iterator[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
|
||||
needs to be an iterator on tuples which are of the same length than the number of
|
||||
inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
|
||||
needs to be an iterable on tuples which are of the same length than the number of
|
||||
parameters in the function, and in the same order than these same parameters
|
||||
compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use
|
||||
during compilation
|
||||
@@ -205,7 +205,7 @@ def compile_numpy_function_into_op_graph(
|
||||
def _compile_numpy_function_internal(
|
||||
function_to_compile: Callable,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
inputset: Iterator[Tuple[Any, ...]],
|
||||
inputset: Iterable[Tuple[Any, ...]],
|
||||
compilation_configuration: CompilationConfiguration,
|
||||
compilation_artifacts: CompilationArtifacts,
|
||||
show_mlir: bool,
|
||||
@@ -216,8 +216,8 @@ def _compile_numpy_function_internal(
|
||||
function_to_compile (Callable): The function you want to compile
|
||||
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
|
||||
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
|
||||
inputset (Iterator[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
|
||||
needs to be an iterator on tuples which are of the same length than the number of
|
||||
inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
|
||||
needs to be an iterable on tuples which are of the same length than the number of
|
||||
parameters in the function, and in the same order than these same parameters
|
||||
compilation_configuration (CompilationConfiguration): Configuration object to use
|
||||
during compilation
|
||||
@@ -260,7 +260,7 @@ def _compile_numpy_function_internal(
|
||||
def compile_numpy_function(
|
||||
function_to_compile: Callable,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
inputset: Iterator[Tuple[Any, ...]],
|
||||
inputset: Iterable[Tuple[Any, ...]],
|
||||
compilation_configuration: Optional[CompilationConfiguration] = None,
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
show_mlir: bool = False,
|
||||
@@ -271,8 +271,8 @@ def compile_numpy_function(
|
||||
function_to_compile (Callable): The function to compile
|
||||
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
|
||||
function is e.g. an EncryptedScalar holding a 7bits unsigned Integer
|
||||
inputset (Iterator[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
|
||||
needs to be an iterator on tuples which are of the same length than the number of
|
||||
inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It
|
||||
needs to be an iterable on tuples which are of the same length than the number of
|
||||
parameters in the function, and in the same order than these same parameters
|
||||
compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use
|
||||
during compilation
|
||||
|
||||
Reference in New Issue
Block a user