refactor: use a consistent short name for operation graph

This commit is contained in:
Umut
2021-11-16 11:30:54 +03:00
parent 5d31aa4d2c
commit 46a018fd21
10 changed files with 52 additions and 52 deletions

View File

@@ -52,7 +52,7 @@ del _missing_nodes_in_mapping
def draw_graph(
opgraph: OPGraph,
op_graph: OPGraph,
show: bool = False,
vertical: bool = True,
save_to: Optional[Path] = None,
@@ -60,7 +60,7 @@ def draw_graph(
"""Draws operation graphs and optionally saves/shows the drawing.
Args:
opgraph (OPGraph): the graph to be drawn and optionally saved/shown
op_graph (OPGraph): the operation graph to be drawn and optionally saved/shown
show (bool): if set to True, the drawing will be shown using matplotlib
vertical (bool): if set to True, the orientation will be vertical
save_to (Optional[Path]): if specified, the drawn graph will be saved to this path; else
@@ -79,8 +79,8 @@ def draw_graph(
value_to_return = IR_NODE_COLOR_MAPPING.get(node.op_name, value_to_return)
return value_to_return
graph = opgraph.graph
output_nodes = set(opgraph.output_nodes.values())
graph = op_graph.graph
output_nodes = set(op_graph.output_nodes.values())
attributes = {
node: {

View File

@@ -10,14 +10,14 @@ from ..representation.intermediate import IntermediateNode
def format_operation_graph(
opgraph: OPGraph,
op_graph: OPGraph,
maximum_constant_length: int = 25,
highlighted_nodes: Optional[Dict[IntermediateNode, List[str]]] = None,
) -> str:
"""Format an operation graph.
Args:
opgraph (OPGraph):
op_graph (OPGraph):
the operation graph to format
maximum_constant_length (int):
@@ -29,7 +29,7 @@ def format_operation_graph(
Returns:
str: formatted operation graph
"""
assert_true(isinstance(opgraph, OPGraph))
assert_true(isinstance(op_graph, OPGraph))
# (node, output_index) -> identifier
# e.g., id_map[(node1, 0)] = 2 and id_map[(node1, 1)] = 3
@@ -49,7 +49,7 @@ def format_operation_graph(
# after their type information is added and we only have line numbers, not nodes
highlighted_lines: Dict[int, List[str]] = {}
for node in nx.topological_sort(opgraph.graph):
for node in nx.topological_sort(op_graph.graph):
# assign a unique id to outputs of node
assert_true(len(node.outputs) > 0)
for i in range(len(node.outputs)):
@@ -61,7 +61,7 @@ def format_operation_graph(
# extract predecessors and their ids
predecessors = []
for predecessor, output_idx in opgraph.get_ordered_inputs_of(node):
for predecessor, output_idx in op_graph.get_ordered_inputs_of(node):
predecessors.append(f"%{id_map[(predecessor, output_idx)]}")
# start the build the line for the node
@@ -116,7 +116,7 @@ def format_operation_graph(
# (if there is a single return, it's in the form `return %id`
# (otherwise, it's in the form `return (%id1, %id2, ..., %idN)`
returns: List[str] = []
for node in opgraph.output_nodes.values():
for node in op_graph.output_nodes.values():
outputs = ", ".join(f"%{id_map[(node, i)]}" for i in range(len(node.outputs)))
returns.append(outputs if len(node.outputs) == 1 else f"({outputs})")
lines.append("return " + returns[0] if len(returns) == 1 else f"({', '.join(returns)})")

View File

@@ -13,15 +13,15 @@ from .operator_graph import OPGraph
class FHECircuit:
"""Class which is the result of compilation."""
opgraph: OPGraph
op_graph: OPGraph
engine: CompilerEngine
def __init__(self, opgraph: OPGraph, engine: CompilerEngine):
self.opgraph = opgraph
def __init__(self, op_graph: OPGraph, engine: CompilerEngine):
self.op_graph = op_graph
self.engine = engine
def __str__(self):
return format_operation_graph(self.opgraph)
return format_operation_graph(self.op_graph)
def draw(
self,
@@ -42,7 +42,7 @@ class FHECircuit:
"""
return draw_graph(self.opgraph, show, vertical, save_to)
return draw_graph(self.op_graph, show, vertical, save_to)
def run(self, *args: List[Union[int, numpy.ndarray]]) -> Union[int, numpy.ndarray]:
"""Encrypt, evaluate, and decrypt the inputs on the circuit.

View File

@@ -23,17 +23,17 @@ from .node_converter import IntermediateNodeConverter
class OPGraphConverter(ABC):
"""Converter of OPGraph to MLIR."""
def convert(self, opgraph: OPGraph) -> str:
def convert(self, op_graph: OPGraph) -> str:
"""Convert an operation graph to its corresponding MLIR representation.
Args:
opgraph (OPGraph): the operation graph to be converted
op_graph (OPGraph): the operation graph to be converted
Returns:
str: textual MLIR representation corresponding to opgraph
str: textual MLIR representation corresponding to given operation graph
"""
additional_conversion_info = self._generate_additional_info_dict(opgraph)
additional_conversion_info = self._generate_additional_info_dict(op_graph)
with Context() as ctx, Location.unknown():
zamalang.register_dialects(ctx)
@@ -42,25 +42,25 @@ class OPGraphConverter(ABC):
with InsertionPoint(module.body):
parameters = [
value_to_mlir_type(ctx, input_node.outputs[0])
for input_node in opgraph.get_ordered_inputs()
for input_node in op_graph.get_ordered_inputs()
]
@builtin.FuncOp.from_py_func(*parameters)
def main(*arg):
ir_to_mlir = {}
for arg_num, node in opgraph.input_nodes.items():
for arg_num, node in op_graph.input_nodes.items():
ir_to_mlir[node] = arg[arg_num]
for node in nx.topological_sort(opgraph.graph):
for node in nx.topological_sort(op_graph.graph):
if isinstance(node, Input):
continue
preds = [ir_to_mlir[pred] for pred in opgraph.get_ordered_preds(node)]
node_converter = IntermediateNodeConverter(ctx, opgraph, node, preds)
preds = [ir_to_mlir[pred] for pred in op_graph.get_ordered_preds(node)]
node_converter = IntermediateNodeConverter(ctx, op_graph, node, preds)
ir_to_mlir[node] = node_converter.convert(additional_conversion_info)
results = (
ir_to_mlir[output_node] for output_node in opgraph.get_ordered_outputs()
ir_to_mlir[output_node] for output_node in op_graph.get_ordered_outputs()
)
return results
@@ -68,11 +68,11 @@ class OPGraphConverter(ABC):
@staticmethod
@abstractmethod
def _generate_additional_info_dict(opgraph: OPGraph) -> Dict[str, Any]:
def _generate_additional_info_dict(op_graph: OPGraph) -> Dict[str, Any]:
"""Generate additional conversion info dict for the MLIR converter.
Args:
opgraph (OPGraph): the operation graph from which the additional info will be generated
op_graph (OPGraph): the operation graph from which the additional info will be generated
Returns:
Dict[str, Any]: dict of additional conversion info

View File

@@ -41,7 +41,7 @@ class IntermediateNodeConverter:
"""Converter of IntermediateNode to MLIR."""
ctx: Context
opgraph: OPGraph
op_graph: OPGraph
node: IntermediateNode
preds: List[OpResult]
@@ -50,10 +50,10 @@ class IntermediateNodeConverter:
one_of_the_inputs_is_a_tensor: bool
def __init__(
self, ctx: Context, opgraph: OPGraph, node: IntermediateNode, preds: List[OpResult]
self, ctx: Context, op_graph: OPGraph, node: IntermediateNode, preds: List[OpResult]
):
self.ctx = ctx
self.opgraph = opgraph
self.op_graph = op_graph
self.node = node
self.preds = preds
@@ -219,7 +219,7 @@ class IntermediateNodeConverter:
variable_input_indices = [
idx
for idx, inp in enumerate(self.opgraph.get_ordered_preds(self.node))
for idx, inp in enumerate(self.op_graph.get_ordered_preds(self.node))
if not isinstance(inp, Constant)
]
if len(variable_input_indices) != 1: # pragma: no cover

View File

@@ -66,7 +66,7 @@ class IntermediateNode(ABC):
self.outputs = [mix_values_func(self.inputs[0], self.inputs[1])]
def text_for_formatting(self, predecessors: List[str], _maximum_constant_length: int) -> str:
"""Get the formatted node (used in formatting opgraph).
"""Get the formatted node (used in formatting operation graphs).
Args:
predecessors (List[str]): predecessor names to this node
@@ -80,7 +80,7 @@ class IntermediateNode(ABC):
@abstractmethod
def text_for_drawing(self) -> str:
"""Get the label of the node (used in drawing opgraph).
"""Get the label of the node (used in drawing operation graphs).
Returns:
str: the label of the node

View File

@@ -71,14 +71,14 @@ class NPMLIRConverter(OPGraphConverter):
"""Numpy-specific MLIR converter."""
@staticmethod
def _generate_additional_info_dict(opgraph: OPGraph) -> Dict[str, Any]:
def _generate_additional_info_dict(op_graph: OPGraph) -> Dict[str, Any]:
additional_conversion_info = {}
# Disable numpy warnings during conversion to avoid issues during TLU generation
with numpy.errstate(all="ignore"):
additional_conversion_info["tables"] = {
node: generate_deduplicated_tables(node, opgraph.get_ordered_preds(node))
for node in opgraph.graph.nodes()
node: generate_deduplicated_tables(node, op_graph.get_ordered_preds(node))
for node in op_graph.graph.nodes()
if isinstance(node, GenericFunction)
}

View File

@@ -12,14 +12,14 @@ def test_format_operation_graph_with_multiple_edges(default_compilation_configur
def function(x):
return x + x
opgraph = compile_numpy_function_into_op_graph(
op_graph = compile_numpy_function_into_op_graph(
function,
{"x": EncryptedScalar(Integer(4, True))},
[(i,) for i in range(0, 10)],
default_compilation_configuration,
)
formatted_graph = format_operation_graph(opgraph)
formatted_graph = format_operation_graph(op_graph)
assert (
formatted_graph
== """
@@ -38,15 +38,15 @@ def test_format_operation_graph_with_offending_nodes(default_compilation_configu
def function(x):
return x + 42
opgraph = compile_numpy_function_into_op_graph(
op_graph = compile_numpy_function_into_op_graph(
function,
{"x": EncryptedScalar(Integer(7, True))},
[(i,) for i in range(-5, 5)],
default_compilation_configuration,
)
highlighted_nodes = {opgraph.input_nodes[0]: ["foo"]}
formatted_graph = format_operation_graph(opgraph, highlighted_nodes=highlighted_nodes).strip()
highlighted_nodes = {op_graph.input_nodes[0]: ["foo"]}
formatted_graph = format_operation_graph(op_graph, highlighted_nodes=highlighted_nodes).strip()
assert (
formatted_graph
== """
@@ -60,8 +60,8 @@ return %2
""".strip()
)
highlighted_nodes = {opgraph.input_nodes[0]: ["foo"], opgraph.output_nodes[0]: ["bar", "baz"]}
formatted_graph = format_operation_graph(opgraph, highlighted_nodes=highlighted_nodes).strip()
highlighted_nodes = {op_graph.input_nodes[0]: ["foo"], op_graph.output_nodes[0]: ["bar", "baz"]}
formatted_graph = format_operation_graph(op_graph, highlighted_nodes=highlighted_nodes).strip()
assert (
formatted_graph
== """

View File

@@ -17,7 +17,7 @@ def test_circuit_str(default_compilation_configuration):
inputset = [(i,) for i in range(2 ** 3)]
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration)
assert str(circuit) == format_operation_graph(circuit.opgraph)
assert str(circuit) == format_operation_graph(circuit.op_graph)
def test_circuit_draw(default_compilation_configuration):
@@ -31,8 +31,8 @@ def test_circuit_draw(default_compilation_configuration):
inputset = [(i,) for i in range(2 ** 3)]
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration)
assert filecmp.cmp(circuit.draw(), draw_graph(circuit.opgraph))
assert filecmp.cmp(circuit.draw(vertical=False), draw_graph(circuit.opgraph, vertical=False))
assert filecmp.cmp(circuit.draw(), draw_graph(circuit.op_graph))
assert filecmp.cmp(circuit.draw(vertical=False), draw_graph(circuit.op_graph, vertical=False))
def test_circuit_run(default_compilation_configuration):

View File

@@ -347,15 +347,15 @@ def test_constant_indexing(
for _ in range(10)
]
opgraph = compile_numpy_function_into_op_graph(
op_graph = compile_numpy_function_into_op_graph(
function_with_indexing,
{"x": input_value},
inputset,
default_compilation_configuration,
)
assert len(opgraph.output_nodes) == 1
output_node = opgraph.output_nodes[0]
assert len(op_graph.output_nodes) == 1
output_node = op_graph.output_nodes[0]
assert len(output_node.outputs) == 1
assert output_value == output_node.outputs[0]
@@ -525,15 +525,15 @@ def test_constant_indexing_with_numpy_integers(
for _ in range(10)
]
opgraph = compile_numpy_function_into_op_graph(
op_graph = compile_numpy_function_into_op_graph(
function_with_indexing,
{"x": input_value},
inputset,
default_compilation_configuration,
)
assert len(opgraph.output_nodes) == 1
output_node = opgraph.output_nodes[0]
assert len(op_graph.output_nodes) == 1
output_node = op_graph.output_nodes[0]
assert len(output_node.outputs) == 1
assert output_value == output_node.outputs[0]