From 0061e01d62db5bc75998cb85ab1a3084cbdd1853 Mon Sep 17 00:00:00 2001 From: Umut Date: Wed, 22 Sep 2021 12:10:11 +0300 Subject: [PATCH] feat: implement checking coherence between inputset and parameters --- concrete/common/__init__.py | 2 +- .../bounds_measurement/inputset_eval.py | 127 ++++++++++++++- concrete/common/compilation/configuration.py | 3 + concrete/common/data_types/dtypes_helpers.py | 67 ++++++-- concrete/common/values/scalars.py | 11 ++ concrete/numpy/compile.py | 5 +- concrete/numpy/np_dtypes_helpers.py | 24 ++- .../bounds_measurement/test_inputset_eval.py | 146 +++++++++++++++++- .../common/data_types/test_dtypes_helpers.py | 1 - tests/numpy/test_compile.py | 2 +- 10 files changed, 360 insertions(+), 28 deletions(-) diff --git a/concrete/common/__init__.py b/concrete/common/__init__.py index 0dcf40159..ce8e30ea3 100644 --- a/concrete/common/__init__.py +++ b/concrete/common/__init__.py @@ -1,3 +1,3 @@ """Module for shared data structures and code.""" -from . import compilation, data_types, debugging, representation +from . import compilation, data_types, debugging, representation, values from .common_helpers import check_op_graph_is_integer_program, is_a_power_of_2 diff --git a/concrete/common/bounds_measurement/inputset_eval.py b/concrete/common/bounds_measurement/inputset_eval.py index f285de2be..4be083f5f 100644 --- a/concrete/common/bounds_measurement/inputset_eval.py +++ b/concrete/common/bounds_measurement/inputset_eval.py @@ -1,17 +1,105 @@ """Code to evaluate the IR graph on inputsets.""" +import sys from typing import Any, Callable, Dict, Iterable, Tuple +from ..compilation import CompilationConfiguration +from ..data_types.dtypes_helpers import ( + get_base_value_for_python_constant_data, + is_data_type_compatible_with, +) from ..debugging import custom_assert from ..operator_graph import OPGraph from ..representation.intermediate import IntermediateNode +def _check_input_coherency( + input_to_check: Dict[str, Any], + parameters: Dict[str, Any], + get_base_value_for_constant_data_func: Callable[[Any], Any], +): + """Check whether `input_to_check` is coherent with `parameters`. + + This function works by iterating over each constant of the input, + determining base value of the constant using `get_base_value_for_constant_data_func` and + checking if the base value of the contant is compatible with the base value of the parameter. + + Args: + input_to_check (Dict[str, Any]): input to check coherency of + parameters (Dict[str, Any]): parameters and their expected base values + get_base_value_for_constant_data_func (Callable[[Any], Any]): + function to get the base value of python objects. + + Returns: + List[str]: List of warnings about the coherency + """ + + warnings = [] + for parameter_name, value in input_to_check.items(): + parameter_base_value = parameters[parameter_name] + + base_value_class = get_base_value_for_constant_data_func(value) + base_value = base_value_class(is_encrypted=parameter_base_value.is_encrypted) + + if base_value.shape != parameter_base_value.shape or not is_data_type_compatible_with( + base_value.data_type, parameter_base_value.data_type + ): + warnings.append( + f"expected {str(parameter_base_value)} " + f"for parameter `{parameter_name}` " + f"but got {str(base_value)} " + f"which is not compatible" + ) + return warnings + + +def _print_input_coherency_warnings( + current_input_index: int, + current_input_data: Dict[int, Any], + parameters: Dict[str, Any], + parameter_index_to_parameter_name: Dict[int, str], + get_base_value_for_constant_data_func: Callable[[Any], Any], +): + """Print coherency warning for `input_to_check` against `parameters`. + + Args: + current_input_index (int): index of the current input on the inputset + current_input_data (Dict[int, Any]): input to print coherency warnings of + parameters (Dict[str, Any]): parameters and their expected base values + parameter_index_to_parameter_name (Dict[int, str]): + dict to get parameter names from parameter indices + get_base_value_for_constant_data_func (Callable[[Any], Any]): + function to get the base value of python objects. + + Returns: + None + """ + + current_input_named_data = { + parameter_index_to_parameter_name[index]: data for index, data in current_input_data.items() + } + + problems = _check_input_coherency( + current_input_named_data, + parameters, + get_base_value_for_constant_data_func, + ) + for problem in problems: + sys.stderr.write( + f"Warning: Input #{current_input_index} (0-indexed) " + f"is not coherent with the hinted parameters ({problem})\n", + ) + + def eval_op_graph_bounds_on_inputset( op_graph: OPGraph, inputset: Iterable[Tuple[Any, ...]], + compilation_configuration: CompilationConfiguration, min_func: Callable[[Any, Any], Any] = min, max_func: Callable[[Any, Any], Any] = max, + get_base_value_for_constant_data_func: Callable[ + [Any], Any + ] = get_base_value_for_python_constant_data, ) -> Tuple[int, Dict[IntermediateNode, Dict[str, Any]]]: """Evaluate the bounds with a inputset. @@ -23,12 +111,16 @@ def eval_op_graph_bounds_on_inputset( inputset (Iterable[Tuple[Any, ...]]): The inputset over which op_graph is evaluated. It needs to be an iterable on tuples which are of the same length than the number of parameters in the function, and in the same order than these same parameters + compilation_configuration (CompilationConfiguration): Configuration object to use + during determining input checking strategy min_func (Callable[[Any, Any], Any], optional): custom function to compute a scalar minimum between two values that can be encountered during evaluation (for e.g. numpy or torch tensors). Defaults to min. max_func (Callable[[Any, Any], Any], optional): custom function to compute a scalar maximum between two values that can be encountered during evaluation (for e.g. numpy or torch tensors). Defaults to max. + get_base_value_for_constant_data_func (Callable[[Any], Any], optional): custom function + to compute the base value of a python object. Returns: Tuple[int, Dict[IntermediateNode, Dict[str, Any]]]: number of inputs in the inputset and @@ -49,14 +141,29 @@ def eval_op_graph_bounds_on_inputset( # TODO: do we want to check coherence between the input data type and the corresponding Input ir # node expected data type ? Not considering bit_width as they may not make sense at this stage - inputset_size = 0 - inputset_iterator = iter(inputset) + parameter_index_to_parameter_name = { + index: input_node.input_name for index, input_node in op_graph.input_nodes.items() + } + parameters = { + input_node.input_name: input_node.inputs[0] for input_node in op_graph.input_nodes.values() + } - first_input_data = dict(enumerate(next(inputset_iterator))) + inputset_iterator = iter(inputset) + inputset_size = 0 + + current_input_data = dict(enumerate(next(inputset_iterator))) inputset_size += 1 - check_inputset_input_len_is_valid(first_input_data.values()) - first_output = op_graph.evaluate(first_input_data) + check_inputset_input_len_is_valid(current_input_data.values()) + _print_input_coherency_warnings( + inputset_size - 1, + current_input_data, + parameters, + parameter_index_to_parameter_name, + get_base_value_for_constant_data_func, + ) + + first_output = op_graph.evaluate(current_input_data) # We evaluate the min and max func to be able to resolve the tensors min and max rather than # having the tensor itself as the stored min and max values. @@ -68,7 +175,17 @@ def eval_op_graph_bounds_on_inputset( for input_data in inputset_iterator: inputset_size += 1 current_input_data = dict(enumerate(input_data)) + check_inputset_input_len_is_valid(current_input_data.values()) + if compilation_configuration.check_every_input_in_inputset: + _print_input_coherency_warnings( + inputset_size - 1, + current_input_data, + parameters, + parameter_index_to_parameter_name, + get_base_value_for_constant_data_func, + ) + current_output = op_graph.evaluate(current_input_data) for node, value in current_output.items(): node_bounds[node]["min"] = min_func(node_bounds[node]["min"], value) diff --git a/concrete/common/compilation/configuration.py b/concrete/common/compilation/configuration.py index c600698e6..07f909e6d 100644 --- a/concrete/common/compilation/configuration.py +++ b/concrete/common/compilation/configuration.py @@ -6,11 +6,14 @@ class CompilationConfiguration: dump_artifacts_on_unexpected_failures: bool enable_topological_optimizations: bool + check_every_input_in_inputset: bool def __init__( self, dump_artifacts_on_unexpected_failures: bool = True, enable_topological_optimizations: bool = True, + check_every_input_in_inputset: bool = False, ): self.dump_artifacts_on_unexpected_failures = dump_artifacts_on_unexpected_failures self.enable_topological_optimizations = enable_topological_optimizations + self.check_every_input_in_inputset = check_every_input_in_inputset diff --git a/concrete/common/data_types/dtypes_helpers.py b/concrete/common/data_types/dtypes_helpers.py index 49c096384..8f5ed10a9 100644 --- a/concrete/common/data_types/dtypes_helpers.py +++ b/concrete/common/data_types/dtypes_helpers.py @@ -2,7 +2,7 @@ from copy import deepcopy from functools import partial -from typing import Callable, Union, cast +from typing import Callable, List, Union, cast from ..debugging.custom_assert import custom_assert from ..values import ( @@ -306,7 +306,9 @@ def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> ) -def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]) -> BaseDataType: +def get_base_data_type_for_python_constant_data( + constant_data: Union[int, float, List[int], List[float]] +) -> BaseDataType: """Determine the BaseDataType to hold the input constant data. Args: @@ -318,10 +320,17 @@ def get_base_data_type_for_python_constant_data(constant_data: Union[int, float] """ constant_data_type: BaseDataType custom_assert( - isinstance(constant_data, (int, float)), + isinstance(constant_data, (int, float, list)), f"Unsupported constant data of type {type(constant_data)}", ) - if isinstance(constant_data, int): + + if isinstance(constant_data, list): + custom_assert(len(constant_data) > 0, "Data type of empty list cannot be detected") + constant_data_type = get_base_data_type_for_python_constant_data(constant_data[0]) + for value in constant_data: + other_data_type = get_base_data_type_for_python_constant_data(value) + constant_data_type = find_type_to_hold_both_lossy(constant_data_type, other_data_type) + elif isinstance(constant_data, int): is_signed = constant_data < 0 constant_data_type = Integer( get_bits_to_represent_value_as_integer(constant_data, is_signed), is_signed @@ -332,22 +341,29 @@ def get_base_data_type_for_python_constant_data(constant_data: Union[int, float] def get_base_value_for_python_constant_data( - constant_data: Union[int, float] -) -> Callable[..., ScalarValue]: - """Wrap the BaseDataType to hold the input constant data in a ScalarValue partial. + constant_data: Union[int, float, List[int], List[float]] +) -> Callable[..., BaseValue]: + """Wrap the BaseDataType to hold the input constant data in BaseValue partial. - The returned object can then be instantiated as an Encrypted or Clear version of the ScalarValue - by calling it with the proper arguments forwarded to the ScalarValue `__init__` function + The returned object can then be instantiated as an Encrypted or Clear version + by calling it with the proper arguments forwarded to the BaseValue `__init__` function Args: - constant_data (Union[int, float]): The constant data for which to determine the - corresponding ScalarValue and BaseDataType. + constant_data (Union[int, float, List[int], List[float]]): The constant data + for which to determine the corresponding Value. Returns: - Callable[..., ScalarValue]: A partial object that will return the proper ScalarValue when - called with `encrypted` as keyword argument (forwarded to the ScalarValue `__init__` + Callable[..., BaseValue]: A partial object that will return the proper BaseValue when + called with `is_encrypted` as keyword argument (forwarded to the BaseValue `__init__` method). """ + + if isinstance(constant_data, list): + assert len(constant_data) > 0 + constant_shape = (len(constant_data),) + constant_data_type = get_base_data_type_for_python_constant_data(constant_data) + return partial(TensorValue, data_type=constant_data_type, shape=constant_shape) + constant_data_type = get_base_data_type_for_python_constant_data(constant_data) return partial(ScalarValue, data_type=constant_data_type) @@ -359,3 +375,28 @@ def get_type_constructor_for_python_constant_data(constant_data: Union[int, floa constant_data (Any): The data for which we want to determine the type constructor. """ return type(constant_data) + + +def is_data_type_compatible_with( + dtype: BaseDataType, + other: BaseDataType, +) -> bool: + """Determine whether dtype is compatible with other. + + `dtype` being compatible with `other` means `other` can hold every value of `dtype` + (e.g., uint2 is compatible with uint4 and int4) + (e.g., int2 is compatible with int4 but not with uint4) + + Note that this function is not symetric. + (e.g., uint2 is compatible with uint4, but uint4 is not compatible with uint2) + + Args: + dtype (BaseDataType): dtype to check compatiblity + other (BaseDataType): dtype to check compatiblity against + + Returns: + bool: Whether the dtype is compatible with other or not + """ + + combination = find_type_to_hold_both_lossy(dtype, other) + return other == combination diff --git a/concrete/common/values/scalars.py b/concrete/common/values/scalars.py index 1b057fa99..e1b541e6d 100644 --- a/concrete/common/values/scalars.py +++ b/concrete/common/values/scalars.py @@ -1,5 +1,7 @@ """Module that defines the scalar values in a program.""" +from typing import Tuple + from ..data_types.base import BaseDataType from .base import BaseValue @@ -14,6 +16,15 @@ class ScalarValue(BaseValue): encrypted_str = "Encrypted" if self._is_encrypted else "Clear" return f"{encrypted_str}Scalar<{self.data_type!r}>" + @property + def shape(self) -> Tuple[int, ...]: + """Return the ScalarValue shape property. + + Returns: + Tuple[int, ...]: The ScalarValue shape which is `()`. + """ + return () + def make_clear_scalar(data_type: BaseDataType) -> ScalarValue: """Create a clear ScalarValue. diff --git a/concrete/numpy/compile.py b/concrete/numpy/compile.py index fd1074083..a00cbc030 100644 --- a/concrete/numpy/compile.py +++ b/concrete/numpy/compile.py @@ -24,6 +24,7 @@ from ..common.values import BaseValue from ..numpy.tracing import trace_numpy_function from .np_dtypes_helpers import ( get_base_data_type_for_numpy_or_python_constant_data, + get_base_value_for_numpy_or_python_constant_data, get_type_constructor_for_numpy_or_python_constant_data, ) @@ -112,8 +113,10 @@ def _compile_numpy_function_into_op_graph_internal( inputset_size, node_bounds = eval_op_graph_bounds_on_inputset( op_graph, inputset, + compilation_configuration=compilation_configuration, min_func=numpy_min_func, max_func=numpy_max_func, + get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data, ) # Check inputset size @@ -134,7 +137,7 @@ def _compile_numpy_function_into_op_graph_internal( minimum_required_inputset_size = min(inputset_size_upper_limit, 10) if inputset_size < minimum_required_inputset_size: sys.stderr.write( - f"Provided inputset contains too few inputs " + f"Warning: Provided inputset contains too few inputs " f"(it should have had at least {minimum_required_inputset_size} " f"but it only had {inputset_size})\n" ) diff --git a/concrete/numpy/np_dtypes_helpers.py b/concrete/numpy/np_dtypes_helpers.py index 2d2c4eab7..cbcdb939d 100644 --- a/concrete/numpy/np_dtypes_helpers.py +++ b/concrete/numpy/np_dtypes_helpers.py @@ -10,6 +10,7 @@ from numpy.typing import DTypeLike from ..common.data_types.base import BaseDataType from ..common.data_types.dtypes_helpers import ( BASE_DATA_TYPES, + find_type_to_hold_both_lossy, get_base_data_type_for_python_constant_data, get_base_value_for_python_constant_data, get_type_constructor_for_python_constant_data, @@ -116,12 +117,26 @@ def get_base_data_type_for_numpy_or_python_constant_data(constant_data: Any) -> """ base_dtype: BaseDataType custom_assert( - isinstance(constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)), + isinstance( + constant_data, (int, float, list, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES) + ), f"Unsupported constant data of type {type(constant_data)}", ) if isinstance(constant_data, (numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)): + native_type = ( + float + if constant_data.dtype == numpy.float32 or constant_data.dtype == numpy.float64 + else int + ) + + min_value = native_type(constant_data.min()) + max_value = native_type(constant_data.max()) + + min_value_dtype = get_base_data_type_for_python_constant_data(min_value) + max_value_dtype = get_base_data_type_for_python_constant_data(max_value) + # numpy - base_dtype = convert_numpy_dtype_to_base_data_type(constant_data.dtype) + base_dtype = find_type_to_hold_both_lossy(min_value_dtype, max_value_dtype) else: # python base_dtype = get_base_data_type_for_python_constant_data(constant_data) @@ -148,7 +163,10 @@ def get_base_value_for_numpy_or_python_constant_data( """ constant_data_value: Callable[..., BaseValue] custom_assert( - isinstance(constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)), + isinstance( + constant_data, + (int, float, list, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES), + ), f"Unsupported constant data of type {type(constant_data)}", ) diff --git a/tests/common/bounds_measurement/test_inputset_eval.py b/tests/common/bounds_measurement/test_inputset_eval.py index 3569c8bc5..dc3c3d7dc 100644 --- a/tests/common/bounds_measurement/test_inputset_eval.py +++ b/tests/common/bounds_measurement/test_inputset_eval.py @@ -2,12 +2,16 @@ from typing import Tuple +import numpy as np import pytest from concrete.common.bounds_measurement.inputset_eval import eval_op_graph_bounds_on_inputset +from concrete.common.compilation import CompilationConfiguration from concrete.common.data_types.floats import Float -from concrete.common.data_types.integers import Integer -from concrete.common.values import EncryptedScalar +from concrete.common.data_types.integers import Integer, UnsignedInteger +from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor +from concrete.numpy.compile import numpy_max_func, numpy_min_func +from concrete.numpy.np_dtypes_helpers import get_base_value_for_numpy_or_python_constant_data from concrete.numpy.tracing import trace_numpy_function @@ -280,7 +284,9 @@ def test_eval_op_graph_bounds_on_inputset_multiple_output( yield (x_gen, y_gen) _, node_bounds = eval_op_graph_bounds_on_inputset( - op_graph, data_gen(*tuple(range(x[0], x[1] + 1) for x in input_ranges)) + op_graph, + data_gen(*tuple(range(x[0], x[1] + 1) for x in input_ranges)), + CompilationConfiguration(), ) for i, output_node in op_graph.output_nodes.items(): @@ -291,3 +297,137 @@ def test_eval_op_graph_bounds_on_inputset_multiple_output( for i, output_node in op_graph.output_nodes.items(): assert expected_output_data_type[i] == output_node.outputs[0].data_type + + +def test_eval_op_graph_bounds_on_non_conformant_inputset_default(capsys): + """Test function for eval_op_graph_bounds_on_inputset with non conformant inputset""" + + def f(x, y): + return np.dot(x, y) + + x = EncryptedTensor(UnsignedInteger(2), (3,)) + y = ClearTensor(UnsignedInteger(2), (3,)) + + inputset = [ + ([2, 1, 3, 1], [1, 2, 1, 1]), + ([3, 3, 3], [3, 3, 5]), + ] + + op_graph = trace_numpy_function(f, {"x": x, "y": y}) + + configuration = CompilationConfiguration() + eval_op_graph_bounds_on_inputset(op_graph, inputset, compilation_configuration=configuration) + + captured = capsys.readouterr() + assert ( + captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters " + "(expected EncryptedTensor, shape=(3,)> for parameter `x` " + "but got EncryptedTensor, shape=(4,)> which is not compatible)\n" + "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters " + "(expected ClearTensor, shape=(3,)> for parameter `y` " + "but got ClearTensor, shape=(4,)> which is not compatible)\n" + ) + + +def test_eval_op_graph_bounds_on_non_conformant_inputset_check_all(capsys): + """Test function for eval_op_graph_bounds_on_inputset with non conformant inputset, check all""" + + def f(x, y): + return np.dot(x, y) + + x = EncryptedTensor(UnsignedInteger(2), (3,)) + y = ClearTensor(UnsignedInteger(2), (3,)) + + inputset = [ + ([2, 1, 3, 1], [1, 2, 1, 1]), + ([3, 3, 3], [3, 3, 5]), + ] + + op_graph = trace_numpy_function(f, {"x": x, "y": y}) + + configuration = CompilationConfiguration(check_every_input_in_inputset=True) + eval_op_graph_bounds_on_inputset(op_graph, inputset, compilation_configuration=configuration) + + captured = capsys.readouterr() + assert ( + captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters " + "(expected EncryptedTensor, shape=(3,)> for parameter `x` " + "but got EncryptedTensor, shape=(4,)> which is not compatible)\n" + "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters " + "(expected ClearTensor, shape=(3,)> for parameter `y` " + "but got ClearTensor, shape=(4,)> which is not compatible)\n" + "Warning: Input #1 (0-indexed) is not coherent with the hinted parameters " + "(expected ClearTensor, shape=(3,)> for parameter `y` " + "but got ClearTensor, shape=(3,)> which is not compatible)\n" + ) + + +def test_eval_op_graph_bounds_on_conformant_numpy_inputset_check_all(capsys): + """Test function for eval_op_graph_bounds_on_inputset + with conformant inputset of numpy arrays, check all""" + + def f(x, y): + return np.dot(x, y) + + x = EncryptedTensor(UnsignedInteger(2), (3,)) + y = ClearTensor(UnsignedInteger(2), (3,)) + + inputset = [ + (np.array([2, 1, 3]), np.array([1, 2, 1])), + (np.array([3, 3, 3]), np.array([3, 3, 1])), + ] + + op_graph = trace_numpy_function(f, {"x": x, "y": y}) + + configuration = CompilationConfiguration(check_every_input_in_inputset=True) + eval_op_graph_bounds_on_inputset( + op_graph, + inputset, + compilation_configuration=configuration, + min_func=numpy_min_func, + max_func=numpy_max_func, + get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data, + ) + + captured = capsys.readouterr() + assert captured.err == "" + + +def test_eval_op_graph_bounds_on_non_conformant_numpy_inputset_check_all(capsys): + """Test function for eval_op_graph_bounds_on_inputset with non conformant inputset, check all""" + + def f(x, y): + return np.dot(x, y) + + x = EncryptedTensor(UnsignedInteger(2), (3,)) + y = ClearTensor(UnsignedInteger(2), (3,)) + + inputset = [ + (np.array([2, 1, 3, 1]), np.array([1, 2, 1, 1])), + (np.array([3, 3, 3]), np.array([3, 3, 5])), + ] + + op_graph = trace_numpy_function(f, {"x": x, "y": y}) + + configuration = CompilationConfiguration(check_every_input_in_inputset=True) + eval_op_graph_bounds_on_inputset( + op_graph, + inputset, + compilation_configuration=configuration, + min_func=numpy_min_func, + max_func=numpy_max_func, + get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data, + ) + + captured = capsys.readouterr() + assert ( + captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters " + "(expected EncryptedTensor, shape=(3,)> for parameter `x` " + "but got EncryptedTensor, shape=(4,)> which is not compatible)\n" + "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters " + "(expected ClearTensor, shape=(3,)> for parameter `y` " + "but got ClearTensor, shape=(4,)> which is not compatible)\n" + "Warning: Input #1 (0-indexed) is not coherent with the hinted parameters " + "(expected ClearTensor, shape=(3,)> for parameter `y` " + "but got ClearTensor, shape=(3,)> which is not compatible)\n" + ) diff --git a/tests/common/data_types/test_dtypes_helpers.py b/tests/common/data_types/test_dtypes_helpers.py index 74413738d..5a0d52c9a 100644 --- a/tests/common/data_types/test_dtypes_helpers.py +++ b/tests/common/data_types/test_dtypes_helpers.py @@ -1,5 +1,4 @@ """Test file for data types helpers""" - import pytest from concrete.common.data_types.base import BaseDataType diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 3aacdb737..1ef3cd0c3 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -303,7 +303,7 @@ def test_compile_function_with_dot(function, params, shape, ref_graph_str): iter_i = itertools.product(range(0, max_for_ij + 1), repeat=repeat) iter_j = itertools.product(range(0, max_for_ij + 1), repeat=repeat) for prod_i, prod_j in itertools.product(iter_i, iter_j): - yield (prod_i, prod_j) + yield (list(prod_i), list(prod_j)) max_for_ij = 3 assert len(shape) == 1