feat(tracing): implement tracing of matmul

This commit is contained in:
Umut
2021-10-26 13:14:02 +03:00
parent 118e6454b7
commit eedbe0606b
8 changed files with 170 additions and 22 deletions

View File

@@ -18,6 +18,7 @@ from ..representation.intermediate import (
Dot,
IndexConstant,
Input,
MatMul,
Mul,
Sub,
UnivariateFunction,
@@ -32,6 +33,7 @@ IR_NODE_COLOR_MAPPING = {
UnivariateFunction: "orange",
IndexConstant: "black",
Dot: "purple",
MatMul: "brown",
"UnivariateFunction": "orange",
"TLU": "grey",
"output": "magenta",

View File

@@ -78,6 +78,9 @@ def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool)
assert_true(len(outputs) == 1)
return "indexing is not supported for the time being"
elif isinstance(node, intermediate.MatMul): # constraints for matrix multiplication
return "matrix multiplication is not supported for the time being"
else: # pragma: no cover
assert_not_reached("Non IntermediateNode object in the OPGraph")

View File

@@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from collections import deque
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast
from loguru import logger
@@ -424,3 +424,49 @@ class Dot(IntermediateNode):
def label(self) -> str:
return "dot"
class MatMul(IntermediateNode):
"""Return the node representing a matrix multiplication."""
_n_in: int = 2
def __init__(
self,
inputs: Iterable[BaseValue],
output_dtype: BaseDataType,
) -> None:
super().__init__(inputs)
assert_true(len(self.inputs) == 2)
assert_true(
all(
isinstance(input_value, TensorValue) and input_value.ndim == 2
for input_value in self.inputs
),
f"MatMul only supports two matrices ({TensorValue.__name__} with ndim == 2)",
)
# regular assertions are for mypy to see the inputs are TensorValue
lhs = cast(TensorValue, self.inputs[0])
rhs = cast(TensorValue, self.inputs[1])
assert_true(
lhs.shape[1] == rhs.shape[0],
f"MatMul between matrices of shapes {lhs.shape} and {rhs.shape} " f"is not supported",
)
output_shape = (lhs.shape[0], rhs.shape[1])
output_value = (
EncryptedTensor(dtype=output_dtype, shape=output_shape)
if (lhs.is_encrypted or rhs.is_encrypted)
else ClearTensor(dtype=output_dtype, shape=output_shape)
)
self.outputs = [output_value]
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0] @ inputs[1]
def label(self) -> str:
return "@"

View File

