mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor: update OPGraph to be able to update bounds with foreign types
- get BaseDataType for values being checked in update_values_with_bounds - rename SUPPORTED_TYPES to BASE_DATA_TYPES
This commit is contained in:
@@ -1,16 +1,16 @@
|
||||
"""File to hold helper functions for data types related stuff."""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import cast
|
||||
from typing import Union, cast
|
||||
|
||||
from .base import BaseDataType
|
||||
from .floats import Float
|
||||
from .integers import Integer
|
||||
from .integers import Integer, get_bits_to_represent_value_as_integer
|
||||
from .values import BaseValue, ClearValue, EncryptedValue, ScalarValue
|
||||
|
||||
INTEGER_TYPES = (Integer,)
|
||||
FLOAT_TYPES = (Float,)
|
||||
SUPPORTED_TYPES = INTEGER_TYPES + FLOAT_TYPES
|
||||
BASE_DATA_TYPES = INTEGER_TYPES + FLOAT_TYPES
|
||||
|
||||
|
||||
def value_is_encrypted_integer(value_to_check: BaseValue) -> bool:
|
||||
@@ -93,8 +93,8 @@ def find_type_to_hold_both_lossy(
|
||||
Returns:
|
||||
BaseDataType: The dtype able to hold (potentially lossy) dtype1 and dtype2
|
||||
"""
|
||||
assert isinstance(dtype1, SUPPORTED_TYPES), f"Unsupported dtype1: {type(dtype1)}"
|
||||
assert isinstance(dtype2, SUPPORTED_TYPES), f"Unsupported dtype2: {type(dtype2)}"
|
||||
assert isinstance(dtype1, BASE_DATA_TYPES), f"Unsupported dtype1: {type(dtype1)}"
|
||||
assert isinstance(dtype2, BASE_DATA_TYPES), f"Unsupported dtype2: {type(dtype2)}"
|
||||
|
||||
type_to_return: BaseDataType
|
||||
|
||||
@@ -161,3 +161,27 @@ def mix_scalar_values_determine_holding_dtype(value1: BaseValue, value2: BaseVal
|
||||
mixed_value = ClearValue(holding_type)
|
||||
|
||||
return mixed_value
|
||||
|
||||
|
||||
def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]) -> BaseDataType:
|
||||
"""Helper function to determine the BaseDataType to hold the input constant data.
|
||||
|
||||
Args:
|
||||
constant_data (Union[int, float]): The constant data for which to determine the
|
||||
corresponding BaseDataType.
|
||||
|
||||
Returns:
|
||||
BaseDataType: The corresponding BaseDataType
|
||||
"""
|
||||
constant_data_type: BaseDataType
|
||||
assert isinstance(
|
||||
constant_data, (int, float)
|
||||
), f"Unsupported constant data of type {type(constant_data)}"
|
||||
if isinstance(constant_data, int):
|
||||
is_signed = constant_data < 0
|
||||
constant_data_type = Integer(
|
||||
get_bits_to_represent_value_as_integer(constant_data, is_signed), is_signed
|
||||
)
|
||||
elif isinstance(constant_data, float):
|
||||
constant_data_type = Float(64)
|
||||
return constant_data_type
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
"""Code to wrap and make manipulating networkx graphs easier."""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Iterable, List, Set, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Union
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from .data_types.base import BaseDataType
|
||||
from .data_types.dtypes_helpers import get_base_data_type_for_python_constant_data
|
||||
from .data_types.floats import Float
|
||||
from .data_types.integers import make_integer_to_hold
|
||||
from .data_types.integers import Integer, make_integer_to_hold
|
||||
from .representation import intermediate as ir
|
||||
from .tracing import BaseTracer
|
||||
from .tracing.tracing_helpers import create_graph_from_output_tracers
|
||||
@@ -130,7 +132,13 @@ class OPGraph:
|
||||
|
||||
return node_results
|
||||
|
||||
def update_values_with_bounds(self, node_bounds: dict):
|
||||
def update_values_with_bounds(
|
||||
self,
|
||||
node_bounds: dict,
|
||||
get_base_data_type_for_constant_data: Callable[
|
||||
[Any], BaseDataType
|
||||
] = get_base_data_type_for_python_constant_data,
|
||||
):
|
||||
"""Update values with bounds.
|
||||
|
||||
Update nodes inputs and outputs values with data types able to hold data ranges measured
|
||||
@@ -139,6 +147,10 @@ class OPGraph:
|
||||
Args:
|
||||
node_bounds (dict): Dictionary with nodes as keys, holding dicts with a 'min' and 'max'
|
||||
keys. Those bounds will be taken as the data range to be represented, per node.
|
||||
get_base_data_type_for_constant_data (Callable[ [Type], BaseDataType ], optional): This
|
||||
is a callback function to convert data encountered during value updates to
|
||||
BaseDataType. This allows to manage data coming from foreign frameworks without
|
||||
specialising OPGraph. Defaults to get_base_data_type_for_python_constant_data.
|
||||
"""
|
||||
node: ir.IntermediateNode
|
||||
|
||||
@@ -149,9 +161,12 @@ class OPGraph:
|
||||
current_node_bounds["max"],
|
||||
)
|
||||
|
||||
min_data_type = get_base_data_type_for_constant_data(min_bound)
|
||||
max_data_type = get_base_data_type_for_constant_data(max_bound)
|
||||
|
||||
if not isinstance(node, ir.Input):
|
||||
for output_value in node.outputs:
|
||||
if isinstance(min_bound, int) and isinstance(max_bound, int):
|
||||
if isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer):
|
||||
output_value.data_type = make_integer_to_hold(
|
||||
(min_bound, max_bound), force_signed=False
|
||||
)
|
||||
@@ -159,8 +174,8 @@ class OPGraph:
|
||||
output_value.data_type = Float(64)
|
||||
else:
|
||||
# Currently variable inputs are only allowed to be integers
|
||||
assert isinstance(min_bound, int) and isinstance(max_bound, int), (
|
||||
f"Inputs to a graph should be integers, got bounds that were not float, \n"
|
||||
assert isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer), (
|
||||
f"Inputs to a graph should be integers, got bounds that were float, \n"
|
||||
f"min: {min_bound} ({type(min_bound)}), max: {max_bound} ({type(max_bound)})"
|
||||
)
|
||||
node.inputs[0].data_type = make_integer_to_hold(
|
||||
|
||||
@@ -7,7 +7,7 @@ import numpy
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
from ..common.data_types.base import BaseDataType
|
||||
from ..common.data_types.dtypes_helpers import SUPPORTED_TYPES
|
||||
from ..common.data_types.dtypes_helpers import BASE_DATA_TYPES
|
||||
from ..common.data_types.floats import Float
|
||||
from ..common.data_types.integers import Integer
|
||||
|
||||
@@ -62,7 +62,7 @@ def convert_common_dtype_to_numpy_dtype(common_dtype: BaseDataType) -> numpy.dty
|
||||
numpy.dtype: The resulting numpy.dtype
|
||||
"""
|
||||
assert isinstance(
|
||||
common_dtype, SUPPORTED_TYPES
|
||||
common_dtype, BASE_DATA_TYPES
|
||||
), f"Unsupported common_dtype: {type(common_dtype)}"
|
||||
type_to_return: numpy.dtype
|
||||
|
||||
|
||||
Reference in New Issue
Block a user