mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat(bounds): add function to update OPGraph IR nodes in and output values
- this allows to have tighter data types by sticking to the smallest types able to represent the ranges passed as argument - update test_dataset_eval to check the output Value's data_type is updated
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
"""Code to wrap and make manipulating networkx graphs easier"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Iterable, Mapping
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from .data_types.integers import make_integer_to_hold_ints
|
||||
from .representation import intermediate as ir
|
||||
from .tracing import BaseTracer
|
||||
from .tracing.tracing_helpers import create_graph_from_output_tracers
|
||||
@@ -57,3 +59,40 @@ class OPGraph:
|
||||
node_results[node] = node.evaluate({0: inputs[node.program_input_idx]})
|
||||
|
||||
return node_results
|
||||
|
||||
def update_values_with_bounds(self, node_bounds: dict):
|
||||
"""Update nodes inputs and outputs values with data types able to hold data ranges measured
|
||||
and passed in nodes_bounds
|
||||
|
||||
Args:
|
||||
node_bounds (dict): Dictionary with nodes as keys, holding dicts with a 'min' and 'max'
|
||||
keys. Those bounds will be taken as the data range to be represented, per node.
|
||||
"""
|
||||
|
||||
node: ir.IntermediateNode
|
||||
|
||||
for node in self.graph.nodes():
|
||||
current_node_bounds = node_bounds[node]
|
||||
min_bound, max_bound = current_node_bounds["min"], current_node_bounds["max"]
|
||||
|
||||
if not isinstance(node, ir.Input):
|
||||
for output_value in node.outputs:
|
||||
output_value.data_type = make_integer_to_hold_ints(
|
||||
(min_bound, max_bound), force_signed=False
|
||||
)
|
||||
else:
|
||||
node.inputs[0].data_type = make_integer_to_hold_ints(
|
||||
(min_bound, max_bound), force_signed=False
|
||||
)
|
||||
node.outputs[0] = deepcopy(node.inputs[0])
|
||||
|
||||
# TODO: #57 manage multiple outputs from a node, probably requires an output_idx when
|
||||
# adding an edge
|
||||
assert len(node.outputs) == 1
|
||||
|
||||
successors = self.graph.succ[node]
|
||||
for succ in successors:
|
||||
edge_data = self.graph.get_edge_data(node, succ)
|
||||
for edge in edge_data.values():
|
||||
input_idx = edge["input_idx"]
|
||||
succ.inputs[input_idx] = deepcopy(node.outputs[0])
|
||||
|
||||
@@ -9,77 +9,93 @@ from hdk.hnumpy.tracing import trace_numpy_function
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,expected_output_bounds",
|
||||
"function,input_ranges,expected_output_bounds,expected_output_data_type",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: x + y,
|
||||
((-10, 10), (-10, 10)),
|
||||
(-20, 20),
|
||||
Integer(6, is_signed=True),
|
||||
id="x + y, (-10, 10), (-10, 10), (-20, 20)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x + y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-14, 7),
|
||||
Integer(5, is_signed=True),
|
||||
id="x + y, (-10, 2), (-4, 5), (-14, 9)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x - y,
|
||||
((-10, 10), (-10, 10)),
|
||||
(-20, 20),
|
||||
Integer(6, is_signed=True),
|
||||
id="x - y, (-10, 10), (-10, 10), (-20, 20)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x - y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-15, 6),
|
||||
Integer(5, is_signed=True),
|
||||
id="x - y, (-10, 2), (-4, 5), (-15, 6)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x * y,
|
||||
((-10, 10), (-10, 10)),
|
||||
(-100, 100),
|
||||
Integer(8, is_signed=True),
|
||||
id="x * y, (-10, 10), (-10, 10), (-100, 100)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x * y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-50, 40),
|
||||
Integer(7, is_signed=True),
|
||||
id="x * y, (-10, 2), (-4, 5), (-50, 40)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x + x + y,
|
||||
((-10, 10), (-10, 10)),
|
||||
(-30, 30),
|
||||
Integer(6, is_signed=True),
|
||||
id="x + x + y, (-10, 10), (-10, 10), (-30, 30)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x - x + y,
|
||||
((-10, 10), (-10, 10)),
|
||||
(-10, 10),
|
||||
Integer(5, is_signed=True),
|
||||
id="x - x + y, (-10, 10), (-10, 10), (-10, 10)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x - x + y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-4, 5),
|
||||
Integer(4, is_signed=True),
|
||||
id="x - x + y, (-10, 2), (-4, 5), (-4, 5)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x * y - x,
|
||||
((-10, 10), (-10, 10)),
|
||||
(-110, 110),
|
||||
Integer(8, is_signed=True),
|
||||
id="x * y - x, (-10, 10), (-10, 10), (-110, 110)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x * y - x,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-40, 50),
|
||||
Integer(7, is_signed=True),
|
||||
id="x * y - x, (-10, 2), (-4, 5), (-40, 50),",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_eval_op_graph_bounds_on_dataset(function, input_ranges, expected_output_bounds):
|
||||
def test_eval_op_graph_bounds_on_dataset(
|
||||
function,
|
||||
input_ranges,
|
||||
expected_output_bounds,
|
||||
expected_output_data_type: Integer,
|
||||
):
|
||||
"""Test function for eval_op_graph_bounds_on_dataset"""
|
||||
|
||||
op_graph = trace_numpy_function(
|
||||
@@ -99,3 +115,8 @@ def test_eval_op_graph_bounds_on_dataset(function, input_ranges, expected_output
|
||||
output_node_bounds = node_bounds[output_node]
|
||||
|
||||
assert (output_node_bounds["min"], output_node_bounds["max"]) == expected_output_bounds
|
||||
|
||||
assert EncryptedValue(Integer(64, True)) == output_node.outputs[0]
|
||||
op_graph.update_values_with_bounds(node_bounds)
|
||||
|
||||
assert expected_output_data_type == output_node.outputs[0].data_type
|
||||
|
||||
Reference in New Issue
Block a user