From 6fe809aeced1e7ff65243aeead6e9c580bb4f2be Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Mon, 6 Sep 2021 17:32:10 +0200 Subject: [PATCH] dev: add code to have the proper data type when getting an AF table - update BaseDataType to store the underlying type constructor e.g. int - add helper functions to get the type constructor for constant data - update operator graph to fill the type constructor during bounds update - add loguru as logger - use type constructor in ArbitraryFunction.get_table, log an info if the type_constructor of the input was None and default to int --- concrete/common/data_types/base.py | 8 ++++ concrete/common/data_types/dtypes_helpers.py | 9 +++++ concrete/common/data_types/floats.py | 1 + concrete/common/data_types/integers.py | 1 + concrete/common/operator_graph.py | 40 ++++++++++++++++--- .../common/representation/intermediate.py | 13 +++++- concrete/numpy/compile.py | 9 ++++- concrete/numpy/np_dtypes_helpers.py | 24 ++++++++++- poetry.lock | 38 +++++++++++++++++- pyproject.toml | 1 + tests/numpy/test_np_dtypes_helpers.py | 21 ++++++++++ 11 files changed, 154 insertions(+), 11 deletions(-) diff --git a/concrete/common/data_types/base.py b/concrete/common/data_types/base.py index 834e75dc9..dec328fb3 100644 --- a/concrete/common/data_types/base.py +++ b/concrete/common/data_types/base.py @@ -1,11 +1,19 @@ """File holding code to represent data types in a program.""" from abc import ABC, abstractmethod +from typing import Optional, Type class BaseDataType(ABC): """Base class to represent a data type.""" + # Constructor for the data type represented (for example numpy.int32 for an int32 numpy array) + underlying_type_constructor: Optional[Type] + + def __init__(self) -> None: + super().__init__() + self.underlying_type_constructor = None + @abstractmethod def __eq__(self, o: object) -> bool: """No default implementation.""" diff --git a/concrete/common/data_types/dtypes_helpers.py b/concrete/common/data_types/dtypes_helpers.py index 7234c311f..83b4a5084 100644 --- a/concrete/common/data_types/dtypes_helpers.py +++ b/concrete/common/data_types/dtypes_helpers.py @@ -336,3 +336,12 @@ def get_base_value_for_python_constant_data( """ constant_data_type = get_base_data_type_for_python_constant_data(constant_data) return partial(ScalarValue, data_type=constant_data_type) + + +def get_type_constructor_for_python_constant_data(constant_data: Union[int, float]): + """Get the constructor for the passed python constant data. + + Args: + constant_data (Any): The data for which we want to determine the type constructor. + """ + return type(constant_data) diff --git a/concrete/common/data_types/floats.py b/concrete/common/data_types/floats.py index 9161bb391..63b52b3b3 100644 --- a/concrete/common/data_types/floats.py +++ b/concrete/common/data_types/floats.py @@ -13,6 +13,7 @@ class Float(base.BaseDataType): bit_width: int def __init__(self, bit_width: int) -> None: + super().__init__() assert bit_width in (32, 64), "Only 32 and 64 bits floats are supported" self.bit_width = bit_width diff --git a/concrete/common/data_types/integers.py b/concrete/common/data_types/integers.py index 7ef0674d7..2cbe83560 100644 --- a/concrete/common/data_types/integers.py +++ b/concrete/common/data_types/integers.py @@ -13,6 +13,7 @@ class Integer(base.BaseDataType): is_signed: bool def __init__(self, bit_width: int, is_signed: bool) -> None: + super().__init__() assert bit_width > 0, "bit_width must be > 0" self.bit_width = bit_width self.is_signed = is_signed diff --git a/concrete/common/operator_graph.py b/concrete/common/operator_graph.py index 313e25ec1..b19c42140 100644 --- a/concrete/common/operator_graph.py +++ b/concrete/common/operator_graph.py @@ -1,12 +1,15 @@ """Code to wrap and make manipulating networkx graphs easier.""" from copy import deepcopy -from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Type, Union import networkx as nx from .data_types.base import BaseDataType -from .data_types.dtypes_helpers import get_base_data_type_for_python_constant_data +from .data_types.dtypes_helpers import ( + get_base_data_type_for_python_constant_data, + get_type_constructor_for_python_constant_data, +) from .data_types.floats import Float from .data_types.integers import Integer, make_integer_to_hold from .representation import intermediate as ir @@ -124,8 +127,9 @@ class OPGraph: curr_inputs = {} for pred_node in self.graph.pred[node]: edges = self.graph.get_edge_data(pred_node, node) - for edge in edges.values(): - curr_inputs[edge["input_idx"]] = node_results[pred_node] + curr_inputs.update( + {edge["input_idx"]: node_results[pred_node] for edge in edges.values()} + ) node_results[node] = node.evaluate(curr_inputs) else: node_results[node] = node.evaluate({0: inputs[node.program_input_idx]}) @@ -138,6 +142,9 @@ class OPGraph: get_base_data_type_for_constant_data: Callable[ [Any], BaseDataType ] = get_base_data_type_for_python_constant_data, + get_type_constructor_for_constant_data: Callable[ + ..., Type + ] = get_type_constructor_for_python_constant_data, ): """Update values with bounds. @@ -147,10 +154,13 @@ class OPGraph: Args: node_bounds (dict): Dictionary with nodes as keys, holding dicts with a 'min' and 'max' keys. Those bounds will be taken as the data range to be represented, per node. - get_base_data_type_for_constant_data (Callable[ [Type], BaseDataType ], optional): This + get_base_data_type_for_constant_data (Callable[ [Any], BaseDataType ], optional): This is a callback function to convert data encountered during value updates to BaseDataType. This allows to manage data coming from foreign frameworks without specialising OPGraph. Defaults to get_base_data_type_for_python_constant_data. + get_type_constructor_for_constant_data (Callable[ ..., Type ], optional): This is a + callback function to determine the type constructor of the data encountered while + updating the graph bounds. Defaults to get_type_constructor_python_constant_data. """ node: ir.IntermediateNode @@ -164,6 +174,16 @@ class OPGraph: min_data_type = get_base_data_type_for_constant_data(min_bound) max_data_type = get_base_data_type_for_constant_data(max_bound) + min_data_type_constructor = get_type_constructor_for_constant_data(min_bound) + max_data_type_constructor = get_type_constructor_for_constant_data(max_bound) + + assert max_data_type_constructor == min_data_type_constructor, ( + f"Got two different type constructors for min and max bound: " + f"{min_data_type_constructor}, {max_data_type_constructor}" + ) + + data_type_constructor = max_data_type_constructor + if not isinstance(node, ir.Input): for output_value in node.outputs: if isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer): @@ -171,7 +191,15 @@ class OPGraph: (min_bound, max_bound), force_signed=False ) else: + assert isinstance(min_data_type, Float) and isinstance( + max_data_type, Float + ), ( + "min_bound and max_bound have different common types, " + "this should never happen.\n" + f"min_bound: {min_data_type}, max_bound: {max_data_type}" + ) output_value.data_type = Float(64) + output_value.data_type.underlying_type_constructor = data_type_constructor else: # Currently variable inputs are only allowed to be integers assert isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer), ( @@ -181,6 +209,8 @@ class OPGraph: node.inputs[0].data_type = make_integer_to_hold( (min_bound, max_bound), force_signed=False ) + node.inputs[0].data_type.underlying_type_constructor = data_type_constructor + node.outputs[0] = deepcopy(node.inputs[0]) # TODO: #57 manage multiple outputs from a node, probably requires an output_idx when diff --git a/concrete/common/representation/intermediate.py b/concrete/common/representation/intermediate.py index 5f4ba2d23..ed6325510 100644 --- a/concrete/common/representation/intermediate.py +++ b/concrete/common/representation/intermediate.py @@ -4,6 +4,8 @@ from abc import ABC, abstractmethod from copy import deepcopy from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type +from loguru import logger + from ..data_types.base import BaseDataType from ..data_types.dtypes_helpers import ( get_base_value_for_python_constant_data, @@ -232,6 +234,8 @@ class ArbitraryFunction(IntermediateNode): def get_table(self) -> List[Any]: """Get the table for the current input value of this ArbitraryFunction. + This function only works if the ArbitraryFunction input value is an unsigned Integer. + Returns: List[Any]: The table. """ @@ -243,11 +247,18 @@ class ArbitraryFunction(IntermediateNode): 0 ].data_type.is_signed, "get_table only works for an unsigned Integer input" + type_constructor = self.inputs[0].data_type.underlying_type_constructor + if type_constructor is None: + logger.info( + f"{self.__class__.__name__} input data type constructor was None, defaulting to int" + ) + type_constructor = int + min_input_range = self.inputs[0].data_type.min_value() max_input_range = self.inputs[0].data_type.max_value() + 1 table = [ - self.evaluate({0: input_value}) + self.evaluate({0: type_constructor(input_value)}) for input_value in range(min_input_range, max_input_range) ] diff --git a/concrete/numpy/compile.py b/concrete/numpy/compile.py index 8ac26c9e5..333e73d10 100644 --- a/concrete/numpy/compile.py +++ b/concrete/numpy/compile.py @@ -19,7 +19,10 @@ from ..common.optimization.topological import fuse_float_operations from ..common.representation import intermediate as ir from ..common.values import BaseValue from ..numpy.tracing import trace_numpy_function -from .np_dtypes_helpers import get_base_data_type_for_numpy_or_python_constant_data +from .np_dtypes_helpers import ( + get_base_data_type_for_numpy_or_python_constant_data, + get_type_constructor_for_numpy_or_python_constant_data, +) def numpy_max_func(lhs: Any, rhs: Any) -> Any: @@ -115,7 +118,9 @@ def _compile_numpy_function_into_op_graph_internal( # Update the graph accordingly: after that, we have the compilable graph op_graph.update_values_with_bounds( - node_bounds, get_base_data_type_for_numpy_or_python_constant_data + node_bounds, + get_base_data_type_for_numpy_or_python_constant_data, + get_type_constructor_for_numpy_or_python_constant_data, ) # Add the initial graph as an artifact diff --git a/concrete/numpy/np_dtypes_helpers.py b/concrete/numpy/np_dtypes_helpers.py index 69586c1cc..1158925c4 100644 --- a/concrete/numpy/np_dtypes_helpers.py +++ b/concrete/numpy/np_dtypes_helpers.py @@ -2,7 +2,7 @@ from copy import deepcopy from functools import partial -from typing import Any, Callable, Dict, List, Union +from typing import Any, Callable, Dict, List, Type, Union import numpy from numpy.typing import DTypeLike @@ -12,6 +12,7 @@ from ..common.data_types.dtypes_helpers import ( BASE_DATA_TYPES, get_base_data_type_for_python_constant_data, get_base_value_for_python_constant_data, + get_type_constructor_for_python_constant_data, ) from ..common.data_types.floats import Float from ..common.data_types.integers import Integer @@ -193,3 +194,24 @@ def get_numpy_function_output_dtype( numpy.seterr(**old_numpy_err_settings) return [output.dtype for output in outputs] + + +def get_type_constructor_for_numpy_or_python_constant_data(constant_data: Any): + """Get the constructor for the numpy scalar underlying dtype or python dtype. + + Args: + constant_data (Any): The data for which we want to determine the type constructor. + """ + + assert isinstance( + constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES) + ), f"Unsupported constant data of type {type(constant_data)}" + + scalar_constructor: Type + + if isinstance(constant_data, (numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)): + scalar_constructor = constant_data.dtype.type + else: + scalar_constructor = get_type_constructor_for_python_constant_data(constant_data) + + return scalar_constructor diff --git a/poetry.lock b/poetry.lock index f0b5b37ec..bf25a86d8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -179,7 +179,7 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} name = "colorama" version = "0.4.4" description = "Cross-platform colored terminal text." -category = "dev" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" @@ -571,6 +571,21 @@ category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" +[[package]] +name = "loguru" +version = "0.5.3" +description = "Python logging made (stupidly) simple" +category = "main" +optional = false +python-versions = ">=3.5" + +[package.dependencies] +colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""} +win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} + +[package.extras] +dev = ["codecov (>=2.0.15)", "colorama (>=0.3.4)", "flake8 (>=3.7.7)", "tox (>=3.9.0)", "tox-travis (>=0.12)", "pytest (>=4.6.2)", "pytest-cov (>=2.7.1)", "Sphinx (>=2.2.1)", "sphinx-autobuild (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "black (>=19.10b0)", "isort (>=5.1.1)"] + [[package]] name = "markdown-it-py" version = "1.1.0" @@ -1504,6 +1519,17 @@ python-versions = "*" [package.dependencies] notebook = ">=4.4.1" +[[package]] +name = "win32-setctime" +version = "1.0.3" +description = "A small Python utility to set file creation time on Windows" +category = "main" +optional = false +python-versions = ">=3.5" + +[package.extras] +dev = ["pytest (>=4.6.2)", "black (>=19.3b0)"] + [[package]] name = "wrapt" version = "1.12.1" @@ -1515,7 +1541,7 @@ python-versions = "*" [metadata] lock-version = "1.1" python-versions = ">=3.8,<3.9" -content-hash = "cb1c1db2c4f94ed4984e565d1b5c0633bcf58a57cd667ff09a18e01e73382aa4" +content-hash = "8a3be3fe122eddfb9a28a4f789c4581311ada706406277b5e84beb431369163b" [metadata.files] alabaster = [ @@ -1871,6 +1897,10 @@ lazy-object-proxy = [ {file = "lazy_object_proxy-1.6.0-cp39-cp39-win32.whl", hash = "sha256:1fee665d2638491f4d6e55bd483e15ef21f6c8c2095f235fef72601021e64f61"}, {file = "lazy_object_proxy-1.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:f5144c75445ae3ca2057faac03fda5a902eff196702b0a24daf1d6ce0650514b"}, ] +loguru = [ + {file = "loguru-0.5.3-py3-none-any.whl", hash = "sha256:f8087ac396b5ee5f67c963b495d615ebbceac2796379599820e324419d53667c"}, + {file = "loguru-0.5.3.tar.gz", hash = "sha256:b28e72ac7a98be3d28ad28570299a393dfcd32e5e3f6a353dec94675767b6319"}, +] markdown-it-py = [ {file = "markdown-it-py-1.1.0.tar.gz", hash = "sha256:36be6bb3ad987bfdb839f5ba78ddf094552ca38ccbd784ae4f74a4e1419fc6e3"}, {file = "markdown_it_py-1.1.0-py3-none-any.whl", hash = "sha256:98080fc0bc34c4f2bcf0846a096a9429acbd9d5d8e67ed34026c03c61c464389"}, @@ -2538,6 +2568,10 @@ widgetsnbextension = [ {file = "widgetsnbextension-3.5.1-py2.py3-none-any.whl", hash = "sha256:bd314f8ceb488571a5ffea6cc5b9fc6cba0adaf88a9d2386b93a489751938bcd"}, {file = "widgetsnbextension-3.5.1.tar.gz", hash = "sha256:079f87d87270bce047512400efd70238820751a11d2d8cb137a5a5bdbaf255c7"}, ] +win32-setctime = [ + {file = "win32_setctime-1.0.3-py3-none-any.whl", hash = "sha256:dc925662de0a6eb987f0b01f599c01a8236cb8c62831c22d9cada09ad958243e"}, + {file = "win32_setctime-1.0.3.tar.gz", hash = "sha256:4e88556c32fdf47f64165a2180ba4552f8bb32c1103a2fafd05723a0bd42bd4b"}, +] wrapt = [ {file = "wrapt-1.12.1.tar.gz", hash = "sha256:b62ffa81fb85f4332a4f609cab4ac40709470da05643a082ec1eb88e6d9b97d7"}, ] diff --git a/pyproject.toml b/pyproject.toml index 4459463e6..cc9d9c71b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ matplotlib = "^3.4.2" numpy = "^1.21.1" pygraphviz = "^1.7" Pillow = "^8.3.1" +loguru = "^0.5.3" [tool.poetry.dev-dependencies] isort = "^5.9.2" diff --git a/tests/numpy/test_np_dtypes_helpers.py b/tests/numpy/test_np_dtypes_helpers.py index 6961c714f..d48180657 100644 --- a/tests/numpy/test_np_dtypes_helpers.py +++ b/tests/numpy/test_np_dtypes_helpers.py @@ -8,6 +8,7 @@ from concrete.common.data_types.integers import Integer from concrete.numpy.np_dtypes_helpers import ( convert_base_data_type_to_numpy_dtype, convert_numpy_dtype_to_base_data_type, + get_type_constructor_for_numpy_or_python_constant_data, ) @@ -55,3 +56,23 @@ def test_convert_numpy_dtype_to_base_data_type(numpy_dtype, expected_common_type def test_convert_common_dtype_to_numpy_dtype(common_dtype, expected_numpy_dtype): """Test function for convert_common_dtype_to_numpy_dtype""" assert expected_numpy_dtype == convert_base_data_type_to_numpy_dtype(common_dtype) + + +@pytest.mark.parametrize( + "constant_data,expected_constructor", + [ + (10, int), + (42.0, float), + (numpy.int32(10), numpy.int32), + (numpy.array([[0, 1], [3, 4]], dtype=numpy.uint64), numpy.uint64), + (numpy.array([[0, 1], [3, 4]], dtype=numpy.float64), numpy.float64), + ], +) +def test_get_type_constructor_for_numpy_or_python_constant_data( + constant_data, expected_constructor +): + """Test function for get_type_constructor_for_numpy_or_python_constant_data""" + + assert expected_constructor == get_type_constructor_for_numpy_or_python_constant_data( + constant_data + )