mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
refactor(tracing): preparatory work for tensor table generation
- removed underlying_type_constructor from BaseDataType as it was scalar specific and put it in values - update inputset_eval to keep a sample of intermediate node values - allows to get the proper value constructor to be used in UnivariateFunction get_table and have tensors as inputs
This commit is contained in:
@@ -135,7 +135,7 @@ def eval_op_graph_bounds_on_inputset(
|
||||
Returns:
|
||||
Tuple[int, Dict[IntermediateNode, Dict[str, Any]]]: number of inputs in the inputset and
|
||||
a dict containing the bounds for each node from op_graph, stored with the node
|
||||
as key and a dict with keys "min" and "max" as value.
|
||||
as key and a dict with keys "min", "max" and "sample" as value.
|
||||
"""
|
||||
|
||||
def check_inputset_input_len_is_valid(data_to_check):
|
||||
@@ -178,8 +178,12 @@ def eval_op_graph_bounds_on_inputset(
|
||||
|
||||
# We evaluate the min and max func to be able to resolve the tensors min and max rather than
|
||||
# having the tensor itself as the stored min and max values.
|
||||
node_bounds = {
|
||||
node: {"min": min_func(value, value), "max": max_func(value, value)}
|
||||
node_bounds_and_samples = {
|
||||
node: {
|
||||
"min": min_func(value, value),
|
||||
"max": max_func(value, value),
|
||||
"sample": value,
|
||||
}
|
||||
for node, value in first_output.items()
|
||||
}
|
||||
|
||||
@@ -200,7 +204,11 @@ def eval_op_graph_bounds_on_inputset(
|
||||
|
||||
current_output = op_graph.evaluate(current_input_data)
|
||||
for node, value in current_output.items():
|
||||
node_bounds[node]["min"] = min_func(node_bounds[node]["min"], value)
|
||||
node_bounds[node]["max"] = max_func(node_bounds[node]["max"], value)
|
||||
node_bounds_and_samples[node]["min"] = min_func(
|
||||
node_bounds_and_samples[node]["min"], value
|
||||
)
|
||||
node_bounds_and_samples[node]["max"] = max_func(
|
||||
node_bounds_and_samples[node]["max"], value
|
||||
)
|
||||
|
||||
return inputset_size, node_bounds
|
||||
return inputset_size, node_bounds_and_samples
|
||||
|
||||
@@ -1,19 +1,11 @@
|
||||
"""File holding code to represent data types in a program."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Type
|
||||
|
||||
|
||||
class BaseDataType(ABC):
|
||||
"""Base class to represent a data type."""
|
||||
|
||||
# Constructor for the data type represented (for example numpy.int32 for an int32 numpy array)
|
||||
underlying_type_constructor: Optional[Type]
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.underlying_type_constructor = None
|
||||
|
||||
@abstractmethod
|
||||
def __eq__(self, o: object) -> bool:
|
||||
"""No default implementation."""
|
||||
|
||||
@@ -312,7 +312,7 @@ def get_base_value_for_python_constant_data(
|
||||
return partial(TensorValue, dtype=constant_data_type, shape=())
|
||||
|
||||
|
||||
def get_type_constructor_for_python_constant_data(constant_data: Union[int, float]):
|
||||
def get_constructor_for_python_constant_data(constant_data: Union[int, float]):
|
||||
"""Get the constructor for the passed python constant data.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
"""Code to wrap and make manipulating networkx graphs easier."""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Type, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Union
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from .data_types.base import BaseDataType
|
||||
from .data_types.dtypes_helpers import (
|
||||
get_base_data_type_for_python_constant_data,
|
||||
get_type_constructor_for_python_constant_data,
|
||||
get_constructor_for_python_constant_data,
|
||||
)
|
||||
from .data_types.floats import Float
|
||||
from .data_types.integers import Integer, make_integer_to_hold
|
||||
@@ -171,15 +171,15 @@ class OPGraph:
|
||||
|
||||
return node_results
|
||||
|
||||
def update_values_with_bounds(
|
||||
def update_values_with_bounds_and_samples(
|
||||
self,
|
||||
node_bounds: dict,
|
||||
node_bounds_and_samples: dict,
|
||||
get_base_data_type_for_constant_data: Callable[
|
||||
[Any], BaseDataType
|
||||
] = get_base_data_type_for_python_constant_data,
|
||||
get_type_constructor_for_constant_data: Callable[
|
||||
..., Type
|
||||
] = get_type_constructor_for_python_constant_data,
|
||||
get_constructor_for_constant_data: Callable[
|
||||
..., Callable
|
||||
] = get_constructor_for_python_constant_data,
|
||||
):
|
||||
"""Update values with bounds.
|
||||
|
||||
@@ -187,40 +187,44 @@ class OPGraph:
|
||||
and passed in nodes_bounds
|
||||
|
||||
Args:
|
||||
node_bounds (dict): Dictionary with nodes as keys, holding dicts with a 'min' and 'max'
|
||||
keys. Those bounds will be taken as the data range to be represented, per node.
|
||||
node_bounds_and_samples (dict): Dictionary with nodes as keys, holding dicts with a
|
||||
'min', 'max' and 'sample' keys. Those bounds will be taken as the data range to be
|
||||
represented, per node. The sample allows to determine the data constructors to
|
||||
prepare the UnivariateFunction nodes for table generation.
|
||||
get_base_data_type_for_constant_data (Callable[ [Any], BaseDataType ], optional): This
|
||||
is a callback function to convert data encountered during value updates to
|
||||
BaseDataType. This allows to manage data coming from foreign frameworks without
|
||||
specialising OPGraph. Defaults to get_base_data_type_for_python_constant_data.
|
||||
get_type_constructor_for_constant_data (Callable[ ..., Type ], optional): This is a
|
||||
get_constructor_for_constant_data (Callable[ ..., Callable ], optional): This is a
|
||||
callback function to determine the type constructor of the data encountered while
|
||||
updating the graph bounds. Defaults to get_type_constructor_python_constant_data.
|
||||
updating the graph bounds. Defaults to get_constructor_for_python_constant_data.
|
||||
"""
|
||||
node: IntermediateNode
|
||||
|
||||
for node in self.graph.nodes():
|
||||
current_node_bounds = node_bounds[node]
|
||||
min_bound, max_bound = (
|
||||
current_node_bounds["min"],
|
||||
current_node_bounds["max"],
|
||||
current_node_bounds_and_samples = node_bounds_and_samples[node]
|
||||
min_bound, max_bound, sample = (
|
||||
current_node_bounds_and_samples["min"],
|
||||
current_node_bounds_and_samples["max"],
|
||||
current_node_bounds_and_samples["sample"],
|
||||
)
|
||||
|
||||
min_data_type = get_base_data_type_for_constant_data(min_bound)
|
||||
max_data_type = get_base_data_type_for_constant_data(max_bound)
|
||||
|
||||
min_data_type_constructor = get_type_constructor_for_constant_data(min_bound)
|
||||
max_data_type_constructor = get_type_constructor_for_constant_data(max_bound)
|
||||
# This is a sanity check
|
||||
min_value_constructor = get_constructor_for_constant_data(min_bound)
|
||||
max_value_constructor = get_constructor_for_constant_data(max_bound)
|
||||
|
||||
assert_true(
|
||||
max_data_type_constructor == min_data_type_constructor,
|
||||
max_value_constructor == min_value_constructor,
|
||||
(
|
||||
f"Got two different type constructors for min and max bound: "
|
||||
f"{min_data_type_constructor}, {max_data_type_constructor}"
|
||||
f"{min_value_constructor}, {max_value_constructor}"
|
||||
),
|
||||
)
|
||||
|
||||
data_type_constructor = max_data_type_constructor
|
||||
value_constructor = get_constructor_for_constant_data(sample)
|
||||
|
||||
if not isinstance(node, Input):
|
||||
for output_value in node.outputs:
|
||||
@@ -238,7 +242,7 @@ class OPGraph:
|
||||
),
|
||||
)
|
||||
output_value.dtype = Float(64)
|
||||
output_value.dtype.underlying_type_constructor = data_type_constructor
|
||||
output_value.underlying_constructor = value_constructor
|
||||
else:
|
||||
# Currently variable inputs are only allowed to be integers
|
||||
assert_true(
|
||||
@@ -252,7 +256,7 @@ class OPGraph:
|
||||
node.inputs[0].dtype = make_integer_to_hold(
|
||||
(min_bound, max_bound), force_signed=False
|
||||
)
|
||||
node.inputs[0].dtype.underlying_type_constructor = data_type_constructor
|
||||
node.inputs[0].underlying_constructor = value_constructor
|
||||
|
||||
node.outputs[0] = deepcopy(node.inputs[0])
|
||||
|
||||
|
||||
@@ -250,26 +250,25 @@ class UnivariateFunction(IntermediateNode):
|
||||
Returns:
|
||||
List[Any]: The table.
|
||||
"""
|
||||
input_dtype = self.inputs[0].dtype
|
||||
# Check the input is an unsigned integer to be able to build a table
|
||||
assert isinstance(
|
||||
self.inputs[0].dtype, Integer
|
||||
input_dtype, Integer
|
||||
), "get_table only works for an unsigned Integer input"
|
||||
assert not self.inputs[
|
||||
0
|
||||
].dtype.is_signed, "get_table only works for an unsigned Integer input"
|
||||
assert not input_dtype.is_signed, "get_table only works for an unsigned Integer input"
|
||||
|
||||
type_constructor = self.inputs[0].dtype.underlying_type_constructor
|
||||
if type_constructor is None:
|
||||
input_value_constructor = self.inputs[0].underlying_constructor
|
||||
if input_value_constructor is None:
|
||||
logger.info(
|
||||
f"{self.__class__.__name__} input data type constructor was None, defaulting to int"
|
||||
)
|
||||
type_constructor = int
|
||||
input_value_constructor = int
|
||||
|
||||
min_input_range = self.inputs[0].dtype.min_value()
|
||||
max_input_range = self.inputs[0].dtype.max_value() + 1
|
||||
min_input_range = input_dtype.min_value()
|
||||
max_input_range = input_dtype.max_value() + 1
|
||||
|
||||
table = [
|
||||
self.evaluate({0: type_constructor(input_value)})
|
||||
self.evaluate({0: input_value_constructor(input_value)})
|
||||
for input_value in range(min_input_range, max_input_range)
|
||||
]
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Optional
|
||||
|
||||
from ..data_types.base import BaseDataType
|
||||
|
||||
@@ -11,10 +12,12 @@ class BaseValue(ABC):
|
||||
|
||||
dtype: BaseDataType
|
||||
_is_encrypted: bool
|
||||
underlying_constructor: Optional[Callable]
|
||||
|
||||
def __init__(self, dtype: BaseDataType, is_encrypted: bool) -> None:
|
||||
self.dtype = deepcopy(dtype)
|
||||
self._is_encrypted = is_encrypted
|
||||
self.underlying_constructor = None
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover
|
||||
return str(self)
|
||||
|
||||
@@ -26,7 +26,7 @@ from ..numpy.tracing import trace_numpy_function
|
||||
from .np_dtypes_helpers import (
|
||||
get_base_data_type_for_numpy_or_python_constant_data,
|
||||
get_base_value_for_numpy_or_python_constant_data,
|
||||
get_type_constructor_for_numpy_or_python_constant_data,
|
||||
get_constructor_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
|
||||
@@ -111,7 +111,7 @@ def _compile_numpy_function_into_op_graph_internal(
|
||||
)
|
||||
|
||||
# Find bounds with the inputset
|
||||
inputset_size, node_bounds = eval_op_graph_bounds_on_inputset(
|
||||
inputset_size, node_bounds_and_samples = eval_op_graph_bounds_on_inputset(
|
||||
op_graph,
|
||||
inputset,
|
||||
compilation_configuration=compilation_configuration,
|
||||
@@ -149,13 +149,13 @@ def _compile_numpy_function_into_op_graph_internal(
|
||||
sys.stderr.write(f"Warning: {message}")
|
||||
|
||||
# Add the bounds as an artifact
|
||||
compilation_artifacts.add_final_operation_graph_bounds(node_bounds)
|
||||
compilation_artifacts.add_final_operation_graph_bounds(node_bounds_and_samples)
|
||||
|
||||
# Update the graph accordingly: after that, we have the compilable graph
|
||||
op_graph.update_values_with_bounds(
|
||||
node_bounds,
|
||||
op_graph.update_values_with_bounds_and_samples(
|
||||
node_bounds_and_samples,
|
||||
get_base_data_type_for_numpy_or_python_constant_data,
|
||||
get_type_constructor_for_numpy_or_python_constant_data,
|
||||
get_constructor_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
# Add the initial graph as an artifact
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Type, Union
|
||||
from typing import Any, Callable, Dict, List, Union
|
||||
|
||||
import numpy
|
||||
from numpy.typing import DTypeLike
|
||||
@@ -13,7 +13,7 @@ from ..common.data_types.dtypes_helpers import (
|
||||
find_type_to_hold_both_lossy,
|
||||
get_base_data_type_for_python_constant_data,
|
||||
get_base_value_for_python_constant_data,
|
||||
get_type_constructor_for_python_constant_data,
|
||||
get_constructor_for_python_constant_data,
|
||||
)
|
||||
from ..common.data_types.floats import Float
|
||||
from ..common.data_types.integers import Integer
|
||||
@@ -224,8 +224,8 @@ def get_numpy_function_output_dtype(
|
||||
return [output.dtype for output in outputs]
|
||||
|
||||
|
||||
def get_type_constructor_for_numpy_or_python_constant_data(constant_data: Any):
|
||||
"""Get the constructor for the numpy scalar underlying dtype or python dtype.
|
||||
def get_constructor_for_numpy_or_python_constant_data(constant_data: Any):
|
||||
"""Get the constructor for the numpy constant data or python dtype.
|
||||
|
||||
Args:
|
||||
constant_data (Any): The data for which we want to determine the type constructor.
|
||||
@@ -236,11 +236,8 @@ def get_type_constructor_for_numpy_or_python_constant_data(constant_data: Any):
|
||||
f"Unsupported constant data of type {type(constant_data)}",
|
||||
)
|
||||
|
||||
scalar_constructor: Type
|
||||
|
||||
if isinstance(constant_data, (numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)):
|
||||
scalar_constructor = constant_data.dtype.type
|
||||
else:
|
||||
scalar_constructor = get_type_constructor_for_python_constant_data(constant_data)
|
||||
|
||||
return scalar_constructor
|
||||
if isinstance(constant_data, numpy.ndarray):
|
||||
return lambda x: numpy.full(constant_data.shape, x, dtype=constant_data.dtype)
|
||||
return constant_data.dtype.type
|
||||
return get_constructor_for_python_constant_data(constant_data)
|
||||
|
||||
@@ -283,17 +283,17 @@ def test_eval_op_graph_bounds_on_inputset_multiple_output(
|
||||
for y_gen in range_y:
|
||||
yield (x_gen, y_gen)
|
||||
|
||||
_, node_bounds = eval_op_graph_bounds_on_inputset(
|
||||
_, node_bounds_and_samples = eval_op_graph_bounds_on_inputset(
|
||||
op_graph,
|
||||
data_gen(*tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
CompilationConfiguration(),
|
||||
)
|
||||
|
||||
for i, output_node in op_graph.output_nodes.items():
|
||||
output_node_bounds = node_bounds[output_node]
|
||||
output_node_bounds = node_bounds_and_samples[output_node]
|
||||
assert (output_node_bounds["min"], output_node_bounds["max"]) == expected_output_bounds[i]
|
||||
|
||||
op_graph.update_values_with_bounds(node_bounds)
|
||||
op_graph.update_values_with_bounds_and_samples(node_bounds_and_samples)
|
||||
|
||||
for i, output_node in op_graph.output_nodes.items():
|
||||
assert expected_output_data_type[i] == output_node.outputs[0].dtype
|
||||
|
||||
@@ -215,6 +215,14 @@ class TestHelpers:
|
||||
|
||||
return graphs_are_isomorphic
|
||||
|
||||
@staticmethod
|
||||
def python_functions_are_equal_or_equivalent(lhs, rhs):
|
||||
"""Helper function to check if two functions are equal or their code are equivalent.
|
||||
|
||||
This is not perfect, but will be good enough for tests.
|
||||
"""
|
||||
return python_functions_are_equal_or_equivalent(lhs, rhs)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_helpers():
|
||||
|
||||
@@ -330,7 +330,7 @@ def test_fail_compile(function, input_ranges, list_of_arg_names):
|
||||
}
|
||||
|
||||
with pytest.raises(RuntimeError, match=".*isn't supported for MLIR lowering.*"):
|
||||
compile_numpy_function_into_op_graph(
|
||||
compile_numpy_function(
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
|
||||
@@ -9,7 +9,7 @@ from concrete.numpy.np_dtypes_helpers import (
|
||||
convert_base_data_type_to_numpy_dtype,
|
||||
convert_numpy_dtype_to_base_data_type,
|
||||
get_base_value_for_numpy_or_python_constant_data,
|
||||
get_type_constructor_for_numpy_or_python_constant_data,
|
||||
get_constructor_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
|
||||
@@ -65,18 +65,31 @@ def test_convert_common_dtype_to_numpy_dtype(common_dtype, expected_numpy_dtype)
|
||||
(10, int),
|
||||
(42.0, float),
|
||||
(numpy.int32(10), numpy.int32),
|
||||
(numpy.array([[0, 1], [3, 4]], dtype=numpy.uint64), numpy.uint64),
|
||||
(numpy.array([[0, 1], [3, 4]], dtype=numpy.float64), numpy.float64),
|
||||
],
|
||||
)
|
||||
def test_get_type_constructor_for_numpy_or_python_constant_data(
|
||||
constant_data, expected_constructor
|
||||
):
|
||||
"""Test function for get_type_constructor_for_numpy_or_python_constant_data"""
|
||||
def test_get_constructor_for_numpy_or_python_constant_data(constant_data, expected_constructor):
|
||||
"""Test function for get_constructor_for_numpy_or_python_constant_data"""
|
||||
|
||||
assert expected_constructor == get_type_constructor_for_numpy_or_python_constant_data(
|
||||
constant_data
|
||||
)
|
||||
assert expected_constructor == get_constructor_for_numpy_or_python_constant_data(constant_data)
|
||||
|
||||
|
||||
def test_get_constructor_for_numpy_arrays(test_helpers):
|
||||
"""Test function for get_constructor_for_numpy_or_python_constant_data for numpy arrays."""
|
||||
|
||||
arrays = [
|
||||
numpy.array([[0, 1], [3, 4]], dtype=numpy.uint64),
|
||||
numpy.array([[0, 1], [3, 4]], dtype=numpy.float64),
|
||||
]
|
||||
|
||||
def get_expected_constructor(array: numpy.ndarray):
|
||||
return lambda x: numpy.full(array.shape, x, dtype=array.dtype)
|
||||
|
||||
expected_constructors = [get_expected_constructor(array) for array in arrays]
|
||||
|
||||
for array, expected_constructor in zip(arrays, expected_constructors):
|
||||
assert test_helpers.python_functions_are_equal_or_equivalent(
|
||||
expected_constructor, get_constructor_for_numpy_or_python_constant_data(array)
|
||||
)
|
||||
|
||||
|
||||
def test_get_base_value_for_numpy_or_python_constant_data_with_list():
|
||||
|
||||
Reference in New Issue
Block a user