diff --git a/concrete/common/bounds_measurement/inputset_eval.py b/concrete/common/bounds_measurement/inputset_eval.py index 904c0f009..10d8d91e6 100644 --- a/concrete/common/bounds_measurement/inputset_eval.py +++ b/concrete/common/bounds_measurement/inputset_eval.py @@ -135,7 +135,7 @@ def eval_op_graph_bounds_on_inputset( Returns: Tuple[int, Dict[IntermediateNode, Dict[str, Any]]]: number of inputs in the inputset and a dict containing the bounds for each node from op_graph, stored with the node - as key and a dict with keys "min" and "max" as value. + as key and a dict with keys "min", "max" and "sample" as value. """ def check_inputset_input_len_is_valid(data_to_check): @@ -178,8 +178,12 @@ def eval_op_graph_bounds_on_inputset( # We evaluate the min and max func to be able to resolve the tensors min and max rather than # having the tensor itself as the stored min and max values. - node_bounds = { - node: {"min": min_func(value, value), "max": max_func(value, value)} + node_bounds_and_samples = { + node: { + "min": min_func(value, value), + "max": max_func(value, value), + "sample": value, + } for node, value in first_output.items() } @@ -200,7 +204,11 @@ def eval_op_graph_bounds_on_inputset( current_output = op_graph.evaluate(current_input_data) for node, value in current_output.items(): - node_bounds[node]["min"] = min_func(node_bounds[node]["min"], value) - node_bounds[node]["max"] = max_func(node_bounds[node]["max"], value) + node_bounds_and_samples[node]["min"] = min_func( + node_bounds_and_samples[node]["min"], value + ) + node_bounds_and_samples[node]["max"] = max_func( + node_bounds_and_samples[node]["max"], value + ) - return inputset_size, node_bounds + return inputset_size, node_bounds_and_samples diff --git a/concrete/common/data_types/base.py b/concrete/common/data_types/base.py index dec328fb3..834e75dc9 100644 --- a/concrete/common/data_types/base.py +++ b/concrete/common/data_types/base.py @@ -1,19 +1,11 @@ """File holding code to represent data types in a program.""" from abc import ABC, abstractmethod -from typing import Optional, Type class BaseDataType(ABC): """Base class to represent a data type.""" - # Constructor for the data type represented (for example numpy.int32 for an int32 numpy array) - underlying_type_constructor: Optional[Type] - - def __init__(self) -> None: - super().__init__() - self.underlying_type_constructor = None - @abstractmethod def __eq__(self, o: object) -> bool: """No default implementation.""" diff --git a/concrete/common/data_types/dtypes_helpers.py b/concrete/common/data_types/dtypes_helpers.py index 86eb7c24f..b1ae53264 100644 --- a/concrete/common/data_types/dtypes_helpers.py +++ b/concrete/common/data_types/dtypes_helpers.py @@ -312,7 +312,7 @@ def get_base_value_for_python_constant_data( return partial(TensorValue, dtype=constant_data_type, shape=()) -def get_type_constructor_for_python_constant_data(constant_data: Union[int, float]): +def get_constructor_for_python_constant_data(constant_data: Union[int, float]): """Get the constructor for the passed python constant data. Args: diff --git a/concrete/common/operator_graph.py b/concrete/common/operator_graph.py index 92f338984..d7355476b 100644 --- a/concrete/common/operator_graph.py +++ b/concrete/common/operator_graph.py @@ -1,14 +1,14 @@ """Code to wrap and make manipulating networkx graphs easier.""" from copy import deepcopy -from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Union import networkx as nx from .data_types.base import BaseDataType from .data_types.dtypes_helpers import ( get_base_data_type_for_python_constant_data, - get_type_constructor_for_python_constant_data, + get_constructor_for_python_constant_data, ) from .data_types.floats import Float from .data_types.integers import Integer, make_integer_to_hold @@ -171,15 +171,15 @@ class OPGraph: return node_results - def update_values_with_bounds( + def update_values_with_bounds_and_samples( self, - node_bounds: dict, + node_bounds_and_samples: dict, get_base_data_type_for_constant_data: Callable[ [Any], BaseDataType ] = get_base_data_type_for_python_constant_data, - get_type_constructor_for_constant_data: Callable[ - ..., Type - ] = get_type_constructor_for_python_constant_data, + get_constructor_for_constant_data: Callable[ + ..., Callable + ] = get_constructor_for_python_constant_data, ): """Update values with bounds. @@ -187,40 +187,44 @@ class OPGraph: and passed in nodes_bounds Args: - node_bounds (dict): Dictionary with nodes as keys, holding dicts with a 'min' and 'max' - keys. Those bounds will be taken as the data range to be represented, per node. + node_bounds_and_samples (dict): Dictionary with nodes as keys, holding dicts with a + 'min', 'max' and 'sample' keys. Those bounds will be taken as the data range to be + represented, per node. The sample allows to determine the data constructors to + prepare the UnivariateFunction nodes for table generation. get_base_data_type_for_constant_data (Callable[ [Any], BaseDataType ], optional): This is a callback function to convert data encountered during value updates to BaseDataType. This allows to manage data coming from foreign frameworks without specialising OPGraph. Defaults to get_base_data_type_for_python_constant_data. - get_type_constructor_for_constant_data (Callable[ ..., Type ], optional): This is a + get_constructor_for_constant_data (Callable[ ..., Callable ], optional): This is a callback function to determine the type constructor of the data encountered while - updating the graph bounds. Defaults to get_type_constructor_python_constant_data. + updating the graph bounds. Defaults to get_constructor_for_python_constant_data. """ node: IntermediateNode for node in self.graph.nodes(): - current_node_bounds = node_bounds[node] - min_bound, max_bound = ( - current_node_bounds["min"], - current_node_bounds["max"], + current_node_bounds_and_samples = node_bounds_and_samples[node] + min_bound, max_bound, sample = ( + current_node_bounds_and_samples["min"], + current_node_bounds_and_samples["max"], + current_node_bounds_and_samples["sample"], ) min_data_type = get_base_data_type_for_constant_data(min_bound) max_data_type = get_base_data_type_for_constant_data(max_bound) - min_data_type_constructor = get_type_constructor_for_constant_data(min_bound) - max_data_type_constructor = get_type_constructor_for_constant_data(max_bound) + # This is a sanity check + min_value_constructor = get_constructor_for_constant_data(min_bound) + max_value_constructor = get_constructor_for_constant_data(max_bound) assert_true( - max_data_type_constructor == min_data_type_constructor, + max_value_constructor == min_value_constructor, ( f"Got two different type constructors for min and max bound: " - f"{min_data_type_constructor}, {max_data_type_constructor}" + f"{min_value_constructor}, {max_value_constructor}" ), ) - data_type_constructor = max_data_type_constructor + value_constructor = get_constructor_for_constant_data(sample) if not isinstance(node, Input): for output_value in node.outputs: @@ -238,7 +242,7 @@ class OPGraph: ), ) output_value.dtype = Float(64) - output_value.dtype.underlying_type_constructor = data_type_constructor + output_value.underlying_constructor = value_constructor else: # Currently variable inputs are only allowed to be integers assert_true( @@ -252,7 +256,7 @@ class OPGraph: node.inputs[0].dtype = make_integer_to_hold( (min_bound, max_bound), force_signed=False ) - node.inputs[0].dtype.underlying_type_constructor = data_type_constructor + node.inputs[0].underlying_constructor = value_constructor node.outputs[0] = deepcopy(node.inputs[0]) diff --git a/concrete/common/representation/intermediate.py b/concrete/common/representation/intermediate.py index 144076ffc..f694d61de 100644 --- a/concrete/common/representation/intermediate.py +++ b/concrete/common/representation/intermediate.py @@ -250,26 +250,25 @@ class UnivariateFunction(IntermediateNode): Returns: List[Any]: The table. """ + input_dtype = self.inputs[0].dtype # Check the input is an unsigned integer to be able to build a table assert isinstance( - self.inputs[0].dtype, Integer + input_dtype, Integer ), "get_table only works for an unsigned Integer input" - assert not self.inputs[ - 0 - ].dtype.is_signed, "get_table only works for an unsigned Integer input" + assert not input_dtype.is_signed, "get_table only works for an unsigned Integer input" - type_constructor = self.inputs[0].dtype.underlying_type_constructor - if type_constructor is None: + input_value_constructor = self.inputs[0].underlying_constructor + if input_value_constructor is None: logger.info( f"{self.__class__.__name__} input data type constructor was None, defaulting to int" ) - type_constructor = int + input_value_constructor = int - min_input_range = self.inputs[0].dtype.min_value() - max_input_range = self.inputs[0].dtype.max_value() + 1 + min_input_range = input_dtype.min_value() + max_input_range = input_dtype.max_value() + 1 table = [ - self.evaluate({0: type_constructor(input_value)}) + self.evaluate({0: input_value_constructor(input_value)}) for input_value in range(min_input_range, max_input_range) ] diff --git a/concrete/common/values/base.py b/concrete/common/values/base.py index 25311f66a..b33f026e4 100644 --- a/concrete/common/values/base.py +++ b/concrete/common/values/base.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from copy import deepcopy +from typing import Callable, Optional from ..data_types.base import BaseDataType @@ -11,10 +12,12 @@ class BaseValue(ABC): dtype: BaseDataType _is_encrypted: bool + underlying_constructor: Optional[Callable] def __init__(self, dtype: BaseDataType, is_encrypted: bool) -> None: self.dtype = deepcopy(dtype) self._is_encrypted = is_encrypted + self.underlying_constructor = None def __repr__(self) -> str: # pragma: no cover return str(self) diff --git a/concrete/numpy/compile.py b/concrete/numpy/compile.py index 689b9c5b6..f45b4df33 100644 --- a/concrete/numpy/compile.py +++ b/concrete/numpy/compile.py @@ -26,7 +26,7 @@ from ..numpy.tracing import trace_numpy_function from .np_dtypes_helpers import ( get_base_data_type_for_numpy_or_python_constant_data, get_base_value_for_numpy_or_python_constant_data, - get_type_constructor_for_numpy_or_python_constant_data, + get_constructor_for_numpy_or_python_constant_data, ) @@ -111,7 +111,7 @@ def _compile_numpy_function_into_op_graph_internal( ) # Find bounds with the inputset - inputset_size, node_bounds = eval_op_graph_bounds_on_inputset( + inputset_size, node_bounds_and_samples = eval_op_graph_bounds_on_inputset( op_graph, inputset, compilation_configuration=compilation_configuration, @@ -149,13 +149,13 @@ def _compile_numpy_function_into_op_graph_internal( sys.stderr.write(f"Warning: {message}") # Add the bounds as an artifact - compilation_artifacts.add_final_operation_graph_bounds(node_bounds) + compilation_artifacts.add_final_operation_graph_bounds(node_bounds_and_samples) # Update the graph accordingly: after that, we have the compilable graph - op_graph.update_values_with_bounds( - node_bounds, + op_graph.update_values_with_bounds_and_samples( + node_bounds_and_samples, get_base_data_type_for_numpy_or_python_constant_data, - get_type_constructor_for_numpy_or_python_constant_data, + get_constructor_for_numpy_or_python_constant_data, ) # Add the initial graph as an artifact diff --git a/concrete/numpy/np_dtypes_helpers.py b/concrete/numpy/np_dtypes_helpers.py index 599886068..24708bb47 100644 --- a/concrete/numpy/np_dtypes_helpers.py +++ b/concrete/numpy/np_dtypes_helpers.py @@ -2,7 +2,7 @@ from copy import deepcopy from functools import partial -from typing import Any, Callable, Dict, List, Type, Union +from typing import Any, Callable, Dict, List, Union import numpy from numpy.typing import DTypeLike @@ -13,7 +13,7 @@ from ..common.data_types.dtypes_helpers import ( find_type_to_hold_both_lossy, get_base_data_type_for_python_constant_data, get_base_value_for_python_constant_data, - get_type_constructor_for_python_constant_data, + get_constructor_for_python_constant_data, ) from ..common.data_types.floats import Float from ..common.data_types.integers import Integer @@ -224,8 +224,8 @@ def get_numpy_function_output_dtype( return [output.dtype for output in outputs] -def get_type_constructor_for_numpy_or_python_constant_data(constant_data: Any): - """Get the constructor for the numpy scalar underlying dtype or python dtype. +def get_constructor_for_numpy_or_python_constant_data(constant_data: Any): + """Get the constructor for the numpy constant data or python dtype. Args: constant_data (Any): The data for which we want to determine the type constructor. @@ -236,11 +236,8 @@ def get_type_constructor_for_numpy_or_python_constant_data(constant_data: Any): f"Unsupported constant data of type {type(constant_data)}", ) - scalar_constructor: Type - if isinstance(constant_data, (numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)): - scalar_constructor = constant_data.dtype.type - else: - scalar_constructor = get_type_constructor_for_python_constant_data(constant_data) - - return scalar_constructor + if isinstance(constant_data, numpy.ndarray): + return lambda x: numpy.full(constant_data.shape, x, dtype=constant_data.dtype) + return constant_data.dtype.type + return get_constructor_for_python_constant_data(constant_data) diff --git a/tests/common/bounds_measurement/test_inputset_eval.py b/tests/common/bounds_measurement/test_inputset_eval.py index 209471873..085ea865c 100644 --- a/tests/common/bounds_measurement/test_inputset_eval.py +++ b/tests/common/bounds_measurement/test_inputset_eval.py @@ -283,17 +283,17 @@ def test_eval_op_graph_bounds_on_inputset_multiple_output( for y_gen in range_y: yield (x_gen, y_gen) - _, node_bounds = eval_op_graph_bounds_on_inputset( + _, node_bounds_and_samples = eval_op_graph_bounds_on_inputset( op_graph, data_gen(*tuple(range(x[0], x[1] + 1) for x in input_ranges)), CompilationConfiguration(), ) for i, output_node in op_graph.output_nodes.items(): - output_node_bounds = node_bounds[output_node] + output_node_bounds = node_bounds_and_samples[output_node] assert (output_node_bounds["min"], output_node_bounds["max"]) == expected_output_bounds[i] - op_graph.update_values_with_bounds(node_bounds) + op_graph.update_values_with_bounds_and_samples(node_bounds_and_samples) for i, output_node in op_graph.output_nodes.items(): assert expected_output_data_type[i] == output_node.outputs[0].dtype diff --git a/tests/conftest.py b/tests/conftest.py index d3e38870b..3a4dc94a6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -215,6 +215,14 @@ class TestHelpers: return graphs_are_isomorphic + @staticmethod + def python_functions_are_equal_or_equivalent(lhs, rhs): + """Helper function to check if two functions are equal or their code are equivalent. + + This is not perfect, but will be good enough for tests. + """ + return python_functions_are_equal_or_equivalent(lhs, rhs) + @pytest.fixture def test_helpers(): diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 97563618e..919505e3d 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -330,7 +330,7 @@ def test_fail_compile(function, input_ranges, list_of_arg_names): } with pytest.raises(RuntimeError, match=".*isn't supported for MLIR lowering.*"): - compile_numpy_function_into_op_graph( + compile_numpy_function( function, function_parameters, data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)), diff --git a/tests/numpy/test_np_dtypes_helpers.py b/tests/numpy/test_np_dtypes_helpers.py index a10e2b594..13f8c9ef6 100644 --- a/tests/numpy/test_np_dtypes_helpers.py +++ b/tests/numpy/test_np_dtypes_helpers.py @@ -9,7 +9,7 @@ from concrete.numpy.np_dtypes_helpers import ( convert_base_data_type_to_numpy_dtype, convert_numpy_dtype_to_base_data_type, get_base_value_for_numpy_or_python_constant_data, - get_type_constructor_for_numpy_or_python_constant_data, + get_constructor_for_numpy_or_python_constant_data, ) @@ -65,18 +65,31 @@ def test_convert_common_dtype_to_numpy_dtype(common_dtype, expected_numpy_dtype) (10, int), (42.0, float), (numpy.int32(10), numpy.int32), - (numpy.array([[0, 1], [3, 4]], dtype=numpy.uint64), numpy.uint64), - (numpy.array([[0, 1], [3, 4]], dtype=numpy.float64), numpy.float64), ], ) -def test_get_type_constructor_for_numpy_or_python_constant_data( - constant_data, expected_constructor -): - """Test function for get_type_constructor_for_numpy_or_python_constant_data""" +def test_get_constructor_for_numpy_or_python_constant_data(constant_data, expected_constructor): + """Test function for get_constructor_for_numpy_or_python_constant_data""" - assert expected_constructor == get_type_constructor_for_numpy_or_python_constant_data( - constant_data - ) + assert expected_constructor == get_constructor_for_numpy_or_python_constant_data(constant_data) + + +def test_get_constructor_for_numpy_arrays(test_helpers): + """Test function for get_constructor_for_numpy_or_python_constant_data for numpy arrays.""" + + arrays = [ + numpy.array([[0, 1], [3, 4]], dtype=numpy.uint64), + numpy.array([[0, 1], [3, 4]], dtype=numpy.float64), + ] + + def get_expected_constructor(array: numpy.ndarray): + return lambda x: numpy.full(array.shape, x, dtype=array.dtype) + + expected_constructors = [get_expected_constructor(array) for array in arrays] + + for array, expected_constructor in zip(arrays, expected_constructors): + assert test_helpers.python_functions_are_equal_or_equivalent( + expected_constructor, get_constructor_for_numpy_or_python_constant_data(array) + ) def test_get_base_value_for_numpy_or_python_constant_data_with_list():