mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(tracing): implement tracing of matmul
This commit is contained in:
@@ -18,6 +18,7 @@ from ..representation.intermediate import (
|
||||
Dot,
|
||||
IndexConstant,
|
||||
Input,
|
||||
MatMul,
|
||||
Mul,
|
||||
Sub,
|
||||
UnivariateFunction,
|
||||
@@ -32,6 +33,7 @@ IR_NODE_COLOR_MAPPING = {
|
||||
UnivariateFunction: "orange",
|
||||
IndexConstant: "black",
|
||||
Dot: "purple",
|
||||
MatMul: "brown",
|
||||
"UnivariateFunction": "orange",
|
||||
"TLU": "grey",
|
||||
"output": "magenta",
|
||||
|
||||
@@ -78,6 +78,9 @@ def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool)
|
||||
assert_true(len(outputs) == 1)
|
||||
return "indexing is not supported for the time being"
|
||||
|
||||
elif isinstance(node, intermediate.MatMul): # constraints for matrix multiplication
|
||||
return "matrix multiplication is not supported for the time being"
|
||||
|
||||
else: # pragma: no cover
|
||||
assert_not_reached("Non IntermediateNode object in the OPGraph")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -424,3 +424,49 @@ class Dot(IntermediateNode):
|
||||
|
||||
def label(self) -> str:
|
||||
return "dot"
|
||||
|
||||
|
||||
class MatMul(IntermediateNode):
|
||||
"""Return the node representing a matrix multiplication."""
|
||||
|
||||
_n_in: int = 2
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inputs: Iterable[BaseValue],
|
||||
output_dtype: BaseDataType,
|
||||
) -> None:
|
||||
super().__init__(inputs)
|
||||
assert_true(len(self.inputs) == 2)
|
||||
|
||||
assert_true(
|
||||
all(
|
||||
isinstance(input_value, TensorValue) and input_value.ndim == 2
|
||||
for input_value in self.inputs
|
||||
),
|
||||
f"MatMul only supports two matrices ({TensorValue.__name__} with ndim == 2)",
|
||||
)
|
||||
|
||||
# regular assertions are for mypy to see the inputs are TensorValue
|
||||
lhs = cast(TensorValue, self.inputs[0])
|
||||
rhs = cast(TensorValue, self.inputs[1])
|
||||
|
||||
assert_true(
|
||||
lhs.shape[1] == rhs.shape[0],
|
||||
f"MatMul between matrices of shapes {lhs.shape} and {rhs.shape} " f"is not supported",
|
||||
)
|
||||
|
||||
output_shape = (lhs.shape[0], rhs.shape[1])
|
||||
output_value = (
|
||||
EncryptedTensor(dtype=output_dtype, shape=output_shape)
|
||||
if (lhs.is_encrypted or rhs.is_encrypted)
|
||||
else ClearTensor(dtype=output_dtype, shape=output_shape)
|
||||
)
|
||||
|
||||
self.outputs = [output_value]
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0] @ inputs[1]
|
||||
|
||||
def label(self) -> str:
|
||||
return "@"
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Union
|
||||
from typing import Any, Callable, Dict, List, Tuple, Union
|
||||
|
||||
import numpy
|
||||
from numpy.typing import DTypeLike
|
||||
@@ -18,6 +18,7 @@ from ..common.data_types.dtypes_helpers import (
|
||||
from ..common.data_types.floats import Float
|
||||
from ..common.data_types.integers import Integer
|
||||
from ..common.debugging.custom_assert import assert_true
|
||||
from ..common.tracing import BaseTracer
|
||||
from ..common.values import BaseValue, TensorValue
|
||||
|
||||
NUMPY_TO_COMMON_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = {
|
||||
@@ -182,9 +183,10 @@ def get_base_value_for_numpy_or_python_constant_data(
|
||||
return constant_data_value
|
||||
|
||||
|
||||
def get_numpy_function_output_dtype(
|
||||
def get_numpy_function_output_dtype_from_input_dtypes(
|
||||
function: Union[numpy.ufunc, Callable],
|
||||
input_dtypes: List[BaseDataType],
|
||||
input_shapes: List[Tuple[int, ...]],
|
||||
) -> List[numpy.dtype]:
|
||||
"""Record the output dtype of a numpy function given some input types.
|
||||
|
||||
@@ -193,6 +195,8 @@ def get_numpy_function_output_dtype(
|
||||
be recorded
|
||||
input_dtypes (List[BaseDataType]): BaseDataTypes in the same order as they will be used with
|
||||
the function inputs
|
||||
input_shapes (List[Tuple[int, ...]]): Shapes in the same order as they will be used with
|
||||
the function inputs
|
||||
|
||||
Returns:
|
||||
List[numpy.dtype]: The ordered numpy dtypes of the function outputs
|
||||
@@ -206,7 +210,12 @@ def get_numpy_function_output_dtype(
|
||||
input_numpy_dtypes = [convert_base_data_type_to_numpy_dtype(dtype) for dtype in input_dtypes]
|
||||
|
||||
dummy_inputs = tuple(
|
||||
dtype.type(1000.0 * numpy.random.random_sample()) for dtype in input_numpy_dtypes
|
||||
(
|
||||
dtype.type(10.0 * numpy.random.random_sample())
|
||||
if shape == ()
|
||||
else numpy.abs(numpy.random.randn(*shape) * 10.0).astype(dtype)
|
||||
)
|
||||
for dtype, shape in zip(input_numpy_dtypes, input_shapes)
|
||||
)
|
||||
|
||||
# We ignore errors as we may call functions with invalid inputs just to get the proper output
|
||||
@@ -220,6 +229,36 @@ def get_numpy_function_output_dtype(
|
||||
return [output.dtype for output in outputs]
|
||||
|
||||
|
||||
def get_numpy_function_output_dtype_from_input_tracers(
|
||||
func: Union[numpy.ufunc, Callable],
|
||||
*input_tracers: BaseTracer,
|
||||
) -> List[BaseDataType]:
|
||||
"""Determine output dtypes for a numpy function.
|
||||
|
||||
This function is responsible for determining the output dtype
|
||||
of a numpy function after inputs with specific dtypes are passed to it.
|
||||
|
||||
Args:
|
||||
func (Union[numpy.ufunc, Callable]): function that is being managed
|
||||
*input_tracers (BaseTracer): inputs to the function
|
||||
|
||||
Returns:
|
||||
List[numpy.dtype]: appropriate BaseDataType for each output of the function
|
||||
"""
|
||||
|
||||
input_shapes = [
|
||||
input_tracer.output.shape if isinstance(input_tracer.output, TensorValue) else ()
|
||||
for input_tracer in input_tracers
|
||||
]
|
||||
output_dtypes = get_numpy_function_output_dtype_from_input_dtypes(
|
||||
func,
|
||||
[input_tracer.output.dtype for input_tracer in input_tracers],
|
||||
input_shapes,
|
||||
)
|
||||
common_output_dtypes = [convert_numpy_dtype_to_base_data_type(dtype) for dtype in output_dtypes]
|
||||
return common_output_dtypes
|
||||
|
||||
|
||||
def get_constructor_for_numpy_or_python_constant_data(constant_data: Any):
|
||||
"""Get the constructor for the numpy constant data or python dtype.
|
||||
|
||||
|
||||
@@ -9,14 +9,14 @@ from numpy.typing import DTypeLike
|
||||
from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype
|
||||
from ..common.debugging.custom_assert import assert_true
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.representation.intermediate import Constant, Dot, UnivariateFunction
|
||||
from ..common.representation.intermediate import Constant, Dot, MatMul, UnivariateFunction
|
||||
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
|
||||
from ..common.values import BaseValue
|
||||
from .np_dtypes_helpers import (
|
||||
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES,
|
||||
convert_numpy_dtype_to_base_data_type,
|
||||
get_base_value_for_numpy_or_python_constant_data,
|
||||
get_numpy_function_output_dtype,
|
||||
get_numpy_function_output_dtype_from_input_tracers,
|
||||
)
|
||||
from .np_indexing_helpers import process_indexing_element
|
||||
|
||||
@@ -139,16 +139,6 @@ class NPTracer(BaseTracer):
|
||||
def _make_const_input_tracer(self, constant_data: Any) -> "NPTracer":
|
||||
return self.__class__([], NPConstant(constant_data), 0)
|
||||
|
||||
@staticmethod
|
||||
def _manage_dtypes(ufunc: Union[numpy.ufunc, Callable], *input_tracers: BaseTracer):
|
||||
output_dtypes = get_numpy_function_output_dtype(
|
||||
ufunc, [input_tracer.output.dtype for input_tracer in input_tracers]
|
||||
)
|
||||
common_output_dtypes = [
|
||||
convert_numpy_dtype_to_base_data_type(dtype) for dtype in output_dtypes
|
||||
]
|
||||
return common_output_dtypes
|
||||
|
||||
@classmethod
|
||||
def _unary_operator(
|
||||
cls, unary_operator, unary_operator_string, *input_tracers: "NPTracer", **kwargs
|
||||
@@ -159,7 +149,10 @@ class NPTracer(BaseTracer):
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
assert_true(len(input_tracers) == 1)
|
||||
common_output_dtypes = cls._manage_dtypes(unary_operator, *input_tracers)
|
||||
common_output_dtypes = get_numpy_function_output_dtype_from_input_tracers(
|
||||
unary_operator,
|
||||
*input_tracers,
|
||||
)
|
||||
assert_true(len(common_output_dtypes) == 1)
|
||||
|
||||
traced_computation = UnivariateFunction(
|
||||
@@ -211,7 +204,10 @@ class NPTracer(BaseTracer):
|
||||
def arbitrary_func(x, baked_constant, **kwargs):
|
||||
return binary_operator(x, baked_constant, **kwargs)
|
||||
|
||||
common_output_dtypes = cls._manage_dtypes(binary_operator, *input_tracers)
|
||||
common_output_dtypes = get_numpy_function_output_dtype_from_input_tracers(
|
||||
binary_operator,
|
||||
*input_tracers,
|
||||
)
|
||||
assert_true(len(common_output_dtypes) == 1)
|
||||
|
||||
op_kwargs = deepcopy(kwargs)
|
||||
@@ -249,7 +245,7 @@ class NPTracer(BaseTracer):
|
||||
"""
|
||||
assert_true((num_args := len(args)) == 2, f"dot expects 2 inputs got {num_args}")
|
||||
|
||||
common_output_dtypes = self._manage_dtypes(numpy.dot, *args)
|
||||
common_output_dtypes = get_numpy_function_output_dtype_from_input_tracers(numpy.dot, *args)
|
||||
assert_true(len(common_output_dtypes) == 1)
|
||||
|
||||
traced_computation = Dot(
|
||||
@@ -273,6 +269,9 @@ class NPTracer(BaseTracer):
|
||||
|
||||
return BaseTracer.__getitem__(self, item)
|
||||
|
||||
def __matmul__(self, other):
|
||||
return self.__array_ufunc__(numpy.matmul, "__call__", self, other)
|
||||
|
||||
# Supported functions are either univariate or bivariate for which one of the two
|
||||
# sources is a constant
|
||||
#
|
||||
@@ -340,7 +339,6 @@ class NPTracer(BaseTracer):
|
||||
numpy.logical_not,
|
||||
numpy.logical_or,
|
||||
numpy.logical_xor,
|
||||
# numpy.matmul,
|
||||
numpy.maximum,
|
||||
numpy.minimum,
|
||||
numpy.negative,
|
||||
@@ -436,9 +434,23 @@ def _on_numpy_multiply(lhs, rhs):
|
||||
return lhs.__mul__(rhs)
|
||||
|
||||
|
||||
def _on_numpy_matmul(lhs, rhs):
|
||||
common_output_dtypes = get_numpy_function_output_dtype_from_input_tracers(
|
||||
numpy.matmul, lhs, rhs
|
||||
)
|
||||
assert_true(len(common_output_dtypes) == 1)
|
||||
|
||||
traced_computation = MatMul(
|
||||
[lhs.output, rhs.output],
|
||||
common_output_dtypes[0],
|
||||
)
|
||||
return NPTracer([lhs, rhs], traced_computation, output_idx=0)
|
||||
|
||||
|
||||
NPTracer.UFUNC_ROUTING[numpy.add] = _on_numpy_add
|
||||
NPTracer.UFUNC_ROUTING[numpy.subtract] = _on_numpy_subtract
|
||||
NPTracer.UFUNC_ROUTING[numpy.multiply] = _on_numpy_multiply
|
||||
NPTracer.UFUNC_ROUTING[numpy.matmul] = _on_numpy_matmul
|
||||
|
||||
|
||||
def trace_numpy_function(
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""Test file for intermediate representation"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy
|
||||
@@ -164,6 +163,18 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
|
||||
),
|
||||
id="IndexConstant, np.array([[1, 2, 3, 4]...[13, 14, 15, 16]])[1:3, 2:0:-1]",
|
||||
),
|
||||
pytest.param(
|
||||
ir.MatMul(
|
||||
[
|
||||
EncryptedTensor(Integer(32, True), shape=(3, 2)),
|
||||
ClearTensor(Integer(32, True), shape=(2, 3)),
|
||||
],
|
||||
Integer(32, True),
|
||||
),
|
||||
[numpy.arange(1, 7).reshape(3, 2), numpy.arange(1, 7).reshape(2, 3)],
|
||||
numpy.array([[9, 12, 15], [19, 26, 33], [29, 40, 51]]),
|
||||
id="MatMul, numpy.arange(1, 7).reshape(3, 2), numpy.arange(1, 7).reshape(2, 3)",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_evaluate(
|
||||
|
||||
@@ -17,6 +17,7 @@ from concrete.common.representation.intermediate import (
|
||||
IndexConstant,
|
||||
Input,
|
||||
IntermediateNode,
|
||||
MatMul,
|
||||
Mul,
|
||||
Sub,
|
||||
UnivariateFunction,
|
||||
@@ -167,6 +168,11 @@ def is_equivalent_sub(lhs: Sub, rhs: object) -> bool:
|
||||
return _is_equivalent_to_binary_non_commutative(lhs, rhs)
|
||||
|
||||
|
||||
def is_equivalent_matmul(lhs: MatMul, rhs: object) -> bool:
|
||||
"""Helper function to check if a MatMul node is equivalent to an other object."""
|
||||
return isinstance(rhs, MatMul) and is_equivalent_intermediate_node(lhs, rhs)
|
||||
|
||||
|
||||
def is_equivalent_intermediate_node(lhs: IntermediateNode, rhs: object) -> bool:
|
||||
"""Helper function to check if an IntermediateNode node is equivalent to an other object."""
|
||||
return (
|
||||
@@ -185,6 +191,7 @@ EQUIVALENT_TEST_FUNC: Dict[Type, Callable[..., bool]] = {
|
||||
Input: is_equivalent_input,
|
||||
Mul: is_equivalent_mul,
|
||||
Sub: is_equivalent_sub,
|
||||
MatMul: is_equivalent_matmul,
|
||||
}
|
||||
|
||||
_missing_nodes_in_mapping = ALL_IR_NODES - EQUIVALENT_TEST_FUNC.keys()
|
||||
|
||||
@@ -880,6 +880,34 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
|
||||
"return(%7)\n"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x @ numpy.ones(shape=(2, 3), dtype=numpy.uint32),
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)],
|
||||
(
|
||||
"function you are trying to compile isn't supported for MLIR lowering\n"
|
||||
"\n"
|
||||
"%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(3, 2)>\n" # noqa: E501
|
||||
"%1 = Constant([[1 1 1] [1 1 1]]) # ClearTensor<Integer<unsigned, 1 bits>, shape=(2, 3)>\n" # noqa: E501
|
||||
"%2 = MatMul(%0, %1) # EncryptedTensor<Integer<unsigned, 4 bits>, shape=(3, 3)>\n" # noqa: E501
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being\n" # noqa: E501
|
||||
"return(%2)\n"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.matmul(x, numpy.ones(shape=(2, 3), dtype=numpy.uint32)),
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)],
|
||||
(
|
||||
"function you are trying to compile isn't supported for MLIR lowering\n"
|
||||
"\n"
|
||||
"%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(3, 2)>\n" # noqa: E501
|
||||
"%1 = Constant([[1 1 1] [1 1 1]]) # ClearTensor<Integer<unsigned, 1 bits>, shape=(2, 3)>\n" # noqa: E501
|
||||
"%2 = MatMul(%0, %1) # EncryptedTensor<Integer<unsigned, 4 bits>, shape=(3, 3)>\n" # noqa: E501
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being\n" # noqa: E501
|
||||
"return(%2)\n"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
# pylint: enable=line-too-long,unnecessary-lambda
|
||||
@@ -894,7 +922,7 @@ def test_fail_compile(function, parameters, inputset, match, default_compilation
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
assert str(excinfo.value) == match
|
||||
assert str(excinfo.value) == match, str(excinfo.value)
|
||||
|
||||
|
||||
def test_fail_with_intermediate_signed_values(default_compilation_configuration):
|
||||
|
||||
Reference in New Issue
Block a user