mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
dev(ordered-inputs): update code to keep the input index in Input IR nodes
- this input index will be useful for MLIR/lower level conversions - it represents the input index of an Input node when considering the traced function signature - update code preparing function parameters to keep signature order
This commit is contained in:
@@ -100,13 +100,16 @@ class Input(IntermediateNode):
|
||||
"""Node representing an input of the numpy program"""
|
||||
|
||||
input_name: str
|
||||
program_input_idx: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_value: BaseValue,
|
||||
input_name: str,
|
||||
program_input_idx: int,
|
||||
) -> None:
|
||||
super().__init__((input_value,))
|
||||
assert len(self.inputs) == 1
|
||||
self.input_name = input_name
|
||||
self.program_input_idx = program_input_idx
|
||||
self.outputs = [deepcopy(self.inputs[0])]
|
||||
|
||||
@@ -3,5 +3,6 @@ from .base_tracer import BaseTracer
|
||||
from .tracing_helpers import (
|
||||
create_graph_from_output_tracers,
|
||||
make_input_tracer,
|
||||
make_input_tracers,
|
||||
prepare_function_parameters,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Helper functions for tracing"""
|
||||
import collections
|
||||
from inspect import signature
|
||||
from typing import Callable, Dict, Iterable, Set, Tuple, Type
|
||||
from typing import Callable, Dict, Iterable, OrderedDict, Set, Tuple, Type
|
||||
|
||||
import networkx as nx
|
||||
from networkx.algorithms.dag import is_directed_acyclic_graph
|
||||
@@ -10,9 +11,30 @@ from ..representation import intermediate as ir
|
||||
from .base_tracer import BaseTracer
|
||||
|
||||
|
||||
def make_input_tracers(
|
||||
tracer_class: Type[BaseTracer],
|
||||
function_parameters: OrderedDict[str, BaseValue],
|
||||
) -> OrderedDict[str, BaseTracer]:
|
||||
"""Helper function to create tracers for a function's parameters
|
||||
|
||||
Args:
|
||||
tracer_class (Type[BaseTracer]): the class of tracer to create an Input for
|
||||
function_parameters (OrderedDict[str, BaseValue]): the dictionary with the parameters names
|
||||
and corresponding Values
|
||||
|
||||
Returns:
|
||||
OrderedDict[str, BaseTracer]: the dictionary containing the Input Tracers for each parameter
|
||||
"""
|
||||
return collections.OrderedDict(
|
||||
(param_name, make_input_tracer(tracer_class, param_name, input_idx, param))
|
||||
for input_idx, (param_name, param) in enumerate(function_parameters.items())
|
||||
)
|
||||
|
||||
|
||||
def make_input_tracer(
|
||||
tracer_class: Type[BaseTracer],
|
||||
input_name: str,
|
||||
input_idx: int,
|
||||
input_value: BaseValue,
|
||||
) -> BaseTracer:
|
||||
"""Helper function to create a tracer for an input value
|
||||
@@ -20,18 +42,19 @@ def make_input_tracer(
|
||||
Args:
|
||||
tracer_class (Type[BaseTracer]): the class of tracer to create an Input for
|
||||
input_name (str): the name of the input in the traced function
|
||||
input_idx (int): the input index in the function parameters
|
||||
input_value (BaseValue): the Value that is an input and needs to be wrapped in an
|
||||
BaseTracer
|
||||
|
||||
Returns:
|
||||
BaseTracer: The BaseTracer for that input value
|
||||
"""
|
||||
return tracer_class([], ir.Input(input_value, input_name), 0)
|
||||
return tracer_class([], ir.Input(input_value, input_name, input_idx), 0)
|
||||
|
||||
|
||||
def prepare_function_parameters(
|
||||
function_to_trace: Callable, function_parameters: Dict[str, BaseValue]
|
||||
) -> Dict[str, BaseValue]:
|
||||
) -> OrderedDict[str, BaseValue]:
|
||||
"""Function to filter the passed function_parameters to trace function_to_trace
|
||||
|
||||
Args:
|
||||
@@ -42,7 +65,7 @@ def prepare_function_parameters(
|
||||
ValueError: Raised when some parameters are missing to trace function_to_trace
|
||||
|
||||
Returns:
|
||||
Dict[str, BaseValue]: filtered function_parameters dictionary
|
||||
OrderedDict[str, BaseValue]: filtered function_parameters dictionary
|
||||
"""
|
||||
function_signature = signature(function_to_trace)
|
||||
|
||||
@@ -54,10 +77,11 @@ def prepare_function_parameters(
|
||||
f"that were not provided: {', '.join(sorted(missing_args))}"
|
||||
)
|
||||
|
||||
useless_arguments = function_parameters.keys() - function_signature.parameters.keys()
|
||||
useful_arguments = function_signature.parameters.keys() - useless_arguments
|
||||
|
||||
return {k: function_parameters[k] for k in useful_arguments}
|
||||
# This convoluted way of creating the dict is to ensure key order is maintained
|
||||
return collections.OrderedDict(
|
||||
(param_name, function_parameters[param_name])
|
||||
for param_name in function_signature.parameters.keys()
|
||||
)
|
||||
|
||||
|
||||
def create_graph_from_output_tracers(
|
||||
|
||||
@@ -7,7 +7,7 @@ from ..common.data_types import BaseValue
|
||||
from ..common.tracing import (
|
||||
BaseTracer,
|
||||
create_graph_from_output_tracers,
|
||||
make_input_tracer,
|
||||
make_input_tracers,
|
||||
prepare_function_parameters,
|
||||
)
|
||||
|
||||
@@ -32,10 +32,7 @@ def trace_numpy_function(
|
||||
"""
|
||||
function_parameters = prepare_function_parameters(function_to_trace, function_parameters)
|
||||
|
||||
input_tracers = {
|
||||
param_name: make_input_tracer(NPTracer, param_name, param)
|
||||
for param_name, param in function_parameters.items()
|
||||
}
|
||||
input_tracers = make_input_tracers(NPTracer, function_parameters)
|
||||
|
||||
# We could easily create a graph of NPTracer, but we may end up with dead nodes starting from
|
||||
# the inputs that's why we create the graph starting from the outputs
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Test file for HDK's common tracing helpers"""
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -24,3 +24,21 @@ def test_prepare_function_parameters(
|
||||
prepared_dict = prepare_function_parameters(function, function_parameters)
|
||||
|
||||
assert prepared_dict == ref_dict
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,function_parameters,expected_ordered_keys",
|
||||
[
|
||||
(lambda x: None, {"x": None}, ["x"]),
|
||||
(lambda x, y: None, {"x": None, "y": None}, ["x", "y"]),
|
||||
(lambda x, y: None, {"y": None, "x": None}, ["x", "y"]),
|
||||
(lambda z, x, y: None, {"y": None, "z": None, "x": None}, ["z", "x", "y"]),
|
||||
],
|
||||
)
|
||||
def test_prepare_function_parameters_order(
|
||||
function, function_parameters: Dict[str, Any], expected_ordered_keys: List[str]
|
||||
):
|
||||
"""Test prepare_function_parameters output order"""
|
||||
prepared_dict = prepare_function_parameters(function, function_parameters)
|
||||
|
||||
assert list(prepared_dict.keys()) == expected_ordered_keys
|
||||
|
||||
@@ -80,8 +80,8 @@ def test_hnumpy_tracing_binary_op(operation, x, y, test_helpers):
|
||||
|
||||
ref_graph = nx.MultiDiGraph()
|
||||
|
||||
input_x = ir.Input(x, input_name="x")
|
||||
input_y = ir.Input(y, input_name="y")
|
||||
input_x = ir.Input(x, input_name="x", program_input_idx=0)
|
||||
input_y = ir.Input(y, input_name="y", program_input_idx=1)
|
||||
|
||||
add_node_z = ir.Add(
|
||||
(
|
||||
|
||||
Reference in New Issue
Block a user