dev(ordered-inputs): update code to keep the input index in Input IR nodes

- this input index will be useful for MLIR/lower level conversions
- it represents the input index of an Input node when considering the
traced function signature
- update code preparing function parameters to keep signature order
This commit is contained in:
Arthur Meyre
2021-07-28 11:19:43 +02:00
parent 5e5d0477b1
commit d739e6672d
6 changed files with 59 additions and 16 deletions

View File

@@ -100,13 +100,16 @@ class Input(IntermediateNode):
"""Node representing an input of the numpy program"""
input_name: str
program_input_idx: int
def __init__(
self,
input_value: BaseValue,
input_name: str,
program_input_idx: int,
) -> None:
super().__init__((input_value,))
assert len(self.inputs) == 1
self.input_name = input_name
self.program_input_idx = program_input_idx
self.outputs = [deepcopy(self.inputs[0])]

View File

@@ -3,5 +3,6 @@ from .base_tracer import BaseTracer
from .tracing_helpers import (
create_graph_from_output_tracers,
make_input_tracer,
make_input_tracers,
prepare_function_parameters,
)

View File

@@ -1,6 +1,7 @@
"""Helper functions for tracing"""
import collections
from inspect import signature
from typing import Callable, Dict, Iterable, Set, Tuple, Type
from typing import Callable, Dict, Iterable, OrderedDict, Set, Tuple, Type
import networkx as nx
from networkx.algorithms.dag import is_directed_acyclic_graph
@@ -10,9 +11,30 @@ from ..representation import intermediate as ir
from .base_tracer import BaseTracer
def make_input_tracers(
tracer_class: Type[BaseTracer],
function_parameters: OrderedDict[str, BaseValue],
) -> OrderedDict[str, BaseTracer]:
"""Helper function to create tracers for a function's parameters
Args:
tracer_class (Type[BaseTracer]): the class of tracer to create an Input for
function_parameters (OrderedDict[str, BaseValue]): the dictionary with the parameters names
and corresponding Values
Returns:
OrderedDict[str, BaseTracer]: the dictionary containing the Input Tracers for each parameter
"""
return collections.OrderedDict(
(param_name, make_input_tracer(tracer_class, param_name, input_idx, param))
for input_idx, (param_name, param) in enumerate(function_parameters.items())
)
def make_input_tracer(
tracer_class: Type[BaseTracer],
input_name: str,
input_idx: int,
input_value: BaseValue,
) -> BaseTracer:
"""Helper function to create a tracer for an input value
@@ -20,18 +42,19 @@ def make_input_tracer(
Args:
tracer_class (Type[BaseTracer]): the class of tracer to create an Input for
input_name (str): the name of the input in the traced function
input_idx (int): the input index in the function parameters
input_value (BaseValue): the Value that is an input and needs to be wrapped in an
BaseTracer
Returns:
BaseTracer: The BaseTracer for that input value
"""
return tracer_class([], ir.Input(input_value, input_name), 0)
return tracer_class([], ir.Input(input_value, input_name, input_idx), 0)
def prepare_function_parameters(
function_to_trace: Callable, function_parameters: Dict[str, BaseValue]
) -> Dict[str, BaseValue]:
) -> OrderedDict[str, BaseValue]:
"""Function to filter the passed function_parameters to trace function_to_trace
Args:
@@ -42,7 +65,7 @@ def prepare_function_parameters(
ValueError: Raised when some parameters are missing to trace function_to_trace
Returns:
Dict[str, BaseValue]: filtered function_parameters dictionary
OrderedDict[str, BaseValue]: filtered function_parameters dictionary
"""
function_signature = signature(function_to_trace)
@@ -54,10 +77,11 @@ def prepare_function_parameters(
f"that were not provided: {', '.join(sorted(missing_args))}"
)
useless_arguments = function_parameters.keys() - function_signature.parameters.keys()
useful_arguments = function_signature.parameters.keys() - useless_arguments
return {k: function_parameters[k] for k in useful_arguments}
# This convoluted way of creating the dict is to ensure key order is maintained
return collections.OrderedDict(
(param_name, function_parameters[param_name])
for param_name in function_signature.parameters.keys()
)
def create_graph_from_output_tracers(

View File

@@ -7,7 +7,7 @@ from ..common.data_types import BaseValue
from ..common.tracing import (
BaseTracer,
create_graph_from_output_tracers,
make_input_tracer,
make_input_tracers,
prepare_function_parameters,
)
@@ -32,10 +32,7 @@ def trace_numpy_function(
"""
function_parameters = prepare_function_parameters(function_to_trace, function_parameters)
input_tracers = {
param_name: make_input_tracer(NPTracer, param_name, param)
for param_name, param in function_parameters.items()
}
input_tracers = make_input_tracers(NPTracer, function_parameters)
# We could easily create a graph of NPTracer, but we may end up with dead nodes starting from
# the inputs that's why we create the graph starting from the outputs

View File

@@ -1,6 +1,6 @@
"""Test file for HDK's common tracing helpers"""
from typing import Any, Dict
from typing import Any, Dict, List
import pytest
@@ -24,3 +24,21 @@ def test_prepare_function_parameters(
prepared_dict = prepare_function_parameters(function, function_parameters)
assert prepared_dict == ref_dict
@pytest.mark.parametrize(
"function,function_parameters,expected_ordered_keys",
[
(lambda x: None, {"x": None}, ["x"]),
(lambda x, y: None, {"x": None, "y": None}, ["x", "y"]),
(lambda x, y: None, {"y": None, "x": None}, ["x", "y"]),
(lambda z, x, y: None, {"y": None, "z": None, "x": None}, ["z", "x", "y"]),
],
)
def test_prepare_function_parameters_order(
function, function_parameters: Dict[str, Any], expected_ordered_keys: List[str]
):
"""Test prepare_function_parameters output order"""
prepared_dict = prepare_function_parameters(function, function_parameters)
assert list(prepared_dict.keys()) == expected_ordered_keys

View File

@@ -80,8 +80,8 @@ def test_hnumpy_tracing_binary_op(operation, x, y, test_helpers):
ref_graph = nx.MultiDiGraph()
input_x = ir.Input(x, input_name="x")
input_y = ir.Input(y, input_name="y")
input_x = ir.Input(x, input_name="x", program_input_idx=0)
input_y = ir.Input(y, input_name="y", program_input_idx=1)
add_node_z = ir.Add(
(