diff --git a/concrete/common/debugging/drawing.py b/concrete/common/debugging/drawing.py index c348e5c5d..664c38348 100644 --- a/concrete/common/debugging/drawing.py +++ b/concrete/common/debugging/drawing.py @@ -18,6 +18,7 @@ from ..representation.intermediate import ( Dot, IndexConstant, Input, + MatMul, Mul, Sub, UnivariateFunction, @@ -32,6 +33,7 @@ IR_NODE_COLOR_MAPPING = { UnivariateFunction: "orange", IndexConstant: "black", Dot: "purple", + MatMul: "brown", "UnivariateFunction": "orange", "TLU": "grey", "output": "magenta", diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py index 39405af8e..a40de7557 100644 --- a/concrete/common/mlir/utils.py +++ b/concrete/common/mlir/utils.py @@ -78,6 +78,9 @@ def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool) assert_true(len(outputs) == 1) return "indexing is not supported for the time being" + elif isinstance(node, intermediate.MatMul): # constraints for matrix multiplication + return "matrix multiplication is not supported for the time being" + else: # pragma: no cover assert_not_reached("Non IntermediateNode object in the OPGraph") diff --git a/concrete/common/representation/intermediate.py b/concrete/common/representation/intermediate.py index 884801b47..e4493ce74 100644 --- a/concrete/common/representation/intermediate.py +++ b/concrete/common/representation/intermediate.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections import deque from copy import deepcopy -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast from loguru import logger @@ -424,3 +424,49 @@ class Dot(IntermediateNode): def label(self) -> str: return "dot" + + +class MatMul(IntermediateNode): + """Return the node representing a matrix multiplication.""" + + _n_in: int = 2 + + def __init__( + self, + inputs: Iterable[BaseValue], + output_dtype: BaseDataType, + ) -> None: + super().__init__(inputs) + assert_true(len(self.inputs) == 2) + + assert_true( + all( + isinstance(input_value, TensorValue) and input_value.ndim == 2 + for input_value in self.inputs + ), + f"MatMul only supports two matrices ({TensorValue.__name__} with ndim == 2)", + ) + + # regular assertions are for mypy to see the inputs are TensorValue + lhs = cast(TensorValue, self.inputs[0]) + rhs = cast(TensorValue, self.inputs[1]) + + assert_true( + lhs.shape[1] == rhs.shape[0], + f"MatMul between matrices of shapes {lhs.shape} and {rhs.shape} " f"is not supported", + ) + + output_shape = (lhs.shape[0], rhs.shape[1]) + output_value = ( + EncryptedTensor(dtype=output_dtype, shape=output_shape) + if (lhs.is_encrypted or rhs.is_encrypted) + else ClearTensor(dtype=output_dtype, shape=output_shape) + ) + + self.outputs = [output_value] + + def evaluate(self, inputs: Dict[int, Any]) -> Any: + return inputs[0] @ inputs[1] + + def label(self) -> str: + return "@" diff --git a/concrete/numpy/np_dtypes_helpers.py b/concrete/numpy/np_dtypes_helpers.py index 78ac6a9bc..9ff9db7af 100644 --- a/concrete/numpy/np_dtypes_helpers.py +++ b/concrete/numpy/np_dtypes_helpers.py @@ -2,7 +2,7 @@ from copy import deepcopy from functools import partial -from typing import Any, Callable, Dict, List, Union +from typing import Any, Callable, Dict, List, Tuple, Union import numpy from numpy.typing import DTypeLike @@ -18,6 +18,7 @@ from ..common.data_types.dtypes_helpers import ( from ..common.data_types.floats import Float from ..common.data_types.integers import Integer from ..common.debugging.custom_assert import assert_true +from ..common.tracing import BaseTracer from ..common.values import BaseValue, TensorValue NUMPY_TO_COMMON_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = { @@ -182,9 +183,10 @@ def get_base_value_for_numpy_or_python_constant_data( return constant_data_value -def get_numpy_function_output_dtype( +def get_numpy_function_output_dtype_from_input_dtypes( function: Union[numpy.ufunc, Callable], input_dtypes: List[BaseDataType], + input_shapes: List[Tuple[int, ...]], ) -> List[numpy.dtype]: """Record the output dtype of a numpy function given some input types. @@ -193,6 +195,8 @@ def get_numpy_function_output_dtype( be recorded input_dtypes (List[BaseDataType]): BaseDataTypes in the same order as they will be used with the function inputs + input_shapes (List[Tuple[int, ...]]): Shapes in the same order as they will be used with + the function inputs Returns: List[numpy.dtype]: The ordered numpy dtypes of the function outputs @@ -206,7 +210,12 @@ def get_numpy_function_output_dtype( input_numpy_dtypes = [convert_base_data_type_to_numpy_dtype(dtype) for dtype in input_dtypes] dummy_inputs = tuple( - dtype.type(1000.0 * numpy.random.random_sample()) for dtype in input_numpy_dtypes + ( + dtype.type(10.0 * numpy.random.random_sample()) + if shape == () + else numpy.abs(numpy.random.randn(*shape) * 10.0).astype(dtype) + ) + for dtype, shape in zip(input_numpy_dtypes, input_shapes) ) # We ignore errors as we may call functions with invalid inputs just to get the proper output @@ -220,6 +229,36 @@ def get_numpy_function_output_dtype( return [output.dtype for output in outputs] +def get_numpy_function_output_dtype_from_input_tracers( + func: Union[numpy.ufunc, Callable], + *input_tracers: BaseTracer, +) -> List[BaseDataType]: + """Determine output dtypes for a numpy function. + + This function is responsible for determining the output dtype + of a numpy function after inputs with specific dtypes are passed to it. + + Args: + func (Union[numpy.ufunc, Callable]): function that is being managed + *input_tracers (BaseTracer): inputs to the function + + Returns: + List[numpy.dtype]: appropriate BaseDataType for each output of the function + """ + + input_shapes = [ + input_tracer.output.shape if isinstance(input_tracer.output, TensorValue) else () + for input_tracer in input_tracers + ] + output_dtypes = get_numpy_function_output_dtype_from_input_dtypes( + func, + [input_tracer.output.dtype for input_tracer in input_tracers], + input_shapes, + ) + common_output_dtypes = [convert_numpy_dtype_to_base_data_type(dtype) for dtype in output_dtypes] + return common_output_dtypes + + def get_constructor_for_numpy_or_python_constant_data(constant_data: Any): """Get the constructor for the numpy constant data or python dtype. diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index dbb2f25b6..8b5488d20 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -9,14 +9,14 @@ from numpy.typing import DTypeLike from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype from ..common.debugging.custom_assert import assert_true from ..common.operator_graph import OPGraph -from ..common.representation.intermediate import Constant, Dot, UnivariateFunction +from ..common.representation.intermediate import Constant, Dot, MatMul, UnivariateFunction from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters from ..common.values import BaseValue from .np_dtypes_helpers import ( SUPPORTED_NUMPY_DTYPES_CLASS_TYPES, convert_numpy_dtype_to_base_data_type, get_base_value_for_numpy_or_python_constant_data, - get_numpy_function_output_dtype, + get_numpy_function_output_dtype_from_input_tracers, ) from .np_indexing_helpers import process_indexing_element @@ -139,16 +139,6 @@ class NPTracer(BaseTracer): def _make_const_input_tracer(self, constant_data: Any) -> "NPTracer": return self.__class__([], NPConstant(constant_data), 0) - @staticmethod - def _manage_dtypes(ufunc: Union[numpy.ufunc, Callable], *input_tracers: BaseTracer): - output_dtypes = get_numpy_function_output_dtype( - ufunc, [input_tracer.output.dtype for input_tracer in input_tracers] - ) - common_output_dtypes = [ - convert_numpy_dtype_to_base_data_type(dtype) for dtype in output_dtypes - ] - return common_output_dtypes - @classmethod def _unary_operator( cls, unary_operator, unary_operator_string, *input_tracers: "NPTracer", **kwargs @@ -159,7 +149,10 @@ class NPTracer(BaseTracer): NPTracer: The output NPTracer containing the traced function """ assert_true(len(input_tracers) == 1) - common_output_dtypes = cls._manage_dtypes(unary_operator, *input_tracers) + common_output_dtypes = get_numpy_function_output_dtype_from_input_tracers( + unary_operator, + *input_tracers, + ) assert_true(len(common_output_dtypes) == 1) traced_computation = UnivariateFunction( @@ -211,7 +204,10 @@ class NPTracer(BaseTracer): def arbitrary_func(x, baked_constant, **kwargs): return binary_operator(x, baked_constant, **kwargs) - common_output_dtypes = cls._manage_dtypes(binary_operator, *input_tracers) + common_output_dtypes = get_numpy_function_output_dtype_from_input_tracers( + binary_operator, + *input_tracers, + ) assert_true(len(common_output_dtypes) == 1) op_kwargs = deepcopy(kwargs) @@ -249,7 +245,7 @@ class NPTracer(BaseTracer): """ assert_true((num_args := len(args)) == 2, f"dot expects 2 inputs got {num_args}") - common_output_dtypes = self._manage_dtypes(numpy.dot, *args) + common_output_dtypes = get_numpy_function_output_dtype_from_input_tracers(numpy.dot, *args) assert_true(len(common_output_dtypes) == 1) traced_computation = Dot( @@ -273,6 +269,9 @@ class NPTracer(BaseTracer): return BaseTracer.__getitem__(self, item) + def __matmul__(self, other): + return self.__array_ufunc__(numpy.matmul, "__call__", self, other) + # Supported functions are either univariate or bivariate for which one of the two # sources is a constant # @@ -340,7 +339,6 @@ class NPTracer(BaseTracer): numpy.logical_not, numpy.logical_or, numpy.logical_xor, - # numpy.matmul, numpy.maximum, numpy.minimum, numpy.negative, @@ -436,9 +434,23 @@ def _on_numpy_multiply(lhs, rhs): return lhs.__mul__(rhs) +def _on_numpy_matmul(lhs, rhs): + common_output_dtypes = get_numpy_function_output_dtype_from_input_tracers( + numpy.matmul, lhs, rhs + ) + assert_true(len(common_output_dtypes) == 1) + + traced_computation = MatMul( + [lhs.output, rhs.output], + common_output_dtypes[0], + ) + return NPTracer([lhs, rhs], traced_computation, output_idx=0) + + NPTracer.UFUNC_ROUTING[numpy.add] = _on_numpy_add NPTracer.UFUNC_ROUTING[numpy.subtract] = _on_numpy_subtract NPTracer.UFUNC_ROUTING[numpy.multiply] = _on_numpy_multiply +NPTracer.UFUNC_ROUTING[numpy.matmul] = _on_numpy_matmul def trace_numpy_function( diff --git a/tests/common/representation/test_intermediate.py b/tests/common/representation/test_intermediate.py index 0228525e8..00e5905ef 100644 --- a/tests/common/representation/test_intermediate.py +++ b/tests/common/representation/test_intermediate.py @@ -1,5 +1,4 @@ """Test file for intermediate representation""" - from copy import deepcopy import numpy @@ -164,6 +163,18 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En ), id="IndexConstant, np.array([[1, 2, 3, 4]...[13, 14, 15, 16]])[1:3, 2:0:-1]", ), + pytest.param( + ir.MatMul( + [ + EncryptedTensor(Integer(32, True), shape=(3, 2)), + ClearTensor(Integer(32, True), shape=(2, 3)), + ], + Integer(32, True), + ), + [numpy.arange(1, 7).reshape(3, 2), numpy.arange(1, 7).reshape(2, 3)], + numpy.array([[9, 12, 15], [19, 26, 33], [29, 40, 51]]), + id="MatMul, numpy.arange(1, 7).reshape(3, 2), numpy.arange(1, 7).reshape(2, 3)", + ), ], ) def test_evaluate( diff --git a/tests/conftest.py b/tests/conftest.py index 56cfd08bb..e947978f6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,6 +17,7 @@ from concrete.common.representation.intermediate import ( IndexConstant, Input, IntermediateNode, + MatMul, Mul, Sub, UnivariateFunction, @@ -167,6 +168,11 @@ def is_equivalent_sub(lhs: Sub, rhs: object) -> bool: return _is_equivalent_to_binary_non_commutative(lhs, rhs) +def is_equivalent_matmul(lhs: MatMul, rhs: object) -> bool: + """Helper function to check if a MatMul node is equivalent to an other object.""" + return isinstance(rhs, MatMul) and is_equivalent_intermediate_node(lhs, rhs) + + def is_equivalent_intermediate_node(lhs: IntermediateNode, rhs: object) -> bool: """Helper function to check if an IntermediateNode node is equivalent to an other object.""" return ( @@ -185,6 +191,7 @@ EQUIVALENT_TEST_FUNC: Dict[Type, Callable[..., bool]] = { Input: is_equivalent_input, Mul: is_equivalent_mul, Sub: is_equivalent_sub, + MatMul: is_equivalent_matmul, } _missing_nodes_in_mapping = ALL_IR_NODES - EQUIVALENT_TEST_FUNC.keys() diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 2be41db05..aa9ffcf87 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -880,6 +880,34 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura "return(%7)\n" ), ), + pytest.param( + lambda x: x @ numpy.ones(shape=(2, 3), dtype=numpy.uint32), + {"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))}, + [(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)], + ( + "function you are trying to compile isn't supported for MLIR lowering\n" + "\n" + "%0 = x # EncryptedTensor, shape=(3, 2)>\n" # noqa: E501 + "%1 = Constant([[1 1 1] [1 1 1]]) # ClearTensor, shape=(2, 3)>\n" # noqa: E501 + "%2 = MatMul(%0, %1) # EncryptedTensor, shape=(3, 3)>\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being\n" # noqa: E501 + "return(%2)\n" + ), + ), + pytest.param( + lambda x: numpy.matmul(x, numpy.ones(shape=(2, 3), dtype=numpy.uint32)), + {"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))}, + [(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)], + ( + "function you are trying to compile isn't supported for MLIR lowering\n" + "\n" + "%0 = x # EncryptedTensor, shape=(3, 2)>\n" # noqa: E501 + "%1 = Constant([[1 1 1] [1 1 1]]) # ClearTensor, shape=(2, 3)>\n" # noqa: E501 + "%2 = MatMul(%0, %1) # EncryptedTensor, shape=(3, 3)>\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being\n" # noqa: E501 + "return(%2)\n" + ), + ), ], ) # pylint: enable=line-too-long,unnecessary-lambda @@ -894,7 +922,7 @@ def test_fail_compile(function, parameters, inputset, match, default_compilation default_compilation_configuration, ) - assert str(excinfo.value) == match + assert str(excinfo.value) == match, str(excinfo.value) def test_fail_with_intermediate_signed_values(default_compilation_configuration):