feat: implement a generic node for functions which change shape

and implement np.transpose with it
and implement np.ravel with it

refs #745
This commit is contained in:
Benoit Chevallier-Mames
2021-10-26 12:11:18 +02:00
committed by Benoit Chevallier
parent 759914dca6
commit 8123a5ef45
9 changed files with 225 additions and 6 deletions

View File

@@ -16,6 +16,7 @@ from ..representation.intermediate import (
Add,
Constant,
Dot,
GenericFunction,
IndexConstant,
Input,
MatMul,
@@ -31,10 +32,12 @@ IR_NODE_COLOR_MAPPING = {
Sub: "yellow",
Mul: "green",
UnivariateFunction: "orange",
GenericFunction: "orange",
IndexConstant: "black",
Dot: "purple",
MatMul: "brown",
"UnivariateFunction": "orange",
"GenericFunction": "orange",
"TLU": "grey",
"output": "magenta",
}

View File

@@ -8,6 +8,7 @@ from ..debugging.custom_assert import assert_true
from ..operator_graph import OPGraph
from ..representation.intermediate import (
Constant,
GenericFunction,
IndexConstant,
Input,
IntermediateNode,
@@ -91,7 +92,7 @@ def get_printable_graph(
base_name = node.__class__.__name__
if isinstance(node, UnivariateFunction):
if isinstance(node, (UnivariateFunction, GenericFunction)):
base_name = node.op_name
what_to_print = base_name + "("

View File

@@ -69,6 +69,9 @@ def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool)
if not value_is_scalar(inputs[0]) or not value_is_unsigned_integer(inputs[0]):
return "only unsigned integer scalar lookup tables are supported"
elif isinstance(node, intermediate.GenericFunction): # constraints for generic functions
return f"{node.op_name} is not supported for the time being" # pragma: no cover
elif isinstance(node, intermediate.Dot): # constraints for dot product
assert_true(len(inputs) == 2)
if not value_is_unsigned_integer(inputs[0]) or not value_is_unsigned_integer(inputs[1]):

View File

@@ -470,3 +470,52 @@ class MatMul(IntermediateNode):
def label(self) -> str:
return "@"
class GenericFunction(IntermediateNode):
"""Return the node representing a generic function."""
# The arbitrary_func is not optional but mypy has a long standing bug and is not able to
# understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623
# arbitrary_func can take more than one argument but during evaluation the input variable will
# be the first argument passed to it. You can add other constant arguments needed for the proper
# execution of the function through op_args and op_kwargs.
arbitrary_func: Optional[Callable]
op_name: str
op_args: Tuple[Any, ...]
op_kwargs: Dict[str, Any]
op_attributes: Dict[str, Any]
_n_in: int = 1
def __init__(
self,
input_base_value: TensorValue,
arbitrary_func: Callable,
output_dtype: BaseDataType,
output_shape: Tuple,
op_name: Optional[str] = None,
op_args: Optional[Tuple[Any, ...]] = None,
op_kwargs: Optional[Dict[str, Any]] = None,
op_attributes: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__([input_base_value])
assert_true(len(self.inputs) == 1)
self.arbitrary_func = arbitrary_func
self.op_args = op_args if op_args is not None else ()
self.op_kwargs = op_kwargs if op_kwargs is not None else {}
self.op_attributes = op_attributes if op_attributes is not None else {}
self.outputs = [
EncryptedTensor(output_dtype, output_shape)
if self.inputs[0].is_encrypted
else ClearTensor(output_dtype, output_shape)
]
self.op_name = op_name if op_name is not None else self.__class__.__name__
def evaluate(self, inputs: Dict[int, Any]) -> Any:
# This is the continuation of the mypy bug workaround
assert self.arbitrary_func is not None
return self.arbitrary_func(inputs[0], *self.op_args, **self.op_kwargs)
def label(self) -> str:
return self.op_name

View File

@@ -1,17 +1,23 @@
"""numpy tracing utilities."""
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union, cast
import numpy
from numpy.typing import DTypeLike
from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype
from ..common.debugging.custom_assert import assert_true
from ..common.debugging.custom_assert import assert_false, assert_true
from ..common.operator_graph import OPGraph
from ..common.representation.intermediate import Constant, Dot, MatMul, UnivariateFunction
from ..common.representation.intermediate import (
Constant,
Dot,
GenericFunction,
MatMul,
UnivariateFunction,
)
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
from ..common.values import BaseValue
from ..common.values import BaseValue, TensorValue
from .np_dtypes_helpers import (
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES,
convert_numpy_dtype_to_base_data_type,
@@ -261,6 +267,62 @@ class NPTracer(BaseTracer):
)
return output_tracer
def transpose(self, *args: "NPTracer", **_kwargs) -> "NPTracer":
"""Trace numpy.transpose.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
assert_true((num_args := len(args)) == 1, f"transpose expect 1 input got {num_args}")
first_arg_output = args[0].output
assert_true(isinstance(first_arg_output, TensorValue))
first_arg_output = cast(TensorValue, first_arg_output)
assert_false(first_arg_output.is_scalar)
traced_computation = GenericFunction(
input_base_value=first_arg_output,
arbitrary_func=numpy.transpose,
output_dtype=first_arg_output.dtype,
output_shape=first_arg_output.shape[::-1],
op_kwargs=deepcopy(_kwargs),
op_name="np.transpose",
)
output_tracer = self.__class__(
args,
traced_computation=traced_computation,
output_idx=0,
)
return output_tracer
def ravel(self, *args: "NPTracer", **_kwargs) -> "NPTracer":
"""Trace numpy.ravel.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
assert_true((num_args := len(args)) == 1, f"ravel expect 1 input got {num_args}")
first_arg_output = args[0].output
assert_true(isinstance(first_arg_output, TensorValue))
first_arg_output = cast(TensorValue, first_arg_output)
assert_false(first_arg_output.is_scalar)
traced_computation = GenericFunction(
input_base_value=first_arg_output,
arbitrary_func=numpy.ravel,
output_dtype=first_arg_output.dtype,
output_shape=(numpy.product(first_arg_output.shape),),
op_kwargs=deepcopy(_kwargs),
op_name="np.ravel",
)
output_tracer = self.__class__(
args,
traced_computation=traced_computation,
output_idx=0,
)
return output_tracer
def __getitem__(self, item):
if isinstance(item, tuple):
item = tuple(process_indexing_element(indexing_element) for indexing_element in item)
@@ -373,6 +435,8 @@ class NPTracer(BaseTracer):
FUNC_ROUTING: Dict[Callable, Callable] = {
numpy.dot: dot,
numpy.transpose: transpose,
numpy.ravel: ravel,
}

View File

@@ -175,6 +175,28 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
numpy.array([[9, 12, 15], [19, 26, 33], [29, 40, 51]]),
id="MatMul, numpy.arange(1, 7).reshape(3, 2), numpy.arange(1, 7).reshape(2, 3)",
),
pytest.param(
ir.GenericFunction(
EncryptedTensor(Integer(32, False), shape=(3, 5)),
lambda x: numpy.transpose(x),
Integer(32, False),
output_shape=(5, 3),
),
[numpy.arange(15).reshape(3, 5)],
numpy.array([[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8, 13], [4, 9, 14]]),
id="GenericFunction, x transpose",
),
pytest.param(
ir.GenericFunction(
EncryptedTensor(Integer(32, False), shape=(3, 5)),
lambda x: numpy.ravel(x),
Integer(32, False),
output_shape=(5, 3),
),
[numpy.arange(15).reshape(3, 5)],
numpy.arange(15),
id="GenericFunction, x ravel",
),
],
)
def test_evaluate(
@@ -184,7 +206,7 @@ def test_evaluate(
):
"""Test evaluate methods on IntermediateNodes"""
if isinstance(expected_result, numpy.ndarray):
assert (node.evaluate(input_data) == expected_result).all()
assert numpy.array_equal(node.evaluate(input_data), expected_result)
else:
assert node.evaluate(input_data) == expected_result

View File

@@ -15,6 +15,7 @@ from concrete.common.representation.intermediate import (
Add,
Constant,
Dot,
GenericFunction,
IndexConstant,
Input,
IntermediateNode,
@@ -186,6 +187,7 @@ def is_equivalent_intermediate_node(lhs: IntermediateNode, rhs: object) -> bool:
EQUIVALENT_TEST_FUNC: Dict[Type, Callable[..., bool]] = {
Add: is_equivalent_add,
UnivariateFunction: is_equivalent_arbitrary_function,
GenericFunction: is_equivalent_arbitrary_function,
Constant: is_equivalent_constant,
Dot: is_equivalent_dot,
IndexConstant: is_equivalent_index_constant,

View File

@@ -215,6 +215,40 @@ def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str):
)
@pytest.mark.parametrize(
"lambda_f,params,ref_graph_str",
[
(
lambda x: numpy.transpose(x),
{
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(3, 5)),
},
"%0 = x\n%1 = np.transpose(%0)\nreturn(%1)\n",
),
(
lambda x: numpy.ravel(x),
{
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(3, 5)),
},
"%0 = x\n%1 = np.ravel(%0)\nreturn(%1)\n",
),
],
)
def test_print_and_draw_graph_with_generic_function(lambda_f, params, ref_graph_str):
"Test get_printable_graph and draw_graph on graphs with generic function"
graph = tracing.trace_numpy_function(lambda_f, params)
draw_graph(graph, show=False)
str_of_the_graph = get_printable_graph(graph)
assert str_of_the_graph == ref_graph_str, (
f"\n==================\nGot \n{str_of_the_graph}"
f"==================\nExpected \n{ref_graph_str}"
f"==================\n"
)
# Remark that the bitwidths are not particularly correct (eg, a MUL of a 17b times 23b
# returning 23b), since they are replaced later by the real bitwidths computed on the
# inputset

