From b971f6b913833372acb95e97bd99667a03dad01d Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Thu, 23 Sep 2021 10:50:05 +0200 Subject: [PATCH] refactor: remove as ir imports to avoid breaking sphinx docs links --- concrete/common/common_helpers.py | 10 ++-- concrete/common/compilation/artifacts.py | 8 +-- concrete/common/debugging/drawing.py | 28 +++++++---- concrete/common/debugging/printing.py | 8 +-- concrete/common/extensions/table.py | 4 +- concrete/common/mlir/converters.py | 14 +++--- concrete/common/mlir/mlir_converter.py | 4 +- concrete/common/operator_graph.py | 52 +++++++++---------- concrete/common/optimization/topological.py | 56 ++++++++++----------- concrete/common/tracing/base_tracer.py | 25 +++++---- concrete/common/tracing/tracing_helpers.py | 4 +- concrete/numpy/compile.py | 4 +- 12 files changed, 115 insertions(+), 102 deletions(-) diff --git a/concrete/common/common_helpers.py b/concrete/common/common_helpers.py index 53b0989f8..7aa55eda4 100644 --- a/concrete/common/common_helpers.py +++ b/concrete/common/common_helpers.py @@ -5,7 +5,7 @@ from typing import List, Optional from .data_types.integers import Integer from .debugging import custom_assert from .operator_graph import OPGraph -from .representation import intermediate as ir +from .representation.intermediate import IntermediateNode def is_a_power_of_2(x: int) -> bool: @@ -22,11 +22,11 @@ def is_a_power_of_2(x: int) -> bool: return x > 0 and (x & (x - 1)) == 0 -def ir_nodes_has_integer_input_and_output(node: ir.IntermediateNode) -> bool: +def ir_nodes_has_integer_input_and_output(node: IntermediateNode) -> bool: """Check if an ir node has Integer inputs and outputs. Args: - node (ir.IntermediateNode): Node to check + node (IntermediateNode): Node to check Returns: bool: True if all input and output values hold Integers @@ -40,13 +40,13 @@ def ir_nodes_has_integer_input_and_output(node: ir.IntermediateNode) -> bool: # long run probably def check_op_graph_is_integer_program( op_graph: OPGraph, - offending_nodes_out: Optional[List[ir.IntermediateNode]] = None, + offending_nodes_out: Optional[List[IntermediateNode]] = None, ) -> bool: """Check if an op_graph inputs, outputs and intermediate values are Integers. Args: op_graph (OPGraph): The OPGraph to check - offending_nodes_out (Optional[List[ir.IntermediateNode]]): Optionally pass a list that will + offending_nodes_out (Optional[List[IntermediateNode]]): Optionally pass a list that will be populated with offending nodes, the list will be cleared before being filled Returns: diff --git a/concrete/common/compilation/artifacts.py b/concrete/common/compilation/artifacts.py index c1af6be45..88746a10d 100644 --- a/concrete/common/compilation/artifacts.py +++ b/concrete/common/compilation/artifacts.py @@ -12,7 +12,7 @@ from PIL import Image from ..debugging import custom_assert, draw_graph, get_printable_graph from ..operator_graph import OPGraph -from ..representation import intermediate as ir +from ..representation.intermediate import IntermediateNode from ..values import BaseValue DEFAULT_OUTPUT_DIRECTORY: Path = Path(".artifacts") @@ -30,7 +30,7 @@ class CompilationArtifacts: textual_representations_of_operation_graphs: Dict[str, str] final_operation_graph: Optional[OPGraph] - bounds_of_the_final_operation_graph: Optional[Dict[ir.IntermediateNode, Dict[str, Any]]] + bounds_of_the_final_operation_graph: Optional[Dict[IntermediateNode, Dict[str, Any]]] mlir_of_the_final_operation_graph: Optional[str] def __init__(self, output_directory: Path = DEFAULT_OUTPUT_DIRECTORY): @@ -92,11 +92,11 @@ class CompilationArtifacts: self.final_operation_graph = operation_graph - def add_final_operation_graph_bounds(self, bounds: Dict[ir.IntermediateNode, Dict[str, Any]]): + def add_final_operation_graph_bounds(self, bounds: Dict[IntermediateNode, Dict[str, Any]]): """Add the bounds of the final operation graph to the artifacts. Args: - bounds (Dict[ir.IntermediateNode, Dict[str, Any]]): the bound dictionary + bounds (Dict[IntermediateNode, Dict[str, Any]]): the bound dictionary Returns: None diff --git a/concrete/common/debugging/drawing.py b/concrete/common/debugging/drawing.py index cb088c24b..c06115996 100644 --- a/concrete/common/debugging/drawing.py +++ b/concrete/common/debugging/drawing.py @@ -11,17 +11,25 @@ from PIL import Image from ..debugging.custom_assert import custom_assert from ..operator_graph import OPGraph -from ..representation import intermediate as ir -from ..representation.intermediate import ALL_IR_NODES +from ..representation.intermediate import ( + ALL_IR_NODES, + Add, + ArbitraryFunction, + Constant, + Dot, + Input, + Mul, + Sub, +) IR_NODE_COLOR_MAPPING = { - ir.Input: "blue", - ir.Constant: "cyan", - ir.Add: "red", - ir.Sub: "yellow", - ir.Mul: "green", - ir.ArbitraryFunction: "orange", - ir.Dot: "purple", + Input: "blue", + Constant: "cyan", + Add: "red", + Sub: "yellow", + Mul: "green", + ArbitraryFunction: "orange", + Dot: "purple", "ArbitraryFunction": "orange", "TLU": "grey", "output": "magenta", @@ -63,7 +71,7 @@ def draw_graph( value_to_return = IR_NODE_COLOR_MAPPING[type(node)] if node in output_nodes: value_to_return = IR_NODE_COLOR_MAPPING["output"] - elif isinstance(node, ir.ArbitraryFunction): + elif isinstance(node, ArbitraryFunction): value_to_return = IR_NODE_COLOR_MAPPING.get(node.op_name, value_to_return) return value_to_return diff --git a/concrete/common/debugging/printing.py b/concrete/common/debugging/printing.py index 3b3123f82..f82fb8203 100644 --- a/concrete/common/debugging/printing.py +++ b/concrete/common/debugging/printing.py @@ -6,7 +6,7 @@ import networkx as nx from ..debugging.custom_assert import custom_assert from ..operator_graph import OPGraph -from ..representation import intermediate as ir +from ..representation.intermediate import ArbitraryFunction, Constant, Input def output_data_type_to_string(node): @@ -49,15 +49,15 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str: # they only are done by incrementing i custom_assert(len(node.outputs) == 1) - if isinstance(node, ir.Input): + if isinstance(node, Input): what_to_print = node.input_name - elif isinstance(node, ir.Constant): + elif isinstance(node, Constant): what_to_print = f"Constant({node.constant_data})" else: base_name = node.__class__.__name__ - if isinstance(node, ir.ArbitraryFunction): + if isinstance(node, ArbitraryFunction): base_name = node.op_name what_to_print = base_name + "(" diff --git a/concrete/common/extensions/table.py b/concrete/common/extensions/table.py index 8fc3eac87..971a4309f 100644 --- a/concrete/common/extensions/table.py +++ b/concrete/common/extensions/table.py @@ -6,7 +6,7 @@ from typing import Iterable, Tuple, Union from ..common_helpers import is_a_power_of_2 from ..data_types.base import BaseDataType from ..data_types.integers import make_integer_to_hold -from ..representation import intermediate as ir +from ..representation.intermediate import ArbitraryFunction from ..tracing.base_tracer import BaseTracer @@ -35,7 +35,7 @@ class LookupTable: # we need to create an `ArbitraryFunction` node # because the result will be determined during the runtime if isinstance(key, BaseTracer): - traced_computation = ir.ArbitraryFunction( + traced_computation = ArbitraryFunction( input_base_value=key.output, arbitrary_func=LookupTable._checked_indexing, output_dtype=self.output_dtype, diff --git a/concrete/common/mlir/converters.py b/concrete/common/mlir/converters.py index 746013839..b9bc1452a 100644 --- a/concrete/common/mlir/converters.py +++ b/concrete/common/mlir/converters.py @@ -22,7 +22,7 @@ from ..data_types.dtypes_helpers import ( value_is_encrypted_tensor_integer, ) from ..debugging.custom_assert import custom_assert -from ..representation import intermediate as ir +from ..representation.intermediate import Add, ArbitraryFunction, Constant, Dot, Mul, Sub def add(node, preds, ir_to_mlir_node, ctx): @@ -189,12 +189,12 @@ def dot(node, preds, ir_to_mlir_node, ctx): V0_OPSET_CONVERSION_FUNCTIONS = { - ir.Add: add, - ir.Sub: sub, - ir.Mul: mul, - ir.Constant: constant, - ir.ArbitraryFunction: apply_lut, - ir.Dot: dot, + Add: add, + Sub: sub, + Mul: mul, + Constant: constant, + ArbitraryFunction: apply_lut, + Dot: dot, } # pylint: enable=no-name-in-module,no-member diff --git a/concrete/common/mlir/mlir_converter.py b/concrete/common/mlir/mlir_converter.py index 7273268a3..aec6b7906 100644 --- a/concrete/common/mlir/mlir_converter.py +++ b/concrete/common/mlir/mlir_converter.py @@ -20,7 +20,7 @@ from ..data_types.dtypes_helpers import ( ) from ..debugging.custom_assert import custom_assert from ..operator_graph import OPGraph -from ..representation import intermediate as ir +from ..representation.intermediate import Input class MLIRConverter: @@ -151,7 +151,7 @@ class MLIRConverter: for arg_num, node in op_graph.input_nodes.items(): ir_to_mlir_node[node] = arg[arg_num] for node in nx.topological_sort(op_graph.graph): - if isinstance(node, ir.Input): + if isinstance(node, Input): continue mlir_op = self.conversion_functions.get(type(node), None) if mlir_op is None: # pragma: no cover diff --git a/concrete/common/operator_graph.py b/concrete/common/operator_graph.py index 1c5ee104e..7d8464cb9 100644 --- a/concrete/common/operator_graph.py +++ b/concrete/common/operator_graph.py @@ -13,7 +13,7 @@ from .data_types.dtypes_helpers import ( from .data_types.floats import Float from .data_types.integers import Integer, make_integer_to_hold from .debugging.custom_assert import custom_assert -from .representation import intermediate as ir +from .representation.intermediate import Input, IntermediateNode from .tracing import BaseTracer from .tracing.tracing_helpers import create_graph_from_output_tracers @@ -22,25 +22,25 @@ class OPGraph: """Class to make work with nx graphs easier.""" graph: nx.MultiDiGraph - input_nodes: Dict[int, ir.Input] - output_nodes: Dict[int, ir.IntermediateNode] + input_nodes: Dict[int, Input] + output_nodes: Dict[int, IntermediateNode] def __init__( self, graph: nx.MultiDiGraph, - input_nodes: Dict[int, ir.Input], - output_nodes: Dict[int, ir.IntermediateNode], + input_nodes: Dict[int, Input], + output_nodes: Dict[int, IntermediateNode], ) -> None: custom_assert( len(input_nodes) > 0, "Got a graph without input nodes which is not supported" ) custom_assert( - all(isinstance(node, ir.Input) for node in input_nodes.values()), - "Got input nodes that were not ir.Input, which is not supported", + all(isinstance(node, Input) for node in input_nodes.values()), + "Got input nodes that were not Input, which is not supported", ) custom_assert( - all(isinstance(node, ir.IntermediateNode) for node in output_nodes.values()), - "Got output nodes which were not ir.IntermediateNode, which is not supported", + all(isinstance(node, IntermediateNode) for node in output_nodes.values()), + "Got output nodes which were not IntermediateNode, which is not supported", ) self.graph = graph @@ -75,7 +75,7 @@ class OPGraph: input_nodes = { node.program_input_idx: node for node in graph.nodes() - if len(graph.pred[node]) == 0 and isinstance(node, ir.Input) + if len(graph.pred[node]) == 0 and isinstance(node, Input) } output_nodes = { output_idx: tracer.traced_computation @@ -86,50 +86,50 @@ class OPGraph: @staticmethod def from_graph( graph: nx.MultiDiGraph, - input_nodes: Iterable[ir.Input], - output_nodes: Iterable[ir.IntermediateNode], + input_nodes: Iterable[Input], + output_nodes: Iterable[IntermediateNode], ) -> "OPGraph": """Construct OPGraph from an existing networkx MultiDiGraph. Args: graph (nx.MultiDiGraph): The networkx MultiDiGraph to use. - input_nodes (Iterable[ir.Input]): The input nodes of the MultiDiGraph. - output_nodes (Iterable[ir.IntermediateNode]): The output nodes of the MultiDiGraph. + input_nodes (Iterable[Input]): The input nodes of the MultiDiGraph. + output_nodes (Iterable[IntermediateNode]): The output nodes of the MultiDiGraph. Returns: OPGraph: The resulting OPGraph. """ return OPGraph(graph, dict(enumerate(input_nodes)), dict(enumerate(output_nodes))) - def get_ordered_inputs(self) -> List[ir.Input]: + def get_ordered_inputs(self) -> List[Input]: """Get the input nodes of the graph, ordered by their index. Returns: - List[ir.Input]: ordered input nodes + List[Input]: ordered input nodes """ return [self.input_nodes[idx] for idx in range(len(self.input_nodes))] - def get_ordered_outputs(self) -> List[ir.IntermediateNode]: + def get_ordered_outputs(self) -> List[IntermediateNode]: """Get the output nodes of the graph, ordered by their index. Returns: - List[ir.IntermediateNode]: ordered input nodes + List[IntermediateNode]: ordered input nodes """ return [self.output_nodes[idx] for idx in range(len(self.output_nodes))] - def evaluate(self, inputs: Dict[int, Any]) -> Dict[ir.IntermediateNode, Any]: + def evaluate(self, inputs: Dict[int, Any]) -> Dict[IntermediateNode, Any]: """Evaluate a graph and get intermediate values for all nodes. Args: inputs (Dict[int, Any]): The inputs to the program Returns: - Dict[ir.IntermediateNode, Any]: Dictionary with node as keys and resulting values + Dict[IntermediateNode, Any]: Dictionary with node as keys and resulting values """ - node_results: Dict[ir.IntermediateNode, Any] = {} + node_results: Dict[IntermediateNode, Any] = {} for node in nx.topological_sort(self.graph): - if not isinstance(node, ir.Input): + if not isinstance(node, Input): curr_inputs = {} for pred_node in self.graph.pred[node]: edges = self.graph.get_edge_data(pred_node, node) @@ -168,7 +168,7 @@ class OPGraph: callback function to determine the type constructor of the data encountered while updating the graph bounds. Defaults to get_type_constructor_python_constant_data. """ - node: ir.IntermediateNode + node: IntermediateNode for node in self.graph.nodes(): current_node_bounds = node_bounds[node] @@ -193,7 +193,7 @@ class OPGraph: data_type_constructor = max_data_type_constructor - if not isinstance(node, ir.Input): + if not isinstance(node, Input): for output_value in node.outputs: if isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer): output_value.data_type = make_integer_to_hold( @@ -242,9 +242,9 @@ class OPGraph: """Remove unreachable nodes from outputs.""" current_nodes = set(self.output_nodes.values()) - useful_nodes: Set[ir.IntermediateNode] = set() + useful_nodes: Set[IntermediateNode] = set() while current_nodes: - next_nodes: Set[ir.IntermediateNode] = set() + next_nodes: Set[IntermediateNode] = set() useful_nodes.update(current_nodes) for node in current_nodes: next_nodes.update(self.graph.pred[node]) diff --git a/concrete/common/optimization/topological.py b/concrete/common/optimization/topological.py index 7646c38b3..18a7f75e4 100644 --- a/concrete/common/optimization/topological.py +++ b/concrete/common/optimization/topological.py @@ -9,7 +9,7 @@ from ..data_types.floats import Float from ..data_types.integers import Integer from ..debugging.custom_assert import custom_assert from ..operator_graph import OPGraph -from ..representation import intermediate as ir +from ..representation.intermediate import ArbitraryFunction, Constant, Input, IntermediateNode def fuse_float_operations( @@ -26,7 +26,7 @@ def fuse_float_operations( """ nx_graph = op_graph.graph - processed_terminal_nodes: Set[ir.IntermediateNode] = set() + processed_terminal_nodes: Set[IntermediateNode] = set() number_of_fuse = 0 while True: float_subgraph_search_result = find_float_subgraph_with_unique_terminal_node( @@ -56,7 +56,7 @@ def fuse_float_operations( if terminal_node in op_graph.output_nodes.values(): # Output value replace it # As the graph changes recreate the output_node_to_idx dict - output_node_to_idx: Dict[ir.IntermediateNode, List[int]] = { + output_node_to_idx: Dict[IntermediateNode, List[int]] = { out_node: [] for out_node in op_graph.output_nodes.values() } for output_idx, output_node in op_graph.output_nodes.items(): @@ -87,21 +87,21 @@ def fuse_float_operations( def convert_float_subgraph_to_fused_node( op_graph: OPGraph, - float_subgraph_start_nodes: Set[ir.IntermediateNode], - terminal_node: ir.IntermediateNode, - subgraph_all_nodes: Set[ir.IntermediateNode], -) -> Optional[Tuple[ir.ArbitraryFunction, ir.IntermediateNode]]: + float_subgraph_start_nodes: Set[IntermediateNode], + terminal_node: IntermediateNode, + subgraph_all_nodes: Set[IntermediateNode], +) -> Optional[Tuple[ArbitraryFunction, IntermediateNode]]: """Convert a float subgraph to an equivalent fused ArbitraryFunction node. Args: op_graph (OPGraph): The OPGraph the float subgraph is part of. - float_subgraph_start_nodes (Set[ir.IntermediateNode]): The nodes starting the float subgraph + float_subgraph_start_nodes (Set[IntermediateNode]): The nodes starting the float subgraph in `op_graph`. - terminal_node (ir.IntermediateNode): The node ending the float subgraph. - subgraph_all_nodes (Set[ir.IntermediateNode]): All the nodes in the float subgraph. + terminal_node (IntermediateNode): The node ending the float subgraph. + subgraph_all_nodes (Set[IntermediateNode]): All the nodes in the float subgraph. Returns: - Optional[Tuple[ir.ArbitraryFunction, ir.IntermediateNode]]: None if the float subgraph + Optional[Tuple[ArbitraryFunction, IntermediateNode]]: None if the float subgraph cannot be fused, otherwise returns a tuple containing the fused node and the node whose output must be plugged as the input to the subgraph. """ @@ -111,7 +111,7 @@ def convert_float_subgraph_to_fused_node( # Only one variable input node, find which node feeds its input non_constant_start_nodes = [ - node for node in float_subgraph_start_nodes if not isinstance(node, ir.Constant) + node for node in float_subgraph_start_nodes if not isinstance(node, Constant) ] custom_assert(len(non_constant_start_nodes) == 1) @@ -126,7 +126,7 @@ def convert_float_subgraph_to_fused_node( float_subgraph = nx.MultiDiGraph(nx_graph.subgraph(subgraph_all_nodes)) - new_subgraph_variable_input = ir.Input(new_input_value, "float_subgraph_input", 0) + new_subgraph_variable_input = Input(new_input_value, "float_subgraph_input", 0) float_subgraph.add_node(new_subgraph_variable_input) for node_after_input in nodes_after_input_set: @@ -155,7 +155,7 @@ def convert_float_subgraph_to_fused_node( ) # Create fused_node - fused_node = ir.ArbitraryFunction( + fused_node = ArbitraryFunction( deepcopy(new_subgraph_variable_input.inputs[0]), lambda x, float_op_subgraph, terminal_node: float_op_subgraph.evaluate({0: x})[ terminal_node @@ -176,8 +176,8 @@ def convert_float_subgraph_to_fused_node( def find_float_subgraph_with_unique_terminal_node( nx_graph: nx.MultiDiGraph, - processed_terminal_nodes: Set[ir.IntermediateNode], -) -> Optional[Tuple[Set[ir.IntermediateNode], ir.IntermediateNode, Set[ir.IntermediateNode]]]: + processed_terminal_nodes: Set[IntermediateNode], +) -> Optional[Tuple[Set[IntermediateNode], IntermediateNode, Set[IntermediateNode]]]: """Find a subgraph of the graph with float computations. The subgraph has a single terminal node with a single Integer output and has a single variable @@ -185,24 +185,24 @@ def find_float_subgraph_with_unique_terminal_node( Args: nx_graph (nx.MultiDiGraph): The networkx graph to search in. - processed_terminal_nodes (Set[ir.IntermediateNode]): The set of terminal nodes for which + processed_terminal_nodes (Set[IntermediateNode]): The set of terminal nodes for which subgraphs have already been searched, those will be skipped. Returns: - Optional[Tuple[Set[ir.IntermediateNode], ir.IntermediateNode, Set[ir.IntermediateNode]]]: + Optional[Tuple[Set[IntermediateNode], IntermediateNode, Set[IntermediateNode]]]: None if there are no float subgraphs to process in `nx_graph`. Otherwise returns a tuple containing the set of nodes beginning a float subgraph, the terminal node of the subgraph and the set of all the nodes in the subgraph. """ - def is_float_to_single_int_node(node: ir.IntermediateNode) -> bool: + def is_float_to_single_int_node(node: IntermediateNode) -> bool: return ( any(isinstance(input_.data_type, Float) for input_ in node.inputs) and len(node.outputs) == 1 and isinstance(node.outputs[0].data_type, Integer) ) - def single_int_output_node(node: ir.IntermediateNode) -> bool: + def single_int_output_node(node: IntermediateNode) -> bool: return len(node.outputs) == 1 and isinstance(node.outputs[0].data_type, Integer) float_subgraphs_terminal_nodes = ( @@ -211,7 +211,7 @@ def find_float_subgraph_with_unique_terminal_node( if is_float_to_single_int_node(node) and node not in processed_terminal_nodes ) - terminal_node: ir.IntermediateNode + terminal_node: IntermediateNode try: terminal_node = next(float_subgraphs_terminal_nodes) @@ -220,10 +220,10 @@ def find_float_subgraph_with_unique_terminal_node( # Use dict as ordered set current_nodes = {terminal_node: None} - float_subgraph_start_nodes: Set[ir.IntermediateNode] = set() - subgraph_all_nodes: Set[ir.IntermediateNode] = set() + float_subgraph_start_nodes: Set[IntermediateNode] = set() + subgraph_all_nodes: Set[IntermediateNode] = set() while current_nodes: - next_nodes: Dict[ir.IntermediateNode, None] = {} + next_nodes: Dict[IntermediateNode, None] = {} for node in current_nodes: subgraph_all_nodes.add(node) predecessors = nx_graph.pred[node] @@ -240,16 +240,16 @@ def find_float_subgraph_with_unique_terminal_node( def subgraph_has_unique_variable_input( - float_subgraph_start_nodes: Set[ir.IntermediateNode], + float_subgraph_start_nodes: Set[IntermediateNode], ) -> bool: """Check that only one of the nodes starting the subgraph is variable. Args: - float_subgraph_start_nodes (Set[ir.IntermediateNode]): The nodes starting the subgraph. + float_subgraph_start_nodes (Set[IntermediateNode]): The nodes starting the subgraph. Returns: - bool: True if only one of the nodes is not an ir.Constant + bool: True if only one of the nodes is not an Constant """ # Only one input to the subgraph where computations are done in floats is variable, this # is the only case we can manage with ArbitraryFunction fusing - return sum(not isinstance(node, ir.Constant) for node in float_subgraph_start_nodes) == 1 + return sum(not isinstance(node, Constant) for node in float_subgraph_start_nodes) == 1 diff --git a/concrete/common/tracing/base_tracer.py b/concrete/common/tracing/base_tracer.py index 094c8f80d..1bd9ad747 100644 --- a/concrete/common/tracing/base_tracer.py +++ b/concrete/common/tracing/base_tracer.py @@ -4,8 +4,13 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Iterable, List, Tuple, Type, Union from ..debugging.custom_assert import custom_assert -from ..representation import intermediate as ir -from ..representation.intermediate import IR_MIX_VALUES_FUNC_ARG_NAME +from ..representation.intermediate import ( + IR_MIX_VALUES_FUNC_ARG_NAME, + Add, + IntermediateNode, + Mul, + Sub, +) from ..values import BaseValue @@ -13,14 +18,14 @@ class BaseTracer(ABC): """Base class for implementing tracers.""" inputs: List["BaseTracer"] - traced_computation: ir.IntermediateNode + traced_computation: IntermediateNode output: BaseValue _mix_values_func: Callable[..., BaseValue] def __init__( self, inputs: Iterable["BaseTracer"], - traced_computation: ir.IntermediateNode, + traced_computation: IntermediateNode, output_index: int, ) -> None: self.inputs = list(inputs) @@ -62,14 +67,14 @@ class BaseTracer(ABC): def instantiate_output_tracers( self, inputs: Iterable[Union["BaseTracer", Any]], - computation_to_trace: Type[ir.IntermediateNode], + computation_to_trace: Type[IntermediateNode], ) -> Tuple["BaseTracer", ...]: """Instantiate all output BaseTracer for a given computation. Args: inputs (Iterable[Union[BaseTracer, Any]]): Previous BaseTracer or data used as inputs for a new node. - computation_to_trace (Type[ir.IntermediateNode]): The IntermediateNode class + computation_to_trace (Type[IntermediateNode]): The IntermediateNode class to instantiate for the computation being traced Returns: @@ -103,7 +108,7 @@ class BaseTracer(ABC): result_tracer = self.instantiate_output_tracers( [self, other], - ir.Add, + Add, ) custom_assert(len(result_tracer) == 1) @@ -120,7 +125,7 @@ class BaseTracer(ABC): result_tracer = self.instantiate_output_tracers( [self, other], - ir.Sub, + Sub, ) custom_assert(len(result_tracer) == 1) @@ -132,7 +137,7 @@ class BaseTracer(ABC): result_tracer = self.instantiate_output_tracers( [other, self], - ir.Sub, + Sub, ) custom_assert(len(result_tracer) == 1) @@ -144,7 +149,7 @@ class BaseTracer(ABC): result_tracer = self.instantiate_output_tracers( [self, other], - ir.Mul, + Mul, ) custom_assert(len(result_tracer) == 1) diff --git a/concrete/common/tracing/tracing_helpers.py b/concrete/common/tracing/tracing_helpers.py index 712926270..a2c894312 100644 --- a/concrete/common/tracing/tracing_helpers.py +++ b/concrete/common/tracing/tracing_helpers.py @@ -7,7 +7,7 @@ import networkx as nx from networkx.algorithms.dag import is_directed_acyclic_graph from ..debugging.custom_assert import custom_assert -from ..representation import intermediate as ir +from ..representation.intermediate import Input from ..values import BaseValue from .base_tracer import BaseTracer @@ -50,7 +50,7 @@ def make_input_tracer( Returns: BaseTracer: The BaseTracer for that input value """ - return tracer_class([], ir.Input(input_value, input_name, input_idx), 0) + return tracer_class([], Input(input_value, input_name, input_idx), 0) def prepare_function_parameters( diff --git a/concrete/numpy/compile.py b/concrete/numpy/compile.py index ea4d3925e..fd1074083 100644 --- a/concrete/numpy/compile.py +++ b/concrete/numpy/compile.py @@ -19,7 +19,7 @@ from ..common.mlir.utils import ( ) from ..common.operator_graph import OPGraph from ..common.optimization.topological import fuse_float_operations -from ..common.representation import intermediate as ir +from ..common.representation.intermediate import IntermediateNode from ..common.values import BaseValue from ..numpy.tracing import trace_numpy_function from .np_dtypes_helpers import ( @@ -99,7 +99,7 @@ def _compile_numpy_function_into_op_graph_internal( fuse_float_operations(op_graph, compilation_artifacts) # TODO: To be removed once we support more than integers - offending_non_integer_nodes: List[ir.IntermediateNode] = [] + offending_non_integer_nodes: List[IntermediateNode] = [] op_grap_is_int_prog = check_op_graph_is_integer_program(op_graph, offending_non_integer_nodes) if not op_grap_is_int_prog: raise ValueError(