mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
dev(ir): make Input ir node accept a name
This commit is contained in:
@@ -100,15 +100,14 @@ class Mul(IntermediateNode):
|
||||
class Input(IntermediateNode):
|
||||
"""Node representing an input of the numpy program"""
|
||||
|
||||
input_name: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inputs: Iterable[BaseValue],
|
||||
op_args: Optional[Tuple[Any, ...]] = None,
|
||||
op_kwargs: Optional[Dict[str, Any]] = None,
|
||||
input_value: BaseValue,
|
||||
input_name: str,
|
||||
) -> None:
|
||||
assert op_args is None, f"Expected op_args to be None, got {op_args}"
|
||||
assert op_kwargs is None, f"Expected op_kwargs to be None, got {op_kwargs}"
|
||||
|
||||
super().__init__(inputs, op_args=op_args, op_kwargs=op_kwargs)
|
||||
super().__init__((input_value,))
|
||||
assert len(self.inputs) == 1
|
||||
self.input_name = input_name
|
||||
self.outputs = [deepcopy(self.inputs[0])]
|
||||
|
||||
@@ -10,18 +10,23 @@ from ..representation import intermediate as ir
|
||||
from .base_tracer import BaseTracer
|
||||
|
||||
|
||||
def make_input_tracer(tracer_class: Type[BaseTracer], input_value: BaseValue) -> BaseTracer:
|
||||
def make_input_tracer(
|
||||
tracer_class: Type[BaseTracer],
|
||||
input_name: str,
|
||||
input_value: BaseValue,
|
||||
) -> BaseTracer:
|
||||
"""Helper function to create a tracer for an input value
|
||||
|
||||
Args:
|
||||
tracer_class (Type[BaseTracer]): the class of tracer to create an Input for
|
||||
input_name (str): the name of the input in the traced function
|
||||
input_value (BaseValue): the Value that is an input and needs to be wrapped in an
|
||||
BaseTracer
|
||||
|
||||
Returns:
|
||||
BaseTracer: The BaseTracer for that input value
|
||||
"""
|
||||
return tracer_class([], ir.Input([input_value]), 0)
|
||||
return tracer_class([], ir.Input(input_value, input_name), 0)
|
||||
|
||||
|
||||
def prepare_function_parameters(
|
||||
|
||||
@@ -33,7 +33,7 @@ def trace_numpy_function(
|
||||
function_parameters = prepare_function_parameters(function_to_trace, function_parameters)
|
||||
|
||||
input_tracers = {
|
||||
param_name: make_input_tracer(NPTracer, param)
|
||||
param_name: make_input_tracer(NPTracer, param_name, param)
|
||||
for param_name, param in function_parameters.items()
|
||||
}
|
||||
|
||||
|
||||
@@ -80,8 +80,8 @@ def test_hnumpy_tracing_binary_op(operation, x, y, test_helpers):
|
||||
|
||||
ref_graph = nx.MultiDiGraph()
|
||||
|
||||
input_x = ir.Input((x,))
|
||||
input_y = ir.Input((y,))
|
||||
input_x = ir.Input(x, input_name="x")
|
||||
input_y = ir.Input(y, input_name="y")
|
||||
|
||||
add_node_z = ir.Add(
|
||||
(
|
||||
|
||||
Reference in New Issue
Block a user