mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(mlir): implement MLIR conversion of MatMul
This commit is contained in:
@@ -28,6 +28,7 @@ from ..representation.intermediate import (
|
||||
Dot,
|
||||
GenericFunction,
|
||||
IntermediateNode,
|
||||
MatMul,
|
||||
Mul,
|
||||
Sub,
|
||||
)
|
||||
@@ -86,27 +87,31 @@ class IntermediateNodeConverter:
|
||||
"""
|
||||
|
||||
if isinstance(self.node, Add):
|
||||
return self.convert_add()
|
||||
result = self.convert_add()
|
||||
|
||||
if isinstance(self.node, Constant):
|
||||
return self.convert_constant()
|
||||
elif isinstance(self.node, Constant):
|
||||
result = self.convert_constant()
|
||||
|
||||
if isinstance(self.node, Dot):
|
||||
return self.convert_dot()
|
||||
elif isinstance(self.node, Dot):
|
||||
result = self.convert_dot()
|
||||
|
||||
if isinstance(self.node, GenericFunction):
|
||||
return self.convert_generic_function(additional_conversion_info)
|
||||
elif isinstance(self.node, GenericFunction):
|
||||
result = self.convert_generic_function(additional_conversion_info)
|
||||
|
||||
if isinstance(self.node, Mul):
|
||||
return self.convert_mul()
|
||||
elif isinstance(self.node, MatMul):
|
||||
result = self.convert_matmul()
|
||||
|
||||
if isinstance(self.node, Sub):
|
||||
return self.convert_sub()
|
||||
elif isinstance(self.node, Mul):
|
||||
result = self.convert_mul()
|
||||
|
||||
# this statement is not covered as unsupported opeations fail on check mlir compatibility
|
||||
raise NotImplementedError(
|
||||
f"{type(self.node)} nodes cannot be converted to MLIR yet"
|
||||
) # pragma: no cover
|
||||
elif isinstance(self.node, Sub):
|
||||
result = self.convert_sub()
|
||||
|
||||
else: # pragma: no cover
|
||||
# this branch is not covered as unsupported opeations fail on check mlir compatibility
|
||||
raise NotImplementedError(f"{type(self.node)} nodes cannot be converted to MLIR yet")
|
||||
|
||||
return result
|
||||
|
||||
def convert_add(self) -> OpResult:
|
||||
"""Convert an Add node to its corresponding MLIR representation.
|
||||
@@ -280,6 +285,37 @@ class IntermediateNodeConverter:
|
||||
|
||||
return result
|
||||
|
||||
def convert_matmul(self) -> OpResult:
|
||||
"""Convert a MatMul node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
assert_true(len(self.node.inputs) == 2)
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
|
||||
if self.all_of_the_inputs_are_encrypted or self.node.inputs[0].is_clear:
|
||||
lhs = self.node.inputs[0]
|
||||
rhs = self.node.inputs[1]
|
||||
|
||||
additional_error_info = (
|
||||
" (notice the encrypted value is in the right hand side which is not supported)"
|
||||
if self.node.inputs[0].is_clear
|
||||
else ""
|
||||
)
|
||||
raise NotImplementedError(
|
||||
f"Matrix multiplication between {lhs} and {rhs} cannot be converted to MLIR yet"
|
||||
f"{additional_error_info}",
|
||||
)
|
||||
|
||||
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
|
||||
preds = self.preds
|
||||
|
||||
result = hlfhelinalg.MatMulEintIntOp(resulting_type, *preds).result
|
||||
|
||||
return result
|
||||
|
||||
def convert_mul(self) -> OpResult:
|
||||
"""Convert a Mul node to its corresponding MLIR representation.
|
||||
|
||||
|
||||
@@ -113,7 +113,9 @@ def check_node_compatibility_with_mlir(
|
||||
return "indexing is not supported for the time being"
|
||||
|
||||
elif isinstance(node, intermediate.MatMul): # constraints for matrix multiplication
|
||||
return "matrix multiplication is not supported for the time being"
|
||||
assert_true(len(inputs) == 2)
|
||||
if not value_is_unsigned_integer(inputs[0]) or not value_is_unsigned_integer(inputs[1]):
|
||||
return "only unsigned integer matrix multiplication is supported"
|
||||
|
||||
else: # pragma: no cover
|
||||
assert_not_reached("Non IntermediateNode object in the OPGraph")
|
||||
|
||||
@@ -493,10 +493,6 @@ class NPTracer(BaseTracer):
|
||||
"""Trace numpy.matmul."""
|
||||
return self.__array_ufunc__(numpy.matmul, "__call__", self, other)
|
||||
|
||||
def matmul(self, other):
|
||||
"""Trace x.matmul."""
|
||||
return self.__array_ufunc__(numpy.matmul, "__call__", self, other)
|
||||
|
||||
# Supported functions are either univariate or bivariate for which one of the two
|
||||
# sources is a constant
|
||||
#
|
||||
|
||||
@@ -65,6 +65,20 @@ from concrete.numpy import compile_numpy_function
|
||||
"EncryptedTensor<uint7, shape=(2,)> "
|
||||
"cannot be converted to MLIR yet",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.ones(shape=(2, 3), dtype=numpy.uint32) @ x,
|
||||
{"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)]
|
||||
+ [(numpy.array([[7, 7], [7, 7], [7, 7]]),)],
|
||||
NotImplementedError,
|
||||
"Matrix multiplication "
|
||||
"between "
|
||||
"ClearTensor<uint6, shape=(2, 3)> "
|
||||
"and "
|
||||
"EncryptedTensor<uint5, shape=(3, 2)> "
|
||||
"cannot be converted to MLIR yet "
|
||||
"(notice the encrypted value is in the right hand side which is not supported)",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_fail_node_conversion(
|
||||
|
||||
@@ -993,6 +993,71 @@ def test_compile_and_run_constant_dot_correctness(
|
||||
assert right_circuit.run(*args) == right(*args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lhs_shape,rhs_shape,input_range",
|
||||
[
|
||||
pytest.param(
|
||||
(3, 2),
|
||||
(2, 3),
|
||||
(0, 4),
|
||||
),
|
||||
pytest.param(
|
||||
(1, 2),
|
||||
(2, 1),
|
||||
(0, 4),
|
||||
),
|
||||
pytest.param(
|
||||
(3, 3),
|
||||
(3, 3),
|
||||
(0, 4),
|
||||
),
|
||||
pytest.param(
|
||||
(2, 1),
|
||||
(1, 2),
|
||||
(0, 8),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_matmul_correctness(
|
||||
lhs_shape, rhs_shape, input_range, default_compilation_configuration
|
||||
):
|
||||
"""Test correctness of results when running a compiled function"""
|
||||
|
||||
low, high = input_range
|
||||
|
||||
inputset = [
|
||||
(numpy.zeros(lhs_shape, dtype=numpy.uint32),),
|
||||
(numpy.ones(lhs_shape, dtype=numpy.uint32) * high,),
|
||||
]
|
||||
for _ in range(8):
|
||||
inputset.append((numpy.random.randint(low, high + 1, size=lhs_shape),))
|
||||
|
||||
constant = numpy.random.randint(low, high + 1, size=rhs_shape)
|
||||
|
||||
def using_operator(x):
|
||||
return x @ constant
|
||||
|
||||
def using_function(x):
|
||||
return numpy.matmul(x, constant)
|
||||
|
||||
operator_circuit = compile_numpy_function(
|
||||
using_operator,
|
||||
{"x": EncryptedTensor(UnsignedInteger(3), lhs_shape)},
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
function_circuit = compile_numpy_function(
|
||||
using_function,
|
||||
{"x": EncryptedTensor(UnsignedInteger(3), lhs_shape)},
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
args = (numpy.random.randint(low, high + 1, size=lhs_shape, dtype=numpy.uint8),)
|
||||
assert numpy.array_equal(operator_circuit.run(*args), using_operator(*args))
|
||||
assert numpy.array_equal(function_circuit.run(*args), using_function(*args))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_bits,list_of_arg_names",
|
||||
[
|
||||
@@ -1202,52 +1267,18 @@ return %9
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x @ numpy.ones(shape=(2, 3), dtype=numpy.uint32),
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},
|
||||
lambda x: x @ -numpy.ones(shape=(2, 3), dtype=numpy.int32),
|
||||
{"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)],
|
||||
(
|
||||
"""
|
||||
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<uint3, shape=(3, 2)>
|
||||
%1 = [[1 1 1] [1 1 1]] # ClearTensor<uint1, shape=(2, 3)>
|
||||
%2 = matmul(%0, %1) # EncryptedTensor<uint4, shape=(3, 3)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being
|
||||
return %2
|
||||
|
||||
""".strip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.matmul(x, numpy.ones(shape=(2, 3), dtype=numpy.uint32)),
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)],
|
||||
(
|
||||
"""
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<uint3, shape=(3, 2)>
|
||||
%1 = [[1 1 1] [1 1 1]] # ClearTensor<uint1, shape=(2, 3)>
|
||||
%2 = matmul(%0, %1) # EncryptedTensor<uint4, shape=(3, 3)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being
|
||||
return %2
|
||||
""".strip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.matmul(numpy.ones(shape=(2, 3), dtype=numpy.uint32)),
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)],
|
||||
(
|
||||
"""
|
||||
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<uint3, shape=(3, 2)>
|
||||
%1 = [[1 1 1] [1 1 1]] # ClearTensor<uint1, shape=(2, 3)>
|
||||
%2 = matmul(%0, %1) # EncryptedTensor<uint4, shape=(3, 3)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being
|
||||
%0 = x # EncryptedTensor<uint3, shape=(3, 2)>
|
||||
%1 = [[-1 -1 -1] [-1 -1 -1]] # ClearTensor<int2, shape=(2, 3)>
|
||||
%2 = matmul(%0, %1) # EncryptedTensor<int5, shape=(3, 3)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer matrix multiplication is supported
|
||||
return %2
|
||||
|
||||
""".strip() # noqa: E501
|
||||
|
||||
Reference in New Issue
Block a user