refactor: remove the content= when adding nodes to a graph

- required by tests but can be done by the testing function itself
This commit is contained in:
Arthur Meyre
2021-08-26 11:01:05 +02:00
parent 9a3e15e89a
commit 31259e556c
6 changed files with 22 additions and 30 deletions

View File

@@ -42,7 +42,7 @@ def fuse_float_operations(op_graph: OPGraph):
fused_node, node_before_subgraph = subgraph_conversion_result
nx_graph.add_node(fused_node, content=fused_node)
nx_graph.add_node(fused_node)
if terminal_node in op_graph.output_nodes.values():
# Output value replace it

View File

@@ -108,11 +108,11 @@ def create_graph_from_output_tracers(
next_tracers: Dict[BaseTracer, None] = dict()
for tracer in current_tracers:
current_ir_node = tracer.traced_computation
graph.add_node(current_ir_node, content=current_ir_node)
graph.add_node(current_ir_node)
for input_idx, input_tracer in enumerate(tracer.inputs):
input_ir_node = input_tracer.traced_computation
graph.add_node(input_ir_node, content=input_ir_node)
graph.add_node(input_ir_node)
graph.add_edge(input_ir_node, current_ir_node, input_idx=input_idx)
if input_tracer not in visited_tracers:
next_tracers.update({input_tracer: None})

View File

@@ -50,7 +50,7 @@ def test_lookup_table_encrypted_lookup(test_helpers):
# (x) - (TLU)
input_x = ir.Input(input_value=x, input_name="x", program_input_idx=0)
ref_graph.add_node(input_x, content=input_x)
ref_graph.add_node(input_x)
output_arbitrary_function = ir.ArbitraryFunction(
input_base_value=x,
@@ -59,7 +59,7 @@ def test_lookup_table_encrypted_lookup(test_helpers):
op_kwargs={"table": deepcopy(table.table)},
op_name="TLU",
)
ref_graph.add_node(output_arbitrary_function, content=output_arbitrary_function)
ref_graph.add_node(output_arbitrary_function)
ref_graph.add_edge(input_x, output_arbitrary_function, input_idx=0)
@@ -87,7 +87,7 @@ def test_lookup_table_encrypted_and_plain_lookup(test_helpers):
# (3)
input_x = ir.Input(input_value=x, input_name="x", program_input_idx=0)
ref_graph.add_node(input_x, content=input_x)
ref_graph.add_node(input_x)
intermediate_arbitrary_function = ir.ArbitraryFunction(
input_base_value=x,
@@ -96,13 +96,13 @@ def test_lookup_table_encrypted_and_plain_lookup(test_helpers):
op_kwargs={"table": deepcopy(table.table)},
op_name="TLU",
)
ref_graph.add_node(intermediate_arbitrary_function, content=intermediate_arbitrary_function)
ref_graph.add_node(intermediate_arbitrary_function)
constant_3 = ir.Constant(3)
ref_graph.add_node(constant_3, content=constant_3)
ref_graph.add_node(constant_3)
output_add = ir.Add((intermediate_arbitrary_function.outputs[0], constant_3.outputs[0]))
ref_graph.add_node(output_add, content=output_add)
ref_graph.add_node(output_add)
ref_graph.add_edge(input_x, intermediate_arbitrary_function, input_idx=0)

View File

@@ -13,8 +13,16 @@ class TestHelpers:
# edge_match is a copy of node_match
edge_matcher = iso.categorical_multiedge_match("input_idx", None)
node_matcher = iso.generic_node_match(
"content", None, lambda lhs, rhs: lhs.is_equivalent_to(rhs)
"_test_content", None, lambda lhs, rhs: lhs.is_equivalent_to(rhs)
)
# Set the _test_content for each node in the graphs
for node in reference.nodes():
reference.add_node(node, _test_content=node)
for node in to_compare.nodes():
to_compare.add_node(node, _test_content=node)
graphs_are_isomorphic = nx.is_isomorphic(
reference,
to_compare,

View File

@@ -31,10 +31,6 @@ def test_digraphs_are_equivalent(test_helpers):
g_1.add_edge(t_0, t_2, input_idx=0)
g_1.add_edge(t_1, t_2, input_idx=1)
# This updates the nodes attributes in the graph
for node in g_1:
g_1.add_node(node, content=node)
t0p = TestNode("Add")
t1p = TestNode("Mul")
t2p = TestNode("TLU")
@@ -42,10 +38,6 @@ def test_digraphs_are_equivalent(test_helpers):
g_2.add_edge(t1p, t2p, input_idx=1)
g_2.add_edge(t0p, t2p, input_idx=0)
# This updates the nodes attributes in the graph
for node in g_2:
g_2.add_node(node, content=node)
bad_g2 = nx.MultiDiGraph()
bad_t0 = TestNode("Not Add")
@@ -53,19 +45,11 @@ def test_digraphs_are_equivalent(test_helpers):
bad_g2.add_edge(bad_t0, t_2, input_idx=0)
bad_g2.add_edge(t_1, t_2, input_idx=1)
# This updates the nodes attributes in the graph
for node in bad_g2:
bad_g2.add_node(node, content=node)
bad_g3 = nx.MultiDiGraph()
bad_g3.add_edge(t_0, t_2, input_idx=1)
bad_g3.add_edge(t_1, t_2, input_idx=0)
# This updates the nodes attributes in the graph
for node in bad_g3:
bad_g3.add_node(node, content=node)
assert test_helpers.digraphs_are_equivalent(g_1, g_2), "Graphs should be equivalent"
assert not test_helpers.digraphs_are_equivalent(g_1, bad_g2), "Graphs should not be equivalent"
assert not test_helpers.digraphs_are_equivalent(g_2, bad_g2), "Graphs should not be equivalent"

View File

@@ -100,10 +100,10 @@ def test_hnumpy_tracing_binary_op(operation, x, y, test_helpers):
)
)
ref_graph.add_node(input_x, content=input_x)
ref_graph.add_node(input_y, content=input_y)
ref_graph.add_node(add_node_z, content=add_node_z)
ref_graph.add_node(returned_final_node, content=returned_final_node)
ref_graph.add_node(input_x)
ref_graph.add_node(input_y)
ref_graph.add_node(add_node_z)
ref_graph.add_node(returned_final_node)
ref_graph.add_edge(input_x, add_node_z, input_idx=0)
ref_graph.add_edge(input_x, add_node_z, input_idx=1)