dev(ir): make Input ir node accept a name

This commit is contained in:
Arthur Meyre
2021-07-23 16:05:31 +02:00
parent f910f1fa9c
commit a56a0dbf0c
4 changed files with 16 additions and 12 deletions

View File

@@ -100,15 +100,14 @@ class Mul(IntermediateNode):
class Input(IntermediateNode):
"""Node representing an input of the numpy program"""
input_name: str
def __init__(
self,
inputs: Iterable[BaseValue],
op_args: Optional[Tuple[Any, ...]] = None,
op_kwargs: Optional[Dict[str, Any]] = None,
input_value: BaseValue,
input_name: str,
) -> None:
assert op_args is None, f"Expected op_args to be None, got {op_args}"
assert op_kwargs is None, f"Expected op_kwargs to be None, got {op_kwargs}"
super().__init__(inputs, op_args=op_args, op_kwargs=op_kwargs)
super().__init__((input_value,))
assert len(self.inputs) == 1
self.input_name = input_name
self.outputs = [deepcopy(self.inputs[0])]

View File

@@ -10,18 +10,23 @@ from ..representation import intermediate as ir
from .base_tracer import BaseTracer
def make_input_tracer(tracer_class: Type[BaseTracer], input_value: BaseValue) -> BaseTracer:
def make_input_tracer(
tracer_class: Type[BaseTracer],
input_name: str,
input_value: BaseValue,
) -> BaseTracer:
"""Helper function to create a tracer for an input value
Args:
tracer_class (Type[BaseTracer]): the class of tracer to create an Input for
input_name (str): the name of the input in the traced function
input_value (BaseValue): the Value that is an input and needs to be wrapped in an
BaseTracer
Returns:
BaseTracer: The BaseTracer for that input value
"""
return tracer_class([], ir.Input([input_value]), 0)
return tracer_class([], ir.Input(input_value, input_name), 0)
def prepare_function_parameters(

View File

@@ -33,7 +33,7 @@ def trace_numpy_function(
function_parameters = prepare_function_parameters(function_to_trace, function_parameters)
input_tracers = {
param_name: make_input_tracer(NPTracer, param)
param_name: make_input_tracer(NPTracer, param_name, param)
for param_name, param in function_parameters.items()
}

View File

@@ -80,8 +80,8 @@ def test_hnumpy_tracing_binary_op(operation, x, y, test_helpers):
ref_graph = nx.MultiDiGraph()
input_x = ir.Input((x,))
input_y = ir.Input((y,))
input_x = ir.Input(x, input_name="x")
input_y = ir.Input(y, input_name="y")
add_node_z = ir.Add(
(