refactor: update OPGraph to be able to update bounds with foreign types

- get BaseDataType for values being checked in update_values_with_bounds
- rename SUPPORTED_TYPES to BASE_DATA_TYPES
This commit is contained in:
Arthur Meyre
2021-08-19 15:11:11 +02:00
parent 60daf31981
commit 4f103e604a
3 changed files with 52 additions and 13 deletions

View File

@@ -1,16 +1,16 @@
"""File to hold helper functions for data types related stuff."""
from copy import deepcopy
from typing import cast
from typing import Union, cast
from .base import BaseDataType
from .floats import Float
from .integers import Integer
from .integers import Integer, get_bits_to_represent_value_as_integer
from .values import BaseValue, ClearValue, EncryptedValue, ScalarValue
INTEGER_TYPES = (Integer,)
FLOAT_TYPES = (Float,)
SUPPORTED_TYPES = INTEGER_TYPES + FLOAT_TYPES
BASE_DATA_TYPES = INTEGER_TYPES + FLOAT_TYPES
def value_is_encrypted_integer(value_to_check: BaseValue) -> bool:
@@ -93,8 +93,8 @@ def find_type_to_hold_both_lossy(
Returns:
BaseDataType: The dtype able to hold (potentially lossy) dtype1 and dtype2
"""
assert isinstance(dtype1, SUPPORTED_TYPES), f"Unsupported dtype1: {type(dtype1)}"
assert isinstance(dtype2, SUPPORTED_TYPES), f"Unsupported dtype2: {type(dtype2)}"
assert isinstance(dtype1, BASE_DATA_TYPES), f"Unsupported dtype1: {type(dtype1)}"
assert isinstance(dtype2, BASE_DATA_TYPES), f"Unsupported dtype2: {type(dtype2)}"
type_to_return: BaseDataType
@@ -161,3 +161,27 @@ def mix_scalar_values_determine_holding_dtype(value1: BaseValue, value2: BaseVal
mixed_value = ClearValue(holding_type)
return mixed_value
def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]) -> BaseDataType:
"""Helper function to determine the BaseDataType to hold the input constant data.
Args:
constant_data (Union[int, float]): The constant data for which to determine the
corresponding BaseDataType.
Returns:
BaseDataType: The corresponding BaseDataType
"""
constant_data_type: BaseDataType
assert isinstance(
constant_data, (int, float)
), f"Unsupported constant data of type {type(constant_data)}"
if isinstance(constant_data, int):
is_signed = constant_data < 0
constant_data_type = Integer(
get_bits_to_represent_value_as_integer(constant_data, is_signed), is_signed
)
elif isinstance(constant_data, float):
constant_data_type = Float(64)
return constant_data_type

View File

@@ -1,12 +1,14 @@
"""Code to wrap and make manipulating networkx graphs easier."""
from copy import deepcopy
from typing import Any, Dict, Iterable, List, Set, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Union
import networkx as nx
from .data_types.base import BaseDataType
from .data_types.dtypes_helpers import get_base_data_type_for_python_constant_data
from .data_types.floats import Float
from .data_types.integers import make_integer_to_hold
from .data_types.integers import Integer, make_integer_to_hold
from .representation import intermediate as ir
from .tracing import BaseTracer
from .tracing.tracing_helpers import create_graph_from_output_tracers
@@ -130,7 +132,13 @@ class OPGraph:
return node_results
def update_values_with_bounds(self, node_bounds: dict):
def update_values_with_bounds(
self,
node_bounds: dict,
get_base_data_type_for_constant_data: Callable[
[Any], BaseDataType
] = get_base_data_type_for_python_constant_data,
):
"""Update values with bounds.
Update nodes inputs and outputs values with data types able to hold data ranges measured
@@ -139,6 +147,10 @@ class OPGraph:
Args:
node_bounds (dict): Dictionary with nodes as keys, holding dicts with a 'min' and 'max'
keys. Those bounds will be taken as the data range to be represented, per node.
get_base_data_type_for_constant_data (Callable[ [Type], BaseDataType ], optional): This
is a callback function to convert data encountered during value updates to
BaseDataType. This allows to manage data coming from foreign frameworks without
specialising OPGraph. Defaults to get_base_data_type_for_python_constant_data.
"""
node: ir.IntermediateNode
@@ -149,9 +161,12 @@ class OPGraph:
current_node_bounds["max"],
)
min_data_type = get_base_data_type_for_constant_data(min_bound)
max_data_type = get_base_data_type_for_constant_data(max_bound)
if not isinstance(node, ir.Input):
for output_value in node.outputs:
if isinstance(min_bound, int) and isinstance(max_bound, int):
if isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer):
output_value.data_type = make_integer_to_hold(
(min_bound, max_bound), force_signed=False
)
@@ -159,8 +174,8 @@ class OPGraph:
output_value.data_type = Float(64)
else:
# Currently variable inputs are only allowed to be integers
assert isinstance(min_bound, int) and isinstance(max_bound, int), (
f"Inputs to a graph should be integers, got bounds that were not float, \n"
assert isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer), (
f"Inputs to a graph should be integers, got bounds that were float, \n"
f"min: {min_bound} ({type(min_bound)}), max: {max_bound} ({type(max_bound)})"
)
node.inputs[0].data_type = make_integer_to_hold(

View File

@@ -7,7 +7,7 @@ import numpy
from numpy.typing import DTypeLike
from ..common.data_types.base import BaseDataType
from ..common.data_types.dtypes_helpers import SUPPORTED_TYPES
from ..common.data_types.dtypes_helpers import BASE_DATA_TYPES
from ..common.data_types.floats import Float
from ..common.data_types.integers import Integer
@@ -62,7 +62,7 @@ def convert_common_dtype_to_numpy_dtype(common_dtype: BaseDataType) -> numpy.dty
numpy.dtype: The resulting numpy.dtype
"""
assert isinstance(
common_dtype, SUPPORTED_TYPES
common_dtype, BASE_DATA_TYPES
), f"Unsupported common_dtype: {type(common_dtype)}"
type_to_return: numpy.dtype