diff --git a/concrete/common/mlir/node_converter.py b/concrete/common/mlir/node_converter.py index 80affd9a7..35024b44d 100644 --- a/concrete/common/mlir/node_converter.py +++ b/concrete/common/mlir/node_converter.py @@ -432,24 +432,20 @@ class IntermediateNodeConverter: assert_true(len(self.node.inputs) == 2) assert_true(len(self.node.outputs) == 1) - if self.all_of_the_inputs_are_encrypted or self.node.inputs[0].is_clear: + if self.all_of_the_inputs_are_encrypted: lhs = self.node.inputs[0] rhs = self.node.inputs[1] - - additional_error_info = ( - " (notice the encrypted value is in the right hand side which is not supported)" - if self.node.inputs[0].is_clear - else "" - ) raise NotImplementedError( - f"Matrix multiplication between {lhs} and {rhs} cannot be converted to MLIR yet" - f"{additional_error_info}", + f"Matrix multiplication between {lhs} and {rhs} cannot be converted to MLIR yet", ) resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0]) preds = self.preds - result = hlfhelinalg.MatMulEintIntOp(resulting_type, *preds).result + if self.node.inputs[0].is_clear: + result = hlfhelinalg.MatMulIntEintOp(resulting_type, *preds).result + else: + result = hlfhelinalg.MatMulEintIntOp(resulting_type, *preds).result return result diff --git a/tests/common/mlir/test_node_converter.py b/tests/common/mlir/test_node_converter.py index 9a7db6a44..521be627e 100644 --- a/tests/common/mlir/test_node_converter.py +++ b/tests/common/mlir/test_node_converter.py @@ -66,18 +66,26 @@ from concrete.numpy import compile_numpy_function "cannot be converted to MLIR yet", ), pytest.param( - lambda x: numpy.ones(shape=(2, 3), dtype=numpy.uint32) @ x, - {"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2))}, - [numpy.random.randint(0, 2 ** 3, size=(3, 2)) for i in range(10)] - + [numpy.array([[7, 7], [7, 7], [7, 7]])], + lambda x, y: x @ y, + { + "x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)), + "y": EncryptedTensor(UnsignedInteger(3), shape=(2, 1)), + }, + [ + ( + numpy.random.randint(0, 2 ** 3, size=(3, 2)), + numpy.random.randint(0, 2 ** 3, size=(2, 1)), + ) + for i in range(10) + ] + + [(numpy.array([[7, 7], [7, 7], [7, 7]]), numpy.array([[7], [7]]))], NotImplementedError, "Matrix multiplication " "between " - "ClearTensor " + "EncryptedTensor " "and " - "EncryptedTensor " - "cannot be converted to MLIR yet " - "(notice the encrypted value is in the right hand side which is not supported)", + "EncryptedTensor " + "cannot be converted to MLIR yet", ), ], ) diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 69ee69f84..1b36510b5 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -1343,37 +1343,65 @@ def test_compile_and_run_matmul_correctness( low, high = input_range - inputset = [ + lhs_inputset = [ numpy.zeros(lhs_shape, dtype=numpy.uint32), numpy.ones(lhs_shape, dtype=numpy.uint32) * high, ] + rhs_inputset = [ + numpy.zeros(rhs_shape, dtype=numpy.uint32), + numpy.ones(rhs_shape, dtype=numpy.uint32) * high, + ] for _ in range(8): - inputset.append(numpy.random.randint(low, high + 1, size=lhs_shape)) + lhs_inputset.append(numpy.random.randint(low, high + 1, size=lhs_shape)) + rhs_inputset.append(numpy.random.randint(low, high + 1, size=rhs_shape)) - constant = numpy.random.randint(low, high + 1, size=rhs_shape) + left_constant = numpy.random.randint(low, high + 1, size=lhs_shape) + right_constant = numpy.random.randint(low, high + 1, size=rhs_shape) - def using_operator(x): - return x @ constant + def using_operator_left(x): + return x @ right_constant - def using_function(x): - return numpy.matmul(x, constant) + def using_function_left(x): + return numpy.matmul(x, right_constant) - operator_circuit = compile_numpy_function( - using_operator, + def using_operator_right(x): + return left_constant @ x + + def using_function_right(x): + return numpy.matmul(left_constant, x) + + operator_left_circuit = compile_numpy_function( + using_operator_left, {"x": EncryptedTensor(UnsignedInteger(3), lhs_shape)}, - inputset, + lhs_inputset, default_compilation_configuration, ) - function_circuit = compile_numpy_function( - using_function, + function_left_circuit = compile_numpy_function( + using_function_left, {"x": EncryptedTensor(UnsignedInteger(3), lhs_shape)}, - inputset, + lhs_inputset, + default_compilation_configuration, + ) + operator_right_circuit = compile_numpy_function( + using_operator_right, + {"x": EncryptedTensor(UnsignedInteger(3), rhs_shape)}, + rhs_inputset, + default_compilation_configuration, + ) + function_right_circuit = compile_numpy_function( + using_function_right, + {"x": EncryptedTensor(UnsignedInteger(3), rhs_shape)}, + rhs_inputset, default_compilation_configuration, ) - args = (numpy.random.randint(low, high + 1, size=lhs_shape, dtype=numpy.uint8),) - check_array_equality(operator_circuit.run(*args), using_operator(*args)) - check_array_equality(function_circuit.run(*args), using_function(*args)) + lhs_arg = numpy.random.randint(low, high + 1, size=lhs_shape, dtype=numpy.uint8) + check_array_equality(operator_left_circuit.run(lhs_arg), using_operator_left(lhs_arg)) + check_array_equality(function_left_circuit.run(lhs_arg), using_function_left(lhs_arg)) + + rhs_arg = numpy.random.randint(low, high + 1, size=rhs_shape, dtype=numpy.uint8) + check_array_equality(operator_right_circuit.run(rhs_arg), using_operator_right(rhs_arg)) + check_array_equality(function_right_circuit.run(rhs_arg), using_function_right(rhs_arg)) @pytest.mark.parametrize(