dev: add code to have the proper data type when getting an AF table

- update BaseDataType to store the underlying type constructor e.g. int
- add helper functions to get the type constructor for constant data
- update operator graph to fill the type constructor during bounds update
- add loguru as logger
- use type constructor in ArbitraryFunction.get_table, log an info if the
type_constructor of the input was None and default to int
This commit is contained in:
Arthur Meyre
2021-09-06 17:32:10 +02:00
parent 269ce01db3
commit 6fe809aece
11 changed files with 154 additions and 11 deletions

View File

@@ -1,11 +1,19 @@
"""File holding code to represent data types in a program."""
from abc import ABC, abstractmethod
from typing import Optional, Type
class BaseDataType(ABC):
"""Base class to represent a data type."""
# Constructor for the data type represented (for example numpy.int32 for an int32 numpy array)
underlying_type_constructor: Optional[Type]
def __init__(self) -> None:
super().__init__()
self.underlying_type_constructor = None
@abstractmethod
def __eq__(self, o: object) -> bool:
"""No default implementation."""

View File

@@ -336,3 +336,12 @@ def get_base_value_for_python_constant_data(
"""
constant_data_type = get_base_data_type_for_python_constant_data(constant_data)
return partial(ScalarValue, data_type=constant_data_type)
def get_type_constructor_for_python_constant_data(constant_data: Union[int, float]):
"""Get the constructor for the passed python constant data.
Args:
constant_data (Any): The data for which we want to determine the type constructor.
"""
return type(constant_data)

View File

@@ -13,6 +13,7 @@ class Float(base.BaseDataType):
bit_width: int
def __init__(self, bit_width: int) -> None:
super().__init__()
assert bit_width in (32, 64), "Only 32 and 64 bits floats are supported"
self.bit_width = bit_width

View File

@@ -13,6 +13,7 @@ class Integer(base.BaseDataType):
is_signed: bool
def __init__(self, bit_width: int, is_signed: bool) -> None:
super().__init__()
assert bit_width > 0, "bit_width must be > 0"
self.bit_width = bit_width
self.is_signed = is_signed

View File

@@ -1,12 +1,15 @@
"""Code to wrap and make manipulating networkx graphs easier."""
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Type, Union
import networkx as nx
from .data_types.base import BaseDataType
from .data_types.dtypes_helpers import get_base_data_type_for_python_constant_data
from .data_types.dtypes_helpers import (
get_base_data_type_for_python_constant_data,
get_type_constructor_for_python_constant_data,
)
from .data_types.floats import Float
from .data_types.integers import Integer, make_integer_to_hold
from .representation import intermediate as ir
@@ -124,8 +127,9 @@ class OPGraph:
curr_inputs = {}
for pred_node in self.graph.pred[node]:
edges = self.graph.get_edge_data(pred_node, node)
for edge in edges.values():
curr_inputs[edge["input_idx"]] = node_results[pred_node]
curr_inputs.update(
{edge["input_idx"]: node_results[pred_node] for edge in edges.values()}
)
node_results[node] = node.evaluate(curr_inputs)
else:
node_results[node] = node.evaluate({0: inputs[node.program_input_idx]})
@@ -138,6 +142,9 @@ class OPGraph:
get_base_data_type_for_constant_data: Callable[
[Any], BaseDataType
] = get_base_data_type_for_python_constant_data,
get_type_constructor_for_constant_data: Callable[
..., Type
] = get_type_constructor_for_python_constant_data,
):
"""Update values with bounds.
@@ -147,10 +154,13 @@ class OPGraph:
Args:
node_bounds (dict): Dictionary with nodes as keys, holding dicts with a 'min' and 'max'
keys. Those bounds will be taken as the data range to be represented, per node.
get_base_data_type_for_constant_data (Callable[ [Type], BaseDataType ], optional): This
get_base_data_type_for_constant_data (Callable[ [Any], BaseDataType ], optional): This
is a callback function to convert data encountered during value updates to
BaseDataType. This allows to manage data coming from foreign frameworks without
specialising OPGraph. Defaults to get_base_data_type_for_python_constant_data.
get_type_constructor_for_constant_data (Callable[ ..., Type ], optional): This is a
callback function to determine the type constructor of the data encountered while
updating the graph bounds. Defaults to get_type_constructor_python_constant_data.
"""
node: ir.IntermediateNode
@@ -164,6 +174,16 @@ class OPGraph:
min_data_type = get_base_data_type_for_constant_data(min_bound)
max_data_type = get_base_data_type_for_constant_data(max_bound)
min_data_type_constructor = get_type_constructor_for_constant_data(min_bound)
max_data_type_constructor = get_type_constructor_for_constant_data(max_bound)
assert max_data_type_constructor == min_data_type_constructor, (
f"Got two different type constructors for min and max bound: "
f"{min_data_type_constructor}, {max_data_type_constructor}"
)
data_type_constructor = max_data_type_constructor
if not isinstance(node, ir.Input):
for output_value in node.outputs:
if isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer):
@@ -171,7 +191,15 @@ class OPGraph:
(min_bound, max_bound), force_signed=False
)
else:
assert isinstance(min_data_type, Float) and isinstance(
max_data_type, Float
), (
"min_bound and max_bound have different common types, "
"this should never happen.\n"
f"min_bound: {min_data_type}, max_bound: {max_data_type}"
)
output_value.data_type = Float(64)
output_value.data_type.underlying_type_constructor = data_type_constructor
else:
# Currently variable inputs are only allowed to be integers
assert isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer), (
@@ -181,6 +209,8 @@ class OPGraph:
node.inputs[0].data_type = make_integer_to_hold(
(min_bound, max_bound), force_signed=False
)
node.inputs[0].data_type.underlying_type_constructor = data_type_constructor
node.outputs[0] = deepcopy(node.inputs[0])
# TODO: #57 manage multiple outputs from a node, probably requires an output_idx when

View File

@@ -4,6 +4,8 @@ from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type
from loguru import logger
from ..data_types.base import BaseDataType
from ..data_types.dtypes_helpers import (
get_base_value_for_python_constant_data,
@@ -232,6 +234,8 @@ class ArbitraryFunction(IntermediateNode):
def get_table(self) -> List[Any]:
"""Get the table for the current input value of this ArbitraryFunction.
This function only works if the ArbitraryFunction input value is an unsigned Integer.
Returns:
List[Any]: The table.
"""
@@ -243,11 +247,18 @@ class ArbitraryFunction(IntermediateNode):
0
].data_type.is_signed, "get_table only works for an unsigned Integer input"
type_constructor = self.inputs[0].data_type.underlying_type_constructor
if type_constructor is None:
logger.info(
f"{self.__class__.__name__} input data type constructor was None, defaulting to int"
)
type_constructor = int
min_input_range = self.inputs[0].data_type.min_value()
max_input_range = self.inputs[0].data_type.max_value() + 1
table = [
self.evaluate({0: input_value})
self.evaluate({0: type_constructor(input_value)})
for input_value in range(min_input_range, max_input_range)
]

View File

@@ -19,7 +19,10 @@ from ..common.optimization.topological import fuse_float_operations
from ..common.representation import intermediate as ir
from ..common.values import BaseValue
from ..numpy.tracing import trace_numpy_function
from .np_dtypes_helpers import get_base_data_type_for_numpy_or_python_constant_data
from .np_dtypes_helpers import (
get_base_data_type_for_numpy_or_python_constant_data,
get_type_constructor_for_numpy_or_python_constant_data,
)
def numpy_max_func(lhs: Any, rhs: Any) -> Any:
@@ -115,7 +118,9 @@ def _compile_numpy_function_into_op_graph_internal(
# Update the graph accordingly: after that, we have the compilable graph
op_graph.update_values_with_bounds(
node_bounds, get_base_data_type_for_numpy_or_python_constant_data
node_bounds,
get_base_data_type_for_numpy_or_python_constant_data,
get_type_constructor_for_numpy_or_python_constant_data,
)
# Add the initial graph as an artifact

View File

@@ -2,7 +2,7 @@
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, List, Union
from typing import Any, Callable, Dict, List, Type, Union
import numpy
from numpy.typing import DTypeLike
@@ -12,6 +12,7 @@ from ..common.data_types.dtypes_helpers import (
BASE_DATA_TYPES,
get_base_data_type_for_python_constant_data,
get_base_value_for_python_constant_data,
get_type_constructor_for_python_constant_data,
)
from ..common.data_types.floats import Float
from ..common.data_types.integers import Integer
@@ -193,3 +194,24 @@ def get_numpy_function_output_dtype(
numpy.seterr(**old_numpy_err_settings)
return [output.dtype for output in outputs]
def get_type_constructor_for_numpy_or_python_constant_data(constant_data: Any):
"""Get the constructor for the numpy scalar underlying dtype or python dtype.
Args:
constant_data (Any): The data for which we want to determine the type constructor.
"""
assert isinstance(
constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
), f"Unsupported constant data of type {type(constant_data)}"
scalar_constructor: Type
if isinstance(constant_data, (numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)):
scalar_constructor = constant_data.dtype.type
else:
scalar_constructor = get_type_constructor_for_python_constant_data(constant_data)
return scalar_constructor

38
poetry.lock generated
View File

@@ -179,7 +179,7 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""}
name = "colorama"
version = "0.4.4"
description = "Cross-platform colored terminal text."
category = "dev"
category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
@@ -571,6 +571,21 @@ category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
[[package]]
name = "loguru"
version = "0.5.3"
description = "Python logging made (stupidly) simple"
category = "main"
optional = false
python-versions = ">=3.5"
[package.dependencies]
colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""}
win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""}
[package.extras]
dev = ["codecov (>=2.0.15)", "colorama (>=0.3.4)", "flake8 (>=3.7.7)", "tox (>=3.9.0)", "tox-travis (>=0.12)", "pytest (>=4.6.2)", "pytest-cov (>=2.7.1)", "Sphinx (>=2.2.1)", "sphinx-autobuild (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "black (>=19.10b0)", "isort (>=5.1.1)"]
[[package]]
name = "markdown-it-py"
version = "1.1.0"
@@ -1504,6 +1519,17 @@ python-versions = "*"
[package.dependencies]
notebook = ">=4.4.1"
[[package]]
name = "win32-setctime"
version = "1.0.3"
description = "A small Python utility to set file creation time on Windows"
category = "main"
optional = false
python-versions = ">=3.5"
[package.extras]
dev = ["pytest (>=4.6.2)", "black (>=19.3b0)"]
[[package]]
name = "wrapt"
version = "1.12.1"
@@ -1515,7 +1541,7 @@ python-versions = "*"
[metadata]
lock-version = "1.1"
python-versions = ">=3.8,<3.9"
content-hash = "cb1c1db2c4f94ed4984e565d1b5c0633bcf58a57cd667ff09a18e01e73382aa4"
content-hash = "8a3be3fe122eddfb9a28a4f789c4581311ada706406277b5e84beb431369163b"
[metadata.files]
alabaster = [
@@ -1871,6 +1897,10 @@ lazy-object-proxy = [
{file = "lazy_object_proxy-1.6.0-cp39-cp39-win32.whl", hash = "sha256:1fee665d2638491f4d6e55bd483e15ef21f6c8c2095f235fef72601021e64f61"},
{file = "lazy_object_proxy-1.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:f5144c75445ae3ca2057faac03fda5a902eff196702b0a24daf1d6ce0650514b"},
]
loguru = [
{file = "loguru-0.5.3-py3-none-any.whl", hash = "sha256:f8087ac396b5ee5f67c963b495d615ebbceac2796379599820e324419d53667c"},
{file = "loguru-0.5.3.tar.gz", hash = "sha256:b28e72ac7a98be3d28ad28570299a393dfcd32e5e3f6a353dec94675767b6319"},
]
markdown-it-py = [
{file = "markdown-it-py-1.1.0.tar.gz", hash = "sha256:36be6bb3ad987bfdb839f5ba78ddf094552ca38ccbd784ae4f74a4e1419fc6e3"},
{file = "markdown_it_py-1.1.0-py3-none-any.whl", hash = "sha256:98080fc0bc34c4f2bcf0846a096a9429acbd9d5d8e67ed34026c03c61c464389"},
@@ -2538,6 +2568,10 @@ widgetsnbextension = [
{file = "widgetsnbextension-3.5.1-py2.py3-none-any.whl", hash = "sha256:bd314f8ceb488571a5ffea6cc5b9fc6cba0adaf88a9d2386b93a489751938bcd"},
{file = "widgetsnbextension-3.5.1.tar.gz", hash = "sha256:079f87d87270bce047512400efd70238820751a11d2d8cb137a5a5bdbaf255c7"},
]
win32-setctime = [
{file = "win32_setctime-1.0.3-py3-none-any.whl", hash = "sha256:dc925662de0a6eb987f0b01f599c01a8236cb8c62831c22d9cada09ad958243e"},
{file = "win32_setctime-1.0.3.tar.gz", hash = "sha256:4e88556c32fdf47f64165a2180ba4552f8bb32c1103a2fafd05723a0bd42bd4b"},
]
wrapt = [
{file = "wrapt-1.12.1.tar.gz", hash = "sha256:b62ffa81fb85f4332a4f609cab4ac40709470da05643a082ec1eb88e6d9b97d7"},
]

View File

@@ -14,6 +14,7 @@ matplotlib = "^3.4.2"
numpy = "^1.21.1"
pygraphviz = "^1.7"
Pillow = "^8.3.1"
loguru = "^0.5.3"
[tool.poetry.dev-dependencies]
isort = "^5.9.2"

View File

@@ -8,6 +8,7 @@ from concrete.common.data_types.integers import Integer
from concrete.numpy.np_dtypes_helpers import (
convert_base_data_type_to_numpy_dtype,
convert_numpy_dtype_to_base_data_type,
get_type_constructor_for_numpy_or_python_constant_data,
)
@@ -55,3 +56,23 @@ def test_convert_numpy_dtype_to_base_data_type(numpy_dtype, expected_common_type
def test_convert_common_dtype_to_numpy_dtype(common_dtype, expected_numpy_dtype):
"""Test function for convert_common_dtype_to_numpy_dtype"""
assert expected_numpy_dtype == convert_base_data_type_to_numpy_dtype(common_dtype)
@pytest.mark.parametrize(
"constant_data,expected_constructor",
[
(10, int),
(42.0, float),
(numpy.int32(10), numpy.int32),
(numpy.array([[0, 1], [3, 4]], dtype=numpy.uint64), numpy.uint64),
(numpy.array([[0, 1], [3, 4]], dtype=numpy.float64), numpy.float64),
],
)
def test_get_type_constructor_for_numpy_or_python_constant_data(
constant_data, expected_constructor
):
"""Test function for get_type_constructor_for_numpy_or_python_constant_data"""
assert expected_constructor == get_type_constructor_for_numpy_or_python_constant_data(
constant_data
)