mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor: use a consistent short name for operation graph
This commit is contained in:
@@ -52,7 +52,7 @@ del _missing_nodes_in_mapping
|
||||
|
||||
|
||||
def draw_graph(
|
||||
opgraph: OPGraph,
|
||||
op_graph: OPGraph,
|
||||
show: bool = False,
|
||||
vertical: bool = True,
|
||||
save_to: Optional[Path] = None,
|
||||
@@ -60,7 +60,7 @@ def draw_graph(
|
||||
"""Draws operation graphs and optionally saves/shows the drawing.
|
||||
|
||||
Args:
|
||||
opgraph (OPGraph): the graph to be drawn and optionally saved/shown
|
||||
op_graph (OPGraph): the operation graph to be drawn and optionally saved/shown
|
||||
show (bool): if set to True, the drawing will be shown using matplotlib
|
||||
vertical (bool): if set to True, the orientation will be vertical
|
||||
save_to (Optional[Path]): if specified, the drawn graph will be saved to this path; else
|
||||
@@ -79,8 +79,8 @@ def draw_graph(
|
||||
value_to_return = IR_NODE_COLOR_MAPPING.get(node.op_name, value_to_return)
|
||||
return value_to_return
|
||||
|
||||
graph = opgraph.graph
|
||||
output_nodes = set(opgraph.output_nodes.values())
|
||||
graph = op_graph.graph
|
||||
output_nodes = set(op_graph.output_nodes.values())
|
||||
|
||||
attributes = {
|
||||
node: {
|
||||
|
||||
@@ -10,14 +10,14 @@ from ..representation.intermediate import IntermediateNode
|
||||
|
||||
|
||||
def format_operation_graph(
|
||||
opgraph: OPGraph,
|
||||
op_graph: OPGraph,
|
||||
maximum_constant_length: int = 25,
|
||||
highlighted_nodes: Optional[Dict[IntermediateNode, List[str]]] = None,
|
||||
) -> str:
|
||||
"""Format an operation graph.
|
||||
|
||||
Args:
|
||||
opgraph (OPGraph):
|
||||
op_graph (OPGraph):
|
||||
the operation graph to format
|
||||
|
||||
maximum_constant_length (int):
|
||||
@@ -29,7 +29,7 @@ def format_operation_graph(
|
||||
Returns:
|
||||
str: formatted operation graph
|
||||
"""
|
||||
assert_true(isinstance(opgraph, OPGraph))
|
||||
assert_true(isinstance(op_graph, OPGraph))
|
||||
|
||||
# (node, output_index) -> identifier
|
||||
# e.g., id_map[(node1, 0)] = 2 and id_map[(node1, 1)] = 3
|
||||
@@ -49,7 +49,7 @@ def format_operation_graph(
|
||||
# after their type information is added and we only have line numbers, not nodes
|
||||
highlighted_lines: Dict[int, List[str]] = {}
|
||||
|
||||
for node in nx.topological_sort(opgraph.graph):
|
||||
for node in nx.topological_sort(op_graph.graph):
|
||||
# assign a unique id to outputs of node
|
||||
assert_true(len(node.outputs) > 0)
|
||||
for i in range(len(node.outputs)):
|
||||
@@ -61,7 +61,7 @@ def format_operation_graph(
|
||||
|
||||
# extract predecessors and their ids
|
||||
predecessors = []
|
||||
for predecessor, output_idx in opgraph.get_ordered_inputs_of(node):
|
||||
for predecessor, output_idx in op_graph.get_ordered_inputs_of(node):
|
||||
predecessors.append(f"%{id_map[(predecessor, output_idx)]}")
|
||||
|
||||
# start the build the line for the node
|
||||
@@ -116,7 +116,7 @@ def format_operation_graph(
|
||||
# (if there is a single return, it's in the form `return %id`
|
||||
# (otherwise, it's in the form `return (%id1, %id2, ..., %idN)`
|
||||
returns: List[str] = []
|
||||
for node in opgraph.output_nodes.values():
|
||||
for node in op_graph.output_nodes.values():
|
||||
outputs = ", ".join(f"%{id_map[(node, i)]}" for i in range(len(node.outputs)))
|
||||
returns.append(outputs if len(node.outputs) == 1 else f"({outputs})")
|
||||
lines.append("return " + returns[0] if len(returns) == 1 else f"({', '.join(returns)})")
|
||||
|
||||
@@ -13,15 +13,15 @@ from .operator_graph import OPGraph
|
||||
class FHECircuit:
|
||||
"""Class which is the result of compilation."""
|
||||
|
||||
opgraph: OPGraph
|
||||
op_graph: OPGraph
|
||||
engine: CompilerEngine
|
||||
|
||||
def __init__(self, opgraph: OPGraph, engine: CompilerEngine):
|
||||
self.opgraph = opgraph
|
||||
def __init__(self, op_graph: OPGraph, engine: CompilerEngine):
|
||||
self.op_graph = op_graph
|
||||
self.engine = engine
|
||||
|
||||
def __str__(self):
|
||||
return format_operation_graph(self.opgraph)
|
||||
return format_operation_graph(self.op_graph)
|
||||
|
||||
def draw(
|
||||
self,
|
||||
@@ -42,7 +42,7 @@ class FHECircuit:
|
||||
|
||||
"""
|
||||
|
||||
return draw_graph(self.opgraph, show, vertical, save_to)
|
||||
return draw_graph(self.op_graph, show, vertical, save_to)
|
||||
|
||||
def run(self, *args: List[Union[int, numpy.ndarray]]) -> Union[int, numpy.ndarray]:
|
||||
"""Encrypt, evaluate, and decrypt the inputs on the circuit.
|
||||
|
||||
@@ -23,17 +23,17 @@ from .node_converter import IntermediateNodeConverter
|
||||
class OPGraphConverter(ABC):
|
||||
"""Converter of OPGraph to MLIR."""
|
||||
|
||||
def convert(self, opgraph: OPGraph) -> str:
|
||||
def convert(self, op_graph: OPGraph) -> str:
|
||||
"""Convert an operation graph to its corresponding MLIR representation.
|
||||
|
||||
Args:
|
||||
opgraph (OPGraph): the operation graph to be converted
|
||||
op_graph (OPGraph): the operation graph to be converted
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to opgraph
|
||||
str: textual MLIR representation corresponding to given operation graph
|
||||
"""
|
||||
|
||||
additional_conversion_info = self._generate_additional_info_dict(opgraph)
|
||||
additional_conversion_info = self._generate_additional_info_dict(op_graph)
|
||||
|
||||
with Context() as ctx, Location.unknown():
|
||||
zamalang.register_dialects(ctx)
|
||||
@@ -42,25 +42,25 @@ class OPGraphConverter(ABC):
|
||||
with InsertionPoint(module.body):
|
||||
parameters = [
|
||||
value_to_mlir_type(ctx, input_node.outputs[0])
|
||||
for input_node in opgraph.get_ordered_inputs()
|
||||
for input_node in op_graph.get_ordered_inputs()
|
||||
]
|
||||
|
||||
@builtin.FuncOp.from_py_func(*parameters)
|
||||
def main(*arg):
|
||||
ir_to_mlir = {}
|
||||
for arg_num, node in opgraph.input_nodes.items():
|
||||
for arg_num, node in op_graph.input_nodes.items():
|
||||
ir_to_mlir[node] = arg[arg_num]
|
||||
|
||||
for node in nx.topological_sort(opgraph.graph):
|
||||
for node in nx.topological_sort(op_graph.graph):
|
||||
if isinstance(node, Input):
|
||||
continue
|
||||
|
||||
preds = [ir_to_mlir[pred] for pred in opgraph.get_ordered_preds(node)]
|
||||
node_converter = IntermediateNodeConverter(ctx, opgraph, node, preds)
|
||||
preds = [ir_to_mlir[pred] for pred in op_graph.get_ordered_preds(node)]
|
||||
node_converter = IntermediateNodeConverter(ctx, op_graph, node, preds)
|
||||
ir_to_mlir[node] = node_converter.convert(additional_conversion_info)
|
||||
|
||||
results = (
|
||||
ir_to_mlir[output_node] for output_node in opgraph.get_ordered_outputs()
|
||||
ir_to_mlir[output_node] for output_node in op_graph.get_ordered_outputs()
|
||||
)
|
||||
return results
|
||||
|
||||
@@ -68,11 +68,11 @@ class OPGraphConverter(ABC):
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def _generate_additional_info_dict(opgraph: OPGraph) -> Dict[str, Any]:
|
||||
def _generate_additional_info_dict(op_graph: OPGraph) -> Dict[str, Any]:
|
||||
"""Generate additional conversion info dict for the MLIR converter.
|
||||
|
||||
Args:
|
||||
opgraph (OPGraph): the operation graph from which the additional info will be generated
|
||||
op_graph (OPGraph): the operation graph from which the additional info will be generated
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: dict of additional conversion info
|
||||
|
||||
@@ -41,7 +41,7 @@ class IntermediateNodeConverter:
|
||||
"""Converter of IntermediateNode to MLIR."""
|
||||
|
||||
ctx: Context
|
||||
opgraph: OPGraph
|
||||
op_graph: OPGraph
|
||||
node: IntermediateNode
|
||||
preds: List[OpResult]
|
||||
|
||||
@@ -50,10 +50,10 @@ class IntermediateNodeConverter:
|
||||
one_of_the_inputs_is_a_tensor: bool
|
||||
|
||||
def __init__(
|
||||
self, ctx: Context, opgraph: OPGraph, node: IntermediateNode, preds: List[OpResult]
|
||||
self, ctx: Context, op_graph: OPGraph, node: IntermediateNode, preds: List[OpResult]
|
||||
):
|
||||
self.ctx = ctx
|
||||
self.opgraph = opgraph
|
||||
self.op_graph = op_graph
|
||||
self.node = node
|
||||
self.preds = preds
|
||||
|
||||
@@ -219,7 +219,7 @@ class IntermediateNodeConverter:
|
||||
|
||||
variable_input_indices = [
|
||||
idx
|
||||
for idx, inp in enumerate(self.opgraph.get_ordered_preds(self.node))
|
||||
for idx, inp in enumerate(self.op_graph.get_ordered_preds(self.node))
|
||||
if not isinstance(inp, Constant)
|
||||
]
|
||||
if len(variable_input_indices) != 1: # pragma: no cover
|
||||
|
||||
@@ -66,7 +66,7 @@ class IntermediateNode(ABC):
|
||||
self.outputs = [mix_values_func(self.inputs[0], self.inputs[1])]
|
||||
|
||||
def text_for_formatting(self, predecessors: List[str], _maximum_constant_length: int) -> str:
|
||||
"""Get the formatted node (used in formatting opgraph).
|
||||
"""Get the formatted node (used in formatting operation graphs).
|
||||
|
||||
Args:
|
||||
predecessors (List[str]): predecessor names to this node
|
||||
@@ -80,7 +80,7 @@ class IntermediateNode(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def text_for_drawing(self) -> str:
|
||||
"""Get the label of the node (used in drawing opgraph).
|
||||
"""Get the label of the node (used in drawing operation graphs).
|
||||
|
||||
Returns:
|
||||
str: the label of the node
|
||||
|
||||
@@ -71,14 +71,14 @@ class NPMLIRConverter(OPGraphConverter):
|
||||
"""Numpy-specific MLIR converter."""
|
||||
|
||||
@staticmethod
|
||||
def _generate_additional_info_dict(opgraph: OPGraph) -> Dict[str, Any]:
|
||||
def _generate_additional_info_dict(op_graph: OPGraph) -> Dict[str, Any]:
|
||||
additional_conversion_info = {}
|
||||
|
||||
# Disable numpy warnings during conversion to avoid issues during TLU generation
|
||||
with numpy.errstate(all="ignore"):
|
||||
additional_conversion_info["tables"] = {
|
||||
node: generate_deduplicated_tables(node, opgraph.get_ordered_preds(node))
|
||||
for node in opgraph.graph.nodes()
|
||||
node: generate_deduplicated_tables(node, op_graph.get_ordered_preds(node))
|
||||
for node in op_graph.graph.nodes()
|
||||
if isinstance(node, GenericFunction)
|
||||
}
|
||||
|
||||
|
||||
@@ -12,14 +12,14 @@ def test_format_operation_graph_with_multiple_edges(default_compilation_configur
|
||||
def function(x):
|
||||
return x + x
|
||||
|
||||
opgraph = compile_numpy_function_into_op_graph(
|
||||
op_graph = compile_numpy_function_into_op_graph(
|
||||
function,
|
||||
{"x": EncryptedScalar(Integer(4, True))},
|
||||
[(i,) for i in range(0, 10)],
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
formatted_graph = format_operation_graph(opgraph)
|
||||
formatted_graph = format_operation_graph(op_graph)
|
||||
assert (
|
||||
formatted_graph
|
||||
== """
|
||||
@@ -38,15 +38,15 @@ def test_format_operation_graph_with_offending_nodes(default_compilation_configu
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
opgraph = compile_numpy_function_into_op_graph(
|
||||
op_graph = compile_numpy_function_into_op_graph(
|
||||
function,
|
||||
{"x": EncryptedScalar(Integer(7, True))},
|
||||
[(i,) for i in range(-5, 5)],
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
highlighted_nodes = {opgraph.input_nodes[0]: ["foo"]}
|
||||
formatted_graph = format_operation_graph(opgraph, highlighted_nodes=highlighted_nodes).strip()
|
||||
highlighted_nodes = {op_graph.input_nodes[0]: ["foo"]}
|
||||
formatted_graph = format_operation_graph(op_graph, highlighted_nodes=highlighted_nodes).strip()
|
||||
assert (
|
||||
formatted_graph
|
||||
== """
|
||||
@@ -60,8 +60,8 @@ return %2
|
||||
""".strip()
|
||||
)
|
||||
|
||||
highlighted_nodes = {opgraph.input_nodes[0]: ["foo"], opgraph.output_nodes[0]: ["bar", "baz"]}
|
||||
formatted_graph = format_operation_graph(opgraph, highlighted_nodes=highlighted_nodes).strip()
|
||||
highlighted_nodes = {op_graph.input_nodes[0]: ["foo"], op_graph.output_nodes[0]: ["bar", "baz"]}
|
||||
formatted_graph = format_operation_graph(op_graph, highlighted_nodes=highlighted_nodes).strip()
|
||||
assert (
|
||||
formatted_graph
|
||||
== """
|
||||
|
||||
@@ -17,7 +17,7 @@ def test_circuit_str(default_compilation_configuration):
|
||||
inputset = [(i,) for i in range(2 ** 3)]
|
||||
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration)
|
||||
|
||||
assert str(circuit) == format_operation_graph(circuit.opgraph)
|
||||
assert str(circuit) == format_operation_graph(circuit.op_graph)
|
||||
|
||||
|
||||
def test_circuit_draw(default_compilation_configuration):
|
||||
@@ -31,8 +31,8 @@ def test_circuit_draw(default_compilation_configuration):
|
||||
inputset = [(i,) for i in range(2 ** 3)]
|
||||
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration)
|
||||
|
||||
assert filecmp.cmp(circuit.draw(), draw_graph(circuit.opgraph))
|
||||
assert filecmp.cmp(circuit.draw(vertical=False), draw_graph(circuit.opgraph, vertical=False))
|
||||
assert filecmp.cmp(circuit.draw(), draw_graph(circuit.op_graph))
|
||||
assert filecmp.cmp(circuit.draw(vertical=False), draw_graph(circuit.op_graph, vertical=False))
|
||||
|
||||
|
||||
def test_circuit_run(default_compilation_configuration):
|
||||
|
||||
@@ -347,15 +347,15 @@ def test_constant_indexing(
|
||||
for _ in range(10)
|
||||
]
|
||||
|
||||
opgraph = compile_numpy_function_into_op_graph(
|
||||
op_graph = compile_numpy_function_into_op_graph(
|
||||
function_with_indexing,
|
||||
{"x": input_value},
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
assert len(opgraph.output_nodes) == 1
|
||||
output_node = opgraph.output_nodes[0]
|
||||
assert len(op_graph.output_nodes) == 1
|
||||
output_node = op_graph.output_nodes[0]
|
||||
|
||||
assert len(output_node.outputs) == 1
|
||||
assert output_value == output_node.outputs[0]
|
||||
@@ -525,15 +525,15 @@ def test_constant_indexing_with_numpy_integers(
|
||||
for _ in range(10)
|
||||
]
|
||||
|
||||
opgraph = compile_numpy_function_into_op_graph(
|
||||
op_graph = compile_numpy_function_into_op_graph(
|
||||
function_with_indexing,
|
||||
{"x": input_value},
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
assert len(opgraph.output_nodes) == 1
|
||||
output_node = opgraph.output_nodes[0]
|
||||
assert len(op_graph.output_nodes) == 1
|
||||
output_node = op_graph.output_nodes[0]
|
||||
|
||||
assert len(output_node.outputs) == 1
|
||||
assert output_value == output_node.outputs[0]
|
||||
|
||||
Reference in New Issue
Block a user