mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
refactor: remove the content= when adding nodes to a graph
- required by tests but can be done by the testing function itself
This commit is contained in:
@@ -42,7 +42,7 @@ def fuse_float_operations(op_graph: OPGraph):
|
||||
|
||||
fused_node, node_before_subgraph = subgraph_conversion_result
|
||||
|
||||
nx_graph.add_node(fused_node, content=fused_node)
|
||||
nx_graph.add_node(fused_node)
|
||||
|
||||
if terminal_node in op_graph.output_nodes.values():
|
||||
# Output value replace it
|
||||
|
||||
@@ -108,11 +108,11 @@ def create_graph_from_output_tracers(
|
||||
next_tracers: Dict[BaseTracer, None] = dict()
|
||||
for tracer in current_tracers:
|
||||
current_ir_node = tracer.traced_computation
|
||||
graph.add_node(current_ir_node, content=current_ir_node)
|
||||
graph.add_node(current_ir_node)
|
||||
|
||||
for input_idx, input_tracer in enumerate(tracer.inputs):
|
||||
input_ir_node = input_tracer.traced_computation
|
||||
graph.add_node(input_ir_node, content=input_ir_node)
|
||||
graph.add_node(input_ir_node)
|
||||
graph.add_edge(input_ir_node, current_ir_node, input_idx=input_idx)
|
||||
if input_tracer not in visited_tracers:
|
||||
next_tracers.update({input_tracer: None})
|
||||
|
||||
@@ -50,7 +50,7 @@ def test_lookup_table_encrypted_lookup(test_helpers):
|
||||
# (x) - (TLU)
|
||||
|
||||
input_x = ir.Input(input_value=x, input_name="x", program_input_idx=0)
|
||||
ref_graph.add_node(input_x, content=input_x)
|
||||
ref_graph.add_node(input_x)
|
||||
|
||||
output_arbitrary_function = ir.ArbitraryFunction(
|
||||
input_base_value=x,
|
||||
@@ -59,7 +59,7 @@ def test_lookup_table_encrypted_lookup(test_helpers):
|
||||
op_kwargs={"table": deepcopy(table.table)},
|
||||
op_name="TLU",
|
||||
)
|
||||
ref_graph.add_node(output_arbitrary_function, content=output_arbitrary_function)
|
||||
ref_graph.add_node(output_arbitrary_function)
|
||||
|
||||
ref_graph.add_edge(input_x, output_arbitrary_function, input_idx=0)
|
||||
|
||||
@@ -87,7 +87,7 @@ def test_lookup_table_encrypted_and_plain_lookup(test_helpers):
|
||||
# (3)
|
||||
|
||||
input_x = ir.Input(input_value=x, input_name="x", program_input_idx=0)
|
||||
ref_graph.add_node(input_x, content=input_x)
|
||||
ref_graph.add_node(input_x)
|
||||
|
||||
intermediate_arbitrary_function = ir.ArbitraryFunction(
|
||||
input_base_value=x,
|
||||
@@ -96,13 +96,13 @@ def test_lookup_table_encrypted_and_plain_lookup(test_helpers):
|
||||
op_kwargs={"table": deepcopy(table.table)},
|
||||
op_name="TLU",
|
||||
)
|
||||
ref_graph.add_node(intermediate_arbitrary_function, content=intermediate_arbitrary_function)
|
||||
ref_graph.add_node(intermediate_arbitrary_function)
|
||||
|
||||
constant_3 = ir.Constant(3)
|
||||
ref_graph.add_node(constant_3, content=constant_3)
|
||||
ref_graph.add_node(constant_3)
|
||||
|
||||
output_add = ir.Add((intermediate_arbitrary_function.outputs[0], constant_3.outputs[0]))
|
||||
ref_graph.add_node(output_add, content=output_add)
|
||||
ref_graph.add_node(output_add)
|
||||
|
||||
ref_graph.add_edge(input_x, intermediate_arbitrary_function, input_idx=0)
|
||||
|
||||
|
||||
@@ -13,8 +13,16 @@ class TestHelpers:
|
||||
# edge_match is a copy of node_match
|
||||
edge_matcher = iso.categorical_multiedge_match("input_idx", None)
|
||||
node_matcher = iso.generic_node_match(
|
||||
"content", None, lambda lhs, rhs: lhs.is_equivalent_to(rhs)
|
||||
"_test_content", None, lambda lhs, rhs: lhs.is_equivalent_to(rhs)
|
||||
)
|
||||
|
||||
# Set the _test_content for each node in the graphs
|
||||
for node in reference.nodes():
|
||||
reference.add_node(node, _test_content=node)
|
||||
|
||||
for node in to_compare.nodes():
|
||||
to_compare.add_node(node, _test_content=node)
|
||||
|
||||
graphs_are_isomorphic = nx.is_isomorphic(
|
||||
reference,
|
||||
to_compare,
|
||||
|
||||
@@ -31,10 +31,6 @@ def test_digraphs_are_equivalent(test_helpers):
|
||||
g_1.add_edge(t_0, t_2, input_idx=0)
|
||||
g_1.add_edge(t_1, t_2, input_idx=1)
|
||||
|
||||
# This updates the nodes attributes in the graph
|
||||
for node in g_1:
|
||||
g_1.add_node(node, content=node)
|
||||
|
||||
t0p = TestNode("Add")
|
||||
t1p = TestNode("Mul")
|
||||
t2p = TestNode("TLU")
|
||||
@@ -42,10 +38,6 @@ def test_digraphs_are_equivalent(test_helpers):
|
||||
g_2.add_edge(t1p, t2p, input_idx=1)
|
||||
g_2.add_edge(t0p, t2p, input_idx=0)
|
||||
|
||||
# This updates the nodes attributes in the graph
|
||||
for node in g_2:
|
||||
g_2.add_node(node, content=node)
|
||||
|
||||
bad_g2 = nx.MultiDiGraph()
|
||||
|
||||
bad_t0 = TestNode("Not Add")
|
||||
@@ -53,19 +45,11 @@ def test_digraphs_are_equivalent(test_helpers):
|
||||
bad_g2.add_edge(bad_t0, t_2, input_idx=0)
|
||||
bad_g2.add_edge(t_1, t_2, input_idx=1)
|
||||
|
||||
# This updates the nodes attributes in the graph
|
||||
for node in bad_g2:
|
||||
bad_g2.add_node(node, content=node)
|
||||
|
||||
bad_g3 = nx.MultiDiGraph()
|
||||
|
||||
bad_g3.add_edge(t_0, t_2, input_idx=1)
|
||||
bad_g3.add_edge(t_1, t_2, input_idx=0)
|
||||
|
||||
# This updates the nodes attributes in the graph
|
||||
for node in bad_g3:
|
||||
bad_g3.add_node(node, content=node)
|
||||
|
||||
assert test_helpers.digraphs_are_equivalent(g_1, g_2), "Graphs should be equivalent"
|
||||
assert not test_helpers.digraphs_are_equivalent(g_1, bad_g2), "Graphs should not be equivalent"
|
||||
assert not test_helpers.digraphs_are_equivalent(g_2, bad_g2), "Graphs should not be equivalent"
|
||||
|
||||
@@ -100,10 +100,10 @@ def test_hnumpy_tracing_binary_op(operation, x, y, test_helpers):
|
||||
)
|
||||
)
|
||||
|
||||
ref_graph.add_node(input_x, content=input_x)
|
||||
ref_graph.add_node(input_y, content=input_y)
|
||||
ref_graph.add_node(add_node_z, content=add_node_z)
|
||||
ref_graph.add_node(returned_final_node, content=returned_final_node)
|
||||
ref_graph.add_node(input_x)
|
||||
ref_graph.add_node(input_y)
|
||||
ref_graph.add_node(add_node_z)
|
||||
ref_graph.add_node(returned_final_node)
|
||||
|
||||
ref_graph.add_edge(input_x, add_node_z, input_idx=0)
|
||||
ref_graph.add_edge(input_x, add_node_z, input_idx=1)
|
||||
|
||||
Reference in New Issue
Block a user