mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
dev(floats): add the possibility to have constant floats in a program
- update ConstantInput to manage floats - update OPGraph update_values_with_bounds to manage floats - update test code to manage cases where output could be a float - add test cases with float inputs
This commit is contained in:
@@ -19,15 +19,21 @@ def eval_op_graph_bounds_on_dataset(op_graph: OPGraph, dataset: Iterator[Tuple[A
|
||||
Dict: dict containing the bounds for each node from op_graph, stored with the node as key
|
||||
and a dict with keys "min" and "max" as value
|
||||
"""
|
||||
|
||||
def check_dataset_input_is_valid(data_to_check):
|
||||
assert len(data_to_check) == len(op_graph.input_nodes), (
|
||||
f"Got input data from dataset of len: {len(data_to_check)}, "
|
||||
f"function being evaluated has {len(op_graph.input_nodes)} inputs, please make "
|
||||
f"sure your data generator returns valid tuples of input values"
|
||||
)
|
||||
# TODO: change this to be more generic and check coherence between the input data type and
|
||||
# the corresponding Input ir node expected data type
|
||||
assert all(
|
||||
isinstance(val, int) for val in data_to_check
|
||||
), "For now dataset evaluation only support int as inputs, please check your dataset"
|
||||
|
||||
first_input_data = dict(enumerate(next(dataset)))
|
||||
|
||||
# Check the dataset is well-formed
|
||||
assert len(first_input_data) == len(op_graph.input_nodes), (
|
||||
f"Got input data from dataset of len: {len(first_input_data)}, function being evaluated has"
|
||||
f" only {len(op_graph.input_nodes)} inputs, please make sure your data generator returns"
|
||||
f" valid tuples of input values"
|
||||
)
|
||||
|
||||
check_dataset_input_is_valid(first_input_data.values())
|
||||
first_output = op_graph.evaluate(first_input_data)
|
||||
|
||||
node_bounds = {
|
||||
@@ -36,19 +42,9 @@ def eval_op_graph_bounds_on_dataset(op_graph: OPGraph, dataset: Iterator[Tuple[A
|
||||
}
|
||||
|
||||
for input_data in dataset:
|
||||
|
||||
next_input_data = dict(enumerate(input_data))
|
||||
|
||||
# Check the dataset is well-formed
|
||||
assert len(next_input_data) == len(op_graph.input_nodes), (
|
||||
f"Got input data from dataset of len: {len(next_input_data)},"
|
||||
f" function being evaluated has"
|
||||
f" only {len(op_graph.input_nodes)} inputs, please make sure"
|
||||
f" your data generator returns"
|
||||
f" valid tuples of input values"
|
||||
)
|
||||
|
||||
current_output = op_graph.evaluate(next_input_data)
|
||||
current_input_data = dict(enumerate(input_data))
|
||||
check_dataset_input_is_valid(current_input_data.values())
|
||||
current_output = op_graph.evaluate(current_input_data)
|
||||
for node, value in current_output.items():
|
||||
node_bounds[node]["min"] = min(node_bounds[node]["min"], value)
|
||||
node_bounds[node]["max"] = max(node_bounds[node]["max"], value)
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, Dict, Iterable, Mapping
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from .data_types.floats import Float
|
||||
from .data_types.integers import make_integer_to_hold_ints
|
||||
from .representation import intermediate as ir
|
||||
from .tracing import BaseTracer
|
||||
@@ -71,10 +72,18 @@ class OPGraph:
|
||||
|
||||
if not isinstance(node, ir.Input):
|
||||
for output_value in node.outputs:
|
||||
output_value.data_type = make_integer_to_hold_ints(
|
||||
(min_bound, max_bound), force_signed=False
|
||||
)
|
||||
if isinstance(min_bound, int) and isinstance(max_bound, int):
|
||||
output_value.data_type = make_integer_to_hold_ints(
|
||||
(min_bound, max_bound), force_signed=False
|
||||
)
|
||||
else:
|
||||
output_value.data_type = Float(64)
|
||||
else:
|
||||
# Currently variable inputs are only allowed to be integers
|
||||
assert isinstance(min_bound, int) and isinstance(max_bound, int), (
|
||||
f"Inputs to a graph should be integers, got bounds that were not float, \n"
|
||||
f"min: {min_bound} ({type(min_bound)}), max: {max_bound} ({type(max_bound)})"
|
||||
)
|
||||
node.inputs[0].data_type = make_integer_to_hold_ints(
|
||||
(min_bound, max_bound), force_signed=False
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
|
||||
|
||||
from ..data_types import BaseValue
|
||||
from ..data_types.dtypes_helpers import mix_values_determine_holding_dtype
|
||||
from ..data_types.floats import Float
|
||||
from ..data_types.integers import Integer, get_bits_to_represent_int
|
||||
from ..data_types.scalars import Scalars
|
||||
from ..data_types.values import ClearValue
|
||||
@@ -153,12 +154,16 @@ class ConstantInput(IntermediateNode):
|
||||
super().__init__([])
|
||||
self.constant_data = constant_data
|
||||
|
||||
# TODO: manage other cases, we can't call get_bits_to_represent_int
|
||||
assert isinstance(constant_data, int)
|
||||
is_signed = constant_data < 0
|
||||
self.outputs = [
|
||||
ClearValue(Integer(get_bits_to_represent_int(constant_data, is_signed), is_signed))
|
||||
]
|
||||
assert isinstance(
|
||||
constant_data, (int, float)
|
||||
), "Only int and float are support for constant input"
|
||||
if isinstance(constant_data, int):
|
||||
is_signed = constant_data < 0
|
||||
self.outputs = [
|
||||
ClearValue(Integer(get_bits_to_represent_int(constant_data, is_signed), is_signed))
|
||||
]
|
||||
elif isinstance(constant_data, float):
|
||||
self.outputs = [ClearValue(Float(64))]
|
||||
|
||||
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
|
||||
return self.constant_data
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Tuple
|
||||
import pytest
|
||||
|
||||
from hdk.common.bounds_measurement.dataset_eval import eval_op_graph_bounds_on_dataset
|
||||
from hdk.common.data_types.floats import Float
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.data_types.values import EncryptedValue
|
||||
from hdk.hnumpy.tracing import trace_numpy_function
|
||||
@@ -27,6 +28,13 @@ from hdk.hnumpy.tracing import trace_numpy_function
|
||||
Integer(5, is_signed=True),
|
||||
id="x + y, (-10, 2), (-4, 5), (-14, 7)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x + y + 1.7,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-12.3, 8.7),
|
||||
Float(64),
|
||||
id="x + y + 1.7, (-10, 2), (-4, 5), (-12.3, 8.7)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x + y + 1,
|
||||
((-10, 2), (-4, 5)),
|
||||
@@ -67,21 +75,42 @@ from hdk.hnumpy.tracing import trace_numpy_function
|
||||
((-10, 2), (-4, 5)),
|
||||
(-57, -36),
|
||||
Integer(7, is_signed=True),
|
||||
id="x - y, (-10, 2), (-4, 5), (-57, -36)",
|
||||
id="x - y - 42, (-10, 2), (-4, 5), (-57, -36)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x - y - 41.5,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-56.5, -35.5),
|
||||
Float(64),
|
||||
id="x - y - 41.5, (-10, 2), (-4, 5), (-56.5, -35.5)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: 3 - x + y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-3, 18),
|
||||
Integer(6, is_signed=True),
|
||||
id="x - y, (-10, 2), (-4, 5), (-3, 18)",
|
||||
id="3 - x + y, (-10, 2), (-4, 5), (-3, 18)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: 2.8 - x + y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-3.2, 17.8),
|
||||
Float(64),
|
||||
id="2.8 - x + y, (-10, 2), (-4, 5), (-3.2, 17.8)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (-13) - x + y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-19, 2),
|
||||
Integer(6, is_signed=True),
|
||||
id="x - y, (-10, 2), (-4, 5), (-16, 2)",
|
||||
id="(-13) - x + y, (-10, 2), (-4, 5), (-19, 2)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (-13.5) - x + y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-19.5, 1.5),
|
||||
Float(64),
|
||||
id="(-13.5) - x + y, (-10, 2), (-4, 5), (-19.5, 1.5)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x * y,
|
||||
@@ -102,7 +131,14 @@ from hdk.hnumpy.tracing import trace_numpy_function
|
||||
((-10, 2), (-4, 5)),
|
||||
(-150, 120),
|
||||
Integer(9, is_signed=True),
|
||||
id="x * y, (-10, 2), (-4, 5), (-150, 120)",
|
||||
id="(3 * x) * y, (-10, 2), (-4, 5), (-150, 120)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (3.0 * x) * y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-150.0, 120.0),
|
||||
Float(64),
|
||||
id="(3.0 * x) * y, (-10, 2), (-4, 5), (-150.0, 120.0)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x * 11) * y,
|
||||
@@ -116,7 +152,14 @@ from hdk.hnumpy.tracing import trace_numpy_function
|
||||
((-10, 2), (-4, 5)),
|
||||
(-440, 550),
|
||||
Integer(11, is_signed=True),
|
||||
id="x * y, (-10, 2), (-4, 5), (-440, 550)",
|
||||
id="(x * (-11)) * y, (-10, 2), (-4, 5), (-440, 550)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x * (-11.0)) * y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-440.0, 550.0),
|
||||
Float(64),
|
||||
id="(x * (-11.0)) * y, (-10, 2), (-4, 5), (-440.0, 550.0)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x + x + y,
|
||||
@@ -187,12 +230,36 @@ def test_eval_op_graph_bounds_on_dataset(
|
||||
((0, 2), (13, 14)),
|
||||
(Integer(2, is_signed=False), Integer(4, is_signed=False)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x + 1.5, y + 9.6),
|
||||
((-1, 1), (3, 4)),
|
||||
((0.5, 2.5), (12.6, 13.6)),
|
||||
(Float(64), Float(64)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x + y + 1, x * y + 42),
|
||||
((-1, 1), (3, 4)),
|
||||
((3, 6), (38, 46)),
|
||||
(Integer(3, is_signed=False), Integer(6, is_signed=False)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x + y + 0.4, x * y + 41.7),
|
||||
((-1, 1), (3, 4)),
|
||||
((2.4, 5.4), (37.7, 45.7)),
|
||||
(Float(64), Float(64)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x + y + 1, x * y + 41.7),
|
||||
((-1, 1), (3, 4)),
|
||||
((3, 6), (37.7, 45.7)),
|
||||
(Integer(3, is_signed=False), Float(64)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x + y + 0.4, x * y + 42),
|
||||
((-1, 1), (3, 4)),
|
||||
((2.4, 5.4), (38, 46)),
|
||||
(Float(64), Integer(6, is_signed=False)),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_eval_op_graph_bounds_on_dataset_multiple_output(
|
||||
@@ -218,11 +285,8 @@ def test_eval_op_graph_bounds_on_dataset_multiple_output(
|
||||
|
||||
for i, output_node in op_graph.output_nodes.items():
|
||||
output_node_bounds = node_bounds[output_node]
|
||||
|
||||
assert (output_node_bounds["min"], output_node_bounds["max"]) == expected_output_bounds[i]
|
||||
|
||||
assert EncryptedValue(Integer(64, True)) == output_node.outputs[0]
|
||||
|
||||
op_graph.update_values_with_bounds(node_bounds)
|
||||
|
||||
for i, output_node in op_graph.output_nodes.items():
|
||||
|
||||
Reference in New Issue
Block a user