refactor(tracing): preparatory work for tensor table generation

- removed underlying_type_constructor from BaseDataType as it was scalar
specific and put it in values
- update inputset_eval to keep a sample of intermediate node values
- allows to get the proper value constructor to be used in
UnivariateFunction get_table and have tensors as inputs
This commit is contained in:
Arthur Meyre
2021-10-13 14:50:37 +02:00
parent 67a9bf12ca
commit 95c48a419c
12 changed files with 102 additions and 78 deletions

View File

@@ -135,7 +135,7 @@ def eval_op_graph_bounds_on_inputset(
Returns:
Tuple[int, Dict[IntermediateNode, Dict[str, Any]]]: number of inputs in the inputset and
a dict containing the bounds for each node from op_graph, stored with the node
as key and a dict with keys "min" and "max" as value.
as key and a dict with keys "min", "max" and "sample" as value.
"""
def check_inputset_input_len_is_valid(data_to_check):
@@ -178,8 +178,12 @@ def eval_op_graph_bounds_on_inputset(
# We evaluate the min and max func to be able to resolve the tensors min and max rather than
# having the tensor itself as the stored min and max values.
node_bounds = {
node: {"min": min_func(value, value), "max": max_func(value, value)}
node_bounds_and_samples = {
node: {
"min": min_func(value, value),
"max": max_func(value, value),
"sample": value,
}
for node, value in first_output.items()
}
@@ -200,7 +204,11 @@ def eval_op_graph_bounds_on_inputset(
current_output = op_graph.evaluate(current_input_data)
for node, value in current_output.items():
node_bounds[node]["min"] = min_func(node_bounds[node]["min"], value)
node_bounds[node]["max"] = max_func(node_bounds[node]["max"], value)
node_bounds_and_samples[node]["min"] = min_func(
node_bounds_and_samples[node]["min"], value
)
node_bounds_and_samples[node]["max"] = max_func(
node_bounds_and_samples[node]["max"], value
)
return inputset_size, node_bounds
return inputset_size, node_bounds_and_samples

View File

@@ -1,19 +1,11 @@
"""File holding code to represent data types in a program."""
from abc import ABC, abstractmethod
from typing import Optional, Type
class BaseDataType(ABC):
"""Base class to represent a data type."""
# Constructor for the data type represented (for example numpy.int32 for an int32 numpy array)
underlying_type_constructor: Optional[Type]
def __init__(self) -> None:
super().__init__()
self.underlying_type_constructor = None
@abstractmethod
def __eq__(self, o: object) -> bool:
"""No default implementation."""

View File

@@ -312,7 +312,7 @@ def get_base_value_for_python_constant_data(
return partial(TensorValue, dtype=constant_data_type, shape=())
def get_type_constructor_for_python_constant_data(constant_data: Union[int, float]):
def get_constructor_for_python_constant_data(constant_data: Union[int, float]):
"""Get the constructor for the passed python constant data.
Args:

View File

@@ -1,14 +1,14 @@
"""Code to wrap and make manipulating networkx graphs easier."""
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Union
import networkx as nx
from .data_types.base import BaseDataType
from .data_types.dtypes_helpers import (
get_base_data_type_for_python_constant_data,
get_type_constructor_for_python_constant_data,
get_constructor_for_python_constant_data,
)
from .data_types.floats import Float
from .data_types.integers import Integer, make_integer_to_hold
@@ -171,15 +171,15 @@ class OPGraph:
return node_results
def update_values_with_bounds(
def update_values_with_bounds_and_samples(
self,
node_bounds: dict,
node_bounds_and_samples: dict,
get_base_data_type_for_constant_data: Callable[
[Any], BaseDataType
] = get_base_data_type_for_python_constant_data,
get_type_constructor_for_constant_data: Callable[
..., Type
] = get_type_constructor_for_python_constant_data,
get_constructor_for_constant_data: Callable[
..., Callable
] = get_constructor_for_python_constant_data,
):
"""Update values with bounds.
@@ -187,40 +187,44 @@ class OPGraph:
and passed in nodes_bounds
Args:
node_bounds (dict): Dictionary with nodes as keys, holding dicts with a 'min' and 'max'
keys. Those bounds will be taken as the data range to be represented, per node.
node_bounds_and_samples (dict): Dictionary with nodes as keys, holding dicts with a
'min', 'max' and 'sample' keys. Those bounds will be taken as the data range to be
represented, per node. The sample allows to determine the data constructors to
prepare the UnivariateFunction nodes for table generation.
get_base_data_type_for_constant_data (Callable[ [Any], BaseDataType ], optional): This
is a callback function to convert data encountered during value updates to
BaseDataType. This allows to manage data coming from foreign frameworks without
specialising OPGraph. Defaults to get_base_data_type_for_python_constant_data.
get_type_constructor_for_constant_data (Callable[ ..., Type ], optional): This is a
get_constructor_for_constant_data (Callable[ ..., Callable ], optional): This is a
callback function to determine the type constructor of the data encountered while
updating the graph bounds. Defaults to get_type_constructor_python_constant_data.
updating the graph bounds. Defaults to get_constructor_for_python_constant_data.
"""
node: IntermediateNode
for node in self.graph.nodes():
current_node_bounds = node_bounds[node]
min_bound, max_bound = (
current_node_bounds["min"],
current_node_bounds["max"],
current_node_bounds_and_samples = node_bounds_and_samples[node]
min_bound, max_bound, sample = (
current_node_bounds_and_samples["min"],
current_node_bounds_and_samples["max"],
current_node_bounds_and_samples["sample"],
)
min_data_type = get_base_data_type_for_constant_data(min_bound)
max_data_type = get_base_data_type_for_constant_data(max_bound)
min_data_type_constructor = get_type_constructor_for_constant_data(min_bound)
max_data_type_constructor = get_type_constructor_for_constant_data(max_bound)
# This is a sanity check
min_value_constructor = get_constructor_for_constant_data(min_bound)
max_value_constructor = get_constructor_for_constant_data(max_bound)
assert_true(
max_data_type_constructor == min_data_type_constructor,
max_value_constructor == min_value_constructor,
(
f"Got two different type constructors for min and max bound: "
f"{min_data_type_constructor}, {max_data_type_constructor}"
f"{min_value_constructor}, {max_value_constructor}"
),
)
data_type_constructor = max_data_type_constructor
value_constructor = get_constructor_for_constant_data(sample)
if not isinstance(node, Input):
for output_value in node.outputs:
@@ -238,7 +242,7 @@ class OPGraph:
),
)
output_value.dtype = Float(64)
output_value.dtype.underlying_type_constructor = data_type_constructor
output_value.underlying_constructor = value_constructor
else:
# Currently variable inputs are only allowed to be integers
assert_true(
@@ -252,7 +256,7 @@ class OPGraph:
node.inputs[0].dtype = make_integer_to_hold(
(min_bound, max_bound), force_signed=False
)
node.inputs[0].dtype.underlying_type_constructor = data_type_constructor
node.inputs[0].underlying_constructor = value_constructor
node.outputs[0] = deepcopy(node.inputs[0])

View File

@@ -250,26 +250,25 @@ class UnivariateFunction(IntermediateNode):
Returns:
List[Any]: The table.
"""
input_dtype = self.inputs[0].dtype
# Check the input is an unsigned integer to be able to build a table
assert isinstance(
self.inputs[0].dtype, Integer
input_dtype, Integer
), "get_table only works for an unsigned Integer input"
assert not self.inputs[
0
].dtype.is_signed, "get_table only works for an unsigned Integer input"
assert not input_dtype.is_signed, "get_table only works for an unsigned Integer input"
type_constructor = self.inputs[0].dtype.underlying_type_constructor
if type_constructor is None:
input_value_constructor = self.inputs[0].underlying_constructor
if input_value_constructor is None:
logger.info(
f"{self.__class__.__name__} input data type constructor was None, defaulting to int"
)
type_constructor = int
input_value_constructor = int
min_input_range = self.inputs[0].dtype.min_value()
max_input_range = self.inputs[0].dtype.max_value() + 1
min_input_range = input_dtype.min_value()
max_input_range = input_dtype.max_value() + 1
table = [
self.evaluate({0: type_constructor(input_value)})
self.evaluate({0: input_value_constructor(input_value)})
for input_value in range(min_input_range, max_input_range)
]

View File

@@ -2,6 +2,7 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Callable, Optional
from ..data_types.base import BaseDataType
@@ -11,10 +12,12 @@ class BaseValue(ABC):
dtype: BaseDataType
_is_encrypted: bool
underlying_constructor: Optional[Callable]
def __init__(self, dtype: BaseDataType, is_encrypted: bool) -> None:
self.dtype = deepcopy(dtype)
self._is_encrypted = is_encrypted
self.underlying_constructor = None
def __repr__(self) -> str: # pragma: no cover
return str(self)

View File

@@ -26,7 +26,7 @@ from ..numpy.tracing import trace_numpy_function
from .np_dtypes_helpers import (
get_base_data_type_for_numpy_or_python_constant_data,
get_base_value_for_numpy_or_python_constant_data,
get_type_constructor_for_numpy_or_python_constant_data,
get_constructor_for_numpy_or_python_constant_data,
)
@@ -111,7 +111,7 @@ def _compile_numpy_function_into_op_graph_internal(
)
# Find bounds with the inputset
inputset_size, node_bounds = eval_op_graph_bounds_on_inputset(
inputset_size, node_bounds_and_samples = eval_op_graph_bounds_on_inputset(
op_graph,
inputset,
compilation_configuration=compilation_configuration,
@@ -149,13 +149,13 @@ def _compile_numpy_function_into_op_graph_internal(
sys.stderr.write(f"Warning: {message}")
# Add the bounds as an artifact
compilation_artifacts.add_final_operation_graph_bounds(node_bounds)
compilation_artifacts.add_final_operation_graph_bounds(node_bounds_and_samples)
# Update the graph accordingly: after that, we have the compilable graph
op_graph.update_values_with_bounds(
node_bounds,
op_graph.update_values_with_bounds_and_samples(
node_bounds_and_samples,
get_base_data_type_for_numpy_or_python_constant_data,
get_type_constructor_for_numpy_or_python_constant_data,
get_constructor_for_numpy_or_python_constant_data,
)
# Add the initial graph as an artifact

View File

@@ -2,7 +2,7 @@
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, List, Type, Union
from typing import Any, Callable, Dict, List, Union
import numpy
from numpy.typing import DTypeLike
@@ -13,7 +13,7 @@ from ..common.data_types.dtypes_helpers import (
find_type_to_hold_both_lossy,
get_base_data_type_for_python_constant_data,
get_base_value_for_python_constant_data,
get_type_constructor_for_python_constant_data,
get_constructor_for_python_constant_data,
)
from ..common.data_types.floats import Float
from ..common.data_types.integers import Integer
@@ -224,8 +224,8 @@ def get_numpy_function_output_dtype(
return [output.dtype for output in outputs]
def get_type_constructor_for_numpy_or_python_constant_data(constant_data: Any):
"""Get the constructor for the numpy scalar underlying dtype or python dtype.
def get_constructor_for_numpy_or_python_constant_data(constant_data: Any):
"""Get the constructor for the numpy constant data or python dtype.
Args:
constant_data (Any): The data for which we want to determine the type constructor.
@@ -236,11 +236,8 @@ def get_type_constructor_for_numpy_or_python_constant_data(constant_data: Any):
f"Unsupported constant data of type {type(constant_data)}",
)
scalar_constructor: Type
if isinstance(constant_data, (numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)):
scalar_constructor = constant_data.dtype.type
else:
scalar_constructor = get_type_constructor_for_python_constant_data(constant_data)
return scalar_constructor
if isinstance(constant_data, numpy.ndarray):
return lambda x: numpy.full(constant_data.shape, x, dtype=constant_data.dtype)
return constant_data.dtype.type
return get_constructor_for_python_constant_data(constant_data)

View File

@@ -283,17 +283,17 @@ def test_eval_op_graph_bounds_on_inputset_multiple_output(
for y_gen in range_y:
yield (x_gen, y_gen)
_, node_bounds = eval_op_graph_bounds_on_inputset(
_, node_bounds_and_samples = eval_op_graph_bounds_on_inputset(
op_graph,
data_gen(*tuple(range(x[0], x[1] + 1) for x in input_ranges)),
CompilationConfiguration(),
)
for i, output_node in op_graph.output_nodes.items():
output_node_bounds = node_bounds[output_node]
output_node_bounds = node_bounds_and_samples[output_node]
assert (output_node_bounds["min"], output_node_bounds["max"]) == expected_output_bounds[i]
op_graph.update_values_with_bounds(node_bounds)
op_graph.update_values_with_bounds_and_samples(node_bounds_and_samples)
for i, output_node in op_graph.output_nodes.items():
assert expected_output_data_type[i] == output_node.outputs[0].dtype

View File

@@ -215,6 +215,14 @@ class TestHelpers:
return graphs_are_isomorphic
@staticmethod
def python_functions_are_equal_or_equivalent(lhs, rhs):
"""Helper function to check if two functions are equal or their code are equivalent.
This is not perfect, but will be good enough for tests.
"""
return python_functions_are_equal_or_equivalent(lhs, rhs)
@pytest.fixture
def test_helpers():

View File

@@ -330,7 +330,7 @@ def test_fail_compile(function, input_ranges, list_of_arg_names):
}
with pytest.raises(RuntimeError, match=".*isn't supported for MLIR lowering.*"):
compile_numpy_function_into_op_graph(
compile_numpy_function(
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),

View File

@@ -9,7 +9,7 @@ from concrete.numpy.np_dtypes_helpers import (
convert_base_data_type_to_numpy_dtype,
convert_numpy_dtype_to_base_data_type,
get_base_value_for_numpy_or_python_constant_data,
get_type_constructor_for_numpy_or_python_constant_data,
get_constructor_for_numpy_or_python_constant_data,
)
@@ -65,18 +65,31 @@ def test_convert_common_dtype_to_numpy_dtype(common_dtype, expected_numpy_dtype)
(10, int),
(42.0, float),
(numpy.int32(10), numpy.int32),
(numpy.array([[0, 1], [3, 4]], dtype=numpy.uint64), numpy.uint64),
(numpy.array([[0, 1], [3, 4]], dtype=numpy.float64), numpy.float64),
],
)
def test_get_type_constructor_for_numpy_or_python_constant_data(
constant_data, expected_constructor
):
"""Test function for get_type_constructor_for_numpy_or_python_constant_data"""
def test_get_constructor_for_numpy_or_python_constant_data(constant_data, expected_constructor):
"""Test function for get_constructor_for_numpy_or_python_constant_data"""
assert expected_constructor == get_type_constructor_for_numpy_or_python_constant_data(
constant_data
)
assert expected_constructor == get_constructor_for_numpy_or_python_constant_data(constant_data)
def test_get_constructor_for_numpy_arrays(test_helpers):
"""Test function for get_constructor_for_numpy_or_python_constant_data for numpy arrays."""
arrays = [
numpy.array([[0, 1], [3, 4]], dtype=numpy.uint64),
numpy.array([[0, 1], [3, 4]], dtype=numpy.float64),
]
def get_expected_constructor(array: numpy.ndarray):
return lambda x: numpy.full(array.shape, x, dtype=array.dtype)
expected_constructors = [get_expected_constructor(array) for array in arrays]
for array, expected_constructor in zip(arrays, expected_constructors):
assert test_helpers.python_functions_are_equal_or_equivalent(
expected_constructor, get_constructor_for_numpy_or_python_constant_data(array)
)
def test_get_base_value_for_numpy_or_python_constant_data_with_list():