refactor: rename get_ordered_inputs_of to be more in line with what it does

- rename to get_ordered_preds_and_inputs_of
This commit is contained in:
Arthur Meyre
2021-11-25 09:38:06 +01:00
parent 39557720ca
commit 7909a4899f
3 changed files with 7 additions and 5 deletions

View File

@@ -61,7 +61,7 @@ def format_operation_graph(
# extract predecessors and their ids
predecessors = []
for predecessor, output_idx in op_graph.get_ordered_inputs_of(node):
for predecessor, output_idx in op_graph.get_ordered_preds_and_inputs_of(node):
predecessors.append(f"%{id_map[(predecessor, output_idx)]}")
# start the build the line for the node

View File

@@ -134,7 +134,7 @@ class IntermediateNodeConverter:
if isinstance(self.node, (Add, Mul, Sub, Dot)):
if self.one_of_the_inputs_is_a_tensor and not self.all_of_the_inputs_are_tensors:
to_be_converted = []
for (pred, output) in self.op_graph.get_ordered_inputs_of(self.node):
for (pred, output) in self.op_graph.get_ordered_preds_and_inputs_of(self.node):
inp = pred.outputs[output]
if isinstance(inp, TensorValue) and inp.is_scalar:
to_be_converted.append(self.nodes_to_mlir_names[pred])

View File

@@ -131,14 +131,16 @@ class OPGraph:
idx_to_pred.update((data["input_idx"], pred) for data in edge_data.values())
return [idx_to_pred[i] for i in range(len(idx_to_pred))]
def get_ordered_inputs_of(self, node: IntermediateNode) -> List[Tuple[IntermediateNode, int]]:
"""Get node inputs ordered by their indices.
def get_ordered_preds_and_inputs_of(
self, node: IntermediateNode
) -> List[Tuple[IntermediateNode, int]]:
"""Get node preds and inputs ordered by their indices.
Args:
node (IntermediateNode): the node for which we want the ordered inputs
Returns:
List[Tuple[IntermediateNode, int]]: the ordered list of inputs
List[Tuple[IntermediateNode, int]]: the ordered list of preds and inputs
"""
idx_to_inp: Dict[int, Tuple[IntermediateNode, int]] = {}