mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: implement a generic node for functions which change shape
and implement np.transpose with it and implement np.ravel with it refs #745
This commit is contained in:
committed by
Benoit Chevallier
parent
759914dca6
commit
8123a5ef45
@@ -16,6 +16,7 @@ from ..representation.intermediate import (
|
||||
Add,
|
||||
Constant,
|
||||
Dot,
|
||||
GenericFunction,
|
||||
IndexConstant,
|
||||
Input,
|
||||
MatMul,
|
||||
@@ -31,10 +32,12 @@ IR_NODE_COLOR_MAPPING = {
|
||||
Sub: "yellow",
|
||||
Mul: "green",
|
||||
UnivariateFunction: "orange",
|
||||
GenericFunction: "orange",
|
||||
IndexConstant: "black",
|
||||
Dot: "purple",
|
||||
MatMul: "brown",
|
||||
"UnivariateFunction": "orange",
|
||||
"GenericFunction": "orange",
|
||||
"TLU": "grey",
|
||||
"output": "magenta",
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ from ..debugging.custom_assert import assert_true
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import (
|
||||
Constant,
|
||||
GenericFunction,
|
||||
IndexConstant,
|
||||
Input,
|
||||
IntermediateNode,
|
||||
@@ -91,7 +92,7 @@ def get_printable_graph(
|
||||
|
||||
base_name = node.__class__.__name__
|
||||
|
||||
if isinstance(node, UnivariateFunction):
|
||||
if isinstance(node, (UnivariateFunction, GenericFunction)):
|
||||
base_name = node.op_name
|
||||
|
||||
what_to_print = base_name + "("
|
||||
|
||||
@@ -69,6 +69,9 @@ def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool)
|
||||
if not value_is_scalar(inputs[0]) or not value_is_unsigned_integer(inputs[0]):
|
||||
return "only unsigned integer scalar lookup tables are supported"
|
||||
|
||||
elif isinstance(node, intermediate.GenericFunction): # constraints for generic functions
|
||||
return f"{node.op_name} is not supported for the time being" # pragma: no cover
|
||||
|
||||
elif isinstance(node, intermediate.Dot): # constraints for dot product
|
||||
assert_true(len(inputs) == 2)
|
||||
if not value_is_unsigned_integer(inputs[0]) or not value_is_unsigned_integer(inputs[1]):
|
||||
|
||||
@@ -470,3 +470,52 @@ class MatMul(IntermediateNode):
|
||||
|
||||
def label(self) -> str:
|
||||
return "@"
|
||||
|
||||
|
||||
class GenericFunction(IntermediateNode):
|
||||
"""Return the node representing a generic function."""
|
||||
|
||||
# The arbitrary_func is not optional but mypy has a long standing bug and is not able to
|
||||
# understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623
|
||||
# arbitrary_func can take more than one argument but during evaluation the input variable will
|
||||
# be the first argument passed to it. You can add other constant arguments needed for the proper
|
||||
# execution of the function through op_args and op_kwargs.
|
||||
arbitrary_func: Optional[Callable]
|
||||
op_name: str
|
||||
op_args: Tuple[Any, ...]
|
||||
op_kwargs: Dict[str, Any]
|
||||
op_attributes: Dict[str, Any]
|
||||
_n_in: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_base_value: TensorValue,
|
||||
arbitrary_func: Callable,
|
||||
output_dtype: BaseDataType,
|
||||
output_shape: Tuple,
|
||||
op_name: Optional[str] = None,
|
||||
op_args: Optional[Tuple[Any, ...]] = None,
|
||||
op_kwargs: Optional[Dict[str, Any]] = None,
|
||||
op_attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__([input_base_value])
|
||||
assert_true(len(self.inputs) == 1)
|
||||
self.arbitrary_func = arbitrary_func
|
||||
self.op_args = op_args if op_args is not None else ()
|
||||
self.op_kwargs = op_kwargs if op_kwargs is not None else {}
|
||||
self.op_attributes = op_attributes if op_attributes is not None else {}
|
||||
|
||||
self.outputs = [
|
||||
EncryptedTensor(output_dtype, output_shape)
|
||||
if self.inputs[0].is_encrypted
|
||||
else ClearTensor(output_dtype, output_shape)
|
||||
]
|
||||
self.op_name = op_name if op_name is not None else self.__class__.__name__
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
# This is the continuation of the mypy bug workaround
|
||||
assert self.arbitrary_func is not None
|
||||
return self.arbitrary_func(inputs[0], *self.op_args, **self.op_kwargs)
|
||||
|
||||
def label(self) -> str:
|
||||
return self.op_name
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
"""numpy tracing utilities."""
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union, cast
|
||||
|
||||
import numpy
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype
|
||||
from ..common.debugging.custom_assert import assert_true
|
||||
from ..common.debugging.custom_assert import assert_false, assert_true
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.representation.intermediate import Constant, Dot, MatMul, UnivariateFunction
|
||||
from ..common.representation.intermediate import (
|
||||
Constant,
|
||||
Dot,
|
||||
GenericFunction,
|
||||
MatMul,
|
||||
UnivariateFunction,
|
||||
)
|
||||
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
|
||||
from ..common.values import BaseValue
|
||||
from ..common.values import BaseValue, TensorValue
|
||||
from .np_dtypes_helpers import (
|
||||
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES,
|
||||
convert_numpy_dtype_to_base_data_type,
|
||||
@@ -261,6 +267,62 @@ class NPTracer(BaseTracer):
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def transpose(self, *args: "NPTracer", **_kwargs) -> "NPTracer":
|
||||
"""Trace numpy.transpose.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
assert_true((num_args := len(args)) == 1, f"transpose expect 1 input got {num_args}")
|
||||
|
||||
first_arg_output = args[0].output
|
||||
assert_true(isinstance(first_arg_output, TensorValue))
|
||||
first_arg_output = cast(TensorValue, first_arg_output)
|
||||
assert_false(first_arg_output.is_scalar)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
input_base_value=first_arg_output,
|
||||
arbitrary_func=numpy.transpose,
|
||||
output_dtype=first_arg_output.dtype,
|
||||
output_shape=first_arg_output.shape[::-1],
|
||||
op_kwargs=deepcopy(_kwargs),
|
||||
op_name="np.transpose",
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
args,
|
||||
traced_computation=traced_computation,
|
||||
output_idx=0,
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def ravel(self, *args: "NPTracer", **_kwargs) -> "NPTracer":
|
||||
"""Trace numpy.ravel.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
assert_true((num_args := len(args)) == 1, f"ravel expect 1 input got {num_args}")
|
||||
|
||||
first_arg_output = args[0].output
|
||||
assert_true(isinstance(first_arg_output, TensorValue))
|
||||
first_arg_output = cast(TensorValue, first_arg_output)
|
||||
assert_false(first_arg_output.is_scalar)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
input_base_value=first_arg_output,
|
||||
arbitrary_func=numpy.ravel,
|
||||
output_dtype=first_arg_output.dtype,
|
||||
output_shape=(numpy.product(first_arg_output.shape),),
|
||||
op_kwargs=deepcopy(_kwargs),
|
||||
op_name="np.ravel",
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
args,
|
||||
traced_computation=traced_computation,
|
||||
output_idx=0,
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def __getitem__(self, item):
|
||||
if isinstance(item, tuple):
|
||||
item = tuple(process_indexing_element(indexing_element) for indexing_element in item)
|
||||
@@ -373,6 +435,8 @@ class NPTracer(BaseTracer):
|
||||
|
||||
FUNC_ROUTING: Dict[Callable, Callable] = {
|
||||
numpy.dot: dot,
|
||||
numpy.transpose: transpose,
|
||||
numpy.ravel: ravel,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -175,6 +175,28 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
|
||||
numpy.array([[9, 12, 15], [19, 26, 33], [29, 40, 51]]),
|
||||
id="MatMul, numpy.arange(1, 7).reshape(3, 2), numpy.arange(1, 7).reshape(2, 3)",
|
||||
),
|
||||
pytest.param(
|
||||
ir.GenericFunction(
|
||||
EncryptedTensor(Integer(32, False), shape=(3, 5)),
|
||||
lambda x: numpy.transpose(x),
|
||||
Integer(32, False),
|
||||
output_shape=(5, 3),
|
||||
),
|
||||
[numpy.arange(15).reshape(3, 5)],
|
||||
numpy.array([[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8, 13], [4, 9, 14]]),
|
||||
id="GenericFunction, x transpose",
|
||||
),
|
||||
pytest.param(
|
||||
ir.GenericFunction(
|
||||
EncryptedTensor(Integer(32, False), shape=(3, 5)),
|
||||
lambda x: numpy.ravel(x),
|
||||
Integer(32, False),
|
||||
output_shape=(5, 3),
|
||||
),
|
||||
[numpy.arange(15).reshape(3, 5)],
|
||||
numpy.arange(15),
|
||||
id="GenericFunction, x ravel",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_evaluate(
|
||||
@@ -184,7 +206,7 @@ def test_evaluate(
|
||||
):
|
||||
"""Test evaluate methods on IntermediateNodes"""
|
||||
if isinstance(expected_result, numpy.ndarray):
|
||||
assert (node.evaluate(input_data) == expected_result).all()
|
||||
assert numpy.array_equal(node.evaluate(input_data), expected_result)
|
||||
else:
|
||||
assert node.evaluate(input_data) == expected_result
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from concrete.common.representation.intermediate import (
|
||||
Add,
|
||||
Constant,
|
||||
Dot,
|
||||
GenericFunction,
|
||||
IndexConstant,
|
||||
Input,
|
||||
IntermediateNode,
|
||||
@@ -186,6 +187,7 @@ def is_equivalent_intermediate_node(lhs: IntermediateNode, rhs: object) -> bool:
|
||||
EQUIVALENT_TEST_FUNC: Dict[Type, Callable[..., bool]] = {
|
||||
Add: is_equivalent_add,
|
||||
UnivariateFunction: is_equivalent_arbitrary_function,
|
||||
GenericFunction: is_equivalent_arbitrary_function,
|
||||
Constant: is_equivalent_constant,
|
||||
Dot: is_equivalent_dot,
|
||||
IndexConstant: is_equivalent_index_constant,
|
||||
|
||||
@@ -215,6 +215,40 @@ def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lambda_f,params,ref_graph_str",
|
||||
[
|
||||
(
|
||||
lambda x: numpy.transpose(x),
|
||||
{
|
||||
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(3, 5)),
|
||||
},
|
||||
"%0 = x\n%1 = np.transpose(%0)\nreturn(%1)\n",
|
||||
),
|
||||
(
|
||||
lambda x: numpy.ravel(x),
|
||||
{
|
||||
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(3, 5)),
|
||||
},
|
||||
"%0 = x\n%1 = np.ravel(%0)\nreturn(%1)\n",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_print_and_draw_graph_with_generic_function(lambda_f, params, ref_graph_str):
|
||||
"Test get_printable_graph and draw_graph on graphs with generic function"
|
||||
graph = tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
draw_graph(graph, show=False)
|
||||
|
||||
str_of_the_graph = get_printable_graph(graph)
|
||||
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
f"==================\nExpected \n{ref_graph_str}"
|
||||
f"==================\n"
|
||||
)
|
||||
|
||||
|
||||
# Remark that the bitwidths are not particularly correct (eg, a MUL of a 17b times 23b
|
||||
# returning 23b), since they are replaced later by the real bitwidths computed on the
|
||||
# inputset
|
||||
|
||||
@@ -608,3 +608,44 @@ def test_nptracer_unsupported_operands(operation, tracer):
|
||||
"""Test cases where NPTracer cannot be used with other operands."""
|
||||
with pytest.raises(TypeError):
|
||||
tracer = operation(tracer)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,input_value,input_and_expected_output_tuples",
|
||||
[
|
||||
(
|
||||
lambda x: numpy.transpose(x),
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
[
|
||||
(numpy.arange(4).reshape(2, 2), numpy.array([[0, 2], [1, 3]])),
|
||||
(numpy.arange(4, 8).reshape(2, 2), numpy.array([[4, 6], [5, 7]])),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: numpy.transpose(x) + 42,
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
[
|
||||
(numpy.arange(15).reshape(3, 5), numpy.arange(42, 57).reshape(3, 5).transpose()),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: numpy.ravel(x),
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
[
|
||||
(numpy.arange(4), numpy.array([0, 1, 2, 3])),
|
||||
(numpy.arange(4).reshape(2, 2), numpy.array([0, 1, 2, 3])),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_tracing_generic_function(function_to_trace, input_value, input_and_expected_output_tuples):
|
||||
"""Test function for managed by GenericFunction node"""
|
||||
for input_, expected_output in input_and_expected_output_tuples:
|
||||
|
||||
op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value})
|
||||
output_node = op_graph.output_nodes[0]
|
||||
|
||||
node_results = op_graph.evaluate({0: input_})
|
||||
evaluated_output = node_results[output_node]
|
||||
assert isinstance(evaluated_output, type(expected_output))
|
||||
assert numpy.array_equal(expected_output, evaluated_output)
|
||||
|
||||
Reference in New Issue
Block a user