mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat(tracing): add output_idx information in edges
- renamed output_index to output_idx in BaseTracer - update tracing and fusing code to manage output_idx correctly - update OPGraph evaluate and update_values_with_bounds to manage output_idx - update tests checking graph validity to have output_idx set properly - the support of actual multi-output nodes is in #81
This commit is contained in:
@@ -89,6 +89,7 @@ def draw_graph(
|
||||
}
|
||||
nx.set_node_attributes(graph, attributes)
|
||||
|
||||
# TODO: #639 adapt drawing routine to manage output_idx
|
||||
for edge in graph.edges(keys=True):
|
||||
idx = graph.edges[edge]["input_idx"]
|
||||
graph.edges[edge]["label"] = f" {idx} " # spaces are there intentionally for a better look
|
||||
|
||||
@@ -61,6 +61,7 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
|
||||
|
||||
for node in nx.topological_sort(graph):
|
||||
|
||||
# TODO: #640
|
||||
# This code doesn't work with more than a single output. For more outputs,
|
||||
# we would need to change the way the destination are created: currently,
|
||||
# they only are done by incrementing i
|
||||
|
||||
@@ -45,7 +45,7 @@ class LookupTable:
|
||||
return key.__class__(
|
||||
inputs=[key],
|
||||
traced_computation=traced_computation,
|
||||
output_index=0,
|
||||
output_idx=0,
|
||||
)
|
||||
|
||||
# if not, it means table is indexed with a constant
|
||||
|
||||
@@ -126,13 +126,44 @@ class OPGraph:
|
||||
"""
|
||||
node_results: Dict[IntermediateNode, Any] = {}
|
||||
|
||||
def get_result_of_node_at_index(node: IntermediateNode, output_idx: int) -> Any:
|
||||
"""Get the output result at index output_idx for a node.
|
||||
|
||||
Args:
|
||||
node (IntermediateNode): the node from which we want the output.
|
||||
output_idx (int): which output we want.
|
||||
|
||||
Returns:
|
||||
Any: the output value of the evaluation of node.
|
||||
"""
|
||||
result = node_results[node]
|
||||
# TODO: #81 remove no cover once we have nodes with multiple outputs
|
||||
if isinstance(result, tuple): # pragma: no cover
|
||||
# If the node has multiple outputs (i.e. the result is a tuple), return the
|
||||
# requested output
|
||||
return result[output_idx]
|
||||
# If the result is not a tuple, then the result is the node's only output. Check that
|
||||
# the requested index is 0 (as it's the only valid value) and return the result itself.
|
||||
assert_true(
|
||||
output_idx == 0,
|
||||
f"Unable to get output at index {output_idx} for node {node}.\n"
|
||||
f"Node result: {result}",
|
||||
)
|
||||
return result
|
||||
|
||||
for node in nx.topological_sort(self.graph):
|
||||
if not isinstance(node, Input):
|
||||
curr_inputs = {}
|
||||
for pred_node in self.graph.pred[node]:
|
||||
edges = self.graph.get_edge_data(pred_node, node)
|
||||
curr_inputs.update(
|
||||
{edge["input_idx"]: node_results[pred_node] for edge in edges.values()}
|
||||
{
|
||||
edge["input_idx"]: get_result_of_node_at_index(
|
||||
pred_node,
|
||||
output_idx=edge["output_idx"],
|
||||
)
|
||||
for edge in edges.values()
|
||||
}
|
||||
)
|
||||
node_results[node] = node.evaluate(curr_inputs)
|
||||
else:
|
||||
@@ -225,16 +256,12 @@ class OPGraph:
|
||||
|
||||
node.outputs[0] = deepcopy(node.inputs[0])
|
||||
|
||||
# TODO: #57 manage multiple outputs from a node, probably requires an output_idx when
|
||||
# adding an edge
|
||||
assert_true(len(node.outputs) == 1)
|
||||
|
||||
successors = self.graph.succ[node]
|
||||
for succ in successors:
|
||||
edge_data = self.graph.get_edge_data(node, succ)
|
||||
for edge in edge_data.values():
|
||||
input_idx = edge["input_idx"]
|
||||
succ.inputs[input_idx] = deepcopy(node.outputs[0])
|
||||
input_idx, output_idx = edge["input_idx"], edge["output_idx"]
|
||||
succ.inputs[input_idx] = deepcopy(node.outputs[output_idx])
|
||||
|
||||
def prune_nodes(self):
|
||||
"""Remove unreachable nodes from outputs."""
|
||||
|
||||
@@ -73,10 +73,14 @@ def fuse_float_operations(
|
||||
succ_edge_data = deepcopy(nx_graph.get_edge_data(terminal_node, succ))
|
||||
for edge_key, edge_data in succ_edge_data.items():
|
||||
nx_graph.remove_edge(terminal_node, succ, key=edge_key)
|
||||
nx_graph.add_edge(fused_node, succ, key=edge_key, **edge_data)
|
||||
# fused_node is always a UnivariateFunction so output_idx == 0 always
|
||||
new_edge_data = deepcopy(edge_data)
|
||||
new_edge_data["output_idx"] = 0
|
||||
nx_graph.add_edge(fused_node, succ, key=edge_key, **new_edge_data)
|
||||
|
||||
# Connect the node feeding the subgraph contained in fused_node
|
||||
nx_graph.add_edge(node_before_subgraph, fused_node, input_idx=0)
|
||||
# node_before_subgraph has a single integer output currently so output_idx == 0
|
||||
nx_graph.add_edge(node_before_subgraph, fused_node, input_idx=0, output_idx=0)
|
||||
|
||||
op_graph.prune_nodes()
|
||||
if compilation_artifacts is not None:
|
||||
@@ -122,6 +126,7 @@ def convert_float_subgraph_to_fused_node(
|
||||
assert_true(len(variable_input_nodes) == 1)
|
||||
|
||||
current_subgraph_variable_input = variable_input_nodes[0]
|
||||
assert_true(len(current_subgraph_variable_input.outputs) == 1)
|
||||
new_input_value = deepcopy(current_subgraph_variable_input.outputs[0])
|
||||
|
||||
nx_graph = op_graph.graph
|
||||
@@ -147,11 +152,14 @@ def convert_float_subgraph_to_fused_node(
|
||||
float_subgraph.remove_edge(
|
||||
current_subgraph_variable_input, node_after_input, key=edge_key
|
||||
)
|
||||
# new_subgraph_variable_input is always an Input so output_idx == 0 always
|
||||
new_edge_data = deepcopy(edge_data)
|
||||
new_edge_data["output_idx"] = 0
|
||||
float_subgraph.add_edge(
|
||||
new_subgraph_variable_input,
|
||||
node_after_input,
|
||||
key=edge_key,
|
||||
**edge_data,
|
||||
**new_edge_data,
|
||||
)
|
||||
|
||||
float_op_subgraph = OPGraph.from_graph(
|
||||
@@ -160,6 +168,8 @@ def convert_float_subgraph_to_fused_node(
|
||||
[terminal_node],
|
||||
)
|
||||
|
||||
assert_true(len(terminal_node.outputs) == 1)
|
||||
|
||||
# Create fused_node
|
||||
fused_node = UnivariateFunction(
|
||||
deepcopy(new_subgraph_variable_input.inputs[0]),
|
||||
|
||||
@@ -19,6 +19,7 @@ class BaseTracer(ABC):
|
||||
|
||||
inputs: List["BaseTracer"]
|
||||
traced_computation: IntermediateNode
|
||||
output_idx: int
|
||||
output: BaseValue
|
||||
_mix_values_func: Callable[..., BaseValue]
|
||||
|
||||
@@ -26,11 +27,12 @@ class BaseTracer(ABC):
|
||||
self,
|
||||
inputs: Iterable["BaseTracer"],
|
||||
traced_computation: IntermediateNode,
|
||||
output_index: int,
|
||||
output_idx: int,
|
||||
) -> None:
|
||||
self.inputs = list(inputs)
|
||||
self.traced_computation = traced_computation
|
||||
self.output = traced_computation.outputs[output_index]
|
||||
self.output_idx = output_idx
|
||||
self.output = traced_computation.outputs[output_idx]
|
||||
|
||||
@abstractmethod
|
||||
def _supports_other_operand(self, other: Any) -> bool:
|
||||
@@ -96,8 +98,8 @@ class BaseTracer(ABC):
|
||||
)
|
||||
|
||||
output_tracers = tuple(
|
||||
self.__class__(sanitized_inputs, traced_computation, output_index)
|
||||
for output_index in range(len(traced_computation.outputs))
|
||||
self.__class__(sanitized_inputs, traced_computation, output_idx)
|
||||
for output_idx in range(len(traced_computation.outputs))
|
||||
)
|
||||
|
||||
return output_tracers
|
||||
|
||||
@@ -115,8 +115,14 @@ def create_graph_from_output_tracers(
|
||||
|
||||
for input_idx, input_tracer in enumerate(tracer.inputs):
|
||||
input_ir_node = input_tracer.traced_computation
|
||||
output_idx = input_tracer.output_idx
|
||||
graph.add_node(input_ir_node)
|
||||
graph.add_edge(input_ir_node, current_ir_node, input_idx=input_idx)
|
||||
graph.add_edge(
|
||||
input_ir_node,
|
||||
current_ir_node,
|
||||
input_idx=input_idx,
|
||||
output_idx=output_idx,
|
||||
)
|
||||
if input_tracer not in visited_tracers:
|
||||
next_tracers.update({input_tracer: None})
|
||||
|
||||
|
||||
@@ -93,9 +93,7 @@ class NPTracer(BaseTracer):
|
||||
output_dtype=output_dtype,
|
||||
op_name=f"astype({normalized_numpy_dtype})",
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
[self], traced_computation=traced_computation, output_index=0
|
||||
)
|
||||
output_tracer = self.__class__([self], traced_computation=traced_computation, output_idx=0)
|
||||
return output_tracer
|
||||
|
||||
@staticmethod
|
||||
@@ -164,7 +162,7 @@ class NPTracer(BaseTracer):
|
||||
output_tracer = cls(
|
||||
input_tracers,
|
||||
traced_computation=traced_computation,
|
||||
output_index=0,
|
||||
output_idx=0,
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
@@ -229,7 +227,7 @@ class NPTracer(BaseTracer):
|
||||
output_tracer = cls(
|
||||
(input_tracers[in_which_input_is_variable],),
|
||||
traced_computation=traced_computation,
|
||||
output_index=0,
|
||||
output_idx=0,
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
@@ -253,7 +251,7 @@ class NPTracer(BaseTracer):
|
||||
output_tracer = self.__class__(
|
||||
args,
|
||||
traced_computation=traced_computation,
|
||||
output_index=0,
|
||||
output_idx=0,
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ def test_lookup_table_encrypted_lookup(test_helpers):
|
||||
# pylint: enable=protected-access
|
||||
ref_graph.add_node(output_arbitrary_function)
|
||||
|
||||
ref_graph.add_edge(input_x, output_arbitrary_function, input_idx=0)
|
||||
ref_graph.add_edge(input_x, output_arbitrary_function, input_idx=0, output_idx=0)
|
||||
|
||||
# TODO: discuss if this check is enough as == is not overloaded properly for UnivariateFunction
|
||||
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)
|
||||
@@ -112,10 +112,10 @@ def test_lookup_table_encrypted_and_plain_lookup(test_helpers):
|
||||
output_add = ir.Add((intermediate_arbitrary_function.outputs[0], constant_3.outputs[0]))
|
||||
ref_graph.add_node(output_add)
|
||||
|
||||
ref_graph.add_edge(input_x, intermediate_arbitrary_function, input_idx=0)
|
||||
ref_graph.add_edge(input_x, intermediate_arbitrary_function, input_idx=0, output_idx=0)
|
||||
|
||||
ref_graph.add_edge(intermediate_arbitrary_function, output_add, input_idx=0)
|
||||
ref_graph.add_edge(constant_3, output_add, input_idx=1)
|
||||
ref_graph.add_edge(intermediate_arbitrary_function, output_add, input_idx=0, output_idx=0)
|
||||
ref_graph.add_edge(constant_3, output_add, input_idx=1, output_idx=0)
|
||||
|
||||
# TODO: discuss if this check is enough as == is not overloaded properly for UnivariateFunction
|
||||
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)
|
||||
|
||||
@@ -194,7 +194,7 @@ class TestHelpers:
|
||||
def digraphs_are_equivalent(reference: nx.MultiDiGraph, to_compare: nx.MultiDiGraph):
|
||||
"""Check that two digraphs are equivalent without modifications"""
|
||||
# edge_match is a copy of node_match
|
||||
edge_matcher = iso.categorical_multiedge_match("input_idx", None)
|
||||
edge_matcher = iso.categorical_multiedge_match(["input_idx", "output_idx"], [None, None])
|
||||
node_matcher = iso.generic_node_match(
|
||||
"_test_content", None, TestHelpers.nodes_are_equivalent
|
||||
)
|
||||
|
||||
@@ -28,27 +28,27 @@ def test_digraphs_are_equivalent(test_helpers):
|
||||
t_1 = TestNode("Mul")
|
||||
t_2 = TestNode("TLU")
|
||||
|
||||
g_1.add_edge(t_0, t_2, input_idx=0)
|
||||
g_1.add_edge(t_1, t_2, input_idx=1)
|
||||
g_1.add_edge(t_0, t_2, input_idx=0, output_idx=0)
|
||||
g_1.add_edge(t_1, t_2, input_idx=1, output_idx=0)
|
||||
|
||||
t0p = TestNode("Add")
|
||||
t1p = TestNode("Mul")
|
||||
t2p = TestNode("TLU")
|
||||
|
||||
g_2.add_edge(t1p, t2p, input_idx=1)
|
||||
g_2.add_edge(t0p, t2p, input_idx=0)
|
||||
g_2.add_edge(t1p, t2p, input_idx=1, output_idx=0)
|
||||
g_2.add_edge(t0p, t2p, input_idx=0, output_idx=0)
|
||||
|
||||
bad_g2 = nx.MultiDiGraph()
|
||||
|
||||
bad_t0 = TestNode("Not Add")
|
||||
|
||||
bad_g2.add_edge(bad_t0, t_2, input_idx=0)
|
||||
bad_g2.add_edge(t_1, t_2, input_idx=1)
|
||||
bad_g2.add_edge(bad_t0, t_2, input_idx=0, output_idx=0)
|
||||
bad_g2.add_edge(t_1, t_2, input_idx=1, output_idx=0)
|
||||
|
||||
bad_g3 = nx.MultiDiGraph()
|
||||
|
||||
bad_g3.add_edge(t_0, t_2, input_idx=1)
|
||||
bad_g3.add_edge(t_1, t_2, input_idx=0)
|
||||
bad_g3.add_edge(t_0, t_2, input_idx=1, output_idx=0)
|
||||
bad_g3.add_edge(t_1, t_2, input_idx=0, output_idx=0)
|
||||
|
||||
assert test_helpers.digraphs_are_equivalent(g_1, g_2), "Graphs should be equivalent"
|
||||
assert not test_helpers.digraphs_are_equivalent(g_1, bad_g2), "Graphs should not be equivalent"
|
||||
|
||||
@@ -159,11 +159,11 @@ def test_numpy_tracing_binary_op(operation, x, y, test_helpers):
|
||||
ref_graph.add_node(add_node_z)
|
||||
ref_graph.add_node(returned_final_node)
|
||||
|
||||
ref_graph.add_edge(input_x, add_node_z, input_idx=0)
|
||||
ref_graph.add_edge(input_x, add_node_z, input_idx=1)
|
||||
ref_graph.add_edge(input_x, add_node_z, input_idx=0, output_idx=0)
|
||||
ref_graph.add_edge(input_x, add_node_z, input_idx=1, output_idx=0)
|
||||
|
||||
ref_graph.add_edge(add_node_z, returned_final_node, input_idx=0)
|
||||
ref_graph.add_edge(input_y, returned_final_node, input_idx=1)
|
||||
ref_graph.add_edge(add_node_z, returned_final_node, input_idx=0, output_idx=0)
|
||||
ref_graph.add_edge(input_y, returned_final_node, input_idx=1, output_idx=0)
|
||||
|
||||
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user