diff --git a/hdk/common/representation/intermediate.py b/hdk/common/representation/intermediate.py index d18612cfa..9e86ed569 100644 --- a/hdk/common/representation/intermediate.py +++ b/hdk/common/representation/intermediate.py @@ -100,13 +100,16 @@ class Input(IntermediateNode): """Node representing an input of the numpy program""" input_name: str + program_input_idx: int def __init__( self, input_value: BaseValue, input_name: str, + program_input_idx: int, ) -> None: super().__init__((input_value,)) assert len(self.inputs) == 1 self.input_name = input_name + self.program_input_idx = program_input_idx self.outputs = [deepcopy(self.inputs[0])] diff --git a/hdk/common/tracing/__init__.py b/hdk/common/tracing/__init__.py index 1818cb5d9..f311b529e 100644 --- a/hdk/common/tracing/__init__.py +++ b/hdk/common/tracing/__init__.py @@ -3,5 +3,6 @@ from .base_tracer import BaseTracer from .tracing_helpers import ( create_graph_from_output_tracers, make_input_tracer, + make_input_tracers, prepare_function_parameters, ) diff --git a/hdk/common/tracing/tracing_helpers.py b/hdk/common/tracing/tracing_helpers.py index e7b1dba42..94bede3ec 100644 --- a/hdk/common/tracing/tracing_helpers.py +++ b/hdk/common/tracing/tracing_helpers.py @@ -1,6 +1,7 @@ """Helper functions for tracing""" +import collections from inspect import signature -from typing import Callable, Dict, Iterable, Set, Tuple, Type +from typing import Callable, Dict, Iterable, OrderedDict, Set, Tuple, Type import networkx as nx from networkx.algorithms.dag import is_directed_acyclic_graph @@ -10,9 +11,30 @@ from ..representation import intermediate as ir from .base_tracer import BaseTracer +def make_input_tracers( + tracer_class: Type[BaseTracer], + function_parameters: OrderedDict[str, BaseValue], +) -> OrderedDict[str, BaseTracer]: + """Helper function to create tracers for a function's parameters + + Args: + tracer_class (Type[BaseTracer]): the class of tracer to create an Input for + function_parameters (OrderedDict[str, BaseValue]): the dictionary with the parameters names + and corresponding Values + + Returns: + OrderedDict[str, BaseTracer]: the dictionary containing the Input Tracers for each parameter + """ + return collections.OrderedDict( + (param_name, make_input_tracer(tracer_class, param_name, input_idx, param)) + for input_idx, (param_name, param) in enumerate(function_parameters.items()) + ) + + def make_input_tracer( tracer_class: Type[BaseTracer], input_name: str, + input_idx: int, input_value: BaseValue, ) -> BaseTracer: """Helper function to create a tracer for an input value @@ -20,18 +42,19 @@ def make_input_tracer( Args: tracer_class (Type[BaseTracer]): the class of tracer to create an Input for input_name (str): the name of the input in the traced function + input_idx (int): the input index in the function parameters input_value (BaseValue): the Value that is an input and needs to be wrapped in an BaseTracer Returns: BaseTracer: The BaseTracer for that input value """ - return tracer_class([], ir.Input(input_value, input_name), 0) + return tracer_class([], ir.Input(input_value, input_name, input_idx), 0) def prepare_function_parameters( function_to_trace: Callable, function_parameters: Dict[str, BaseValue] -) -> Dict[str, BaseValue]: +) -> OrderedDict[str, BaseValue]: """Function to filter the passed function_parameters to trace function_to_trace Args: @@ -42,7 +65,7 @@ def prepare_function_parameters( ValueError: Raised when some parameters are missing to trace function_to_trace Returns: - Dict[str, BaseValue]: filtered function_parameters dictionary + OrderedDict[str, BaseValue]: filtered function_parameters dictionary """ function_signature = signature(function_to_trace) @@ -54,10 +77,11 @@ def prepare_function_parameters( f"that were not provided: {', '.join(sorted(missing_args))}" ) - useless_arguments = function_parameters.keys() - function_signature.parameters.keys() - useful_arguments = function_signature.parameters.keys() - useless_arguments - - return {k: function_parameters[k] for k in useful_arguments} + # This convoluted way of creating the dict is to ensure key order is maintained + return collections.OrderedDict( + (param_name, function_parameters[param_name]) + for param_name in function_signature.parameters.keys() + ) def create_graph_from_output_tracers( diff --git a/hdk/hnumpy/tracing.py b/hdk/hnumpy/tracing.py index c55d97124..8b10835e1 100644 --- a/hdk/hnumpy/tracing.py +++ b/hdk/hnumpy/tracing.py @@ -7,7 +7,7 @@ from ..common.data_types import BaseValue from ..common.tracing import ( BaseTracer, create_graph_from_output_tracers, - make_input_tracer, + make_input_tracers, prepare_function_parameters, ) @@ -32,10 +32,7 @@ def trace_numpy_function( """ function_parameters = prepare_function_parameters(function_to_trace, function_parameters) - input_tracers = { - param_name: make_input_tracer(NPTracer, param_name, param) - for param_name, param in function_parameters.items() - } + input_tracers = make_input_tracers(NPTracer, function_parameters) # We could easily create a graph of NPTracer, but we may end up with dead nodes starting from # the inputs that's why we create the graph starting from the outputs diff --git a/tests/common/tracing/test_tracing_helpers.py b/tests/common/tracing/test_tracing_helpers.py index 20adb02ad..38e57a7fc 100644 --- a/tests/common/tracing/test_tracing_helpers.py +++ b/tests/common/tracing/test_tracing_helpers.py @@ -1,6 +1,6 @@ """Test file for HDK's common tracing helpers""" -from typing import Any, Dict +from typing import Any, Dict, List import pytest @@ -24,3 +24,21 @@ def test_prepare_function_parameters( prepared_dict = prepare_function_parameters(function, function_parameters) assert prepared_dict == ref_dict + + +@pytest.mark.parametrize( + "function,function_parameters,expected_ordered_keys", + [ + (lambda x: None, {"x": None}, ["x"]), + (lambda x, y: None, {"x": None, "y": None}, ["x", "y"]), + (lambda x, y: None, {"y": None, "x": None}, ["x", "y"]), + (lambda z, x, y: None, {"y": None, "z": None, "x": None}, ["z", "x", "y"]), + ], +) +def test_prepare_function_parameters_order( + function, function_parameters: Dict[str, Any], expected_ordered_keys: List[str] +): + """Test prepare_function_parameters output order""" + prepared_dict = prepare_function_parameters(function, function_parameters) + + assert list(prepared_dict.keys()) == expected_ordered_keys diff --git a/tests/hnumpy/test_tracing.py b/tests/hnumpy/test_tracing.py index f9eb2ad3b..46d745374 100644 --- a/tests/hnumpy/test_tracing.py +++ b/tests/hnumpy/test_tracing.py @@ -80,8 +80,8 @@ def test_hnumpy_tracing_binary_op(operation, x, y, test_helpers): ref_graph = nx.MultiDiGraph() - input_x = ir.Input(x, input_name="x") - input_y = ir.Input(y, input_name="y") + input_x = ir.Input(x, input_name="x", program_input_idx=0) + input_y = ir.Input(y, input_name="y", program_input_idx=1) add_node_z = ir.Add( (