mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
dev: add code to have the proper data type when getting an AF table
- update BaseDataType to store the underlying type constructor e.g. int - add helper functions to get the type constructor for constant data - update operator graph to fill the type constructor during bounds update - add loguru as logger - use type constructor in ArbitraryFunction.get_table, log an info if the type_constructor of the input was None and default to int
This commit is contained in:
@@ -1,11 +1,19 @@
|
||||
"""File holding code to represent data types in a program."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Type
|
||||
|
||||
|
||||
class BaseDataType(ABC):
|
||||
"""Base class to represent a data type."""
|
||||
|
||||
# Constructor for the data type represented (for example numpy.int32 for an int32 numpy array)
|
||||
underlying_type_constructor: Optional[Type]
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.underlying_type_constructor = None
|
||||
|
||||
@abstractmethod
|
||||
def __eq__(self, o: object) -> bool:
|
||||
"""No default implementation."""
|
||||
|
||||
@@ -336,3 +336,12 @@ def get_base_value_for_python_constant_data(
|
||||
"""
|
||||
constant_data_type = get_base_data_type_for_python_constant_data(constant_data)
|
||||
return partial(ScalarValue, data_type=constant_data_type)
|
||||
|
||||
|
||||
def get_type_constructor_for_python_constant_data(constant_data: Union[int, float]):
|
||||
"""Get the constructor for the passed python constant data.
|
||||
|
||||
Args:
|
||||
constant_data (Any): The data for which we want to determine the type constructor.
|
||||
"""
|
||||
return type(constant_data)
|
||||
|
||||
@@ -13,6 +13,7 @@ class Float(base.BaseDataType):
|
||||
bit_width: int
|
||||
|
||||
def __init__(self, bit_width: int) -> None:
|
||||
super().__init__()
|
||||
assert bit_width in (32, 64), "Only 32 and 64 bits floats are supported"
|
||||
self.bit_width = bit_width
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ class Integer(base.BaseDataType):
|
||||
is_signed: bool
|
||||
|
||||
def __init__(self, bit_width: int, is_signed: bool) -> None:
|
||||
super().__init__()
|
||||
assert bit_width > 0, "bit_width must be > 0"
|
||||
self.bit_width = bit_width
|
||||
self.is_signed = is_signed
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
"""Code to wrap and make manipulating networkx graphs easier."""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Type, Union
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from .data_types.base import BaseDataType
|
||||
from .data_types.dtypes_helpers import get_base_data_type_for_python_constant_data
|
||||
from .data_types.dtypes_helpers import (
|
||||
get_base_data_type_for_python_constant_data,
|
||||
get_type_constructor_for_python_constant_data,
|
||||
)
|
||||
from .data_types.floats import Float
|
||||
from .data_types.integers import Integer, make_integer_to_hold
|
||||
from .representation import intermediate as ir
|
||||
@@ -124,8 +127,9 @@ class OPGraph:
|
||||
curr_inputs = {}
|
||||
for pred_node in self.graph.pred[node]:
|
||||
edges = self.graph.get_edge_data(pred_node, node)
|
||||
for edge in edges.values():
|
||||
curr_inputs[edge["input_idx"]] = node_results[pred_node]
|
||||
curr_inputs.update(
|
||||
{edge["input_idx"]: node_results[pred_node] for edge in edges.values()}
|
||||
)
|
||||
node_results[node] = node.evaluate(curr_inputs)
|
||||
else:
|
||||
node_results[node] = node.evaluate({0: inputs[node.program_input_idx]})
|
||||
@@ -138,6 +142,9 @@ class OPGraph:
|
||||
get_base_data_type_for_constant_data: Callable[
|
||||
[Any], BaseDataType
|
||||
] = get_base_data_type_for_python_constant_data,
|
||||
get_type_constructor_for_constant_data: Callable[
|
||||
..., Type
|
||||
] = get_type_constructor_for_python_constant_data,
|
||||
):
|
||||
"""Update values with bounds.
|
||||
|
||||
@@ -147,10 +154,13 @@ class OPGraph:
|
||||
Args:
|
||||
node_bounds (dict): Dictionary with nodes as keys, holding dicts with a 'min' and 'max'
|
||||
keys. Those bounds will be taken as the data range to be represented, per node.
|
||||
get_base_data_type_for_constant_data (Callable[ [Type], BaseDataType ], optional): This
|
||||
get_base_data_type_for_constant_data (Callable[ [Any], BaseDataType ], optional): This
|
||||
is a callback function to convert data encountered during value updates to
|
||||
BaseDataType. This allows to manage data coming from foreign frameworks without
|
||||
specialising OPGraph. Defaults to get_base_data_type_for_python_constant_data.
|
||||
get_type_constructor_for_constant_data (Callable[ ..., Type ], optional): This is a
|
||||
callback function to determine the type constructor of the data encountered while
|
||||
updating the graph bounds. Defaults to get_type_constructor_python_constant_data.
|
||||
"""
|
||||
node: ir.IntermediateNode
|
||||
|
||||
@@ -164,6 +174,16 @@ class OPGraph:
|
||||
min_data_type = get_base_data_type_for_constant_data(min_bound)
|
||||
max_data_type = get_base_data_type_for_constant_data(max_bound)
|
||||
|
||||
min_data_type_constructor = get_type_constructor_for_constant_data(min_bound)
|
||||
max_data_type_constructor = get_type_constructor_for_constant_data(max_bound)
|
||||
|
||||
assert max_data_type_constructor == min_data_type_constructor, (
|
||||
f"Got two different type constructors for min and max bound: "
|
||||
f"{min_data_type_constructor}, {max_data_type_constructor}"
|
||||
)
|
||||
|
||||
data_type_constructor = max_data_type_constructor
|
||||
|
||||
if not isinstance(node, ir.Input):
|
||||
for output_value in node.outputs:
|
||||
if isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer):
|
||||
@@ -171,7 +191,15 @@ class OPGraph:
|
||||
(min_bound, max_bound), force_signed=False
|
||||
)
|
||||
else:
|
||||
assert isinstance(min_data_type, Float) and isinstance(
|
||||
max_data_type, Float
|
||||
), (
|
||||
"min_bound and max_bound have different common types, "
|
||||
"this should never happen.\n"
|
||||
f"min_bound: {min_data_type}, max_bound: {max_data_type}"
|
||||
)
|
||||
output_value.data_type = Float(64)
|
||||
output_value.data_type.underlying_type_constructor = data_type_constructor
|
||||
else:
|
||||
# Currently variable inputs are only allowed to be integers
|
||||
assert isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer), (
|
||||
@@ -181,6 +209,8 @@ class OPGraph:
|
||||
node.inputs[0].data_type = make_integer_to_hold(
|
||||
(min_bound, max_bound), force_signed=False
|
||||
)
|
||||
node.inputs[0].data_type.underlying_type_constructor = data_type_constructor
|
||||
|
||||
node.outputs[0] = deepcopy(node.inputs[0])
|
||||
|
||||
# TODO: #57 manage multiple outputs from a node, probably requires an output_idx when
|
||||
|
||||
@@ -4,6 +4,8 @@ from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from ..data_types.base import BaseDataType
|
||||
from ..data_types.dtypes_helpers import (
|
||||
get_base_value_for_python_constant_data,
|
||||
@@ -232,6 +234,8 @@ class ArbitraryFunction(IntermediateNode):
|
||||
def get_table(self) -> List[Any]:
|
||||
"""Get the table for the current input value of this ArbitraryFunction.
|
||||
|
||||
This function only works if the ArbitraryFunction input value is an unsigned Integer.
|
||||
|
||||
Returns:
|
||||
List[Any]: The table.
|
||||
"""
|
||||
@@ -243,11 +247,18 @@ class ArbitraryFunction(IntermediateNode):
|
||||
0
|
||||
].data_type.is_signed, "get_table only works for an unsigned Integer input"
|
||||
|
||||
type_constructor = self.inputs[0].data_type.underlying_type_constructor
|
||||
if type_constructor is None:
|
||||
logger.info(
|
||||
f"{self.__class__.__name__} input data type constructor was None, defaulting to int"
|
||||
)
|
||||
type_constructor = int
|
||||
|
||||
min_input_range = self.inputs[0].data_type.min_value()
|
||||
max_input_range = self.inputs[0].data_type.max_value() + 1
|
||||
|
||||
table = [
|
||||
self.evaluate({0: input_value})
|
||||
self.evaluate({0: type_constructor(input_value)})
|
||||
for input_value in range(min_input_range, max_input_range)
|
||||
]
|
||||
|
||||
|
||||
@@ -19,7 +19,10 @@ from ..common.optimization.topological import fuse_float_operations
|
||||
from ..common.representation import intermediate as ir
|
||||
from ..common.values import BaseValue
|
||||
from ..numpy.tracing import trace_numpy_function
|
||||
from .np_dtypes_helpers import get_base_data_type_for_numpy_or_python_constant_data
|
||||
from .np_dtypes_helpers import (
|
||||
get_base_data_type_for_numpy_or_python_constant_data,
|
||||
get_type_constructor_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
|
||||
def numpy_max_func(lhs: Any, rhs: Any) -> Any:
|
||||
@@ -115,7 +118,9 @@ def _compile_numpy_function_into_op_graph_internal(
|
||||
|
||||
# Update the graph accordingly: after that, we have the compilable graph
|
||||
op_graph.update_values_with_bounds(
|
||||
node_bounds, get_base_data_type_for_numpy_or_python_constant_data
|
||||
node_bounds,
|
||||
get_base_data_type_for_numpy_or_python_constant_data,
|
||||
get_type_constructor_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
# Add the initial graph as an artifact
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Union
|
||||
from typing import Any, Callable, Dict, List, Type, Union
|
||||
|
||||
import numpy
|
||||
from numpy.typing import DTypeLike
|
||||
@@ -12,6 +12,7 @@ from ..common.data_types.dtypes_helpers import (
|
||||
BASE_DATA_TYPES,
|
||||
get_base_data_type_for_python_constant_data,
|
||||
get_base_value_for_python_constant_data,
|
||||
get_type_constructor_for_python_constant_data,
|
||||
)
|
||||
from ..common.data_types.floats import Float
|
||||
from ..common.data_types.integers import Integer
|
||||
@@ -193,3 +194,24 @@ def get_numpy_function_output_dtype(
|
||||
numpy.seterr(**old_numpy_err_settings)
|
||||
|
||||
return [output.dtype for output in outputs]
|
||||
|
||||
|
||||
def get_type_constructor_for_numpy_or_python_constant_data(constant_data: Any):
|
||||
"""Get the constructor for the numpy scalar underlying dtype or python dtype.
|
||||
|
||||
Args:
|
||||
constant_data (Any): The data for which we want to determine the type constructor.
|
||||
"""
|
||||
|
||||
assert isinstance(
|
||||
constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
|
||||
), f"Unsupported constant data of type {type(constant_data)}"
|
||||
|
||||
scalar_constructor: Type
|
||||
|
||||
if isinstance(constant_data, (numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)):
|
||||
scalar_constructor = constant_data.dtype.type
|
||||
else:
|
||||
scalar_constructor = get_type_constructor_for_python_constant_data(constant_data)
|
||||
|
||||
return scalar_constructor
|
||||
|
||||
38
poetry.lock
generated
38
poetry.lock
generated
@@ -179,7 +179,7 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""}
|
||||
name = "colorama"
|
||||
version = "0.4.4"
|
||||
description = "Cross-platform colored terminal text."
|
||||
category = "dev"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
||||
|
||||
@@ -571,6 +571,21 @@ category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
|
||||
|
||||
[[package]]
|
||||
name = "loguru"
|
||||
version = "0.5.3"
|
||||
description = "Python logging made (stupidly) simple"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.5"
|
||||
|
||||
[package.dependencies]
|
||||
colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""}
|
||||
win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""}
|
||||
|
||||
[package.extras]
|
||||
dev = ["codecov (>=2.0.15)", "colorama (>=0.3.4)", "flake8 (>=3.7.7)", "tox (>=3.9.0)", "tox-travis (>=0.12)", "pytest (>=4.6.2)", "pytest-cov (>=2.7.1)", "Sphinx (>=2.2.1)", "sphinx-autobuild (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "black (>=19.10b0)", "isort (>=5.1.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "markdown-it-py"
|
||||
version = "1.1.0"
|
||||
@@ -1504,6 +1519,17 @@ python-versions = "*"
|
||||
[package.dependencies]
|
||||
notebook = ">=4.4.1"
|
||||
|
||||
[[package]]
|
||||
name = "win32-setctime"
|
||||
version = "1.0.3"
|
||||
description = "A small Python utility to set file creation time on Windows"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.5"
|
||||
|
||||
[package.extras]
|
||||
dev = ["pytest (>=4.6.2)", "black (>=19.3b0)"]
|
||||
|
||||
[[package]]
|
||||
name = "wrapt"
|
||||
version = "1.12.1"
|
||||
@@ -1515,7 +1541,7 @@ python-versions = "*"
|
||||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = ">=3.8,<3.9"
|
||||
content-hash = "cb1c1db2c4f94ed4984e565d1b5c0633bcf58a57cd667ff09a18e01e73382aa4"
|
||||
content-hash = "8a3be3fe122eddfb9a28a4f789c4581311ada706406277b5e84beb431369163b"
|
||||
|
||||
[metadata.files]
|
||||
alabaster = [
|
||||
@@ -1871,6 +1897,10 @@ lazy-object-proxy = [
|
||||
{file = "lazy_object_proxy-1.6.0-cp39-cp39-win32.whl", hash = "sha256:1fee665d2638491f4d6e55bd483e15ef21f6c8c2095f235fef72601021e64f61"},
|
||||
{file = "lazy_object_proxy-1.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:f5144c75445ae3ca2057faac03fda5a902eff196702b0a24daf1d6ce0650514b"},
|
||||
]
|
||||
loguru = [
|
||||
{file = "loguru-0.5.3-py3-none-any.whl", hash = "sha256:f8087ac396b5ee5f67c963b495d615ebbceac2796379599820e324419d53667c"},
|
||||
{file = "loguru-0.5.3.tar.gz", hash = "sha256:b28e72ac7a98be3d28ad28570299a393dfcd32e5e3f6a353dec94675767b6319"},
|
||||
]
|
||||
markdown-it-py = [
|
||||
{file = "markdown-it-py-1.1.0.tar.gz", hash = "sha256:36be6bb3ad987bfdb839f5ba78ddf094552ca38ccbd784ae4f74a4e1419fc6e3"},
|
||||
{file = "markdown_it_py-1.1.0-py3-none-any.whl", hash = "sha256:98080fc0bc34c4f2bcf0846a096a9429acbd9d5d8e67ed34026c03c61c464389"},
|
||||
@@ -2538,6 +2568,10 @@ widgetsnbextension = [
|
||||
{file = "widgetsnbextension-3.5.1-py2.py3-none-any.whl", hash = "sha256:bd314f8ceb488571a5ffea6cc5b9fc6cba0adaf88a9d2386b93a489751938bcd"},
|
||||
{file = "widgetsnbextension-3.5.1.tar.gz", hash = "sha256:079f87d87270bce047512400efd70238820751a11d2d8cb137a5a5bdbaf255c7"},
|
||||
]
|
||||
win32-setctime = [
|
||||
{file = "win32_setctime-1.0.3-py3-none-any.whl", hash = "sha256:dc925662de0a6eb987f0b01f599c01a8236cb8c62831c22d9cada09ad958243e"},
|
||||
{file = "win32_setctime-1.0.3.tar.gz", hash = "sha256:4e88556c32fdf47f64165a2180ba4552f8bb32c1103a2fafd05723a0bd42bd4b"},
|
||||
]
|
||||
wrapt = [
|
||||
{file = "wrapt-1.12.1.tar.gz", hash = "sha256:b62ffa81fb85f4332a4f609cab4ac40709470da05643a082ec1eb88e6d9b97d7"},
|
||||
]
|
||||
|
||||
@@ -14,6 +14,7 @@ matplotlib = "^3.4.2"
|
||||
numpy = "^1.21.1"
|
||||
pygraphviz = "^1.7"
|
||||
Pillow = "^8.3.1"
|
||||
loguru = "^0.5.3"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
isort = "^5.9.2"
|
||||
|
||||
@@ -8,6 +8,7 @@ from concrete.common.data_types.integers import Integer
|
||||
from concrete.numpy.np_dtypes_helpers import (
|
||||
convert_base_data_type_to_numpy_dtype,
|
||||
convert_numpy_dtype_to_base_data_type,
|
||||
get_type_constructor_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
|
||||
@@ -55,3 +56,23 @@ def test_convert_numpy_dtype_to_base_data_type(numpy_dtype, expected_common_type
|
||||
def test_convert_common_dtype_to_numpy_dtype(common_dtype, expected_numpy_dtype):
|
||||
"""Test function for convert_common_dtype_to_numpy_dtype"""
|
||||
assert expected_numpy_dtype == convert_base_data_type_to_numpy_dtype(common_dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"constant_data,expected_constructor",
|
||||
[
|
||||
(10, int),
|
||||
(42.0, float),
|
||||
(numpy.int32(10), numpy.int32),
|
||||
(numpy.array([[0, 1], [3, 4]], dtype=numpy.uint64), numpy.uint64),
|
||||
(numpy.array([[0, 1], [3, 4]], dtype=numpy.float64), numpy.float64),
|
||||
],
|
||||
)
|
||||
def test_get_type_constructor_for_numpy_or_python_constant_data(
|
||||
constant_data, expected_constructor
|
||||
):
|
||||
"""Test function for get_type_constructor_for_numpy_or_python_constant_data"""
|
||||
|
||||
assert expected_constructor == get_type_constructor_for_numpy_or_python_constant_data(
|
||||
constant_data
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user