View File

@@ -608,3 +608,44 @@ def test_nptracer_unsupported_operands(operation, tracer):
"""Test cases where NPTracer cannot be used with other operands."""
with pytest.raises(TypeError):
tracer = operation(tracer)
@pytest.mark.parametrize(
"function_to_trace,input_value,input_and_expected_output_tuples",
[
(
lambda x: numpy.transpose(x),
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
[
(numpy.arange(4).reshape(2, 2), numpy.array([[0, 2], [1, 3]])),
(numpy.arange(4, 8).reshape(2, 2), numpy.array([[4, 6], [5, 7]])),
],
),
(
lambda x: numpy.transpose(x) + 42,
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
[
(numpy.arange(15).reshape(3, 5), numpy.arange(42, 57).reshape(3, 5).transpose()),
],
),
(
lambda x: numpy.ravel(x),
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
[
(numpy.arange(4), numpy.array([0, 1, 2, 3])),
(numpy.arange(4).reshape(2, 2), numpy.array([0, 1, 2, 3])),
],
),
],
)
def test_tracing_generic_function(function_to_trace, input_value, input_and_expected_output_tuples):
"""Test function for managed by GenericFunction node"""
for input_, expected_output in input_and_expected_output_tuples:
op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value})
output_node = op_graph.output_nodes[0]
node_results = op_graph.evaluate({0: input_})
evaluated_output = node_results[output_node]
assert isinstance(evaluated_output, type(expected_output))
assert numpy.array_equal(expected_output, evaluated_output)