diff --git a/hdk/common/representation/intermediate.py b/hdk/common/representation/intermediate.py index 1c74ec031..bf9da6dde 100644 --- a/hdk/common/representation/intermediate.py +++ b/hdk/common/representation/intermediate.py @@ -100,15 +100,14 @@ class Mul(IntermediateNode): class Input(IntermediateNode): """Node representing an input of the numpy program""" + input_name: str + def __init__( self, - inputs: Iterable[BaseValue], - op_args: Optional[Tuple[Any, ...]] = None, - op_kwargs: Optional[Dict[str, Any]] = None, + input_value: BaseValue, + input_name: str, ) -> None: - assert op_args is None, f"Expected op_args to be None, got {op_args}" - assert op_kwargs is None, f"Expected op_kwargs to be None, got {op_kwargs}" - - super().__init__(inputs, op_args=op_args, op_kwargs=op_kwargs) + super().__init__((input_value,)) assert len(self.inputs) == 1 + self.input_name = input_name self.outputs = [deepcopy(self.inputs[0])] diff --git a/hdk/common/tracing/tracing_helpers.py b/hdk/common/tracing/tracing_helpers.py index a101ec5db..e7b1dba42 100644 --- a/hdk/common/tracing/tracing_helpers.py +++ b/hdk/common/tracing/tracing_helpers.py @@ -10,18 +10,23 @@ from ..representation import intermediate as ir from .base_tracer import BaseTracer -def make_input_tracer(tracer_class: Type[BaseTracer], input_value: BaseValue) -> BaseTracer: +def make_input_tracer( + tracer_class: Type[BaseTracer], + input_name: str, + input_value: BaseValue, +) -> BaseTracer: """Helper function to create a tracer for an input value Args: tracer_class (Type[BaseTracer]): the class of tracer to create an Input for + input_name (str): the name of the input in the traced function input_value (BaseValue): the Value that is an input and needs to be wrapped in an BaseTracer Returns: BaseTracer: The BaseTracer for that input value """ - return tracer_class([], ir.Input([input_value]), 0) + return tracer_class([], ir.Input(input_value, input_name), 0) def prepare_function_parameters( diff --git a/hdk/hnumpy/tracing.py b/hdk/hnumpy/tracing.py index 8e1ac38a9..c55d97124 100644 --- a/hdk/hnumpy/tracing.py +++ b/hdk/hnumpy/tracing.py @@ -33,7 +33,7 @@ def trace_numpy_function( function_parameters = prepare_function_parameters(function_to_trace, function_parameters) input_tracers = { - param_name: make_input_tracer(NPTracer, param) + param_name: make_input_tracer(NPTracer, param_name, param) for param_name, param in function_parameters.items() } diff --git a/tests/hnumpy/test_tracing.py b/tests/hnumpy/test_tracing.py index 7d4e8b132..07b4acbc1 100644 --- a/tests/hnumpy/test_tracing.py +++ b/tests/hnumpy/test_tracing.py @@ -80,8 +80,8 @@ def test_hnumpy_tracing_binary_op(operation, x, y, test_helpers): ref_graph = nx.MultiDiGraph() - input_x = ir.Input((x,)) - input_y = ir.Input((y,)) + input_x = ir.Input(x, input_name="x") + input_y = ir.Input(y, input_name="y") add_node_z = ir.Add( (