@@ -2,7 +2,7 @@
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, List, Union
from typing import Any, Callable, Dict, List, Tuple, Union
import numpy
from numpy.typing import DTypeLike
@@ -18,6 +18,7 @@ from ..common.data_types.dtypes_helpers import (
from ..common.data_types.floats import Float
from ..common.data_types.integers import Integer
from ..common.debugging.custom_assert import assert_true
from ..common.tracing import BaseTracer
from ..common.values import BaseValue, TensorValue
NUMPY_TO_COMMON_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = {
@@ -182,9 +183,10 @@ def get_base_value_for_numpy_or_python_constant_data(
return constant_data_value
def get_numpy_function_output_dtype(
def get_numpy_function_output_dtype_from_input_dtypes(
function: Union[numpy.ufunc, Callable],
input_dtypes: List[BaseDataType],
input_shapes: List[Tuple[int, ...]],
) -> List[numpy.dtype]:
"""Record the output dtype of a numpy function given some input types.
@@ -193,6 +195,8 @@ def get_numpy_function_output_dtype(
be recorded
input_dtypes (List[BaseDataType]): BaseDataTypes in the same order as they will be used with
the function inputs
input_shapes (List[Tuple[int, ...]]): Shapes in the same order as they will be used with
the function inputs
Returns:
List[numpy.dtype]: The ordered numpy dtypes of the function outputs
@@ -206,7 +210,12 @@ def get_numpy_function_output_dtype(
input_numpy_dtypes = [convert_base_data_type_to_numpy_dtype(dtype) for dtype in input_dtypes]
dummy_inputs = tuple(
dtype.type(1000.0 * numpy.random.random_sample()) for dtype in input_numpy_dtypes
(
dtype.type(10.0 * numpy.random.random_sample())
if shape == ()
else numpy.abs(numpy.random.randn(*shape) * 10.0).astype(dtype)
)
for dtype, shape in zip(input_numpy_dtypes, input_shapes)
)
# We ignore errors as we may call functions with invalid inputs just to get the proper output
@@ -220,6 +229,36 @@ def get_numpy_function_output_dtype(
return [output.dtype for output in outputs]
def get_numpy_function_output_dtype_from_input_tracers(
func: Union[numpy.ufunc, Callable],
*input_tracers: BaseTracer,
) -> List[BaseDataType]:
"""Determine output dtypes for a numpy function.
This function is responsible for determining the output dtype
of a numpy function after inputs with specific dtypes are passed to it.
Args:
func (Union[numpy.ufunc, Callable]): function that is being managed
*input_tracers (BaseTracer): inputs to the function
Returns:
List[numpy.dtype]: appropriate BaseDataType for each output of the function
"""
input_shapes = [
input_tracer.output.shape if isinstance(input_tracer.output, TensorValue) else ()
for input_tracer in input_tracers
]
output_dtypes = get_numpy_function_output_dtype_from_input_dtypes(
func,
[input_tracer.output.dtype for input_tracer in input_tracers],
input_shapes,
)
common_output_dtypes = [convert_numpy_dtype_to_base_data_type(dtype) for dtype in output_dtypes]
return common_output_dtypes
def get_constructor_for_numpy_or_python_constant_data(constant_data: Any):
"""Get the constructor for the numpy constant data or python dtype.

View File

@@ -9,14 +9,14 @@ from numpy.typing import DTypeLike
from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype
from ..common.debugging.custom_assert import assert_true
from ..common.operator_graph import OPGraph
from ..common.representation.intermediate import Constant, Dot, UnivariateFunction
from ..common.representation.intermediate import Constant, Dot, MatMul, UnivariateFunction
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
from ..common.values import BaseValue
from .np_dtypes_helpers import (
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES,
convert_numpy_dtype_to_base_data_type,
get_base_value_for_numpy_or_python_constant_data,
get_numpy_function_output_dtype,
get_numpy_function_output_dtype_from_input_tracers,
)
from .np_indexing_helpers import process_indexing_element
@@ -139,16 +139,6 @@ class NPTracer(BaseTracer):
def _make_const_input_tracer(self, constant_data: Any) -> "NPTracer":
return self.__class__([], NPConstant(constant_data), 0)
@staticmethod
def _manage_dtypes(ufunc: Union[numpy.ufunc, Callable], *input_tracers: BaseTracer):
output_dtypes = get_numpy_function_output_dtype(
ufunc, [input_tracer.output.dtype for input_tracer in input_tracers]
)
common_output_dtypes = [
convert_numpy_dtype_to_base_data_type(dtype) for dtype in output_dtypes
]
return common_output_dtypes
@classmethod
def _unary_operator(
cls, unary_operator, unary_operator_string, *input_tracers: "NPTracer", **kwargs
@@ -159,7 +149,10 @@ class NPTracer(BaseTracer):
NPTracer: The output NPTracer containing the traced function
"""
assert_true(len(input_tracers) == 1)
common_output_dtypes = cls._manage_dtypes(unary_operator, *input_tracers)
common_output_dtypes = get_numpy_function_output_dtype_from_input_tracers(
unary_operator,
*input_tracers,
)
assert_true(len(common_output_dtypes) == 1)
traced_computation = UnivariateFunction(
@@ -211,7 +204,10 @@ class NPTracer(BaseTracer):
def arbitrary_func(x, baked_constant, **kwargs):
return binary_operator(x, baked_constant, **kwargs)
common_output_dtypes = cls._manage_dtypes(binary_operator, *input_tracers)
common_output_dtypes = get_numpy_function_output_dtype_from_input_tracers(
binary_operator,
*input_tracers,
)
assert_true(len(common_output_dtypes) == 1)
op_kwargs = deepcopy(kwargs)
@@ -249,7 +245,7 @@ class NPTracer(BaseTracer):
"""
assert_true((num_args := len(args)) == 2, f"dot expects 2 inputs got {num_args}")
common_output_dtypes = self._manage_dtypes(numpy.dot, *args)
common_output_dtypes = get_numpy_function_output_dtype_from_input_tracers(numpy.dot, *args)
assert_true(len(common_output_dtypes) == 1)
traced_computation = Dot(
@@ -273,6 +269,9 @@ class NPTracer(BaseTracer):
return BaseTracer.__getitem__(self, item)
def __matmul__(self, other):
return self.__array_ufunc__(numpy.matmul, "__call__", self, other)
# Supported functions are either univariate or bivariate for which one of the two
# sources is a constant
#
@@ -340,7 +339,6 @@ class NPTracer(BaseTracer):
numpy.logical_not,
numpy.logical_or,
numpy.logical_xor,
# numpy.matmul,
numpy.maximum,
numpy.minimum,
numpy.negative,
@@ -436,9 +434,23 @@ def _on_numpy_multiply(lhs, rhs):
return lhs.__mul__(rhs)
def _on_numpy_matmul(lhs, rhs):
common_output_dtypes = get_numpy_function_output_dtype_from_input_tracers(
numpy.matmul, lhs, rhs
)
assert_true(len(common_output_dtypes) == 1)
traced_computation = MatMul(
[lhs.output, rhs.output],
common_output_dtypes[0],
)
return NPTracer([lhs, rhs], traced_computation, output_idx=0)
NPTracer.UFUNC_ROUTING[numpy.add] = _on_numpy_add
NPTracer.UFUNC_ROUTING[numpy.subtract] = _on_numpy_subtract
NPTracer.UFUNC_ROUTING[numpy.multiply] = _on_numpy_multiply
NPTracer.UFUNC_ROUTING[numpy.matmul] = _on_numpy_matmul
def trace_numpy_function(

View File

@@ -1,5 +1,4 @@
"""Test file for intermediate representation"""
from copy import deepcopy
import numpy
@@ -164,6 +163,18 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
),
id="IndexConstant, np.array([[1, 2, 3, 4]...[13, 14, 15, 16]])[1:3, 2:0:-1]",
),
pytest.param(
ir.MatMul(
[
EncryptedTensor(Integer(32, True), shape=(3, 2)),
ClearTensor(Integer(32, True), shape=(2, 3)),
],
Integer(32, True),
),
[numpy.arange(1, 7).reshape(3, 2), numpy.arange(1, 7).reshape(2, 3)],
numpy.array([[9, 12, 15], [19, 26, 33], [29, 40, 51]]),
id="MatMul, numpy.arange(1, 7).reshape(3, 2), numpy.arange(1, 7).reshape(2, 3)",
),
],
)
def test_evaluate(

View File

@@ -17,6 +17,7 @@ from concrete.common.representation.intermediate import (
IndexConstant,
Input,
IntermediateNode,
MatMul,
Mul,
Sub,
UnivariateFunction,
@@ -167,6 +168,11 @@ def is_equivalent_sub(lhs: Sub, rhs: object) -> bool:
return _is_equivalent_to_binary_non_commutative(lhs, rhs)
def is_equivalent_matmul(lhs: MatMul, rhs: object) -> bool:
"""Helper function to check if a MatMul node is equivalent to an other object."""
return isinstance(rhs, MatMul) and is_equivalent_intermediate_node(lhs, rhs)
def is_equivalent_intermediate_node(lhs: IntermediateNode, rhs: object) -> bool:
"""Helper function to check if an IntermediateNode node is equivalent to an other object."""
return (
@@ -185,6 +191,7 @@ EQUIVALENT_TEST_FUNC: Dict[Type, Callable[..., bool]] = {
Input: is_equivalent_input,
Mul: is_equivalent_mul,
Sub: is_equivalent_sub,
MatMul: is_equivalent_matmul,
}
_missing_nodes_in_mapping = ALL_IR_NODES - EQUIVALENT_TEST_FUNC.keys()

View File

@@ -880,6 +880,34 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
"return(%7)\n"
),
),
pytest.param(
lambda x: x @ numpy.ones(shape=(2, 3), dtype=numpy.uint32),
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)],
(
"function you are trying to compile isn't supported for MLIR lowering\n"
"\n"
"%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(3, 2)>\n" # noqa: E501
"%1 = Constant([[1 1 1] [1 1 1]]) # ClearTensor<Integer<unsigned, 1 bits>, shape=(2, 3)>\n" # noqa: E501
"%2 = MatMul(%0, %1) # EncryptedTensor<Integer<unsigned, 4 bits>, shape=(3, 3)>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being\n" # noqa: E501
"return(%2)\n"
),
),
pytest.param(
lambda x: numpy.matmul(x, numpy.ones(shape=(2, 3), dtype=numpy.uint32)),
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)],
(
"function you are trying to compile isn't supported for MLIR lowering\n"
"\n"
"%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(3, 2)>\n" # noqa: E501
"%1 = Constant([[1 1 1] [1 1 1]]) # ClearTensor<Integer<unsigned, 1 bits>, shape=(2, 3)>\n" # noqa: E501
"%2 = MatMul(%0, %1) # EncryptedTensor<Integer<unsigned, 4 bits>, shape=(3, 3)>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being\n" # noqa: E501
"return(%2)\n"
),
),
],
)
# pylint: enable=line-too-long,unnecessary-lambda
@@ -894,7 +922,7 @@ def test_fail_compile(function, parameters, inputset, match, default_compilation
default_compilation_configuration,
)
assert str(excinfo.value) == match
assert str(excinfo.value) == match, str(excinfo.value)
def test_fail_with_intermediate_signed_values(default_compilation_configuration):